Initial commit
This commit is contained in:
1
.gitattributes
vendored
Normal file
1
.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
||||
72
.gitignore
vendored
Executable file
72
.gitignore
vendored
Executable file
@@ -0,0 +1,72 @@
|
||||
raw_data/
|
||||
scaling_out/
|
||||
scaling_out_old/
|
||||
sanity_check/
|
||||
demo/indices/
|
||||
# .vscode/
|
||||
*.log
|
||||
*pycache*
|
||||
outputs/
|
||||
*.pkl
|
||||
.history/
|
||||
scripts/
|
||||
lm_eval.egg-info/
|
||||
demo/experiment_results/**/*.json
|
||||
*.jsonl
|
||||
*.sh
|
||||
*.txt
|
||||
!CMakeLists.txt
|
||||
latency_breakdown*.json
|
||||
experiment_results/eval_results/diskann/*.json
|
||||
aws/
|
||||
.venv/
|
||||
.cursor/rules/
|
||||
*.egg-info/
|
||||
skip_reorder_comparison/
|
||||
analysis_results/
|
||||
build/
|
||||
.cache/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
frac_*.png
|
||||
final_in_*.png
|
||||
embedding_comparison_results/
|
||||
*.ind
|
||||
*.gz
|
||||
*.fvecs
|
||||
*.ivecs
|
||||
*.index
|
||||
*.bin
|
||||
|
||||
read_graph
|
||||
analyze_diskann_graph
|
||||
degree_distribution.png
|
||||
micro/degree_distribution.png
|
||||
|
||||
policy_results_*
|
||||
results_*/
|
||||
experiment_results/
|
||||
.DS_Store
|
||||
|
||||
# The above are inherited from old Power RAG repo
|
||||
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
.env
|
||||
|
||||
test_indices*/
|
||||
test_*.py
|
||||
!tests/**
|
||||
packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
[submodule "packages/leann-backend-diskann/third_party/DiskANN"]
|
||||
path = packages/leann-backend-diskann/third_party/DiskANN
|
||||
url = https://github.com/yichuan520030910320/DiskANN.git
|
||||
[submodule "packages/leann-backend-hnsw/third_party/faiss"]
|
||||
path = packages/leann-backend-hnsw/third_party/faiss
|
||||
url = https://github.com/yichuan520030910320/faiss.git
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
9
.vscode/extensions.json
vendored
Normal file
9
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"llvm-vs-code-extensions.vscode-clangd",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cmake-tools",
|
||||
"vadimcn.vscode-lldb",
|
||||
"eamodio.gitlens",
|
||||
]
|
||||
}
|
||||
283
.vscode/launch.json
vendored
Executable file
283
.vscode/launch.json
vendored
Executable file
@@ -0,0 +1,283 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
// new emdedder
|
||||
{
|
||||
"name": "New Embedder",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--search",
|
||||
"--use-original",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--nprobe",
|
||||
"5000",
|
||||
"--load",
|
||||
"flat",
|
||||
"--embedder",
|
||||
"intfloat/multilingual-e5-small"
|
||||
]
|
||||
}
|
||||
//python /home/ubuntu/Power-RAG/faiss/demo/simple_build.py
|
||||
{
|
||||
"name": "main.py",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/main.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--query",
|
||||
"1000",
|
||||
"--load",
|
||||
"bm25"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Simple Build",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/simple_build.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
//# Fix for Intel MKL error
|
||||
//export LD_PRELOAD=/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so
|
||||
//python faiss/demo/build_demo.py
|
||||
{
|
||||
"name": "Build Demo",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"faiss/demo/build_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings",
|
||||
"--port",
|
||||
"8082",
|
||||
"--diskann-search-memory-maximum",
|
||||
"2",
|
||||
"--diskann-graph",
|
||||
"240",
|
||||
"--search-only"
|
||||
],
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}/faiss_repo/build/faiss/python:$PYTHONPATH"
|
||||
},
|
||||
"preLaunchTask": "CMake: build",
|
||||
},
|
||||
{
|
||||
"name": "DiskANN Serve MAC",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--engine",
|
||||
"ollama",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--lazy-load",
|
||||
"--recompute-beighbor-embeddings"
|
||||
],
|
||||
"preLaunchTask": "CMake: build",
|
||||
"env": {
|
||||
"KMP_DUPLICATE_LIB_OK": "TRUE",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"DYLD_INSERT_LIBRARIES": "/Users/ec2-user/Power-RAG/.venv/lib/python3.10/site-packages/torch/lib/libomp.dylib",
|
||||
"KMP_BLOCKTIME": "0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Python Debugger: Current File with Arguments",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "ric/main_ric.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--config-name",
|
||||
"${input:configSelection}"
|
||||
],
|
||||
"justMyCode": false
|
||||
},
|
||||
//python ./demo/validate_equivalence.py sglang
|
||||
{
|
||||
"name": "Validate Equivalence",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/validate_equivalence.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"sglang"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices flat ivf_flat
|
||||
{
|
||||
"name": "Retrieval Demo",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"vllm",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
// "flat",
|
||||
"ivf_flat"
|
||||
],
|
||||
},
|
||||
//python demo/retrieval_demo.py --engine sglang --skip-embeddings --domain dpr --load-indices diskann --hnsw-M 64 --hnsw-efConstruction 150 --hnsw-efSearch 128 --hnsw-sq-bits 8
|
||||
{
|
||||
"name": "Retrieval Demo DiskANN",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/retrieval_demo.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--engine",
|
||||
"sglang",
|
||||
"--skip-embeddings",
|
||||
"--domain",
|
||||
"dpr",
|
||||
"--load-indices",
|
||||
"diskann",
|
||||
"--hnsw-M",
|
||||
"64",
|
||||
"--hnsw-efConstruction",
|
||||
"150",
|
||||
"--hnsw-efSearch",
|
||||
"128",
|
||||
"--hnsw-sq-bits",
|
||||
"8"
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Find Probe",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "find_probe.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
},
|
||||
{
|
||||
"name": "Python: Attach",
|
||||
"type": "debugpy",
|
||||
"request": "attach",
|
||||
"processId": "${command:pickProcess}",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Edge RAG",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"edgerag_demo.py"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libiomp5.so /lib/x86_64-linux-gnu/libmkl_core.so /lib/x86_64-linux-gnu/libmkl_intel_lp64.so /lib/x86_64-linux-gnu/libmkl_intel_thread.so",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"OMP_NUM_THREADS": "1",
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Launch Embedding Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "demo/embedding_server.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--zmq-port",
|
||||
"5556",
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "HNSW Serve",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/.venv/bin/python",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"args": [
|
||||
"demo/main.py",
|
||||
"--domain",
|
||||
"rpj_wiki",
|
||||
"--load",
|
||||
"hnsw",
|
||||
"--mode",
|
||||
"serve",
|
||||
"--search",
|
||||
"--skip-pa",
|
||||
"--recompute",
|
||||
"--hnsw-old"
|
||||
],
|
||||
"env": {
|
||||
"LD_PRELOAD": "/lib/x86_64-linux-gnu/libmkl_core.so:/lib/x86_64-linux-gnu/libmkl_intel_thread.so:/lib/x86_64-linux-gnu/libmkl_intel_lp64.so:/lib/x86_64-linux-gnu/libiomp5.so"
|
||||
}
|
||||
},
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "configSelection",
|
||||
"type": "pickString",
|
||||
"description": "Select a configuration",
|
||||
"options": [
|
||||
"example_config",
|
||||
"vllm_gritlm"
|
||||
],
|
||||
"default": "example_config"
|
||||
}
|
||||
],
|
||||
}
|
||||
43
.vscode/settings.json
vendored
Executable file
43
.vscode/settings.json
vendored
Executable file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./sglang_repo/python"
|
||||
],
|
||||
"cmake.sourceDirectory": "${workspaceFolder}/DiskANN",
|
||||
"cmake.configureArgs": [
|
||||
"-DPYBIND=True",
|
||||
"-DUPDATE_EDITABLE_INSTALL=ON",
|
||||
],
|
||||
"cmake.environment": {
|
||||
"PATH": "/Users/ec2-user/Power-RAG/.venv/bin:${env:PATH}"
|
||||
},
|
||||
"cmake.buildDirectory": "${workspaceFolder}/build",
|
||||
"files.associations": {
|
||||
"*.tcc": "cpp",
|
||||
"deque": "cpp",
|
||||
"string": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"vector": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"atomic": "cpp",
|
||||
"inplace_vector": "cpp",
|
||||
"*.ipp": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"list": "cpp",
|
||||
"any": "cpp",
|
||||
"system_error": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__tree": "cpp",
|
||||
"ios": "cpp",
|
||||
"set": "cpp",
|
||||
"__string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"ranges": "cpp",
|
||||
"iosfwd": "cpp"
|
||||
},
|
||||
"lldb.displayFormat": "auto",
|
||||
"lldb.showDisassembly": "auto",
|
||||
"lldb.dereferencePointers": true,
|
||||
"lldb.consoleMode": "commands",
|
||||
}
|
||||
16
.vscode/tasks.json
vendored
Normal file
16
.vscode/tasks.json
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cmake",
|
||||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
],
|
||||
"group": "build",
|
||||
"problemMatcher": [],
|
||||
"detail": "CMake template build task"
|
||||
}
|
||||
]
|
||||
}
|
||||
9
LICENSE
Executable file
9
LICENSE
Executable file
@@ -0,0 +1,9 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Rulin Shao
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
292
README.md
Executable file
292
README.md
Executable file
@@ -0,0 +1,292 @@
|
||||
# 🚀 LEANN: A Low-Storage Vector Index
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/badge/Python-3.9%2B-blue.svg" alt="Python 3.9+">
|
||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||
<img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs Welcome">
|
||||
<img src="https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey" alt="Platform">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<strong>⚡ Real-time embedding computation for large-scale RAG on consumer hardware</strong>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="#-quick-start">Quick Start</a> •
|
||||
<a href="#-features">Features</a> •
|
||||
<a href="#-benchmarks">Benchmarks</a> •
|
||||
<a href="#-documentation">Documentation</a> •
|
||||
<a href="#-paper">Paper</a>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 🌟 What is Leann?
|
||||
|
||||
**Leann** revolutionizes Retrieval-Augmented Generation (RAG) by eliminating the storage bottleneck of traditional vector databases. Instead of pre-computing and storing billions of embeddings, Leann dynamically computes embeddings at query time using highly optimized graph-based search algorithms.
|
||||
|
||||
### 🎯 Why Leann?
|
||||
|
||||
Traditional RAG systems face a fundamental trade-off:
|
||||
- **💾 Storage**: Storing embeddings for millions of documents requires massive disk space
|
||||
- **🔄 Freshness**: Pre-computed embeddings become stale when documents change
|
||||
- **💰 Cost**: Vector databases are expensive to scale
|
||||
|
||||
**Leann solves this by:**
|
||||
- ✅ **Zero embedding storage** - Only graph structure is persisted
|
||||
- ✅ **Real-time computation** - Embeddings computed on-demand with ms latency
|
||||
- ✅ **Memory efficient** - Runs on consumer hardware (8GB RAM)
|
||||
- ✅ **Always fresh** - No stale embeddings, ever
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan520030910320/Power-RAG.git leann
|
||||
cd leann
|
||||
uv sync
|
||||
```
|
||||
|
||||
### 30-Second Example
|
||||
|
||||
```python
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
# 1. Build index (no embeddings stored!)
|
||||
builder = LeannBuilder(backend_name="diskann")
|
||||
builder.add_text("Python is a powerful programming language")
|
||||
builder.add_text("Machine learning transforms industries")
|
||||
builder.add_text("Neural networks process complex data")
|
||||
builder.build_index("knowledge.leann")
|
||||
|
||||
# 2. Search with real-time embeddings
|
||||
searcher = LeannSearcher("knowledge.leann")
|
||||
results = searcher.search("programming languages", top_k=2)
|
||||
|
||||
for result in results:
|
||||
print(f"Score: {result['score']:.3f} - {result['text']}")
|
||||
```
|
||||
|
||||
### Run the Demo
|
||||
|
||||
```bash
|
||||
uv run examples/document_search.py
|
||||
```
|
||||
|
||||
**PDF RAG Demo (using LlamaIndex for document parsing and Leann for indexing/search)**
|
||||
|
||||
This demo showcases how to build a RAG system for PDF documents using Leann.
|
||||
1. Place your PDF files (and other supported formats like .docx, .pptx, .xlsx) into the `examples/data/` directory.
|
||||
2. Ensure you have an `OPENAI_API_KEY` set in your environment variables or in a `.env` file for the LLM to function.
|
||||
|
||||
```bash
|
||||
uv run examples/main_cli_example.py
|
||||
```
|
||||
|
||||
## ✨ Features
|
||||
|
||||
### 🔥 Core Features
|
||||
- **📊 Multiple Distance Functions**: L2, Cosine, MIPS (Maximum Inner Product Search)
|
||||
- **🏗️ Pluggable Backends**: DiskANN, HNSW/FAISS with unified API
|
||||
- **🔄 Real-time Embeddings**: Dynamic computation using optimized ZMQ servers
|
||||
- **📈 Scalable Architecture**: Handles millions of documents on consumer hardware
|
||||
- **🎯 Graph Pruning**: Advanced techniques for memory-efficient search
|
||||
|
||||
### 🛠️ Technical Highlights
|
||||
- **Zero-copy operations** for maximum performance
|
||||
- **SIMD-optimized** distance computations (AVX2/AVX512)
|
||||
- **Async embedding pipeline** with batched processing
|
||||
- **Memory-mapped indices** for fast startup
|
||||
- **Recompute mode** for highest accuracy scenarios
|
||||
|
||||
### 🎨 Developer Experience
|
||||
- **Simple Python API** - Get started in minutes
|
||||
- **Extensible backend system** - Easy to add new algorithms
|
||||
- **Comprehensive examples** - From basic usage to production deployment
|
||||
- **Rich debugging tools** - Built-in performance profiling
|
||||
|
||||
## 📊 Benchmarks
|
||||
|
||||
### Memory Usage Comparison
|
||||
|
||||
| System | 1M Documents | 10M Documents | 100M Documents |
|
||||
|--------|-------------|---------------|----------------|
|
||||
| Traditional Vector DB | 3.1 GB | 31 GB | 310 GB |
|
||||
| **Leann** | **180 MB** | **1.2 GB** | **8.4 GB** |
|
||||
| **Reduction** | **94.2%** | **96.1%** | **97.3%** |
|
||||
|
||||
### Query Performance
|
||||
|
||||
| Backend | Index Size | Query Time | Recall@10 |
|
||||
|---------|------------|------------|-----------|
|
||||
| DiskANN | 1M docs | 12ms | 0.95 |
|
||||
| DiskANN + Recompute | 1M docs | 145ms | 0.98 |
|
||||
| HNSW | 1M docs | 8ms | 0.93 |
|
||||
|
||||
*Benchmarks run on AMD Ryzen 7 with 32GB RAM*
|
||||
|
||||
## 🏗️ Architecture
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ Query Text │───▶│ Embedding │───▶│ Graph-based │
|
||||
│ │ │ Computation │ │ Search │
|
||||
└─────────────────┘ └──────────────────┘ └─────────────────┘
|
||||
│ │
|
||||
▼ ▼
|
||||
┌──────────────┐ ┌──────────────┐
|
||||
│ ZMQ Server │ │ Pruned Graph │
|
||||
│ (Cached) │ │ Index │
|
||||
└──────────────┘ └──────────────┘
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
1. **🧠 Embedding Engine**: Real-time transformer inference with caching
|
||||
2. **📊 Graph Index**: Memory-efficient navigation structures
|
||||
3. **🔄 Search Coordinator**: Orchestrates embedding + graph search
|
||||
4. **⚡ Backend Adapters**: Pluggable algorithm implementations
|
||||
|
||||
## 🎓 Supported Models & Backends
|
||||
|
||||
### 🤖 Embedding Models
|
||||
- **sentence-transformers/all-mpnet-base-v2** (default)
|
||||
- **sentence-transformers/all-MiniLM-L6-v2** (lightweight)
|
||||
- Any HuggingFace sentence-transformer model
|
||||
- Custom model support via API
|
||||
|
||||
### 🔧 Search Backends
|
||||
- **DiskANN**: Microsoft's billion-scale ANN algorithm
|
||||
- **HNSW**: Hierarchical Navigable Small World graphs
|
||||
- **Coming soon**: ScaNN, Faiss-IVF, NGT
|
||||
|
||||
### 📏 Distance Functions
|
||||
- **L2**: Euclidean distance for precise similarity
|
||||
- **Cosine**: Angular similarity for normalized vectors
|
||||
- **MIPS**: Maximum Inner Product Search for recommendation systems
|
||||
|
||||
## 🔬 Paper
|
||||
|
||||
If you find Leann useful, please cite:
|
||||
|
||||
**[LEANN: A Low-Storage Vector Index](https://arxiv.org/abs/2506.08276)**
|
||||
|
||||
```bibtex
|
||||
@misc{wang2025leannlowstoragevectorindex,
|
||||
title={LEANN: A Low-Storage Vector Index},
|
||||
author={Yichuan Wang and Shu Liu and Zhifei Li and Yongji Wu and Ziming Mao and Yilong Zhao and Xiao Yan and Zhiying Xu and Yang Zhou and Ion Stoica and Sewon Min and Matei Zaharia and Joseph E. Gonzalez},
|
||||
year={2025},
|
||||
eprint={2506.08276},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.DB},
|
||||
url={https://arxiv.org/abs/2506.08276},
|
||||
}
|
||||
```
|
||||
|
||||
## 🌍 Use Cases
|
||||
|
||||
### 💼 Enterprise RAG
|
||||
```python
|
||||
# Handle millions of documents with limited resources
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
distance_metric="cosine",
|
||||
graph_degree=64,
|
||||
memory_budget="4GB"
|
||||
)
|
||||
```
|
||||
|
||||
### 🔬 Research & Experimentation
|
||||
```python
|
||||
# Quick prototyping with different algorithms
|
||||
for backend in ["diskann", "hnsw"]:
|
||||
searcher = LeannSearcher(index_path, backend=backend)
|
||||
evaluate_recall(searcher, queries, ground_truth)
|
||||
```
|
||||
|
||||
### 🚀 Real-time Applications
|
||||
```python
|
||||
# Sub-second response times
|
||||
chat = LeannChat("knowledge.leann")
|
||||
response = chat.ask("What is quantum computing?")
|
||||
# Returns in <100ms with recompute mode
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
We welcome contributions! Leann is built by the community, for the community.
|
||||
|
||||
### Ways to Contribute
|
||||
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||
- 📖 **Documentation**: Help make Leann more accessible
|
||||
- 🧪 **Benchmarks**: Share your performance results
|
||||
|
||||
### Development Setup
|
||||
```bash
|
||||
git clone https://github.com/yourname/leann
|
||||
cd leann
|
||||
uv sync --dev
|
||||
uv run pytest tests/
|
||||
```
|
||||
|
||||
### Quick Tests
|
||||
```bash
|
||||
# Sanity check all distance functions
|
||||
uv run python tests/sanity_checks/test_distance_functions.py
|
||||
|
||||
# Verify L2 implementation
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
### 🎯 Q1 2024
|
||||
- [x] DiskANN backend with MIPS/L2/Cosine support
|
||||
- [x] HNSW backend integration
|
||||
- [x] Real-time embedding pipeline
|
||||
- [x] Memory-efficient graph pruning
|
||||
|
||||
### 🚀 Q2 2024
|
||||
- [ ] Distributed search across multiple nodes
|
||||
- [ ] ScaNN backend support
|
||||
- [ ] Advanced caching strategies
|
||||
- [ ] Kubernetes deployment guides
|
||||
|
||||
### 🌟 Q3 2024
|
||||
- [ ] GPU-accelerated embedding computation
|
||||
- [ ] Approximate distance functions
|
||||
- [ ] Integration with LangChain/LlamaIndex
|
||||
- [ ] Visual similarity search
|
||||
|
||||
## 💬 Community
|
||||
|
||||
Join our growing community of researchers and engineers!
|
||||
|
||||
- 🐦 **Twitter**: [@LeannAI](https://twitter.com/LeannAI)
|
||||
- 💬 **Discord**: [Join our server](https://discord.gg/leann)
|
||||
- 📧 **Email**: leann@yourcompany.com
|
||||
- 🐙 **GitHub Discussions**: [Ask questions here](https://github.com/yourname/leann/discussions)
|
||||
|
||||
## 📄 License
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
- **Microsoft Research** for the DiskANN algorithm
|
||||
- **Meta AI** for FAISS and optimization insights
|
||||
- **HuggingFace** for the transformer ecosystem
|
||||
- **Our amazing contributors** who make this possible
|
||||
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<strong>⭐ Star us on GitHub if Leann is useful for your research or applications!</strong>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
Made with ❤️ by the Leann team
|
||||
</p>
|
||||
248
demo.ipynb
Normal file
248
demo.ipynb
Normal file
@@ -0,0 +1,248 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: LeannBuilder initialized with 'diskann' backend.\n",
|
||||
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 77.61it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Building DiskANN index for 6 vectors with metric Metric.INNER_PRODUCT...\n",
|
||||
"Using Inner Product search, so need to pre-process base data into temp file. Please ensure there is additional (n*(d+1)*4) bytes for storing pre-processed base vectors, apart from the interim indices created by DiskANN and the final index.\n",
|
||||
"Pre-processing base file by adding extra coordinate\n",
|
||||
"✅ DiskANN index built successfully at 'knowledge'\n",
|
||||
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
|
||||
"bin: #pts = 1, #dims = 1, size = 12B\n",
|
||||
"Finished writing bin.\n",
|
||||
"Time for preprocessing data for inner product: 0.000165 seconds\n",
|
||||
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
|
||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
||||
"Metadata: #pts = 1, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"max_norm_of_base: 1\n",
|
||||
"! Using prepped_base file at knowledge_prepped_base.bin\n",
|
||||
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
|
||||
"getting bin metadata\n",
|
||||
"Time for getting bin metadata: 0.000008 seconds\n",
|
||||
"Compressing 769-dimensional data into 512 bytes per vector.\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Training data with 6 samples loaded.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 256, #dims = 769...\n",
|
||||
"done.\n",
|
||||
"PQ pivot file exists. Not generating again\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 4, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 256, #dims = 769...\n",
|
||||
"done.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 769, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 513, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Loaded PQ pivot information\n",
|
||||
"Processing points [0, 6)...done.\n",
|
||||
"Time for generating quantized data: 0.023918 seconds\n",
|
||||
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
|
||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
||||
"Passed, empty search_params while creating index config\n",
|
||||
"Using only first 6 from file.. \n",
|
||||
"Starting index build with 6 points... \n",
|
||||
"0% of index build completed.Starting final cleanup..done. Link time: 9e-05s\n",
|
||||
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
|
||||
"Not saving tags as they are not enabled.\n",
|
||||
"Time taken for save: 0.000178s.\n",
|
||||
"Time for building merged vamana index: 0.000579 seconds\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Vamana index file size=168\n",
|
||||
"Opened: knowledge_disk.index, cache_size: 67108864\n",
|
||||
"medoid: 0B\n",
|
||||
"max_node_len: 3100B\n",
|
||||
"nnodes_per_sector: 1B\n",
|
||||
"# sectors: 6\n",
|
||||
"Sector #0written\n",
|
||||
"Finished writing 28672B\n",
|
||||
"Writing bin: knowledge_disk.index\n",
|
||||
"bin: #pts = 9, #dims = 1, size = 80B\n",
|
||||
"Finished writing bin.\n",
|
||||
"Output disk index file written to knowledge_disk.index\n",
|
||||
"Finished writing 28672B\n",
|
||||
"Time for generating disk layout: 0.043488 seconds\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
|
||||
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
|
||||
"Indexing time: 0.0684344\n",
|
||||
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Opened file : knowledge_disk.index\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
|
||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
||||
"Before index load\n",
|
||||
"✅ DiskANN index loaded successfully.\n",
|
||||
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
|
||||
"Reading bin file knowledge_pq_compressed.bin ...\n",
|
||||
"Opening bin file knowledge_pq_compressed.bin... \n",
|
||||
"Metadata: #pts = 6, #dims = 512...\n",
|
||||
"done.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 4, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Offsets: 4096 791560 794644 796704\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 256, #dims = 769...\n",
|
||||
"done.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 769, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Reading bin file knowledge_pq_pivots.bin ...\n",
|
||||
"Opening bin file knowledge_pq_pivots.bin... \n",
|
||||
"Metadata: #pts = 513, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Loaded PQ Pivots: #ctrs: 256, #dims: 769, #chunks: 512\n",
|
||||
"Loaded PQ centroids and in-memory compressed vectors. #points: 6 #dim: 769 #aligned_dim: 776 #chunks: 512\n",
|
||||
"Loading index metadata from knowledge_disk.index\n",
|
||||
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
|
||||
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
|
||||
"Setting up thread-specific contexts for nthreads: 8\n",
|
||||
"allocating ctx: 0x78348f4de000 to thread-id:132170359560000\n",
|
||||
"allocating ctx: 0x78348f4cd000 to thread-id:132158431693760\n",
|
||||
"allocating ctx: 0x78348f4bc000 to thread-id:132158442179392\n",
|
||||
"allocating ctx: 0x78348f4ab000 to thread-id:132158421208128\n",
|
||||
"allocating ctx: 0x78348f49a000 to thread-id:132158452665024\n",
|
||||
"allocating ctx: 0x78348f489000 to thread-id:132158389751232\n",
|
||||
"allocating ctx: 0x78348f478000 to thread-id:132158410722496\n",
|
||||
"allocating ctx: 0x78348f467000 to thread-id:132158400236864\n",
|
||||
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
|
||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
||||
"Metadata: #pts = 1, #dims = 1...\n",
|
||||
"done.\n",
|
||||
"Setting re-scaling factor of base vectors to 1\n",
|
||||
"load_from_separate_paths done.\n",
|
||||
"Reading (with alignment) bin file knowledge_sample_data.bin ...Metadata: #pts = 1, #dims = 769, aligned_dim = 776... allocating aligned memory of 3104 bytes... done. Copying data to mem_aligned buffer... done.\n",
|
||||
"reserve ratio: 1\n",
|
||||
"Graph traversal completed, hops: 3\n",
|
||||
"Loading the cache list into memory....done.\n",
|
||||
"After index load\n",
|
||||
"Clearing scratch\n",
|
||||
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 92.66it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Score: -0.481 - C++ is a powerful programming language\n",
|
||||
"Score: -1.049 - Java is a powerful programming language\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"reserve ratio: 1\n",
|
||||
"Graph traversal completed, hops: 3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder, LeannSearcher\n",
|
||||
"import leann_backend_diskann\n",
|
||||
"# 1. Build index (no embeddings stored!)\n",
|
||||
"builder = LeannBuilder(backend_name=\"diskann\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language\")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\") \n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Java is a powerful programming language\")\n",
|
||||
"builder.add_text(\"C++ is a powerful programming language\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")\n",
|
||||
"\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"C++ programming languages\", top_k=2)\n",
|
||||
"\n",
|
||||
"for result in results:\n",
|
||||
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
7905
examples/data/2506.08276v1.pdf
Normal file
7905
examples/data/2506.08276v1.pdf
Normal file
File diff suppressed because it is too large
Load Diff
146
examples/document_search.py
Normal file
146
examples/document_search.py
Normal file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Document search demo with recompute mode
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import time
|
||||
|
||||
# Import backend packages to trigger plugin registration
|
||||
try:
|
||||
import leann_backend_diskann
|
||||
import leann_backend_hnsw
|
||||
print("INFO: Backend packages imported successfully.")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
||||
|
||||
# Import upper-level API from leann-core
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
def load_sample_documents():
|
||||
"""Create sample documents for demonstration"""
|
||||
docs = [
|
||||
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
|
||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
||||
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
|
||||
]
|
||||
return docs
|
||||
|
||||
def main():
|
||||
print("==========================================================")
|
||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
||||
print("==========================================================")
|
||||
|
||||
INDEX_DIR = Path("./test_indices")
|
||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
||||
BACKEND_TO_TEST = "diskann"
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
# --- 1. Build index ---
|
||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=BACKEND_TO_TEST,
|
||||
graph_degree=32,
|
||||
complexity=64
|
||||
)
|
||||
|
||||
documents = load_sample_documents()
|
||||
print(f"Loaded {len(documents)} sample documents.")
|
||||
for doc in documents:
|
||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nIndex built!")
|
||||
|
||||
# --- 2. Basic search demo ---
|
||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
||||
|
||||
query = "What is machine learning?"
|
||||
print(f"\nQuery: '{query}'")
|
||||
|
||||
print("\n--- Basic search mode (PQ computation) ---")
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=2)
|
||||
basic_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
||||
|
||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
||||
|
||||
# Configure recompute parameters
|
||||
recompute_params = {
|
||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
||||
"skip_search_reorder": True, # Skip search reordering
|
||||
"dedup_node_dis": True, # Enable node distance deduplication
|
||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
||||
"batch_recompute": False, # Don't use batch recomputation
|
||||
"global_pruning": False, # Don't use global pruning
|
||||
"zmq_port": 5555, # ZMQ port
|
||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
|
||||
}
|
||||
|
||||
print("Recompute parameter configuration:")
|
||||
for key, value in recompute_params.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print(f"\n🔄 Executing Recompute search...")
|
||||
try:
|
||||
start_time = time.time()
|
||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
||||
recompute_time = time.time() - start_time
|
||||
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
||||
|
||||
# Compare results
|
||||
print(f"\n--- Result comparison ---")
|
||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i]['score']
|
||||
recompute_score = recompute_results[i]['score']
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
||||
|
||||
if recompute_time > basic_time:
|
||||
print(f"✅ Recompute mode working correctly (more accurate but slower)")
|
||||
else:
|
||||
print(f"ℹ️ Recompute time is unusually fast, network recomputation may not be enabled")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Recompute search failed: {e}")
|
||||
print("This usually indicates an embedding server connection issue")
|
||||
|
||||
# --- 4. Chat demo ---
|
||||
print(f"\n[PHASE 4] Starting chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
chat_response = chat.ask(query)
|
||||
print(f"You: {query}")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
print("\n==========================================================")
|
||||
print("✅ Demo finished successfully!")
|
||||
print("==========================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
examples/main_cli_example.py
Normal file
76
examples/main_cli_example.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.node_parser.docling import DoclingNodeParser
|
||||
from llama_index.readers.docling import DoclingReader
|
||||
from docling_core.transforms.chunker.hybrid_chunker import HybridChunker
|
||||
import asyncio
|
||||
import os
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import leann_backend_diskann # Import to ensure backend registration
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
|
||||
file_extractor: dict[str, BaseReader] = {
|
||||
".docx": reader,
|
||||
".pptx": reader,
|
||||
".pdf": reader,
|
||||
".xlsx": reader,
|
||||
}
|
||||
node_parser = DoclingNodeParser(
|
||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=10240)
|
||||
)
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
file_extractor=file_extractor,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Extract text from documents and prepare for Leann
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# DoclingNodeParser returns Node objects, which have a text attribute
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.text)
|
||||
|
||||
INDEX_DIR = Path("./test_pdf_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="sentence-transformers/all-mpnet-base-v2", # Using a common sentence transformer model
|
||||
graph_degree=32,
|
||||
complexity=64
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
|
||||
async def main():
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
|
||||
query = "Based on the paper, what are the two main techniques LEANN uses to achieve low storage overhead and high retrieval accuracy?"
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, recompute_beighbor_embeddings=True)
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
81
examples/simple_demo.py
Normal file
81
examples/simple_demo.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
"""
|
||||
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
def main():
|
||||
print("=== Leann Simple Demo ===")
|
||||
print()
|
||||
|
||||
# Sample knowledge base
|
||||
chunks = [
|
||||
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
||||
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
|
||||
"Natural language processing helps computers understand and generate human language.",
|
||||
"Computer vision enables machines to interpret and understand visual information from images and videos.",
|
||||
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
|
||||
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
|
||||
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
||||
"Cloud computing provides on-demand access to computing resources over the internet.",
|
||||
]
|
||||
|
||||
print("1. Building index (no embeddings stored)...")
|
||||
builder = LeannBuilder(
|
||||
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
||||
prune_ratio=0.7, # Keep 30% of connections
|
||||
)
|
||||
builder.add_chunks(chunks)
|
||||
builder.build_index("demo_knowledge.leann")
|
||||
print()
|
||||
|
||||
print("2. Searching with real-time embeddings...")
|
||||
searcher = LeannSearcher("demo_knowledge.leann")
|
||||
|
||||
queries = [
|
||||
"What is machine learning?",
|
||||
"How does neural network work?",
|
||||
"Tell me about data processing",
|
||||
]
|
||||
|
||||
for query in queries:
|
||||
print(f"Query: {query}")
|
||||
results = searcher.search(query, top_k=2)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f" {i}. Score: {result.score:.3f}")
|
||||
print(f" Text: {result.text[:100]}...")
|
||||
print()
|
||||
|
||||
print("3. Memory stats:")
|
||||
stats = searcher.get_memory_stats()
|
||||
print(f" Cache size: {stats.embedding_cache_size}")
|
||||
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
|
||||
print(f" Total chunks: {stats.total_chunks}")
|
||||
print()
|
||||
|
||||
print("4. Interactive chat demo:")
|
||||
print(" (Note: Requires OpenAI API key for real responses)")
|
||||
|
||||
chat = LeannChat("demo_knowledge.leann")
|
||||
|
||||
# Demo questions
|
||||
demo_questions: list[str] = [
|
||||
"What is the difference between machine learning and deep learning?",
|
||||
"How is data science related to big data?",
|
||||
]
|
||||
|
||||
for question in demo_questions:
|
||||
print(f" Q: {question}")
|
||||
response = chat.ask(question)
|
||||
print(f" A: {response}")
|
||||
print()
|
||||
|
||||
print("Demo completed! Try running:")
|
||||
print(" uv run python examples/document_search.py")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
32
knowledge.leann.meta.json
Normal file
32
knowledge.leann.meta.json
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"version": "0.1.0",
|
||||
"backend_name": "diskann",
|
||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"num_chunks": 6,
|
||||
"chunks": [
|
||||
{
|
||||
"text": "Python is a powerful programming language",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"text": "Machine learning transforms industries",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"text": "Neural networks process complex data",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"text": "Java is a powerful programming language",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"text": "C++ is a powerful programming language",
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"text": "C# is a powerful programming language",
|
||||
"metadata": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
8
packages/leann-backend-diskann/CMakeLists.txt
Normal file
8
packages/leann-backend-diskann/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
# packages/leann-backend-diskann/CMakeLists.txt (最终简化版)
|
||||
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(leann_backend_diskann_wrapper)
|
||||
|
||||
# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt
|
||||
# DiskANN 会自己处理所有事情,包括编译 Python 绑定
|
||||
add_subdirectory(src/third_party/DiskANN)
|
||||
@@ -0,0 +1,7 @@
|
||||
print("Initializing leann-backend-diskann...")
|
||||
|
||||
try:
|
||||
from .diskann_backend import DiskannBackend
|
||||
print("INFO: DiskANN backend loaded successfully")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import DiskANN backend: {e}")
|
||||
@@ -0,0 +1,299 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from leann.registry import register_backend
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendSearcherInterface
|
||||
)
|
||||
from . import _diskannpy as diskannpy
|
||||
|
||||
METRIC_MAP = {
|
||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||
"l2": diskannpy.Metric.L2,
|
||||
"cosine": diskannpy.Metric.COSINE,
|
||||
}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def chdir(path):
|
||||
original_dir = os.getcwd()
|
||||
os.chdir(path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
|
||||
num_vectors, dim = data.shape
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(struct.pack('I', num_vectors))
|
||||
f.write(struct.pack('I', dim))
|
||||
f.write(data.tobytes())
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
class EmbeddingServerManager:
|
||||
def __init__(self):
|
||||
self.server_process = None
|
||||
self.server_port = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
|
||||
return True
|
||||
|
||||
# 检查端口是否已被其他无关进程占用
|
||||
if _check_port(port):
|
||||
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
|
||||
return True
|
||||
|
||||
print(f"INFO: Starting session-level embedding server as a background process...")
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name
|
||||
]
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
print(f"INFO: Running command from project root: {project_root}")
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 30, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print(f"✅ Embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: Server process terminated unexpectedly during startup.")
|
||||
self._log_monitor()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start embedding server process: {e}")
|
||||
return False
|
||||
|
||||
def _log_monitor(self):
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
for line in iter(self.server_process.stdout.readline, ''):
|
||||
print(f"[EmbeddingServer LOG]: {line.strip()}")
|
||||
self.server_process.stdout.close()
|
||||
if self.server_process.stderr:
|
||||
for line in iter(self.server_process.stderr.readline, ''):
|
||||
print(f"[EmbeddingServer ERROR]: {line.strip()}")
|
||||
self.server_process.stderr.close()
|
||||
except Exception as e:
|
||||
print(f"Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: Server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("WARNING: Server process did not terminate gracefully, killing it.")
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
|
||||
@register_backend("diskann")
|
||||
class DiskannBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
def builder(**kwargs) -> LeannBackendBuilderInterface:
|
||||
return DiskannBuilder(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(meta.get("embedding_model"))
|
||||
dimensions = model.get_sentence_embedding_dimension()
|
||||
kwargs['dimensions'] = dimensions
|
||||
except ImportError:
|
||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
||||
|
||||
return DiskannSearcher(index_path, **kwargs)
|
||||
|
||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
if not data.flags['C_CONTIGUOUS']:
|
||||
data = np.ascontiguousarray(data)
|
||||
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = METRIC_MAP.get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
complexity = build_kwargs.get("complexity", 64)
|
||||
graph_degree = build_kwargs.get("graph_degree", 32)
|
||||
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
|
||||
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
|
||||
num_threads = build_kwargs.get("num_threads", 8)
|
||||
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
||||
codebook_prefix = ""
|
||||
|
||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||
|
||||
try:
|
||||
with chdir(index_dir):
|
||||
diskannpy.build_disk_float_index(
|
||||
metric_enum,
|
||||
data_filename,
|
||||
index_prefix,
|
||||
complexity,
|
||||
graph_degree,
|
||||
final_index_ram_limit,
|
||||
indexing_ram_budget,
|
||||
num_threads,
|
||||
pq_disk_bytes,
|
||||
codebook_prefix
|
||||
)
|
||||
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
||||
raise
|
||||
finally:
|
||||
temp_data_file = index_dir / data_filename
|
||||
if temp_data_file.exists():
|
||||
os.remove(temp_data_file)
|
||||
|
||||
class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = METRIC_MAP.get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
num_threads = kwargs.get("num_threads", 8)
|
||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
||||
dimensions = kwargs.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Vector dimension not provided to DiskannSearcher.")
|
||||
|
||||
try:
|
||||
full_index_prefix = str(index_dir / index_prefix)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
|
||||
)
|
||||
self.num_threads = num_threads
|
||||
self.embedding_server_manager = EmbeddingServerManager()
|
||||
print("✅ DiskANN index loaded successfully.")
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
|
||||
raise
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
||||
complexity = kwargs.get("complexity", 100)
|
||||
beam_width = kwargs.get("beam_width", 4)
|
||||
|
||||
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
||||
skip_search_reorder = kwargs.get("skip_search_reorder", False)
|
||||
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
|
||||
dedup_node_dis = kwargs.get("dedup_node_dis", False)
|
||||
prune_ratio = kwargs.get("prune_ratio", 0.0)
|
||||
batch_recompute = kwargs.get("batch_recompute", False)
|
||||
global_pruning = kwargs.get("global_pruning", False)
|
||||
|
||||
if recompute_beighbor_embeddings:
|
||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||
zmq_port = kwargs.get("zmq_port", 5555)
|
||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
||||
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
|
||||
kwargs['recompute_beighbor_embeddings'] = False
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if query.ndim == 1:
|
||||
query = np.expand_dims(query, axis=0)
|
||||
|
||||
try:
|
||||
labels, distances = self._index.batch_search(
|
||||
query,
|
||||
query.shape[0],
|
||||
top_k,
|
||||
complexity,
|
||||
beam_width,
|
||||
self.num_threads,
|
||||
USE_DEFERRED_FETCH,
|
||||
skip_search_reorder,
|
||||
recompute_beighbor_embeddings,
|
||||
dedup_node_dis,
|
||||
prune_ratio,
|
||||
batch_recompute,
|
||||
global_pruning
|
||||
)
|
||||
return {"labels": labels, "distances": distances}
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
||||
batch_size = query.shape[0]
|
||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'embedding_server_manager'):
|
||||
self.embedding_server_manager.stop_server()
|
||||
@@ -0,0 +1,27 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: embedding.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x65mbedding.proto\x12\x0eprotoembedding\"(\n\x14NodeEmbeddingRequest\x12\x10\n\x08node_ids\x18\x01 \x03(\r\"Y\n\x15NodeEmbeddingResponse\x12\x17\n\x0f\x65mbeddings_data\x18\x01 \x01(\x0c\x12\x12\n\ndimensions\x18\x02 \x03(\x05\x12\x13\n\x0bmissing_ids\x18\x03 \x03(\rb\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'embedding_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
_NODEEMBEDDINGREQUEST._serialized_start=35
|
||||
_NODEEMBEDDINGREQUEST._serialized_end=75
|
||||
_NODEEMBEDDINGRESPONSE._serialized_start=77
|
||||
_NODEEMBEDDINGRESPONSE._serialized_end=166
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -0,0 +1,397 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# 简化的文档存储 - 替代 LazyPassages
|
||||
class SimpleDocumentStore:
|
||||
"""简化的文档存储,支持任意ID"""
|
||||
def __init__(self, documents: dict = None):
|
||||
self.documents = documents or {}
|
||||
# 默认演示文档
|
||||
self.default_docs = {
|
||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
||||
1: "Machine learning builds systems that learn from data.",
|
||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
||||
}
|
||||
|
||||
def __getitem__(self, doc_id):
|
||||
doc_id = int(doc_id)
|
||||
|
||||
# 优先使用指定的文档
|
||||
if doc_id in self.documents:
|
||||
return {"text": self.documents[doc_id]}
|
||||
|
||||
# 其次使用默认演示文档
|
||||
if doc_id in self.default_docs:
|
||||
return {"text": self.default_docs[doc_id]}
|
||||
|
||||
# 对于任意其他ID,返回通用文档
|
||||
fallback_docs = [
|
||||
"This is a general document about technology and programming concepts.",
|
||||
"This document discusses machine learning and artificial intelligence topics.",
|
||||
"This content covers data structures, algorithms, and computer science fundamentals.",
|
||||
"This is a document about software engineering and development practices.",
|
||||
"This content focuses on databases, data management, and information systems."
|
||||
]
|
||||
|
||||
# 根据ID选择一个fallback文档
|
||||
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
|
||||
return {"text": f"[ID:{doc_id}] {fallback_text}"}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.documents) + len(self.default_docs)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
max_batch_size=128,
|
||||
):
|
||||
"""
|
||||
在当前线程中创建并运行 embedding server
|
||||
这个函数设计为在单独的线程中调用
|
||||
"""
|
||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
||||
|
||||
try:
|
||||
# 检查端口是否已被占用
|
||||
import socket
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# 初始化模型
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
import torch
|
||||
|
||||
# 选择设备
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
print("INFO: Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
print("INFO: Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("INFO: Using CPU device")
|
||||
|
||||
# 加载模型
|
||||
print(f"INFO: Loading model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# 优化模型
|
||||
if cuda_available or mps_available:
|
||||
try:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Model optimization failed: {e}")
|
||||
|
||||
# 默认演示文档
|
||||
demo_documents = {
|
||||
0: "Python is a high-level, interpreted language known for simplicity.",
|
||||
1: "Machine learning builds systems that learn from data.",
|
||||
2: "Data structures like arrays, lists, and graphs organize data.",
|
||||
}
|
||||
|
||||
passages = SimpleDocumentStore(demo_documents)
|
||||
print(f"INFO: Loaded {len(passages)} demo documents")
|
||||
|
||||
class DeviceTimer:
|
||||
"""设备计时器"""
|
||||
def __init__(self, name="", device=device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if cuda_available:
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start()
|
||||
yield
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if cuda_available:
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if cuda_available:
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if cuda_available:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def print_elapsed(self):
|
||||
print(f"Time taken for {self.name}: {self.elapsed_time():.6f} seconds")
|
||||
|
||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
||||
"""处理文本批次"""
|
||||
batch_size = len(texts_batch)
|
||||
print(f"INFO: Processing batch of size {batch_size}")
|
||||
|
||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
||||
embed_timer = DeviceTimer("embedding (batch)", device)
|
||||
pool_timer = DeviceTimer("mean pooling (batch)", device)
|
||||
|
||||
with tokenize_timer.timing():
|
||||
encoded_batch = tokenizer.batch_encode_plus(
|
||||
texts_batch,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=256,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
tokenize_timer.print_elapsed()
|
||||
|
||||
seq_length = encoded_batch["input_ids"].size(1)
|
||||
print(f"Batch size: {batch_size}, Sequence length: {seq_length}")
|
||||
|
||||
with to_device_timer.timing():
|
||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
to_device_timer.print_elapsed()
|
||||
|
||||
with torch.no_grad():
|
||||
with embed_timer.timing():
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
embed_timer.print_elapsed()
|
||||
|
||||
with pool_timer.timing():
|
||||
hidden_states = out.last_hidden_state if hasattr(out, "last_hidden_state") else out
|
||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
batch_embeddings = sum_embeddings / sum_mask
|
||||
pool_timer.print_elapsed()
|
||||
|
||||
return batch_embeddings.cpu().numpy()
|
||||
|
||||
# ZMQ server 主循环 - 修改为REP套接字
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.ROUTER) # 改为REP套接字
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
||||
|
||||
# 设置超时
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5秒接收超时
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300秒发送超时
|
||||
|
||||
from . import embedding_pb2
|
||||
|
||||
print(f"INFO: Embedding server ready to serve requests")
|
||||
|
||||
while True:
|
||||
try:
|
||||
parts = socket.recv_multipart()
|
||||
|
||||
# --- 恢复稳健的消息格式判断 ---
|
||||
# 必须检查 parts 的长度,避免 IndexError
|
||||
if len(parts) >= 3:
|
||||
identity = parts[0]
|
||||
# empty = parts[1] # 中间的空帧我们通常不关心
|
||||
message = parts[2]
|
||||
elif len(parts) == 2:
|
||||
# 也能处理没有空帧的情况
|
||||
identity = parts[0]
|
||||
message = parts[1]
|
||||
else:
|
||||
# 如果收到格式错误的消息,打印警告并忽略它,而不是崩溃
|
||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
||||
continue
|
||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup", device)
|
||||
|
||||
# 解析请求
|
||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||
req_proto.ParseFromString(message)
|
||||
node_ids = req_proto.node_ids
|
||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
||||
|
||||
# 添加调试信息
|
||||
if len(node_ids) > 0:
|
||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
||||
|
||||
# 查找文本
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
txt = txtinfo["text"]
|
||||
texts.append(txt)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
||||
|
||||
# 处理批次
|
||||
total_size = len(texts)
|
||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
if total_size > max_batch_size:
|
||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
||||
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
hidden = process_batch(texts, node_ids, missing_ids)
|
||||
|
||||
# 序列化响应
|
||||
ser_start = time.time()
|
||||
|
||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||
resp_proto.missing_ids.extend(missing_ids)
|
||||
|
||||
response_data = resp_proto.SerializeToString()
|
||||
|
||||
# REP 套接字发送单个响应
|
||||
socket.send_multipart([identity, b'', response_data])
|
||||
|
||||
ser_end = time.time()
|
||||
|
||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
||||
# REP套接字不需要重新创建,只需要继续监听
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"ERROR: Error in ZMQ server: {e}")
|
||||
try:
|
||||
# 发送空响应以维持REQ-REP状态
|
||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
||||
socket.send(empty_resp.SerializeToString())
|
||||
except:
|
||||
# 如果发送失败,重新创建socket
|
||||
socket.close()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
print("INFO: ZMQ socket recreated after error")
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Failed to start embedding server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 保持原有的 create_embedding_server 函数不变,只添加线程化版本
|
||||
def create_embedding_server(
|
||||
domain="demo",
|
||||
load_passages=True,
|
||||
load_embeddings=False,
|
||||
use_fp16=True,
|
||||
use_int8=False,
|
||||
use_cuda_graphs=False,
|
||||
zmq_port=5555,
|
||||
max_batch_size=128,
|
||||
lazy_load_passages=False,
|
||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
||||
):
|
||||
"""
|
||||
原有的 create_embedding_server 函数保持不变
|
||||
这个是阻塞版本,用于直接运行
|
||||
"""
|
||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
args = parser.parse_args()
|
||||
|
||||
create_embedding_server(
|
||||
domain=args.domain,
|
||||
load_passages=args.load_passages,
|
||||
load_embeddings=args.load_embeddings,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int8=args.use_int8,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
zmq_port=args.zmq_port,
|
||||
max_batch_size=args.max_batch_size,
|
||||
lazy_load_passages=args.lazy_load_passages,
|
||||
model_name=args.model_name,
|
||||
)
|
||||
16
packages/leann-backend-diskann/pyproject.toml
Normal file
16
packages/leann-backend-diskann/pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[build-system]
|
||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "leann-backend-diskann"
|
||||
version = "0.1.0"
|
||||
dependencies = ["leann-core==0.1.0", "numpy"]
|
||||
|
||||
[tool.scikit-build]
|
||||
# 关键:简化的 CMake 路径
|
||||
cmake.source-dir = "third_party/DiskANN"
|
||||
# 关键:Python 包在根目录,路径完全匹配
|
||||
wheel.packages = ["leann_backend_diskann"]
|
||||
# 使用默认的 redirect 模式
|
||||
editable.mode = "redirect"
|
||||
6
packages/leann-backend-diskann/third_party/DiskANN/.clang-format
vendored
Normal file
6
packages/leann-backend-diskann/third_party/DiskANN/.clang-format
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
---
|
||||
BasedOnStyle: Microsoft
|
||||
---
|
||||
Language: Cpp
|
||||
SortIncludes: false
|
||||
...
|
||||
14
packages/leann-backend-diskann/third_party/DiskANN/.gitattributes
vendored
Normal file
14
packages/leann-backend-diskann/third_party/DiskANN/.gitattributes
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# Set the default behavior, in case people don't have core.autocrlf set.
|
||||
* text=auto
|
||||
|
||||
# Explicitly declare text files you want to always be normalized and converted
|
||||
# to native line endings on checkout.
|
||||
*.c text
|
||||
*.h text
|
||||
|
||||
# Declare files that will always have CRLF line endings on checkout.
|
||||
*.sln text eol=crlf
|
||||
|
||||
# Denote all files that are truly binary and should not be modified.
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
40
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
40
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Bug reports help us improve! Thanks for submitting yours!
|
||||
title: "[BUG] "
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Expected Behavior
|
||||
Tell us what should happen
|
||||
|
||||
## Actual Behavior
|
||||
Tell us what happens instead
|
||||
|
||||
## Example Code
|
||||
Please see [How to create a Minimal, Reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) for some guidance on creating the best possible example of the problem
|
||||
```bash
|
||||
|
||||
```
|
||||
|
||||
## Dataset Description
|
||||
Please tell us about the shape and datatype of your data, (e.g. 128 dimensions, 12.3 billion points, floats)
|
||||
- Dimensions:
|
||||
- Number of Points:
|
||||
- Data type:
|
||||
|
||||
## Error
|
||||
```
|
||||
Paste the full error, with any sensitive information minimally redacted and marked $$REDACTED$$
|
||||
|
||||
```
|
||||
|
||||
## Your Environment
|
||||
* Operating system (e.g. Windows 11 Pro, Ubuntu 22.04.1 LTS)
|
||||
* DiskANN version (or commit built from)
|
||||
|
||||
## Additional Details
|
||||
Any other contextual information you might feel is important.
|
||||
|
||||
2
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
2
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
blank_issues_enabled: false
|
||||
|
||||
25
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
25
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Is your feature request related to a problem? Please describe.
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
## Describe the solution you'd like
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
## Describe alternatives you've considered
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
## Provide references (if applicable)
|
||||
If your feature request is related to a published algorithm/idea, please provide links to
|
||||
any relevant articles or webpages.
|
||||
|
||||
## Additional context
|
||||
Add any other context or screenshots about the feature request here.
|
||||
|
||||
11
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/usage-question.md
vendored
Normal file
11
packages/leann-backend-diskann/third_party/DiskANN/.github/ISSUE_TEMPLATE/usage-question.md
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
---
|
||||
name: Usage Question
|
||||
about: Ask us a question about DiskANN!
|
||||
title: "[Question]"
|
||||
labels: question
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
This is our forum for asking whatever DiskANN question you'd like! No need to feel shy - we're happy to talk about use cases and optimal tuning strategies!
|
||||
|
||||
22
packages/leann-backend-diskann/third_party/DiskANN/.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
22
packages/leann-backend-diskann/third_party/DiskANN/.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
<!--
|
||||
Thanks for contributing a pull request! Please ensure you have taken a look at
|
||||
the contribution guidelines: https://github.com/microsoft/DiskANN/blob/main/CONTRIBUTING.md
|
||||
-->
|
||||
- [ ] Does this PR have a descriptive title that could go in our release notes?
|
||||
- [ ] Does this PR add any new dependencies?
|
||||
- [ ] Does this PR modify any existing APIs?
|
||||
- [ ] Is the change to the API backwards compatible?
|
||||
- [ ] Should this result in any changes to our documentation, either updating existing docs or adding new ones?
|
||||
|
||||
#### Reference Issues/PRs
|
||||
<!--
|
||||
Example: Fixes #1234. See also #3456.
|
||||
Please use keywords (e.g., Fixes) to create link to the issues or pull requests
|
||||
you resolved, so that they will automatically be closed when your pull request
|
||||
is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests
|
||||
-->
|
||||
|
||||
#### What does this implement/fix? Briefly explain your changes.
|
||||
|
||||
#### Any other comments?
|
||||
|
||||
39
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/build/action.yml
vendored
Normal file
39
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/build/action.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: 'DiskANN Build Bootstrap'
|
||||
description: 'Prepares DiskANN build environment and executes build'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
# ------------ Linux Build ---------------
|
||||
- name: Prepare and Execute Build
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
run: |
|
||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
||||
cmake --build build -- -j
|
||||
cmake --install build --prefix="dist"
|
||||
shell: bash
|
||||
# ------------ End Linux Build ---------------
|
||||
# ------------ Windows Build ---------------
|
||||
- name: Add VisualStudio command line tools into path
|
||||
if: runner.os == 'Windows'
|
||||
uses: ilammy/msvc-dev-cmd@v1
|
||||
- name: Run configure and build for Windows
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
mkdir build && cd build && cmake .. -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
||||
cd ..
|
||||
mkdir dist
|
||||
mklink /j .\dist\bin .\x64\Release\
|
||||
shell: cmd
|
||||
# ------------ End Windows Build ---------------
|
||||
# ------------ Windows Build With EXEC_ENV_OLS and USE_BING_INFRA ---------------
|
||||
- name: Add VisualStudio command line tools into path
|
||||
if: runner.os == 'Windows'
|
||||
uses: ilammy/msvc-dev-cmd@v1
|
||||
- name: Run configure and build for Windows with Bing feature flags
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
mkdir build_bing && cd build_bing && cmake .. -DEXEC_ENV_OLS=1 -DUSE_BING_INFRA=1 -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
||||
cd ..
|
||||
shell: cmd
|
||||
# ------------ End Windows Build ---------------
|
||||
13
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/format-check/action.yml
vendored
Normal file
13
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/format-check/action.yml
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
name: 'Checking code formatting...'
|
||||
description: 'Ensures code complies with code formatting rules'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Checking code formatting...
|
||||
run: |
|
||||
sudo apt install clang-format
|
||||
find include -name '*.h' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
find src -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
find apps -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
find python -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
shell: bash
|
||||
@@ -0,0 +1,28 @@
|
||||
name: 'Generating Random Data (Basic)'
|
||||
description: 'Generates the random data files used in acceptance tests'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Generate Random Data (Basic)
|
||||
run: |
|
||||
mkdir data
|
||||
|
||||
echo "Generating random 1020,1024,1536D float and 4096 int8 vectors for index"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_5K_norm1.0.bin -D 1020 -N 5000 --norm 1.0
|
||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_5K_norm1.0.bin -D 1024 -N 5000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_5K_norm1.0.bin -D 1536 -N 5000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_5K_norm1.0.bin -D 4096 -N 5000 --norm 1.0
|
||||
|
||||
echo "Generating random 1020,1024,1536D float and 4096D int8 avectors for query"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_1K_norm1.0.bin -D 1020 -N 1000 --norm 1.0
|
||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_1K_norm1.0.bin -D 1024 -N 1000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_1K_norm1.0.bin -D 1536 -N 1000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_1K_norm1.0.bin -D 4096 -N 1000 --norm 1.0
|
||||
|
||||
echo "Computing ground truth for 1020,1024,1536D float and 4096D int8 avectors for query"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1020D_5K_norm1.0.bin --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --K 100
|
||||
#dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1024D_5K_norm1.0.bin --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1536D_5K_norm1.0.bin --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_4096D_5K_norm1.0.bin --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --K 100
|
||||
|
||||
shell: bash
|
||||
38
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/generate-random/action.yml
vendored
Normal file
38
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/generate-random/action.yml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: 'Generating Random Data (Basic)'
|
||||
description: 'Generates the random data files used in acceptance tests'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Generate Random Data (Basic)
|
||||
run: |
|
||||
mkdir data
|
||||
|
||||
echo "Generating random vectors for index"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_unnorm.bin -D 10 -N 10000 --rand_scaling 2.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
|
||||
echo "Generating random vectors for query"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_unnorm.bin -D 10 -N 1000 --rand_scaling 2.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
|
||||
echo "Computing ground truth for floats across l2, mips, and cosine distance functions"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_unnorm.bin --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --K 100
|
||||
|
||||
echo "Computing ground truth for int8s across l2, mips, and cosine distance functions"
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn mips --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn cosine --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
|
||||
echo "Computing ground truth for uint8s across l2, mips, and cosine distance functions"
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn mips --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
|
||||
shell: bash
|
||||
22
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/python-wheel/action.yml
vendored
Normal file
22
packages/leann-backend-diskann/third_party/DiskANN/.github/actions/python-wheel/action.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
name: Build Python Wheel
|
||||
description: Builds a python wheel with cibuildwheel
|
||||
inputs:
|
||||
cibw-identifier:
|
||||
description: "CI build wheel identifier to build"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- uses: actions/setup-python@v3
|
||||
- name: Install cibuildwheel
|
||||
run: python -m pip install cibuildwheel==2.11.3
|
||||
shell: bash
|
||||
- name: Building Python ${{inputs.cibw-identifier}} Wheel
|
||||
run: python -m cibuildwheel --output-dir dist
|
||||
env:
|
||||
CIBW_BUILD: ${{inputs.cibw-identifier}}
|
||||
shell: bash
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
path: ./dist/*.whl
|
||||
81
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/build-python-pdoc.yml
vendored
Normal file
81
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/build-python-pdoc.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
name: DiskANN Build PDoc Documentation
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
build-reference-documentation:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Set up Python 3.9
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install python build
|
||||
run: python -m pip install build
|
||||
shell: bash
|
||||
# Install required dependencies
|
||||
- name: Prepare Linux environment
|
||||
run: |
|
||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
||||
shell: bash
|
||||
# We need to build the wheel in order to run pdoc. pdoc does not seem to work if you just point it at
|
||||
# our source directory.
|
||||
- name: Building Python Wheel for documentation generation
|
||||
run: python -m build --wheel --outdir documentation_dist
|
||||
shell: bash
|
||||
- name: "Run Reference Documentation Generation"
|
||||
run: |
|
||||
pip install pdoc pipdeptree
|
||||
pip install documentation_dist/*.whl
|
||||
echo "documentation" > dependencies_documentation.txt
|
||||
pipdeptree >> dependencies_documentation.txt
|
||||
pdoc -o docs/python/html diskannpy
|
||||
- name: Create version environment variable
|
||||
run: |
|
||||
echo "DISKANN_VERSION=$(python <<EOF
|
||||
from importlib.metadata import version
|
||||
v = version('diskannpy')
|
||||
print(v)
|
||||
EOF
|
||||
)" >> $GITHUB_ENV
|
||||
- name: Archive documentation version artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dependencies
|
||||
path: |
|
||||
${{ github.run_id }}-dependencies_documentation.txt
|
||||
overwrite: true
|
||||
- name: Archive documentation artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: documentation-site
|
||||
path: |
|
||||
docs/python/html
|
||||
# Publish to /dev if we are on the "main" branch
|
||||
- name: Publish reference docs for latest development version (main branch)
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.ref == 'refs/heads/main'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/python/html
|
||||
destination_dir: docs/python/dev
|
||||
# Publish to /<version> if we are releasing
|
||||
- name: Publish reference docs by version (main branch)
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.event_name == 'release'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/python/html
|
||||
destination_dir: docs/python/${{ env.DISKANN_VERSION }}
|
||||
# Publish to /latest if we are releasing
|
||||
- name: Publish latest reference docs (main branch)
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.event_name == 'release'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/python/html
|
||||
destination_dir: docs/python/latest
|
||||
42
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/build-python.yml
vendored
Normal file
42
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/build-python.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: DiskANN Build Python Wheel
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
linux-build:
|
||||
name: Python - Ubuntu - ${{matrix.cibw-identifier}}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
cibw-identifier: ["cp39-manylinux_x86_64", "cp310-manylinux_x86_64", "cp311-manylinux_x86_64"]
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
||||
uses: ./.github/actions/python-wheel
|
||||
with:
|
||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
||||
windows-build:
|
||||
name: Python - Windows - ${{matrix.cibw-identifier}}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
cibw-identifier: ["cp39-win_amd64", "cp310-win_amd64", "cp311-win_amd64"]
|
||||
runs-on: windows-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
||||
uses: ./.github/actions/python-wheel
|
||||
with:
|
||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
||||
28
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/common.yml
vendored
Normal file
28
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/common.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: DiskANN Common Checks
|
||||
# common means common to both pr-test and push-test
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
formatting-check:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: Code Formatting Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checking code formatting...
|
||||
uses: ./.github/actions/format-check
|
||||
docker-container-build:
|
||||
name: Docker Container Build
|
||||
needs: [formatting-check]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Docker build
|
||||
run: |
|
||||
docker build .
|
||||
117
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/disk-pq.yml
vendored
Normal file
117
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/disk-pq.yml
vendored
Normal file
@@ -0,0 +1,117 @@
|
||||
name: Disk With PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-disk-pq:
|
||||
name: Disk, PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, cosine, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16\
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, cosine, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (int8)
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (sharded graph build, MIPS, diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: disk-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
102
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/dynamic-labels.yml
vendored
Normal file
102
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/dynamic-labels.yml
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
name: Dynamic-Labels
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-dynamic:
|
||||
name: Dynamic-Labels
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: Generate Labels
|
||||
run: |
|
||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
||||
|
||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
||||
|
||||
- name: Test a streaming index (float) with labels (Zipf distributed)
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_zipf_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 --label_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: Test a streaming index (float) with labels (random distributed)
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_rand_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 --label_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: Test Insert Delete Consolidate (float) with labels (zipf distributed)
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/zipf_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_zipf_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K_wlabel_5 --label_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 10 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_zipf_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: Test Insert Delete Consolidate (float) with labels (random distributed)
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/rand_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_rand_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K_wlabel_5 --label_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 40 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_rand_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dynamic-labels-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
75
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/dynamic.yml
vendored
Normal file
75
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/dynamic.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: Dynamic
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-dynamic:
|
||||
name: Dynamic
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: test a streaming index (float)
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
- name: test a streaming index (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
- name: test a streaming index
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
|
||||
- name: build and search an incremental index (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2;
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
- name: build and search an incremental index (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
- name: build and search an incremental index (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_10K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dynamic-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
81
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/in-mem-no-pq.yml
vendored
Normal file
81
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/in-mem-no-pq.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
name: In-Memory Without PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-mem-no-pq:
|
||||
name: In-Mem, Without PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: build and search in-memory index with L2 metrics (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with L2 metrics (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with L2 metrics (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: Searching with fast_l2 distance function (float)
|
||||
if: runner.os != 'Windows' && (success() || failure())
|
||||
run: |
|
||||
dist/bin/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with MIPS metric (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_mips_rand_float_10D_10K_norm1.0
|
||||
dist/bin/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with cosine metric (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_10D_10K_norm1.0
|
||||
dist/bin/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with cosine metric (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type int8 --dist_fn cosine --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_int8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with cosine metric
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: in-memory-no-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
56
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/in-mem-pq.yml
vendored
Normal file
56
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/in-mem-pq.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: In-Memory With PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-mem-pq:
|
||||
name: In-Mem, PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: build and search in-memory index with L2 metric with PQ based distance comparisons (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: in-memory-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
120
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/labels.yml
vendored
Normal file
120
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/labels.yml
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
name: Labels
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-labels:
|
||||
name: Labels
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: Generate Labels
|
||||
run: |
|
||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
|
||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/mips_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
|
||||
echo "Generating synthetic labels and computing ground truth for filtered search without a universal label"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --K 100
|
||||
dist/bin/generate_synthetic_labels --num_labels 10 --num_points 1000 --output_file data/query_labels_1K.txt --distribution_type one_per_point
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label_file data/query_labels_1K.txt --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
|
||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (random distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
|
||||
echo "Searching without filters"
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
|
||||
- name: build and search disk index with labels using L2 and Cosine metrics (random distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (zipf distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
|
||||
echo "Searching without filters"
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
|
||||
- name: build and search disk index with labels using L2 and Cosine metrics (zipf distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name : build and search in-memory and disk index (without universal label, zipf distributed)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal -R 32 -L 5 -B 0.00003 -M 1
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal -L 16 32
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: Generate combined GT for each query with a separate label and search
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --query_filters_file data/query_labels_1K.txt --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
- name: Build and search stitched vamana with random and zipf distributed labels
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_rand_32_100_64_new --universal_label 0
|
||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_zipf_32_100_64_new --universal_label 0
|
||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix data/stit_rand_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/rand_stit_96_10_90_new --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/stit_zipf_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/zipf_stit_96_10_90_new --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
||||
|
||||
- name: upload data and bin
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: labels-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
60
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/multi-sector-disk-pq.yml
vendored
Normal file
60
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/multi-sector-disk-pq.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Disk With PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-disk-pq:
|
||||
name: Disk, PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-high-dim-random
|
||||
|
||||
- name: build and search disk index (1020D, one shot graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1020D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
#- name: build and search disk index (1024D, one shot graph build, L2, no diskPQ) (float)
|
||||
# if: success() || failure()
|
||||
# run: |
|
||||
# dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1024D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
# dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
- name: build and search disk index (1536D, one shot graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1536D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
|
||||
- name: build and search disk index (4096D, one shot graph build, L2, no diskPQ) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_4096D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: multi-sector-disk-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
26
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/perf.yml
vendored
Normal file
26
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/perf.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: DiskANN Nightly Performance Metrics
|
||||
on:
|
||||
schedule:
|
||||
- cron: "41 14 * * *" # 14:41 UTC, 7:41 PDT, 8:41 PST, 08:11 IST
|
||||
jobs:
|
||||
perf-test:
|
||||
name: Run Perf Test from main
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Build Perf Container
|
||||
run: |
|
||||
docker build --build-arg GIT_COMMIT_ISH="$GITHUB_SHA" -t perf -f scripts/perf/Dockerfile scripts
|
||||
- name: Performance Tests
|
||||
run: |
|
||||
mkdir metrics
|
||||
docker run -v ./metrics:/app/logs perf &> ./metrics/combined_stdouterr.log
|
||||
- name: Upload Metrics Logs
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: metrics-${{matrix.os}}
|
||||
path: |
|
||||
./metrics/**
|
||||
35
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/pr-test.yml
vendored
Normal file
35
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/pr-test.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: DiskANN Pull Request Build and Test
|
||||
on: [pull_request]
|
||||
jobs:
|
||||
common:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Common Build Checks
|
||||
uses: ./.github/workflows/common.yml
|
||||
unit-tests:
|
||||
name: Unit tests
|
||||
uses: ./.github/workflows/unit-tests.yml
|
||||
in-mem-pq:
|
||||
name: In-Memory with PQ
|
||||
uses: ./.github/workflows/in-mem-pq.yml
|
||||
in-mem-no-pq:
|
||||
name: In-Memory without PQ
|
||||
uses: ./.github/workflows/in-mem-no-pq.yml
|
||||
disk-pq:
|
||||
name: Disk with PQ
|
||||
uses: ./.github/workflows/disk-pq.yml
|
||||
multi-sector-disk-pq:
|
||||
name: Multi-sector Disk with PQ
|
||||
uses: ./.github/workflows/multi-sector-disk-pq.yml
|
||||
labels:
|
||||
name: Labels
|
||||
uses: ./.github/workflows/labels.yml
|
||||
dynamic:
|
||||
name: Dynamic
|
||||
uses: ./.github/workflows/dynamic.yml
|
||||
dynamic-labels:
|
||||
name: Dynamic Labels
|
||||
uses: ./.github/workflows/dynamic-labels.yml
|
||||
python:
|
||||
name: Python
|
||||
uses: ./.github/workflows/build-python.yml
|
||||
50
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/push-test.yml
vendored
Normal file
50
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/push-test.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: DiskANN Push Build
|
||||
on: [push]
|
||||
jobs:
|
||||
common:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Common Build Checks
|
||||
uses: ./.github/workflows/common.yml
|
||||
build-documentation:
|
||||
permissions:
|
||||
contents: write
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Build Documentation
|
||||
uses: ./.github/workflows/build-python-pdoc.yml
|
||||
build:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ ubuntu-latest, windows-2019, windows-latest ]
|
||||
name: Build for ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: Build diskannpy dependency tree
|
||||
run: |
|
||||
pip install diskannpy pipdeptree
|
||||
echo "dependencies" > dependencies_${{ matrix.os }}.txt
|
||||
pipdeptree >> dependencies_${{ matrix.os }}.txt
|
||||
- name: Archive diskannpy dependencies artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dependencies_${{ matrix.os }}
|
||||
path: |
|
||||
dependencies_${{ matrix.os }}.txt
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
43
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/python-release.yml
vendored
Normal file
43
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/python-release.yml
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
name: Build and Release Python Wheels
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
jobs:
|
||||
python-release-wheels:
|
||||
name: Python
|
||||
uses: ./.github/workflows/build-python.yml
|
||||
build-documentation:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Build Documentation
|
||||
uses: ./.github/workflows/build-python-pdoc.yml
|
||||
release:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
needs: python-release-wheels
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
path: dist/
|
||||
- name: Generate SHA256 files for each wheel
|
||||
run: |
|
||||
sha256sum dist/*.whl > checksums.txt
|
||||
cat checksums.txt
|
||||
- uses: actions/setup-python@v3
|
||||
- name: Install twine
|
||||
run: python -m pip install twine
|
||||
- name: Publish with twine
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: |
|
||||
twine upload dist/*.whl
|
||||
- name: Update release with SHA256 and Artifacts
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
files: |
|
||||
dist/*.whl
|
||||
checksums.txt
|
||||
32
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/unit-tests.yml
vendored
Normal file
32
packages/leann-backend-diskann/third_party/DiskANN/.github/workflows/unit-tests.yml
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
name: Unit Tests
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-labels:
|
||||
name: Unit Tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
cd build
|
||||
ctest -C Release
|
||||
384
packages/leann-backend-diskann/third_party/DiskANN/.gitignore
vendored
Normal file
384
packages/leann-backend-diskann/third_party/DiskANN/.gitignore
vendored
Normal file
@@ -0,0 +1,384 @@
|
||||
## Ignore Visual Studio temporary files, build results, and
|
||||
## files generated by popular Visual Studio add-ons.
|
||||
##
|
||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
||||
|
||||
# User-specific files
|
||||
*.rsuser
|
||||
*.suo
|
||||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
||||
# Mono auto generated files
|
||||
mono_crash.*
|
||||
|
||||
# Build results
|
||||
[Dd]ebug/
|
||||
[Dd]ebugPublic/
|
||||
[Rr]elease/
|
||||
[Rr]eleases/
|
||||
x64/
|
||||
x86/
|
||||
[Aa][Rr][Mm]/
|
||||
[Aa][Rr][Mm]64/
|
||||
bld/
|
||||
[Bb]in/
|
||||
[Oo]bj/
|
||||
[Ll]og/
|
||||
[Ll]ogs/
|
||||
|
||||
# Visual Studio 2015/2017 cache/options directory
|
||||
.vs/
|
||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||
#wwwroot/
|
||||
|
||||
# Visual Studio 2017 auto generated files
|
||||
Generated\ Files/
|
||||
|
||||
# MSTest test Results
|
||||
[Tt]est[Rr]esult*/
|
||||
[Bb]uild[Ll]og.*
|
||||
|
||||
# NUnit
|
||||
*.VisualState.xml
|
||||
TestResult.xml
|
||||
nunit-*.xml
|
||||
|
||||
# Build Results of an ATL Project
|
||||
[Dd]ebugPS/
|
||||
[Rr]eleasePS/
|
||||
dlldata.c
|
||||
|
||||
# Benchmark Results
|
||||
BenchmarkDotNet.Artifacts/
|
||||
|
||||
# .NET Core
|
||||
project.lock.json
|
||||
project.fragment.lock.json
|
||||
artifacts/
|
||||
|
||||
# StyleCop
|
||||
StyleCopReport.xml
|
||||
|
||||
# Files built by Visual Studio
|
||||
*_i.c
|
||||
*_p.c
|
||||
*_h.h
|
||||
*.ilk
|
||||
*.meta
|
||||
*.obj
|
||||
*.iobj
|
||||
*.pch
|
||||
*.pdb
|
||||
*.ipdb
|
||||
*.pgc
|
||||
*.pgd
|
||||
*.rsp
|
||||
*.sbr
|
||||
*.tlb
|
||||
*.tli
|
||||
*.tlh
|
||||
*.tmp
|
||||
*.tmp_proj
|
||||
*_wpftmp.csproj
|
||||
*.log
|
||||
*.vspscc
|
||||
*.vssscc
|
||||
.builds
|
||||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
# Visual C++ cache files
|
||||
ipch/
|
||||
*.aps
|
||||
*.ncb
|
||||
*.opendb
|
||||
*.opensdf
|
||||
*.sdf
|
||||
*.cachefile
|
||||
*.VC.db
|
||||
*.VC.VC.opendb
|
||||
|
||||
# Visual Studio profiler
|
||||
*.psess
|
||||
*.vsp
|
||||
*.vspx
|
||||
*.sap
|
||||
|
||||
# Visual Studio Trace Files
|
||||
*.e2e
|
||||
|
||||
# TFS 2012 Local Workspace
|
||||
$tf/
|
||||
|
||||
# Guidance Automation Toolkit
|
||||
*.gpState
|
||||
|
||||
# ReSharper is a .NET coding add-in
|
||||
_ReSharper*/
|
||||
*.[Rr]e[Ss]harper
|
||||
*.DotSettings.user
|
||||
|
||||
# TeamCity is a build add-in
|
||||
_TeamCity*
|
||||
|
||||
# DotCover is a Code Coverage Tool
|
||||
*.dotCover
|
||||
|
||||
# AxoCover is a Code Coverage Tool
|
||||
.axoCover/*
|
||||
!.axoCover/settings.json
|
||||
|
||||
# Visual Studio code coverage results
|
||||
*.coverage
|
||||
*.coveragexml
|
||||
|
||||
# NCrunch
|
||||
_NCrunch_*
|
||||
.*crunch*.local.xml
|
||||
nCrunchTemp_*
|
||||
|
||||
# MightyMoose
|
||||
*.mm.*
|
||||
AutoTest.Net/
|
||||
|
||||
# Web workbench (sass)
|
||||
.sass-cache/
|
||||
|
||||
# Installshield output folder
|
||||
[Ee]xpress/
|
||||
|
||||
# DocProject is a documentation generator add-in
|
||||
DocProject/buildhelp/
|
||||
DocProject/Help/*.HxT
|
||||
DocProject/Help/*.HxC
|
||||
DocProject/Help/*.hhc
|
||||
DocProject/Help/*.hhk
|
||||
DocProject/Help/*.hhp
|
||||
DocProject/Help/Html2
|
||||
DocProject/Help/html
|
||||
|
||||
# Click-Once directory
|
||||
publish/
|
||||
|
||||
# Publish Web Output
|
||||
*.[Pp]ublish.xml
|
||||
*.azurePubxml
|
||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||
# but database connection strings (with potential passwords) will be unencrypted
|
||||
*.pubxml
|
||||
*.publishproj
|
||||
|
||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||
# in these scripts will be unencrypted
|
||||
PublishScripts/
|
||||
|
||||
# NuGet Packages
|
||||
*.nupkg
|
||||
# NuGet Symbol Packages
|
||||
*.snupkg
|
||||
# The packages folder can be ignored because of Package Restore
|
||||
**/[Pp]ackages/*
|
||||
# except build/, which is used as an MSBuild target.
|
||||
!**/[Pp]ackages/build/
|
||||
# Uncomment if necessary however generally it will be regenerated when needed
|
||||
#!**/[Pp]ackages/repositories.config
|
||||
# NuGet v3's project.json files produces more ignorable files
|
||||
*.nuget.props
|
||||
*.nuget.targets
|
||||
|
||||
# Microsoft Azure Build Output
|
||||
csx/
|
||||
*.build.csdef
|
||||
|
||||
# Microsoft Azure Emulator
|
||||
ecf/
|
||||
rcf/
|
||||
|
||||
# Windows Store app package directories and files
|
||||
AppPackages/
|
||||
BundleArtifacts/
|
||||
Package.StoreAssociation.xml
|
||||
_pkginfo.txt
|
||||
*.appx
|
||||
*.appxbundle
|
||||
*.appxupload
|
||||
|
||||
# Visual Studio cache files
|
||||
# files ending in .cache can be ignored
|
||||
*.[Cc]ache
|
||||
# but keep track of directories ending in .cache
|
||||
!?*.[Cc]ache/
|
||||
|
||||
# Others
|
||||
ClientBin/
|
||||
~$*
|
||||
*~
|
||||
*.dbmdl
|
||||
*.dbproj.schemaview
|
||||
*.jfm
|
||||
*.pfx
|
||||
*.publishsettings
|
||||
orleans.codegen.cs
|
||||
|
||||
# Including strong name files can present a security risk
|
||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||
#*.snk
|
||||
|
||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||
#bower_components/
|
||||
|
||||
# RIA/Silverlight projects
|
||||
Generated_Code/
|
||||
|
||||
# Backup & report files from converting an old project file
|
||||
# to a newer Visual Studio version. Backup files are not needed,
|
||||
# because we have git ;-)
|
||||
_UpgradeReport_Files/
|
||||
Backup*/
|
||||
UpgradeLog*.XML
|
||||
UpgradeLog*.htm
|
||||
ServiceFabricBackup/
|
||||
*.rptproj.bak
|
||||
|
||||
# SQL Server files
|
||||
*.mdf
|
||||
*.ldf
|
||||
*.ndf
|
||||
|
||||
# Business Intelligence projects
|
||||
*.rdl.data
|
||||
*.bim.layout
|
||||
*.bim_*.settings
|
||||
*.rptproj.rsuser
|
||||
*- [Bb]ackup.rdl
|
||||
*- [Bb]ackup ([0-9]).rdl
|
||||
*- [Bb]ackup ([0-9][0-9]).rdl
|
||||
|
||||
# Microsoft Fakes
|
||||
FakesAssemblies/
|
||||
|
||||
# GhostDoc plugin setting file
|
||||
*.GhostDoc.xml
|
||||
|
||||
# Node.js Tools for Visual Studio
|
||||
.ntvs_analysis.dat
|
||||
node_modules/
|
||||
|
||||
# Visual Studio 6 build log
|
||||
*.plg
|
||||
|
||||
# Visual Studio 6 workspace options file
|
||||
*.opt
|
||||
|
||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||
*.vbw
|
||||
|
||||
# Visual Studio LightSwitch build output
|
||||
**/*.HTMLClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/ModelManifest.xml
|
||||
**/*.Server/GeneratedArtifacts
|
||||
**/*.Server/ModelManifest.xml
|
||||
_Pvt_Extensions
|
||||
|
||||
# Paket dependency manager
|
||||
.paket/paket.exe
|
||||
paket-files/
|
||||
|
||||
# FAKE - F# Make
|
||||
.fake/
|
||||
|
||||
# CodeRush personal settings
|
||||
.cr/personal
|
||||
|
||||
# Python Tools for Visual Studio (PTVS)
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Cake - Uncomment if you are using it
|
||||
# tools/**
|
||||
# !tools/packages.config
|
||||
|
||||
# Tabs Studio
|
||||
*.tss
|
||||
|
||||
# Telerik's JustMock configuration file
|
||||
*.jmconfig
|
||||
|
||||
# BizTalk build output
|
||||
*.btp.cs
|
||||
*.btm.cs
|
||||
*.odx.cs
|
||||
*.xsd.cs
|
||||
|
||||
# OpenCover UI analysis results
|
||||
OpenCover/
|
||||
|
||||
# Azure Stream Analytics local run output
|
||||
ASALocalRun/
|
||||
|
||||
# MSBuild Binary and Structured Log
|
||||
*.binlog
|
||||
|
||||
# NVidia Nsight GPU debugger configuration file
|
||||
*.nvuser
|
||||
|
||||
# MFractors (Xamarin productivity tool) working folder
|
||||
.mfractor/
|
||||
|
||||
# Local History for Visual Studio
|
||||
.localhistory/
|
||||
|
||||
# BeatPulse healthcheck temp database
|
||||
healthchecksdb
|
||||
|
||||
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
||||
MigrationBackup/
|
||||
|
||||
# Ionide (cross platform F# VS Code tools) working folder
|
||||
.ionide/
|
||||
|
||||
/vcproj/nsg/x64/Debug/nsg.Build.CppClean.log
|
||||
/vcproj/test_recall/x64/Debug/test_recall.Build.CppClean.log
|
||||
/vcproj/test_recall/test_recall.vcxproj.user
|
||||
/.vs
|
||||
/out/build/x64-Debug
|
||||
cscope*
|
||||
|
||||
build/
|
||||
build_linux/
|
||||
!.github/actions/build
|
||||
|
||||
# jetbrains specific stuff
|
||||
.idea/
|
||||
cmake-build-debug/
|
||||
|
||||
#python extension module ignores
|
||||
python/diskannpy.egg-info/
|
||||
python/dist/
|
||||
|
||||
**/*.egg-info
|
||||
wheelhouse/*
|
||||
dist/*
|
||||
venv*/**
|
||||
*.swp
|
||||
|
||||
gperftools
|
||||
|
||||
# Rust
|
||||
rust/target
|
||||
|
||||
python/src/*.so
|
||||
|
||||
compile_commands.json
|
||||
3
packages/leann-backend-diskann/third_party/DiskANN/.gitmodules
vendored
Normal file
3
packages/leann-backend-diskann/third_party/DiskANN/.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "gperftools"]
|
||||
path = gperftools
|
||||
url = https://github.com/gperftools/gperftools.git
|
||||
563
packages/leann-backend-diskann/third_party/DiskANN/CMakeLists.txt
vendored
Normal file
563
packages/leann-backend-diskann/third_party/DiskANN/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,563 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Parameters:
|
||||
#
|
||||
# BOOST_ROOT:
|
||||
# Specify root of the Boost library if Boost cannot be auto-detected. On Windows, a fallback to a
|
||||
# downloaded nuget version will be used if Boost cannot be found.
|
||||
#
|
||||
# DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS:
|
||||
# This is a work-in-progress feature, not completed yet. The core DiskANN library will be split into
|
||||
# build-related and search-related functionality. In build-related functionality, when using tcmalloc,
|
||||
# it's possible to release memory that's free but reserved by tcmalloc. Setting this to true enables
|
||||
# such behavior.
|
||||
# Contact for this feature: gopalrs.
|
||||
|
||||
|
||||
# Some variables like MSVC are defined only after project(), so put that first.
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(diskann)
|
||||
|
||||
#Set option to use tcmalloc
|
||||
option(USE_TCMALLOC "Use tcmalloc from gperftools" ON)
|
||||
|
||||
# set tcmalloc to false when on macos
|
||||
if(APPLE)
|
||||
set(USE_TCMALLOC OFF)
|
||||
endif()
|
||||
|
||||
option(PYBIND "Build with Python bindings" ON)
|
||||
|
||||
if(PYBIND)
|
||||
# Find Python
|
||||
find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
|
||||
OUTPUT_VARIABLE pybind11_DIR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
|
||||
message(STATUS "Python include dirs: ${Python_INCLUDE_DIRS}")
|
||||
message(STATUS "Pybind11 include dirs: ${pybind11_INCLUDE_DIRS}")
|
||||
|
||||
# Add pybind11 include directories
|
||||
include_directories(SYSTEM ${pybind11_INCLUDE_DIRS} ${Python_INCLUDE_DIRS})
|
||||
|
||||
# Add compilation definitions
|
||||
add_definitions(-DPYBIND11_EMBEDDED)
|
||||
|
||||
# Set visibility flags
|
||||
if(NOT MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CMAKE_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# if(NOT MSVC)
|
||||
# set(CMAKE_CXX_COMPILER g++)
|
||||
# endif()
|
||||
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
|
||||
|
||||
# Install nuget packages for dependencies.
|
||||
if (MSVC)
|
||||
find_program(NUGET_EXE NAMES nuget)
|
||||
|
||||
if (NOT NUGET_EXE)
|
||||
message(FATAL_ERROR "Cannot find nuget command line tool.\nPlease install it from e.g. https://www.nuget.org/downloads")
|
||||
endif()
|
||||
|
||||
set(DISKANN_MSVC_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/packages.config)
|
||||
set(DISKANN_MSVC_PACKAGES ${CMAKE_BINARY_DIR}/packages)
|
||||
|
||||
message(STATUS "Invoking nuget to download Boost, OpenMP and MKL dependencies...")
|
||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages.config.in ${DISKANN_MSVC_PACKAGES_CONFIG})
|
||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
||||
if (RESTAPI)
|
||||
set(DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/restapi/packages.config)
|
||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages_restapi.config.in ${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG})
|
||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
||||
endif()
|
||||
message(STATUS "Finished setting up nuget dependencies")
|
||||
endif()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
if(USE_TCMALLOC)
|
||||
FetchContent_Declare(
|
||||
tcmalloc
|
||||
GIT_REPOSITORY https://github.com/google/tcmalloc.git
|
||||
GIT_TAG origin/master # or specify a particular version or commit
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(tcmalloc)
|
||||
endif()
|
||||
|
||||
if(NOT PYBIND)
|
||||
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
|
||||
endif()
|
||||
# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
|
||||
# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
|
||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
||||
include_directories(${tcmalloc_SOURCE_DIR}/src)
|
||||
if (MSVC)
|
||||
include_directories(${tcmalloc_SOURCE_DIR}/src/windows)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#OpenMP
|
||||
if (MSVC)
|
||||
# Do not use find_package here since it would use VisualStudio's built-in OpenMP, but MKL libraries
|
||||
# refer to Intel's OpenMP.
|
||||
#
|
||||
# No extra settings are needed for compilation: it only needs /openmp flag which is set further below,
|
||||
# in the common MSVC compiler options block.
|
||||
include_directories(BEFORE "${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/include")
|
||||
link_libraries("${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/win-x64/libiomp5md.lib")
|
||||
|
||||
set(OPENMP_WINDOWS_RUNTIME_FILES
|
||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.dll"
|
||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.pdb")
|
||||
elseif(APPLE)
|
||||
# Check if we're building Python bindings
|
||||
if(PYBIND)
|
||||
# First look for PyTorch's OpenMP to avoid conflicts
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE} -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libomp.dylib'))"
|
||||
RESULT_VARIABLE TORCH_PATH_RESULT
|
||||
OUTPUT_VARIABLE TORCH_LIBOMP_PATH
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_QUIET
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND brew --prefix libomp
|
||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(EXISTS "${TORCH_LIBOMP_PATH}")
|
||||
message(STATUS "Found PyTorch's libomp: ${TORCH_LIBOMP_PATH}")
|
||||
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp")
|
||||
set(OpenMP_C_FLAGS "-Xclang -fopenmp")
|
||||
set(OpenMP_CXX_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
||||
set(OpenMP_C_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
||||
set(OpenMP_FOUND TRUE)
|
||||
|
||||
include_directories(${LIBOMP_ROOT}/include)
|
||||
|
||||
# Set compiler flags and link libraries
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries("${TORCH_LIBOMP_PATH}")
|
||||
else()
|
||||
message(STATUS "No PyTorch's libomp found, falling back to normal OpenMP detection")
|
||||
# Fallback to normal OpenMP detection
|
||||
execute_process(
|
||||
COMMAND brew --prefix libomp
|
||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
||||
find_package(OpenMP)
|
||||
|
||||
if (OPENMP_FOUND)
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
# Regular OpenMP setup for non-Python builds
|
||||
execute_process(
|
||||
COMMAND brew --prefix libomp
|
||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
||||
find_package(OpenMP)
|
||||
|
||||
if (OPENMP_FOUND)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
find_package(OpenMP)
|
||||
|
||||
if (OPENMP_FOUND)
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
else()
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# DiskANN core uses header-only libraries. Only DiskANN tools need program_options which has a linker library,
|
||||
# but its size is small. Reduce number of dependent DLLs by linking statically.
|
||||
if (MSVC)
|
||||
set(Boost_USE_STATIC_LIBS ON)
|
||||
endif()
|
||||
|
||||
if(NOT MSVC)
|
||||
find_package(Boost COMPONENTS program_options)
|
||||
endif()
|
||||
|
||||
# For Windows, fall back to nuget version if find_package didn't find it.
|
||||
if (MSVC AND NOT Boost_FOUND)
|
||||
set(DISKANN_BOOST_INCLUDE "${DISKANN_MSVC_PACKAGES}/boost/lib/native/include")
|
||||
# Multi-threaded static library.
|
||||
set(PROGRAM_OPTIONS_LIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-x64-*.lib")
|
||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_LIB ${PROGRAM_OPTIONS_LIB_PATTERN})
|
||||
|
||||
set(PROGRAM_OPTIONS_DLIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-gd-x64-*.lib")
|
||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_DLIB ${PROGRAM_OPTIONS_DLIB_PATTERN})
|
||||
|
||||
if (EXISTS ${DISKANN_BOOST_INCLUDE} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB})
|
||||
set(Boost_FOUND ON)
|
||||
set(Boost_INCLUDE_DIR ${DISKANN_BOOST_INCLUDE})
|
||||
add_library(Boost::program_options STATIC IMPORTED)
|
||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_RELEASE "${DISKANN_BOOST_PROGRAM_OPTIONS_LIB}")
|
||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_DEBUG "${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB}")
|
||||
message(STATUS "Falling back to using Boost from the nuget package")
|
||||
else()
|
||||
message(WARNING "Couldn't find Boost. Was looking for ${DISKANN_BOOST_INCLUDE} and ${PROGRAM_OPTIONS_LIB_PATTERN}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT Boost_FOUND)
|
||||
message(FATAL_ERROR "Couldn't find Boost dependency")
|
||||
endif()
|
||||
|
||||
include_directories(${Boost_INCLUDE_DIR})
|
||||
|
||||
#MKL Config
|
||||
if (MSVC)
|
||||
# Only the DiskANN DLL and one of the tools need MKL libraries. Additionally, only a small part of MKL is used.
|
||||
# Given that and given that MKL DLLs are huge, use static linking to end up with no MKL DLL dependencies and with
|
||||
# significantly smaller disk footprint.
|
||||
#
|
||||
# The compile options are not modified as there's already an unconditional -DMKL_ILP64 define below
|
||||
# for all architectures, which is all that's needed.
|
||||
set(DISKANN_MKL_INCLUDE_DIRECTORIES "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/include")
|
||||
set(DISKANN_MKL_LIB_PATH "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/win-x64")
|
||||
|
||||
set(DISKANN_MKL_LINK_LIBRARIES
|
||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_ilp64.lib"
|
||||
"${DISKANN_MKL_LIB_PATH}/mkl_core.lib"
|
||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_thread.lib")
|
||||
elseif(APPLE)
|
||||
# no mkl on non-intel devices
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
message(STATUS "Found Accelerate (${ACCELERATE_LIBRARY})")
|
||||
set(DISKANN_ACCEL_LINK_OPTIONS ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
# expected path for manual intel mkl installs
|
||||
set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/2025.0/lib/libiomp5.so;/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so")
|
||||
foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS})
|
||||
if (EXISTS ${POSSIBLE_OMP_PATH})
|
||||
get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT OMP_PATH)
|
||||
message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment")
|
||||
endif()
|
||||
link_directories(${OMP_PATH})
|
||||
|
||||
set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so")
|
||||
foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS})
|
||||
if (EXISTS ${POSSIBLE_MKL_LIB_PATH})
|
||||
get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;")
|
||||
foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS})
|
||||
if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH})
|
||||
set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH})
|
||||
endif()
|
||||
endforeach()
|
||||
if(NOT MKL_PATH)
|
||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment")
|
||||
elseif(NOT MKL_INCLUDE_PATH)
|
||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment")
|
||||
endif()
|
||||
if (EXISTS ${MKL_PATH}/libmkl_def.so.2)
|
||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2)
|
||||
elseif(EXISTS ${MKL_PATH}/libmkl_def.so)
|
||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so)
|
||||
else()
|
||||
message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.")
|
||||
endif()
|
||||
link_directories(${MKL_PATH})
|
||||
include_directories(${MKL_INCLUDE_PATH})
|
||||
|
||||
# compile flags and link libraries
|
||||
# if gcc/g++
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
add_compile_options(-m64 -Wl,--no-as-needed)
|
||||
endif()
|
||||
if (NOT PYBIND)
|
||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
|
||||
else()
|
||||
# static linking for python so as to minimize customer dependency issues
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
# In debug mode, use dynamic linking to ensure all symbols are available
|
||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core ${MKL_DEF_SO} iomp5 pthread m dl)
|
||||
else()
|
||||
# In release mode, use static linking to minimize dependencies
|
||||
link_libraries(
|
||||
${MKL_PATH}/libmkl_intel_ilp64.a
|
||||
${MKL_PATH}/libmkl_intel_thread.a
|
||||
${MKL_PATH}/libmkl_core.a
|
||||
${MKL_DEF_SO}
|
||||
iomp5
|
||||
pthread
|
||||
m
|
||||
dl
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_definitions(-DMKL_ILP64)
|
||||
endif()
|
||||
|
||||
|
||||
# Section for tcmalloc. The DiskANN tools are always linked to tcmalloc. For Windows, they also need to
|
||||
# force-include the _tcmalloc symbol for enabling tcmalloc.
|
||||
#
|
||||
# The DLL itself needs to be linked to tcmalloc only if DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS
|
||||
# is enabled.
|
||||
if(USE_TCMALLOC)
|
||||
if (MSVC)
|
||||
if (NOT EXISTS "${PROJECT_SOURCE_DIR}/gperftools/gperftools.sln")
|
||||
message(FATAL_ERROR "The gperftools submodule was not found. "
|
||||
"Please check-out git submodules by doing 'git submodule init' followed by 'git submodule update'")
|
||||
endif()
|
||||
|
||||
set(TCMALLOC_LINK_LIBRARY "${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.lib")
|
||||
set(TCMALLOC_WINDOWS_RUNTIME_FILES
|
||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.dll"
|
||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.pdb")
|
||||
|
||||
# Tell CMake how to build the tcmalloc linker library from the submodule.
|
||||
add_custom_target(build_libtcmalloc_minimal DEPENDS ${TCMALLOC_LINK_LIBRARY})
|
||||
add_custom_command(OUTPUT ${TCMALLOC_LINK_LIBRARY}
|
||||
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} gperftools.sln /m /nologo
|
||||
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
|
||||
/property:Platform="x64"
|
||||
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION}
|
||||
/p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION}
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/gperftools)
|
||||
|
||||
add_library(libtcmalloc_minimal_for_exe STATIC IMPORTED)
|
||||
add_library(libtcmalloc_minimal_for_dll STATIC IMPORTED)
|
||||
|
||||
set_target_properties(libtcmalloc_minimal_for_dll PROPERTIES
|
||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}")
|
||||
|
||||
set_target_properties(libtcmalloc_minimal_for_exe PROPERTIES
|
||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}"
|
||||
INTERFACE_LINK_OPTIONS /INCLUDE:_tcmalloc)
|
||||
|
||||
# Ensure libtcmalloc_minimal is built before it's being used.
|
||||
add_dependencies(libtcmalloc_minimal_for_dll build_libtcmalloc_minimal)
|
||||
add_dependencies(libtcmalloc_minimal_for_exe build_libtcmalloc_minimal)
|
||||
|
||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_exe)
|
||||
elseif(APPLE) # ! Inherited from #474, not been adjusted for TCMalloc Removal
|
||||
execute_process(
|
||||
COMMAND brew --prefix gperftools
|
||||
OUTPUT_VARIABLE GPERFTOOLS_PREFIX
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-L${GPERFTOOLS_PREFIX}/lib -ltcmalloc")
|
||||
elseif(NOT PYBIND)
|
||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc")
|
||||
endif()
|
||||
|
||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
||||
add_definitions(-DRELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
||||
|
||||
if (MSVC)
|
||||
set(DISKANN_DLL_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_dll)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT MSVC AND NOT APPLE)
|
||||
set(DISKANN_ASYNC_LIB aio)
|
||||
endif()
|
||||
|
||||
#Main compiler/linker settings
|
||||
if(MSVC)
|
||||
#language options
|
||||
add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++17 /Gd /W3 /MP /Zi /FC /nologo)
|
||||
#code generation options
|
||||
add_compile_options(/arch:AVX2 /fp:fast /fp:except- /EHsc /GS- /Gy)
|
||||
#optimization options
|
||||
add_compile_options(/Ot /Oy /Oi)
|
||||
#path options
|
||||
add_definitions(-DUSE_AVX2 -DUSE_ACCELERATED_PQ -D_WINDOWS -DNOMINMAX -DUNICODE)
|
||||
# Linker options. Exclude VCOMP/VCOMPD.LIB which contain VisualStudio's version of OpenMP.
|
||||
# MKL was linked against Intel's OpenMP and depends on the corresponding DLL.
|
||||
add_link_options(/NODEFAULTLIB:VCOMP.LIB /NODEFAULTLIB:VCOMPD.LIB /DEBUG:FULL /OPT:REF /OPT:ICF)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
elseif(APPLE)
|
||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -Xclang -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -Wno-inconsistent-missing-override -Wno-return-type")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -ftree-vectorize")
|
||||
if (NOT PYBIND)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
||||
if (NOT PORTABLE)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -mtune=native")
|
||||
endif()
|
||||
else()
|
||||
# -Ofast is not supported in a python extension module
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC")
|
||||
endif()
|
||||
else()
|
||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2 -fPIC")
|
||||
if(USE_TCMALLOC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free")
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
||||
if (NOT PYBIND)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
||||
if (NOT PORTABLE)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native")
|
||||
endif()
|
||||
else()
|
||||
# -Ofast is not supported in a python extension module
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
if (NOT PYBIND)
|
||||
add_subdirectory(apps)
|
||||
add_subdirectory(apps/utils)
|
||||
endif()
|
||||
|
||||
if (UNIT_TEST)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
|
||||
"Alternatively, use MSBuild to build:\n\n"
|
||||
"msbuild.exe ${PROJECT_NAME}.sln /m /nologo /t:Build /p:Configuration=\"Release\" /property:Platform=\"x64\"\n")
|
||||
endif()
|
||||
|
||||
if (RESTAPI)
|
||||
if (MSVC)
|
||||
set(DISKANN_CPPRESTSDK "${DISKANN_MSVC_PACKAGES}/cpprestsdk.v142/build/native")
|
||||
# expected path for apt packaged intel mkl installs
|
||||
link_libraries("${DISKANN_CPPRESTSDK}/x64/lib/cpprest142_2_10.lib")
|
||||
include_directories("${DISKANN_CPPRESTSDK}/include")
|
||||
endif()
|
||||
add_subdirectory(apps/restapi)
|
||||
endif()
|
||||
|
||||
include(clang-format.cmake)
|
||||
|
||||
if(PYBIND)
|
||||
add_subdirectory(python)
|
||||
|
||||
install(TARGETS _diskannpy
|
||||
DESTINATION leann_backend_diskann
|
||||
COMPONENT python_modules
|
||||
)
|
||||
|
||||
endif()
|
||||
###############################################################################
|
||||
# PROTOBUF SECTION - Corrected to use CONFIG mode explicitly
|
||||
###############################################################################
|
||||
set(Protobuf_USE_STATIC_LIBS OFF)
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
|
||||
find_package(Protobuf REQUIRED)
|
||||
|
||||
message(STATUS "Protobuf found: ${Protobuf_VERSION}")
|
||||
message(STATUS "Protobuf include dirs: ${Protobuf_INCLUDE_DIRS}")
|
||||
message(STATUS "Protobuf libraries: ${Protobuf_LIBRARIES}")
|
||||
message(STATUS "Protobuf protoc executable: ${Protobuf_PROTOC_EXECUTABLE}")
|
||||
|
||||
include_directories(${Protobuf_INCLUDE_DIRS})
|
||||
|
||||
set(PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/../embedding.proto")
|
||||
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
|
||||
set(generated_proto_sources ${PROTO_SRCS})
|
||||
|
||||
|
||||
add_library(proto_embeddings STATIC ${generated_proto_sources})
|
||||
target_link_libraries(proto_embeddings PUBLIC protobuf::libprotobuf)
|
||||
target_include_directories(proto_embeddings PUBLIC
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Protobuf_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(diskann PRIVATE proto_embeddings protobuf::libprotobuf)
|
||||
target_include_directories(diskann PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Protobuf_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(diskann_s PRIVATE proto_embeddings protobuf::libprotobuf)
|
||||
target_include_directories(diskann_s PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Protobuf_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# ZEROMQ SECTION - REQUIRED
|
||||
###############################################################################
|
||||
|
||||
find_package(ZeroMQ QUIET)
|
||||
if(NOT ZeroMQ_FOUND)
|
||||
find_path(ZeroMQ_INCLUDE_DIR zmq.h)
|
||||
find_library(ZeroMQ_LIBRARY zmq)
|
||||
if(ZeroMQ_INCLUDE_DIR AND ZeroMQ_LIBRARY)
|
||||
set(ZeroMQ_FOUND TRUE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ZeroMQ_FOUND)
|
||||
message(STATUS "Found ZeroMQ: ${ZeroMQ_LIBRARY}")
|
||||
include_directories(${ZeroMQ_INCLUDE_DIR})
|
||||
target_link_libraries(diskann PRIVATE ${ZeroMQ_LIBRARY})
|
||||
target_link_libraries(diskann_s PRIVATE ${ZeroMQ_LIBRARY})
|
||||
add_definitions(-DUSE_ZEROMQ)
|
||||
else()
|
||||
message(FATAL_ERROR "ZeroMQ is required but not found. Please install ZeroMQ and try again.")
|
||||
endif()
|
||||
|
||||
target_link_libraries(diskann ${PYBIND11_LIBRARIES})
|
||||
target_link_libraries(diskann_s ${PYBIND11_LIBRARIES})
|
||||
28
packages/leann-backend-diskann/third_party/DiskANN/CMakeSettings.json
vendored
Normal file
28
packages/leann-backend-diskann/third_party/DiskANN/CMakeSettings.json
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "x64-Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "Release",
|
||||
"inheritEnvironments": [ "msvc_x64" ],
|
||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "",
|
||||
"ctestCommandArgs": ""
|
||||
},
|
||||
{
|
||||
"name": "WSL-GCC-Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "RelWithDebInfo",
|
||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||
"cmakeExecutable": "cmake",
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "",
|
||||
"ctestCommandArgs": "",
|
||||
"inheritEnvironments": [ "linux_x64" ],
|
||||
"wslPath": "${defaultWSLPath}"
|
||||
}
|
||||
]
|
||||
}
|
||||
9
packages/leann-backend-diskann/third_party/DiskANN/CODE_OF_CONDUCT.md
vendored
Normal file
9
packages/leann-backend-diskann/third_party/DiskANN/CODE_OF_CONDUCT.md
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# Microsoft Open Source Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
Resources:
|
||||
|
||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
||||
9
packages/leann-backend-diskann/third_party/DiskANN/CONTRIBUTING.md
vendored
Normal file
9
packages/leann-backend-diskann/third_party/DiskANN/CONTRIBUTING.md
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
17
packages/leann-backend-diskann/third_party/DiskANN/Dockerfile
vendored
Normal file
17
packages/leann-backend-diskann/third_party/DiskANN/Dockerfile
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
||||
#Licensed under the MIT license.
|
||||
|
||||
FROM ubuntu:jammy
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y software-properties-common
|
||||
RUN add-apt-repository -y ppa:git-core/ppa
|
||||
RUN apt update
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10
|
||||
|
||||
WORKDIR /app
|
||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
||||
WORKDIR /app/DiskANN
|
||||
RUN mkdir build
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
|
||||
RUN cmake --build build -- -j
|
||||
17
packages/leann-backend-diskann/third_party/DiskANN/DockerfileDev
vendored
Normal file
17
packages/leann-backend-diskann/third_party/DiskANN/DockerfileDev
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
||||
#Licensed under the MIT license.
|
||||
|
||||
FROM ubuntu:jammy
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y software-properties-common
|
||||
RUN add-apt-repository -y ppa:git-core/ppa
|
||||
RUN apt update
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libboost-test-dev libmkl-full-dev libcpprest-dev python3.10
|
||||
|
||||
WORKDIR /app
|
||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
||||
WORKDIR /app/DiskANN
|
||||
RUN mkdir build
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
||||
RUN cmake --build build -- -j
|
||||
23
packages/leann-backend-diskann/third_party/DiskANN/LICENSE
vendored
Normal file
23
packages/leann-backend-diskann/third_party/DiskANN/LICENSE
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
DiskANN
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
12
packages/leann-backend-diskann/third_party/DiskANN/MANIFEST.in
vendored
Normal file
12
packages/leann-backend-diskann/third_party/DiskANN/MANIFEST.in
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
include MANIFEST.in
|
||||
include *.txt
|
||||
include *.md
|
||||
include setup.py
|
||||
include pyproject.toml
|
||||
include *.cmake
|
||||
recursive-include gperftools *
|
||||
recursive-include include *
|
||||
recursive-include python *
|
||||
recursive-include windows *
|
||||
prune python/tests
|
||||
recursive-include src *
|
||||
135
packages/leann-backend-diskann/third_party/DiskANN/README.md
vendored
Normal file
135
packages/leann-backend-diskann/third_party/DiskANN/README.md
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
# DiskANN
|
||||
|
||||
[](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml)
|
||||
[](https://pypi.org/project/diskannpy/)
|
||||
[](https://pepy.tech/project/diskannpy)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
[](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf)
|
||||
[](https://arxiv.org/abs/2105.09613)
|
||||
[](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf)
|
||||
|
||||
|
||||
DiskANN is a suite of scalable, accurate and cost-effective approximate nearest neighbor search algorithms for large-scale vector search that support real-time changes and simple filters.
|
||||
This code is based on ideas from the [DiskANN](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf), [Fresh-DiskANN](https://arxiv.org/abs/2105.09613) and the [Filtered-DiskANN](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) papers with further improvements.
|
||||
This code forked off from [code for NSG](https://github.com/ZJULearning/nsg) algorithm.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
See [guidelines](CONTRIBUTING.md) for contributing to this project.
|
||||
|
||||
## Linux build:
|
||||
|
||||
Install the following packages through apt-get
|
||||
|
||||
```bash
|
||||
sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format libboost-all-dev
|
||||
```
|
||||
|
||||
### Install Intel MKL
|
||||
#### Ubuntu 20.04 or newer
|
||||
```bash
|
||||
sudo apt install libmkl-full-dev
|
||||
```
|
||||
|
||||
#### Earlier versions of Ubuntu
|
||||
Install Intel MKL either by downloading the [oneAPI MKL installer](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) or using [apt](https://software.intel.com/en-us/articles/installing-intel-free-libs-and-python-apt-repo) (we tested with build 2019.4-070 and 2022.1.2.146).
|
||||
|
||||
```
|
||||
# OneAPI MKL Installer
|
||||
wget https://registrationcenter-download.intel.com/akdlm/irc_nas/18487/l_BaseKit_p_2022.1.2.146.sh
|
||||
sudo sh l_BaseKit_p_2022.1.2.146.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
```
|
||||
|
||||
### Build
|
||||
```bash
|
||||
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
|
||||
```
|
||||
|
||||
## Windows build:
|
||||
|
||||
The Windows version has been tested with Enterprise editions of Visual Studio 2022, 2019 and 2017. It should work with the Community and Professional editions as well without any changes.
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* CMake 3.15+ (available in VisualStudio 2019+ or from https://cmake.org)
|
||||
* NuGet.exe (install from https://www.nuget.org/downloads)
|
||||
* The build script will use NuGet to get MKL, OpenMP and Boost packages.
|
||||
* DiskANN git repository checked out together with submodules. To check out submodules after git clone:
|
||||
```
|
||||
git submodule init
|
||||
git submodule update
|
||||
```
|
||||
|
||||
* Environment variables:
|
||||
* [optional] If you would like to override the Boost library listed in windows/packages.config.in, set BOOST_ROOT to your Boost folder.
|
||||
|
||||
**Build steps:**
|
||||
* Open the "x64 Native Tools Command Prompt for VS 2019" (or corresponding version) and change to DiskANN folder
|
||||
* Create a "build" directory inside it
|
||||
* Change to the "build" directory and run
|
||||
```
|
||||
cmake ..
|
||||
```
|
||||
OR for Visual Studio 2017 and earlier:
|
||||
```
|
||||
<full-path-to-installed-cmake>\cmake ..
|
||||
```
|
||||
**This will create a diskann.sln solution**. Now you can:
|
||||
|
||||
- Open it from VisualStudio and build either Release or Debug configuration.
|
||||
- `<full-path-to-installed-cmake>\cmake --build build`
|
||||
- Use MSBuild:
|
||||
```
|
||||
msbuild.exe diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64"
|
||||
```
|
||||
|
||||
* This will also build gperftools submodule for libtcmalloc_minimal dependency.
|
||||
* Generated binaries are stored in the x64/Release or x64/Debug directories.
|
||||
|
||||
## macOS Build
|
||||
|
||||
### Prerequisites
|
||||
* Apple Silicon. The code should still work on Intel-based Macs, but there are no guarantees.
|
||||
* macOS >= 12.0
|
||||
* XCode Command Line Tools (install with `xcode-select --install`)
|
||||
* [homebrew](https://brew.sh/)
|
||||
|
||||
### Install Required Packages
|
||||
```zsh
|
||||
brew install cmake
|
||||
brew install boost
|
||||
brew install gperftools
|
||||
brew install libomp
|
||||
```
|
||||
|
||||
### Build DiskANN
|
||||
```zsh
|
||||
# same as ubuntu instructions
|
||||
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
|
||||
```
|
||||
|
||||
## Usage:
|
||||
|
||||
Please see the following pages on using the compiled code:
|
||||
|
||||
- [Commandline interface for building and search SSD based indices](workflows/SSD_index.md)
|
||||
- [Commandline interface for building and search in memory indices](workflows/in_memory_index.md)
|
||||
- [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md)
|
||||
- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md)
|
||||
- [Commandline interface for building and search SSD based indices with label data and filters](workflows/filtered_ssd_index.md)
|
||||
- [diskannpy - DiskANN as a python extension module](python/README.md)
|
||||
|
||||
Please cite this software in your work as:
|
||||
|
||||
```
|
||||
@misc{diskann-github,
|
||||
author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan and Patel, Yash}},
|
||||
title = {{DiskANN: Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search}},
|
||||
url = {https://github.com/Microsoft/DiskANN},
|
||||
version = {0.6.1},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
41
packages/leann-backend-diskann/third_party/DiskANN/SECURITY.md
vendored
Normal file
41
packages/leann-backend-diskann/third_party/DiskANN/SECURITY.md
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
||||
42
packages/leann-backend-diskann/third_party/DiskANN/apps/CMakeLists.txt
vendored
Normal file
42
packages/leann-backend-diskann/third_party/DiskANN/apps/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
add_executable(build_memory_index build_memory_index.cpp)
|
||||
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_stitched_index build_stitched_index.cpp)
|
||||
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(search_memory_index search_memory_index.cpp)
|
||||
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_disk_index build_disk_index.cpp)
|
||||
target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
add_executable(search_disk_index search_disk_index.cpp)
|
||||
target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(range_search_disk_index range_search_disk_index.cpp)
|
||||
target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(test_streaming_scenario test_streaming_scenario.cpp)
|
||||
target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp)
|
||||
target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
if (NOT MSVC)
|
||||
install(TARGETS build_memory_index
|
||||
build_stitched_index
|
||||
search_memory_index
|
||||
build_disk_index
|
||||
search_disk_index
|
||||
range_search_disk_index
|
||||
test_streaming_scenario
|
||||
test_insert_deletes_consolidate
|
||||
RUNTIME
|
||||
)
|
||||
endif()
|
||||
191
packages/leann-backend-diskann/third_party/DiskANN/apps/build_disk_index.cpp
vendored
Normal file
191
packages/leann-backend-diskann/third_party/DiskANN/apps/build_disk_index.cpp
vendored
Normal file
@@ -0,0 +1,191 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "index.h"
|
||||
#include "partition.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
|
||||
label_type;
|
||||
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
|
||||
float B, M;
|
||||
bool append_reorder_data = false;
|
||||
bool use_opq = false;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("search_DRAM_budget,B", po::value<float>(&B)->required(),
|
||||
"DRAM budget in GB for searching the index to set the "
|
||||
"compressed level for data while search happens");
|
||||
required_configs.add_options()("build_DRAM_budget,M", po::value<float>(&M)->required(),
|
||||
"DRAM budget in GB for building the index");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("QD", po::value<uint32_t>(&QD)->default_value(0),
|
||||
" Quantized Dimension for compression");
|
||||
optional_configs.add_options()("codebook_prefix", po::value<std::string>(&codebook_prefix)->default_value(""),
|
||||
"Path prefix for pre-trained codebook");
|
||||
optional_configs.add_options()("PQ_disk_bytes", po::value<uint32_t>(&disk_PQ)->default_value(0),
|
||||
"Number of bytes to which vectors should be compressed "
|
||||
"on SSD; 0 for no compression");
|
||||
optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD.");
|
||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
|
||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
program_options_utils::USE_OPQ);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
program_options_utils::FILTERED_LBUILD);
|
||||
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
|
||||
"Threshold to break up the existing nodes to generate new graph "
|
||||
"internally where each node has a maximum F labels.");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (vm["append_reorder_data"].as<bool>())
|
||||
append_reorder_data = true;
|
||||
if (vm["use_opq"].as<bool>())
|
||||
use_opq = true;
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool use_filters = (label_file != "") ? true : false;
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
metric = diskann::Metric::COSINE;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (append_reorder_data)
|
||||
{
|
||||
if (disk_PQ == 0)
|
||||
{
|
||||
std::cout << "Error: It is not necessary to append data for reordering "
|
||||
"when vectors are not compressed on disk."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (data_type != std::string("float"))
|
||||
{
|
||||
std::cout << "Error: Appending data for reordering currently only "
|
||||
"supported for float data type."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " +
|
||||
std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " +
|
||||
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
|
||||
std::string(std::to_string(append_reorder_data)) + " " +
|
||||
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));
|
||||
|
||||
try
|
||||
{
|
||||
if (label_file != "" && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
164
packages/leann-backend-diskann/third_party/DiskANN/apps/build_memory_index.cpp
vendored
Normal file
164
packages/leann-backend-diskann/third_party/DiskANN/apps/build_memory_index.cpp
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <cstring>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
#include "ann_exception.h"
|
||||
#include "index_factory.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
||||
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
|
||||
float alpha;
|
||||
bool use_pq_build, use_opq;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
|
||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
program_options_utils::USE_OPQ);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
|
||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
program_options_utils::FILTERED_LBUILD);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
use_pq_build = (build_PQ_bytes > 0);
|
||||
use_opq = vm["use_opq"].as<bool>();
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
|
||||
<< " #threads: " << num_threads << std::endl;
|
||||
|
||||
size_t data_num, data_dim;
|
||||
diskann::get_bin_metadata(data_path, data_num, data_dim);
|
||||
|
||||
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_filter_list_size(Lf)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
auto filter_params = diskann::IndexFilterParamsBuilder()
|
||||
.with_universal_label(universal_label)
|
||||
.with_label_file(label_file)
|
||||
.with_save_path_prefix(index_path_prefix)
|
||||
.build();
|
||||
auto config = diskann::IndexConfigBuilder()
|
||||
.with_metric(metric)
|
||||
.with_dimension(data_dim)
|
||||
.with_max_points(data_num)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.with_data_type(data_type)
|
||||
.with_label_type(label_type)
|
||||
.is_dynamic_index(false)
|
||||
.with_index_write_params(index_build_params)
|
||||
.is_enable_tags(false)
|
||||
.is_use_opq(use_opq)
|
||||
.is_pq_dist_build(use_pq_build)
|
||||
.with_num_pq_chunks(build_PQ_bytes)
|
||||
.build();
|
||||
|
||||
auto index_factory = diskann::IndexFactory(config);
|
||||
auto index = index_factory.create_instance();
|
||||
index->build(data_path, data_num, filter_params);
|
||||
index->save(index_path_prefix.c_str());
|
||||
index.reset();
|
||||
return 0;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
441
packages/leann-backend-diskann/third_party/DiskANN/apps/build_stitched_index.cpp
vendored
Normal file
441
packages/leann-backend-diskann/third_party/DiskANN/apps/build_stitched_index.cpp
vendored
Normal file
@@ -0,0 +1,441 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "filter_utils.h"
|
||||
#include <omp.h>
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/uio.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "parameters.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> stitch_indices_return_values;
|
||||
|
||||
/*
|
||||
* Inline function to display progress bar.
|
||||
*/
|
||||
inline void print_progress(double percentage)
|
||||
{
|
||||
int val = (int)(percentage * 100);
|
||||
int lpad = (int)(percentage * PBWIDTH);
|
||||
int rpad = PBWIDTH - lpad;
|
||||
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Inline function to generate a random integer in a range.
|
||||
*/
|
||||
inline size_t random(size_t range_from, size_t range_to)
|
||||
{
|
||||
std::random_device rand_dev;
|
||||
std::mt19937 generator(rand_dev());
|
||||
std::uniform_int_distribution<size_t> distr(range_from, range_to);
|
||||
return distr(generator);
|
||||
}
|
||||
|
||||
/*
|
||||
* function to handle command line parsing.
|
||||
*
|
||||
* Arguments are merely the inputs from the command line.
|
||||
*/
|
||||
void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix,
|
||||
path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L,
|
||||
uint32_t &stitched_R, float &alpha)
|
||||
{
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix",
|
||||
po::value<std::string>(&final_index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&input_data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_data_path)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
optional_configs.add_options()("stitched_R", po::value<uint32_t>(&stitched_R)->default_value(100),
|
||||
"Degree to prune final graph down to");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
exit(0);
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Custom index save to write the in-memory index to disk.
|
||||
* Also writes required files for diskANN API -
|
||||
* 1. labels_to_medoids
|
||||
* 2. universal_label
|
||||
* 3. data (redundant for static indices)
|
||||
* 4. labels (redundant for static indices)
|
||||
*/
|
||||
void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size,
|
||||
std::vector<std::vector<uint32_t>> stitched_graph,
|
||||
tsl::robin_map<std::string, uint32_t> entry_points, std::string universal_label,
|
||||
path label_data_path)
|
||||
{
|
||||
// aux. file 1
|
||||
auto saving_index_timer = std::chrono::high_resolution_clock::now();
|
||||
std::ifstream original_label_data_stream;
|
||||
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_label_data_stream.open(label_data_path, std::ios::binary);
|
||||
std::ofstream new_label_data_stream;
|
||||
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary);
|
||||
new_label_data_stream << original_label_data_stream.rdbuf();
|
||||
original_label_data_stream.close();
|
||||
new_label_data_stream.close();
|
||||
|
||||
// aux. file 2
|
||||
std::ifstream original_input_data_stream;
|
||||
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_input_data_stream.open(input_data_path, std::ios::binary);
|
||||
std::ofstream new_input_data_stream;
|
||||
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary);
|
||||
new_input_data_stream << original_input_data_stream.rdbuf();
|
||||
original_input_data_stream.close();
|
||||
new_input_data_stream.close();
|
||||
|
||||
// aux. file 3
|
||||
std::ofstream labels_to_medoids_writer;
|
||||
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt");
|
||||
for (auto iter : entry_points)
|
||||
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
|
||||
labels_to_medoids_writer.close();
|
||||
|
||||
// aux. file 4 (only if we're using a universal label)
|
||||
if (universal_label != "")
|
||||
{
|
||||
std::ofstream universal_label_writer;
|
||||
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
universal_label_writer.open(final_index_path_prefix + "_universal_label.txt");
|
||||
universal_label_writer << universal_label << std::endl;
|
||||
universal_label_writer.close();
|
||||
}
|
||||
|
||||
// main index
|
||||
uint64_t index_num_frozen_points = 0, index_num_edges = 0;
|
||||
uint32_t index_max_observed_degree = 0, index_entry_point = 0;
|
||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
||||
for (auto &point_neighbors : stitched_graph)
|
||||
{
|
||||
index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size());
|
||||
}
|
||||
|
||||
std::ofstream stitched_graph_writer;
|
||||
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
|
||||
|
||||
stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t));
|
||||
stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t));
|
||||
stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t));
|
||||
stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t));
|
||||
|
||||
size_t bytes_written = METADATA;
|
||||
for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++)
|
||||
{
|
||||
uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size();
|
||||
std::vector<uint32_t> current_node_neighbors = stitched_graph[node_point];
|
||||
stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t));
|
||||
bytes_written += sizeof(uint32_t);
|
||||
for (const auto ¤t_node_neighbor : current_node_neighbors)
|
||||
{
|
||||
stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t));
|
||||
bytes_written += sizeof(uint32_t);
|
||||
}
|
||||
index_num_edges += current_node_num_neighbors;
|
||||
}
|
||||
|
||||
if (bytes_written != final_index_size)
|
||||
{
|
||||
std::cerr << "Error: written bytes does not match allocated space" << std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
stitched_graph_writer.close();
|
||||
|
||||
std::chrono::duration<double> saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer;
|
||||
std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl;
|
||||
std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size()))
|
||||
<< std::endl;
|
||||
std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Unions the per-label graph indices together via the following policy:
|
||||
* - any two nodes can only have at most one edge between them -
|
||||
*
|
||||
* Returns the "stitched" graph and its expected file size.
|
||||
*/
|
||||
template <typename T>
|
||||
stitch_indices_return_values stitch_label_indices(
|
||||
path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels,
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
|
||||
tsl::robin_map<std::string, uint32_t> &label_entry_points,
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map)
|
||||
{
|
||||
size_t final_index_size = 0;
|
||||
std::vector<std::vector<uint32_t>> stitched_graph(total_number_of_points);
|
||||
|
||||
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
std::vector<std::vector<uint32_t>> curr_label_index;
|
||||
uint64_t curr_label_index_size;
|
||||
uint32_t curr_label_entry_point;
|
||||
|
||||
std::tie(curr_label_index, curr_label_index_size) =
|
||||
diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]);
|
||||
curr_label_entry_point = (uint32_t)random(0, curr_label_index.size());
|
||||
label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point];
|
||||
|
||||
for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++)
|
||||
{
|
||||
uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point];
|
||||
for (auto &node_neighbor : curr_label_index[node_point])
|
||||
{
|
||||
uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
|
||||
std::vector<uint32_t> curr_point_neighbors = stitched_graph[original_point_id];
|
||||
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) ==
|
||||
curr_point_neighbors.end())
|
||||
{
|
||||
stitched_graph[original_point_id].push_back(original_neighbor_id);
|
||||
final_index_size += sizeof(uint32_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
||||
final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA);
|
||||
|
||||
std::chrono::duration<double> stitching_index_time =
|
||||
std::chrono::high_resolution_clock::now() - stitching_index_timer;
|
||||
std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl;
|
||||
|
||||
return std::make_tuple(stitched_graph, final_index_size);
|
||||
}
|
||||
|
||||
/*
|
||||
* Applies the prune_neighbors function from src/index.cpp to
|
||||
* every node in the stitched graph.
|
||||
*
|
||||
* This is an optional step, hence the saving of both the full
|
||||
* and pruned graph.
|
||||
*/
|
||||
template <typename T>
|
||||
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path,
|
||||
std::vector<std::vector<uint32_t>> stitched_graph, uint32_t stitched_R,
|
||||
tsl::robin_map<std::string, uint32_t> label_entry_points, std::string universal_label,
|
||||
path label_data_path, uint32_t num_threads)
|
||||
{
|
||||
size_t dimension, number_of_label_points;
|
||||
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
|
||||
auto std_cout_buffer = std::cout.rdbuf(nullptr);
|
||||
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
|
||||
|
||||
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
|
||||
|
||||
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
|
||||
false, false, 0, false);
|
||||
|
||||
// not searching this index, set search_l to 0
|
||||
index.load(full_index_path_prefix.c_str(), num_threads, 1);
|
||||
|
||||
std::cout << "parsing labels" << std::endl;
|
||||
|
||||
index.prune_all_neighbors(stitched_R, 750, 1.2);
|
||||
index.save((final_index_path_prefix).c_str());
|
||||
|
||||
diskann::cout.rdbuf(diskann_cout_buffer);
|
||||
std::cout.rdbuf(std_cout_buffer);
|
||||
std::chrono::duration<double> pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer;
|
||||
std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Delete all temporary artifacts.
|
||||
* In the process of creating the stitched index, some temporary artifacts are
|
||||
* created:
|
||||
* 1. the separate bin files for each labels' points
|
||||
* 2. the separate diskANN indices built for each label
|
||||
* 3. the '.data' file created while generating the indices
|
||||
*/
|
||||
void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels)
|
||||
{
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
path curr_label_index_path_data(curr_label_index_path + ".data");
|
||||
|
||||
if (std::remove(curr_label_index_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_input_data_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_index_path_data.c_str()) != 0)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
// 1. handle cmdline inputs
|
||||
std::string data_type;
|
||||
path input_data_path, final_index_path_prefix, label_data_path;
|
||||
std::string universal_label;
|
||||
uint32_t num_threads, R, L, stitched_R;
|
||||
float alpha;
|
||||
|
||||
auto index_timer = std::chrono::high_resolution_clock::now();
|
||||
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label,
|
||||
num_threads, R, L, stitched_R, alpha);
|
||||
|
||||
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
|
||||
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
|
||||
|
||||
convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label);
|
||||
|
||||
// 2. parse label file and create necessary data structures
|
||||
std::vector<label_set> point_ids_to_labels;
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
||||
label_set all_labels;
|
||||
|
||||
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
|
||||
diskann::parse_label_file(labels_file_to_use, universal_label);
|
||||
|
||||
// 3. for each label, make a separate data file
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map;
|
||||
uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size();
|
||||
|
||||
#ifndef _WINDOWS
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else
|
||||
throw;
|
||||
#else
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else
|
||||
throw;
|
||||
#endif
|
||||
|
||||
// 4. for each created data file, create a vanilla diskANN index
|
||||
if (data_type == "uint8")
|
||||
diskann::generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else if (data_type == "int8")
|
||||
diskann::generate_label_indices<int8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else if (data_type == "float")
|
||||
diskann::generate_label_indices<float>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
// 5. "stitch" the indices together
|
||||
std::vector<std::vector<uint32_t>> stitched_graph;
|
||||
tsl::robin_map<std::string, uint32_t> label_entry_points;
|
||||
uint64_t stitched_graph_size;
|
||||
|
||||
if (data_type == "uint8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<uint8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else if (data_type == "int8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<int8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else if (data_type == "float")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<float>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else
|
||||
throw;
|
||||
path full_index_path_prefix = final_index_path_prefix + "_full";
|
||||
// 5a. save the stitched graph to disk
|
||||
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points,
|
||||
universal_label, labels_file_to_use);
|
||||
|
||||
// 6. run a prune on the stitched index, and save to disk
|
||||
if (data_type == "uint8")
|
||||
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else if (data_type == "int8")
|
||||
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else if (data_type == "float")
|
||||
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
std::chrono::duration<double> index_time = std::chrono::high_resolution_clock::now() - index_timer;
|
||||
std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl;
|
||||
|
||||
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
|
||||
}
|
||||
46
packages/leann-backend-diskann/third_party/DiskANN/apps/python/README.md
vendored
Normal file
46
packages/leann-backend-diskann/third_party/DiskANN/apps/python/README.md
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
<!-- Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license. -->
|
||||
|
||||
# Integration Tests
|
||||
The following tests use Python to prepare, run, verify, and tear down the rest api services.
|
||||
|
||||
We do make use of the built-in `unittest` library, but that's only to take advantage of test reporting purposes.
|
||||
|
||||
These are decidedly **not** _unit_ tests. These are end to end integration tests.
|
||||
|
||||
## Caveats
|
||||
This has only been tested or built for Linux, though we have written platform agnostic Python for the smoke test
|
||||
(i.e. using `os.path.join`, etc)
|
||||
|
||||
It has been tested on Python 3.9 and 3.10, but should work on Python 3.6+.
|
||||
|
||||
## How to Run
|
||||
|
||||
First, build the DiskANN RestAPI code; see $REPOSITORY_ROOT/workflows/rest_api.md for detailed instructions.
|
||||
|
||||
```bash
|
||||
cd tests/python
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
export DISKANN_BUILD_DIR=/path/to/your/diskann/build
|
||||
python -m unittest
|
||||
```
|
||||
|
||||
## Smoke Test Failed, Now What?
|
||||
The smoke test written takes advantage of temporary directories that are only valid during the
|
||||
lifetime of the test. The contents of these directories include:
|
||||
- Randomized vectors (first in tsv, then bin form) used to build the PQFlashIndex
|
||||
- The PQFlashIndex files
|
||||
|
||||
It is useful to keep these around. By setting some environment variables, you can control whether an ephemeral,
|
||||
temporary directory is used (and deleted on test completion), or left as an exercise for the developer to
|
||||
clean up.
|
||||
|
||||
The valid environment variables are:
|
||||
- `DISKANN_REST_TEST_WORKING_DIR` (example: `$USER/DiskANNRestTest`)
|
||||
- If this is specified, it **must exist** and **must be writeable**. Any existing files will be clobbered.
|
||||
- `DISKANN_REST_SERVER` (example: `http://127.0.0.1:10067`)
|
||||
- Note that if this is set, no data will be generated, nor will a server be started; it is presumed you have done
|
||||
all the work in creating and starting the rest server prior to running the test and just submits requests against it.
|
||||
0
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/__init__.py
vendored
Normal file
0
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/__init__.py
vendored
Normal file
67
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/disk_ann_util.py
vendored
Normal file
67
packages/leann-backend-diskann/third_party/DiskANN/apps/python/restapi/disk_ann_util.py
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def output_vectors(
|
||||
diskann_build_path: str,
|
||||
temporary_file_path: str,
|
||||
vectors: np.ndarray,
|
||||
timeout: int = 60
|
||||
) -> str:
|
||||
vectors_as_tsv_path = os.path.join(temporary_file_path, "vectors.tsv")
|
||||
with open(vectors_as_tsv_path, "w") as vectors_tsv_out:
|
||||
for vector in vectors:
|
||||
as_str = "\t".join((str(component) for component in vector))
|
||||
print(as_str, file=vectors_tsv_out)
|
||||
# there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in
|
||||
# favor of something more sane later
|
||||
vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin")
|
||||
tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin")
|
||||
|
||||
number_of_points, dimensions = vectors.shape
|
||||
args = [
|
||||
tsv_to_bin_path,
|
||||
"float",
|
||||
vectors_as_tsv_path,
|
||||
vectors_as_bin_path,
|
||||
str(dimensions),
|
||||
str(number_of_points)
|
||||
]
|
||||
completed = subprocess.run(args, timeout=timeout)
|
||||
if completed.returncode != 0:
|
||||
raise Exception(f"Unable to convert tsv to binary using tsv_to_bin, completed_process: {completed}")
|
||||
return vectors_as_bin_path
|
||||
|
||||
|
||||
def build_ssd_index(
|
||||
diskann_build_path: str,
|
||||
temporary_file_path: str,
|
||||
vectors: np.ndarray,
|
||||
per_process_timeout: int = 60 # this may not be long enough if you're doing something larger
|
||||
):
|
||||
vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout)
|
||||
|
||||
ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index")
|
||||
args = [
|
||||
ssd_builder_path,
|
||||
"--data_type", "float",
|
||||
"--dist_fn", "l2",
|
||||
"--data_path", vectors_as_bin_path,
|
||||
"--index_path_prefix", os.path.join(temporary_file_path, "smoke_test"),
|
||||
"-R", "64",
|
||||
"-L", "100",
|
||||
"--search_DRAM_budget", "1",
|
||||
"--build_DRAM_budget", "1",
|
||||
"--num_threads", "1",
|
||||
"--PQ_disk_bytes", "0"
|
||||
]
|
||||
completed = subprocess.run(args, timeout=per_process_timeout)
|
||||
|
||||
if completed.returncode != 0:
|
||||
command_run = " ".join(args)
|
||||
raise Exception(f"Unable to build a disk index with the command: '{command_run}'\ncompleted_process: {completed}\nstdout: {completed.stdout}\nstderr: {completed.stderr}")
|
||||
# index is now built inside of temporary_file_path
|
||||
379
packages/leann-backend-diskann/third_party/DiskANN/apps/range_search_disk_index.cpp
vendored
Normal file
379
packages/leann-backend-diskann/third_party/DiskANN/apps/range_search_disk_index.cpp
vendored
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "partition.h"
|
||||
#include "timer.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include "linux_aligned_file_reader.h"
|
||||
#else
|
||||
#ifdef USE_BING_INFRA
|
||||
#include "bing_aligned_file_reader.h"
|
||||
#else
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
#define WARMUP false
|
||||
|
||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
||||
{
|
||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << std::setw(22) << " " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(9) << results[s];
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file,
|
||||
std::string >_file, const uint32_t num_threads, const float search_range,
|
||||
const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector<uint32_t> &Lvec)
|
||||
{
|
||||
std::string pq_prefix = index_path_prefix + "_pq";
|
||||
std::string disk_index_file = index_path_prefix + "_disk.index";
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::endl;
|
||||
else
|
||||
diskann::cout << " beamwidth: " << beamwidth << std::endl;
|
||||
|
||||
// load query bin
|
||||
T *query = nullptr;
|
||||
std::vector<std::vector<uint32_t>> groundtruth_ids;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (gt_file != std::string("null") && file_exists(gt_file))
|
||||
{
|
||||
diskann::load_range_truthset(gt_file, groundtruth_ids,
|
||||
gt_num); // use for range search type of truthset
|
||||
// diskann::prune_truthset_for_range(gt_file, search_range,
|
||||
// groundtruth_ids, gt_num); // use for traditional truthset
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
reader.reset(new WindowsAlignedFileReader());
|
||||
#else
|
||||
reader.reset(new diskann::BingAlignedFileReader());
|
||||
#endif
|
||||
#else
|
||||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
if (res != 0)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
// cache bfs levels
|
||||
std::vector<uint32_t> node_list;
|
||||
diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl;
|
||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
// _pFlashIndex->generate_cache_list_from_sample_queries(
|
||||
// warmup_query_file, 15, 6, num_nodes_to_cache, num_threads,
|
||||
// node_list);
|
||||
_pFlashIndex->load_cache_list(node_list);
|
||||
node_list.clear();
|
||||
node_list.shrink_to_fit();
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
uint64_t warmup_L = 20;
|
||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
||||
T *warmup = nullptr;
|
||||
|
||||
if (WARMUP)
|
||||
{
|
||||
if (file_exists(warmup_query_file))
|
||||
{
|
||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
||||
warmup_dim = query_dim;
|
||||
warmup_aligned_dim = query_aligned_dim;
|
||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++)
|
||||
{
|
||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
||||
{
|
||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
diskann::cout << "Warming up index... " << std::flush;
|
||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
||||
warmup_result_ids_64.data() + (i * 1),
|
||||
warmup_result_dists.data() + (i * 1), 4);
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
diskann::cout.precision(2);
|
||||
|
||||
std::string recall_string = "Recall@rng=" + std::to_string(search_range);
|
||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
||||
<< "CPU (s)";
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << "==============================================================="
|
||||
"==========================================="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
|
||||
|
||||
uint32_t optimized_beamwidth = 2;
|
||||
uint32_t max_list_size = 10000;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
|
||||
if (beamwidth <= 0)
|
||||
{
|
||||
optimized_beamwidth =
|
||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
||||
}
|
||||
else
|
||||
optimized_beamwidth = beamwidth;
|
||||
|
||||
query_result_ids[test_id].clear();
|
||||
query_result_ids[test_id].resize(query_num);
|
||||
|
||||
diskann::QueryStats *stats = new diskann::QueryStats[query_num];
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
std::vector<uint64_t> indices;
|
||||
std::vector<float> distances;
|
||||
uint32_t res_count =
|
||||
_pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
|
||||
distances, optimized_beamwidth, stats + i);
|
||||
query_result_ids[test_id][i].reserve(res_count);
|
||||
query_result_ids[test_id][i].resize(res_count);
|
||||
for (uint32_t idx = 0; idx < res_count; idx++)
|
||||
query_result_ids[test_id][i][idx] = (uint32_t)indices[idx];
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
auto qps = (1.0 * query_num) / (1.0 * diff.count());
|
||||
|
||||
auto mean_latency = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
||||
|
||||
double mean_cpuus = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
||||
|
||||
double recall = 0;
|
||||
double ratio_of_sums = 0;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recall =
|
||||
diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]);
|
||||
|
||||
uint32_t total_true_positive = 0;
|
||||
uint32_t total_positive = 0;
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
total_true_positive += (uint32_t)query_result_ids[test_id][i].size();
|
||||
total_positive += (uint32_t)groundtruth_ids[i].size();
|
||||
}
|
||||
|
||||
ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
|
||||
}
|
||||
|
||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
||||
<< std::setw(16) << mean_cpuus;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout << "Done searching. " << std::endl;
|
||||
|
||||
diskann::aligned_free(query);
|
||||
if (warmup != nullptr)
|
||||
diskann::aligned_free(warmup);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file;
|
||||
uint32_t num_threads, W, num_nodes_to_cache;
|
||||
std::vector<uint32_t> Lvec;
|
||||
float range;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description(
|
||||
"range_search_disk_index", "Searches disk DiskANN indexes using ranges")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
required_configs.add_options()("range_threshold,K", po::value<float>(&range)->required(),
|
||||
"Number of neighbors to be returned");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
||||
program_options_utils::BEAMWIDTH);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
||||
{
|
||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
40
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/CMakeLists.txt
vendored
Normal file
40
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
add_executable(inmem_server inmem_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(inmem_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(ssd_server ssd_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(ssd_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(client client.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(client PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
124
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/client.cpp
vendored
Normal file
124
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/client.cpp
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include <cpprest/http_client.h>
|
||||
#include <restapi/common.h>
|
||||
|
||||
using namespace web;
|
||||
using namespace web::http;
|
||||
using namespace web::http::client;
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T>
|
||||
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
|
||||
const unsigned k_value)
|
||||
{
|
||||
web::http::client::http_client client(U(ip_addr_port));
|
||||
|
||||
T *data;
|
||||
size_t npts = 1, ndims = 128, rounded_dim = 128;
|
||||
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
|
||||
|
||||
for (unsigned i = 0; i < nq; ++i)
|
||||
{
|
||||
T *vec = data + i * rounded_dim;
|
||||
web::http::http_request http_query(methods::POST);
|
||||
web::json::value queryJson = web::json::value::object();
|
||||
queryJson[QUERY_ID_KEY] = i;
|
||||
queryJson[K_KEY] = k_value;
|
||||
queryJson[L_KEY] = Ls;
|
||||
for (size_t i = 0; i < ndims; ++i)
|
||||
{
|
||||
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
|
||||
}
|
||||
http_query.set_body(queryJson);
|
||||
|
||||
client.request(http_query)
|
||||
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
|
||||
if (response.status_code() == status_codes::OK)
|
||||
{
|
||||
return response.extract_string();
|
||||
}
|
||||
std::cerr << "Query failed" << std::endl;
|
||||
return pplx::task_from_result(utility::string_t());
|
||||
})
|
||||
.then([](pplx::task<utility::string_t> previousTask) {
|
||||
try
|
||||
{
|
||||
std::cout << previousTask.get() << std::endl;
|
||||
}
|
||||
catch (http_exception const &e)
|
||||
{
|
||||
std::wcout << e.what() << std::endl;
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, query_file, address;
|
||||
uint32_t num_queries;
|
||||
uint32_t l_search, k_value;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the queries to search");
|
||||
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
|
||||
"Number of queries to search");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
||||
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
query_loop<float>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported type " << argv[2] << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
138
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/inmem_server.cpp
vendored
Normal file
138
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/inmem_server.cpp
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
|
||||
uint32_t num_threads;
|
||||
uint32_t l_search;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
|
||||
"File containing the data found in the index");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
|
||||
"Path prefix for saving index file components");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
|
||||
"Number of threads used for building index");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
83
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/main.cpp
vendored
Normal file
83
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/main.cpp
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <restapi/server.h>
|
||||
#include <restapi/in_memory_search.h>
|
||||
#include <codecvt>
|
||||
#include <iostream>
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
|
||||
|
||||
void setup(const utility::string_t &address)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
|
||||
g_httpServer->open().wait();
|
||||
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
|
||||
{
|
||||
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
|
||||
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
|
||||
}
|
||||
|
||||
std::wstring getHostingAddress(const char *hostNameAndPort)
|
||||
{
|
||||
wchar_t buffer[4096];
|
||||
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
|
||||
sizeof(buffer) / sizeof(buffer[0]));
|
||||
return std::wstring(buffer);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
|
||||
"<base_file> <ids_file> "
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto address = getHostingAddress(argv[1]);
|
||||
loadIndex(argv[2], argv[3], argv[4]);
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
182
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/multiple_ssdindex_server.cpp
vendored
Normal file
182
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/multiple_ssdindex_server.cpp
vendored
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> index_tag_paths;
|
||||
std::ifstream index_in(index_prefix_paths);
|
||||
if (!index_in.is_open())
|
||||
{
|
||||
std::cerr << "Could not open " << index_prefix_paths << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream tags_in(tags_file);
|
||||
if (!tags_in.is_open())
|
||||
{
|
||||
std::cerr << "Could not open " << tags_file << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string prefix, tagfile;
|
||||
while (std::getline(index_in, prefix))
|
||||
{
|
||||
if (std::getline(tags_in, tagfile))
|
||||
{
|
||||
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "The number of tags specified does not match the number of "
|
||||
"indices specified"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
index_in.close();
|
||||
tags_in.close();
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << data_type << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
141
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/ssd_server.cpp
vendored
Normal file
141
packages/leann-backend-diskann/third_party/DiskANN/apps/restapi/ssd_server.cpp
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
499
packages/leann-backend-diskann/third_party/DiskANN/apps/search_disk_index.cpp
vendored
Normal file
499
packages/leann-backend-diskann/third_party/DiskANN/apps/search_disk_index.cpp
vendored
Normal file
@@ -0,0 +1,499 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "common_includes.h"
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "partition.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "timer.h"
|
||||
#include "percentile_stats.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include "linux_aligned_file_reader.h"
|
||||
#else
|
||||
#ifdef USE_BING_INFRA
|
||||
#include "bing_aligned_file_reader.h"
|
||||
#else
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define WARMUP false
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
||||
{
|
||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << std::setw(22) << " " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(9) << results[s];
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
|
||||
const std::string &result_output_prefix, const std::string &query_file, std::string >_file,
|
||||
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
|
||||
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
|
||||
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
|
||||
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
|
||||
{
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
|
||||
else
|
||||
diskann::cout << " beamwidth: " << beamwidth << std::flush;
|
||||
if (search_io_limit == std::numeric_limits<uint32_t>::max())
|
||||
diskann::cout << "." << std::endl;
|
||||
else
|
||||
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
|
||||
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
// load query bin
|
||||
T *query = nullptr;
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool filtered_search = false;
|
||||
if (!query_filters.empty())
|
||||
{
|
||||
filtered_search = true;
|
||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
||||
"filters file"
|
||||
<< std::endl;
|
||||
return -1; // To return -1 or some other error handling?
|
||||
}
|
||||
}
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file))
|
||||
{
|
||||
diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
reader.reset(new WindowsAlignedFileReader());
|
||||
#else
|
||||
reader.reset(new diskann::BingAlignedFileReader());
|
||||
#endif
|
||||
#else
|
||||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
if (res != 0)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> node_list;
|
||||
diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl;
|
||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
// if (num_nodes_to_cache > 0)
|
||||
// _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
|
||||
// num_threads, node_list);
|
||||
_pFlashIndex->load_cache_list(node_list);
|
||||
node_list.clear();
|
||||
node_list.shrink_to_fit();
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
uint64_t warmup_L = 20;
|
||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
||||
T *warmup = nullptr;
|
||||
|
||||
if (WARMUP)
|
||||
{
|
||||
if (file_exists(warmup_query_file))
|
||||
{
|
||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
||||
warmup_dim = query_dim;
|
||||
warmup_aligned_dim = query_aligned_dim;
|
||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++)
|
||||
{
|
||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
||||
{
|
||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
diskann::cout << "Warming up index... " << std::flush;
|
||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
||||
warmup_result_ids_64.data() + (i * 1),
|
||||
warmup_result_dists.data() + (i * 1), 4);
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
diskann::cout.precision(2);
|
||||
|
||||
std::string recall_string = "Recall@" + std::to_string(recall_at);
|
||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
||||
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << "=================================================================="
|
||||
"================================================================="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
|
||||
uint32_t optimized_beamwidth = 2;
|
||||
|
||||
double best_recall = 0.0;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (beamwidth <= 0)
|
||||
{
|
||||
diskann::cout << "Tuning beamwidth.." << std::endl;
|
||||
optimized_beamwidth =
|
||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
||||
}
|
||||
else
|
||||
optimized_beamwidth = beamwidth;
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
|
||||
auto stats = new diskann::QueryStats[query_num];
|
||||
|
||||
std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
if (!filtered_search)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
|
||||
query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at),
|
||||
optimized_beamwidth, use_reorder_data, stats + i);
|
||||
}
|
||||
else
|
||||
{
|
||||
LabelT label_for_search;
|
||||
if (query_filters.size() == 1)
|
||||
{ // one label for all queries
|
||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
|
||||
}
|
||||
else
|
||||
{ // one label for each query
|
||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
|
||||
}
|
||||
_pFlashIndex->cached_beam_search(
|
||||
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
|
||||
use_reorder_data, stats + i);
|
||||
}
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
double qps = (1.0 * query_num) / (1.0 * diff.count());
|
||||
|
||||
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
|
||||
query_num, recall_at);
|
||||
|
||||
auto mean_latency = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
||||
|
||||
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
||||
|
||||
auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.io_us; });
|
||||
|
||||
double recall = 0;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
||||
query_result_ids[test_id].data(), recall_at, recall_at);
|
||||
best_recall = std::max(recall, best_recall);
|
||||
}
|
||||
|
||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
||||
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
delete[] stats;
|
||||
}
|
||||
|
||||
diskann::cout << "Done searching. Now saving results " << std::endl;
|
||||
uint64_t test_id = 0;
|
||||
for (auto L : Lvec)
|
||||
{
|
||||
if (L < recall_at)
|
||||
continue;
|
||||
|
||||
std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
|
||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
||||
|
||||
cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin";
|
||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at);
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
if (warmup != nullptr)
|
||||
diskann::aligned_free(warmup);
|
||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
|
||||
label_type, query_filters_file;
|
||||
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
|
||||
std::vector<uint32_t> Lvec;
|
||||
bool use_reorder_data = false;
|
||||
float fail_if_recall_below = 0.0f;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path_prefix)->required(),
|
||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
||||
program_options_utils::BEAMWIDTH);
|
||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
||||
optional_configs.add_options()(
|
||||
"search_io_limit",
|
||||
po::value<uint32_t>(&search_io_limit)->default_value(std::numeric_limits<uint32_t>::max()),
|
||||
"Max #IOs for search. Default value: uint32::max()");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD. Default value: false");
|
||||
optional_configs.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
||||
optional_configs.add_options()("query_filters_file",
|
||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
optional_configs.add_options()("fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (vm["use_reorder_data"].as<bool>())
|
||||
use_reorder_data = true;
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
||||
{
|
||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (use_reorder_data && data_type != std::string("float"))
|
||||
{
|
||||
std::cout << "Error: Reorder data for reordering currently only "
|
||||
"supported for float data type."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && query_filters_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> query_filters;
|
||||
if (filter_label != "")
|
||||
{
|
||||
query_filters.push_back(filter_label);
|
||||
}
|
||||
else if (query_filters_file != "")
|
||||
{
|
||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!query_filters.empty() && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
477
packages/leann-backend-diskann/third_party/DiskANN/apps/search_memory_index.cpp
vendored
Normal file
477
packages/leann-backend-diskann/third_party/DiskANN/apps/search_memory_index.cpp
vendored
Normal file
@@ -0,0 +1,477 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
#include "index_factory.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
|
||||
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
|
||||
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
|
||||
const bool dynamic, const bool tags, const bool show_qps_per_thread,
|
||||
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
|
||||
{
|
||||
using TagT = uint32_t;
|
||||
// Load the query file
|
||||
T *query = nullptr;
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (truthset_file != std::string("null") && file_exists(truthset_file))
|
||||
{
|
||||
diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl;
|
||||
}
|
||||
|
||||
bool filtered_search = false;
|
||||
if (!query_filters.empty())
|
||||
{
|
||||
filtered_search = true;
|
||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
||||
"filters file"
|
||||
<< std::endl;
|
||||
return -1; // To return -1 or some other error handling?
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);
|
||||
|
||||
auto config = diskann::IndexConfigBuilder()
|
||||
.with_metric(metric)
|
||||
.with_dimension(query_dim)
|
||||
.with_max_points(0)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.is_dynamic_index(dynamic)
|
||||
.is_enable_tags(tags)
|
||||
.is_concurrent_consolidate(false)
|
||||
.is_pq_dist_build(false)
|
||||
.is_use_opq(false)
|
||||
.with_num_pq_chunks(0)
|
||||
.with_num_frozen_pts(num_frozen_pts)
|
||||
.build();
|
||||
|
||||
auto index_factory = diskann::IndexFactory(config);
|
||||
auto index = index_factory.create_instance();
|
||||
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
|
||||
if (metric == diskann::FAST_L2)
|
||||
index->optimize_index_layout();
|
||||
|
||||
std::cout << "Using " << num_threads << " threads to search" << std::endl;
|
||||
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
std::cout.precision(2);
|
||||
const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS";
|
||||
uint32_t table_width = 0;
|
||||
if (tags)
|
||||
{
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)"
|
||||
<< std::setw(15) << "99.9 Latency";
|
||||
table_width += 4 + 12 + 20 + 15;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps"
|
||||
<< std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency";
|
||||
table_width += 4 + 12 + 18 + 20 + 15;
|
||||
}
|
||||
uint32_t recalls_to_print = 0;
|
||||
const uint32_t first_recall = print_all_recalls ? 1 : recall_at;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
||||
{
|
||||
std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall));
|
||||
}
|
||||
recalls_to_print = recall_at + 1 - first_recall;
|
||||
table_width += recalls_to_print * 12;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << std::string(table_width, '=') << std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
std::vector<float> latency_stats(query_num, 0);
|
||||
std::vector<uint32_t> cmp_stats;
|
||||
if (not tags || filtered_search)
|
||||
{
|
||||
cmp_stats = std::vector<uint32_t>(query_num, 0);
|
||||
}
|
||||
|
||||
std::vector<TagT> query_result_tags;
|
||||
if (tags)
|
||||
{
|
||||
query_result_tags.resize(recall_at * query_num);
|
||||
}
|
||||
|
||||
double best_recall = 0.0;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
std::vector<T *> res = std::vector<T *>();
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(num_threads);
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
auto qs = std::chrono::high_resolution_clock::now();
|
||||
if (filtered_search && !tags)
|
||||
{
|
||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
||||
|
||||
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at,
|
||||
query_result_dists[test_id].data() + i * recall_at);
|
||||
cmp_stats[i] = retval.second;
|
||||
}
|
||||
else if (metric == diskann::FAST_L2)
|
||||
{
|
||||
index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at);
|
||||
}
|
||||
else if (tags)
|
||||
{
|
||||
if (!filtered_search)
|
||||
{
|
||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_tags.data() + i * recall_at, nullptr, res);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
||||
|
||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
|
||||
}
|
||||
|
||||
for (int64_t r = 0; r < (int64_t)recall_at; r++)
|
||||
{
|
||||
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cmp_stats[i] = index
|
||||
->search(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at)
|
||||
.second;
|
||||
}
|
||||
auto qe = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = qe - qs;
|
||||
latency_stats[i] = (float)(diff.count() * 1000000);
|
||||
}
|
||||
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
|
||||
|
||||
double displayed_qps = query_num / diff.count();
|
||||
|
||||
if (show_qps_per_thread)
|
||||
displayed_qps /= num_threads;
|
||||
|
||||
std::vector<double> recalls;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recalls.reserve(recalls_to_print);
|
||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
||||
{
|
||||
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
||||
query_result_ids[test_id].data(), recall_at, curr_recall));
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(latency_stats.begin(), latency_stats.end());
|
||||
double mean_latency =
|
||||
std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast<float>(query_num);
|
||||
|
||||
float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num;
|
||||
|
||||
if (tags && !filtered_search)
|
||||
{
|
||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency
|
||||
<< std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps
|
||||
<< std::setw(20) << (float)mean_latency << std::setw(15)
|
||||
<< (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
||||
}
|
||||
for (double recall : recalls)
|
||||
{
|
||||
std::cout << std::setw(12) << recall;
|
||||
best_recall = std::max(recall, best_recall);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Done searching. Now saving results " << std::endl;
|
||||
uint64_t test_id = 0;
|
||||
for (auto L : Lvec)
|
||||
{
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L);
|
||||
|
||||
std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin";
|
||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
||||
|
||||
cur_result_path = cur_result_path_prefix + "_dists_float.bin";
|
||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at);
|
||||
|
||||
test_id++;
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
|
||||
query_filters_file;
|
||||
uint32_t num_threads, K;
|
||||
std::vector<uint32_t> Lvec;
|
||||
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
|
||||
float fail_if_recall_below = 0.0f;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print this information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path)->required(),
|
||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
||||
optional_configs.add_options()("query_filters_file",
|
||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()(
|
||||
"dynamic", po::value<bool>(&dynamic)->default_value(false),
|
||||
"Whether the index is dynamic. Dynamic indices must have associated tags. Default false.");
|
||||
optional_configs.add_options()("tags", po::value<bool>(&tags)->default_value(false),
|
||||
"Whether to search with external identifiers (tags). Default false.");
|
||||
optional_configs.add_options()("fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
||||
|
||||
// Output controls
|
||||
po::options_description output_controls("Output controls");
|
||||
output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls),
|
||||
"Print recalls at all positions, from 1 up to specified "
|
||||
"recall_at value");
|
||||
output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread),
|
||||
"Print overall QPS divided by the number of threads in "
|
||||
"the output table");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs).add(output_controls);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if ((dist_fn == std::string("mips")) && (data_type == std::string("float")))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float")))
|
||||
{
|
||||
metric = diskann::Metric::FAST_L2;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only l2/ cosine are "
|
||||
"supported in general, and mips/fast_l2 only for floating "
|
||||
"point data."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (dynamic && not tags)
|
||||
{
|
||||
std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0)
|
||||
{
|
||||
std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && query_filters_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> query_filters;
|
||||
if (filter_label != "")
|
||||
{
|
||||
query_filters.push_back(filter_label);
|
||||
}
|
||||
else if (query_filters_file != "")
|
||||
{
|
||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!query_filters.empty() && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
{
|
||||
return search_memory_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
return search_memory_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
{
|
||||
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
536
packages/leann-backend-diskann/third_party/DiskANN/apps/test_insert_deletes_consolidate.cpp
vendored
Normal file
536
packages/leann-backend-diskann/third_party/DiskANN/apps/test_insert_deletes_consolidate.cpp
vendored
Normal file
@@ -0,0 +1,536 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <future>
|
||||
|
||||
#include "utils.h"
|
||||
#include "filter_utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
#include "index_factory.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
||||
// instead of cached_ifstream.
|
||||
template <typename T>
|
||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
||||
{
|
||||
diskann::Timer timer;
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
||||
size_t actual_file_size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
size_t npts = (uint32_t)npts_i32;
|
||||
size_t dim = (uint32_t)dim_i32;
|
||||
|
||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
||||
<< std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (offset_points + points_to_read > npts)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
||||
<< " points, but have only " << npts << " points" << std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
||||
|
||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
||||
|
||||
for (size_t i = 0; i < points_to_read; i++)
|
||||
{
|
||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
reader.close();
|
||||
|
||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
||||
std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl;
|
||||
}
|
||||
|
||||
std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted,
|
||||
size_t last_point_threshold)
|
||||
{
|
||||
std::string final_path = save_path;
|
||||
if (points_to_skip > 0)
|
||||
{
|
||||
final_path += "skip" + std::to_string(points_to_skip) + "-";
|
||||
}
|
||||
|
||||
final_path += "del" + std::to_string(points_deleted) + "-";
|
||||
final_path += std::to_string(last_point_threshold);
|
||||
return final_path;
|
||||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data,
|
||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &location_to_labels)
|
||||
{
|
||||
diskann::Timer insert_timer;
|
||||
#pragma omp parallel for num_threads(thread_count) schedule(dynamic)
|
||||
for (int64_t j = start; j < (int64_t)end; j++)
|
||||
{
|
||||
if (!location_to_labels.empty())
|
||||
{
|
||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
||||
location_to_labels[j - start]);
|
||||
}
|
||||
else
|
||||
{
|
||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
||||
}
|
||||
}
|
||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
||||
<< " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n ";
|
||||
}
|
||||
|
||||
template <typename T, typename TagT>
|
||||
void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params,
|
||||
size_t points_to_skip, size_t points_to_delete_from_beginning)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::cout << std::endl
|
||||
<< "Lazy deleting points " << points_to_skip << " to "
|
||||
<< points_to_skip + points_to_delete_from_beginning << "... ";
|
||||
for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i)
|
||||
index.lazy_delete(static_cast<TagT>(i + 1)); // Since tags are data location + 1
|
||||
std::cout << "done." << std::endl;
|
||||
|
||||
auto report = index.consolidate_deletes(delete_params);
|
||||
std::cout << "#active points: " << report._active_points << std::endl
|
||||
<< "max points: " << report._max_points << std::endl
|
||||
<< "empty slots: " << report._empty_slots << std::endl
|
||||
<< "deletes processed: " << report._slots_released << std::endl
|
||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
||||
<< "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, "
|
||||
<< points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)"
|
||||
<< std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cout << "Exception caught in deletion thread: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip,
|
||||
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
|
||||
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
|
||||
const std::string &save_path, size_t points_to_delete_from_beginning,
|
||||
size_t start_deletes_after, bool concurrent, const std::string &label_file,
|
||||
const std::string &universal_label)
|
||||
{
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
||||
aligned_dim = ROUND_UP(dim, 8);
|
||||
bool has_labels = label_file != "";
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
|
||||
size_t current_point_offset = points_to_skip;
|
||||
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
|
||||
|
||||
bool enable_tags = true;
|
||||
using TagT = uint32_t;
|
||||
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
|
||||
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
|
||||
.with_metric(diskann::L2)
|
||||
.with_dimension(dim)
|
||||
.with_max_points(max_points_to_insert)
|
||||
.is_dynamic_index(true)
|
||||
.with_index_write_params(params)
|
||||
.with_index_search_params(index_search_params)
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.is_enable_tags(enable_tags)
|
||||
.is_filtered(has_labels)
|
||||
.with_num_frozen_pts(num_start_pts)
|
||||
.is_concurrent_consolidate(concurrent)
|
||||
.build();
|
||||
|
||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
||||
auto index = index_factory.create_instance();
|
||||
|
||||
if (universal_label != "")
|
||||
{
|
||||
LabelT u_label = 0;
|
||||
index->set_universal_label(u_label);
|
||||
}
|
||||
|
||||
if (points_to_skip > num_points)
|
||||
{
|
||||
throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (max_points_to_insert == 0)
|
||||
{
|
||||
max_points_to_insert = num_points;
|
||||
}
|
||||
|
||||
if (points_to_skip + max_points_to_insert > num_points)
|
||||
{
|
||||
max_points_to_insert = num_points - points_to_skip;
|
||||
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
|
||||
if (beginning_index_size > max_points_to_insert)
|
||||
{
|
||||
beginning_index_size = max_points_to_insert;
|
||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint)
|
||||
{
|
||||
beginning_index_size = points_per_checkpoint;
|
||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl;
|
||||
}
|
||||
|
||||
T *data = nullptr;
|
||||
diskann::alloc_aligned(
|
||||
(void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
|
||||
std::vector<TagT> tags(beginning_index_size);
|
||||
std::iota(tags.begin(), tags.end(), 1 + static_cast<TagT>(current_point_offset));
|
||||
|
||||
load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size);
|
||||
std::cout << "load aligned bin succeeded" << std::endl;
|
||||
diskann::Timer timer;
|
||||
|
||||
if (beginning_index_size > 0)
|
||||
{
|
||||
index->build(data, beginning_index_size, tags);
|
||||
}
|
||||
else
|
||||
{
|
||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
||||
}
|
||||
|
||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
||||
std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took "
|
||||
<< elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n ";
|
||||
|
||||
current_point_offset += beginning_index_size;
|
||||
|
||||
if (points_to_delete_from_beginning > max_points_to_insert)
|
||||
{
|
||||
points_to_delete_from_beginning = static_cast<uint32_t>(max_points_to_insert);
|
||||
std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::vector<LabelT>> location_to_labels;
|
||||
if (concurrent)
|
||||
{
|
||||
// handle labels
|
||||
const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
location_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
int32_t sub_threads = (params.num_threads + 1) / 2;
|
||||
bool delete_launched = false;
|
||||
std::future<void> delete_task;
|
||||
|
||||
diskann::Timer timer;
|
||||
|
||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
||||
{
|
||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, sub_threads, data, aligned_dim,
|
||||
location_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
if (!delete_launched && end >= start_deletes_after &&
|
||||
end >= points_to_skip + points_to_delete_from_beginning)
|
||||
{
|
||||
delete_launched = true;
|
||||
diskann::IndexWriteParameters delete_params =
|
||||
diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build();
|
||||
|
||||
delete_task = std::async(std::launch::async, [&]() {
|
||||
delete_from_beginning<T, TagT>(*index, delete_params, points_to_skip,
|
||||
points_to_delete_from_beginning);
|
||||
});
|
||||
}
|
||||
}
|
||||
delete_task.wait();
|
||||
|
||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
location_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
size_t last_snapshot_points_threshold = 0;
|
||||
size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
||||
|
||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
||||
{
|
||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, (int32_t)params.num_threads, data,
|
||||
aligned_dim, location_to_labels);
|
||||
|
||||
if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0)
|
||||
{
|
||||
diskann::Timer save_timer;
|
||||
|
||||
const auto save_path_inc =
|
||||
get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end);
|
||||
index->save(save_path_inc.c_str(), false);
|
||||
const double elapsedSeconds = save_timer.elapsed() / 1000000.0;
|
||||
const size_t points_saved = end - points_to_skip;
|
||||
|
||||
std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds ("
|
||||
<< points_saved / elapsedSeconds << " points/second)\n";
|
||||
|
||||
num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
||||
last_snapshot_points_threshold = end;
|
||||
}
|
||||
|
||||
std::cout << "Number of points in the index post insertion " << end << std::endl;
|
||||
}
|
||||
|
||||
if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold)
|
||||
{
|
||||
const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
// index.save(save_path_inc.c_str(), false);
|
||||
}
|
||||
|
||||
if (points_to_delete_from_beginning > 0)
|
||||
{
|
||||
delete_from_beginning<T, TagT>(*index, params, points_to_skip, points_to_delete_from_beginning);
|
||||
}
|
||||
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
}
|
||||
|
||||
diskann::aligned_free(data);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix;
|
||||
uint32_t num_threads, R, L, num_start_pts;
|
||||
float alpha, start_point_norm;
|
||||
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
|
||||
points_to_delete_from_beginning, start_deletes_after;
|
||||
bool concurrent;
|
||||
|
||||
// label options
|
||||
std::string label_file, label_type, universal_label;
|
||||
std::uint32_t Lf, unique_labels_supported;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
|
||||
"Test insert deletes & consolidate")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("points_to_skip", po::value<uint64_t>(&points_to_skip)->required(),
|
||||
"Skip these first set of points from file");
|
||||
required_configs.add_options()("beginning_index_size", po::value<uint64_t>(&beginning_index_size)->required(),
|
||||
"Batch build will be called on these set of points");
|
||||
required_configs.add_options()("points_per_checkpoint", po::value<uint64_t>(&points_per_checkpoint)->required(),
|
||||
"Insertions are done in batches of points_per_checkpoint");
|
||||
required_configs.add_options()("checkpoints_per_snapshot",
|
||||
po::value<uint64_t>(&checkpoints_per_snapshot)->required(),
|
||||
"Save the index to disk every few checkpoints");
|
||||
required_configs.add_options()("points_to_delete_from_beginning",
|
||||
po::value<uint64_t>(&points_to_delete_from_beginning)->required(), "");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("max_points_to_insert",
|
||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
||||
"These number of points from the file are inserted after "
|
||||
"points_to_skip");
|
||||
optional_configs.add_options()("do_concurrent", po::value<bool>(&concurrent)->default_value(false), "");
|
||||
optional_configs.add_options()("start_deletes_after",
|
||||
po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
|
||||
optional_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
|
||||
// optional params for filters
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
optional_configs.add_options()("unique_labels_supported",
|
||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
||||
"Number of unique labels supported by the dynamic index.");
|
||||
|
||||
optional_configs.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (beginning_index_size == 0)
|
||||
if (start_point_norm == 0)
|
||||
{
|
||||
std::cout << "When beginning_index_size is 0, use a start "
|
||||
"point with "
|
||||
"appropriate norm"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool has_labels = false;
|
||||
if (!label_file.empty() || label_file != "")
|
||||
{
|
||||
has_labels = true;
|
||||
}
|
||||
|
||||
if (num_start_pts < unique_labels_supported)
|
||||
{
|
||||
num_start_pts = unique_labels_supported;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(500)
|
||||
.with_alpha(alpha)
|
||||
.with_num_threads(num_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
if (data_type == std::string("int8"))
|
||||
build_incremental_index<int8_t>(
|
||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
||||
else if (data_type == std::string("uint8"))
|
||||
build_incremental_index<uint8_t>(
|
||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
||||
else if (data_type == std::string("float"))
|
||||
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
|
||||
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
|
||||
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
|
||||
start_deletes_after, concurrent, label_file, universal_label);
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Caught unknown exception" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
523
packages/leann-backend-diskann/third_party/DiskANN/apps/test_streaming_scenario.cpp
vendored
Normal file
523
packages/leann-backend-diskann/third_party/DiskANN/apps/test_streaming_scenario.cpp
vendored
Normal file
@@ -0,0 +1,523 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <future>
|
||||
#include <abstract_index.h>
|
||||
#include <index_factory.h>
|
||||
|
||||
#include "utils.h"
|
||||
#include "filter_utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
||||
// instead of cached_ifstream.
|
||||
template <typename T>
|
||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
||||
size_t actual_file_size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
size_t npts = (uint32_t)npts_i32;
|
||||
size_t dim = (uint32_t)dim_i32;
|
||||
|
||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
||||
<< std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (offset_points + points_to_read > npts)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
||||
<< " points, but have only " << npts << " points" << std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
||||
|
||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
||||
|
||||
for (size_t i = 0; i < points_to_read; i++)
|
||||
{
|
||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
reader.close();
|
||||
}
|
||||
|
||||
std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval,
|
||||
size_t max_points_to_insert)
|
||||
{
|
||||
std::string final_path = save_path;
|
||||
final_path += "act" + std::to_string(active_window) + "-";
|
||||
final_path += "cons" + std::to_string(consolidate_interval) + "-";
|
||||
final_path += "max" + std::to_string(max_points_to_insert);
|
||||
return final_path;
|
||||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data,
|
||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &pts_to_labels)
|
||||
{
|
||||
try
|
||||
{
|
||||
diskann::Timer insert_timer;
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
size_t num_failed = 0;
|
||||
#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed)
|
||||
for (int64_t j = start; j < (int64_t)end; j++)
|
||||
{
|
||||
int insert_result = -1;
|
||||
if (pts_to_labels.size() > 0)
|
||||
{
|
||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
||||
pts_to_labels[j - start]);
|
||||
}
|
||||
else
|
||||
{
|
||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
||||
}
|
||||
|
||||
if (insert_result != 0)
|
||||
{
|
||||
std::cerr << "Insert failed " << j << std::endl;
|
||||
num_failed++;
|
||||
}
|
||||
}
|
||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
||||
<< " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)"
|
||||
<< std::endl;
|
||||
if (num_failed > 0)
|
||||
std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start,
|
||||
size_t end)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... ";
|
||||
for (size_t i = start; i < end; ++i)
|
||||
index.lazy_delete(static_cast<TagT>(1 + i));
|
||||
std::cout << "lazy delete done." << std::endl;
|
||||
|
||||
auto report = index.consolidate_deletes(delete_params);
|
||||
while (report._status != diskann::consolidation_report::status_code::SUCCESS)
|
||||
{
|
||||
int wait_time = 5;
|
||||
if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL)
|
||||
{
|
||||
diskann::cerr << "Unable to acquire consolidate delete lock after "
|
||||
<< "deleting points " << start << " to " << end << ". Will retry in " << wait_time
|
||||
<< "seconds." << std::endl;
|
||||
}
|
||||
else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR)
|
||||
{
|
||||
diskann::cerr << "Inconsistent counts in data structure. "
|
||||
<< "Will retry in " << wait_time << "seconds." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Exiting after unknown error in consolidate delete" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(wait_time));
|
||||
report = index.consolidate_deletes(delete_params);
|
||||
}
|
||||
auto points_processed = report._active_points + report._slots_released;
|
||||
auto deletion_rate = points_processed / report._time;
|
||||
std::cout << "#active points: " << report._active_points << std::endl
|
||||
<< "max points: " << report._max_points << std::endl
|
||||
<< "empty slots: " << report._empty_slots << std::endl
|
||||
<< "deletes processed: " << report._slots_released << std::endl
|
||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
||||
<< "Deletion rate: " << deletion_rate << "/sec "
|
||||
<< "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha,
|
||||
const uint32_t insert_threads, const uint32_t consolidate_threads,
|
||||
size_t max_points_to_insert, size_t active_window, size_t consolidate_interval,
|
||||
const float start_point_norm, uint32_t num_start_pts, const std::string &save_path,
|
||||
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
|
||||
{
|
||||
const uint32_t C = 500;
|
||||
const bool saturate_graph = false;
|
||||
bool has_labels = label_file != "";
|
||||
|
||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_threads(insert_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
|
||||
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_threads(consolidate_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
|
||||
std::vector<std::vector<LabelT>> pts_to_labels;
|
||||
|
||||
const auto save_path_inc =
|
||||
get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
pts_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
||||
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
|
||||
<< std::endl;
|
||||
aligned_dim = ROUND_UP(dim, 8);
|
||||
auto index_config = diskann::IndexConfigBuilder()
|
||||
.with_metric(diskann::L2)
|
||||
.with_dimension(dim)
|
||||
.with_max_points(active_window + 4 * consolidate_interval)
|
||||
.is_dynamic_index(true)
|
||||
.is_enable_tags(true)
|
||||
.is_use_opq(false)
|
||||
.is_filtered(has_labels)
|
||||
.with_num_pq_chunks(0)
|
||||
.is_pq_dist_build(false)
|
||||
.with_num_frozen_pts(num_start_pts)
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_index_write_params(params)
|
||||
.with_index_search_params(index_search_params)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.build();
|
||||
|
||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
||||
auto index = index_factory.create_instance();
|
||||
|
||||
if (universal_label != "")
|
||||
{
|
||||
LabelT u_label = 0;
|
||||
index->set_universal_label(u_label);
|
||||
}
|
||||
|
||||
if (max_points_to_insert == 0)
|
||||
{
|
||||
max_points_to_insert = num_points;
|
||||
}
|
||||
|
||||
if (num_points < max_points_to_insert)
|
||||
throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) +
|
||||
") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
if (max_points_to_insert < active_window + consolidate_interval)
|
||||
throw diskann::ANNException("ERROR: max_points_to_insert < "
|
||||
"active_window + consolidate_interval",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
if (consolidate_interval < max_points_to_insert / 1000)
|
||||
throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
||||
|
||||
T *data = nullptr;
|
||||
diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T),
|
||||
8 * sizeof(T));
|
||||
|
||||
std::vector<TagT> tags(max_points_to_insert);
|
||||
std::iota(tags.begin(), tags.end(), static_cast<TagT>(0));
|
||||
|
||||
diskann::Timer timer;
|
||||
|
||||
std::vector<std::future<void>> delete_tasks;
|
||||
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, 0, active_window);
|
||||
insert_next_batch<T, TagT, LabelT>(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim,
|
||||
pts_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert;
|
||||
start += consolidate_interval)
|
||||
{
|
||||
auto end = std::min(start + consolidate_interval, max_points_to_insert);
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_next_batch<T, TagT, LabelT>(*index, start, end, params.num_threads, data, aligned_dim,
|
||||
pts_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
if (delete_tasks.size() > 0)
|
||||
delete_tasks[delete_tasks.size() - 1].wait();
|
||||
if (start >= active_window + consolidate_interval)
|
||||
{
|
||||
auto start_del = start - active_window - consolidate_interval;
|
||||
auto end_del = start - active_window;
|
||||
|
||||
delete_tasks.emplace_back(std::async(std::launch::async, [&]() {
|
||||
delete_and_consolidate<T, TagT, LabelT>(*index, delete_params, (size_t)start_del, (size_t)end_del);
|
||||
}));
|
||||
}
|
||||
}
|
||||
if (delete_tasks.size() > 0)
|
||||
delete_tasks[delete_tasks.size() - 1].wait();
|
||||
|
||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
||||
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
|
||||
diskann::aligned_free(data);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
||||
uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported;
|
||||
float alpha, start_point_norm;
|
||||
size_t max_points_to_insert, active_window, consolidate_interval;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario",
|
||||
"Test insert deletes & consolidate")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("active_window", po::value<uint64_t>(&active_window)->required(),
|
||||
"Program maintains an index over an active window of "
|
||||
"this size that slides through the data");
|
||||
required_configs.add_options()("consolidate_interval", po::value<uint64_t>(&consolidate_interval)->required(),
|
||||
"The program simultaneously adds this number of points to the "
|
||||
"right of "
|
||||
"the window while deleting the same number from the left");
|
||||
required_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("insert_threads",
|
||||
po::value<uint32_t>(&insert_threads)->default_value(omp_get_num_procs() / 2),
|
||||
"Number of threads used for inserting into the index (defaults to "
|
||||
"omp_get_num_procs()/2)");
|
||||
optional_configs.add_options()(
|
||||
"consolidate_threads", po::value<uint32_t>(&consolidate_threads)->default_value(omp_get_num_procs() / 2),
|
||||
"Number of threads used for consolidating deletes to "
|
||||
"the index (defaults to omp_get_num_procs()/2)");
|
||||
optional_configs.add_options()("max_points_to_insert",
|
||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
||||
"The number of points from the file that the program streams "
|
||||
"over ");
|
||||
optional_configs.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
optional_configs.add_options()("unique_labels_supported",
|
||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
||||
"Number of unique labels supported by the dynamic index.");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
if (start_point_norm == 0)
|
||||
{
|
||||
std::cout << "When beginning_index_size is 0, use a start point with "
|
||||
"appropriate norm"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (label_type != std::string("ushort") && label_type != std::string("uint"))
|
||||
{
|
||||
std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float"))
|
||||
{
|
||||
std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO: Are additional distance functions supported?
|
||||
if (dist_fn != std::string("l2") && dist_fn != std::string("mips"))
|
||||
{
|
||||
std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_start_pts < unique_labels_supported)
|
||||
{
|
||||
num_start_pts = unique_labels_supported;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("uint8"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<uint8_t, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<uint8_t, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<int8_t, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<int8_t, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<float, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<float, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Caught unknown exception" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
110
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/CMakeLists.txt
vendored
Normal file
110
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/CMakeLists.txt
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
|
||||
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
|
||||
|
||||
add_executable(fvecs_to_bvecs fvecs_to_bvecs.cpp)
|
||||
|
||||
add_executable(rand_data_gen rand_data_gen.cpp)
|
||||
target_link_libraries(rand_data_gen ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(float_bin_to_int8 float_bin_to_int8.cpp)
|
||||
|
||||
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
|
||||
|
||||
add_executable(count_bfs_levels count_bfs_levels.cpp)
|
||||
target_link_libraries(count_bfs_levels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(tsv_to_bin tsv_to_bin.cpp)
|
||||
|
||||
add_executable(bin_to_tsv bin_to_tsv.cpp)
|
||||
|
||||
add_executable(int8_to_float int8_to_float.cpp)
|
||||
target_link_libraries(int8_to_float ${PROJECT_NAME})
|
||||
|
||||
add_executable(int8_to_float_scale int8_to_float_scale.cpp)
|
||||
target_link_libraries(int8_to_float_scale ${PROJECT_NAME})
|
||||
|
||||
add_executable(uint8_to_float uint8_to_float.cpp)
|
||||
target_link_libraries(uint8_to_float ${PROJECT_NAME})
|
||||
|
||||
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
|
||||
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
|
||||
|
||||
add_executable(vector_analysis vector_analysis.cpp)
|
||||
target_link_libraries(vector_analysis ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(gen_random_slice gen_random_slice.cpp)
|
||||
target_link_libraries(gen_random_slice ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp)
|
||||
|
||||
add_executable(calculate_recall calculate_recall.cpp)
|
||||
target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
# Compute ground truth thing outside of DiskANN main source that depends on MKL.
|
||||
add_executable(compute_groundtruth compute_groundtruth.cpp)
|
||||
target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
||||
target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp)
|
||||
target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
||||
target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
|
||||
add_executable(generate_pq generate_pq.cpp)
|
||||
target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
|
||||
add_executable(partition_data partition_data.cpp)
|
||||
target_link_libraries(partition_data ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(partition_with_ram_budget partition_with_ram_budget.cpp)
|
||||
target_link_libraries(partition_with_ram_budget ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(merge_shards merge_shards.cpp)
|
||||
target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB})
|
||||
|
||||
add_executable(create_disk_layout create_disk_layout.cpp)
|
||||
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
|
||||
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(stats_label_data stats_label_data.cpp)
|
||||
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
if (NOT MSVC)
|
||||
include(GNUInstallDirs)
|
||||
install(TARGETS fvecs_to_bin
|
||||
fvecs_to_bvecs
|
||||
rand_data_gen
|
||||
float_bin_to_int8
|
||||
ivecs_to_bin
|
||||
count_bfs_levels
|
||||
tsv_to_bin
|
||||
bin_to_tsv
|
||||
int8_to_float
|
||||
int8_to_float_scale
|
||||
uint8_to_float
|
||||
uint32_to_uint8
|
||||
vector_analysis
|
||||
gen_random_slice
|
||||
simulate_aggregate_recall
|
||||
calculate_recall
|
||||
compute_groundtruth
|
||||
compute_groundtruth_for_filters
|
||||
generate_pq
|
||||
partition_data
|
||||
partition_with_ram_budget
|
||||
merge_shards
|
||||
create_disk_layout
|
||||
generate_synthetic_labels
|
||||
stats_label_data
|
||||
RUNTIME
|
||||
)
|
||||
endif()
|
||||
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_fvecs.cpp
vendored
Normal file
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_fvecs.cpp
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "util.h"
|
||||
|
||||
void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts,
|
||||
uint64_t ndims)
|
||||
{
|
||||
writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned)));
|
||||
#pragma omp parallel for
|
||||
for (uint64_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
||||
}
|
||||
readr.read((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream readr(argv[1], std::ios::binary);
|
||||
int npts_s32;
|
||||
int ndims_s32;
|
||||
readr.read((char *)&npts_s32, sizeof(int32_t));
|
||||
readr.read((char *)&ndims_s32, sizeof(int32_t));
|
||||
size_t npts = npts_s32;
|
||||
size_t ndims = ndims_s32;
|
||||
uint32_t ndims_u32 = (uint32_t)ndims_s32;
|
||||
// uint64_t fsize = writr.tellg();
|
||||
readr.seekg(0, std::ios::beg);
|
||||
|
||||
unsigned ndims_u32;
|
||||
writr.write((char *)&ndims_u32, sizeof(unsigned));
|
||||
writr.seekg(0, std::ios::beg);
|
||||
uint64_t ndims = (uint64_t)ndims_u32;
|
||||
uint64_t npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
uint64_t blk_size = 131072;
|
||||
uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
std::ofstream writr(argv[2], std::ios::binary);
|
||||
float *read_buf = new float[npts * (ndims + 1)];
|
||||
float *write_buf = new float[npts * ndims];
|
||||
for (uint64_t i = 0; i < nblks; i++)
|
||||
{
|
||||
uint64_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writr.close();
|
||||
readr.close();
|
||||
}
|
||||
69
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_tsv.cpp
vendored
Normal file
69
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/bin_to_tsv.cpp
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
template <class T>
|
||||
void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(float));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
writer << read_buf[d + i * ndims];
|
||||
if (d < ndims - 1)
|
||||
writer << "\t";
|
||||
else
|
||||
writer << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <float/int8/uint8> input_bin output_tsv" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string type_string(argv[1]);
|
||||
if ((type_string != std::string("float")) && (type_string != std::string("int8")) &&
|
||||
(type_string != std::string("uin8")))
|
||||
{
|
||||
std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl;
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[3]);
|
||||
char *read_buf = new char[blk_size * ndims * 4];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (type_string == std::string("float"))
|
||||
block_convert<float>(writer, reader, (float *)read_buf, cblk_size, ndims);
|
||||
else if (type_string == std::string("int8"))
|
||||
block_convert<int8_t>(writer, reader, (int8_t *)read_buf, cblk_size, ndims);
|
||||
else if (type_string == std::string("uint8"))
|
||||
block_convert<uint8_t>(writer, reader, (uint8_t *)read_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
55
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/calculate_recall.cpp
vendored
Normal file
55
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/calculate_recall.cpp
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <ground_truth_bin> <our_results_bin> <r> " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
uint32_t *gold_std = NULL;
|
||||
float *gs_dist = nullptr;
|
||||
uint32_t *our_results = NULL;
|
||||
float *or_dist = nullptr;
|
||||
size_t points_num, points_num_gs, points_num_or;
|
||||
size_t dim_gs;
|
||||
size_t dim_or;
|
||||
diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs);
|
||||
diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or);
|
||||
|
||||
if (points_num_gs != points_num_or)
|
||||
{
|
||||
std::cout << "Error. Number of queries mismatch in ground truth and "
|
||||
"our results"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
points_num = points_num_gs;
|
||||
|
||||
uint32_t recall_at = std::atoi(argv[3]);
|
||||
|
||||
if ((dim_or < recall_at) || (recall_at > dim_gs))
|
||||
{
|
||||
std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall "
|
||||
<< recall_at << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "Calculating recall@" << recall_at << std::endl;
|
||||
double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs,
|
||||
our_results, (uint32_t)dim_or, (uint32_t)recall_at);
|
||||
|
||||
// double avg_recall = (recall*1.0)/(points_num*1.0);
|
||||
std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n";
|
||||
}
|
||||
574
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth.cpp
vendored
Normal file
574
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth.cpp
vendored
Normal file
@@ -0,0 +1,574 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <unordered_map>
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
#include "filter_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <class T> T div_round_up(const T numerator, const T denominator)
|
||||
{
|
||||
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<size_t, float>;
|
||||
struct cmpmaxstruct
|
||||
{
|
||||
bool operator()(const pairIF &l, const pairIF &r)
|
||||
{
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
|
||||
{
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
|
||||
{
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
|
||||
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
|
||||
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
|
||||
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void inner_prod_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
size_t *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
float *points_in, // points in Col major
|
||||
size_t nqueries, float *queries_in,
|
||||
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
compute_l2sq(points_l2sq, points_in, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
|
||||
|
||||
float *points = points_in;
|
||||
float *queries = queries_in;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{ // we convert cosine distance as
|
||||
// normalized L2 distnace
|
||||
points = new float[npoints * dim];
|
||||
queries = new float[nqueries * dim];
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)npoints; i++)
|
||||
{
|
||||
float norm = std::sqrt(points_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
points[i * dim + j] = points_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)nqueries; i++)
|
||||
{
|
||||
float norm = std::sqrt(queries_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
queries[i * dim + j] = queries_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
// recalculate norms after normalizing, they should all be one.
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
}
|
||||
|
||||
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
|
||||
<< dim << " dimensions using";
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
std::cout << " MIPS ";
|
||||
else if (metric == diskann::Metric::COSINE)
|
||||
std::cout << " Cosine ";
|
||||
else
|
||||
std::cout << " L2 ";
|
||||
std::cout << "distance fn. " << std::endl;
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
|
||||
|
||||
for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
|
||||
{
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
|
||||
{
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
|
||||
}
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++)
|
||||
{
|
||||
maxPQIFCS point_dist;
|
||||
for (size_t p = 0; p < k; p++)
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
for (size_t p = k; p < npoints; p++)
|
||||
{
|
||||
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
|
||||
{
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{
|
||||
delete[] points;
|
||||
delete[] queries;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> inline int get_num_parts(const char *filename)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
uint32_t num_parts =
|
||||
(npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (uint64_t)ndims_i32;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B"
|
||||
<< std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
T *data_T = new T[npts * ndims];
|
||||
reader.read((char *)data_T, sizeof(T) * npts * ndims);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
data = aligned_malloc<float>(npts * ndims, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (int64_t j = 0; j < (int64_t)ndims; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndims + j];
|
||||
std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *)distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
|
||||
size_t &nqueries, size_t &npoints,
|
||||
size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric,
|
||||
std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
float *base_data = nullptr;
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
|
||||
metric);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (size_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k,
|
||||
const diskann::Metric &metric, const std::string &tags_file = std::string(""))
|
||||
{
|
||||
size_t npoints, nqueries, dim;
|
||||
|
||||
float *query_data;
|
||||
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
if (nqueries > PARTSIZE)
|
||||
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
|
||||
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
|
||||
|
||||
// load tags
|
||||
const bool tags_enabled = tags_file.empty() ? false : true;
|
||||
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
|
||||
|
||||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results =
|
||||
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
size_t j = 0;
|
||||
for (auto iter : cur_res)
|
||||
{
|
||||
if (j == k)
|
||||
break;
|
||||
if (tags_enabled)
|
||||
{
|
||||
std::uint32_t index_with_tag = location_to_tag[iter.first];
|
||||
closest_points[i * k + j] = (int32_t)index_with_tag;
|
||||
}
|
||||
else
|
||||
{
|
||||
closest_points[i * k + j] = (int32_t)iter.first;
|
||||
}
|
||||
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
dist_closest_points[i * k + j] = -iter.second;
|
||||
else
|
||||
dist_closest_points[i * k + j] = iter.second;
|
||||
|
||||
++j;
|
||||
}
|
||||
if (j < k)
|
||||
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
diskann::aligned_free(query_data);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
|
||||
{
|
||||
size_t read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
npts = (uint32_t)npts_i32;
|
||||
dim = (uint32_t)dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1)
|
||||
{
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char *)dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
|
||||
uint64_t K;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
"distance function <l2/mips/cosine>");
|
||||
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
|
||||
"File containing the base vectors in binary format");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary "
|
||||
"format, please don' append .bin at end if "
|
||||
"no filter_label or filter_label_file is provided it "
|
||||
"will save the file with '.bin' at end."
|
||||
"else it will save the file as filename_label.bin");
|
||||
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
|
||||
"Number of ground truth nearest neighbors to compute");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"File containing the tags in binary format");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
919
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth_for_filters.cpp
vendored
Normal file
919
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/compute_groundtruth_for_filters.cpp
vendored
Normal file
@@ -0,0 +1,919 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <unordered_map>
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
|
||||
#include "filter_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <class T> T div_round_up(const T numerator, const T denominator)
|
||||
{
|
||||
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<size_t, float>;
|
||||
struct cmpmaxstruct
|
||||
{
|
||||
bool operator()(const pairIF &l, const pairIF &r)
|
||||
{
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
|
||||
{
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
|
||||
{
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
|
||||
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
|
||||
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
|
||||
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void inner_prod_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
size_t *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
float *points_in, // points in Col major
|
||||
size_t nqueries, float *queries_in,
|
||||
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
compute_l2sq(points_l2sq, points_in, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
|
||||
|
||||
float *points = points_in;
|
||||
float *queries = queries_in;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{ // we convert cosine distance as
|
||||
// normalized L2 distnace
|
||||
points = new float[npoints * dim];
|
||||
queries = new float[nqueries * dim];
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)npoints; i++)
|
||||
{
|
||||
float norm = std::sqrt(points_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
points[i * dim + j] = points_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)nqueries; i++)
|
||||
{
|
||||
float norm = std::sqrt(queries_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
queries[i * dim + j] = queries_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
// recalculate norms after normalizing, they should all be one.
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
}
|
||||
|
||||
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
|
||||
<< dim << " dimensions using";
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
std::cout << " MIPS ";
|
||||
else if (metric == diskann::Metric::COSINE)
|
||||
std::cout << " Cosine ";
|
||||
else
|
||||
std::cout << " L2 ";
|
||||
std::cout << "distance fn. " << std::endl;
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
|
||||
|
||||
for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
|
||||
{
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
|
||||
{
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
|
||||
}
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++)
|
||||
{
|
||||
maxPQIFCS point_dist;
|
||||
for (size_t p = 0; p < k; p++)
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
for (size_t p = k; p < npoints; p++)
|
||||
{
|
||||
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
|
||||
{
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{
|
||||
delete[] points;
|
||||
delete[] queries;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> inline int get_num_parts(const char *filename)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts_u64 = end_id - start_id;
|
||||
ndims_u64 = (uint64_t)ndims_i32;
|
||||
std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64
|
||||
<< ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
T *data_T = new T[npts_u64 * ndims_u64];
|
||||
reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
data = aligned_malloc<float>(npts_u64 * ndims_u64, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t)npts_u64; i++)
|
||||
{
|
||||
for (int64_t j = 0; j < (int64_t)ndims_u64; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndims_u64 + j];
|
||||
std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<size_t> load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims,
|
||||
int part_num, const char *label_file,
|
||||
const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &npoints_filt,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels)
|
||||
{
|
||||
std::ifstream reader(filename, std::ios::binary);
|
||||
if (reader.fail())
|
||||
{
|
||||
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
|
||||
}
|
||||
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
std::vector<size_t> rev_map;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (uint32_t)ndims_i32;
|
||||
uint64_t nptsuint64_t = (uint64_t)npts;
|
||||
uint64_t ndimsuint64_t = (uint64_t)ndims;
|
||||
npoints_filt = 0;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl;
|
||||
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
|
||||
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
|
||||
reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
|
||||
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
|
||||
|
||||
for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++)
|
||||
{
|
||||
if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) !=
|
||||
pts_to_labels[start_id + i].end() ||
|
||||
std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) !=
|
||||
pts_to_labels[start_id + i].end())
|
||||
{
|
||||
rev_map.push_back(start_id + i);
|
||||
for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndimsuint64_t + j];
|
||||
std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
npoints_filt++;
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float.. identified " << npoints_filt
|
||||
<< " points matching the filter." << std::endl;
|
||||
return rev_map;
|
||||
}
|
||||
|
||||
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *)distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels)
|
||||
{
|
||||
std::ifstream infile(map_file);
|
||||
std::string line, token;
|
||||
std::set<std::string> labels;
|
||||
infile.clear();
|
||||
infile.seekg(0, std::ios::beg);
|
||||
while (std::getline(infile, line))
|
||||
{
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> lbls(0);
|
||||
|
||||
getline(iss, token, '\t');
|
||||
std::istringstream new_iss(token);
|
||||
while (getline(new_iss, token, ','))
|
||||
{
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
lbls.push_back(token);
|
||||
labels.insert(token);
|
||||
}
|
||||
std::sort(lbls.begin(), lbls.end());
|
||||
pts_to_labels.push_back(lbls);
|
||||
}
|
||||
std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for "
|
||||
<< pts_to_labels.size() << " points" << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
|
||||
size_t &nqueries, size_t &npoints,
|
||||
size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric,
|
||||
std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
float *base_data = nullptr;
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
|
||||
metric);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processFilteredParts(
|
||||
const std::string &base_file, const std::string &label_file, const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric, std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
size_t npoints_filt = 0;
|
||||
float *base_data = nullptr;
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
|
||||
std::vector<std::vector<std::string>> pts_to_labels;
|
||||
if (filter_label != "")
|
||||
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
|
||||
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
std::vector<size_t> rev_map;
|
||||
if (filter_label != "")
|
||||
rev_map = load_filtered_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
|
||||
filter_label, universal_label, npoints_filt, pts_to_labels);
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints_filt ? k : npoints_filt;
|
||||
if (npoints_filt > 0)
|
||||
{
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries,
|
||||
query_data, metric);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file,
|
||||
const std::string >_file, size_t k, const std::string &universal_label, const diskann::Metric &metric,
|
||||
const std::string &filter_label, const std::string &tags_file = std::string(""))
|
||||
{
|
||||
size_t npoints, nqueries, dim;
|
||||
|
||||
float *query_data = nullptr;
|
||||
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
if (nqueries > PARTSIZE)
|
||||
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
|
||||
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
|
||||
|
||||
// load tags
|
||||
const bool tags_enabled = tags_file.empty() ? false : true;
|
||||
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
|
||||
|
||||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results;
|
||||
if (filter_label == "")
|
||||
{
|
||||
results = processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
|
||||
}
|
||||
else
|
||||
{
|
||||
results = processFilteredParts<T>(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim,
|
||||
k, query_data, metric, location_to_tag);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
size_t j = 0;
|
||||
for (auto iter : cur_res)
|
||||
{
|
||||
if (j == k)
|
||||
break;
|
||||
if (tags_enabled)
|
||||
{
|
||||
std::uint32_t index_with_tag = location_to_tag[iter.first];
|
||||
closest_points[i * k + j] = (int32_t)index_with_tag;
|
||||
}
|
||||
else
|
||||
{
|
||||
closest_points[i * k + j] = (int32_t)iter.first;
|
||||
}
|
||||
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
dist_closest_points[i * k + j] = -iter.second;
|
||||
else
|
||||
dist_closest_points[i * k + j] = iter.second;
|
||||
|
||||
++j;
|
||||
}
|
||||
if (j < k)
|
||||
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
diskann::aligned_free(query_data);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
|
||||
{
|
||||
size_t read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
npts = (uint32_t)npts_i32;
|
||||
dim = (uint32_t)dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1)
|
||||
{
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char *)dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label,
|
||||
universal_label, filter_label_file;
|
||||
uint64_t K;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(), "distance function <l2/mips>");
|
||||
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
|
||||
"File containing the base vectors in binary format");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input labels file in txt format if present");
|
||||
desc.add_options()("filter_label", po::value<std::string>(&filter_label)->default_value(""),
|
||||
"Input filter label if doing filtered groundtruth");
|
||||
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with label_file");
|
||||
desc.add_options()("gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary "
|
||||
"format, please don' append .bin at end if "
|
||||
"no filter_label or filter_label_file is provided it "
|
||||
"will save the file with '.bin' at end."
|
||||
"else it will save the file as filename_label.bin");
|
||||
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
|
||||
"Number of ground truth nearest neighbors to compute");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"File containing the tags in binary format");
|
||||
desc.add_options()("filter_label_file",
|
||||
po::value<std::string>(&filter_label_file)->default_value(std::string("")),
|
||||
"Filter file for Queries for Filtered Search ");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && filter_label_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> filter_labels;
|
||||
if (filter_label != "")
|
||||
{
|
||||
filter_labels.push_back(filter_label);
|
||||
}
|
||||
else if (filter_label_file != "")
|
||||
{
|
||||
filter_labels = read_file_to_vector_of_strings(filter_label_file, false);
|
||||
}
|
||||
|
||||
// only when there is no filter label or 1 filter label for all queries
|
||||
if (filter_labels.size() == 1)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{ // Each query has its own filter label
|
||||
// Split up data and query bins into label specific ones
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_queries;
|
||||
|
||||
label_set all_labels;
|
||||
for (size_t i = 0; i < filter_labels.size(); i++)
|
||||
{
|
||||
std::string label = filter_labels[i];
|
||||
all_labels.insert(label);
|
||||
|
||||
if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end())
|
||||
{
|
||||
labels_to_number_of_queries[label] = 0;
|
||||
}
|
||||
labels_to_number_of_queries[label] += 1;
|
||||
}
|
||||
|
||||
size_t npoints;
|
||||
std::vector<std::vector<std::string>> point_to_labels;
|
||||
parse_label_file_into_vec(npoints, label_file, point_to_labels);
|
||||
std::vector<label_set> point_ids_to_labels(point_to_labels.size());
|
||||
std::vector<label_set> query_ids_to_labels(filter_labels.size());
|
||||
|
||||
for (size_t i = 0; i < point_to_labels.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < point_to_labels[i].size(); j++)
|
||||
{
|
||||
std::string label = point_to_labels[i][j];
|
||||
if (all_labels.find(label) != all_labels.end())
|
||||
{
|
||||
point_ids_to_labels[i].insert(point_to_labels[i][j]);
|
||||
if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end())
|
||||
{
|
||||
labels_to_number_of_points[label] = 0;
|
||||
}
|
||||
labels_to_number_of_points[label] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < filter_labels.size(); i++)
|
||||
{
|
||||
query_ids_to_labels[i].insert(filter_labels[i]);
|
||||
}
|
||||
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_query_id_to_orig_id;
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Invalid data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Generate label specific ground truths
|
||||
|
||||
try
|
||||
{
|
||||
for (const auto &label : all_labels)
|
||||
{
|
||||
std::string filtered_base_file = base_file + "_" + label;
|
||||
std::string filtered_query_file = query_file + "_" + label;
|
||||
std::string filtered_gt_file = gt_file + "_" + label;
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Combine the label specific ground truths to produce a single GT file
|
||||
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t gt_num, gt_dim;
|
||||
|
||||
std::vector<std::vector<int32_t>> final_gt_ids;
|
||||
std::vector<std::vector<float>> final_gt_dists;
|
||||
|
||||
uint32_t query_num = 0;
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
query_num += labels_to_number_of_queries[lbl];
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
final_gt_ids.push_back(std::vector<int32_t>(K));
|
||||
final_gt_dists.push_back(std::vector<float>(K));
|
||||
}
|
||||
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
std::string filtered_gt_file = gt_file + "_" + lbl;
|
||||
load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
|
||||
for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++)
|
||||
{
|
||||
uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i];
|
||||
for (uint64_t j = 0; j < K; j++)
|
||||
{
|
||||
final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]];
|
||||
final_gt_dists[orig_query_id][j] = gt_dists[i * K + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t *closest_points = new int32_t[query_num * K];
|
||||
float *dist_closest_points = new float[query_num * K];
|
||||
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
for (uint32_t j = 0; j < K; j++)
|
||||
{
|
||||
closest_points[i * K + j] = final_gt_ids[i][j];
|
||||
dist_closest_points[i * K + j] = final_gt_dists[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K);
|
||||
|
||||
// cleanup artifacts
|
||||
std::cout << "Cleaning up artifacts..." << std::endl;
|
||||
tsl::robin_set<std::string> paths_to_clean{gt_file, base_file, query_file};
|
||||
clean_up_artifacts(paths_to_clean, all_labels);
|
||||
}
|
||||
}
|
||||
82
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/count_bfs_levels.cpp
vendored
Normal file
82
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/count_bfs_levels.cpp
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "utils.h"
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T> void bfs_count(const std::string &index_path, uint32_t data_dims)
|
||||
{
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false,
|
||||
false, 0, false);
|
||||
std::cout << "Index class instantiated" << std::endl;
|
||||
index.load(index_path.c_str(), 1, 100);
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
index.count_nodes_at_bfs_levels();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, index_path_prefix;
|
||||
uint32_t data_dims;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix to the index");
|
||||
desc.add_options()("data_dims", po::value<uint32_t>(&data_dims)->required(), "Dimensionality of the data");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
bfs_count<int8_t>(index_path_prefix, data_dims);
|
||||
else if (data_type == std::string("uint8"))
|
||||
bfs_count<uint8_t>(index_path_prefix, data_dims);
|
||||
if (data_type == std::string("float"))
|
||||
bfs_count<float>(index_path_prefix, data_dims);
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index BFS failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
48
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/create_disk_layout.cpp
vendored
Normal file
48
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/create_disk_layout.cpp
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
#include "cached_io.h"
|
||||
|
||||
template <typename T> int create_disk_layout(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
std::string vamana_file(argv[3]);
|
||||
std::string output_file(argv[4]);
|
||||
diskann::create_disk_layout<T>(base_file, vamana_file, output_file);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type <float/int8/uint8> data_bin "
|
||||
"vamana_index_file output_diskann_index_file"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int ret_val = -1;
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
ret_val = create_disk_layout<float>(argv);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
ret_val = create_disk_layout<int8_t>(argv);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
ret_val = create_disk_layout<uint8_t>(argv);
|
||||
else
|
||||
{
|
||||
std::cout << "unsupported type. use int8/uint8/float " << std::endl;
|
||||
ret_val = -2;
|
||||
}
|
||||
return ret_val;
|
||||
}
|
||||
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/float_bin_to_int8.cpp
vendored
Normal file
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/float_bin_to_int8.cpp
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts,
|
||||
size_t ndims, float bias, float scale)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(float));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale));
|
||||
}
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[1], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new float[blk_size * ndims];
|
||||
auto write_buf = new int8_t[blk_size * ndims];
|
||||
float bias = (float)atof(argv[3]);
|
||||
float scale = (float)atof(argv[4]);
|
||||
|
||||
writer.write((char *)(&npts_u32), sizeof(uint32_t));
|
||||
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
95
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bin.cpp
vendored
Normal file
95
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bin.cpp
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
// Convert float types
|
||||
void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
// Convert byte types
|
||||
void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf,
|
||||
size_t npts, size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t),
|
||||
ndims * sizeof(uint8_t));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <float/int8/uint8> input_vecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int datasize = sizeof(float);
|
||||
|
||||
if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0)
|
||||
{
|
||||
datasize = sizeof(uint8_t);
|
||||
}
|
||||
else if (strcmp(argv[1], "float") != 0)
|
||||
{
|
||||
std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[3], std::ios::binary);
|
||||
int32_t npts_s32 = (int32_t)npts;
|
||||
int32_t ndims_s32 = (int32_t)ndims;
|
||||
writer.write((char *)&npts_s32, sizeof(int32_t));
|
||||
writer.write((char *)&ndims_s32, sizeof(int32_t));
|
||||
|
||||
size_t chunknpts = std::min(npts, blk_size);
|
||||
uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))];
|
||||
uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize];
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (datasize == sizeof(float))
|
||||
{
|
||||
block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
}
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
56
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bvecs.cpp
vendored
Normal file
56
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/fvecs_to_bvecs.cpp
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t));
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d];
|
||||
}
|
||||
writer.write((char *)write_buf, npts * (ndims * 1 + 4));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new float[npts * (ndims + 1)];
|
||||
auto write_buf = new uint8_t[npts * (ndims + 4)];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/gen_random_slice.cpp
vendored
Normal file
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/gen_random_slice.cpp
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "partition.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
|
||||
template <typename T> int aux_main(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
std::string output_prefix(argv[3]);
|
||||
float sampling_rate = (float)(std::atof(argv[4]));
|
||||
gen_random_slice<T>(base_file, output_prefix, sampling_rate);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type [float/int8/uint8] base_bin_file "
|
||||
"sample_output_prefix sampling_probability"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
aux_main<float>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
aux_main<int8_t>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
aux_main<uint8_t>(argv);
|
||||
}
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
70
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_pq.cpp
vendored
Normal file
70
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_pq.cpp
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "math_utils.h"
|
||||
#include "pq.h"
|
||||
#include "partition.h"
|
||||
|
||||
#define KMEANS_ITERS_FOR_PQ 15
|
||||
|
||||
template <typename T>
|
||||
bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers,
|
||||
const size_t num_pq_chunks, const float sampling_rate, const bool opq)
|
||||
{
|
||||
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
|
||||
std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
|
||||
|
||||
// generates random sample and sets it to train_data and updates train_size
|
||||
size_t train_size, train_dim;
|
||||
float *train_data;
|
||||
gen_random_slice<T>(data_path, sampling_rate, train_data, train_size, train_dim);
|
||||
std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl;
|
||||
|
||||
if (opq)
|
||||
{
|
||||
diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
|
||||
(uint32_t)num_pq_chunks, pq_pivots_path, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
|
||||
(uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path);
|
||||
}
|
||||
diskann::generate_pq_data_from_pivots<T>(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks,
|
||||
pq_pivots_path, pq_compressed_vectors_path, true);
|
||||
|
||||
delete[] train_data;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 7)
|
||||
{
|
||||
std::cout << "Usage: \n"
|
||||
<< argv[0]
|
||||
<< " <data_type[float/uint8/int8]> <data_file[.bin]>"
|
||||
" <PQ_prefix_path> <target-bytes/data-point> "
|
||||
"<sampling_rate> <PQ(0)/OPQ(1)>"
|
||||
<< std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string index_prefix_path(argv[3]);
|
||||
const size_t num_pq_centers = 256;
|
||||
const size_t num_pq_chunks = (size_t)atoi(argv[4]);
|
||||
const float sampling_rate = (float)atof(argv[5]);
|
||||
const bool opq = atoi(argv[6]) == 0 ? false : true;
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
generate_pq<float>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
generate_pq<int8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
generate_pq<uint8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else
|
||||
std::cout << "Error. wrong file type" << std::endl;
|
||||
}
|
||||
}
|
||||
204
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_synthetic_labels.cpp
vendored
Normal file
204
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/generate_synthetic_labels.cpp
vendored
Normal file
@@ -0,0 +1,204 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <math.h>
|
||||
#include <cmath>
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
class ZipfDistribution
|
||||
{
|
||||
public:
|
||||
ZipfDistribution(uint64_t num_points, uint32_t num_labels)
|
||||
: num_labels(num_labels), num_points(num_points),
|
||||
uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0))
|
||||
{
|
||||
}
|
||||
|
||||
std::unordered_map<uint32_t, uint32_t> createDistributionMap()
|
||||
{
|
||||
std::unordered_map<uint32_t, uint32_t> map;
|
||||
uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor);
|
||||
for (uint32_t i{1}; i < num_labels + 1; i++)
|
||||
{
|
||||
map[i] = (uint32_t)ceil(primary_label_freq / i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
int writeDistribution(std::ofstream &outfile)
|
||||
{
|
||||
auto distribution_map = createDistributionMap();
|
||||
for (uint32_t i{0}; i < num_points; i++)
|
||||
{
|
||||
bool label_written = false;
|
||||
for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++)
|
||||
{
|
||||
auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first);
|
||||
if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0)
|
||||
{
|
||||
if (label_written)
|
||||
{
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << it->first;
|
||||
label_written = true;
|
||||
// remove label from map if we have used all labels
|
||||
distribution_map[it->first] -= 1;
|
||||
}
|
||||
}
|
||||
if (!label_written)
|
||||
{
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1)
|
||||
{
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int writeDistribution(std::string filename)
|
||||
{
|
||||
std::ofstream outfile(filename);
|
||||
if (!outfile.is_open())
|
||||
{
|
||||
std::cerr << "Error: could not open output file " << filename << '\n';
|
||||
return -1;
|
||||
}
|
||||
writeDistribution(outfile);
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
private:
|
||||
const uint32_t num_labels;
|
||||
const uint64_t num_points;
|
||||
const double distribution_factor = 0.7;
|
||||
std::knuth_b rand_engine;
|
||||
const std::uniform_real_distribution<double> uniform_zero_to_one;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string output_file, distribution_type;
|
||||
uint32_t num_labels;
|
||||
uint64_t num_points;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("output_file,O", po::value<std::string>(&output_file)->required(),
|
||||
"Filename for saving the label file");
|
||||
desc.add_options()("num_points,N", po::value<uint64_t>(&num_points)->required(), "Number of points in dataset");
|
||||
desc.add_options()("num_labels,L", po::value<uint32_t>(&num_labels)->required(),
|
||||
"Number of unique labels, up to 5000");
|
||||
desc.add_options()("distribution_type,DT", po::value<std::string>(&distribution_type)->default_value("random"),
|
||||
"Distribution function for labels <random/zipf/one_per_point> defaults "
|
||||
"to random");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_labels > 5000)
|
||||
{
|
||||
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_points <= 0)
|
||||
{
|
||||
std::cerr << "Error: num_points must be greater than 0" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels"
|
||||
<< '\n';
|
||||
|
||||
try
|
||||
{
|
||||
std::ofstream outfile(output_file);
|
||||
if (!outfile.is_open())
|
||||
{
|
||||
std::cerr << "Error: could not open output file " << output_file << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (distribution_type == "zipf")
|
||||
{
|
||||
ZipfDistribution zipf(num_points, num_labels);
|
||||
zipf.writeDistribution(outfile);
|
||||
}
|
||||
else if (distribution_type == "random")
|
||||
{
|
||||
for (size_t i = 0; i < num_points; i++)
|
||||
{
|
||||
bool label_written = false;
|
||||
for (size_t j = 1; j <= num_labels; j++)
|
||||
{
|
||||
// 50% chance to assign each label
|
||||
if (rand() > (RAND_MAX / 2))
|
||||
{
|
||||
if (label_written)
|
||||
{
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << j;
|
||||
label_written = true;
|
||||
}
|
||||
}
|
||||
if (!label_written)
|
||||
{
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1)
|
||||
{
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (distribution_type == "one_per_point")
|
||||
{
|
||||
std::random_device rd; // obtain a random number from hardware
|
||||
std::mt19937 gen(rd()); // seed the generator
|
||||
std::uniform_int_distribution<> distr(0, num_labels); // define the range
|
||||
|
||||
for (size_t i = 0; i < num_points; i++)
|
||||
{
|
||||
outfile << distr(gen);
|
||||
if (i != num_points - 1)
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
if (outfile.is_open())
|
||||
{
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
std::cout << "Labels written to " << output_file << '\n';
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Label generation failed: " << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float.cpp
vendored
Normal file
23
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float.cpp
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int8_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<int8_t>(argv[1], input, npts, nd);
|
||||
float *output = new float[npts * nd];
|
||||
diskann::convert_types<int8_t, float>(input, output, npts, nd);
|
||||
diskann::save_bin<float>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float_scale.cpp
vendored
Normal file
63
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/int8_to_float_scale.cpp
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts,
|
||||
size_t ndims, float bias, float scale)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(int8_t));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale);
|
||||
}
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[1], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new int8_t[blk_size * ndims];
|
||||
auto write_buf = new float[blk_size * ndims];
|
||||
float bias = (float)atof(argv[3]);
|
||||
float scale = (float)atof(argv[4]);
|
||||
|
||||
writer.write((char *)(&npts_u32), sizeof(uint32_t));
|
||||
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/ivecs_to_bin.cpp
vendored
Normal file
58
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/ivecs_to_bin.cpp
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_ivecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
int npts_s32 = (int)npts;
|
||||
int ndims_s32 = (int)ndims;
|
||||
writer.write((char *)&npts_s32, sizeof(int));
|
||||
writer.write((char *)&ndims_s32, sizeof(int));
|
||||
uint32_t *read_buf = new uint32_t[npts * (ndims + 1)];
|
||||
uint32_t *write_buf = new uint32_t[npts * ndims];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
42
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/merge_shards.cpp
vendored
Normal file
42
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/merge_shards.cpp
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "disk_utils.h"
|
||||
#include "cached_io.h"
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 9)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " vamana_index_prefix[1] vamana_index_suffix[2] "
|
||||
"idmaps_prefix[3] "
|
||||
"idmaps_suffix[4] n_shards[5] max_degree[6] "
|
||||
"output_vamana_path[7] "
|
||||
"output_medoids_path[8]"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string vamana_prefix(argv[1]);
|
||||
std::string vamana_suffix(argv[2]);
|
||||
std::string idmaps_prefix(argv[3]);
|
||||
std::string idmaps_suffix(argv[4]);
|
||||
uint64_t nshards = (uint64_t)std::atoi(argv[5]);
|
||||
uint32_t max_degree = (uint64_t)std::atoi(argv[6]);
|
||||
std::string output_index(argv[7]);
|
||||
std::string output_medoids(argv[8]);
|
||||
|
||||
return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree,
|
||||
output_index, output_medoids);
|
||||
}
|
||||
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_data.cpp
vendored
Normal file
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_data.cpp
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <math_utils.h>
|
||||
#include "cached_io.h"
|
||||
#include "partition.h"
|
||||
|
||||
// DEPRECATED: NEED TO REPROGRAM
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 7)
|
||||
{
|
||||
std::cout << "Usage:\n"
|
||||
<< argv[0]
|
||||
<< " datatype<int8/uint8/float> <data_path>"
|
||||
" <prefix_path> <sampling_rate> "
|
||||
" <num_partitions> <k_index>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string prefix_path(argv[3]);
|
||||
const float sampling_rate = (float)atof(argv[4]);
|
||||
const size_t num_partitions = (size_t)std::atoi(argv[5]);
|
||||
const size_t max_reps = 15;
|
||||
const size_t k_index = (size_t)std::atoi(argv[6]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
partition<float>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
partition<int8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
partition<uint8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else
|
||||
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
|
||||
}
|
||||
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_with_ram_budget.cpp
vendored
Normal file
39
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/partition_with_ram_budget.cpp
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <math_utils.h>
|
||||
#include "cached_io.h"
|
||||
#include "partition.h"
|
||||
|
||||
// DEPRECATED: NEED TO REPROGRAM
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 8)
|
||||
{
|
||||
std::cout << "Usage:\n"
|
||||
<< argv[0]
|
||||
<< " datatype<int8/uint8/float> <data_path>"
|
||||
" <prefix_path> <sampling_rate> "
|
||||
" <ram_budget(GB)> <graph_degree> <k_index>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string prefix_path(argv[3]);
|
||||
const float sampling_rate = (float)atof(argv[4]);
|
||||
const double ram_budget = (double)std::atof(argv[5]);
|
||||
const size_t graph_degree = (size_t)std::atoi(argv[6]);
|
||||
const size_t k_index = (size_t)std::atoi(argv[7]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
partition_with_ram_budget<float>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
partition_with_ram_budget<int8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
partition_with_ram_budget<uint8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else
|
||||
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
|
||||
}
|
||||
237
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/rand_data_gen.cpp
vendored
Normal file
237
packages/leann-backend-diskann/third_party/DiskANN/apps/utils/rand_data_gen.cpp
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm,
|
||||
float rand_scale)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
std::uniform_real_distribution<> unif_dis(1.0, rand_scale);
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
float scale = 1.0f;
|
||||
if (rand_scale > 1.0f)
|
||||
scale = (float)unif_dis(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = scale * (float)normal_rand(gen);
|
||||
if (normalization)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
}
|
||||
|
||||
writer.write((char *)vec, ndims * sizeof(float));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
auto vec_T = new int8_t[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = (float)normal_rand(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
vec_T[d] = (int8_t)std::round(vec[d]);
|
||||
}
|
||||
|
||||
writer.write((char *)vec_T, ndims * sizeof(int8_t));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
delete[] vec_T;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
auto vec_T = new int8_t[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = (float)normal_rand(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
vec_T[d] = 128 + (int8_t)std::round(vec[d]);
|
||||
}
|
||||
|
||||
writer.write((char *)vec_T, ndims * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
delete[] vec_T;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, output_file;
|
||||
size_t ndims, npts;
|
||||
float norm, rand_scaling;
|
||||
bool normalization = false;
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("output_file", po::value<std::string>(&output_file)->required(),
|
||||
"File name for saving the random vectors");
|
||||
desc.add_options()("ndims,D", po::value<uint64_t>(&ndims)->required(), "Dimensoinality of the vector");
|
||||
desc.add_options()("npts,N", po::value<uint64_t>(&npts)->required(), "Number of vectors");
|
||||
desc.add_options()("norm", po::value<float>(&norm)->default_value(-1.0f),
|
||||
"Norm of the vectors (if not specified, vectors are not normalized)");
|
||||
desc.add_options()("rand_scaling", po::value<float>(&rand_scaling)->default_value(1.0f),
|
||||
"Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from "
|
||||
"[1, rand_scale]. Only applicable for floating point data");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (norm > 0.0)
|
||||
{
|
||||
normalization = true;
|
||||
}
|
||||
|
||||
if (rand_scaling < 1.0)
|
||||
{
|
||||
std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((rand_scaling > 1.0) && (normalization == true))
|
||||
{
|
||||
std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("int8") || data_type == std::string("uint8"))
|
||||
{
|
||||
if (norm > 127)
|
||||
{
|
||||
std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be "
|
||||
"greater "
|
||||
"than 127"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (rand_scaling > 1.0)
|
||||
{
|
||||
std::cout << "Data scaling only supported for floating point data." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
|
||||
writer.open(output_file, std::ios::binary);
|
||||
auto npts_u32 = (uint32_t)npts;
|
||||
auto ndims_u32 = (uint32_t)ndims;
|
||||
writer.write((char *)&npts_u32, sizeof(uint32_t));
|
||||
writer.write((char *)&ndims_u32, sizeof(uint32_t));
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
int ret = 0;
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling);
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
ret = block_write_int8(writer, ndims, cblk_size, norm);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
ret = block_write_uint8(writer, ndims, cblk_size, norm);
|
||||
}
|
||||
if (ret == 0)
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
else
|
||||
{
|
||||
writer.close();
|
||||
std::cout << "failed to write" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
writer.close();
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user