Compare commits
13 Commits
master
...
debug_disk
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df63526503 | ||
|
|
e92deee1e8 | ||
|
|
910927a405 | ||
|
|
0aa84e147b | ||
|
|
368474d036 | ||
|
|
a627abe794 | ||
|
|
44815ee7fd | ||
|
|
371e3de04e | ||
|
|
b81b5d0f86 | ||
|
|
ee507bfe7a | ||
|
|
30898814ae | ||
|
|
a075fd6f47 | ||
|
|
303ff6fe1d |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,6 +29,7 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
|
||||
22
README.md
22
README.md
@@ -43,8 +43,9 @@ Traditional RAG systems face a fundamental trade-off:
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yichuan520030910320/Power-RAG.git leann
|
||||
git clone git@github.com:yichuan520030910320/LEANN-RAG.git leann
|
||||
cd leann
|
||||
git submodule update --init --recursive
|
||||
uv sync
|
||||
```
|
||||
|
||||
@@ -240,6 +241,25 @@ uv run python tests/sanity_checks/test_distance_functions.py
|
||||
# Verify L2 implementation
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
## ❓ FAQ
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### NCCL Topology Error
|
||||
**Problem**: You encounter `ncclTopoComputePaths` error during document processing:
|
||||
```
|
||||
ncclTopoComputePaths (system=<optimized out>, comm=comm@entry=0x5555a82fa3c0) at graph/paths.cc:688
|
||||
```
|
||||
|
||||
**Solution**: Set these environment variables before running your script:
|
||||
```bash
|
||||
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
|
||||
export NCCL_DEBUG=INFO
|
||||
export NCCL_DEBUG_SUBSYS=INIT,GRAPH
|
||||
export NCCL_IB_DISABLE=1
|
||||
export NCCL_NET_PLUGIN=none
|
||||
export NCCL_SOCKET_IFNAME=ens5
|
||||
|
||||
|
||||
## 📈 Roadmap
|
||||
|
||||
|
||||
194
demo.ipynb
194
demo.ipynb
@@ -2,14 +2,31 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: LeannBuilder initialized with 'diskann' backend.\n",
|
||||
"Initializing leann-backend-diskann...\n",
|
||||
"INFO: Registering backend 'diskann'\n",
|
||||
"INFO: DiskANN backend loaded successfully\n",
|
||||
"INFO: LeannBuilder initialized with 'diskann' backend.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ubuntu/LEANN_clean/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Computing embeddings for 6 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
||||
]
|
||||
},
|
||||
@@ -17,7 +34,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 77.61it/s]"
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 2.91it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -31,7 +48,7 @@
|
||||
"Writing bin: knowledge_disk.index_max_base_norm.bin\n",
|
||||
"bin: #pts = 1, #dims = 1, size = 12B\n",
|
||||
"Finished writing bin.\n",
|
||||
"Time for preprocessing data for inner product: 0.000165 seconds\n",
|
||||
"Time for preprocessing data for inner product: 0.000172 seconds\n",
|
||||
"Reading max_norm_of_base from knowledge_disk.index_max_base_norm.bin\n",
|
||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
||||
@@ -41,7 +58,7 @@
|
||||
"! Using prepped_base file at knowledge_prepped_base.bin\n",
|
||||
"Starting index build: R=32 L=64 Query RAM budget: 4.02653e+09 Indexing ram budget: 8 T: 8\n",
|
||||
"getting bin metadata\n",
|
||||
"Time for getting bin metadata: 0.000008 seconds\n",
|
||||
"Time for getting bin metadata: 0.000019 seconds\n",
|
||||
"Compressing 769-dimensional data into 512 bytes per vector.\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Training data with 6 samples loaded.\n",
|
||||
@@ -69,17 +86,17 @@
|
||||
"done.\n",
|
||||
"Loaded PQ pivot information\n",
|
||||
"Processing points [0, 6)...done.\n",
|
||||
"Time for generating quantized data: 0.023918 seconds\n",
|
||||
"Time for generating quantized data: 0.055587 seconds\n",
|
||||
"Full index fits in RAM budget, should consume at most 2.03973e-05GiBs, so building in one shot\n",
|
||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
||||
"Passed, empty search_params while creating index config\n",
|
||||
"Using only first 6 from file.. \n",
|
||||
"Starting index build with 6 points... \n",
|
||||
"0% of index build completed.Starting final cleanup..done. Link time: 9e-05s\n",
|
||||
"0% of index build completed.Starting final cleanup..done. Link time: 0.00011s\n",
|
||||
"Index built with degree: max:5 avg:5 min:5 count(deg<2):0\n",
|
||||
"Not saving tags as they are not enabled.\n",
|
||||
"Time taken for save: 0.000178s.\n",
|
||||
"Time for building merged vamana index: 0.000579 seconds\n",
|
||||
"Time taken for save: 0.000148s.\n",
|
||||
"Time for building merged vamana index: 0.000836 seconds\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Vamana index file size=168\n",
|
||||
"Opened: knowledge_disk.index, cache_size: 67108864\n",
|
||||
@@ -94,11 +111,11 @@
|
||||
"Finished writing bin.\n",
|
||||
"Output disk index file written to knowledge_disk.index\n",
|
||||
"Finished writing 28672B\n",
|
||||
"Time for generating disk layout: 0.043488 seconds\n",
|
||||
"Time for generating disk layout: 0.040268 seconds\n",
|
||||
"Opened: knowledge_prepped_base.bin, size: 18464, cache_size: 18464\n",
|
||||
"Loading base knowledge_prepped_base.bin. #points: 6. #dim: 769.\n",
|
||||
"Wrote 1 points to sample file: knowledge_sample_data.bin\n",
|
||||
"Indexing time: 0.0684344\n",
|
||||
"Indexing time: 0.0970594\n",
|
||||
"INFO: Leann metadata saved to knowledge.leann.meta.json\n"
|
||||
]
|
||||
},
|
||||
@@ -106,7 +123,6 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Opened file : knowledge_disk.index\n"
|
||||
]
|
||||
},
|
||||
@@ -114,12 +130,12 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"✅ DiskANN index loaded successfully.\n",
|
||||
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
|
||||
"Since data is floating point, we assume that it has been appropriately pre-processed (normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we shall invoke an l2 distance function.\n",
|
||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
||||
"L2: Using AVX2 distance computation DistanceL2Float\n",
|
||||
"Before index load\n",
|
||||
"✅ DiskANN index loaded successfully.\n",
|
||||
"INFO: LeannSearcher initialized with 'diskann' backend using index 'knowledge.leann'.\n",
|
||||
"Reading bin file knowledge_pq_compressed.bin ...\n",
|
||||
"Opening bin file knowledge_pq_compressed.bin... \n",
|
||||
"Metadata: #pts = 6, #dims = 512...\n",
|
||||
@@ -147,14 +163,14 @@
|
||||
"Disk-Index File Meta-data: # nodes per sector: 1, max node len (bytes): 3100, max node degree: 5\n",
|
||||
"Disk-Index Meta: nodes per sector: 1, max node len: 3100, max node degree: 5\n",
|
||||
"Setting up thread-specific contexts for nthreads: 8\n",
|
||||
"allocating ctx: 0x78348f4de000 to thread-id:132170359560000\n",
|
||||
"allocating ctx: 0x78348f4cd000 to thread-id:132158431693760\n",
|
||||
"allocating ctx: 0x78348f4bc000 to thread-id:132158442179392\n",
|
||||
"allocating ctx: 0x78348f4ab000 to thread-id:132158421208128\n",
|
||||
"allocating ctx: 0x78348f49a000 to thread-id:132158452665024\n",
|
||||
"allocating ctx: 0x78348f489000 to thread-id:132158389751232\n",
|
||||
"allocating ctx: 0x78348f478000 to thread-id:132158410722496\n",
|
||||
"allocating ctx: 0x78348f467000 to thread-id:132158400236864\n",
|
||||
"allocating ctx: 0x7a33f7204000 to thread-id:134367072315200\n",
|
||||
"allocating ctx: 0x7a33f6805000 to thread-id:134355206802368\n",
|
||||
"allocating ctx: 0x7a33f5e72000 to thread-id:134355217288000\n",
|
||||
"allocating ctx: 0x7a33f5e61000 to thread-id:134355227773632\n",
|
||||
"allocating ctx: 0x7a33f5e50000 to thread-id:134355196316736\n",
|
||||
"allocating ctx: 0x7a33f5e3f000 to thread-id:134355164859840\n",
|
||||
"allocating ctx: 0x7a33f5e2e000 to thread-id:134355175345472\n",
|
||||
"allocating ctx: 0x7a33f5e1d000 to thread-id:134355185831104\n",
|
||||
"Loading centroid data from medoids vector data of 1 medoid(s)\n",
|
||||
"Reading bin file knowledge_disk.index_max_base_norm.bin ...\n",
|
||||
"Opening bin file knowledge_disk.index_max_base_norm.bin... \n",
|
||||
@@ -167,7 +183,6 @@
|
||||
"Graph traversal completed, hops: 3\n",
|
||||
"Loading the cache list into memory....done.\n",
|
||||
"After index load\n",
|
||||
"Clearing scratch\n",
|
||||
"INFO: Computing embeddings for 1 chunks using 'sentence-transformers/all-mpnet-base-v2'...\n"
|
||||
]
|
||||
},
|
||||
@@ -175,15 +190,17 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 92.66it/s]"
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 60.54it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Score: -0.481 - C++ is a powerful programming language\n",
|
||||
"Score: -1.049 - Java is a powerful programming language\n"
|
||||
"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running\n",
|
||||
"INFO: Starting session-level embedding server as a background process...\n",
|
||||
"INFO: Running command from project root: /home/ubuntu/LEANN_clean/leann\n",
|
||||
"INFO: Server process started with PID: 424761\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -197,8 +214,127 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"reserve ratio: 1\n",
|
||||
"Graph traversal completed, hops: 3\n"
|
||||
"✅ Embedding server is up and ready for this session.\n",
|
||||
"[EmbeddingServer LOG]: Initializing leann-backend-diskann...\n",
|
||||
"[EmbeddingServer LOG]: WARNING: Could not import DiskANN backend: cannot import name '_diskannpy' from partially initialized module 'packages.leann-backend-diskann.leann_backend_diskann' (most likely due to a circular import) (/home/ubuntu/LEANN_clean/leann/packages/leann-backend-diskann/leann_backend_diskann/__init__.py)\n",
|
||||
"[EmbeddingServer LOG]: INFO: Initializing embedding server thread on port 5555\n",
|
||||
"[EmbeddingServer LOG]: INFO: Using CUDA device\n",
|
||||
"[EmbeddingServer LOG]: INFO: Loading model sentence-transformers/all-mpnet-base-v2\n",
|
||||
"[EmbeddingServer LOG]: INFO: Using FP16 precision with model: sentence-transformers/all-mpnet-base-v2\n",
|
||||
"[EmbeddingServer LOG]: INFO: Loaded 6 demo documents\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ ROUTER server listening on port 5555\n",
|
||||
"[EmbeddingServer LOG]: INFO: Embedding server ready to serve requests\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 3 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 1 node embeddings: [0]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 0\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000028 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 1, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 1\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.019294 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 1, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000210 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.065444 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.041810 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000194 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.128073 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 2, 3, 4, 5]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 1 to 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000042 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001791 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000112 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 3.674183 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000372 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000177 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 3.677425 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 4, 2, 1, 0]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 4\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000030 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001550 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000097 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.009335 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000154 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000073 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011773 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 1, 2, 4, 5]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001041 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000125 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008972 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000151 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000048 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010853 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [3, 1, 0, 2, 5]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001350 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000088 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008869 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000146 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000063 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.011054 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [0, 2, 3, 4, 5]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000022 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001195 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008903 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000145 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000060 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010921 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Received ZMQ request from client 006b8b45, size 7 bytes\n",
|
||||
"[EmbeddingServer LOG]: INFO: Request for 5 node embeddings: [1, 0, 3, 4, 5]\n",
|
||||
"[EmbeddingServer LOG]: DEBUG: Node ID range: 0 to 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for text lookup: 0.000020 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Total batch size: 5, max_batch_size: 128\n",
|
||||
"[EmbeddingServer LOG]: INFO: Processing batch of size 5\n",
|
||||
"[EmbeddingServer LOG]: Time taken for tokenization (batch): 0.001188 seconds\n",
|
||||
"[EmbeddingServer LOG]: Batch size: 5, Sequence length: 256\n",
|
||||
"[EmbeddingServer LOG]: Time taken for transfer to device (batch): 0.000087 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for embedding (batch): 0.008858 seconds\n",
|
||||
"[EmbeddingServer LOG]: Time taken for mean pooling (batch): 0.000153 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: Serialize time: 0.000052 seconds\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ E2E time: 0.010886 seconds\n",
|
||||
"reserve ratio: Score: -0.481 - C++ is a powerful programming language1\n",
|
||||
"Graph traversal completed, hops: 3\n",
|
||||
"\n",
|
||||
"Score: -1.049 - Java is a powerful programming language\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n",
|
||||
"[EmbeddingServer LOG]: INFO: ZMQ socket timeout, continuing to listen\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -217,7 +353,7 @@
|
||||
"\n",
|
||||
"# 2. Search with real-time embeddings\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"results = searcher.search(\"C++ programming languages\", top_k=2)\n",
|
||||
"results = searcher.search(\"C++ programming languages\", top_k=2,recompute_beighbor_embeddings=True)\n",
|
||||
"\n",
|
||||
"for result in results:\n",
|
||||
" print(f\"Score: {result['score']:.3f} - {result['text']}\")"
|
||||
|
||||
BIN
examples/data/FairTree__OSDI_25_ (1).pdf
Normal file
BIN
examples/data/FairTree__OSDI_25_ (1).pdf
Normal file
Binary file not shown.
@@ -1,3 +1,6 @@
|
||||
import faulthandler
|
||||
faulthandler.enable()
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, Settings
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.node_parser.docling import DoclingNodeParser
|
||||
@@ -7,7 +10,7 @@ import asyncio
|
||||
import os
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
||||
import leann_backend_diskann # Import to ensure backend registration
|
||||
import leann_backend_hnsw # Import to ensure backend registration
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -21,9 +24,9 @@ file_extractor: dict[str, BaseReader] = {
|
||||
".xlsx": reader,
|
||||
}
|
||||
node_parser = DoclingNodeParser(
|
||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=10240)
|
||||
chunker=HybridChunker(tokenizer="Qwen/Qwen3-Embedding-4B", max_tokens=64)
|
||||
)
|
||||
|
||||
print("Loading documents...")
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
recursive=True,
|
||||
@@ -31,11 +34,9 @@ documents = SimpleDirectoryReader(
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".docx", ".pptx", ".xlsx"]
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Extract text from documents and prepare for Leann
|
||||
print("Documents loaded.")
|
||||
all_texts = []
|
||||
for doc in documents:
|
||||
# DoclingNodeParser returns Node objects, which have a text attribute
|
||||
nodes = node_parser.get_nodes_from_documents([doc])
|
||||
for node in nodes:
|
||||
all_texts.append(node.text)
|
||||
@@ -43,33 +44,37 @@ for doc in documents:
|
||||
INDEX_DIR = Path("./test_pdf_index")
|
||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
||||
|
||||
if INDEX_DIR.exists():
|
||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
||||
shutil.rmtree(INDEX_DIR)
|
||||
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name="diskann",
|
||||
embedding_model="sentence-transformers/all-mpnet-base-v2", # Using a common sentence transformer model
|
||||
graph_degree=32,
|
||||
complexity=64
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
if not INDEX_DIR.exists():
|
||||
print(f"--- Index directory not found, building new index ---")
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
print(f"\n[PHASE 1] Building Leann index...")
|
||||
|
||||
# CSR compact mode with recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model="facebook/contriever",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_compact=True,
|
||||
is_recompute=True
|
||||
)
|
||||
|
||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
||||
for chunk_text in all_texts:
|
||||
builder.add_text(chunk_text)
|
||||
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
||||
else:
|
||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
||||
|
||||
async def main():
|
||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
||||
chat = LeannChat(index_path=INDEX_PATH)
|
||||
|
||||
query = "Based on the paper, what are the two main techniques LEANN uses to achieve low storage overhead and high retrieval accuracy?"
|
||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
||||
print(f"You: {query}")
|
||||
chat_response = chat.ask(query, recompute_beighbor_embeddings=True)
|
||||
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
|
||||
print(f"Leann: {chat_response}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -3,11 +3,17 @@ Simple demo showing basic leann usage
|
||||
Run: uv run python examples/simple_demo.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
||||
|
||||
|
||||
def main():
|
||||
print("=== Leann Simple Demo ===")
|
||||
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
||||
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||
print()
|
||||
|
||||
# Sample knowledge base
|
||||
@@ -24,10 +30,11 @@ def main():
|
||||
|
||||
print("1. Building index (no embeddings stored)...")
|
||||
builder = LeannBuilder(
|
||||
embedding_model="sentence-transformers/all-mpnet-base-v2",
|
||||
prune_ratio=0.7, # Keep 30% of connections
|
||||
embedding_model=args.embedding_model,
|
||||
backend_name="hnsw",
|
||||
)
|
||||
builder.add_chunks(chunks)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk)
|
||||
builder.build_index("demo_knowledge.leann")
|
||||
print()
|
||||
|
||||
@@ -49,14 +56,7 @@ def main():
|
||||
print(f" Text: {result.text[:100]}...")
|
||||
print()
|
||||
|
||||
print("3. Memory stats:")
|
||||
stats = searcher.get_memory_stats()
|
||||
print(f" Cache size: {stats.embedding_cache_size}")
|
||||
print(f" Cache memory: {stats.embedding_cache_memory_mb:.1f} MB")
|
||||
print(f" Total chunks: {stats.total_chunks}")
|
||||
print()
|
||||
|
||||
print("4. Interactive chat demo:")
|
||||
print("3. Interactive chat demo:")
|
||||
print(" (Note: Requires OpenAI API key for real responses)")
|
||||
|
||||
chat = LeannChat("demo_knowledge.leann")
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
print("Initializing leann-backend-diskann...")
|
||||
|
||||
try:
|
||||
from .diskann_backend import DiskannBackend
|
||||
print("INFO: DiskANN backend loaded successfully")
|
||||
except ImportError as e:
|
||||
print(f"WARNING: Could not import DiskANN backend: {e}")
|
||||
@@ -76,8 +76,8 @@ class EmbeddingServerManager:
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
# stdout=subprocess.PIPE,
|
||||
# stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
@@ -143,20 +143,16 @@ class DiskannBackend(LeannBackendFactoryInterface):
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.name}.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(meta.get("embedding_model"))
|
||||
dimensions = model.get_sentence_embedding_dimension()
|
||||
kwargs['dimensions'] = dimensions
|
||||
except ImportError:
|
||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
||||
|
||||
dimensions = meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
|
||||
|
||||
kwargs['dimensions'] = dimensions
|
||||
return DiskannSearcher(index_path, **kwargs)
|
||||
|
||||
class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
@@ -246,7 +242,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
raise
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
||||
complexity = kwargs.get("complexity", 100)
|
||||
complexity = kwargs.get("complexity", 256)
|
||||
beam_width = kwargs.get("beam_width", 4)
|
||||
|
||||
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
|
||||
@@ -259,7 +255,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
if recompute_beighbor_embeddings:
|
||||
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
|
||||
zmq_port = kwargs.get("zmq_port", 5555)
|
||||
zmq_port = kwargs.get("zmq_port", 6666)
|
||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
|
||||
|
||||
1
packages/leann-backend-diskann/third_party/DiskANN
vendored
Submodule
1
packages/leann-backend-diskann/third_party/DiskANN
vendored
Submodule
Submodule packages/leann-backend-diskann/third_party/DiskANN added at 015c201141
@@ -1,6 +0,0 @@
|
||||
---
|
||||
BasedOnStyle: Microsoft
|
||||
---
|
||||
Language: Cpp
|
||||
SortIncludes: false
|
||||
...
|
||||
@@ -1,14 +0,0 @@
|
||||
# Set the default behavior, in case people don't have core.autocrlf set.
|
||||
* text=auto
|
||||
|
||||
# Explicitly declare text files you want to always be normalized and converted
|
||||
# to native line endings on checkout.
|
||||
*.c text
|
||||
*.h text
|
||||
|
||||
# Declare files that will always have CRLF line endings on checkout.
|
||||
*.sln text eol=crlf
|
||||
|
||||
# Denote all files that are truly binary and should not be modified.
|
||||
*.png binary
|
||||
*.jpg binary
|
||||
@@ -1,40 +0,0 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Bug reports help us improve! Thanks for submitting yours!
|
||||
title: "[BUG] "
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Expected Behavior
|
||||
Tell us what should happen
|
||||
|
||||
## Actual Behavior
|
||||
Tell us what happens instead
|
||||
|
||||
## Example Code
|
||||
Please see [How to create a Minimal, Reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) for some guidance on creating the best possible example of the problem
|
||||
```bash
|
||||
|
||||
```
|
||||
|
||||
## Dataset Description
|
||||
Please tell us about the shape and datatype of your data, (e.g. 128 dimensions, 12.3 billion points, floats)
|
||||
- Dimensions:
|
||||
- Number of Points:
|
||||
- Data type:
|
||||
|
||||
## Error
|
||||
```
|
||||
Paste the full error, with any sensitive information minimally redacted and marked $$REDACTED$$
|
||||
|
||||
```
|
||||
|
||||
## Your Environment
|
||||
* Operating system (e.g. Windows 11 Pro, Ubuntu 22.04.1 LTS)
|
||||
* DiskANN version (or commit built from)
|
||||
|
||||
## Additional Details
|
||||
Any other contextual information you might feel is important.
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
blank_issues_enabled: false
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Is your feature request related to a problem? Please describe.
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
## Describe the solution you'd like
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
## Describe alternatives you've considered
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
## Provide references (if applicable)
|
||||
If your feature request is related to a published algorithm/idea, please provide links to
|
||||
any relevant articles or webpages.
|
||||
|
||||
## Additional context
|
||||
Add any other context or screenshots about the feature request here.
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
---
|
||||
name: Usage Question
|
||||
about: Ask us a question about DiskANN!
|
||||
title: "[Question]"
|
||||
labels: question
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
This is our forum for asking whatever DiskANN question you'd like! No need to feel shy - we're happy to talk about use cases and optimal tuning strategies!
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
<!--
|
||||
Thanks for contributing a pull request! Please ensure you have taken a look at
|
||||
the contribution guidelines: https://github.com/microsoft/DiskANN/blob/main/CONTRIBUTING.md
|
||||
-->
|
||||
- [ ] Does this PR have a descriptive title that could go in our release notes?
|
||||
- [ ] Does this PR add any new dependencies?
|
||||
- [ ] Does this PR modify any existing APIs?
|
||||
- [ ] Is the change to the API backwards compatible?
|
||||
- [ ] Should this result in any changes to our documentation, either updating existing docs or adding new ones?
|
||||
|
||||
#### Reference Issues/PRs
|
||||
<!--
|
||||
Example: Fixes #1234. See also #3456.
|
||||
Please use keywords (e.g., Fixes) to create link to the issues or pull requests
|
||||
you resolved, so that they will automatically be closed when your pull request
|
||||
is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests
|
||||
-->
|
||||
|
||||
#### What does this implement/fix? Briefly explain your changes.
|
||||
|
||||
#### Any other comments?
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
name: 'DiskANN Build Bootstrap'
|
||||
description: 'Prepares DiskANN build environment and executes build'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
# ------------ Linux Build ---------------
|
||||
- name: Prepare and Execute Build
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
run: |
|
||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
||||
cmake --build build -- -j
|
||||
cmake --install build --prefix="dist"
|
||||
shell: bash
|
||||
# ------------ End Linux Build ---------------
|
||||
# ------------ Windows Build ---------------
|
||||
- name: Add VisualStudio command line tools into path
|
||||
if: runner.os == 'Windows'
|
||||
uses: ilammy/msvc-dev-cmd@v1
|
||||
- name: Run configure and build for Windows
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
mkdir build && cd build && cmake .. -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
||||
cd ..
|
||||
mkdir dist
|
||||
mklink /j .\dist\bin .\x64\Release\
|
||||
shell: cmd
|
||||
# ------------ End Windows Build ---------------
|
||||
# ------------ Windows Build With EXEC_ENV_OLS and USE_BING_INFRA ---------------
|
||||
- name: Add VisualStudio command line tools into path
|
||||
if: runner.os == 'Windows'
|
||||
uses: ilammy/msvc-dev-cmd@v1
|
||||
- name: Run configure and build for Windows with Bing feature flags
|
||||
if: runner.os == 'Windows'
|
||||
run: |
|
||||
mkdir build_bing && cd build_bing && cmake .. -DEXEC_ENV_OLS=1 -DUSE_BING_INFRA=1 -DUNIT_TEST=True && msbuild diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64" -consoleloggerparameters:"ErrorsOnly;Summary"
|
||||
cd ..
|
||||
shell: cmd
|
||||
# ------------ End Windows Build ---------------
|
||||
@@ -1,13 +0,0 @@
|
||||
name: 'Checking code formatting...'
|
||||
description: 'Ensures code complies with code formatting rules'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Checking code formatting...
|
||||
run: |
|
||||
sudo apt install clang-format
|
||||
find include -name '*.h' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
find src -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
find apps -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
find python -name '*.cpp' -type f -print0 | xargs -0 -P 16 /usr/bin/clang-format --Werror --dry-run
|
||||
shell: bash
|
||||
@@ -1,28 +0,0 @@
|
||||
name: 'Generating Random Data (Basic)'
|
||||
description: 'Generates the random data files used in acceptance tests'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Generate Random Data (Basic)
|
||||
run: |
|
||||
mkdir data
|
||||
|
||||
echo "Generating random 1020,1024,1536D float and 4096 int8 vectors for index"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_5K_norm1.0.bin -D 1020 -N 5000 --norm 1.0
|
||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_5K_norm1.0.bin -D 1024 -N 5000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_5K_norm1.0.bin -D 1536 -N 5000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_5K_norm1.0.bin -D 4096 -N 5000 --norm 1.0
|
||||
|
||||
echo "Generating random 1020,1024,1536D float and 4096D int8 avectors for query"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1020D_1K_norm1.0.bin -D 1020 -N 1000 --norm 1.0
|
||||
#dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1024D_1K_norm1.0.bin -D 1024 -N 1000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_1536D_1K_norm1.0.bin -D 1536 -N 1000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_4096D_1K_norm1.0.bin -D 4096 -N 1000 --norm 1.0
|
||||
|
||||
echo "Computing ground truth for 1020,1024,1536D float and 4096D int8 avectors for query"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1020D_5K_norm1.0.bin --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --K 100
|
||||
#dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1024D_5K_norm1.0.bin --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_1536D_5K_norm1.0.bin --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_4096D_5K_norm1.0.bin --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --K 100
|
||||
|
||||
shell: bash
|
||||
@@ -1,38 +0,0 @@
|
||||
name: 'Generating Random Data (Basic)'
|
||||
description: 'Generates the random data files used in acceptance tests'
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Generate Random Data (Basic)
|
||||
run: |
|
||||
mkdir data
|
||||
|
||||
echo "Generating random vectors for index"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_norm1.0.bin -D 10 -N 10000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_10K_unnorm.bin -D 10 -N 10000 --rand_scaling 2.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
|
||||
echo "Generating random vectors for query"
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_norm1.0.bin -D 10 -N 1000 --norm 1.0
|
||||
dist/bin/rand_data_gen --data_type float --output_file data/rand_float_10D_1K_unnorm.bin -D 10 -N 1000 --rand_scaling 2.0
|
||||
dist/bin/rand_data_gen --data_type int8 --output_file data/rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
dist/bin/rand_data_gen --data_type uint8 --output_file data/rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
|
||||
echo "Computing ground truth for floats across l2, mips, and cosine distance functions"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn mips --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_norm1.0.bin --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn cosine --base_file data/rand_float_10D_10K_unnorm.bin --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --K 100
|
||||
|
||||
echo "Computing ground truth for int8s across l2, mips, and cosine distance functions"
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn mips --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/mips_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn cosine --base_file data/rand_int8_10D_10K_norm50.0.bin --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
|
||||
echo "Computing ground truth for uint8s across l2, mips, and cosine distance functions"
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn mips --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
|
||||
shell: bash
|
||||
@@ -1,22 +0,0 @@
|
||||
name: Build Python Wheel
|
||||
description: Builds a python wheel with cibuildwheel
|
||||
inputs:
|
||||
cibw-identifier:
|
||||
description: "CI build wheel identifier to build"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- uses: actions/setup-python@v3
|
||||
- name: Install cibuildwheel
|
||||
run: python -m pip install cibuildwheel==2.11.3
|
||||
shell: bash
|
||||
- name: Building Python ${{inputs.cibw-identifier}} Wheel
|
||||
run: python -m cibuildwheel --output-dir dist
|
||||
env:
|
||||
CIBW_BUILD: ${{inputs.cibw-identifier}}
|
||||
shell: bash
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
path: ./dist/*.whl
|
||||
@@ -1,81 +0,0 @@
|
||||
name: DiskANN Build PDoc Documentation
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
build-reference-documentation:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Set up Python 3.9
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.9
|
||||
- name: Install python build
|
||||
run: python -m pip install build
|
||||
shell: bash
|
||||
# Install required dependencies
|
||||
- name: Prepare Linux environment
|
||||
run: |
|
||||
sudo scripts/dev/install-dev-deps-ubuntu.bash
|
||||
shell: bash
|
||||
# We need to build the wheel in order to run pdoc. pdoc does not seem to work if you just point it at
|
||||
# our source directory.
|
||||
- name: Building Python Wheel for documentation generation
|
||||
run: python -m build --wheel --outdir documentation_dist
|
||||
shell: bash
|
||||
- name: "Run Reference Documentation Generation"
|
||||
run: |
|
||||
pip install pdoc pipdeptree
|
||||
pip install documentation_dist/*.whl
|
||||
echo "documentation" > dependencies_documentation.txt
|
||||
pipdeptree >> dependencies_documentation.txt
|
||||
pdoc -o docs/python/html diskannpy
|
||||
- name: Create version environment variable
|
||||
run: |
|
||||
echo "DISKANN_VERSION=$(python <<EOF
|
||||
from importlib.metadata import version
|
||||
v = version('diskannpy')
|
||||
print(v)
|
||||
EOF
|
||||
)" >> $GITHUB_ENV
|
||||
- name: Archive documentation version artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dependencies
|
||||
path: |
|
||||
${{ github.run_id }}-dependencies_documentation.txt
|
||||
overwrite: true
|
||||
- name: Archive documentation artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: documentation-site
|
||||
path: |
|
||||
docs/python/html
|
||||
# Publish to /dev if we are on the "main" branch
|
||||
- name: Publish reference docs for latest development version (main branch)
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.ref == 'refs/heads/main'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/python/html
|
||||
destination_dir: docs/python/dev
|
||||
# Publish to /<version> if we are releasing
|
||||
- name: Publish reference docs by version (main branch)
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.event_name == 'release'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/python/html
|
||||
destination_dir: docs/python/${{ env.DISKANN_VERSION }}
|
||||
# Publish to /latest if we are releasing
|
||||
- name: Publish latest reference docs (main branch)
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: github.event_name == 'release'
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: docs/python/html
|
||||
destination_dir: docs/python/latest
|
||||
@@ -1,42 +0,0 @@
|
||||
name: DiskANN Build Python Wheel
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
linux-build:
|
||||
name: Python - Ubuntu - ${{matrix.cibw-identifier}}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
cibw-identifier: ["cp39-manylinux_x86_64", "cp310-manylinux_x86_64", "cp311-manylinux_x86_64"]
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
||||
uses: ./.github/actions/python-wheel
|
||||
with:
|
||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
||||
windows-build:
|
||||
name: Python - Windows - ${{matrix.cibw-identifier}}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
cibw-identifier: ["cp39-win_amd64", "cp310-win_amd64", "cp311-win_amd64"]
|
||||
runs-on: windows-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
- name: Building python wheel ${{matrix.cibw-identifier}}
|
||||
uses: ./.github/actions/python-wheel
|
||||
with:
|
||||
cibw-identifier: ${{matrix.cibw-identifier}}
|
||||
@@ -1,28 +0,0 @@
|
||||
name: DiskANN Common Checks
|
||||
# common means common to both pr-test and push-test
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
formatting-check:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: Code Formatting Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checking code formatting...
|
||||
uses: ./.github/actions/format-check
|
||||
docker-container-build:
|
||||
name: Docker Container Build
|
||||
needs: [formatting-check]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Docker build
|
||||
run: |
|
||||
docker build .
|
||||
@@ -1,117 +0,0 @@
|
||||
name: Disk With PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-disk-pq:
|
||||
name: Disk, PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, cosine, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16\
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, cosine, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_unnorm.bin --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/disk_index_cosine_rand_float_10D_10K_unnorm_diskfull_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_unnorm.bin --gt_file data/cosine_rand_float_10D_10K_unnorm_10D_1K_unnorm_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (int8)
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_10D_10K_norm1.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ) (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: build and search disk index (sharded graph build, MIPS, diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 --PQ_disk_bytes 5
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_mips_rand_float_10D_10K_norm1.0_diskpq_sharded --result_path /tmp/res --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: disk-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,102 +0,0 @@
|
||||
name: Dynamic-Labels
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-dynamic:
|
||||
name: Dynamic-Labels
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: Generate Labels
|
||||
run: |
|
||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
||||
|
||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
||||
|
||||
- name: Test a streaming index (float) with labels (Zipf distributed)
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_zipf_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 --label_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_base-act4000-cons2000-max10000
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: Test a streaming index (float) with labels (random distributed)
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --universal_label 0 --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_rand_stream -R 64 --FilteredLbuild 200 -L 50 --alpha 1.2 --insert_threads 8 --consolidate_threads 8 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --universal_label 0 --filter_label 1 --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 --label_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000_raw_labels.txt --tags_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 1 --fail_if_recall_below 40 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000_1 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_base-act4000-cons2000-max10000
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: Test Insert Delete Consolidate (float) with labels (zipf distributed)
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/zipf_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_zipf_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K_wlabel_5 --label_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 10 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_zipf_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_zipf_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_zipf_random10D_1K
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_zipf_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_zipf_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: Test Insert Delete Consolidate (float) with labels (random distributed)
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --universal_label 0 --label_file data/rand_labels_50_10K.txt --FilteredLbuild 70 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_rand_ins_del -R 64 -L 10 --alpha 1.2 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2 --unique_labels_supported 51
|
||||
|
||||
echo "Computing groundtruth with filter"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type float --filter_label 5 --universal_label 0 --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K_wlabel_5 --label_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500_raw_labels.txt --tags_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
echo "Searching with filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --filter_label 5 --fail_if_recall_below 40 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_rand_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K_wlabel_5 -K 10 -L 20 40 60 80 100 150 -T 64 --dynamic true --tags 1
|
||||
|
||||
echo "Computing groundtruth w/o filter"
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_rand_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_rand_random10D_1K
|
||||
echo "Searching without filter"
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_rand_ins_del.after-concurrent-delete-del2500-7500 --result_path res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_rand_random10D_1K -K 10 -L 20 40 60 80 100 -T 64
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dynamic-labels-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,75 +0,0 @@
|
||||
name: Dynamic
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-dynamic:
|
||||
name: Dynamic
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: test a streaming index (float)
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 3.2
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
- name: test a streaming index (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
- name: test a streaming index
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_stream.after-streaming-act4000-cons2000-max10000.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_base-act4000-cons2000-max10000 --tags_file data/index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_stream.after-streaming-act4000-cons2000-max10000 --result_path data/res_stream --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
|
||||
- name: build and search an incremental index (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 3.2;
|
||||
dist/bin/compute_groundtruth --data_type float --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_float_10D_1K_norm1.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_float_10D_1K_norm1.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
- name: build and search an incremental index (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type int8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_1K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_int8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
- name: build and search an incremental index (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/index_ins_del.after-concurrent-delete-del2500-7500.data --query_file data/rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file data/gt100_random10D_10K-conc-2500-7500 --tags_file data/index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_ins_del.after-concurrent-delete-del2500-7500 --result_path data/res_ins_del --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dynamic-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,81 +0,0 @@
|
||||
name: In-Memory Without PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-mem-no-pq:
|
||||
name: In-Mem, Without PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: build and search in-memory index with L2 metrics (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with L2 metrics (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with L2 metrics (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: Searching with fast_l2 distance function (float)
|
||||
if: runner.os != 'Windows' && (success() || failure())
|
||||
run: |
|
||||
dist/bin/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with MIPS metric (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn mips --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_mips_rand_float_10D_10K_norm1.0
|
||||
dist/bin/search_memory_index --data_type float --dist_fn mips --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/mips_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with cosine metric (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn cosine --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_cosine_rand_float_10D_10K_norm1.0
|
||||
dist/bin/search_memory_index --data_type float --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with cosine metric (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type int8 --dist_fn cosine --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_int8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with cosine metric
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50.0
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: in-memory-no-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,56 +0,0 @@
|
||||
name: In-Memory With PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-mem-pq:
|
||||
name: In-Mem, PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: build and search in-memory index with L2 metric with PQ based distance comparisons (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type float --dist_fn l2 --data_path data/rand_float_10D_10K_norm1.0.bin --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_float_10D_10K_norm1.0_buildpq5 --query_file data/rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_int8_10D_10K_norm50.0_buildpq5 --query_file data/rand_int8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons (uint8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --data_path data/rand_uint8_10D_10K_norm50.0.bin --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: in-memory-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,120 +0,0 @@
|
||||
name: Labels
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-labels:
|
||||
name: Labels
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-random
|
||||
|
||||
- name: Generate Labels
|
||||
run: |
|
||||
echo "Generating synthetic labels and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/rand_labels_50_10K.txt --distribution_type random
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 10 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
|
||||
echo "Generating synthetic labels with a zipf distribution and computing ground truth for filtered search with universal label"
|
||||
dist/bin/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file data/zipf_labels_50_10K.txt --distribution_type zipf
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn mips --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/mips_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn cosine --universal_label 0 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
|
||||
echo "Generating synthetic labels and computing ground truth for filtered search without a universal label"
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --filter_label 5 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --K 100
|
||||
dist/bin/generate_synthetic_labels --num_labels 10 --num_points 1000 --output_file data/query_labels_1K.txt --distribution_type one_per_point
|
||||
dist/bin/compute_groundtruth_for_filters --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label_file data/query_labels_1K.txt --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
|
||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (random distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
|
||||
echo "Searching without filters"
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
|
||||
- name: build and search disk index with labels using L2 and Cosine metrics (random distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_rand_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search in-memory index with labels using L2 and Cosine metrics (zipf distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn cosine --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
|
||||
echo "Searching without filters"
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file data/rand_uint8_10D_10K_norm50.0.bin --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix data/index_cosine_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/cosine_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 32 64
|
||||
|
||||
- name: build and search disk index with labels using L2 and Cosine metrics (zipf distributed labels)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --universal_label 0 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel -R 32 -L 5 -B 0.00003 -M 1
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 50 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
|
||||
- name : build and search in-memory and disk index (without universal label, zipf distributed)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal
|
||||
dist/bin/build_disk_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal -R 32 -L 5 -B 0.00003 -M 1
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 5 --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal -L 16 32
|
||||
dist/bin/search_disk_index --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/disk_index_l2_zipf_uint8_10D_10K_norm50_wlabel_nouniversal --result_path temp --query_file data/rand_uint8_10D_1K_norm50.0.bin --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel_nouniversal --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: Generate combined GT for each query with a separate label and search
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --query_filters_file data/query_labels_1K.txt --fail_if_recall_below 70 --index_path_prefix data/index_l2_zipf_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/combined_l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
|
||||
dist/bin/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix data/index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file data/rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
- name: Build and search stitched vamana with random and zipf distributed labels
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/rand_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_rand_32_100_64_new --universal_label 0
|
||||
dist/bin/build_stitched_index --num_threads 48 --data_type uint8 --data_path data/rand_uint8_10D_10K_norm50.0.bin --label_file data/zipf_labels_50_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix data/stit_zipf_32_100_64_new --universal_label 0
|
||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix data/stit_rand_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/rand_stit_96_10_90_new --gt_file data/l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
||||
dist/bin/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 5 --index_path_prefix data/stit_zipf_32_100_64_new --query_file data/rand_uint8_10D_1K_norm50.0.bin --result_path data/zipf_stit_96_10_90_new --gt_file data/l2_zipf_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 16 32 150
|
||||
|
||||
- name: upload data and bin
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: labels-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,60 +0,0 @@
|
||||
name: Disk With PQ
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-disk-pq:
|
||||
name: Disk, PQ
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Generate Data
|
||||
uses: ./.github/actions/generate-high-dim-random
|
||||
|
||||
- name: build and search disk index (1020D, one shot graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1020D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1020D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1020D_1K_norm1.0.bin --gt_file data/l2_rand_float_1020D_5K_norm1.0_1020D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
#- name: build and search disk index (1024D, one shot graph build, L2, no diskPQ) (float)
|
||||
# if: success() || failure()
|
||||
# run: |
|
||||
# dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1024D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
# dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1024D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1024D_1K_norm1.0.bin --gt_file data/l2_rand_float_1024D_5K_norm1.0_1024D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
- name: build and search disk index (1536D, one shot graph build, L2, no diskPQ) (float)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type float --dist_fn l2 --data_path data/rand_float_1536D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
dist/bin/search_disk_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_float_1536D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_float_1536D_1K_norm1.0.bin --gt_file data/l2_rand_float_1536D_5K_norm1.0_1536D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
|
||||
- name: build and search disk index (4096D, one shot graph build, L2, no diskPQ) (int8)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
dist/bin/build_disk_index --data_type int8 --dist_fn l2 --data_path data/rand_int8_4096D_5K_norm1.0.bin --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot -R 32 -L 500 -B 0.003 -M 1
|
||||
dist/bin/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix data/disk_index_l2_rand_int8_4096D_5K_norm1.0_diskfull_oneshot --result_path /tmp/res --query_file data/rand_int8_4096D_1K_norm1.0.bin --gt_file data/l2_rand_int8_4096D_5K_norm1.0_4096D_1K_norm1.0_gt100 --recall_at 5 -L 250 -W 2 --num_nodes_to_cache 100 -T 16
|
||||
|
||||
- name: upload data and bin
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: multi-sector-disk-pq-${{matrix.os}}
|
||||
path: |
|
||||
./dist/**
|
||||
./data/**
|
||||
@@ -1,26 +0,0 @@
|
||||
name: DiskANN Nightly Performance Metrics
|
||||
on:
|
||||
schedule:
|
||||
- cron: "41 14 * * *" # 14:41 UTC, 7:41 PDT, 8:41 PST, 08:11 IST
|
||||
jobs:
|
||||
perf-test:
|
||||
name: Run Perf Test from main
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Build Perf Container
|
||||
run: |
|
||||
docker build --build-arg GIT_COMMIT_ISH="$GITHUB_SHA" -t perf -f scripts/perf/Dockerfile scripts
|
||||
- name: Performance Tests
|
||||
run: |
|
||||
mkdir metrics
|
||||
docker run -v ./metrics:/app/logs perf &> ./metrics/combined_stdouterr.log
|
||||
- name: Upload Metrics Logs
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: metrics-${{matrix.os}}
|
||||
path: |
|
||||
./metrics/**
|
||||
@@ -1,35 +0,0 @@
|
||||
name: DiskANN Pull Request Build and Test
|
||||
on: [pull_request]
|
||||
jobs:
|
||||
common:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Common Build Checks
|
||||
uses: ./.github/workflows/common.yml
|
||||
unit-tests:
|
||||
name: Unit tests
|
||||
uses: ./.github/workflows/unit-tests.yml
|
||||
in-mem-pq:
|
||||
name: In-Memory with PQ
|
||||
uses: ./.github/workflows/in-mem-pq.yml
|
||||
in-mem-no-pq:
|
||||
name: In-Memory without PQ
|
||||
uses: ./.github/workflows/in-mem-no-pq.yml
|
||||
disk-pq:
|
||||
name: Disk with PQ
|
||||
uses: ./.github/workflows/disk-pq.yml
|
||||
multi-sector-disk-pq:
|
||||
name: Multi-sector Disk with PQ
|
||||
uses: ./.github/workflows/multi-sector-disk-pq.yml
|
||||
labels:
|
||||
name: Labels
|
||||
uses: ./.github/workflows/labels.yml
|
||||
dynamic:
|
||||
name: Dynamic
|
||||
uses: ./.github/workflows/dynamic.yml
|
||||
dynamic-labels:
|
||||
name: Dynamic Labels
|
||||
uses: ./.github/workflows/dynamic-labels.yml
|
||||
python:
|
||||
name: Python
|
||||
uses: ./.github/workflows/build-python.yml
|
||||
@@ -1,50 +0,0 @@
|
||||
name: DiskANN Push Build
|
||||
on: [push]
|
||||
jobs:
|
||||
common:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Common Build Checks
|
||||
uses: ./.github/workflows/common.yml
|
||||
build-documentation:
|
||||
permissions:
|
||||
contents: write
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Build Documentation
|
||||
uses: ./.github/workflows/build-python-pdoc.yml
|
||||
build:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ ubuntu-latest, windows-2019, windows-latest ]
|
||||
name: Build for ${{matrix.os}}
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: Build diskannpy dependency tree
|
||||
run: |
|
||||
pip install diskannpy pipdeptree
|
||||
echo "dependencies" > dependencies_${{ matrix.os }}.txt
|
||||
pipdeptree >> dependencies_${{ matrix.os }}.txt
|
||||
- name: Archive diskannpy dependencies artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dependencies_${{ matrix.os }}
|
||||
path: |
|
||||
dependencies_${{ matrix.os }}.txt
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
@@ -1,43 +0,0 @@
|
||||
name: Build and Release Python Wheels
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
jobs:
|
||||
python-release-wheels:
|
||||
name: Python
|
||||
uses: ./.github/workflows/build-python.yml
|
||||
build-documentation:
|
||||
strategy:
|
||||
fail-fast: true
|
||||
name: DiskANN Build Documentation
|
||||
uses: ./.github/workflows/build-python-pdoc.yml
|
||||
release:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
needs: python-release-wheels
|
||||
steps:
|
||||
- uses: actions/download-artifact@v3
|
||||
with:
|
||||
name: wheels
|
||||
path: dist/
|
||||
- name: Generate SHA256 files for each wheel
|
||||
run: |
|
||||
sha256sum dist/*.whl > checksums.txt
|
||||
cat checksums.txt
|
||||
- uses: actions/setup-python@v3
|
||||
- name: Install twine
|
||||
run: python -m pip install twine
|
||||
- name: Publish with twine
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
|
||||
run: |
|
||||
twine upload dist/*.whl
|
||||
- name: Update release with SHA256 and Artifacts
|
||||
uses: softprops/action-gh-release@v1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
files: |
|
||||
dist/*.whl
|
||||
checksums.txt
|
||||
@@ -1,32 +0,0 @@
|
||||
name: Unit Tests
|
||||
on: [workflow_call]
|
||||
jobs:
|
||||
acceptance-tests-labels:
|
||||
name: Unit Tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
runs-on: ${{matrix.os}}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- name: Checkout repository
|
||||
if: ${{ runner.os == 'Windows' }}
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
submodules: true
|
||||
- name: DiskANN Build CLI Applications
|
||||
uses: ./.github/actions/build
|
||||
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
cd build
|
||||
ctest -C Release
|
||||
@@ -1,384 +0,0 @@
|
||||
## Ignore Visual Studio temporary files, build results, and
|
||||
## files generated by popular Visual Studio add-ons.
|
||||
##
|
||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
||||
|
||||
# User-specific files
|
||||
*.rsuser
|
||||
*.suo
|
||||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
||||
# Mono auto generated files
|
||||
mono_crash.*
|
||||
|
||||
# Build results
|
||||
[Dd]ebug/
|
||||
[Dd]ebugPublic/
|
||||
[Rr]elease/
|
||||
[Rr]eleases/
|
||||
x64/
|
||||
x86/
|
||||
[Aa][Rr][Mm]/
|
||||
[Aa][Rr][Mm]64/
|
||||
bld/
|
||||
[Bb]in/
|
||||
[Oo]bj/
|
||||
[Ll]og/
|
||||
[Ll]ogs/
|
||||
|
||||
# Visual Studio 2015/2017 cache/options directory
|
||||
.vs/
|
||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||
#wwwroot/
|
||||
|
||||
# Visual Studio 2017 auto generated files
|
||||
Generated\ Files/
|
||||
|
||||
# MSTest test Results
|
||||
[Tt]est[Rr]esult*/
|
||||
[Bb]uild[Ll]og.*
|
||||
|
||||
# NUnit
|
||||
*.VisualState.xml
|
||||
TestResult.xml
|
||||
nunit-*.xml
|
||||
|
||||
# Build Results of an ATL Project
|
||||
[Dd]ebugPS/
|
||||
[Rr]eleasePS/
|
||||
dlldata.c
|
||||
|
||||
# Benchmark Results
|
||||
BenchmarkDotNet.Artifacts/
|
||||
|
||||
# .NET Core
|
||||
project.lock.json
|
||||
project.fragment.lock.json
|
||||
artifacts/
|
||||
|
||||
# StyleCop
|
||||
StyleCopReport.xml
|
||||
|
||||
# Files built by Visual Studio
|
||||
*_i.c
|
||||
*_p.c
|
||||
*_h.h
|
||||
*.ilk
|
||||
*.meta
|
||||
*.obj
|
||||
*.iobj
|
||||
*.pch
|
||||
*.pdb
|
||||
*.ipdb
|
||||
*.pgc
|
||||
*.pgd
|
||||
*.rsp
|
||||
*.sbr
|
||||
*.tlb
|
||||
*.tli
|
||||
*.tlh
|
||||
*.tmp
|
||||
*.tmp_proj
|
||||
*_wpftmp.csproj
|
||||
*.log
|
||||
*.vspscc
|
||||
*.vssscc
|
||||
.builds
|
||||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
# Visual C++ cache files
|
||||
ipch/
|
||||
*.aps
|
||||
*.ncb
|
||||
*.opendb
|
||||
*.opensdf
|
||||
*.sdf
|
||||
*.cachefile
|
||||
*.VC.db
|
||||
*.VC.VC.opendb
|
||||
|
||||
# Visual Studio profiler
|
||||
*.psess
|
||||
*.vsp
|
||||
*.vspx
|
||||
*.sap
|
||||
|
||||
# Visual Studio Trace Files
|
||||
*.e2e
|
||||
|
||||
# TFS 2012 Local Workspace
|
||||
$tf/
|
||||
|
||||
# Guidance Automation Toolkit
|
||||
*.gpState
|
||||
|
||||
# ReSharper is a .NET coding add-in
|
||||
_ReSharper*/
|
||||
*.[Rr]e[Ss]harper
|
||||
*.DotSettings.user
|
||||
|
||||
# TeamCity is a build add-in
|
||||
_TeamCity*
|
||||
|
||||
# DotCover is a Code Coverage Tool
|
||||
*.dotCover
|
||||
|
||||
# AxoCover is a Code Coverage Tool
|
||||
.axoCover/*
|
||||
!.axoCover/settings.json
|
||||
|
||||
# Visual Studio code coverage results
|
||||
*.coverage
|
||||
*.coveragexml
|
||||
|
||||
# NCrunch
|
||||
_NCrunch_*
|
||||
.*crunch*.local.xml
|
||||
nCrunchTemp_*
|
||||
|
||||
# MightyMoose
|
||||
*.mm.*
|
||||
AutoTest.Net/
|
||||
|
||||
# Web workbench (sass)
|
||||
.sass-cache/
|
||||
|
||||
# Installshield output folder
|
||||
[Ee]xpress/
|
||||
|
||||
# DocProject is a documentation generator add-in
|
||||
DocProject/buildhelp/
|
||||
DocProject/Help/*.HxT
|
||||
DocProject/Help/*.HxC
|
||||
DocProject/Help/*.hhc
|
||||
DocProject/Help/*.hhk
|
||||
DocProject/Help/*.hhp
|
||||
DocProject/Help/Html2
|
||||
DocProject/Help/html
|
||||
|
||||
# Click-Once directory
|
||||
publish/
|
||||
|
||||
# Publish Web Output
|
||||
*.[Pp]ublish.xml
|
||||
*.azurePubxml
|
||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||
# but database connection strings (with potential passwords) will be unencrypted
|
||||
*.pubxml
|
||||
*.publishproj
|
||||
|
||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||
# in these scripts will be unencrypted
|
||||
PublishScripts/
|
||||
|
||||
# NuGet Packages
|
||||
*.nupkg
|
||||
# NuGet Symbol Packages
|
||||
*.snupkg
|
||||
# The packages folder can be ignored because of Package Restore
|
||||
**/[Pp]ackages/*
|
||||
# except build/, which is used as an MSBuild target.
|
||||
!**/[Pp]ackages/build/
|
||||
# Uncomment if necessary however generally it will be regenerated when needed
|
||||
#!**/[Pp]ackages/repositories.config
|
||||
# NuGet v3's project.json files produces more ignorable files
|
||||
*.nuget.props
|
||||
*.nuget.targets
|
||||
|
||||
# Microsoft Azure Build Output
|
||||
csx/
|
||||
*.build.csdef
|
||||
|
||||
# Microsoft Azure Emulator
|
||||
ecf/
|
||||
rcf/
|
||||
|
||||
# Windows Store app package directories and files
|
||||
AppPackages/
|
||||
BundleArtifacts/
|
||||
Package.StoreAssociation.xml
|
||||
_pkginfo.txt
|
||||
*.appx
|
||||
*.appxbundle
|
||||
*.appxupload
|
||||
|
||||
# Visual Studio cache files
|
||||
# files ending in .cache can be ignored
|
||||
*.[Cc]ache
|
||||
# but keep track of directories ending in .cache
|
||||
!?*.[Cc]ache/
|
||||
|
||||
# Others
|
||||
ClientBin/
|
||||
~$*
|
||||
*~
|
||||
*.dbmdl
|
||||
*.dbproj.schemaview
|
||||
*.jfm
|
||||
*.pfx
|
||||
*.publishsettings
|
||||
orleans.codegen.cs
|
||||
|
||||
# Including strong name files can present a security risk
|
||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||
#*.snk
|
||||
|
||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||
#bower_components/
|
||||
|
||||
# RIA/Silverlight projects
|
||||
Generated_Code/
|
||||
|
||||
# Backup & report files from converting an old project file
|
||||
# to a newer Visual Studio version. Backup files are not needed,
|
||||
# because we have git ;-)
|
||||
_UpgradeReport_Files/
|
||||
Backup*/
|
||||
UpgradeLog*.XML
|
||||
UpgradeLog*.htm
|
||||
ServiceFabricBackup/
|
||||
*.rptproj.bak
|
||||
|
||||
# SQL Server files
|
||||
*.mdf
|
||||
*.ldf
|
||||
*.ndf
|
||||
|
||||
# Business Intelligence projects
|
||||
*.rdl.data
|
||||
*.bim.layout
|
||||
*.bim_*.settings
|
||||
*.rptproj.rsuser
|
||||
*- [Bb]ackup.rdl
|
||||
*- [Bb]ackup ([0-9]).rdl
|
||||
*- [Bb]ackup ([0-9][0-9]).rdl
|
||||
|
||||
# Microsoft Fakes
|
||||
FakesAssemblies/
|
||||
|
||||
# GhostDoc plugin setting file
|
||||
*.GhostDoc.xml
|
||||
|
||||
# Node.js Tools for Visual Studio
|
||||
.ntvs_analysis.dat
|
||||
node_modules/
|
||||
|
||||
# Visual Studio 6 build log
|
||||
*.plg
|
||||
|
||||
# Visual Studio 6 workspace options file
|
||||
*.opt
|
||||
|
||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||
*.vbw
|
||||
|
||||
# Visual Studio LightSwitch build output
|
||||
**/*.HTMLClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/ModelManifest.xml
|
||||
**/*.Server/GeneratedArtifacts
|
||||
**/*.Server/ModelManifest.xml
|
||||
_Pvt_Extensions
|
||||
|
||||
# Paket dependency manager
|
||||
.paket/paket.exe
|
||||
paket-files/
|
||||
|
||||
# FAKE - F# Make
|
||||
.fake/
|
||||
|
||||
# CodeRush personal settings
|
||||
.cr/personal
|
||||
|
||||
# Python Tools for Visual Studio (PTVS)
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Cake - Uncomment if you are using it
|
||||
# tools/**
|
||||
# !tools/packages.config
|
||||
|
||||
# Tabs Studio
|
||||
*.tss
|
||||
|
||||
# Telerik's JustMock configuration file
|
||||
*.jmconfig
|
||||
|
||||
# BizTalk build output
|
||||
*.btp.cs
|
||||
*.btm.cs
|
||||
*.odx.cs
|
||||
*.xsd.cs
|
||||
|
||||
# OpenCover UI analysis results
|
||||
OpenCover/
|
||||
|
||||
# Azure Stream Analytics local run output
|
||||
ASALocalRun/
|
||||
|
||||
# MSBuild Binary and Structured Log
|
||||
*.binlog
|
||||
|
||||
# NVidia Nsight GPU debugger configuration file
|
||||
*.nvuser
|
||||
|
||||
# MFractors (Xamarin productivity tool) working folder
|
||||
.mfractor/
|
||||
|
||||
# Local History for Visual Studio
|
||||
.localhistory/
|
||||
|
||||
# BeatPulse healthcheck temp database
|
||||
healthchecksdb
|
||||
|
||||
# Backup folder for Package Reference Convert tool in Visual Studio 2017
|
||||
MigrationBackup/
|
||||
|
||||
# Ionide (cross platform F# VS Code tools) working folder
|
||||
.ionide/
|
||||
|
||||
/vcproj/nsg/x64/Debug/nsg.Build.CppClean.log
|
||||
/vcproj/test_recall/x64/Debug/test_recall.Build.CppClean.log
|
||||
/vcproj/test_recall/test_recall.vcxproj.user
|
||||
/.vs
|
||||
/out/build/x64-Debug
|
||||
cscope*
|
||||
|
||||
build/
|
||||
build_linux/
|
||||
!.github/actions/build
|
||||
|
||||
# jetbrains specific stuff
|
||||
.idea/
|
||||
cmake-build-debug/
|
||||
|
||||
#python extension module ignores
|
||||
python/diskannpy.egg-info/
|
||||
python/dist/
|
||||
|
||||
**/*.egg-info
|
||||
wheelhouse/*
|
||||
dist/*
|
||||
venv*/**
|
||||
*.swp
|
||||
|
||||
gperftools
|
||||
|
||||
# Rust
|
||||
rust/target
|
||||
|
||||
python/src/*.so
|
||||
|
||||
compile_commands.json
|
||||
@@ -1,3 +0,0 @@
|
||||
[submodule "gperftools"]
|
||||
path = gperftools
|
||||
url = https://github.com/gperftools/gperftools.git
|
||||
@@ -1,563 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Parameters:
|
||||
#
|
||||
# BOOST_ROOT:
|
||||
# Specify root of the Boost library if Boost cannot be auto-detected. On Windows, a fallback to a
|
||||
# downloaded nuget version will be used if Boost cannot be found.
|
||||
#
|
||||
# DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS:
|
||||
# This is a work-in-progress feature, not completed yet. The core DiskANN library will be split into
|
||||
# build-related and search-related functionality. In build-related functionality, when using tcmalloc,
|
||||
# it's possible to release memory that's free but reserved by tcmalloc. Setting this to true enables
|
||||
# such behavior.
|
||||
# Contact for this feature: gopalrs.
|
||||
|
||||
|
||||
# Some variables like MSVC are defined only after project(), so put that first.
|
||||
cmake_minimum_required(VERSION 3.20)
|
||||
project(diskann)
|
||||
|
||||
#Set option to use tcmalloc
|
||||
option(USE_TCMALLOC "Use tcmalloc from gperftools" ON)
|
||||
|
||||
# set tcmalloc to false when on macos
|
||||
if(APPLE)
|
||||
set(USE_TCMALLOC OFF)
|
||||
endif()
|
||||
|
||||
option(PYBIND "Build with Python bindings" ON)
|
||||
|
||||
if(PYBIND)
|
||||
# Find Python
|
||||
find_package(Python 3.6 COMPONENTS Interpreter Development REQUIRED)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())"
|
||||
OUTPUT_VARIABLE pybind11_DIR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
|
||||
message(STATUS "Python include dirs: ${Python_INCLUDE_DIRS}")
|
||||
message(STATUS "Pybind11 include dirs: ${pybind11_INCLUDE_DIRS}")
|
||||
|
||||
# Add pybind11 include directories
|
||||
include_directories(SYSTEM ${pybind11_INCLUDE_DIRS} ${Python_INCLUDE_DIRS})
|
||||
|
||||
# Add compilation definitions
|
||||
add_definitions(-DPYBIND11_EMBEDDED)
|
||||
|
||||
# Set visibility flags
|
||||
if(NOT MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(CMAKE_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
# if(NOT MSVC)
|
||||
# set(CMAKE_CXX_COMPILER g++)
|
||||
# endif()
|
||||
|
||||
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake;${CMAKE_MODULE_PATH}")
|
||||
|
||||
# Install nuget packages for dependencies.
|
||||
if (MSVC)
|
||||
find_program(NUGET_EXE NAMES nuget)
|
||||
|
||||
if (NOT NUGET_EXE)
|
||||
message(FATAL_ERROR "Cannot find nuget command line tool.\nPlease install it from e.g. https://www.nuget.org/downloads")
|
||||
endif()
|
||||
|
||||
set(DISKANN_MSVC_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/packages.config)
|
||||
set(DISKANN_MSVC_PACKAGES ${CMAKE_BINARY_DIR}/packages)
|
||||
|
||||
message(STATUS "Invoking nuget to download Boost, OpenMP and MKL dependencies...")
|
||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages.config.in ${DISKANN_MSVC_PACKAGES_CONFIG})
|
||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
||||
if (RESTAPI)
|
||||
set(DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG ${CMAKE_BINARY_DIR}/restapi/packages.config)
|
||||
configure_file(${PROJECT_SOURCE_DIR}/windows/packages_restapi.config.in ${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG})
|
||||
exec_program(${NUGET_EXE} ARGS install \"${DISKANN_MSVC_RESTAPI_PACKAGES_CONFIG}\" -ExcludeVersion -OutputDirectory \"${DISKANN_MSVC_PACKAGES}\")
|
||||
endif()
|
||||
message(STATUS "Finished setting up nuget dependencies")
|
||||
endif()
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
if(USE_TCMALLOC)
|
||||
FetchContent_Declare(
|
||||
tcmalloc
|
||||
GIT_REPOSITORY https://github.com/google/tcmalloc.git
|
||||
GIT_TAG origin/master # or specify a particular version or commit
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(tcmalloc)
|
||||
endif()
|
||||
|
||||
if(NOT PYBIND)
|
||||
set(DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS ON)
|
||||
endif()
|
||||
# It's necessary to include tcmalloc headers only if calling into MallocExtension interface.
|
||||
# For using tcmalloc in DiskANN tools, it's enough to just link with tcmalloc.
|
||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
||||
include_directories(${tcmalloc_SOURCE_DIR}/src)
|
||||
if (MSVC)
|
||||
include_directories(${tcmalloc_SOURCE_DIR}/src/windows)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#OpenMP
|
||||
if (MSVC)
|
||||
# Do not use find_package here since it would use VisualStudio's built-in OpenMP, but MKL libraries
|
||||
# refer to Intel's OpenMP.
|
||||
#
|
||||
# No extra settings are needed for compilation: it only needs /openmp flag which is set further below,
|
||||
# in the common MSVC compiler options block.
|
||||
include_directories(BEFORE "${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/include")
|
||||
link_libraries("${DISKANN_MSVC_PACKAGES}/intelopenmp.devel.win/lib/native/win-x64/libiomp5md.lib")
|
||||
|
||||
set(OPENMP_WINDOWS_RUNTIME_FILES
|
||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.dll"
|
||||
"${DISKANN_MSVC_PACKAGES}/intelopenmp.redist.win/runtimes/win-x64/native/libiomp5md.pdb")
|
||||
elseif(APPLE)
|
||||
# Check if we're building Python bindings
|
||||
if(PYBIND)
|
||||
# First look for PyTorch's OpenMP to avoid conflicts
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE} -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libomp.dylib'))"
|
||||
RESULT_VARIABLE TORCH_PATH_RESULT
|
||||
OUTPUT_VARIABLE TORCH_LIBOMP_PATH
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_QUIET
|
||||
)
|
||||
|
||||
execute_process(
|
||||
COMMAND brew --prefix libomp
|
||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(EXISTS "${TORCH_LIBOMP_PATH}")
|
||||
message(STATUS "Found PyTorch's libomp: ${TORCH_LIBOMP_PATH}")
|
||||
set(OpenMP_CXX_FLAGS "-Xclang -fopenmp")
|
||||
set(OpenMP_C_FLAGS "-Xclang -fopenmp")
|
||||
set(OpenMP_CXX_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
||||
set(OpenMP_C_LIBRARIES "${TORCH_LIBOMP_PATH}")
|
||||
set(OpenMP_FOUND TRUE)
|
||||
|
||||
include_directories(${LIBOMP_ROOT}/include)
|
||||
|
||||
# Set compiler flags and link libraries
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries("${TORCH_LIBOMP_PATH}")
|
||||
else()
|
||||
message(STATUS "No PyTorch's libomp found, falling back to normal OpenMP detection")
|
||||
# Fallback to normal OpenMP detection
|
||||
execute_process(
|
||||
COMMAND brew --prefix libomp
|
||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
||||
find_package(OpenMP)
|
||||
|
||||
if (OPENMP_FOUND)
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
# Regular OpenMP setup for non-Python builds
|
||||
execute_process(
|
||||
COMMAND brew --prefix libomp
|
||||
OUTPUT_VARIABLE LIBOMP_ROOT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
set(OpenMP_ROOT "${LIBOMP_ROOT}")
|
||||
find_package(OpenMP)
|
||||
|
||||
if (OPENMP_FOUND)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
find_package(OpenMP)
|
||||
|
||||
if (OPENMP_FOUND)
|
||||
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
else()
|
||||
message(FATAL_ERROR "No OpenMP support")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# DiskANN core uses header-only libraries. Only DiskANN tools need program_options which has a linker library,
|
||||
# but its size is small. Reduce number of dependent DLLs by linking statically.
|
||||
if (MSVC)
|
||||
set(Boost_USE_STATIC_LIBS ON)
|
||||
endif()
|
||||
|
||||
if(NOT MSVC)
|
||||
find_package(Boost COMPONENTS program_options)
|
||||
endif()
|
||||
|
||||
# For Windows, fall back to nuget version if find_package didn't find it.
|
||||
if (MSVC AND NOT Boost_FOUND)
|
||||
set(DISKANN_BOOST_INCLUDE "${DISKANN_MSVC_PACKAGES}/boost/lib/native/include")
|
||||
# Multi-threaded static library.
|
||||
set(PROGRAM_OPTIONS_LIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-x64-*.lib")
|
||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_LIB ${PROGRAM_OPTIONS_LIB_PATTERN})
|
||||
|
||||
set(PROGRAM_OPTIONS_DLIB_PATTERN "${DISKANN_MSVC_PACKAGES}/boost_program_options-vc${MSVC_TOOLSET_VERSION}/lib/native/libboost_program_options-vc${MSVC_TOOLSET_VERSION}-mt-gd-x64-*.lib")
|
||||
file(GLOB DISKANN_BOOST_PROGRAM_OPTIONS_DLIB ${PROGRAM_OPTIONS_DLIB_PATTERN})
|
||||
|
||||
if (EXISTS ${DISKANN_BOOST_INCLUDE} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_LIB} AND EXISTS ${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB})
|
||||
set(Boost_FOUND ON)
|
||||
set(Boost_INCLUDE_DIR ${DISKANN_BOOST_INCLUDE})
|
||||
add_library(Boost::program_options STATIC IMPORTED)
|
||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_RELEASE "${DISKANN_BOOST_PROGRAM_OPTIONS_LIB}")
|
||||
set_target_properties(Boost::program_options PROPERTIES IMPORTED_LOCATION_DEBUG "${DISKANN_BOOST_PROGRAM_OPTIONS_DLIB}")
|
||||
message(STATUS "Falling back to using Boost from the nuget package")
|
||||
else()
|
||||
message(WARNING "Couldn't find Boost. Was looking for ${DISKANN_BOOST_INCLUDE} and ${PROGRAM_OPTIONS_LIB_PATTERN}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT Boost_FOUND)
|
||||
message(FATAL_ERROR "Couldn't find Boost dependency")
|
||||
endif()
|
||||
|
||||
include_directories(${Boost_INCLUDE_DIR})
|
||||
|
||||
#MKL Config
|
||||
if (MSVC)
|
||||
# Only the DiskANN DLL and one of the tools need MKL libraries. Additionally, only a small part of MKL is used.
|
||||
# Given that and given that MKL DLLs are huge, use static linking to end up with no MKL DLL dependencies and with
|
||||
# significantly smaller disk footprint.
|
||||
#
|
||||
# The compile options are not modified as there's already an unconditional -DMKL_ILP64 define below
|
||||
# for all architectures, which is all that's needed.
|
||||
set(DISKANN_MKL_INCLUDE_DIRECTORIES "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/include")
|
||||
set(DISKANN_MKL_LIB_PATH "${DISKANN_MSVC_PACKAGES}/intelmkl.static.win-x64/lib/native/win-x64")
|
||||
|
||||
set(DISKANN_MKL_LINK_LIBRARIES
|
||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_ilp64.lib"
|
||||
"${DISKANN_MKL_LIB_PATH}/mkl_core.lib"
|
||||
"${DISKANN_MKL_LIB_PATH}/mkl_intel_thread.lib")
|
||||
elseif(APPLE)
|
||||
# no mkl on non-intel devices
|
||||
find_library(ACCELERATE_LIBRARY Accelerate)
|
||||
message(STATUS "Found Accelerate (${ACCELERATE_LIBRARY})")
|
||||
set(DISKANN_ACCEL_LINK_OPTIONS ${ACCELERATE_LIBRARY})
|
||||
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
||||
else()
|
||||
# expected path for manual intel mkl installs
|
||||
set(POSSIBLE_OMP_PATHS "/opt/intel/oneapi/compiler/2025.0/lib/libiomp5.so;/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin/libiomp5.so;/usr/lib/x86_64-linux-gnu/libiomp5.so;/opt/intel/lib/intel64_lin/libiomp5.so")
|
||||
foreach(POSSIBLE_OMP_PATH ${POSSIBLE_OMP_PATHS})
|
||||
if (EXISTS ${POSSIBLE_OMP_PATH})
|
||||
get_filename_component(OMP_PATH ${POSSIBLE_OMP_PATH} DIRECTORY)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT OMP_PATH)
|
||||
message(FATAL_ERROR "Could not find Intel OMP in standard locations; use -DOMP_PATH to specify the install location for your environment")
|
||||
endif()
|
||||
link_directories(${OMP_PATH})
|
||||
|
||||
set(POSSIBLE_MKL_LIB_PATHS "/opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so;/usr/lib/x86_64-linux-gnu/libmkl_core.so;/opt/intel/mkl/lib/intel64/libmkl_core.so")
|
||||
foreach(POSSIBLE_MKL_LIB_PATH ${POSSIBLE_MKL_LIB_PATHS})
|
||||
if (EXISTS ${POSSIBLE_MKL_LIB_PATH})
|
||||
get_filename_component(MKL_PATH ${POSSIBLE_MKL_LIB_PATH} DIRECTORY)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(POSSIBLE_MKL_INCLUDE_PATHS "/opt/intel/oneapi/mkl/latest/include;/usr/include/mkl;/opt/intel/mkl/include/;")
|
||||
foreach(POSSIBLE_MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATHS})
|
||||
if (EXISTS ${POSSIBLE_MKL_INCLUDE_PATH})
|
||||
set(MKL_INCLUDE_PATH ${POSSIBLE_MKL_INCLUDE_PATH})
|
||||
endif()
|
||||
endforeach()
|
||||
if(NOT MKL_PATH)
|
||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_PATH to specify the install location for your environment")
|
||||
elseif(NOT MKL_INCLUDE_PATH)
|
||||
message(FATAL_ERROR "Could not find Intel MKL in standard locations; use -DMKL_INCLUDE_PATH to specify the install location for headers for your environment")
|
||||
endif()
|
||||
if (EXISTS ${MKL_PATH}/libmkl_def.so.2)
|
||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so.2)
|
||||
elseif(EXISTS ${MKL_PATH}/libmkl_def.so)
|
||||
set(MKL_DEF_SO ${MKL_PATH}/libmkl_def.so)
|
||||
else()
|
||||
message(FATAL_ERROR "Despite finding MKL, libmkl_def.so was not found in expected locations.")
|
||||
endif()
|
||||
link_directories(${MKL_PATH})
|
||||
include_directories(${MKL_INCLUDE_PATH})
|
||||
|
||||
# compile flags and link libraries
|
||||
# if gcc/g++
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
add_compile_options(-m64 -Wl,--no-as-needed)
|
||||
endif()
|
||||
if (NOT PYBIND)
|
||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core iomp5 pthread m dl)
|
||||
else()
|
||||
# static linking for python so as to minimize customer dependency issues
|
||||
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
|
||||
# In debug mode, use dynamic linking to ensure all symbols are available
|
||||
link_libraries(mkl_intel_ilp64 mkl_intel_thread mkl_core ${MKL_DEF_SO} iomp5 pthread m dl)
|
||||
else()
|
||||
# In release mode, use static linking to minimize dependencies
|
||||
link_libraries(
|
||||
${MKL_PATH}/libmkl_intel_ilp64.a
|
||||
${MKL_PATH}/libmkl_intel_thread.a
|
||||
${MKL_PATH}/libmkl_core.a
|
||||
${MKL_DEF_SO}
|
||||
iomp5
|
||||
pthread
|
||||
m
|
||||
dl
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_definitions(-DMKL_ILP64)
|
||||
endif()
|
||||
|
||||
|
||||
# Section for tcmalloc. The DiskANN tools are always linked to tcmalloc. For Windows, they also need to
|
||||
# force-include the _tcmalloc symbol for enabling tcmalloc.
|
||||
#
|
||||
# The DLL itself needs to be linked to tcmalloc only if DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS
|
||||
# is enabled.
|
||||
if(USE_TCMALLOC)
|
||||
if (MSVC)
|
||||
if (NOT EXISTS "${PROJECT_SOURCE_DIR}/gperftools/gperftools.sln")
|
||||
message(FATAL_ERROR "The gperftools submodule was not found. "
|
||||
"Please check-out git submodules by doing 'git submodule init' followed by 'git submodule update'")
|
||||
endif()
|
||||
|
||||
set(TCMALLOC_LINK_LIBRARY "${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.lib")
|
||||
set(TCMALLOC_WINDOWS_RUNTIME_FILES
|
||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.dll"
|
||||
"${PROJECT_SOURCE_DIR}/gperftools/x64/Release-Patch/libtcmalloc_minimal.pdb")
|
||||
|
||||
# Tell CMake how to build the tcmalloc linker library from the submodule.
|
||||
add_custom_target(build_libtcmalloc_minimal DEPENDS ${TCMALLOC_LINK_LIBRARY})
|
||||
add_custom_command(OUTPUT ${TCMALLOC_LINK_LIBRARY}
|
||||
COMMAND ${CMAKE_VS_MSBUILD_COMMAND} gperftools.sln /m /nologo
|
||||
/t:libtcmalloc_minimal /p:Configuration="Release-Patch"
|
||||
/property:Platform="x64"
|
||||
/p:PlatformToolset=v${MSVC_TOOLSET_VERSION}
|
||||
/p:WindowsTargetPlatformVersion=${CMAKE_VS_WINDOWS_TARGET_PLATFORM_VERSION}
|
||||
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/gperftools)
|
||||
|
||||
add_library(libtcmalloc_minimal_for_exe STATIC IMPORTED)
|
||||
add_library(libtcmalloc_minimal_for_dll STATIC IMPORTED)
|
||||
|
||||
set_target_properties(libtcmalloc_minimal_for_dll PROPERTIES
|
||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}")
|
||||
|
||||
set_target_properties(libtcmalloc_minimal_for_exe PROPERTIES
|
||||
IMPORTED_LOCATION "${TCMALLOC_LINK_LIBRARY}"
|
||||
INTERFACE_LINK_OPTIONS /INCLUDE:_tcmalloc)
|
||||
|
||||
# Ensure libtcmalloc_minimal is built before it's being used.
|
||||
add_dependencies(libtcmalloc_minimal_for_dll build_libtcmalloc_minimal)
|
||||
add_dependencies(libtcmalloc_minimal_for_exe build_libtcmalloc_minimal)
|
||||
|
||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_exe)
|
||||
elseif(APPLE) # ! Inherited from #474, not been adjusted for TCMalloc Removal
|
||||
execute_process(
|
||||
COMMAND brew --prefix gperftools
|
||||
OUTPUT_VARIABLE GPERFTOOLS_PREFIX
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-L${GPERFTOOLS_PREFIX}/lib -ltcmalloc")
|
||||
elseif(NOT PYBIND)
|
||||
set(DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS "-ltcmalloc")
|
||||
endif()
|
||||
|
||||
if (DISKANN_RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
||||
add_definitions(-DRELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS)
|
||||
|
||||
if (MSVC)
|
||||
set(DISKANN_DLL_TCMALLOC_LINK_OPTIONS libtcmalloc_minimal_for_dll)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT MSVC AND NOT APPLE)
|
||||
set(DISKANN_ASYNC_LIB aio)
|
||||
endif()
|
||||
|
||||
#Main compiler/linker settings
|
||||
if(MSVC)
|
||||
#language options
|
||||
add_compile_options(/permissive- /openmp:experimental /Zc:twoPhase- /Zc:inline /WX- /std:c++17 /Gd /W3 /MP /Zi /FC /nologo)
|
||||
#code generation options
|
||||
add_compile_options(/arch:AVX2 /fp:fast /fp:except- /EHsc /GS- /Gy)
|
||||
#optimization options
|
||||
add_compile_options(/Ot /Oy /Oi)
|
||||
#path options
|
||||
add_definitions(-DUSE_AVX2 -DUSE_ACCELERATED_PQ -D_WINDOWS -DNOMINMAX -DUNICODE)
|
||||
# Linker options. Exclude VCOMP/VCOMPD.LIB which contain VisualStudio's version of OpenMP.
|
||||
# MKL was linked against Intel's OpenMP and depends on the corresponding DLL.
|
||||
add_link_options(/NODEFAULTLIB:VCOMP.LIB /NODEFAULTLIB:VCOMPD.LIB /DEBUG:FULL /OPT:REF /OPT:ICF)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR}/x64/Debug)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR}/x64/Release)
|
||||
elseif(APPLE)
|
||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftree-vectorize -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free -Xclang -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -Wno-inconsistent-missing-override -Wno-return-type")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Ofast -DNDEBUG -ftree-vectorize")
|
||||
if (NOT PYBIND)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
||||
if (NOT PORTABLE)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -mtune=native")
|
||||
endif()
|
||||
else()
|
||||
# -Ofast is not supported in a python extension module
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -fPIC")
|
||||
endif()
|
||||
else()
|
||||
set(ENV{TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD} 500000000000)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2 -mfma -msse2 -ftree-vectorize -fopenmp -fopenmp-simd -funroll-loops -Wfatal-errors -DUSE_AVX2 -fPIC")
|
||||
if(USE_TCMALLOC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free")
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -DDEBUG")
|
||||
if (NOT PYBIND)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG -Ofast")
|
||||
if (NOT PORTABLE)
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native")
|
||||
endif()
|
||||
else()
|
||||
# -Ofast is not supported in a python extension module
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
if (NOT PYBIND)
|
||||
add_subdirectory(apps)
|
||||
add_subdirectory(apps/utils)
|
||||
endif()
|
||||
|
||||
if (UNIT_TEST)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
message(STATUS "The ${PROJECT_NAME}.sln has been created, opened it from VisualStudio to build Release or Debug configurations.\n"
|
||||
"Alternatively, use MSBuild to build:\n\n"
|
||||
"msbuild.exe ${PROJECT_NAME}.sln /m /nologo /t:Build /p:Configuration=\"Release\" /property:Platform=\"x64\"\n")
|
||||
endif()
|
||||
|
||||
if (RESTAPI)
|
||||
if (MSVC)
|
||||
set(DISKANN_CPPRESTSDK "${DISKANN_MSVC_PACKAGES}/cpprestsdk.v142/build/native")
|
||||
# expected path for apt packaged intel mkl installs
|
||||
link_libraries("${DISKANN_CPPRESTSDK}/x64/lib/cpprest142_2_10.lib")
|
||||
include_directories("${DISKANN_CPPRESTSDK}/include")
|
||||
endif()
|
||||
add_subdirectory(apps/restapi)
|
||||
endif()
|
||||
|
||||
include(clang-format.cmake)
|
||||
|
||||
if(PYBIND)
|
||||
add_subdirectory(python)
|
||||
|
||||
install(TARGETS _diskannpy
|
||||
DESTINATION leann_backend_diskann
|
||||
COMPONENT python_modules
|
||||
)
|
||||
|
||||
endif()
|
||||
###############################################################################
|
||||
# PROTOBUF SECTION - Corrected to use CONFIG mode explicitly
|
||||
###############################################################################
|
||||
set(Protobuf_USE_STATIC_LIBS OFF)
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
|
||||
find_package(Protobuf REQUIRED)
|
||||
|
||||
message(STATUS "Protobuf found: ${Protobuf_VERSION}")
|
||||
message(STATUS "Protobuf include dirs: ${Protobuf_INCLUDE_DIRS}")
|
||||
message(STATUS "Protobuf libraries: ${Protobuf_LIBRARIES}")
|
||||
message(STATUS "Protobuf protoc executable: ${Protobuf_PROTOC_EXECUTABLE}")
|
||||
|
||||
include_directories(${Protobuf_INCLUDE_DIRS})
|
||||
|
||||
set(PROTO_FILE "${CMAKE_CURRENT_SOURCE_DIR}/../embedding.proto")
|
||||
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILE})
|
||||
set(generated_proto_sources ${PROTO_SRCS})
|
||||
|
||||
|
||||
add_library(proto_embeddings STATIC ${generated_proto_sources})
|
||||
target_link_libraries(proto_embeddings PUBLIC protobuf::libprotobuf)
|
||||
target_include_directories(proto_embeddings PUBLIC
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Protobuf_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(diskann PRIVATE proto_embeddings protobuf::libprotobuf)
|
||||
target_include_directories(diskann PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Protobuf_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
target_link_libraries(diskann_s PRIVATE proto_embeddings protobuf::libprotobuf)
|
||||
target_include_directories(diskann_s PRIVATE
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Protobuf_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# ZEROMQ SECTION - REQUIRED
|
||||
###############################################################################
|
||||
|
||||
find_package(ZeroMQ QUIET)
|
||||
if(NOT ZeroMQ_FOUND)
|
||||
find_path(ZeroMQ_INCLUDE_DIR zmq.h)
|
||||
find_library(ZeroMQ_LIBRARY zmq)
|
||||
if(ZeroMQ_INCLUDE_DIR AND ZeroMQ_LIBRARY)
|
||||
set(ZeroMQ_FOUND TRUE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(ZeroMQ_FOUND)
|
||||
message(STATUS "Found ZeroMQ: ${ZeroMQ_LIBRARY}")
|
||||
include_directories(${ZeroMQ_INCLUDE_DIR})
|
||||
target_link_libraries(diskann PRIVATE ${ZeroMQ_LIBRARY})
|
||||
target_link_libraries(diskann_s PRIVATE ${ZeroMQ_LIBRARY})
|
||||
add_definitions(-DUSE_ZEROMQ)
|
||||
else()
|
||||
message(FATAL_ERROR "ZeroMQ is required but not found. Please install ZeroMQ and try again.")
|
||||
endif()
|
||||
|
||||
target_link_libraries(diskann ${PYBIND11_LIBRARIES})
|
||||
target_link_libraries(diskann_s ${PYBIND11_LIBRARIES})
|
||||
@@ -1,28 +0,0 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "x64-Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "Release",
|
||||
"inheritEnvironments": [ "msvc_x64" ],
|
||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "",
|
||||
"ctestCommandArgs": ""
|
||||
},
|
||||
{
|
||||
"name": "WSL-GCC-Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "RelWithDebInfo",
|
||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||
"cmakeExecutable": "cmake",
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "",
|
||||
"ctestCommandArgs": "",
|
||||
"inheritEnvironments": [ "linux_x64" ],
|
||||
"wslPath": "${defaultWSLPath}"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
# Microsoft Open Source Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
Resources:
|
||||
|
||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
||||
@@ -1,9 +0,0 @@
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
@@ -1,17 +0,0 @@
|
||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
||||
#Licensed under the MIT license.
|
||||
|
||||
FROM ubuntu:jammy
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y software-properties-common
|
||||
RUN add-apt-repository -y ppa:git-core/ppa
|
||||
RUN apt update
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10
|
||||
|
||||
WORKDIR /app
|
||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
||||
WORKDIR /app/DiskANN
|
||||
RUN mkdir build
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
|
||||
RUN cmake --build build -- -j
|
||||
@@ -1,17 +0,0 @@
|
||||
#Copyright(c) Microsoft Corporation.All rights reserved.
|
||||
#Licensed under the MIT license.
|
||||
|
||||
FROM ubuntu:jammy
|
||||
|
||||
RUN apt update
|
||||
RUN apt install -y software-properties-common
|
||||
RUN add-apt-repository -y ppa:git-core/ppa
|
||||
RUN apt update
|
||||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libboost-test-dev libmkl-full-dev libcpprest-dev python3.10
|
||||
|
||||
WORKDIR /app
|
||||
RUN git clone https://github.com/microsoft/DiskANN.git
|
||||
WORKDIR /app/DiskANN
|
||||
RUN mkdir build
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DUNIT_TEST=True
|
||||
RUN cmake --build build -- -j
|
||||
@@ -1,23 +0,0 @@
|
||||
DiskANN
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -1,12 +0,0 @@
|
||||
include MANIFEST.in
|
||||
include *.txt
|
||||
include *.md
|
||||
include setup.py
|
||||
include pyproject.toml
|
||||
include *.cmake
|
||||
recursive-include gperftools *
|
||||
recursive-include include *
|
||||
recursive-include python *
|
||||
recursive-include windows *
|
||||
prune python/tests
|
||||
recursive-include src *
|
||||
@@ -1,135 +0,0 @@
|
||||
# DiskANN
|
||||
|
||||
[](https://github.com/microsoft/DiskANN/actions/workflows/push-test.yml)
|
||||
[](https://pypi.org/project/diskannpy/)
|
||||
[](https://pepy.tech/project/diskannpy)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
|
||||
[](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf)
|
||||
[](https://arxiv.org/abs/2105.09613)
|
||||
[](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf)
|
||||
|
||||
|
||||
DiskANN is a suite of scalable, accurate and cost-effective approximate nearest neighbor search algorithms for large-scale vector search that support real-time changes and simple filters.
|
||||
This code is based on ideas from the [DiskANN](https://papers.nips.cc/paper/9527-rand-nsg-fast-accurate-billion-point-nearest-neighbor-search-on-a-single-node.pdf), [Fresh-DiskANN](https://arxiv.org/abs/2105.09613) and the [Filtered-DiskANN](https://harsha-simhadri.org/pubs/Filtered-DiskANN23.pdf) papers with further improvements.
|
||||
This code forked off from [code for NSG](https://github.com/ZJULearning/nsg) algorithm.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
See [guidelines](CONTRIBUTING.md) for contributing to this project.
|
||||
|
||||
## Linux build:
|
||||
|
||||
Install the following packages through apt-get
|
||||
|
||||
```bash
|
||||
sudo apt install make cmake g++ libaio-dev libgoogle-perftools-dev clang-format libboost-all-dev
|
||||
```
|
||||
|
||||
### Install Intel MKL
|
||||
#### Ubuntu 20.04 or newer
|
||||
```bash
|
||||
sudo apt install libmkl-full-dev
|
||||
```
|
||||
|
||||
#### Earlier versions of Ubuntu
|
||||
Install Intel MKL either by downloading the [oneAPI MKL installer](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) or using [apt](https://software.intel.com/en-us/articles/installing-intel-free-libs-and-python-apt-repo) (we tested with build 2019.4-070 and 2022.1.2.146).
|
||||
|
||||
```
|
||||
# OneAPI MKL Installer
|
||||
wget https://registrationcenter-download.intel.com/akdlm/irc_nas/18487/l_BaseKit_p_2022.1.2.146.sh
|
||||
sudo sh l_BaseKit_p_2022.1.2.146.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
```
|
||||
|
||||
### Build
|
||||
```bash
|
||||
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
|
||||
```
|
||||
|
||||
## Windows build:
|
||||
|
||||
The Windows version has been tested with Enterprise editions of Visual Studio 2022, 2019 and 2017. It should work with the Community and Professional editions as well without any changes.
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
* CMake 3.15+ (available in VisualStudio 2019+ or from https://cmake.org)
|
||||
* NuGet.exe (install from https://www.nuget.org/downloads)
|
||||
* The build script will use NuGet to get MKL, OpenMP and Boost packages.
|
||||
* DiskANN git repository checked out together with submodules. To check out submodules after git clone:
|
||||
```
|
||||
git submodule init
|
||||
git submodule update
|
||||
```
|
||||
|
||||
* Environment variables:
|
||||
* [optional] If you would like to override the Boost library listed in windows/packages.config.in, set BOOST_ROOT to your Boost folder.
|
||||
|
||||
**Build steps:**
|
||||
* Open the "x64 Native Tools Command Prompt for VS 2019" (or corresponding version) and change to DiskANN folder
|
||||
* Create a "build" directory inside it
|
||||
* Change to the "build" directory and run
|
||||
```
|
||||
cmake ..
|
||||
```
|
||||
OR for Visual Studio 2017 and earlier:
|
||||
```
|
||||
<full-path-to-installed-cmake>\cmake ..
|
||||
```
|
||||
**This will create a diskann.sln solution**. Now you can:
|
||||
|
||||
- Open it from VisualStudio and build either Release or Debug configuration.
|
||||
- `<full-path-to-installed-cmake>\cmake --build build`
|
||||
- Use MSBuild:
|
||||
```
|
||||
msbuild.exe diskann.sln /m /nologo /t:Build /p:Configuration="Release" /property:Platform="x64"
|
||||
```
|
||||
|
||||
* This will also build gperftools submodule for libtcmalloc_minimal dependency.
|
||||
* Generated binaries are stored in the x64/Release or x64/Debug directories.
|
||||
|
||||
## macOS Build
|
||||
|
||||
### Prerequisites
|
||||
* Apple Silicon. The code should still work on Intel-based Macs, but there are no guarantees.
|
||||
* macOS >= 12.0
|
||||
* XCode Command Line Tools (install with `xcode-select --install`)
|
||||
* [homebrew](https://brew.sh/)
|
||||
|
||||
### Install Required Packages
|
||||
```zsh
|
||||
brew install cmake
|
||||
brew install boost
|
||||
brew install gperftools
|
||||
brew install libomp
|
||||
```
|
||||
|
||||
### Build DiskANN
|
||||
```zsh
|
||||
# same as ubuntu instructions
|
||||
mkdir build && cd build && cmake -DCMAKE_BUILD_TYPE=Release .. && make -j
|
||||
```
|
||||
|
||||
## Usage:
|
||||
|
||||
Please see the following pages on using the compiled code:
|
||||
|
||||
- [Commandline interface for building and search SSD based indices](workflows/SSD_index.md)
|
||||
- [Commandline interface for building and search in memory indices](workflows/in_memory_index.md)
|
||||
- [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md)
|
||||
- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md)
|
||||
- [Commandline interface for building and search SSD based indices with label data and filters](workflows/filtered_ssd_index.md)
|
||||
- [diskannpy - DiskANN as a python extension module](python/README.md)
|
||||
|
||||
Please cite this software in your work as:
|
||||
|
||||
```
|
||||
@misc{diskann-github,
|
||||
author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan and Patel, Yash}},
|
||||
title = {{DiskANN: Graph-structured Indices for Scalable, Fast, Fresh and Filtered Approximate Nearest Neighbor Search}},
|
||||
url = {https://github.com/Microsoft/DiskANN},
|
||||
version = {0.6.1},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
@@ -1,41 +0,0 @@
|
||||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
||||
@@ -1,42 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
add_executable(build_memory_index build_memory_index.cpp)
|
||||
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_stitched_index build_stitched_index.cpp)
|
||||
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(search_memory_index search_memory_index.cpp)
|
||||
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_disk_index build_disk_index.cpp)
|
||||
target_link_libraries(build_disk_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
add_executable(search_disk_index search_disk_index.cpp)
|
||||
target_link_libraries(search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(range_search_disk_index range_search_disk_index.cpp)
|
||||
target_link_libraries(range_search_disk_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(test_streaming_scenario test_streaming_scenario.cpp)
|
||||
target_link_libraries(test_streaming_scenario ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(test_insert_deletes_consolidate test_insert_deletes_consolidate.cpp)
|
||||
target_link_libraries(test_insert_deletes_consolidate ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
if (NOT MSVC)
|
||||
install(TARGETS build_memory_index
|
||||
build_stitched_index
|
||||
search_memory_index
|
||||
build_disk_index
|
||||
search_disk_index
|
||||
range_search_disk_index
|
||||
test_streaming_scenario
|
||||
test_insert_deletes_consolidate
|
||||
RUNTIME
|
||||
)
|
||||
endif()
|
||||
@@ -1,191 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "index.h"
|
||||
#include "partition.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
|
||||
label_type;
|
||||
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
|
||||
float B, M;
|
||||
bool append_reorder_data = false;
|
||||
bool use_opq = false;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("search_DRAM_budget,B", po::value<float>(&B)->required(),
|
||||
"DRAM budget in GB for searching the index to set the "
|
||||
"compressed level for data while search happens");
|
||||
required_configs.add_options()("build_DRAM_budget,M", po::value<float>(&M)->required(),
|
||||
"DRAM budget in GB for building the index");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("QD", po::value<uint32_t>(&QD)->default_value(0),
|
||||
" Quantized Dimension for compression");
|
||||
optional_configs.add_options()("codebook_prefix", po::value<std::string>(&codebook_prefix)->default_value(""),
|
||||
"Path prefix for pre-trained codebook");
|
||||
optional_configs.add_options()("PQ_disk_bytes", po::value<uint32_t>(&disk_PQ)->default_value(0),
|
||||
"Number of bytes to which vectors should be compressed "
|
||||
"on SSD; 0 for no compression");
|
||||
optional_configs.add_options()("append_reorder_data", po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD.");
|
||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
|
||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
program_options_utils::USE_OPQ);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
program_options_utils::FILTERED_LBUILD);
|
||||
optional_configs.add_options()("filter_threshold,F", po::value<uint32_t>(&filter_threshold)->default_value(0),
|
||||
"Threshold to break up the existing nodes to generate new graph "
|
||||
"internally where each node has a maximum F labels.");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (vm["append_reorder_data"].as<bool>())
|
||||
append_reorder_data = true;
|
||||
if (vm["use_opq"].as<bool>())
|
||||
use_opq = true;
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool use_filters = (label_file != "") ? true : false;
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
metric = diskann::Metric::COSINE;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (append_reorder_data)
|
||||
{
|
||||
if (disk_PQ == 0)
|
||||
{
|
||||
std::cout << "Error: It is not necessary to append data for reordering "
|
||||
"when vectors are not compressed on disk."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (data_type != std::string("float"))
|
||||
{
|
||||
std::cout << "Error: Appending data for reordering currently only "
|
||||
"supported for float data type."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
std::string params = std::string(std::to_string(R)) + " " + std::string(std::to_string(L)) + " " +
|
||||
std::string(std::to_string(B)) + " " + std::string(std::to_string(M)) + " " +
|
||||
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
|
||||
std::string(std::to_string(append_reorder_data)) + " " +
|
||||
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));
|
||||
|
||||
try
|
||||
{
|
||||
if (label_file != "" && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
|
||||
use_filters, label_file, universal_label, filter_threshold, Lf);
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, codebook_prefix, use_filters, label_file,
|
||||
universal_label, filter_threshold, Lf);
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <cstring>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
#include "ann_exception.h"
|
||||
#include "index_factory.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
||||
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
|
||||
float alpha;
|
||||
bool use_pq_build, use_opq;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
|
||||
program_options_utils::BUIlD_GRAPH_PQ_BYTES);
|
||||
optional_configs.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
program_options_utils::USE_OPQ);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
|
||||
optional_configs.add_options()("FilteredLbuild", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
program_options_utils::FILTERED_LBUILD);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
use_pq_build = (build_PQ_bytes > 0);
|
||||
use_opq = vm["use_opq"].as<bool>();
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
|
||||
<< " #threads: " << num_threads << std::endl;
|
||||
|
||||
size_t data_num, data_dim;
|
||||
diskann::get_bin_metadata(data_path, data_num, data_dim);
|
||||
|
||||
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_filter_list_size(Lf)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
auto filter_params = diskann::IndexFilterParamsBuilder()
|
||||
.with_universal_label(universal_label)
|
||||
.with_label_file(label_file)
|
||||
.with_save_path_prefix(index_path_prefix)
|
||||
.build();
|
||||
auto config = diskann::IndexConfigBuilder()
|
||||
.with_metric(metric)
|
||||
.with_dimension(data_dim)
|
||||
.with_max_points(data_num)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.with_data_type(data_type)
|
||||
.with_label_type(label_type)
|
||||
.is_dynamic_index(false)
|
||||
.with_index_write_params(index_build_params)
|
||||
.is_enable_tags(false)
|
||||
.is_use_opq(use_opq)
|
||||
.is_pq_dist_build(use_pq_build)
|
||||
.with_num_pq_chunks(build_PQ_bytes)
|
||||
.build();
|
||||
|
||||
auto index_factory = diskann::IndexFactory(config);
|
||||
auto index = index_factory.create_instance();
|
||||
index->build(data_path, data_num, filter_params);
|
||||
index->save(index_path_prefix.c_str());
|
||||
index.reset();
|
||||
return 0;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,441 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "filter_utils.h"
|
||||
#include <omp.h>
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/uio.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "parameters.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
typedef std::tuple<std::vector<std::vector<uint32_t>>, uint64_t> stitch_indices_return_values;
|
||||
|
||||
/*
|
||||
* Inline function to display progress bar.
|
||||
*/
|
||||
inline void print_progress(double percentage)
|
||||
{
|
||||
int val = (int)(percentage * 100);
|
||||
int lpad = (int)(percentage * PBWIDTH);
|
||||
int rpad = PBWIDTH - lpad;
|
||||
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Inline function to generate a random integer in a range.
|
||||
*/
|
||||
inline size_t random(size_t range_from, size_t range_to)
|
||||
{
|
||||
std::random_device rand_dev;
|
||||
std::mt19937 generator(rand_dev());
|
||||
std::uniform_int_distribution<size_t> distr(range_from, range_to);
|
||||
return distr(generator);
|
||||
}
|
||||
|
||||
/*
|
||||
* function to handle command line parsing.
|
||||
*
|
||||
* Arguments are merely the inputs from the command line.
|
||||
*/
|
||||
void handle_args(int argc, char **argv, std::string &data_type, path &input_data_path, path &final_index_path_prefix,
|
||||
path &label_data_path, std::string &universal_label, uint32_t &num_threads, uint32_t &R, uint32_t &L,
|
||||
uint32_t &stitched_R, float &alpha)
|
||||
{
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("build_stitched_index", "Build a stitched DiskANN index.")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix",
|
||||
po::value<std::string>(&final_index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&input_data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_data_path)->default_value(""),
|
||||
program_options_utils::LABEL_FILE);
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
program_options_utils::UNIVERSAL_LABEL);
|
||||
optional_configs.add_options()("stitched_R", po::value<uint32_t>(&stitched_R)->default_value(100),
|
||||
"Degree to prune final graph down to");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
exit(0);
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Custom index save to write the in-memory index to disk.
|
||||
* Also writes required files for diskANN API -
|
||||
* 1. labels_to_medoids
|
||||
* 2. universal_label
|
||||
* 3. data (redundant for static indices)
|
||||
* 4. labels (redundant for static indices)
|
||||
*/
|
||||
void save_full_index(path final_index_path_prefix, path input_data_path, uint64_t final_index_size,
|
||||
std::vector<std::vector<uint32_t>> stitched_graph,
|
||||
tsl::robin_map<std::string, uint32_t> entry_points, std::string universal_label,
|
||||
path label_data_path)
|
||||
{
|
||||
// aux. file 1
|
||||
auto saving_index_timer = std::chrono::high_resolution_clock::now();
|
||||
std::ifstream original_label_data_stream;
|
||||
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_label_data_stream.open(label_data_path, std::ios::binary);
|
||||
std::ofstream new_label_data_stream;
|
||||
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_label_data_stream.open(final_index_path_prefix + "_labels.txt", std::ios::binary);
|
||||
new_label_data_stream << original_label_data_stream.rdbuf();
|
||||
original_label_data_stream.close();
|
||||
new_label_data_stream.close();
|
||||
|
||||
// aux. file 2
|
||||
std::ifstream original_input_data_stream;
|
||||
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_input_data_stream.open(input_data_path, std::ios::binary);
|
||||
std::ofstream new_input_data_stream;
|
||||
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_input_data_stream.open(final_index_path_prefix + ".data", std::ios::binary);
|
||||
new_input_data_stream << original_input_data_stream.rdbuf();
|
||||
original_input_data_stream.close();
|
||||
new_input_data_stream.close();
|
||||
|
||||
// aux. file 3
|
||||
std::ofstream labels_to_medoids_writer;
|
||||
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
labels_to_medoids_writer.open(final_index_path_prefix + "_labels_to_medoids.txt");
|
||||
for (auto iter : entry_points)
|
||||
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
|
||||
labels_to_medoids_writer.close();
|
||||
|
||||
// aux. file 4 (only if we're using a universal label)
|
||||
if (universal_label != "")
|
||||
{
|
||||
std::ofstream universal_label_writer;
|
||||
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
universal_label_writer.open(final_index_path_prefix + "_universal_label.txt");
|
||||
universal_label_writer << universal_label << std::endl;
|
||||
universal_label_writer.close();
|
||||
}
|
||||
|
||||
// main index
|
||||
uint64_t index_num_frozen_points = 0, index_num_edges = 0;
|
||||
uint32_t index_max_observed_degree = 0, index_entry_point = 0;
|
||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
||||
for (auto &point_neighbors : stitched_graph)
|
||||
{
|
||||
index_max_observed_degree = std::max(index_max_observed_degree, (uint32_t)point_neighbors.size());
|
||||
}
|
||||
|
||||
std::ofstream stitched_graph_writer;
|
||||
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
|
||||
|
||||
stitched_graph_writer.write((char *)&final_index_size, sizeof(uint64_t));
|
||||
stitched_graph_writer.write((char *)&index_max_observed_degree, sizeof(uint32_t));
|
||||
stitched_graph_writer.write((char *)&index_entry_point, sizeof(uint32_t));
|
||||
stitched_graph_writer.write((char *)&index_num_frozen_points, sizeof(uint64_t));
|
||||
|
||||
size_t bytes_written = METADATA;
|
||||
for (uint32_t node_point = 0; node_point < stitched_graph.size(); node_point++)
|
||||
{
|
||||
uint32_t current_node_num_neighbors = (uint32_t)stitched_graph[node_point].size();
|
||||
std::vector<uint32_t> current_node_neighbors = stitched_graph[node_point];
|
||||
stitched_graph_writer.write((char *)¤t_node_num_neighbors, sizeof(uint32_t));
|
||||
bytes_written += sizeof(uint32_t);
|
||||
for (const auto ¤t_node_neighbor : current_node_neighbors)
|
||||
{
|
||||
stitched_graph_writer.write((char *)¤t_node_neighbor, sizeof(uint32_t));
|
||||
bytes_written += sizeof(uint32_t);
|
||||
}
|
||||
index_num_edges += current_node_num_neighbors;
|
||||
}
|
||||
|
||||
if (bytes_written != final_index_size)
|
||||
{
|
||||
std::cerr << "Error: written bytes does not match allocated space" << std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
stitched_graph_writer.close();
|
||||
|
||||
std::chrono::duration<double> saving_index_time = std::chrono::high_resolution_clock::now() - saving_index_timer;
|
||||
std::cout << "Stitched graph written in " << saving_index_time.count() << " seconds" << std::endl;
|
||||
std::cout << "Stitched graph average degree: " << ((float)index_num_edges) / ((float)(stitched_graph.size()))
|
||||
<< std::endl;
|
||||
std::cout << "Stitched graph max degree: " << index_max_observed_degree << std::endl << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Unions the per-label graph indices together via the following policy:
|
||||
* - any two nodes can only have at most one edge between them -
|
||||
*
|
||||
* Returns the "stitched" graph and its expected file size.
|
||||
*/
|
||||
template <typename T>
|
||||
stitch_indices_return_values stitch_label_indices(
|
||||
path final_index_path_prefix, uint32_t total_number_of_points, label_set all_labels,
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points,
|
||||
tsl::robin_map<std::string, uint32_t> &label_entry_points,
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map)
|
||||
{
|
||||
size_t final_index_size = 0;
|
||||
std::vector<std::vector<uint32_t>> stitched_graph(total_number_of_points);
|
||||
|
||||
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
std::vector<std::vector<uint32_t>> curr_label_index;
|
||||
uint64_t curr_label_index_size;
|
||||
uint32_t curr_label_entry_point;
|
||||
|
||||
std::tie(curr_label_index, curr_label_index_size) =
|
||||
diskann::load_label_index(curr_label_index_path, labels_to_number_of_points[lbl]);
|
||||
curr_label_entry_point = (uint32_t)random(0, curr_label_index.size());
|
||||
label_entry_points[lbl] = label_id_to_orig_id_map[lbl][curr_label_entry_point];
|
||||
|
||||
for (uint32_t node_point = 0; node_point < curr_label_index.size(); node_point++)
|
||||
{
|
||||
uint32_t original_point_id = label_id_to_orig_id_map[lbl][node_point];
|
||||
for (auto &node_neighbor : curr_label_index[node_point])
|
||||
{
|
||||
uint32_t original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
|
||||
std::vector<uint32_t> curr_point_neighbors = stitched_graph[original_point_id];
|
||||
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), original_neighbor_id) ==
|
||||
curr_point_neighbors.end())
|
||||
{
|
||||
stitched_graph[original_point_id].push_back(original_neighbor_id);
|
||||
final_index_size += sizeof(uint32_t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t METADATA = 2 * sizeof(uint64_t) + 2 * sizeof(uint32_t);
|
||||
final_index_size += (total_number_of_points * sizeof(uint32_t) + METADATA);
|
||||
|
||||
std::chrono::duration<double> stitching_index_time =
|
||||
std::chrono::high_resolution_clock::now() - stitching_index_timer;
|
||||
std::cout << "stitched graph generated in memory in " << stitching_index_time.count() << " seconds" << std::endl;
|
||||
|
||||
return std::make_tuple(stitched_graph, final_index_size);
|
||||
}
|
||||
|
||||
/*
|
||||
* Applies the prune_neighbors function from src/index.cpp to
|
||||
* every node in the stitched graph.
|
||||
*
|
||||
* This is an optional step, hence the saving of both the full
|
||||
* and pruned graph.
|
||||
*/
|
||||
template <typename T>
|
||||
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, path input_data_path,
|
||||
std::vector<std::vector<uint32_t>> stitched_graph, uint32_t stitched_R,
|
||||
tsl::robin_map<std::string, uint32_t> label_entry_points, std::string universal_label,
|
||||
path label_data_path, uint32_t num_threads)
|
||||
{
|
||||
size_t dimension, number_of_label_points;
|
||||
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
|
||||
auto std_cout_buffer = std::cout.rdbuf(nullptr);
|
||||
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
|
||||
|
||||
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
|
||||
|
||||
diskann::Index<T> index(diskann::Metric::L2, dimension, number_of_label_points, nullptr, nullptr, 0, false, false,
|
||||
false, false, 0, false);
|
||||
|
||||
// not searching this index, set search_l to 0
|
||||
index.load(full_index_path_prefix.c_str(), num_threads, 1);
|
||||
|
||||
std::cout << "parsing labels" << std::endl;
|
||||
|
||||
index.prune_all_neighbors(stitched_R, 750, 1.2);
|
||||
index.save((final_index_path_prefix).c_str());
|
||||
|
||||
diskann::cout.rdbuf(diskann_cout_buffer);
|
||||
std::cout.rdbuf(std_cout_buffer);
|
||||
std::chrono::duration<double> pruning_index_time = std::chrono::high_resolution_clock::now() - pruning_index_timer;
|
||||
std::cout << "pruning performed in " << pruning_index_time.count() << " seconds\n" << std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Delete all temporary artifacts.
|
||||
* In the process of creating the stitched index, some temporary artifacts are
|
||||
* created:
|
||||
* 1. the separate bin files for each labels' points
|
||||
* 2. the separate diskANN indices built for each label
|
||||
* 3. the '.data' file created while generating the indices
|
||||
*/
|
||||
void clean_up_artifacts(path input_data_path, path final_index_path_prefix, label_set all_labels)
|
||||
{
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
path curr_label_index_path_data(curr_label_index_path + ".data");
|
||||
|
||||
if (std::remove(curr_label_index_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_input_data_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_index_path_data.c_str()) != 0)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
// 1. handle cmdline inputs
|
||||
std::string data_type;
|
||||
path input_data_path, final_index_path_prefix, label_data_path;
|
||||
std::string universal_label;
|
||||
uint32_t num_threads, R, L, stitched_R;
|
||||
float alpha;
|
||||
|
||||
auto index_timer = std::chrono::high_resolution_clock::now();
|
||||
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, label_data_path, universal_label,
|
||||
num_threads, R, L, stitched_R, alpha);
|
||||
|
||||
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
|
||||
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
|
||||
|
||||
convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label);
|
||||
|
||||
// 2. parse label file and create necessary data structures
|
||||
std::vector<label_set> point_ids_to_labels;
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
||||
label_set all_labels;
|
||||
|
||||
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
|
||||
diskann::parse_label_file(labels_file_to_use, universal_label);
|
||||
|
||||
// 3. for each label, make a separate data file
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id_map;
|
||||
uint32_t total_number_of_points = (uint32_t)point_ids_to_labels.size();
|
||||
|
||||
#ifndef _WINDOWS
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else
|
||||
throw;
|
||||
#else
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
else
|
||||
throw;
|
||||
#endif
|
||||
|
||||
// 4. for each created data file, create a vanilla diskANN index
|
||||
if (data_type == "uint8")
|
||||
diskann::generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else if (data_type == "int8")
|
||||
diskann::generate_label_indices<int8_t>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else if (data_type == "float")
|
||||
diskann::generate_label_indices<float>(input_data_path, final_index_path_prefix, all_labels, R, L, alpha,
|
||||
num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
// 5. "stitch" the indices together
|
||||
std::vector<std::vector<uint32_t>> stitched_graph;
|
||||
tsl::robin_map<std::string, uint32_t> label_entry_points;
|
||||
uint64_t stitched_graph_size;
|
||||
|
||||
if (data_type == "uint8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<uint8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else if (data_type == "int8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<int8_t>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else if (data_type == "float")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<float>(final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points, label_id_to_orig_id_map);
|
||||
else
|
||||
throw;
|
||||
path full_index_path_prefix = final_index_path_prefix + "_full";
|
||||
// 5a. save the stitched graph to disk
|
||||
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, stitched_graph, label_entry_points,
|
||||
universal_label, labels_file_to_use);
|
||||
|
||||
// 6. run a prune on the stitched index, and save to disk
|
||||
if (data_type == "uint8")
|
||||
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else if (data_type == "int8")
|
||||
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else if (data_type == "float")
|
||||
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix, input_data_path, stitched_graph,
|
||||
stitched_R, label_entry_points, universal_label, labels_file_to_use, num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
std::chrono::duration<double> index_time = std::chrono::high_resolution_clock::now() - index_timer;
|
||||
std::cout << "pruned/stitched graph generated in " << index_time.count() << " seconds" << std::endl;
|
||||
|
||||
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
<!-- Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license. -->
|
||||
|
||||
# Integration Tests
|
||||
The following tests use Python to prepare, run, verify, and tear down the rest api services.
|
||||
|
||||
We do make use of the built-in `unittest` library, but that's only to take advantage of test reporting purposes.
|
||||
|
||||
These are decidedly **not** _unit_ tests. These are end to end integration tests.
|
||||
|
||||
## Caveats
|
||||
This has only been tested or built for Linux, though we have written platform agnostic Python for the smoke test
|
||||
(i.e. using `os.path.join`, etc)
|
||||
|
||||
It has been tested on Python 3.9 and 3.10, but should work on Python 3.6+.
|
||||
|
||||
## How to Run
|
||||
|
||||
First, build the DiskANN RestAPI code; see $REPOSITORY_ROOT/workflows/rest_api.md for detailed instructions.
|
||||
|
||||
```bash
|
||||
cd tests/python
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
export DISKANN_BUILD_DIR=/path/to/your/diskann/build
|
||||
python -m unittest
|
||||
```
|
||||
|
||||
## Smoke Test Failed, Now What?
|
||||
The smoke test written takes advantage of temporary directories that are only valid during the
|
||||
lifetime of the test. The contents of these directories include:
|
||||
- Randomized vectors (first in tsv, then bin form) used to build the PQFlashIndex
|
||||
- The PQFlashIndex files
|
||||
|
||||
It is useful to keep these around. By setting some environment variables, you can control whether an ephemeral,
|
||||
temporary directory is used (and deleted on test completion), or left as an exercise for the developer to
|
||||
clean up.
|
||||
|
||||
The valid environment variables are:
|
||||
- `DISKANN_REST_TEST_WORKING_DIR` (example: `$USER/DiskANNRestTest`)
|
||||
- If this is specified, it **must exist** and **must be writeable**. Any existing files will be clobbered.
|
||||
- `DISKANN_REST_SERVER` (example: `http://127.0.0.1:10067`)
|
||||
- Note that if this is set, no data will be generated, nor will a server be started; it is presumed you have done
|
||||
all the work in creating and starting the rest server prior to running the test and just submits requests against it.
|
||||
@@ -1,67 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
def output_vectors(
|
||||
diskann_build_path: str,
|
||||
temporary_file_path: str,
|
||||
vectors: np.ndarray,
|
||||
timeout: int = 60
|
||||
) -> str:
|
||||
vectors_as_tsv_path = os.path.join(temporary_file_path, "vectors.tsv")
|
||||
with open(vectors_as_tsv_path, "w") as vectors_tsv_out:
|
||||
for vector in vectors:
|
||||
as_str = "\t".join((str(component) for component in vector))
|
||||
print(as_str, file=vectors_tsv_out)
|
||||
# there is probably a clever way to have numpy write out C++ friendly floats, so feel free to remove this in
|
||||
# favor of something more sane later
|
||||
vectors_as_bin_path = os.path.join(temporary_file_path, "vectors.bin")
|
||||
tsv_to_bin_path = os.path.join(diskann_build_path, "apps", "utils", "tsv_to_bin")
|
||||
|
||||
number_of_points, dimensions = vectors.shape
|
||||
args = [
|
||||
tsv_to_bin_path,
|
||||
"float",
|
||||
vectors_as_tsv_path,
|
||||
vectors_as_bin_path,
|
||||
str(dimensions),
|
||||
str(number_of_points)
|
||||
]
|
||||
completed = subprocess.run(args, timeout=timeout)
|
||||
if completed.returncode != 0:
|
||||
raise Exception(f"Unable to convert tsv to binary using tsv_to_bin, completed_process: {completed}")
|
||||
return vectors_as_bin_path
|
||||
|
||||
|
||||
def build_ssd_index(
|
||||
diskann_build_path: str,
|
||||
temporary_file_path: str,
|
||||
vectors: np.ndarray,
|
||||
per_process_timeout: int = 60 # this may not be long enough if you're doing something larger
|
||||
):
|
||||
vectors_as_bin_path = output_vectors(diskann_build_path, temporary_file_path, vectors, timeout=per_process_timeout)
|
||||
|
||||
ssd_builder_path = os.path.join(diskann_build_path, "apps", "build_disk_index")
|
||||
args = [
|
||||
ssd_builder_path,
|
||||
"--data_type", "float",
|
||||
"--dist_fn", "l2",
|
||||
"--data_path", vectors_as_bin_path,
|
||||
"--index_path_prefix", os.path.join(temporary_file_path, "smoke_test"),
|
||||
"-R", "64",
|
||||
"-L", "100",
|
||||
"--search_DRAM_budget", "1",
|
||||
"--build_DRAM_budget", "1",
|
||||
"--num_threads", "1",
|
||||
"--PQ_disk_bytes", "0"
|
||||
]
|
||||
completed = subprocess.run(args, timeout=per_process_timeout)
|
||||
|
||||
if completed.returncode != 0:
|
||||
command_run = " ".join(args)
|
||||
raise Exception(f"Unable to build a disk index with the command: '{command_run}'\ncompleted_process: {completed}\nstdout: {completed.stdout}\nstderr: {completed.stderr}")
|
||||
# index is now built inside of temporary_file_path
|
||||
@@ -1,379 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "partition.h"
|
||||
#include "timer.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include "linux_aligned_file_reader.h"
|
||||
#else
|
||||
#ifdef USE_BING_INFRA
|
||||
#include "bing_aligned_file_reader.h"
|
||||
#else
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
#define WARMUP false
|
||||
|
||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
||||
{
|
||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << std::setw(22) << " " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(9) << results[s];
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &query_file,
|
||||
std::string >_file, const uint32_t num_threads, const float search_range,
|
||||
const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const std::vector<uint32_t> &Lvec)
|
||||
{
|
||||
std::string pq_prefix = index_path_prefix + "_pq";
|
||||
std::string disk_index_file = index_path_prefix + "_disk.index";
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::endl;
|
||||
else
|
||||
diskann::cout << " beamwidth: " << beamwidth << std::endl;
|
||||
|
||||
// load query bin
|
||||
T *query = nullptr;
|
||||
std::vector<std::vector<uint32_t>> groundtruth_ids;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (gt_file != std::string("null") && file_exists(gt_file))
|
||||
{
|
||||
diskann::load_range_truthset(gt_file, groundtruth_ids,
|
||||
gt_num); // use for range search type of truthset
|
||||
// diskann::prune_truthset_for_range(gt_file, search_range,
|
||||
// groundtruth_ids, gt_num); // use for traditional truthset
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
reader.reset(new WindowsAlignedFileReader());
|
||||
#else
|
||||
reader.reset(new diskann::BingAlignedFileReader());
|
||||
#endif
|
||||
#else
|
||||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
if (res != 0)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
// cache bfs levels
|
||||
std::vector<uint32_t> node_list;
|
||||
diskann::cout << "Caching " << num_nodes_to_cache << " BFS nodes around medoid(s)" << std::endl;
|
||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
// _pFlashIndex->generate_cache_list_from_sample_queries(
|
||||
// warmup_query_file, 15, 6, num_nodes_to_cache, num_threads,
|
||||
// node_list);
|
||||
_pFlashIndex->load_cache_list(node_list);
|
||||
node_list.clear();
|
||||
node_list.shrink_to_fit();
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
uint64_t warmup_L = 20;
|
||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
||||
T *warmup = nullptr;
|
||||
|
||||
if (WARMUP)
|
||||
{
|
||||
if (file_exists(warmup_query_file))
|
||||
{
|
||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
||||
warmup_dim = query_dim;
|
||||
warmup_aligned_dim = query_aligned_dim;
|
||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++)
|
||||
{
|
||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
||||
{
|
||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
diskann::cout << "Warming up index... " << std::flush;
|
||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
||||
warmup_result_ids_64.data() + (i * 1),
|
||||
warmup_result_dists.data() + (i * 1), 4);
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
diskann::cout.precision(2);
|
||||
|
||||
std::string recall_string = "Recall@rng=" + std::to_string(search_range);
|
||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
||||
<< "CPU (s)";
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << "==============================================================="
|
||||
"==========================================="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<std::vector<uint32_t>>> query_result_ids(Lvec.size());
|
||||
|
||||
uint32_t optimized_beamwidth = 2;
|
||||
uint32_t max_list_size = 10000;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
|
||||
if (beamwidth <= 0)
|
||||
{
|
||||
optimized_beamwidth =
|
||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
||||
}
|
||||
else
|
||||
optimized_beamwidth = beamwidth;
|
||||
|
||||
query_result_ids[test_id].clear();
|
||||
query_result_ids[test_id].resize(query_num);
|
||||
|
||||
diskann::QueryStats *stats = new diskann::QueryStats[query_num];
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
std::vector<uint64_t> indices;
|
||||
std::vector<float> distances;
|
||||
uint32_t res_count =
|
||||
_pFlashIndex->range_search(query + (i * query_aligned_dim), search_range, L, max_list_size, indices,
|
||||
distances, optimized_beamwidth, stats + i);
|
||||
query_result_ids[test_id][i].reserve(res_count);
|
||||
query_result_ids[test_id][i].resize(res_count);
|
||||
for (uint32_t idx = 0; idx < res_count; idx++)
|
||||
query_result_ids[test_id][i][idx] = (uint32_t)indices[idx];
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
auto qps = (1.0 * query_num) / (1.0 * diff.count());
|
||||
|
||||
auto mean_latency = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
||||
|
||||
double mean_cpuus = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
||||
|
||||
double recall = 0;
|
||||
double ratio_of_sums = 0;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recall =
|
||||
diskann::calculate_range_search_recall((uint32_t)query_num, groundtruth_ids, query_result_ids[test_id]);
|
||||
|
||||
uint32_t total_true_positive = 0;
|
||||
uint32_t total_positive = 0;
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
total_true_positive += (uint32_t)query_result_ids[test_id][i].size();
|
||||
total_positive += (uint32_t)groundtruth_ids[i].size();
|
||||
}
|
||||
|
||||
ratio_of_sums = (1.0 * total_true_positive) / (1.0 * total_positive);
|
||||
}
|
||||
|
||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
||||
<< std::setw(16) << mean_cpuus;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall << "," << ratio_of_sums << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout << "Done searching. " << std::endl;
|
||||
|
||||
diskann::aligned_free(query);
|
||||
if (warmup != nullptr)
|
||||
diskann::aligned_free(warmup);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file;
|
||||
uint32_t num_threads, W, num_nodes_to_cache;
|
||||
std::vector<uint32_t> Lvec;
|
||||
float range;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description(
|
||||
"range_search_disk_index", "Searches disk DiskANN indexes using ranges")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
required_configs.add_options()("range_threshold,K", po::value<float>(&range)->required(),
|
||||
"Number of neighbors to be returned");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
||||
program_options_utils::BEAMWIDTH);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
||||
{
|
||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(metric, index_path_prefix, query_file, gt_file, num_threads, range, W,
|
||||
num_nodes_to_cache, Lvec);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
|
||||
add_executable(inmem_server inmem_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(inmem_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(inmem_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(inmem_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(inmem_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(ssd_server ssd_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(ssd_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(ssd_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(ssd_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(ssd_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(multiple_ssdindex_server multiple_ssdindex_server.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(multiple_ssdindex_server PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(multiple_ssdindex_server debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(multiple_ssdindex_server optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(multiple_ssdindex_server ${PROJECT_NAME} aio -ltcmalloc -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
|
||||
add_executable(client client.cpp)
|
||||
if(MSVC)
|
||||
target_link_options(client PRIVATE /MACHINE:x64)
|
||||
target_link_libraries(client debug ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}/diskann_dll.lib Boost::program_options)
|
||||
target_link_libraries(client optimized ${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}/diskann_dll.lib Boost::program_options)
|
||||
else()
|
||||
target_link_libraries(client ${PROJECT_NAME} -lboost_system -lcrypto -lssl -lcpprest Boost::program_options)
|
||||
endif()
|
||||
@@ -1,124 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include <cpprest/http_client.h>
|
||||
#include <restapi/common.h>
|
||||
|
||||
using namespace web;
|
||||
using namespace web::http;
|
||||
using namespace web::http::client;
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T>
|
||||
void query_loop(const std::string &ip_addr_port, const std::string &query_file, const unsigned nq, const unsigned Ls,
|
||||
const unsigned k_value)
|
||||
{
|
||||
web::http::client::http_client client(U(ip_addr_port));
|
||||
|
||||
T *data;
|
||||
size_t npts = 1, ndims = 128, rounded_dim = 128;
|
||||
diskann::load_aligned_bin<T>(query_file, data, npts, ndims, rounded_dim);
|
||||
|
||||
for (unsigned i = 0; i < nq; ++i)
|
||||
{
|
||||
T *vec = data + i * rounded_dim;
|
||||
web::http::http_request http_query(methods::POST);
|
||||
web::json::value queryJson = web::json::value::object();
|
||||
queryJson[QUERY_ID_KEY] = i;
|
||||
queryJson[K_KEY] = k_value;
|
||||
queryJson[L_KEY] = Ls;
|
||||
for (size_t i = 0; i < ndims; ++i)
|
||||
{
|
||||
queryJson[VECTOR_KEY][i] = web::json::value::number(vec[i]);
|
||||
}
|
||||
http_query.set_body(queryJson);
|
||||
|
||||
client.request(http_query)
|
||||
.then([](web::http::http_response response) -> pplx::task<utility::string_t> {
|
||||
if (response.status_code() == status_codes::OK)
|
||||
{
|
||||
return response.extract_string();
|
||||
}
|
||||
std::cerr << "Query failed" << std::endl;
|
||||
return pplx::task_from_result(utility::string_t());
|
||||
})
|
||||
.then([](pplx::task<utility::string_t> previousTask) {
|
||||
try
|
||||
{
|
||||
std::cout << previousTask.get() << std::endl;
|
||||
}
|
||||
catch (http_exception const &e)
|
||||
{
|
||||
std::wcout << e.what() << std::endl;
|
||||
}
|
||||
})
|
||||
.wait();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, query_file, address;
|
||||
uint32_t num_queries;
|
||||
uint32_t l_search, k_value;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the queries to search");
|
||||
desc.add_options()("num_queries,Q", po::value<uint32_t>(&num_queries)->required(),
|
||||
"Number of queries to search");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
||||
desc.add_options()("k_value,K", po::value<uint32_t>(&k_value)->default_value(10), "Value of K (default 10)");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
query_loop<float>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
query_loop<int8_t>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
query_loop<uint8_t>(address, query_file, num_queries, l_search, k_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported type " << argv[2] << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_inMemorySearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
|
||||
uint32_t num_threads;
|
||||
uint32_t l_search;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("data_file", po::value<std::string>(&data_file)->required(),
|
||||
"File containing the data found in the index");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_file)->required(),
|
||||
"Path prefix for saving index file components");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->required(),
|
||||
"Number of threads used for building index");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(), "Value of L");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<float>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <restapi/server.h>
|
||||
#include <restapi/in_memory_search.h>
|
||||
#include <codecvt>
|
||||
#include <iostream>
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::unique_ptr<diskann::InMemorySearch> g_inMemorySearch(nullptr);
|
||||
|
||||
void setup(const utility::string_t &address)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::wcout << L"Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_inMemorySearch));
|
||||
g_httpServer->open().wait();
|
||||
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
void loadIndex(const char *indexFile, const char *baseFile, const char *idsFile)
|
||||
{
|
||||
auto nsgSearch = new diskann::InMemorySearch(baseFile, indexFile, idsFile, diskann::L2);
|
||||
g_inMemorySearch = std::unique_ptr<diskann::InMemorySearch>(nsgSearch);
|
||||
}
|
||||
|
||||
std::wstring getHostingAddress(const char *hostNameAndPort)
|
||||
{
|
||||
wchar_t buffer[4096];
|
||||
mbstowcs_s(nullptr, buffer, sizeof(buffer) / sizeof(buffer[0]), hostNameAndPort,
|
||||
sizeof(buffer) / sizeof(buffer[0]));
|
||||
return std::wstring(buffer);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: nsg_server <ip_addr_and_port> <index_file> "
|
||||
"<base_file> <ids_file> "
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto address = getHostingAddress(argv[1]);
|
||||
loadIndex(argv[2], argv[3], argv[4]);
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("index_prefix_paths", po::value<std::string>(&index_prefix_paths)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> index_tag_paths;
|
||||
std::ifstream index_in(index_prefix_paths);
|
||||
if (!index_in.is_open())
|
||||
{
|
||||
std::cerr << "Could not open " << index_prefix_paths << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream tags_in(tags_file);
|
||||
if (!tags_in.is_open())
|
||||
{
|
||||
std::cerr << "Could not open " << tags_file << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string prefix, tagfile;
|
||||
while (std::getline(index_in, prefix))
|
||||
{
|
||||
if (std::getline(tags_in, tagfile))
|
||||
{
|
||||
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "The number of tags specified does not match the number of "
|
||||
"indices specified"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
index_in.close();
|
||||
tags_in.close();
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<float>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
for (auto &index_tag : index_tag_paths)
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<uint8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads, index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << data_type << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <ctime>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
namespace po = boost::program_options;
|
||||
|
||||
std::unique_ptr<Server> g_httpServer(nullptr);
|
||||
std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
||||
|
||||
void setup(const utility::string_t &address, const std::string &typestring)
|
||||
{
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
g_httpServer = std::unique_ptr<Server>(new Server(uri, g_ssdSearch, typestring));
|
||||
std::cout << "Created a server object" << std::endl;
|
||||
|
||||
g_httpServer->open().wait();
|
||||
ucout << U"Listening for requests on: " << address << std::endl;
|
||||
}
|
||||
|
||||
void teardown(const utility::string_t &address)
|
||||
{
|
||||
g_httpServer->close().wait();
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(), "Web server address");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()("num_threads,T", po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else
|
||||
{
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<int8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<uint8_t>(index_path_prefix, num_nodes_to_cache, num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1)
|
||||
{
|
||||
try
|
||||
{
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit")
|
||||
{
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,499 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "common_includes.h"
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "index.h"
|
||||
#include "disk_utils.h"
|
||||
#include "math_utils.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "partition.h"
|
||||
#include "pq_flash_index.h"
|
||||
#include "timer.h"
|
||||
#include "percentile_stats.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#include "linux_aligned_file_reader.h"
|
||||
#else
|
||||
#ifdef USE_BING_INFRA
|
||||
#include "bing_aligned_file_reader.h"
|
||||
#else
|
||||
#include "windows_aligned_file_reader.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define WARMUP false
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
void print_stats(std::string category, std::vector<float> percentiles, std::vector<float> results)
|
||||
{
|
||||
diskann::cout << std::setw(20) << category << ": " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(8) << percentiles[s] << "%";
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << std::setw(22) << " " << std::flush;
|
||||
for (uint32_t s = 0; s < percentiles.size(); s++)
|
||||
{
|
||||
diskann::cout << std::setw(9) << results[s];
|
||||
}
|
||||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix,
|
||||
const std::string &result_output_prefix, const std::string &query_file, std::string >_file,
|
||||
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
|
||||
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
|
||||
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
|
||||
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
|
||||
{
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
|
||||
else
|
||||
diskann::cout << " beamwidth: " << beamwidth << std::flush;
|
||||
if (search_io_limit == std::numeric_limits<uint32_t>::max())
|
||||
diskann::cout << "." << std::endl;
|
||||
else
|
||||
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
|
||||
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
// load query bin
|
||||
T *query = nullptr;
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool filtered_search = false;
|
||||
if (!query_filters.empty())
|
||||
{
|
||||
filtered_search = true;
|
||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
||||
"filters file"
|
||||
<< std::endl;
|
||||
return -1; // To return -1 or some other error handling?
|
||||
}
|
||||
}
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file))
|
||||
{
|
||||
diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
|
||||
std::shared_ptr<AlignedFileReader> reader = nullptr;
|
||||
#ifdef _WINDOWS
|
||||
#ifndef USE_BING_INFRA
|
||||
reader.reset(new WindowsAlignedFileReader());
|
||||
#else
|
||||
reader.reset(new diskann::BingAlignedFileReader());
|
||||
#endif
|
||||
#else
|
||||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
if (res != 0)
|
||||
{
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> node_list;
|
||||
diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl;
|
||||
_pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list);
|
||||
// if (num_nodes_to_cache > 0)
|
||||
// _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
|
||||
// num_threads, node_list);
|
||||
_pFlashIndex->load_cache_list(node_list);
|
||||
node_list.clear();
|
||||
node_list.shrink_to_fit();
|
||||
|
||||
omp_set_num_threads(num_threads);
|
||||
|
||||
uint64_t warmup_L = 20;
|
||||
uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0;
|
||||
T *warmup = nullptr;
|
||||
|
||||
if (WARMUP)
|
||||
{
|
||||
if (file_exists(warmup_query_file))
|
||||
{
|
||||
diskann::load_aligned_bin<T>(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads);
|
||||
warmup_dim = query_dim;
|
||||
warmup_aligned_dim = query_aligned_dim;
|
||||
diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(-128, 127);
|
||||
for (uint32_t i = 0; i < warmup_num; i++)
|
||||
{
|
||||
for (uint32_t d = 0; d < warmup_dim; d++)
|
||||
{
|
||||
warmup[i * warmup_aligned_dim + d] = (T)dis(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
diskann::cout << "Warming up index... " << std::flush;
|
||||
std::vector<uint64_t> warmup_result_ids_64(warmup_num, 0);
|
||||
std::vector<float> warmup_result_dists(warmup_num, 0);
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)warmup_num; i++)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L,
|
||||
warmup_result_ids_64.data() + (i * 1),
|
||||
warmup_result_dists.data() + (i * 1), 4);
|
||||
}
|
||||
diskann::cout << "..done" << std::endl;
|
||||
}
|
||||
|
||||
diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
diskann::cout.precision(2);
|
||||
|
||||
std::string recall_string = "Recall@" + std::to_string(recall_at);
|
||||
diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16)
|
||||
<< "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16)
|
||||
<< "Mean IO (us)" << std::setw(16) << "CPU (s)";
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall_string << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
diskann::cout << "=================================================================="
|
||||
"================================================================="
|
||||
<< std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
|
||||
uint32_t optimized_beamwidth = 2;
|
||||
|
||||
double best_recall = 0.0;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (beamwidth <= 0)
|
||||
{
|
||||
diskann::cout << "Tuning beamwidth.." << std::endl;
|
||||
optimized_beamwidth =
|
||||
optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth);
|
||||
}
|
||||
else
|
||||
optimized_beamwidth = beamwidth;
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
|
||||
auto stats = new diskann::QueryStats[query_num];
|
||||
|
||||
std::vector<uint64_t> query_result_ids_64(recall_at * query_num);
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
if (!filtered_search)
|
||||
{
|
||||
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
|
||||
query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at),
|
||||
optimized_beamwidth, use_reorder_data, stats + i);
|
||||
}
|
||||
else
|
||||
{
|
||||
LabelT label_for_search;
|
||||
if (query_filters.size() == 1)
|
||||
{ // one label for all queries
|
||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[0]);
|
||||
}
|
||||
else
|
||||
{ // one label for each query
|
||||
label_for_search = _pFlashIndex->get_converted_label(query_filters[i]);
|
||||
}
|
||||
_pFlashIndex->cached_beam_search(
|
||||
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
|
||||
use_reorder_data, stats + i);
|
||||
}
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
double qps = (1.0 * query_num) / (1.0 * diff.count());
|
||||
|
||||
diskann::convert_types<uint64_t, uint32_t>(query_result_ids_64.data(), query_result_ids[test_id].data(),
|
||||
query_num, recall_at);
|
||||
|
||||
auto mean_latency = diskann::get_mean_stats<float>(
|
||||
stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto latency_999 = diskann::get_percentile_stats<float>(
|
||||
stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; });
|
||||
|
||||
auto mean_ios = diskann::get_mean_stats<uint32_t>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_ios; });
|
||||
|
||||
auto mean_cpuus = diskann::get_mean_stats<float>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.cpu_us; });
|
||||
|
||||
auto mean_io_us = diskann::get_mean_stats<float>(stats, query_num,
|
||||
[](const diskann::QueryStats &stats) { return stats.io_us; });
|
||||
|
||||
double recall = 0;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
||||
query_result_ids[test_id].data(), recall_at, recall_at);
|
||||
best_recall = std::max(recall, best_recall);
|
||||
}
|
||||
|
||||
diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps
|
||||
<< std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios
|
||||
<< std::setw(16) << mean_io_us << std::setw(16) << mean_cpuus;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
diskann::cout << std::setw(16) << recall << std::endl;
|
||||
}
|
||||
else
|
||||
diskann::cout << std::endl;
|
||||
delete[] stats;
|
||||
}
|
||||
|
||||
diskann::cout << "Done searching. Now saving results " << std::endl;
|
||||
uint64_t test_id = 0;
|
||||
for (auto L : Lvec)
|
||||
{
|
||||
if (L < recall_at)
|
||||
continue;
|
||||
|
||||
std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin";
|
||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
||||
|
||||
cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin";
|
||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at);
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
if (warmup != nullptr)
|
||||
diskann::aligned_free(warmup);
|
||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
|
||||
label_type, query_filters_file;
|
||||
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
|
||||
std::vector<uint32_t> Lvec;
|
||||
bool use_reorder_data = false;
|
||||
float fail_if_recall_below = 0.0f;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path_prefix)->required(),
|
||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("beamwidth,W", po::value<uint32_t>(&W)->default_value(2),
|
||||
program_options_utils::BEAMWIDTH);
|
||||
optional_configs.add_options()("num_nodes_to_cache", po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
program_options_utils::NUMBER_OF_NODES_TO_CACHE);
|
||||
optional_configs.add_options()(
|
||||
"search_io_limit",
|
||||
po::value<uint32_t>(&search_io_limit)->default_value(std::numeric_limits<uint32_t>::max()),
|
||||
"Max #IOs for search. Default value: uint32::max()");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD. Default value: false");
|
||||
optional_configs.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
||||
optional_configs.add_options()("query_filters_file",
|
||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
optional_configs.add_options()("fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (vm["use_reorder_data"].as<bool>())
|
||||
use_reorder_data = true;
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only L2/ Inner "
|
||||
"Product/Cosine are supported."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT))
|
||||
{
|
||||
std::cout << "Currently support only floating point data for Inner Product." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (use_reorder_data && data_type != std::string("float"))
|
||||
{
|
||||
std::cout << "Error: Reorder data for reordering currently only "
|
||||
"supported for float data type."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && query_filters_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> query_filters;
|
||||
if (filter_label != "")
|
||||
{
|
||||
query_filters.push_back(filter_label);
|
||||
}
|
||||
else if (query_filters_file != "")
|
||||
{
|
||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!query_filters.empty() && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
|
||||
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, query_filters, use_reorder_data);
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,477 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
#include "index_factory.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T, typename LabelT = uint32_t>
|
||||
int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix,
|
||||
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
|
||||
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
|
||||
const bool dynamic, const bool tags, const bool show_qps_per_thread,
|
||||
const std::vector<std::string> &query_filters, const float fail_if_recall_below)
|
||||
{
|
||||
using TagT = uint32_t;
|
||||
// Load the query file
|
||||
T *query = nullptr;
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim;
|
||||
diskann::load_aligned_bin<T>(query_file, query, query_num, query_dim, query_aligned_dim);
|
||||
|
||||
bool calc_recall_flag = false;
|
||||
if (truthset_file != std::string("null") && file_exists(truthset_file))
|
||||
{
|
||||
diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
if (gt_num != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl;
|
||||
}
|
||||
calc_recall_flag = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl;
|
||||
}
|
||||
|
||||
bool filtered_search = false;
|
||||
if (!query_filters.empty())
|
||||
{
|
||||
filtered_search = true;
|
||||
if (query_filters.size() != 1 && query_filters.size() != query_num)
|
||||
{
|
||||
std::cout << "Error. Mismatch in number of queries and size of query "
|
||||
"filters file"
|
||||
<< std::endl;
|
||||
return -1; // To return -1 or some other error handling?
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path);
|
||||
|
||||
auto config = diskann::IndexConfigBuilder()
|
||||
.with_metric(metric)
|
||||
.with_dimension(query_dim)
|
||||
.with_max_points(0)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.is_dynamic_index(dynamic)
|
||||
.is_enable_tags(tags)
|
||||
.is_concurrent_consolidate(false)
|
||||
.is_pq_dist_build(false)
|
||||
.is_use_opq(false)
|
||||
.with_num_pq_chunks(0)
|
||||
.with_num_frozen_pts(num_frozen_pts)
|
||||
.build();
|
||||
|
||||
auto index_factory = diskann::IndexFactory(config);
|
||||
auto index = index_factory.create_instance();
|
||||
index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end())));
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
|
||||
if (metric == diskann::FAST_L2)
|
||||
index->optimize_index_layout();
|
||||
|
||||
std::cout << "Using " << num_threads << " threads to search" << std::endl;
|
||||
std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield);
|
||||
std::cout.precision(2);
|
||||
const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS";
|
||||
uint32_t table_width = 0;
|
||||
if (tags)
|
||||
{
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)"
|
||||
<< std::setw(15) << "99.9 Latency";
|
||||
table_width += 4 + 12 + 20 + 15;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps"
|
||||
<< std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency";
|
||||
table_width += 4 + 12 + 18 + 20 + 15;
|
||||
}
|
||||
uint32_t recalls_to_print = 0;
|
||||
const uint32_t first_recall = print_all_recalls ? 1 : recall_at;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
||||
{
|
||||
std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall));
|
||||
}
|
||||
recalls_to_print = recall_at + 1 - first_recall;
|
||||
table_width += recalls_to_print * 12;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
std::cout << std::string(table_width, '=') << std::endl;
|
||||
|
||||
std::vector<std::vector<uint32_t>> query_result_ids(Lvec.size());
|
||||
std::vector<std::vector<float>> query_result_dists(Lvec.size());
|
||||
std::vector<float> latency_stats(query_num, 0);
|
||||
std::vector<uint32_t> cmp_stats;
|
||||
if (not tags || filtered_search)
|
||||
{
|
||||
cmp_stats = std::vector<uint32_t>(query_num, 0);
|
||||
}
|
||||
|
||||
std::vector<TagT> query_result_tags;
|
||||
if (tags)
|
||||
{
|
||||
query_result_tags.resize(recall_at * query_num);
|
||||
}
|
||||
|
||||
double best_recall = 0.0;
|
||||
|
||||
for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++)
|
||||
{
|
||||
uint32_t L = Lvec[test_id];
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
std::vector<T *> res = std::vector<T *>();
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
omp_set_num_threads(num_threads);
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t)query_num; i++)
|
||||
{
|
||||
auto qs = std::chrono::high_resolution_clock::now();
|
||||
if (filtered_search && !tags)
|
||||
{
|
||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
||||
|
||||
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at,
|
||||
query_result_dists[test_id].data() + i * recall_at);
|
||||
cmp_stats[i] = retval.second;
|
||||
}
|
||||
else if (metric == diskann::FAST_L2)
|
||||
{
|
||||
index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at);
|
||||
}
|
||||
else if (tags)
|
||||
{
|
||||
if (!filtered_search)
|
||||
{
|
||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_tags.data() + i * recall_at, nullptr, res);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
|
||||
|
||||
index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter);
|
||||
}
|
||||
|
||||
for (int64_t r = 0; r < (int64_t)recall_at; r++)
|
||||
{
|
||||
query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cmp_stats[i] = index
|
||||
->search(query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at)
|
||||
.second;
|
||||
}
|
||||
auto qe = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = qe - qs;
|
||||
latency_stats[i] = (float)(diff.count() * 1000000);
|
||||
}
|
||||
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
|
||||
|
||||
double displayed_qps = query_num / diff.count();
|
||||
|
||||
if (show_qps_per_thread)
|
||||
displayed_qps /= num_threads;
|
||||
|
||||
std::vector<double> recalls;
|
||||
if (calc_recall_flag)
|
||||
{
|
||||
recalls.reserve(recalls_to_print);
|
||||
for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++)
|
||||
{
|
||||
recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim,
|
||||
query_result_ids[test_id].data(), recall_at, curr_recall));
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(latency_stats.begin(), latency_stats.end());
|
||||
double mean_latency =
|
||||
std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast<float>(query_num);
|
||||
|
||||
float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num;
|
||||
|
||||
if (tags && !filtered_search)
|
||||
{
|
||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency
|
||||
<< std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps
|
||||
<< std::setw(20) << (float)mean_latency << std::setw(15)
|
||||
<< (float)latency_stats[(uint64_t)(0.999 * query_num)];
|
||||
}
|
||||
for (double recall : recalls)
|
||||
{
|
||||
std::cout << std::setw(12) << recall;
|
||||
best_recall = std::max(recall, best_recall);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Done searching. Now saving results " << std::endl;
|
||||
uint64_t test_id = 0;
|
||||
for (auto L : Lvec)
|
||||
{
|
||||
if (L < recall_at)
|
||||
{
|
||||
diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl;
|
||||
continue;
|
||||
}
|
||||
std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L);
|
||||
|
||||
std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin";
|
||||
diskann::save_bin<uint32_t>(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at);
|
||||
|
||||
cur_result_path = cur_result_path_prefix + "_dists_float.bin";
|
||||
diskann::save_bin<float>(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at);
|
||||
|
||||
test_id++;
|
||||
}
|
||||
|
||||
diskann::aligned_free(query);
|
||||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
|
||||
query_filters_file;
|
||||
uint32_t num_threads, K;
|
||||
std::vector<uint32_t> Lvec;
|
||||
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
|
||||
float fail_if_recall_below = 0.0f;
|
||||
|
||||
po::options_description desc{
|
||||
program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print this information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("result_path", po::value<std::string>(&result_path)->required(),
|
||||
program_options_utils::RESULT_PATH_DESCRIPTION);
|
||||
required_configs.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
program_options_utils::QUERY_FILE_DESCRIPTION);
|
||||
required_configs.add_options()("recall_at,K", po::value<uint32_t>(&K)->required(),
|
||||
program_options_utils::NUMBER_OF_RESULTS_DESCRIPTION);
|
||||
required_configs.add_options()("search_list,L",
|
||||
po::value<std::vector<uint32_t>>(&Lvec)->multitoken()->required(),
|
||||
program_options_utils::SEARCH_LIST_DESCRIPTION);
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
program_options_utils::FILTER_LABEL_DESCRIPTION);
|
||||
optional_configs.add_options()("query_filters_file",
|
||||
po::value<std::string>(&query_filters_file)->default_value(std::string("")),
|
||||
program_options_utils::FILTERS_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
program_options_utils::LABEL_TYPE_DESCRIPTION);
|
||||
optional_configs.add_options()("gt_file", po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION);
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()(
|
||||
"dynamic", po::value<bool>(&dynamic)->default_value(false),
|
||||
"Whether the index is dynamic. Dynamic indices must have associated tags. Default false.");
|
||||
optional_configs.add_options()("tags", po::value<bool>(&tags)->default_value(false),
|
||||
"Whether to search with external identifiers (tags). Default false.");
|
||||
optional_configs.add_options()("fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
program_options_utils::FAIL_IF_RECALL_BELOW);
|
||||
|
||||
// Output controls
|
||||
po::options_description output_controls("Output controls");
|
||||
output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls),
|
||||
"Print recalls at all positions, from 1 up to specified "
|
||||
"recall_at value");
|
||||
output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread),
|
||||
"Print overall QPS divided by the number of threads in "
|
||||
"the output table");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs).add(output_controls);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if ((dist_fn == std::string("mips")) && (data_type == std::string("float")))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float")))
|
||||
{
|
||||
metric = diskann::Metric::FAST_L2;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported distance function. Currently only l2/ cosine are "
|
||||
"supported in general, and mips/fast_l2 only for floating "
|
||||
"point data."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (dynamic && not tags)
|
||||
{
|
||||
std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0)
|
||||
{
|
||||
std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && query_filters_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> query_filters;
|
||||
if (filter_label != "")
|
||||
{
|
||||
query_filters.push_back(filter_label);
|
||||
}
|
||||
else if (query_filters_file != "")
|
||||
{
|
||||
query_filters = read_file_to_vector_of_strings(query_filters_file);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (!query_filters.empty() && label_type == "ushort")
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
{
|
||||
return search_memory_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
return search_memory_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls,
|
||||
Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
return search_memory_index<float, uint16_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
{
|
||||
return search_memory_index<int8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
return search_memory_index<uint8_t>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
return search_memory_index<float>(metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, query_filters, fail_if_recall_below);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,536 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <future>
|
||||
|
||||
#include "utils.h"
|
||||
#include "filter_utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
#include "index_factory.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
||||
// instead of cached_ifstream.
|
||||
template <typename T>
|
||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
||||
{
|
||||
diskann::Timer timer;
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
||||
size_t actual_file_size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
size_t npts = (uint32_t)npts_i32;
|
||||
size_t dim = (uint32_t)dim_i32;
|
||||
|
||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
||||
<< std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (offset_points + points_to_read > npts)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
||||
<< " points, but have only " << npts << " points" << std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
||||
|
||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
||||
|
||||
for (size_t i = 0; i < points_to_read; i++)
|
||||
{
|
||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
reader.close();
|
||||
|
||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
||||
std::cout << "Read " << points_to_read << " points using non-cached reads in " << elapsedSeconds << std::endl;
|
||||
}
|
||||
|
||||
std::string get_save_filename(const std::string &save_path, size_t points_to_skip, size_t points_deleted,
|
||||
size_t last_point_threshold)
|
||||
{
|
||||
std::string final_path = save_path;
|
||||
if (points_to_skip > 0)
|
||||
{
|
||||
final_path += "skip" + std::to_string(points_to_skip) + "-";
|
||||
}
|
||||
|
||||
final_path += "del" + std::to_string(points_deleted) + "-";
|
||||
final_path += std::to_string(last_point_threshold);
|
||||
return final_path;
|
||||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void insert_till_next_checkpoint(diskann::AbstractIndex &index, size_t start, size_t end, int32_t thread_count, T *data,
|
||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &location_to_labels)
|
||||
{
|
||||
diskann::Timer insert_timer;
|
||||
#pragma omp parallel for num_threads(thread_count) schedule(dynamic)
|
||||
for (int64_t j = start; j < (int64_t)end; j++)
|
||||
{
|
||||
if (!location_to_labels.empty())
|
||||
{
|
||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
||||
location_to_labels[j - start]);
|
||||
}
|
||||
else
|
||||
{
|
||||
index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
||||
}
|
||||
}
|
||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
||||
<< " points/second overall, " << (end - start) / elapsedSeconds / thread_count << " per thread)\n ";
|
||||
}
|
||||
|
||||
template <typename T, typename TagT>
|
||||
void delete_from_beginning(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params,
|
||||
size_t points_to_skip, size_t points_to_delete_from_beginning)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::cout << std::endl
|
||||
<< "Lazy deleting points " << points_to_skip << " to "
|
||||
<< points_to_skip + points_to_delete_from_beginning << "... ";
|
||||
for (size_t i = points_to_skip; i < points_to_skip + points_to_delete_from_beginning; ++i)
|
||||
index.lazy_delete(static_cast<TagT>(i + 1)); // Since tags are data location + 1
|
||||
std::cout << "done." << std::endl;
|
||||
|
||||
auto report = index.consolidate_deletes(delete_params);
|
||||
std::cout << "#active points: " << report._active_points << std::endl
|
||||
<< "max points: " << report._max_points << std::endl
|
||||
<< "empty slots: " << report._empty_slots << std::endl
|
||||
<< "deletes processed: " << report._slots_released << std::endl
|
||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
||||
<< "rate: (" << points_to_delete_from_beginning / report._time << " points/second overall, "
|
||||
<< points_to_delete_from_beginning / report._time / delete_params.num_threads << " per thread)"
|
||||
<< std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cout << "Exception caught in deletion thread: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void build_incremental_index(const std::string &data_path, diskann::IndexWriteParameters ¶ms, size_t points_to_skip,
|
||||
size_t max_points_to_insert, size_t beginning_index_size, float start_point_norm,
|
||||
uint32_t num_start_pts, size_t points_per_checkpoint, size_t checkpoints_per_snapshot,
|
||||
const std::string &save_path, size_t points_to_delete_from_beginning,
|
||||
size_t start_deletes_after, bool concurrent, const std::string &label_file,
|
||||
const std::string &universal_label)
|
||||
{
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
||||
aligned_dim = ROUND_UP(dim, 8);
|
||||
bool has_labels = label_file != "";
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
|
||||
size_t current_point_offset = points_to_skip;
|
||||
const size_t last_point_threshold = points_to_skip + max_points_to_insert;
|
||||
|
||||
bool enable_tags = true;
|
||||
using TagT = uint32_t;
|
||||
auto index_search_params = diskann::IndexSearchParams(params.search_list_size, params.num_threads);
|
||||
diskann::IndexConfig index_config = diskann::IndexConfigBuilder()
|
||||
.with_metric(diskann::L2)
|
||||
.with_dimension(dim)
|
||||
.with_max_points(max_points_to_insert)
|
||||
.is_dynamic_index(true)
|
||||
.with_index_write_params(params)
|
||||
.with_index_search_params(index_search_params)
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.is_enable_tags(enable_tags)
|
||||
.is_filtered(has_labels)
|
||||
.with_num_frozen_pts(num_start_pts)
|
||||
.is_concurrent_consolidate(concurrent)
|
||||
.build();
|
||||
|
||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
||||
auto index = index_factory.create_instance();
|
||||
|
||||
if (universal_label != "")
|
||||
{
|
||||
LabelT u_label = 0;
|
||||
index->set_universal_label(u_label);
|
||||
}
|
||||
|
||||
if (points_to_skip > num_points)
|
||||
{
|
||||
throw diskann::ANNException("Asked to skip more points than in data file", -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (max_points_to_insert == 0)
|
||||
{
|
||||
max_points_to_insert = num_points;
|
||||
}
|
||||
|
||||
if (points_to_skip + max_points_to_insert > num_points)
|
||||
{
|
||||
max_points_to_insert = num_points - points_to_skip;
|
||||
std::cerr << "WARNING: Reducing max_points_to_insert to " << max_points_to_insert
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
|
||||
if (beginning_index_size > max_points_to_insert)
|
||||
{
|
||||
beginning_index_size = max_points_to_insert;
|
||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
if (checkpoints_per_snapshot > 0 && beginning_index_size > points_per_checkpoint)
|
||||
{
|
||||
beginning_index_size = points_per_checkpoint;
|
||||
std::cerr << "WARNING: Reducing beginning index size to " << beginning_index_size << std::endl;
|
||||
}
|
||||
|
||||
T *data = nullptr;
|
||||
diskann::alloc_aligned(
|
||||
(void **)&data, std::max(points_per_checkpoint, beginning_index_size) * aligned_dim * sizeof(T), 8 * sizeof(T));
|
||||
|
||||
std::vector<TagT> tags(beginning_index_size);
|
||||
std::iota(tags.begin(), tags.end(), 1 + static_cast<TagT>(current_point_offset));
|
||||
|
||||
load_aligned_bin_part(data_path, data, current_point_offset, beginning_index_size);
|
||||
std::cout << "load aligned bin succeeded" << std::endl;
|
||||
diskann::Timer timer;
|
||||
|
||||
if (beginning_index_size > 0)
|
||||
{
|
||||
index->build(data, beginning_index_size, tags);
|
||||
}
|
||||
else
|
||||
{
|
||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
||||
}
|
||||
|
||||
const double elapsedSeconds = timer.elapsed() / 1000000.0;
|
||||
std::cout << "Initial non-incremental index build time for " << beginning_index_size << " points took "
|
||||
<< elapsedSeconds << " seconds (" << beginning_index_size / elapsedSeconds << " points/second)\n ";
|
||||
|
||||
current_point_offset += beginning_index_size;
|
||||
|
||||
if (points_to_delete_from_beginning > max_points_to_insert)
|
||||
{
|
||||
points_to_delete_from_beginning = static_cast<uint32_t>(max_points_to_insert);
|
||||
std::cerr << "WARNING: Reducing points to delete from beginning to " << points_to_delete_from_beginning
|
||||
<< " points since the data file has only that many" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<std::vector<LabelT>> location_to_labels;
|
||||
if (concurrent)
|
||||
{
|
||||
// handle labels
|
||||
const auto save_path_inc = get_save_filename(save_path + ".after-concurrent-delete-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
location_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
int32_t sub_threads = (params.num_threads + 1) / 2;
|
||||
bool delete_launched = false;
|
||||
std::future<void> delete_task;
|
||||
|
||||
diskann::Timer timer;
|
||||
|
||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
||||
{
|
||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, sub_threads, data, aligned_dim,
|
||||
location_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
if (!delete_launched && end >= start_deletes_after &&
|
||||
end >= points_to_skip + points_to_delete_from_beginning)
|
||||
{
|
||||
delete_launched = true;
|
||||
diskann::IndexWriteParameters delete_params =
|
||||
diskann::IndexWriteParametersBuilder(params).with_num_threads(sub_threads).build();
|
||||
|
||||
delete_task = std::async(std::launch::async, [&]() {
|
||||
delete_from_beginning<T, TagT>(*index, delete_params, points_to_skip,
|
||||
points_to_delete_from_beginning);
|
||||
});
|
||||
}
|
||||
}
|
||||
delete_task.wait();
|
||||
|
||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto save_path_inc = get_save_filename(save_path + ".after-delete-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
location_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
size_t last_snapshot_points_threshold = 0;
|
||||
size_t num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
||||
|
||||
for (size_t start = current_point_offset; start < last_point_threshold;
|
||||
start += points_per_checkpoint, current_point_offset += points_per_checkpoint)
|
||||
{
|
||||
const size_t end = std::min(start + points_per_checkpoint, last_point_threshold);
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_till_next_checkpoint<T, TagT, LabelT>(*index, start, end, (int32_t)params.num_threads, data,
|
||||
aligned_dim, location_to_labels);
|
||||
|
||||
if (checkpoints_per_snapshot > 0 && --num_checkpoints_till_snapshot == 0)
|
||||
{
|
||||
diskann::Timer save_timer;
|
||||
|
||||
const auto save_path_inc =
|
||||
get_save_filename(save_path + ".inc-", points_to_skip, points_to_delete_from_beginning, end);
|
||||
index->save(save_path_inc.c_str(), false);
|
||||
const double elapsedSeconds = save_timer.elapsed() / 1000000.0;
|
||||
const size_t points_saved = end - points_to_skip;
|
||||
|
||||
std::cout << "Saved " << points_saved << " points in " << elapsedSeconds << " seconds ("
|
||||
<< points_saved / elapsedSeconds << " points/second)\n";
|
||||
|
||||
num_checkpoints_till_snapshot = checkpoints_per_snapshot;
|
||||
last_snapshot_points_threshold = end;
|
||||
}
|
||||
|
||||
std::cout << "Number of points in the index post insertion " << end << std::endl;
|
||||
}
|
||||
|
||||
if (checkpoints_per_snapshot > 0 && last_snapshot_points_threshold != last_point_threshold)
|
||||
{
|
||||
const auto save_path_inc = get_save_filename(save_path + ".inc-", points_to_skip,
|
||||
points_to_delete_from_beginning, last_point_threshold);
|
||||
// index.save(save_path_inc.c_str(), false);
|
||||
}
|
||||
|
||||
if (points_to_delete_from_beginning > 0)
|
||||
{
|
||||
delete_from_beginning<T, TagT>(*index, params, points_to_skip, points_to_delete_from_beginning);
|
||||
}
|
||||
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
}
|
||||
|
||||
diskann::aligned_free(data);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix;
|
||||
uint32_t num_threads, R, L, num_start_pts;
|
||||
float alpha, start_point_norm;
|
||||
size_t points_to_skip, max_points_to_insert, beginning_index_size, points_per_checkpoint, checkpoints_per_snapshot,
|
||||
points_to_delete_from_beginning, start_deletes_after;
|
||||
bool concurrent;
|
||||
|
||||
// label options
|
||||
std::string label_file, label_type, universal_label;
|
||||
std::uint32_t Lf, unique_labels_supported;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description("test_insert_deletes_consolidate",
|
||||
"Test insert deletes & consolidate")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("points_to_skip", po::value<uint64_t>(&points_to_skip)->required(),
|
||||
"Skip these first set of points from file");
|
||||
required_configs.add_options()("beginning_index_size", po::value<uint64_t>(&beginning_index_size)->required(),
|
||||
"Batch build will be called on these set of points");
|
||||
required_configs.add_options()("points_per_checkpoint", po::value<uint64_t>(&points_per_checkpoint)->required(),
|
||||
"Insertions are done in batches of points_per_checkpoint");
|
||||
required_configs.add_options()("checkpoints_per_snapshot",
|
||||
po::value<uint64_t>(&checkpoints_per_snapshot)->required(),
|
||||
"Save the index to disk every few checkpoints");
|
||||
required_configs.add_options()("points_to_delete_from_beginning",
|
||||
po::value<uint64_t>(&points_to_delete_from_beginning)->required(), "");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
program_options_utils::NUMBER_THREADS_DESCRIPTION);
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("max_points_to_insert",
|
||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
||||
"These number of points from the file are inserted after "
|
||||
"points_to_skip");
|
||||
optional_configs.add_options()("do_concurrent", po::value<bool>(&concurrent)->default_value(false), "");
|
||||
optional_configs.add_options()("start_deletes_after",
|
||||
po::value<uint64_t>(&start_deletes_after)->default_value(0), "");
|
||||
optional_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->default_value(0),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
|
||||
// optional params for filters
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
optional_configs.add_options()("unique_labels_supported",
|
||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
||||
"Number of unique labels supported by the dynamic index.");
|
||||
|
||||
optional_configs.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
if (beginning_index_size == 0)
|
||||
if (start_point_norm == 0)
|
||||
{
|
||||
std::cout << "When beginning_index_size is 0, use a start "
|
||||
"point with "
|
||||
"appropriate norm"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool has_labels = false;
|
||||
if (!label_file.empty() || label_file != "")
|
||||
{
|
||||
has_labels = true;
|
||||
}
|
||||
|
||||
if (num_start_pts < unique_labels_supported)
|
||||
{
|
||||
num_start_pts = unique_labels_supported;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(500)
|
||||
.with_alpha(alpha)
|
||||
.with_num_threads(num_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
if (data_type == std::string("int8"))
|
||||
build_incremental_index<int8_t>(
|
||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
||||
else if (data_type == std::string("uint8"))
|
||||
build_incremental_index<uint8_t>(
|
||||
data_path, params, points_to_skip, max_points_to_insert, beginning_index_size, start_point_norm,
|
||||
num_start_pts, points_per_checkpoint, checkpoints_per_snapshot, index_path_prefix,
|
||||
points_to_delete_from_beginning, start_deletes_after, concurrent, label_file, universal_label);
|
||||
else if (data_type == std::string("float"))
|
||||
build_incremental_index<float>(data_path, params, points_to_skip, max_points_to_insert,
|
||||
beginning_index_size, start_point_norm, num_start_pts, points_per_checkpoint,
|
||||
checkpoints_per_snapshot, index_path_prefix, points_to_delete_from_beginning,
|
||||
start_deletes_after, concurrent, label_file, universal_label);
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Caught unknown exception" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,523 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <timer.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <future>
|
||||
#include <abstract_index.h>
|
||||
#include <index_factory.h>
|
||||
|
||||
#include "utils.h"
|
||||
#include "filter_utils.h"
|
||||
#include "program_options_utils.hpp"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// load_aligned_bin modified to read pieces of the file, but using ifstream
|
||||
// instead of cached_ifstream.
|
||||
template <typename T>
|
||||
inline void load_aligned_bin_part(const std::string &bin_file, T *data, size_t offset_points, size_t points_to_read)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(bin_file, std::ios::binary | std::ios::ate);
|
||||
size_t actual_file_size = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
size_t npts = (uint32_t)npts_i32;
|
||||
size_t dim = (uint32_t)dim_i32;
|
||||
|
||||
size_t expected_actual_file_size = npts * dim * sizeof(T) + 2 * sizeof(uint32_t);
|
||||
if (actual_file_size != expected_actual_file_size)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. Actual size is " << actual_file_size << " while expected size is "
|
||||
<< expected_actual_file_size << " npts = " << npts << " dim = " << dim << " size of <T>= " << sizeof(T)
|
||||
<< std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
if (offset_points + points_to_read > npts)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. Not enough points in file. Requested " << offset_points << " offset and " << points_to_read
|
||||
<< " points, but have only " << npts << " points" << std::endl;
|
||||
std::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
reader.seekg(2 * sizeof(uint32_t) + offset_points * dim * sizeof(T));
|
||||
|
||||
const size_t rounded_dim = ROUND_UP(dim, 8);
|
||||
|
||||
for (size_t i = 0; i < points_to_read; i++)
|
||||
{
|
||||
reader.read((char *)(data + i * rounded_dim), dim * sizeof(T));
|
||||
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
|
||||
}
|
||||
reader.close();
|
||||
}
|
||||
|
||||
std::string get_save_filename(const std::string &save_path, size_t active_window, size_t consolidate_interval,
|
||||
size_t max_points_to_insert)
|
||||
{
|
||||
std::string final_path = save_path;
|
||||
final_path += "act" + std::to_string(active_window) + "-";
|
||||
final_path += "cons" + std::to_string(consolidate_interval) + "-";
|
||||
final_path += "max" + std::to_string(max_points_to_insert);
|
||||
return final_path;
|
||||
}
|
||||
|
||||
template <typename T, typename TagT, typename LabelT>
|
||||
void insert_next_batch(diskann::AbstractIndex &index, size_t start, size_t end, size_t insert_threads, T *data,
|
||||
size_t aligned_dim, std::vector<std::vector<LabelT>> &pts_to_labels)
|
||||
{
|
||||
try
|
||||
{
|
||||
diskann::Timer insert_timer;
|
||||
std::cout << std::endl << "Inserting from " << start << " to " << end << std::endl;
|
||||
|
||||
size_t num_failed = 0;
|
||||
#pragma omp parallel for num_threads((int32_t)insert_threads) schedule(dynamic) reduction(+ : num_failed)
|
||||
for (int64_t j = start; j < (int64_t)end; j++)
|
||||
{
|
||||
int insert_result = -1;
|
||||
if (pts_to_labels.size() > 0)
|
||||
{
|
||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j),
|
||||
pts_to_labels[j - start]);
|
||||
}
|
||||
else
|
||||
{
|
||||
insert_result = index.insert_point(&data[(j - start) * aligned_dim], 1 + static_cast<TagT>(j));
|
||||
}
|
||||
|
||||
if (insert_result != 0)
|
||||
{
|
||||
std::cerr << "Insert failed " << j << std::endl;
|
||||
num_failed++;
|
||||
}
|
||||
}
|
||||
const double elapsedSeconds = insert_timer.elapsed() / 1000000.0;
|
||||
std::cout << "Insertion time " << elapsedSeconds << " seconds (" << (end - start) / elapsedSeconds
|
||||
<< " points/second overall, " << (end - start) / elapsedSeconds / insert_threads << " per thread)"
|
||||
<< std::endl;
|
||||
if (num_failed > 0)
|
||||
std::cout << num_failed << " of " << end - start << "inserts failed" << std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cout << "Exiting after catching exception in insertion task: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
void delete_and_consolidate(diskann::AbstractIndex &index, diskann::IndexWriteParameters &delete_params, size_t start,
|
||||
size_t end)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::cout << std::endl << "Lazy deleting points " << start << " to " << end << "... ";
|
||||
for (size_t i = start; i < end; ++i)
|
||||
index.lazy_delete(static_cast<TagT>(1 + i));
|
||||
std::cout << "lazy delete done." << std::endl;
|
||||
|
||||
auto report = index.consolidate_deletes(delete_params);
|
||||
while (report._status != diskann::consolidation_report::status_code::SUCCESS)
|
||||
{
|
||||
int wait_time = 5;
|
||||
if (report._status == diskann::consolidation_report::status_code::LOCK_FAIL)
|
||||
{
|
||||
diskann::cerr << "Unable to acquire consolidate delete lock after "
|
||||
<< "deleting points " << start << " to " << end << ". Will retry in " << wait_time
|
||||
<< "seconds." << std::endl;
|
||||
}
|
||||
else if (report._status == diskann::consolidation_report::status_code::INCONSISTENT_COUNT_ERROR)
|
||||
{
|
||||
diskann::cerr << "Inconsistent counts in data structure. "
|
||||
<< "Will retry in " << wait_time << "seconds." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Exiting after unknown error in consolidate delete" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::seconds(wait_time));
|
||||
report = index.consolidate_deletes(delete_params);
|
||||
}
|
||||
auto points_processed = report._active_points + report._slots_released;
|
||||
auto deletion_rate = points_processed / report._time;
|
||||
std::cout << "#active points: " << report._active_points << std::endl
|
||||
<< "max points: " << report._max_points << std::endl
|
||||
<< "empty slots: " << report._empty_slots << std::endl
|
||||
<< "deletes processed: " << report._slots_released << std::endl
|
||||
<< "latest delete size: " << report._delete_set_size << std::endl
|
||||
<< "Deletion rate: " << deletion_rate << "/sec "
|
||||
<< "Deletion rate: " << deletion_rate / delete_params.num_threads << "/thread/sec " << std::endl;
|
||||
}
|
||||
catch (std::system_error &e)
|
||||
{
|
||||
std::cerr << "Exiting after catching exception in deletion task: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
void build_incremental_index(const std::string &data_path, const uint32_t L, const uint32_t R, const float alpha,
|
||||
const uint32_t insert_threads, const uint32_t consolidate_threads,
|
||||
size_t max_points_to_insert, size_t active_window, size_t consolidate_interval,
|
||||
const float start_point_norm, uint32_t num_start_pts, const std::string &save_path,
|
||||
const std::string &label_file, const std::string &universal_label, const uint32_t Lf)
|
||||
{
|
||||
const uint32_t C = 500;
|
||||
const bool saturate_graph = false;
|
||||
bool has_labels = label_file != "";
|
||||
|
||||
diskann::IndexWriteParameters params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_threads(insert_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
auto index_search_params = diskann::IndexSearchParams(L, insert_threads);
|
||||
diskann::IndexWriteParameters delete_params = diskann::IndexWriteParametersBuilder(L, R)
|
||||
.with_max_occlusion_size(C)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(saturate_graph)
|
||||
.with_num_threads(consolidate_threads)
|
||||
.with_filter_list_size(Lf)
|
||||
.build();
|
||||
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
|
||||
std::vector<std::vector<LabelT>> pts_to_labels;
|
||||
|
||||
const auto save_path_inc =
|
||||
get_save_filename(save_path + ".after-streaming-", active_window, consolidate_interval, max_points_to_insert);
|
||||
std::string labels_file_to_use = save_path_inc + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path_inc + "_labels_map.txt";
|
||||
if (has_labels)
|
||||
{
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use, mem_labels_int_map_file, universal_label);
|
||||
auto parse_result = diskann::parse_formatted_label_file<LabelT>(labels_file_to_use);
|
||||
pts_to_labels = std::get<0>(parse_result);
|
||||
}
|
||||
|
||||
diskann::get_bin_metadata(data_path, num_points, dim);
|
||||
diskann::cout << "metadata: file " << data_path << " has " << num_points << " points in " << dim << " dims"
|
||||
<< std::endl;
|
||||
aligned_dim = ROUND_UP(dim, 8);
|
||||
auto index_config = diskann::IndexConfigBuilder()
|
||||
.with_metric(diskann::L2)
|
||||
.with_dimension(dim)
|
||||
.with_max_points(active_window + 4 * consolidate_interval)
|
||||
.is_dynamic_index(true)
|
||||
.is_enable_tags(true)
|
||||
.is_use_opq(false)
|
||||
.is_filtered(has_labels)
|
||||
.with_num_pq_chunks(0)
|
||||
.is_pq_dist_build(false)
|
||||
.with_num_frozen_pts(num_start_pts)
|
||||
.with_tag_type(diskann_type_to_name<TagT>())
|
||||
.with_label_type(diskann_type_to_name<LabelT>())
|
||||
.with_data_type(diskann_type_to_name<T>())
|
||||
.with_index_write_params(params)
|
||||
.with_index_search_params(index_search_params)
|
||||
.with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY)
|
||||
.with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY)
|
||||
.build();
|
||||
|
||||
diskann::IndexFactory index_factory = diskann::IndexFactory(index_config);
|
||||
auto index = index_factory.create_instance();
|
||||
|
||||
if (universal_label != "")
|
||||
{
|
||||
LabelT u_label = 0;
|
||||
index->set_universal_label(u_label);
|
||||
}
|
||||
|
||||
if (max_points_to_insert == 0)
|
||||
{
|
||||
max_points_to_insert = num_points;
|
||||
}
|
||||
|
||||
if (num_points < max_points_to_insert)
|
||||
throw diskann::ANNException(std::string("num_points(") + std::to_string(num_points) +
|
||||
") < max_points_to_insert(" + std::to_string(max_points_to_insert) + ")",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
if (max_points_to_insert < active_window + consolidate_interval)
|
||||
throw diskann::ANNException("ERROR: max_points_to_insert < "
|
||||
"active_window + consolidate_interval",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
if (consolidate_interval < max_points_to_insert / 1000)
|
||||
throw diskann::ANNException("ERROR: consolidate_interval is too small", -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
index->set_start_points_at_random(static_cast<T>(start_point_norm));
|
||||
|
||||
T *data = nullptr;
|
||||
diskann::alloc_aligned((void **)&data, std::max(consolidate_interval, active_window) * aligned_dim * sizeof(T),
|
||||
8 * sizeof(T));
|
||||
|
||||
std::vector<TagT> tags(max_points_to_insert);
|
||||
std::iota(tags.begin(), tags.end(), static_cast<TagT>(0));
|
||||
|
||||
diskann::Timer timer;
|
||||
|
||||
std::vector<std::future<void>> delete_tasks;
|
||||
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, 0, active_window);
|
||||
insert_next_batch<T, TagT, LabelT>(*index, (size_t)0, active_window, params.num_threads, data, aligned_dim,
|
||||
pts_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
for (size_t start = active_window; start + consolidate_interval <= max_points_to_insert;
|
||||
start += consolidate_interval)
|
||||
{
|
||||
auto end = std::min(start + consolidate_interval, max_points_to_insert);
|
||||
auto insert_task = std::async(std::launch::async, [&]() {
|
||||
load_aligned_bin_part(data_path, data, start, end - start);
|
||||
insert_next_batch<T, TagT, LabelT>(*index, start, end, params.num_threads, data, aligned_dim,
|
||||
pts_to_labels);
|
||||
});
|
||||
insert_task.wait();
|
||||
|
||||
if (delete_tasks.size() > 0)
|
||||
delete_tasks[delete_tasks.size() - 1].wait();
|
||||
if (start >= active_window + consolidate_interval)
|
||||
{
|
||||
auto start_del = start - active_window - consolidate_interval;
|
||||
auto end_del = start - active_window;
|
||||
|
||||
delete_tasks.emplace_back(std::async(std::launch::async, [&]() {
|
||||
delete_and_consolidate<T, TagT, LabelT>(*index, delete_params, (size_t)start_del, (size_t)end_del);
|
||||
}));
|
||||
}
|
||||
}
|
||||
if (delete_tasks.size() > 0)
|
||||
delete_tasks[delete_tasks.size() - 1].wait();
|
||||
|
||||
std::cout << "Time Elapsed " << timer.elapsed() / 1000 << "ms\n";
|
||||
|
||||
index->save(save_path_inc.c_str(), true);
|
||||
|
||||
diskann::aligned_free(data);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
|
||||
uint32_t insert_threads, consolidate_threads, R, L, num_start_pts, Lf, unique_labels_supported;
|
||||
float alpha, start_point_norm;
|
||||
size_t max_points_to_insert, active_window, consolidate_interval;
|
||||
|
||||
po::options_description desc{program_options_utils::make_program_description("test_streaming_scenario",
|
||||
"Test insert deletes & consolidate")};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
// Required parameters
|
||||
po::options_description required_configs("Required");
|
||||
required_configs.add_options()("data_type", po::value<std::string>(&data_type)->required(),
|
||||
program_options_utils::DATA_TYPE_DESCRIPTION);
|
||||
required_configs.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
program_options_utils::DISTANCE_FUNCTION_DESCRIPTION);
|
||||
required_configs.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION);
|
||||
required_configs.add_options()("data_path", po::value<std::string>(&data_path)->required(),
|
||||
program_options_utils::INPUT_DATA_PATH);
|
||||
required_configs.add_options()("active_window", po::value<uint64_t>(&active_window)->required(),
|
||||
"Program maintains an index over an active window of "
|
||||
"this size that slides through the data");
|
||||
required_configs.add_options()("consolidate_interval", po::value<uint64_t>(&consolidate_interval)->required(),
|
||||
"The program simultaneously adds this number of points to the "
|
||||
"right of "
|
||||
"the window while deleting the same number from the left");
|
||||
required_configs.add_options()("start_point_norm", po::value<float>(&start_point_norm)->required(),
|
||||
"Set the start point to a random point on a sphere of this radius");
|
||||
|
||||
// Optional parameters
|
||||
po::options_description optional_configs("Optional");
|
||||
optional_configs.add_options()("max_degree,R", po::value<uint32_t>(&R)->default_value(64),
|
||||
program_options_utils::MAX_BUILD_DEGREE);
|
||||
optional_configs.add_options()("Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
program_options_utils::GRAPH_BUILD_COMPLEXITY);
|
||||
optional_configs.add_options()("alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
program_options_utils::GRAPH_BUILD_ALPHA);
|
||||
optional_configs.add_options()("insert_threads",
|
||||
po::value<uint32_t>(&insert_threads)->default_value(omp_get_num_procs() / 2),
|
||||
"Number of threads used for inserting into the index (defaults to "
|
||||
"omp_get_num_procs()/2)");
|
||||
optional_configs.add_options()(
|
||||
"consolidate_threads", po::value<uint32_t>(&consolidate_threads)->default_value(omp_get_num_procs() / 2),
|
||||
"Number of threads used for consolidating deletes to "
|
||||
"the index (defaults to omp_get_num_procs()/2)");
|
||||
optional_configs.add_options()("max_points_to_insert",
|
||||
po::value<uint64_t>(&max_points_to_insert)->default_value(0),
|
||||
"The number of points from the file that the program streams "
|
||||
"over ");
|
||||
optional_configs.add_options()(
|
||||
"num_start_points",
|
||||
po::value<uint32_t>(&num_start_pts)->default_value(diskann::defaults::NUM_FROZEN_POINTS_DYNAMIC),
|
||||
"Set the number of random start (frozen) points to use when "
|
||||
"inserting and searching");
|
||||
|
||||
optional_configs.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
optional_configs.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
optional_configs.add_options()("FilteredLbuild,Lf", po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
optional_configs.add_options()("unique_labels_supported",
|
||||
po::value<uint32_t>(&unique_labels_supported)->default_value(0),
|
||||
"Number of unique labels supported by the dynamic index.");
|
||||
|
||||
// Merge required and optional parameters
|
||||
desc.add(required_configs).add(optional_configs);
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Validate arguments
|
||||
if (start_point_norm == 0)
|
||||
{
|
||||
std::cout << "When beginning_index_size is 0, use a start point with "
|
||||
"appropriate norm"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (label_type != std::string("ushort") && label_type != std::string("uint"))
|
||||
{
|
||||
std::cerr << "Invalid label type. Supported types are uint and ushort" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("int8") && data_type != std::string("uint8") && data_type != std::string("float"))
|
||||
{
|
||||
std::cerr << "Invalid data type. Supported types are int8, uint8 and float" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO: Are additional distance functions supported?
|
||||
if (dist_fn != std::string("l2") && dist_fn != std::string("mips"))
|
||||
{
|
||||
std::cerr << "Invalid distance function. Supported functions are l2 and mips" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_start_pts < unique_labels_supported)
|
||||
{
|
||||
num_start_pts = unique_labels_supported;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("uint8"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<uint8_t, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<uint8_t, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<int8_t, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<int8_t, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
else if (data_type == std::string("float"))
|
||||
{
|
||||
if (label_type == std::string("ushort"))
|
||||
{
|
||||
build_incremental_index<float, uint32_t, uint16_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
else if (label_type == std::string("uint"))
|
||||
{
|
||||
build_incremental_index<float, uint32_t, uint32_t>(
|
||||
data_path, L, R, alpha, insert_threads, consolidate_threads, max_points_to_insert, active_window,
|
||||
consolidate_interval, start_point_norm, num_start_pts, index_path_prefix, label_file,
|
||||
universal_label, Lf);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << "Caught exception: " << e.what() << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::cerr << "Caught unknown exception" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_COMPILE_WARNING_AS_ERROR ON)
|
||||
|
||||
|
||||
add_executable(fvecs_to_bin fvecs_to_bin.cpp)
|
||||
|
||||
add_executable(fvecs_to_bvecs fvecs_to_bvecs.cpp)
|
||||
|
||||
add_executable(rand_data_gen rand_data_gen.cpp)
|
||||
target_link_libraries(rand_data_gen ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(float_bin_to_int8 float_bin_to_int8.cpp)
|
||||
|
||||
add_executable(ivecs_to_bin ivecs_to_bin.cpp)
|
||||
|
||||
add_executable(count_bfs_levels count_bfs_levels.cpp)
|
||||
target_link_libraries(count_bfs_levels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(tsv_to_bin tsv_to_bin.cpp)
|
||||
|
||||
add_executable(bin_to_tsv bin_to_tsv.cpp)
|
||||
|
||||
add_executable(int8_to_float int8_to_float.cpp)
|
||||
target_link_libraries(int8_to_float ${PROJECT_NAME})
|
||||
|
||||
add_executable(int8_to_float_scale int8_to_float_scale.cpp)
|
||||
target_link_libraries(int8_to_float_scale ${PROJECT_NAME})
|
||||
|
||||
add_executable(uint8_to_float uint8_to_float.cpp)
|
||||
target_link_libraries(uint8_to_float ${PROJECT_NAME})
|
||||
|
||||
add_executable(uint32_to_uint8 uint32_to_uint8.cpp)
|
||||
target_link_libraries(uint32_to_uint8 ${PROJECT_NAME})
|
||||
|
||||
add_executable(vector_analysis vector_analysis.cpp)
|
||||
target_link_libraries(vector_analysis ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(gen_random_slice gen_random_slice.cpp)
|
||||
target_link_libraries(gen_random_slice ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(simulate_aggregate_recall simulate_aggregate_recall.cpp)
|
||||
|
||||
add_executable(calculate_recall calculate_recall.cpp)
|
||||
target_link_libraries(calculate_recall ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
# Compute ground truth thing outside of DiskANN main source that depends on MKL.
|
||||
add_executable(compute_groundtruth compute_groundtruth.cpp)
|
||||
target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
||||
target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp)
|
||||
target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES})
|
||||
target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options)
|
||||
|
||||
|
||||
add_executable(generate_pq generate_pq.cpp)
|
||||
target_link_libraries(generate_pq ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
|
||||
add_executable(partition_data partition_data.cpp)
|
||||
target_link_libraries(partition_data ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(partition_with_ram_budget partition_with_ram_budget.cpp)
|
||||
target_link_libraries(partition_with_ram_budget ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(merge_shards merge_shards.cpp)
|
||||
target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} ${DISKANN_ASYNC_LIB})
|
||||
|
||||
add_executable(create_disk_layout create_disk_layout.cpp)
|
||||
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
|
||||
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(stats_label_data stats_label_data.cpp)
|
||||
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
if (NOT MSVC)
|
||||
include(GNUInstallDirs)
|
||||
install(TARGETS fvecs_to_bin
|
||||
fvecs_to_bvecs
|
||||
rand_data_gen
|
||||
float_bin_to_int8
|
||||
ivecs_to_bin
|
||||
count_bfs_levels
|
||||
tsv_to_bin
|
||||
bin_to_tsv
|
||||
int8_to_float
|
||||
int8_to_float_scale
|
||||
uint8_to_float
|
||||
uint32_to_uint8
|
||||
vector_analysis
|
||||
gen_random_slice
|
||||
simulate_aggregate_recall
|
||||
calculate_recall
|
||||
compute_groundtruth
|
||||
compute_groundtruth_for_filters
|
||||
generate_pq
|
||||
partition_data
|
||||
partition_with_ram_budget
|
||||
merge_shards
|
||||
create_disk_layout
|
||||
generate_synthetic_labels
|
||||
stats_label_data
|
||||
RUNTIME
|
||||
)
|
||||
endif()
|
||||
@@ -1,63 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "util.h"
|
||||
|
||||
void block_convert(std::ifstream &writr, std::ofstream &readr, float *read_buf, float *write_buf, uint64_t npts,
|
||||
uint64_t ndims)
|
||||
{
|
||||
writr.write((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(unsigned)));
|
||||
#pragma omp parallel for
|
||||
for (uint64_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
||||
}
|
||||
readr.read((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_bin output_fvecs" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream readr(argv[1], std::ios::binary);
|
||||
int npts_s32;
|
||||
int ndims_s32;
|
||||
readr.read((char *)&npts_s32, sizeof(int32_t));
|
||||
readr.read((char *)&ndims_s32, sizeof(int32_t));
|
||||
size_t npts = npts_s32;
|
||||
size_t ndims = ndims_s32;
|
||||
uint32_t ndims_u32 = (uint32_t)ndims_s32;
|
||||
// uint64_t fsize = writr.tellg();
|
||||
readr.seekg(0, std::ios::beg);
|
||||
|
||||
unsigned ndims_u32;
|
||||
writr.write((char *)&ndims_u32, sizeof(unsigned));
|
||||
writr.seekg(0, std::ios::beg);
|
||||
uint64_t ndims = (uint64_t)ndims_u32;
|
||||
uint64_t npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
uint64_t blk_size = 131072;
|
||||
uint64_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
std::ofstream writr(argv[2], std::ios::binary);
|
||||
float *read_buf = new float[npts * (ndims + 1)];
|
||||
float *write_buf = new float[npts * ndims];
|
||||
for (uint64_t i = 0; i < nblks; i++)
|
||||
{
|
||||
uint64_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writr, readr, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writr.close();
|
||||
readr.close();
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
template <class T>
|
||||
void block_convert(std::ofstream &writer, std::ifstream &reader, T *read_buf, size_t npts, size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(float));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
writer << read_buf[d + i * ndims];
|
||||
if (d < ndims - 1)
|
||||
writer << "\t";
|
||||
else
|
||||
writer << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <float/int8/uint8> input_bin output_tsv" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string type_string(argv[1]);
|
||||
if ((type_string != std::string("float")) && (type_string != std::string("int8")) &&
|
||||
(type_string != std::string("uin8")))
|
||||
{
|
||||
std::cerr << "Error: type not supported. Use float/int8/uint8" << std::endl;
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[3]);
|
||||
char *read_buf = new char[blk_size * ndims * 4];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (type_string == std::string("float"))
|
||||
block_convert<float>(writer, reader, (float *)read_buf, cblk_size, ndims);
|
||||
else if (type_string == std::string("int8"))
|
||||
block_convert<int8_t>(writer, reader, (int8_t *)read_buf, cblk_size, ndims);
|
||||
else if (type_string == std::string("uint8"))
|
||||
block_convert<uint8_t>(writer, reader, (uint8_t *)read_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <ground_truth_bin> <our_results_bin> <r> " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
uint32_t *gold_std = NULL;
|
||||
float *gs_dist = nullptr;
|
||||
uint32_t *our_results = NULL;
|
||||
float *or_dist = nullptr;
|
||||
size_t points_num, points_num_gs, points_num_or;
|
||||
size_t dim_gs;
|
||||
size_t dim_or;
|
||||
diskann::load_truthset(argv[1], gold_std, gs_dist, points_num_gs, dim_gs);
|
||||
diskann::load_truthset(argv[2], our_results, or_dist, points_num_or, dim_or);
|
||||
|
||||
if (points_num_gs != points_num_or)
|
||||
{
|
||||
std::cout << "Error. Number of queries mismatch in ground truth and "
|
||||
"our results"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
points_num = points_num_gs;
|
||||
|
||||
uint32_t recall_at = std::atoi(argv[3]);
|
||||
|
||||
if ((dim_or < recall_at) || (recall_at > dim_gs))
|
||||
{
|
||||
std::cout << "ground truth has size " << dim_gs << "; our set has " << dim_or << " points. Asking for recall "
|
||||
<< recall_at << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "Calculating recall@" << recall_at << std::endl;
|
||||
double recall_val = diskann::calculate_recall((uint32_t)points_num, gold_std, gs_dist, (uint32_t)dim_gs,
|
||||
our_results, (uint32_t)dim_or, (uint32_t)recall_at);
|
||||
|
||||
// double avg_recall = (recall*1.0)/(points_num*1.0);
|
||||
std::cout << "Avg. recall@" << recall_at << " is " << recall_val << "\n";
|
||||
}
|
||||
@@ -1,574 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <unordered_map>
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
#include "filter_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <class T> T div_round_up(const T numerator, const T denominator)
|
||||
{
|
||||
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<size_t, float>;
|
||||
struct cmpmaxstruct
|
||||
{
|
||||
bool operator()(const pairIF &l, const pairIF &r)
|
||||
{
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
|
||||
{
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
|
||||
{
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
|
||||
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
|
||||
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
|
||||
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void inner_prod_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
size_t *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
float *points_in, // points in Col major
|
||||
size_t nqueries, float *queries_in,
|
||||
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
compute_l2sq(points_l2sq, points_in, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
|
||||
|
||||
float *points = points_in;
|
||||
float *queries = queries_in;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{ // we convert cosine distance as
|
||||
// normalized L2 distnace
|
||||
points = new float[npoints * dim];
|
||||
queries = new float[nqueries * dim];
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)npoints; i++)
|
||||
{
|
||||
float norm = std::sqrt(points_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
points[i * dim + j] = points_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)nqueries; i++)
|
||||
{
|
||||
float norm = std::sqrt(queries_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
queries[i * dim + j] = queries_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
// recalculate norms after normalizing, they should all be one.
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
}
|
||||
|
||||
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
|
||||
<< dim << " dimensions using";
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
std::cout << " MIPS ";
|
||||
else if (metric == diskann::Metric::COSINE)
|
||||
std::cout << " Cosine ";
|
||||
else
|
||||
std::cout << " L2 ";
|
||||
std::cout << "distance fn. " << std::endl;
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
|
||||
|
||||
for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
|
||||
{
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
|
||||
{
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
|
||||
}
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++)
|
||||
{
|
||||
maxPQIFCS point_dist;
|
||||
for (size_t p = 0; p < k; p++)
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
for (size_t p = k; p < npoints; p++)
|
||||
{
|
||||
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
|
||||
{
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{
|
||||
delete[] points;
|
||||
delete[] queries;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> inline int get_num_parts(const char *filename)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
uint32_t num_parts =
|
||||
(npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (uint64_t)ndims_i32;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B"
|
||||
<< std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
T *data_T = new T[npts * ndims];
|
||||
reader.read((char *)data_T, sizeof(T) * npts * ndims);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
data = aligned_malloc<float>(npts * ndims, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (int64_t j = 0; j < (int64_t)ndims; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndims + j];
|
||||
std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *)distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
|
||||
size_t &nqueries, size_t &npoints,
|
||||
size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric,
|
||||
std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
float *base_data = nullptr;
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
|
||||
metric);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (size_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &query_file, const std::string >_file, size_t k,
|
||||
const diskann::Metric &metric, const std::string &tags_file = std::string(""))
|
||||
{
|
||||
size_t npoints, nqueries, dim;
|
||||
|
||||
float *query_data;
|
||||
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
if (nqueries > PARTSIZE)
|
||||
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
|
||||
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
|
||||
|
||||
// load tags
|
||||
const bool tags_enabled = tags_file.empty() ? false : true;
|
||||
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
|
||||
|
||||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results =
|
||||
processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
size_t j = 0;
|
||||
for (auto iter : cur_res)
|
||||
{
|
||||
if (j == k)
|
||||
break;
|
||||
if (tags_enabled)
|
||||
{
|
||||
std::uint32_t index_with_tag = location_to_tag[iter.first];
|
||||
closest_points[i * k + j] = (int32_t)index_with_tag;
|
||||
}
|
||||
else
|
||||
{
|
||||
closest_points[i * k + j] = (int32_t)iter.first;
|
||||
}
|
||||
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
dist_closest_points[i * k + j] = -iter.second;
|
||||
else
|
||||
dist_closest_points[i * k + j] = iter.second;
|
||||
|
||||
++j;
|
||||
}
|
||||
if (j < k)
|
||||
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
diskann::aligned_free(query_data);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
|
||||
{
|
||||
size_t read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
npts = (uint32_t)npts_i32;
|
||||
dim = (uint32_t)dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1)
|
||||
{
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char *)dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
|
||||
uint64_t K;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(),
|
||||
"distance function <l2/mips/cosine>");
|
||||
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
|
||||
"File containing the base vectors in binary format");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary "
|
||||
"format, please don' append .bin at end if "
|
||||
"no filter_label or filter_label_file is provided it "
|
||||
"will save the file with '.bin' at end."
|
||||
"else it will save the file as filename_label.bin");
|
||||
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
|
||||
"Number of ground truth nearest neighbors to compute");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"File containing the tags in binary format");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,919 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <random>
|
||||
#include <limits>
|
||||
#include <cstring>
|
||||
#include <queue>
|
||||
#include <omp.h>
|
||||
#include <mkl.h>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <unordered_map>
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#ifdef _WINDOWS
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
|
||||
#include "filter_utils.h"
|
||||
#include "utils.h"
|
||||
|
||||
// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED)
|
||||
|
||||
#define PARTSIZE 10000000
|
||||
#define ALIGNMENT 512
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <class T> T div_round_up(const T numerator, const T denominator)
|
||||
{
|
||||
return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator);
|
||||
}
|
||||
|
||||
using pairIF = std::pair<size_t, float>;
|
||||
struct cmpmaxstruct
|
||||
{
|
||||
bool operator()(const pairIF &l, const pairIF &r)
|
||||
{
|
||||
return l.second < r.second;
|
||||
};
|
||||
};
|
||||
|
||||
using maxPQIFCS = std::priority_queue<pairIF, std::vector<pairIF>, cmpmaxstruct>;
|
||||
|
||||
template <class T> T *aligned_malloc(const size_t n, const size_t alignment)
|
||||
{
|
||||
#ifdef _WINDOWS
|
||||
return (T *)_aligned_malloc(sizeof(T) * n, alignment);
|
||||
#else
|
||||
return static_cast<T *>(aligned_alloc(alignment, sizeof(T) * n));
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool custom_dist(const std::pair<uint32_t, float> &a, const std::pair<uint32_t, float> &b)
|
||||
{
|
||||
return a.second < b.second;
|
||||
}
|
||||
|
||||
void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim)
|
||||
{
|
||||
assert(points_l2sq != NULL);
|
||||
#pragma omp parallel for schedule(static, 65536)
|
||||
for (int64_t d = 0; d < num_points; ++d)
|
||||
points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1,
|
||||
matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1);
|
||||
}
|
||||
|
||||
void distsq_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points,
|
||||
const float *const points_l2sq, // points in Col major
|
||||
size_t nqueries, const float *const queries,
|
||||
const float *const queries_l2sq, // queries in Col major
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints,
|
||||
ones_vec, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints,
|
||||
queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints);
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void inner_prod_to_points(const size_t dim,
|
||||
float *dist_matrix, // Col Major, cols are queries, rows are points
|
||||
size_t npoints, const float *const points, size_t nqueries, const float *const queries,
|
||||
float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0
|
||||
{
|
||||
bool ones_vec_alloc = false;
|
||||
if (ones_vec == NULL)
|
||||
{
|
||||
ones_vec = new float[nqueries > npoints ? nqueries : npoints];
|
||||
std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0);
|
||||
ones_vec_alloc = true;
|
||||
}
|
||||
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim,
|
||||
(float)0.0, dist_matrix, npoints);
|
||||
|
||||
if (ones_vec_alloc)
|
||||
delete[] ones_vec;
|
||||
}
|
||||
|
||||
void exact_knn(const size_t dim, const size_t k,
|
||||
size_t *const closest_points, // k * num_queries preallocated, col
|
||||
// major, queries columns
|
||||
float *const dist_closest_points, // k * num_queries
|
||||
// preallocated, Dist to
|
||||
// corresponding closes_points
|
||||
size_t npoints,
|
||||
float *points_in, // points in Col major
|
||||
size_t nqueries, float *queries_in,
|
||||
diskann::Metric metric = diskann::Metric::L2) // queries in Col major
|
||||
{
|
||||
float *points_l2sq = new float[npoints];
|
||||
float *queries_l2sq = new float[nqueries];
|
||||
compute_l2sq(points_l2sq, points_in, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries_in, nqueries, dim);
|
||||
|
||||
float *points = points_in;
|
||||
float *queries = queries_in;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{ // we convert cosine distance as
|
||||
// normalized L2 distnace
|
||||
points = new float[npoints * dim];
|
||||
queries = new float[nqueries * dim];
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)npoints; i++)
|
||||
{
|
||||
float norm = std::sqrt(points_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
points[i * dim + j] = points_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for schedule(static, 4096)
|
||||
for (int64_t i = 0; i < (int64_t)nqueries; i++)
|
||||
{
|
||||
float norm = std::sqrt(queries_l2sq[i]);
|
||||
if (norm == 0)
|
||||
{
|
||||
norm = std::numeric_limits<float>::epsilon();
|
||||
}
|
||||
for (uint32_t j = 0; j < dim; j++)
|
||||
{
|
||||
queries[i * dim + j] = queries_in[i * dim + j] / norm;
|
||||
}
|
||||
}
|
||||
// recalculate norms after normalizing, they should all be one.
|
||||
compute_l2sq(points_l2sq, points, npoints, dim);
|
||||
compute_l2sq(queries_l2sq, queries, nqueries, dim);
|
||||
}
|
||||
|
||||
std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in "
|
||||
<< dim << " dimensions using";
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
std::cout << " MIPS ";
|
||||
else if (metric == diskann::Metric::COSINE)
|
||||
std::cout << " Cosine ";
|
||||
else
|
||||
std::cout << " L2 ";
|
||||
std::cout << "distance fn. " << std::endl;
|
||||
|
||||
size_t q_batch_size = (1 << 9);
|
||||
float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints];
|
||||
|
||||
for (uint64_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b)
|
||||
{
|
||||
int64_t q_b = b * q_batch_size;
|
||||
int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size;
|
||||
|
||||
if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE)
|
||||
{
|
||||
distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b,
|
||||
queries + (ptrdiff_t)q_b * (ptrdiff_t)dim);
|
||||
}
|
||||
std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic, 16)
|
||||
for (long long q = q_b; q < q_e; q++)
|
||||
{
|
||||
maxPQIFCS point_dist;
|
||||
for (size_t p = 0; p < k; p++)
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
for (size_t p = k; p < npoints; p++)
|
||||
{
|
||||
if (point_dist.top().second > dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints])
|
||||
point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]);
|
||||
if (point_dist.size() > k)
|
||||
point_dist.pop();
|
||||
}
|
||||
for (ptrdiff_t l = 0; l < (ptrdiff_t)k; ++l)
|
||||
{
|
||||
closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().first;
|
||||
dist_closest_points[(ptrdiff_t)(k - 1 - l) + (ptrdiff_t)q * (ptrdiff_t)k] = point_dist.top().second;
|
||||
point_dist.pop();
|
||||
}
|
||||
assert(std::is_sorted(dist_closest_points + (ptrdiff_t)q * (ptrdiff_t)k,
|
||||
dist_closest_points + (ptrdiff_t)(q + 1) * (ptrdiff_t)k));
|
||||
}
|
||||
std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl;
|
||||
}
|
||||
|
||||
delete[] dist_matrix;
|
||||
|
||||
delete[] points_l2sq;
|
||||
delete[] queries_l2sq;
|
||||
|
||||
if (metric == diskann::Metric::COSINE)
|
||||
{
|
||||
delete[] points;
|
||||
delete[] queries;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> inline int get_num_parts(const char *filename)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl;
|
||||
reader.close();
|
||||
int num_parts = (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1;
|
||||
std::cout << "Number of parts: " << num_parts << std::endl;
|
||||
return num_parts;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void load_bin_as_float(const char *filename, float *&data, size_t &npts_u64, size_t &ndims_u64, int part_num)
|
||||
{
|
||||
std::ifstream reader;
|
||||
reader.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
reader.open(filename, std::ios::binary);
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts_u64 = end_id - start_id;
|
||||
ndims_u64 = (uint64_t)ndims_i32;
|
||||
std::cout << "#pts in part = " << npts_u64 << ", #dims = " << ndims_u64
|
||||
<< ", size = " << npts_u64 * ndims_u64 * sizeof(T) << "B" << std::endl;
|
||||
|
||||
reader.seekg(start_id * ndims_u64 * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
T *data_T = new T[npts_u64 * ndims_u64];
|
||||
reader.read((char *)data_T, sizeof(T) * npts_u64 * ndims_u64);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
data = aligned_malloc<float>(npts_u64 * ndims_u64, ALIGNMENT);
|
||||
#pragma omp parallel for schedule(dynamic, 32768)
|
||||
for (int64_t i = 0; i < (int64_t)npts_u64; i++)
|
||||
{
|
||||
for (int64_t j = 0; j < (int64_t)ndims_u64; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndims_u64 + j];
|
||||
std::memcpy((char *)(data + i * ndims_u64 + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::vector<size_t> load_filtered_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims,
|
||||
int part_num, const char *label_file,
|
||||
const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &npoints_filt,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels)
|
||||
{
|
||||
std::ifstream reader(filename, std::ios::binary);
|
||||
if (reader.fail())
|
||||
{
|
||||
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
|
||||
}
|
||||
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
std::vector<size_t> rev_map;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (uint32_t)ndims_i32;
|
||||
uint64_t nptsuint64_t = (uint64_t)npts;
|
||||
uint64_t ndimsuint64_t = (uint64_t)ndims;
|
||||
npoints_filt = 0;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" << std::endl;
|
||||
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg);
|
||||
|
||||
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
|
||||
reader.read((char *)data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
|
||||
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
|
||||
|
||||
for (int64_t i = 0; i < (int64_t)nptsuint64_t; i++)
|
||||
{
|
||||
if (std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), filter_label) !=
|
||||
pts_to_labels[start_id + i].end() ||
|
||||
std::find(pts_to_labels[start_id + i].begin(), pts_to_labels[start_id + i].end(), universal_label) !=
|
||||
pts_to_labels[start_id + i].end())
|
||||
{
|
||||
rev_map.push_back(start_id + i);
|
||||
for (int64_t j = 0; j < (int64_t)ndimsuint64_t; j++)
|
||||
{
|
||||
float cur_val_float = (float)data_T[i * ndimsuint64_t + j];
|
||||
std::memcpy((char *)(data + npoints_filt * ndimsuint64_t + j), (char *)&cur_val_float, sizeof(float));
|
||||
}
|
||||
npoints_filt++;
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float.. identified " << npoints_filt
|
||||
<< " points matching the filter." << std::endl;
|
||||
return rev_map;
|
||||
}
|
||||
|
||||
template <typename T> inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims)
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ios::failbit | std::ios::badbit);
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
std::cout << "Writing bin: " << filename << "\n";
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "bin: #pts = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(T));
|
||||
writer.close();
|
||||
std::cout << "Finished writing bin" << std::endl;
|
||||
}
|
||||
|
||||
inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
std::ofstream writer(filename, std::ios::binary | std::ios::out);
|
||||
int npts_i32 = (int)npts, ndims_i32 = (int)ndims;
|
||||
writer.write((char *)&npts_i32, sizeof(int));
|
||||
writer.write((char *)&ndims_i32, sizeof(int));
|
||||
std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, "
|
||||
"npts*dim dist-matrix) with npts = "
|
||||
<< npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int)
|
||||
<< "B" << std::endl;
|
||||
|
||||
writer.write((char *)data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *)distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
inline void parse_label_file_into_vec(size_t &line_cnt, const std::string &map_file,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels)
|
||||
{
|
||||
std::ifstream infile(map_file);
|
||||
std::string line, token;
|
||||
std::set<std::string> labels;
|
||||
infile.clear();
|
||||
infile.seekg(0, std::ios::beg);
|
||||
while (std::getline(infile, line))
|
||||
{
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> lbls(0);
|
||||
|
||||
getline(iss, token, '\t');
|
||||
std::istringstream new_iss(token);
|
||||
while (getline(new_iss, token, ','))
|
||||
{
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
lbls.push_back(token);
|
||||
labels.insert(token);
|
||||
}
|
||||
std::sort(lbls.begin(), lbls.end());
|
||||
pts_to_labels.push_back(lbls);
|
||||
}
|
||||
std::cout << "Identified " << labels.size() << " distinct label(s), and populated labels for "
|
||||
<< pts_to_labels.size() << " points" << std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processUnfilteredParts(const std::string &base_file,
|
||||
size_t &nqueries, size_t &npoints,
|
||||
size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric,
|
||||
std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
float *base_data = nullptr;
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, base_data, nqueries, query_data,
|
||||
metric);
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> processFilteredParts(
|
||||
const std::string &base_file, const std::string &label_file, const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &nqueries, size_t &npoints, size_t &dim, size_t &k, float *query_data,
|
||||
const diskann::Metric &metric, std::vector<uint32_t> &location_to_tag)
|
||||
{
|
||||
size_t npoints_filt = 0;
|
||||
float *base_data = nullptr;
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> res(nqueries);
|
||||
int num_parts = get_num_parts<T>(base_file.c_str());
|
||||
|
||||
std::vector<std::vector<std::string>> pts_to_labels;
|
||||
if (filter_label != "")
|
||||
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
|
||||
|
||||
for (int p = 0; p < num_parts; p++)
|
||||
{
|
||||
size_t start_id = p * PARTSIZE;
|
||||
std::vector<size_t> rev_map;
|
||||
if (filter_label != "")
|
||||
rev_map = load_filtered_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
|
||||
filter_label, universal_label, npoints_filt, pts_to_labels);
|
||||
size_t *closest_points_part = new size_t[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto part_k = k < npoints_filt ? k : npoints_filt;
|
||||
if (npoints_filt > 0)
|
||||
{
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints_filt, base_data, nqueries,
|
||||
query_data, metric);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
for (uint64_t j = 0; j < part_k; j++)
|
||||
{
|
||||
if (!location_to_tag.empty())
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
|
||||
res[i].push_back(std::make_pair((uint32_t)(rev_map[closest_points_part[i * part_k + j]]),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
|
||||
delete[] closest_points_part;
|
||||
delete[] dist_closest_points_part;
|
||||
|
||||
diskann::aligned_free(base_data);
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &label_file, const std::string &query_file,
|
||||
const std::string >_file, size_t k, const std::string &universal_label, const diskann::Metric &metric,
|
||||
const std::string &filter_label, const std::string &tags_file = std::string(""))
|
||||
{
|
||||
size_t npoints, nqueries, dim;
|
||||
|
||||
float *query_data = nullptr;
|
||||
|
||||
load_bin_as_float<T>(query_file.c_str(), query_data, nqueries, dim, 0);
|
||||
if (nqueries > PARTSIZE)
|
||||
std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE
|
||||
<< ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl;
|
||||
|
||||
// load tags
|
||||
const bool tags_enabled = tags_file.empty() ? false : true;
|
||||
std::vector<uint32_t> location_to_tag = diskann::loadTags(tags_file, base_file);
|
||||
|
||||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::pair<uint32_t, float>>> results;
|
||||
if (filter_label == "")
|
||||
{
|
||||
results = processUnfilteredParts<T>(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag);
|
||||
}
|
||||
else
|
||||
{
|
||||
results = processFilteredParts<T>(base_file, label_file, filter_label, universal_label, nqueries, npoints, dim,
|
||||
k, query_data, metric, location_to_tag);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < nqueries; i++)
|
||||
{
|
||||
std::vector<std::pair<uint32_t, float>> &cur_res = results[i];
|
||||
std::sort(cur_res.begin(), cur_res.end(), custom_dist);
|
||||
size_t j = 0;
|
||||
for (auto iter : cur_res)
|
||||
{
|
||||
if (j == k)
|
||||
break;
|
||||
if (tags_enabled)
|
||||
{
|
||||
std::uint32_t index_with_tag = location_to_tag[iter.first];
|
||||
closest_points[i * k + j] = (int32_t)index_with_tag;
|
||||
}
|
||||
else
|
||||
{
|
||||
closest_points[i * k + j] = (int32_t)iter.first;
|
||||
}
|
||||
|
||||
if (metric == diskann::Metric::INNER_PRODUCT)
|
||||
dist_closest_points[i * k + j] = -iter.second;
|
||||
else
|
||||
dist_closest_points[i * k + j] = iter.second;
|
||||
|
||||
++j;
|
||||
}
|
||||
if (j < k)
|
||||
std::cout << "WARNING: found less than k GT entries for query " << i << std::endl;
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k);
|
||||
delete[] closest_points;
|
||||
delete[] dist_closest_points;
|
||||
diskann::aligned_free(query_data);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim)
|
||||
{
|
||||
size_t read_blk_size = 64 * 1024 * 1024;
|
||||
cached_ifstream reader(bin_file, read_blk_size);
|
||||
diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl;
|
||||
size_t actual_file_size = reader.get_file_size();
|
||||
|
||||
int npts_i32, dim_i32;
|
||||
reader.read((char *)&npts_i32, sizeof(int));
|
||||
reader.read((char *)&dim_i32, sizeof(int));
|
||||
npts = (uint32_t)npts_i32;
|
||||
dim = (uint32_t)dim_i32;
|
||||
|
||||
diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl;
|
||||
|
||||
int truthset_type = -1; // 1 means truthset has ids and distances, 2 means
|
||||
// only ids, -1 is error
|
||||
size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_with_dists)
|
||||
truthset_type = 1;
|
||||
|
||||
size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t);
|
||||
|
||||
if (actual_file_size == expected_file_size_just_ids)
|
||||
truthset_type = 2;
|
||||
|
||||
if (truthset_type == -1)
|
||||
{
|
||||
std::stringstream stream;
|
||||
stream << "Error. File size mismatch. File should have bin format, with "
|
||||
"npts followed by ngt followed by npts*ngt ids and optionally "
|
||||
"followed by npts*ngt distance values; actual size: "
|
||||
<< actual_file_size << ", expected: " << expected_file_size_with_dists << " or "
|
||||
<< expected_file_size_just_ids;
|
||||
diskann::cout << stream.str();
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
ids = new uint32_t[npts * dim];
|
||||
reader.read((char *)ids, npts * dim * sizeof(uint32_t));
|
||||
|
||||
if (truthset_type == 1)
|
||||
{
|
||||
dists = new float[npts * dim];
|
||||
reader.read((char *)dists, npts * dim * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, label_file, filter_label,
|
||||
universal_label, filter_label_file;
|
||||
uint64_t K;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->required(), "distance function <l2/mips>");
|
||||
desc.add_options()("base_file", po::value<std::string>(&base_file)->required(),
|
||||
"File containing the base vectors in binary format");
|
||||
desc.add_options()("query_file", po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input labels file in txt format if present");
|
||||
desc.add_options()("filter_label", po::value<std::string>(&filter_label)->default_value(""),
|
||||
"Input filter label if doing filtered groundtruth");
|
||||
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with label_file");
|
||||
desc.add_options()("gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary "
|
||||
"format, please don' append .bin at end if "
|
||||
"no filter_label or filter_label_file is provided it "
|
||||
"will save the file with '.bin' at end."
|
||||
"else it will save the file as filename_label.bin");
|
||||
desc.add_options()("K", po::value<uint64_t>(&K)->required(),
|
||||
"Number of ground truth nearest neighbors to compute");
|
||||
desc.add_options()("tags_file", po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"File containing the tags in binary format");
|
||||
desc.add_options()("filter_label_file",
|
||||
po::value<std::string>(&filter_label_file)->default_value(std::string("")),
|
||||
"Filter file for Queries for Filtered Search ");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (filter_label != "" && filter_label_file != "")
|
||||
{
|
||||
std::cerr << "Only one of filter_label and query_filters_file should be provided" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
{
|
||||
metric = diskann::Metric::L2;
|
||||
}
|
||||
else if (dist_fn == std::string("mips"))
|
||||
{
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
}
|
||||
else if (dist_fn == std::string("cosine"))
|
||||
{
|
||||
metric = diskann::Metric::COSINE;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::string> filter_labels;
|
||||
if (filter_label != "")
|
||||
{
|
||||
filter_labels.push_back(filter_label);
|
||||
}
|
||||
else if (filter_label_file != "")
|
||||
{
|
||||
filter_labels = read_file_to_vector_of_strings(filter_label_file, false);
|
||||
}
|
||||
|
||||
// only when there is no filter label or 1 filter label for all queries
|
||||
if (filter_labels.size() == 1)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K, universal_label, metric,
|
||||
filter_labels[0], tags_file);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{ // Each query has its own filter label
|
||||
// Split up data and query bins into label specific ones
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_points;
|
||||
tsl::robin_map<std::string, uint32_t> labels_to_number_of_queries;
|
||||
|
||||
label_set all_labels;
|
||||
for (size_t i = 0; i < filter_labels.size(); i++)
|
||||
{
|
||||
std::string label = filter_labels[i];
|
||||
all_labels.insert(label);
|
||||
|
||||
if (labels_to_number_of_queries.find(label) == labels_to_number_of_queries.end())
|
||||
{
|
||||
labels_to_number_of_queries[label] = 0;
|
||||
}
|
||||
labels_to_number_of_queries[label] += 1;
|
||||
}
|
||||
|
||||
size_t npoints;
|
||||
std::vector<std::vector<std::string>> point_to_labels;
|
||||
parse_label_file_into_vec(npoints, label_file, point_to_labels);
|
||||
std::vector<label_set> point_ids_to_labels(point_to_labels.size());
|
||||
std::vector<label_set> query_ids_to_labels(filter_labels.size());
|
||||
|
||||
for (size_t i = 0; i < point_to_labels.size(); i++)
|
||||
{
|
||||
for (size_t j = 0; j < point_to_labels[i].size(); j++)
|
||||
{
|
||||
std::string label = point_to_labels[i][j];
|
||||
if (all_labels.find(label) != all_labels.end())
|
||||
{
|
||||
point_ids_to_labels[i].insert(point_to_labels[i][j]);
|
||||
if (labels_to_number_of_points.find(label) == labels_to_number_of_points.end())
|
||||
{
|
||||
labels_to_number_of_points[label] = 0;
|
||||
}
|
||||
labels_to_number_of_points[label] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < filter_labels.size(); i++)
|
||||
{
|
||||
query_ids_to_labels[i].insert(filter_labels[i]);
|
||||
}
|
||||
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_id_to_orig_id;
|
||||
tsl::robin_map<std::string, std::vector<uint32_t>> label_query_id_to_orig_id;
|
||||
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<float>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<int8_t>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
label_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
base_file, labels_to_number_of_points, point_ids_to_labels, all_labels);
|
||||
|
||||
label_query_id_to_orig_id = diskann::generate_label_specific_vector_files_compat<uint8_t>(
|
||||
query_file, labels_to_number_of_queries, query_ids_to_labels,
|
||||
all_labels); // query_filters acts like query_ids_to_labels
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::cerr << "Invalid data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Generate label specific ground truths
|
||||
|
||||
try
|
||||
{
|
||||
for (const auto &label : all_labels)
|
||||
{
|
||||
std::string filtered_base_file = base_file + "_" + label;
|
||||
std::string filtered_query_file = query_file + "_" + label;
|
||||
std::string filtered_gt_file = gt_file + "_" + label;
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(filtered_base_file, "", filtered_query_file, filtered_gt_file, K, "", metric, "");
|
||||
}
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Combine the label specific ground truths to produce a single GT file
|
||||
|
||||
uint32_t *gt_ids = nullptr;
|
||||
float *gt_dists = nullptr;
|
||||
size_t gt_num, gt_dim;
|
||||
|
||||
std::vector<std::vector<int32_t>> final_gt_ids;
|
||||
std::vector<std::vector<float>> final_gt_dists;
|
||||
|
||||
uint32_t query_num = 0;
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
query_num += labels_to_number_of_queries[lbl];
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
final_gt_ids.push_back(std::vector<int32_t>(K));
|
||||
final_gt_dists.push_back(std::vector<float>(K));
|
||||
}
|
||||
|
||||
for (const auto &lbl : all_labels)
|
||||
{
|
||||
std::string filtered_gt_file = gt_file + "_" + lbl;
|
||||
load_truthset(filtered_gt_file, gt_ids, gt_dists, gt_num, gt_dim);
|
||||
|
||||
for (uint32_t i = 0; i < labels_to_number_of_queries[lbl]; i++)
|
||||
{
|
||||
uint32_t orig_query_id = label_query_id_to_orig_id[lbl][i];
|
||||
for (uint64_t j = 0; j < K; j++)
|
||||
{
|
||||
final_gt_ids[orig_query_id][j] = label_id_to_orig_id[lbl][gt_ids[i * K + j]];
|
||||
final_gt_dists[orig_query_id][j] = gt_dists[i * K + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t *closest_points = new int32_t[query_num * K];
|
||||
float *dist_closest_points = new float[query_num * K];
|
||||
|
||||
for (uint32_t i = 0; i < query_num; i++)
|
||||
{
|
||||
for (uint32_t j = 0; j < K; j++)
|
||||
{
|
||||
closest_points[i * K + j] = final_gt_ids[i][j];
|
||||
dist_closest_points[i * K + j] = final_gt_dists[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, query_num, K);
|
||||
|
||||
// cleanup artifacts
|
||||
std::cout << "Cleaning up artifacts..." << std::endl;
|
||||
tsl::robin_set<std::string> paths_to_clean{gt_file, base_file, query_file};
|
||||
clean_up_artifacts(paths_to_clean, all_labels);
|
||||
}
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <omp.h>
|
||||
#include <set>
|
||||
#include <string.h>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#include "utils.h"
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template <typename T> void bfs_count(const std::string &index_path, uint32_t data_dims)
|
||||
{
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0, nullptr, nullptr, 0, false, false, false,
|
||||
false, 0, false);
|
||||
std::cout << "Index class instantiated" << std::endl;
|
||||
index.load(index_path.c_str(), 1, 100);
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
index.count_nodes_at_bfs_levels();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, index_path_prefix;
|
||||
uint32_t data_dims;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("index_path_prefix", po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix to the index");
|
||||
desc.add_options()("data_dims", po::value<uint32_t>(&data_dims)->required(), "Dimensionality of the data");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
if (data_type == std::string("int8"))
|
||||
bfs_count<int8_t>(index_path_prefix, data_dims);
|
||||
else if (data_type == std::string("uint8"))
|
||||
bfs_count<uint8_t>(index_path_prefix, data_dims);
|
||||
if (data_type == std::string("float"))
|
||||
bfs_count<float>(index_path_prefix, data_dims);
|
||||
}
|
||||
catch (std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index BFS failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
#include "disk_utils.h"
|
||||
#include "cached_io.h"
|
||||
|
||||
template <typename T> int create_disk_layout(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
std::string vamana_file(argv[3]);
|
||||
std::string output_file(argv[4]);
|
||||
diskann::create_disk_layout<T>(base_file, vamana_file, output_file);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type <float/int8/uint8> data_bin "
|
||||
"vamana_index_file output_diskann_index_file"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int ret_val = -1;
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
ret_val = create_disk_layout<float>(argv);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
ret_val = create_disk_layout<int8_t>(argv);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
ret_val = create_disk_layout<uint8_t>(argv);
|
||||
else
|
||||
{
|
||||
std::cout << "unsupported type. use int8/uint8/float " << std::endl;
|
||||
ret_val = -2;
|
||||
}
|
||||
return ret_val;
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ofstream &writer, int8_t *write_buf, std::ifstream &reader, float *read_buf, size_t npts,
|
||||
size_t ndims, float bias, float scale)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(float));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
write_buf[d + i * ndims] = (int8_t)((read_buf[d + i * ndims] - bias) * (254.0 / scale));
|
||||
}
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " input_bin output_tsv bias scale" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[1], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new float[blk_size * ndims];
|
||||
auto write_buf = new int8_t[blk_size * ndims];
|
||||
float bias = (float)atof(argv[3]);
|
||||
float scale = (float)atof(argv[4]);
|
||||
|
||||
writer.write((char *)(&npts_u32), sizeof(uint32_t));
|
||||
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
// Convert float types
|
||||
void block_convert_float(std::ifstream &reader, std::ofstream &writer, float *read_buf, float *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(float));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
// Convert byte types
|
||||
void block_convert_byte(std::ifstream &reader, std::ofstream &writer, uint8_t *read_buf, uint8_t *write_buf,
|
||||
size_t npts, size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(uint8_t) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + sizeof(uint32_t))) + sizeof(uint32_t),
|
||||
ndims * sizeof(uint8_t));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cout << argv[0] << " <float/int8/uint8> input_vecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int datasize = sizeof(float);
|
||||
|
||||
if (strcmp(argv[1], "uint8") == 0 || strcmp(argv[1], "int8") == 0)
|
||||
{
|
||||
datasize = sizeof(uint8_t);
|
||||
}
|
||||
else if (strcmp(argv[1], "float") != 0)
|
||||
{
|
||||
std::cout << "Error: type not supported. Use float/int8/uint8" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims * datasize) + sizeof(uint32_t));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[3], std::ios::binary);
|
||||
int32_t npts_s32 = (int32_t)npts;
|
||||
int32_t ndims_s32 = (int32_t)ndims;
|
||||
writer.write((char *)&npts_s32, sizeof(int32_t));
|
||||
writer.write((char *)&ndims_s32, sizeof(int32_t));
|
||||
|
||||
size_t chunknpts = std::min(npts, blk_size);
|
||||
uint8_t *read_buf = new uint8_t[chunknpts * ((ndims * datasize) + sizeof(uint32_t))];
|
||||
uint8_t *write_buf = new uint8_t[chunknpts * ndims * datasize];
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (datasize == sizeof(float))
|
||||
{
|
||||
block_convert_float(reader, writer, (float *)read_buf, (float *)write_buf, cblk_size, ndims);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_convert_byte(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
}
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream &reader, std::ofstream &writer, float *read_buf, uint8_t *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(float) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * (ndims + 4), read_buf + i * (ndims + 1), sizeof(uint32_t));
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
write_buf[i * (ndims + 4) + 4 + d] = (uint8_t)read_buf[i * (ndims + 1) + 1 + d];
|
||||
}
|
||||
writer.write((char *)write_buf, npts * (ndims * 1 + 4));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_fvecs output_bvecs(uint8)" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims + 1) * sizeof(float));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new float[npts * (ndims + 1)];
|
||||
auto write_buf = new uint8_t[npts * (ndims + 4)];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "partition.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
|
||||
template <typename T> int aux_main(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
std::string output_prefix(argv[3]);
|
||||
float sampling_rate = (float)(std::atof(argv[4]));
|
||||
gen_random_slice<T>(base_file, output_prefix, sampling_rate);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type [float/int8/uint8] base_bin_file "
|
||||
"sample_output_prefix sampling_probability"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
aux_main<float>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
aux_main<int8_t>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
aux_main<uint8_t>(argv);
|
||||
}
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include "math_utils.h"
|
||||
#include "pq.h"
|
||||
#include "partition.h"
|
||||
|
||||
#define KMEANS_ITERS_FOR_PQ 15
|
||||
|
||||
template <typename T>
|
||||
bool generate_pq(const std::string &data_path, const std::string &index_prefix_path, const size_t num_pq_centers,
|
||||
const size_t num_pq_chunks, const float sampling_rate, const bool opq)
|
||||
{
|
||||
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
|
||||
std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin";
|
||||
|
||||
// generates random sample and sets it to train_data and updates train_size
|
||||
size_t train_size, train_dim;
|
||||
float *train_data;
|
||||
gen_random_slice<T>(data_path, sampling_rate, train_data, train_size, train_dim);
|
||||
std::cout << "For computing pivots, loaded sample data of size " << train_size << std::endl;
|
||||
|
||||
if (opq)
|
||||
{
|
||||
diskann::generate_opq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
|
||||
(uint32_t)num_pq_chunks, pq_pivots_path, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
diskann::generate_pq_pivots(train_data, train_size, (uint32_t)train_dim, (uint32_t)num_pq_centers,
|
||||
(uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path);
|
||||
}
|
||||
diskann::generate_pq_data_from_pivots<T>(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks,
|
||||
pq_pivots_path, pq_compressed_vectors_path, true);
|
||||
|
||||
delete[] train_data;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 7)
|
||||
{
|
||||
std::cout << "Usage: \n"
|
||||
<< argv[0]
|
||||
<< " <data_type[float/uint8/int8]> <data_file[.bin]>"
|
||||
" <PQ_prefix_path> <target-bytes/data-point> "
|
||||
"<sampling_rate> <PQ(0)/OPQ(1)>"
|
||||
<< std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string index_prefix_path(argv[3]);
|
||||
const size_t num_pq_centers = 256;
|
||||
const size_t num_pq_chunks = (size_t)atoi(argv[4]);
|
||||
const float sampling_rate = (float)atof(argv[5]);
|
||||
const bool opq = atoi(argv[6]) == 0 ? false : true;
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
generate_pq<float>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
generate_pq<int8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
generate_pq<uint8_t>(data_path, index_prefix_path, num_pq_centers, num_pq_chunks, sampling_rate, opq);
|
||||
else
|
||||
std::cout << "Error. wrong file type" << std::endl;
|
||||
}
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <math.h>
|
||||
#include <cmath>
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
class ZipfDistribution
|
||||
{
|
||||
public:
|
||||
ZipfDistribution(uint64_t num_points, uint32_t num_labels)
|
||||
: num_labels(num_labels), num_points(num_points),
|
||||
uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0))
|
||||
{
|
||||
}
|
||||
|
||||
std::unordered_map<uint32_t, uint32_t> createDistributionMap()
|
||||
{
|
||||
std::unordered_map<uint32_t, uint32_t> map;
|
||||
uint32_t primary_label_freq = (uint32_t)ceil(num_points * distribution_factor);
|
||||
for (uint32_t i{1}; i < num_labels + 1; i++)
|
||||
{
|
||||
map[i] = (uint32_t)ceil(primary_label_freq / i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
int writeDistribution(std::ofstream &outfile)
|
||||
{
|
||||
auto distribution_map = createDistributionMap();
|
||||
for (uint32_t i{0}; i < num_points; i++)
|
||||
{
|
||||
bool label_written = false;
|
||||
for (auto it = distribution_map.cbegin(); it != distribution_map.cend(); it++)
|
||||
{
|
||||
auto label_selection_probability = std::bernoulli_distribution(distribution_factor / (double)it->first);
|
||||
if (label_selection_probability(rand_engine) && distribution_map[it->first] > 0)
|
||||
{
|
||||
if (label_written)
|
||||
{
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << it->first;
|
||||
label_written = true;
|
||||
// remove label from map if we have used all labels
|
||||
distribution_map[it->first] -= 1;
|
||||
}
|
||||
}
|
||||
if (!label_written)
|
||||
{
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1)
|
||||
{
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int writeDistribution(std::string filename)
|
||||
{
|
||||
std::ofstream outfile(filename);
|
||||
if (!outfile.is_open())
|
||||
{
|
||||
std::cerr << "Error: could not open output file " << filename << '\n';
|
||||
return -1;
|
||||
}
|
||||
writeDistribution(outfile);
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
private:
|
||||
const uint32_t num_labels;
|
||||
const uint64_t num_points;
|
||||
const double distribution_factor = 0.7;
|
||||
std::knuth_b rand_engine;
|
||||
const std::uniform_real_distribution<double> uniform_zero_to_one;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string output_file, distribution_type;
|
||||
uint32_t num_labels;
|
||||
uint64_t num_points;
|
||||
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("output_file,O", po::value<std::string>(&output_file)->required(),
|
||||
"Filename for saving the label file");
|
||||
desc.add_options()("num_points,N", po::value<uint64_t>(&num_points)->required(), "Number of points in dataset");
|
||||
desc.add_options()("num_labels,L", po::value<uint32_t>(&num_labels)->required(),
|
||||
"Number of unique labels, up to 5000");
|
||||
desc.add_options()("distribution_type,DT", po::value<std::string>(&distribution_type)->default_value("random"),
|
||||
"Distribution function for labels <random/zipf/one_per_point> defaults "
|
||||
"to random");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_labels > 5000)
|
||||
{
|
||||
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_points <= 0)
|
||||
{
|
||||
std::cerr << "Error: num_points must be greater than 0" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Generating synthetic labels for " << num_points << " points with " << num_labels << " unique labels"
|
||||
<< '\n';
|
||||
|
||||
try
|
||||
{
|
||||
std::ofstream outfile(output_file);
|
||||
if (!outfile.is_open())
|
||||
{
|
||||
std::cerr << "Error: could not open output file " << output_file << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (distribution_type == "zipf")
|
||||
{
|
||||
ZipfDistribution zipf(num_points, num_labels);
|
||||
zipf.writeDistribution(outfile);
|
||||
}
|
||||
else if (distribution_type == "random")
|
||||
{
|
||||
for (size_t i = 0; i < num_points; i++)
|
||||
{
|
||||
bool label_written = false;
|
||||
for (size_t j = 1; j <= num_labels; j++)
|
||||
{
|
||||
// 50% chance to assign each label
|
||||
if (rand() > (RAND_MAX / 2))
|
||||
{
|
||||
if (label_written)
|
||||
{
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << j;
|
||||
label_written = true;
|
||||
}
|
||||
}
|
||||
if (!label_written)
|
||||
{
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1)
|
||||
{
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (distribution_type == "one_per_point")
|
||||
{
|
||||
std::random_device rd; // obtain a random number from hardware
|
||||
std::mt19937 gen(rd()); // seed the generator
|
||||
std::uniform_int_distribution<> distr(0, num_labels); // define the range
|
||||
|
||||
for (size_t i = 0; i < num_points; i++)
|
||||
{
|
||||
outfile << distr(gen);
|
||||
if (i != num_points - 1)
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
if (outfile.is_open())
|
||||
{
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
std::cout << "Labels written to " << output_file << '\n';
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << "Label generation failed: " << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_int8_bin output_float_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int8_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<int8_t>(argv[1], input, npts, nd);
|
||||
float *output = new float[npts * nd];
|
||||
diskann::convert_types<int8_t, float>(input, output, npts, nd);
|
||||
diskann::save_bin<float>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ofstream &writer, float *write_buf, std::ifstream &reader, int8_t *read_buf, size_t npts,
|
||||
size_t ndims, float bias, float scale)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * ndims * sizeof(int8_t));
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
{
|
||||
write_buf[d + i * ndims] = (((float)read_buf[d + i * ndims] - bias) * scale);
|
||||
}
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(float));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 5)
|
||||
{
|
||||
std::cout << "Usage: " << argv[0] << " input-int8.bin output-float.bin bias scale" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::ifstream reader(argv[1], std::ios::binary);
|
||||
uint32_t npts_u32;
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&npts_u32, sizeof(uint32_t));
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
size_t npts = npts_u32;
|
||||
size_t ndims = ndims_u32;
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
auto read_buf = new int8_t[blk_size * ndims];
|
||||
auto write_buf = new float[blk_size * ndims];
|
||||
float bias = (float)atof(argv[3]);
|
||||
float scale = (float)atof(argv[4]);
|
||||
|
||||
writer.write((char *)(&npts_u32), sizeof(uint32_t));
|
||||
writer.write((char *)(&ndims_u32), sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(writer, write_buf, reader, read_buf, cblk_size, ndims, bias, scale);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
writer.close();
|
||||
reader.close();
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert(std::ifstream &reader, std::ofstream &writer, uint32_t *read_buf, uint32_t *write_buf, size_t npts,
|
||||
size_t ndims)
|
||||
{
|
||||
reader.read((char *)read_buf, npts * (ndims * sizeof(uint32_t) + sizeof(uint32_t)));
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
memcpy(write_buf + i * ndims, (read_buf + i * (ndims + 1)) + 1, ndims * sizeof(uint32_t));
|
||||
}
|
||||
writer.write((char *)write_buf, npts * ndims * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_ivecs output_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream reader(argv[1], std::ios::binary | std::ios::ate);
|
||||
size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
uint32_t ndims_u32;
|
||||
reader.read((char *)&ndims_u32, sizeof(uint32_t));
|
||||
reader.seekg(0, std::ios::beg);
|
||||
size_t ndims = (size_t)ndims_u32;
|
||||
size_t npts = fsize / ((ndims + 1) * sizeof(uint32_t));
|
||||
std::cout << "Dataset: #pts = " << npts << ", # dims = " << ndims << std::endl;
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[2], std::ios::binary);
|
||||
int npts_s32 = (int)npts;
|
||||
int ndims_s32 = (int)ndims;
|
||||
writer.write((char *)&npts_s32, sizeof(int));
|
||||
writer.write((char *)&ndims_s32, sizeof(int));
|
||||
uint32_t *read_buf = new uint32_t[npts * (ndims + 1)];
|
||||
uint32_t *write_buf = new uint32_t[npts * ndims];
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
block_convert(reader, writer, read_buf, write_buf, cblk_size, ndims);
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
delete[] read_buf;
|
||||
delete[] write_buf;
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "disk_utils.h"
|
||||
#include "cached_io.h"
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 9)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " vamana_index_prefix[1] vamana_index_suffix[2] "
|
||||
"idmaps_prefix[3] "
|
||||
"idmaps_suffix[4] n_shards[5] max_degree[6] "
|
||||
"output_vamana_path[7] "
|
||||
"output_medoids_path[8]"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string vamana_prefix(argv[1]);
|
||||
std::string vamana_suffix(argv[2]);
|
||||
std::string idmaps_prefix(argv[3]);
|
||||
std::string idmaps_suffix(argv[4]);
|
||||
uint64_t nshards = (uint64_t)std::atoi(argv[5]);
|
||||
uint32_t max_degree = (uint64_t)std::atoi(argv[6]);
|
||||
std::string output_index(argv[7]);
|
||||
std::string output_medoids(argv[8]);
|
||||
|
||||
return diskann::merge_shards(vamana_prefix, vamana_suffix, idmaps_prefix, idmaps_suffix, nshards, max_degree,
|
||||
output_index, output_medoids);
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <math_utils.h>
|
||||
#include "cached_io.h"
|
||||
#include "partition.h"
|
||||
|
||||
// DEPRECATED: NEED TO REPROGRAM
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 7)
|
||||
{
|
||||
std::cout << "Usage:\n"
|
||||
<< argv[0]
|
||||
<< " datatype<int8/uint8/float> <data_path>"
|
||||
" <prefix_path> <sampling_rate> "
|
||||
" <num_partitions> <k_index>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string prefix_path(argv[3]);
|
||||
const float sampling_rate = (float)atof(argv[4]);
|
||||
const size_t num_partitions = (size_t)std::atoi(argv[5]);
|
||||
const size_t max_reps = 15;
|
||||
const size_t k_index = (size_t)std::atoi(argv[6]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
partition<float>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
partition<int8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
partition<uint8_t>(data_path, sampling_rate, num_partitions, max_reps, prefix_path, k_index);
|
||||
else
|
||||
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <index.h>
|
||||
#include <math_utils.h>
|
||||
#include "cached_io.h"
|
||||
#include "partition.h"
|
||||
|
||||
// DEPRECATED: NEED TO REPROGRAM
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 8)
|
||||
{
|
||||
std::cout << "Usage:\n"
|
||||
<< argv[0]
|
||||
<< " datatype<int8/uint8/float> <data_path>"
|
||||
" <prefix_path> <sampling_rate> "
|
||||
" <ram_budget(GB)> <graph_degree> <k_index>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const std::string data_path(argv[2]);
|
||||
const std::string prefix_path(argv[3]);
|
||||
const float sampling_rate = (float)atof(argv[4]);
|
||||
const double ram_budget = (double)std::atof(argv[5]);
|
||||
const size_t graph_degree = (size_t)std::atoi(argv[6]);
|
||||
const size_t k_index = (size_t)std::atoi(argv[7]);
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
partition_with_ram_budget<float>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
partition_with_ram_budget<int8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
partition_with_ram_budget<uint8_t>(data_path, sampling_rate, ram_budget, graph_degree, prefix_path, k_index);
|
||||
else
|
||||
std::cout << "unsupported data format. use float/int8/uint8" << std::endl;
|
||||
}
|
||||
@@ -1,237 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int block_write_float(std::ofstream &writer, size_t ndims, size_t npts, bool normalization, float norm,
|
||||
float rand_scale)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
std::uniform_real_distribution<> unif_dis(1.0, rand_scale);
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
float scale = 1.0f;
|
||||
if (rand_scale > 1.0f)
|
||||
scale = (float)unif_dis(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = scale * (float)normal_rand(gen);
|
||||
if (normalization)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
}
|
||||
|
||||
writer.write((char *)vec, ndims * sizeof(float));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int block_write_int8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
auto vec_T = new int8_t[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = (float)normal_rand(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
vec_T[d] = (int8_t)std::round(vec[d]);
|
||||
}
|
||||
|
||||
writer.write((char *)vec_T, ndims * sizeof(int8_t));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
delete[] vec_T;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int block_write_uint8(std::ofstream &writer, size_t ndims, size_t npts, float norm)
|
||||
{
|
||||
auto vec = new float[ndims];
|
||||
auto vec_T = new int8_t[ndims];
|
||||
|
||||
std::random_device rd{};
|
||||
std::mt19937 gen{rd()};
|
||||
std::normal_distribution<> normal_rand{0, 1};
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
float sum = 0;
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = (float)normal_rand(gen);
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
sum += vec[d] * vec[d];
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
vec[d] = vec[d] * norm / std::sqrt(sum);
|
||||
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
vec_T[d] = 128 + (int8_t)std::round(vec[d]);
|
||||
}
|
||||
|
||||
writer.write((char *)vec_T, ndims * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
delete[] vec;
|
||||
delete[] vec_T;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string data_type, output_file;
|
||||
size_t ndims, npts;
|
||||
float norm, rand_scaling;
|
||||
bool normalization = false;
|
||||
try
|
||||
{
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
|
||||
desc.add_options()("data_type", po::value<std::string>(&data_type)->required(), "data type <int8/uint8/float>");
|
||||
desc.add_options()("output_file", po::value<std::string>(&output_file)->required(),
|
||||
"File name for saving the random vectors");
|
||||
desc.add_options()("ndims,D", po::value<uint64_t>(&ndims)->required(), "Dimensoinality of the vector");
|
||||
desc.add_options()("npts,N", po::value<uint64_t>(&npts)->required(), "Number of vectors");
|
||||
desc.add_options()("norm", po::value<float>(&norm)->default_value(-1.0f),
|
||||
"Norm of the vectors (if not specified, vectors are not normalized)");
|
||||
desc.add_options()("rand_scaling", po::value<float>(&rand_scaling)->default_value(1.0f),
|
||||
"Each vector will be scaled (if not explicitly normalized) by a factor randomly chosen from "
|
||||
"[1, rand_scale]. Only applicable for floating point data");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &ex)
|
||||
{
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (norm > 0.0)
|
||||
{
|
||||
normalization = true;
|
||||
}
|
||||
|
||||
if (rand_scaling < 1.0)
|
||||
{
|
||||
std::cout << "We will only scale the vector norms randomly in [1, value], so value must be >= 1." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if ((rand_scaling > 1.0) && (normalization == true))
|
||||
{
|
||||
std::cout << "Data cannot be normalized and randomly scaled at same time. Use one or the other." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("int8") || data_type == std::string("uint8"))
|
||||
{
|
||||
if (norm > 127)
|
||||
{
|
||||
std::cerr << "Error: for int8/uint8 datatypes, L2 norm can not be "
|
||||
"greater "
|
||||
"than 127"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
if (rand_scaling > 1.0)
|
||||
{
|
||||
std::cout << "Data scaling only supported for floating point data." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
std::ofstream writer;
|
||||
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
|
||||
writer.open(output_file, std::ios::binary);
|
||||
auto npts_u32 = (uint32_t)npts;
|
||||
auto ndims_u32 = (uint32_t)ndims;
|
||||
writer.write((char *)&npts_u32, sizeof(uint32_t));
|
||||
writer.write((char *)&ndims_u32, sizeof(uint32_t));
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
|
||||
int ret = 0;
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (data_type == std::string("float"))
|
||||
{
|
||||
ret = block_write_float(writer, ndims, cblk_size, normalization, norm, rand_scaling);
|
||||
}
|
||||
else if (data_type == std::string("int8"))
|
||||
{
|
||||
ret = block_write_int8(writer, ndims, cblk_size, norm);
|
||||
}
|
||||
else if (data_type == std::string("uint8"))
|
||||
{
|
||||
ret = block_write_uint8(writer, ndims, cblk_size, norm);
|
||||
}
|
||||
if (ret == 0)
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
else
|
||||
{
|
||||
writer.close();
|
||||
std::cout << "failed to write" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
writer.close();
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index build failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
|
||||
inline float aggregate_recall(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, uint32_t *count,
|
||||
const std::vector<float> &recalls)
|
||||
{
|
||||
float found = 0;
|
||||
for (uint32_t i = 0; i < npart; ++i)
|
||||
{
|
||||
size_t max_found = std::min(count[i], k);
|
||||
found += recalls[max_found - 1] * max_found;
|
||||
}
|
||||
return found / (float)k_aggr;
|
||||
}
|
||||
|
||||
void simulate(const uint32_t k_aggr, const uint32_t k, const uint32_t npart, const uint32_t nsim,
|
||||
const std::vector<float> &recalls)
|
||||
{
|
||||
std::random_device r;
|
||||
std::default_random_engine randeng(r());
|
||||
std::uniform_int_distribution<int> uniform_dist(0, npart - 1);
|
||||
|
||||
uint32_t *count = new uint32_t[npart];
|
||||
double aggr_recall = 0;
|
||||
|
||||
for (uint32_t i = 0; i < nsim; ++i)
|
||||
{
|
||||
for (uint32_t p = 0; p < npart; ++p)
|
||||
{
|
||||
count[p] = 0;
|
||||
}
|
||||
for (uint32_t t = 0; t < k_aggr; ++t)
|
||||
{
|
||||
count[uniform_dist(randeng)]++;
|
||||
}
|
||||
aggr_recall += aggregate_recall(k_aggr, k, npart, count, recalls);
|
||||
}
|
||||
|
||||
std::cout << "Aggregate recall is " << aggr_recall / (double)nsim << std::endl;
|
||||
delete[] count;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc < 6)
|
||||
{
|
||||
std::cout << argv[0] << " k_aggregate k_out npart nsim recall@1 recall@2 ... recall@k" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const uint32_t k_aggr = atoi(argv[1]);
|
||||
const uint32_t k = atoi(argv[2]);
|
||||
const uint32_t npart = atoi(argv[3]);
|
||||
const uint32_t nsim = atoi(argv[4]);
|
||||
|
||||
std::vector<float> recalls;
|
||||
for (int ctr = 5; ctr < argc; ctr++)
|
||||
{
|
||||
recalls.push_back((float)atof(argv[ctr]));
|
||||
}
|
||||
|
||||
if (recalls.size() != k)
|
||||
{
|
||||
std::cerr << "Please input k numbers for recall@1, recall@2 .. recall@k" << std::endl;
|
||||
}
|
||||
if (k_aggr > npart * k)
|
||||
{
|
||||
std::cerr << "k_aggr must be <= k * npart" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
if (nsim <= npart * k_aggr)
|
||||
{
|
||||
std::cerr << "Choose nsim > npart*k_aggr" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
simulate(k_aggr, k, npart, nsim, recalls);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
namespace po = boost::program_options;
|
||||
|
||||
void stats_analysis(const std::string labels_file, std::string univeral_label, uint32_t density = 10)
|
||||
{
|
||||
std::string token, line;
|
||||
std::ifstream labels_stream(labels_file);
|
||||
std::unordered_map<std::string, uint32_t> label_counts;
|
||||
std::string label_with_max_points;
|
||||
uint32_t max_points = 0;
|
||||
long long sum = 0;
|
||||
long long point_cnt = 0;
|
||||
float avg_labels_per_pt, mean_label_size;
|
||||
|
||||
std::vector<uint32_t> labels_per_point;
|
||||
uint32_t dense_pts = 0;
|
||||
if (labels_stream.is_open())
|
||||
{
|
||||
while (getline(labels_stream, line))
|
||||
{
|
||||
point_cnt++;
|
||||
std::stringstream iss(line);
|
||||
uint32_t lbl_cnt = 0;
|
||||
while (getline(iss, token, ','))
|
||||
{
|
||||
lbl_cnt++;
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
if (label_counts.find(token) == label_counts.end())
|
||||
label_counts[token] = 0;
|
||||
label_counts[token]++;
|
||||
}
|
||||
if (lbl_cnt >= density)
|
||||
{
|
||||
dense_pts++;
|
||||
}
|
||||
labels_per_point.emplace_back(lbl_cnt);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "fraction of dense points with >= " << density
|
||||
<< " labels = " << (float)dense_pts / (float)labels_per_point.size() << std::endl;
|
||||
std::sort(labels_per_point.begin(), labels_per_point.end());
|
||||
|
||||
std::vector<std::pair<std::string, uint32_t>> label_count_vec;
|
||||
|
||||
for (auto it = label_counts.begin(); it != label_counts.end(); it++)
|
||||
{
|
||||
auto &lbl = *it;
|
||||
label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second));
|
||||
if (lbl.second > max_points)
|
||||
{
|
||||
max_points = lbl.second;
|
||||
label_with_max_points = lbl.first;
|
||||
}
|
||||
sum += lbl.second;
|
||||
}
|
||||
|
||||
sort(label_count_vec.begin(), label_count_vec.end(),
|
||||
[](const std::pair<std::string, uint32_t> &lhs, const std::pair<std::string, uint32_t> &rhs) {
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
for (float p = 0; p < 1; p += 0.05)
|
||||
{
|
||||
std::cout << "Percentile " << (100 * p) << "\t" << label_count_vec[(size_t)(p * label_count_vec.size())].first
|
||||
<< " with count=" << label_count_vec[(size_t)(p * label_count_vec.size())].second << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Most common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 1].first
|
||||
<< " with count=" << label_count_vec[label_count_vec.size() - 1].second << std::endl;
|
||||
if (label_count_vec.size() > 1)
|
||||
std::cout << "Second common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 2].first
|
||||
<< " with count=" << label_count_vec[label_count_vec.size() - 2].second << std::endl;
|
||||
if (label_count_vec.size() > 2)
|
||||
std::cout << "Third common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 3].first
|
||||
<< " with count=" << label_count_vec[label_count_vec.size() - 3].second << std::endl;
|
||||
avg_labels_per_pt = sum / (float)point_cnt;
|
||||
mean_label_size = sum / (float)label_counts.size();
|
||||
std::cout << "Total number of points = " << point_cnt << ", number of labels = " << label_counts.size()
|
||||
<< std::endl;
|
||||
std::cout << "Average number of labels per point = " << avg_labels_per_pt << std::endl;
|
||||
std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl;
|
||||
std::cout << "Most popular label is " << label_with_max_points << " with " << max_points << " pts" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
std::string labels_file, universal_label;
|
||||
uint32_t density;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try
|
||||
{
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("labels_file", po::value<std::string>(&labels_file)->required(),
|
||||
"path to labels data file.");
|
||||
desc.add_options()("universal_label", po::value<std::string>(&universal_label)->required(),
|
||||
"Universal label used in labels file.");
|
||||
desc.add_options()("density", po::value<uint32_t>(&density)->default_value(1),
|
||||
"Number of labels each point in labels file, defaults to 1");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help"))
|
||||
{
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception &e)
|
||||
{
|
||||
std::cerr << e.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
stats_analysis(labels_file, universal_label, density);
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
void block_convert_float(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
|
||||
{
|
||||
auto read_buf = new float[npts * (ndims + 1)];
|
||||
|
||||
auto cursor = read_buf;
|
||||
float val;
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
reader >> val;
|
||||
*cursor = val;
|
||||
cursor++;
|
||||
}
|
||||
}
|
||||
writer.write((char *)read_buf, npts * ndims * sizeof(float));
|
||||
delete[] read_buf;
|
||||
}
|
||||
|
||||
void block_convert_int8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
|
||||
{
|
||||
auto read_buf = new int8_t[npts * (ndims + 1)];
|
||||
|
||||
auto cursor = read_buf;
|
||||
int val;
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
reader >> val;
|
||||
*cursor = (int8_t)val;
|
||||
cursor++;
|
||||
}
|
||||
}
|
||||
writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t));
|
||||
delete[] read_buf;
|
||||
}
|
||||
|
||||
void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t npts, size_t ndims)
|
||||
{
|
||||
auto read_buf = new uint8_t[npts * (ndims + 1)];
|
||||
|
||||
auto cursor = read_buf;
|
||||
int val;
|
||||
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; ++d)
|
||||
{
|
||||
reader >> val;
|
||||
*cursor = (uint8_t)val;
|
||||
cursor++;
|
||||
}
|
||||
}
|
||||
writer.write((char *)read_buf, npts * ndims * sizeof(uint8_t));
|
||||
delete[] read_buf;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 6)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< "<float/int8/uint8> input_filename.tsv output_filename.bin "
|
||||
"dim num_pts>"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) != std::string("float") && std::string(argv[1]) != std::string("int8") &&
|
||||
std::string(argv[1]) != std::string("uint8"))
|
||||
{
|
||||
std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl;
|
||||
}
|
||||
|
||||
size_t ndims = atoi(argv[4]);
|
||||
size_t npts = atoi(argv[5]);
|
||||
|
||||
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
|
||||
// size_t fsize = reader.tellg();
|
||||
reader.seekg(0, std::ios::beg);
|
||||
reader.seekg(0, std::ios::beg);
|
||||
|
||||
size_t blk_size = 131072;
|
||||
size_t nblks = ROUND_UP(npts, blk_size) / blk_size;
|
||||
std::cout << "# blks: " << nblks << std::endl;
|
||||
std::ofstream writer(argv[3], std::ios::binary);
|
||||
auto npts_u32 = (uint32_t)npts;
|
||||
auto ndims_u32 = (uint32_t)ndims;
|
||||
writer.write((char *)&npts_u32, sizeof(uint32_t));
|
||||
writer.write((char *)&ndims_u32, sizeof(uint32_t));
|
||||
|
||||
for (size_t i = 0; i < nblks; i++)
|
||||
{
|
||||
size_t cblk_size = std::min(npts - i * blk_size, blk_size);
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
block_convert_float(reader, writer, cblk_size, ndims);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
block_convert_int8(reader, writer, cblk_size, ndims);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
block_convert_uint8(reader, writer, cblk_size, ndims);
|
||||
}
|
||||
std::cout << "Block #" << i << " written" << std::endl;
|
||||
}
|
||||
|
||||
reader.close();
|
||||
writer.close();
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_uint32_bin output_int8_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
uint32_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<uint32_t>(argv[1], input, npts, nd);
|
||||
uint8_t *output = new uint8_t[npts * nd];
|
||||
diskann::convert_types<uint32_t, uint8_t>(input, output, npts, nd);
|
||||
diskann::save_bin<uint8_t>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include "utils.h"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc != 3)
|
||||
{
|
||||
std::cout << argv[0] << " input_uint8_bin output_float_bin" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
uint8_t *input;
|
||||
size_t npts, nd;
|
||||
diskann::load_bin<uint8_t>(argv[1], input, npts, nd);
|
||||
float *output = new float[npts * nd];
|
||||
diskann::convert_types<uint8_t, float>(input, output, npts, nd);
|
||||
diskann::save_bin<float>(argv[2], output, npts, nd);
|
||||
delete[] output;
|
||||
delete[] input;
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <omp.h>
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "partition.h"
|
||||
#include "utils.h"
|
||||
|
||||
template <typename T> int analyze_norm(std::string base_file)
|
||||
{
|
||||
std::cout << "Analyzing data norms" << std::endl;
|
||||
T *data;
|
||||
size_t npts, ndims;
|
||||
diskann::load_bin<T>(base_file, data, npts, ndims);
|
||||
std::vector<float> norms(npts, 0);
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
norms[i] += data[i * ndims + d] * data[i * ndims + d];
|
||||
norms[i] = std::sqrt(norms[i]);
|
||||
}
|
||||
std::sort(norms.begin(), norms.end());
|
||||
for (int p = 0; p < 100; p += 5)
|
||||
std::cout << "percentile " << p << ": " << norms[(uint64_t)(std::floor((p / 100.0) * npts))] << std::endl;
|
||||
std::cout << "percentile 100"
|
||||
<< ": " << norms[npts - 1] << std::endl;
|
||||
delete[] data;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T> int normalize_base(std::string base_file, std::string out_file)
|
||||
{
|
||||
std::cout << "Normalizing base" << std::endl;
|
||||
T *data;
|
||||
size_t npts, ndims;
|
||||
diskann::load_bin<T>(base_file, data, npts, ndims);
|
||||
// std::vector<float> norms(npts, 0);
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
float pt_norm = 0;
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
pt_norm += data[i * ndims + d] * data[i * ndims + d];
|
||||
pt_norm = std::sqrt(pt_norm);
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
data[i * ndims + d] = static_cast<T>(data[i * ndims + d] / pt_norm);
|
||||
}
|
||||
diskann::save_bin<T>(out_file, data, npts, ndims);
|
||||
delete[] data;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T> int augment_base(std::string base_file, std::string out_file, bool prep_base = true)
|
||||
{
|
||||
std::cout << "Analyzing data norms" << std::endl;
|
||||
T *data;
|
||||
size_t npts, ndims;
|
||||
diskann::load_bin<T>(base_file, data, npts, ndims);
|
||||
std::vector<float> norms(npts, 0);
|
||||
float max_norm = 0;
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int64_t i = 0; i < (int64_t)npts; i++)
|
||||
{
|
||||
for (size_t d = 0; d < ndims; d++)
|
||||
norms[i] += data[i * ndims + d] * data[i * ndims + d];
|
||||
max_norm = norms[i] > max_norm ? norms[i] : max_norm;
|
||||
}
|
||||
// std::sort(norms.begin(), norms.end());
|
||||
max_norm = std::sqrt(max_norm);
|
||||
std::cout << "Max norm: " << max_norm << std::endl;
|
||||
T *new_data;
|
||||
size_t newdims = ndims + 1;
|
||||
new_data = new T[npts * newdims];
|
||||
for (size_t i = 0; i < npts; i++)
|
||||
{
|
||||
if (prep_base)
|
||||
{
|
||||
for (size_t j = 0; j < ndims; j++)
|
||||
{
|
||||
new_data[i * newdims + j] = static_cast<T>(data[i * ndims + j] / max_norm);
|
||||
}
|
||||
float diff = 1 - (norms[i] / (max_norm * max_norm));
|
||||
diff = diff <= 0 ? 0 : std::sqrt(diff);
|
||||
new_data[i * newdims + ndims] = static_cast<T>(diff);
|
||||
if (diff <= 0)
|
||||
{
|
||||
std::cout << i << " has large max norm, investigate if needed. diff = " << diff << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t j = 0; j < ndims; j++)
|
||||
{
|
||||
new_data[i * newdims + j] = static_cast<T>(data[i * ndims + j] / std::sqrt(norms[i]));
|
||||
}
|
||||
new_data[i * newdims + ndims] = 0;
|
||||
}
|
||||
}
|
||||
diskann::save_bin<T>(out_file, new_data, npts, newdims);
|
||||
delete[] new_data;
|
||||
delete[] data;
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T> int aux_main(char **argv)
|
||||
{
|
||||
std::string base_file(argv[2]);
|
||||
uint32_t option = atoi(argv[3]);
|
||||
if (option == 1)
|
||||
analyze_norm<T>(base_file);
|
||||
else if (option == 2)
|
||||
augment_base<T>(base_file, std::string(argv[4]), true);
|
||||
else if (option == 3)
|
||||
augment_base<T>(base_file, std::string(argv[4]), false);
|
||||
else if (option == 4)
|
||||
normalize_base<T>(base_file, std::string(argv[4]));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
if (argc < 4)
|
||||
{
|
||||
std::cout << argv[0]
|
||||
<< " data_type [float/int8/uint8] base_bin_file "
|
||||
"[option: 1-norm analysis, 2-prep_base_for_mip, "
|
||||
"3-prep_query_for_mip, 4-normalize-vecs] [out_file for "
|
||||
"options 2/3/4]"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (std::string(argv[1]) == std::string("float"))
|
||||
{
|
||||
aux_main<float>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("int8"))
|
||||
{
|
||||
aux_main<int8_t>(argv);
|
||||
}
|
||||
else if (std::string(argv[1]) == std::string("uint8"))
|
||||
{
|
||||
aux_main<uint8_t>(argv);
|
||||
}
|
||||
else
|
||||
std::cout << "Unsupported type. Use float/int8/uint8." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
if (NOT MSVC)
|
||||
message(STATUS "Setting up `make format` and `make checkformat`")
|
||||
# additional target to perform clang-format run, requires clang-format
|
||||
# get all project files
|
||||
file(GLOB_RECURSE ALL_SOURCE_FILES include/*.h include/*.hpp python/src/*.cpp src/*.cpp src/*.hpp apps/*.cpp apps/*.hpp)
|
||||
|
||||
message(status ${ALL_SOURCE_FILES})
|
||||
|
||||
add_custom_target(
|
||||
format
|
||||
COMMAND /usr/bin/clang-format
|
||||
-i
|
||||
${ALL_SOURCE_FILES}
|
||||
)
|
||||
add_custom_target(
|
||||
checkformat
|
||||
COMMAND /usr/bin/clang-format
|
||||
--Werror
|
||||
--dry-run
|
||||
${ALL_SOURCE_FILES}
|
||||
)
|
||||
endif()
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "types.h"
|
||||
#include "windows_customizations.h"
|
||||
#include "distance.h"
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
|
||||
template <typename data_t> class AbstractScratch;
|
||||
|
||||
template <typename data_t> class AbstractDataStore
|
||||
{
|
||||
public:
|
||||
AbstractDataStore(const location_t capacity, const size_t dim);
|
||||
|
||||
virtual ~AbstractDataStore() = default;
|
||||
|
||||
// Return number of points returned
|
||||
virtual location_t load(const std::string &filename) = 0;
|
||||
|
||||
// Why does store take num_pts? Since store only has capacity, but we allow
|
||||
// resizing we can end up in a situation where the store has spare capacity.
|
||||
// To optimize disk utilization, we pass the number of points that are "true"
|
||||
// points, so that the store can discard the empty locations before saving.
|
||||
virtual size_t save(const std::string &filename, const location_t num_pts) = 0;
|
||||
|
||||
DISKANN_DLLEXPORT virtual location_t capacity() const;
|
||||
|
||||
DISKANN_DLLEXPORT virtual size_t get_dims() const;
|
||||
|
||||
// Implementers can choose to return _dim if they are not
|
||||
// concerned about memory alignment.
|
||||
// Some distance metrics (like l2) need data vectors to be aligned, so we
|
||||
// align the dimension by padding zeros.
|
||||
virtual size_t get_aligned_dim() const = 0;
|
||||
|
||||
// populate the store with vectors (either from a pointer or bin file),
|
||||
// potentially after pre-processing the vectors if the metric deems so
|
||||
// e.g., normalizing vectors for cosine distance over floating-point vectors
|
||||
// useful for bulk or static index building.
|
||||
virtual void populate_data(const data_t *vectors, const location_t num_pts) = 0;
|
||||
virtual void populate_data(const std::string &filename, const size_t offset) = 0;
|
||||
|
||||
// save the first num_pts many vectors back to bin file
|
||||
// note: cannot undo the pre-processing done in populate data
|
||||
virtual void extract_data_to_bin(const std::string &filename, const location_t num_pts) = 0;
|
||||
|
||||
// Returns the updated capacity of the datastore. Clients should check
|
||||
// if resize actually changed the capacity to new_num_points before
|
||||
// proceeding with operations. See the code below:
|
||||
// auto new_capcity = data_store->resize(new_num_points);
|
||||
// if ( new_capacity >= new_num_points) {
|
||||
// //PROCEED
|
||||
// else
|
||||
// //ERROR.
|
||||
virtual location_t resize(const location_t new_num_points);
|
||||
|
||||
// operations on vectors
|
||||
// like populate_data function, but over one vector at a time useful for
|
||||
// streaming setting
|
||||
virtual void get_vector(const location_t i, data_t *dest) const = 0;
|
||||
virtual void set_vector(const location_t i, const data_t *const vector) = 0;
|
||||
virtual void prefetch_vector(const location_t loc) = 0;
|
||||
|
||||
// internal shuffle operations to move around vectors
|
||||
// will bulk-move all the vectors in [old_start_loc, old_start_loc +
|
||||
// num_points) to [new_start_loc, new_start_loc + num_points) and set the old
|
||||
// positions to zero vectors.
|
||||
virtual void move_vectors(const location_t old_start_loc, const location_t new_start_loc,
|
||||
const location_t num_points) = 0;
|
||||
|
||||
// same as above, without resetting the vectors in [from_loc, from_loc +
|
||||
// num_points) to zero
|
||||
virtual void copy_vectors(const location_t from_loc, const location_t to_loc, const location_t num_points) = 0;
|
||||
|
||||
// With the PQ Data Store PR, we have also changed iterate_to_fixed_point to NOT take the query
|
||||
// from the scratch object. Therefore every data store has to implement preprocess_query which
|
||||
// at the least will be to copy the query into the scratch object. So making this pure virtual.
|
||||
virtual void preprocess_query(const data_t *aligned_query,
|
||||
AbstractScratch<data_t> *query_scratch = nullptr) const = 0;
|
||||
// distance functions.
|
||||
virtual float get_distance(const data_t *query, const location_t loc) const = 0;
|
||||
virtual void get_distance(const data_t *query, const location_t *locations, const uint32_t location_count,
|
||||
float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
|
||||
// Specific overload for index.cpp.
|
||||
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
|
||||
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const = 0;
|
||||
virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;
|
||||
|
||||
// stats of the data stored in store
|
||||
// Returns the point in the dataset that is closest to the mean of all points
|
||||
// in the dataset
|
||||
virtual location_t calculate_medoid() const = 0;
|
||||
|
||||
// REFACTOR PQ TODO: Each data store knows about its distance function, so this is
|
||||
// redundant. However, we don't have an OptmizedDataStore yet, and to preserve code
|
||||
// compability, we are exposing this function.
|
||||
virtual Distance<data_t> *get_dist_fn() const = 0;
|
||||
|
||||
// search helpers
|
||||
// if the base data is aligned per the request of the metric, this will tell
|
||||
// how to align the query vector in a consistent manner
|
||||
virtual size_t get_alignment_factor() const = 0;
|
||||
|
||||
protected:
|
||||
// Expand the datastore to new_num_points. Returns the new capacity created,
|
||||
// which should be == new_num_points in the normal case. Implementers can also
|
||||
// return _capacity to indicate that there are not implementing this method.
|
||||
virtual location_t expand(const location_t new_num_points) = 0;
|
||||
|
||||
// Shrink the datastore to new_num_points. It is NOT an error if shrink
|
||||
// doesn't reduce the capacity so callers need to check this correctly. See
|
||||
// also for "default" implementation
|
||||
virtual location_t shrink(const location_t new_num_points) = 0;
|
||||
|
||||
location_t _capacity;
|
||||
size_t _dim;
|
||||
};
|
||||
|
||||
} // namespace diskann
|
||||
@@ -1,68 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "types.h"
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
|
||||
class AbstractGraphStore
|
||||
{
|
||||
public:
|
||||
AbstractGraphStore(const size_t total_pts, const size_t reserve_graph_degree)
|
||||
: _capacity(total_pts), _reserve_graph_degree(reserve_graph_degree)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~AbstractGraphStore() = default;
|
||||
|
||||
// returns tuple of <nodes_read, start, num_frozen_points>
|
||||
virtual std::tuple<uint32_t, uint32_t, size_t> load(const std::string &index_path_prefix,
|
||||
const size_t num_points) = 0;
|
||||
virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_fz_points,
|
||||
const uint32_t start) = 0;
|
||||
|
||||
// not synchronised, user should use lock when necvessary.
|
||||
virtual const std::vector<location_t> &get_neighbours(const location_t i) const = 0;
|
||||
virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0;
|
||||
virtual void clear_neighbours(const location_t i) = 0;
|
||||
virtual void swap_neighbours(const location_t a, location_t b) = 0;
|
||||
|
||||
virtual void set_neighbours(const location_t i, std::vector<location_t> &neighbours) = 0;
|
||||
|
||||
virtual size_t resize_graph(const size_t new_size) = 0;
|
||||
virtual void clear_graph() = 0;
|
||||
|
||||
virtual uint32_t get_max_observed_degree() = 0;
|
||||
|
||||
// set during load
|
||||
virtual size_t get_max_range_of_graph() = 0;
|
||||
|
||||
// Total internal points _max_points + _num_frozen_points
|
||||
size_t get_total_points()
|
||||
{
|
||||
return _capacity;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Internal function, changes total points when resize_graph is called.
|
||||
void set_total_points(size_t new_capacity)
|
||||
{
|
||||
_capacity = new_capacity;
|
||||
}
|
||||
|
||||
size_t get_reserve_graph_degree()
|
||||
{
|
||||
return _reserve_graph_degree;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t _capacity;
|
||||
size_t _reserve_graph_degree;
|
||||
};
|
||||
|
||||
} // namespace diskann
|
||||
@@ -1,129 +0,0 @@
|
||||
#pragma once
|
||||
#include "distance.h"
|
||||
#include "parameters.h"
|
||||
#include "utils.h"
|
||||
#include "types.h"
|
||||
#include "index_config.h"
|
||||
#include "index_build_params.h"
|
||||
#include <any>
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
struct consolidation_report
|
||||
{
|
||||
enum status_code
|
||||
{
|
||||
SUCCESS = 0,
|
||||
FAIL = 1,
|
||||
LOCK_FAIL = 2,
|
||||
INCONSISTENT_COUNT_ERROR = 3
|
||||
};
|
||||
status_code _status;
|
||||
size_t _active_points, _max_points, _empty_slots, _slots_released, _delete_set_size, _num_calls_to_process_delete;
|
||||
double _time;
|
||||
|
||||
consolidation_report(status_code status, size_t active_points, size_t max_points, size_t empty_slots,
|
||||
size_t slots_released, size_t delete_set_size, size_t num_calls_to_process_delete,
|
||||
double time_secs)
|
||||
: _status(status), _active_points(active_points), _max_points(max_points), _empty_slots(empty_slots),
|
||||
_slots_released(slots_released), _delete_set_size(delete_set_size),
|
||||
_num_calls_to_process_delete(num_calls_to_process_delete), _time(time_secs)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/* A templated independent class for intercation with Index. Uses Type Erasure to add virtual implemetation of methods
|
||||
that can take any type(using std::any) and Provides a clean API that can be inherited by different type of Index.
|
||||
*/
|
||||
class AbstractIndex
|
||||
{
|
||||
public:
|
||||
AbstractIndex() = default;
|
||||
virtual ~AbstractIndex() = default;
|
||||
|
||||
virtual void build(const std::string &data_file, const size_t num_points_to_load,
|
||||
IndexFilterParams &build_params) = 0;
|
||||
|
||||
template <typename data_type, typename tag_type>
|
||||
void build(const data_type *data, const size_t num_points_to_load, const std::vector<tag_type> &tags);
|
||||
|
||||
virtual void save(const char *filename, bool compact_before_save = false) = 0;
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
|
||||
#else
|
||||
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l) = 0;
|
||||
#endif
|
||||
|
||||
// For FastL2 search on optimized layout
|
||||
template <typename data_type>
|
||||
void search_with_optimized_layout(const data_type *query, size_t K, size_t L, uint32_t *indices);
|
||||
|
||||
// Initialize space for res_vectors before calling.
|
||||
template <typename data_type, typename tag_type>
|
||||
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
|
||||
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
|
||||
const std::string filter_label = "");
|
||||
|
||||
// Added search overload that takes L as parameter, so that we
|
||||
// can customize L on a per-query basis without tampering with "Parameters"
|
||||
// IDtype is either uint32_t or uint64_t
|
||||
template <typename data_type, typename IDType>
|
||||
std::pair<uint32_t, uint32_t> search(const data_type *query, const size_t K, const uint32_t L, IDType *indices,
|
||||
float *distances = nullptr);
|
||||
|
||||
// Filter support search
|
||||
// IndexType is either uint32_t or uint64_t
|
||||
template <typename IndexType>
|
||||
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
|
||||
const size_t K, const uint32_t L, IndexType *indices,
|
||||
float *distances);
|
||||
|
||||
// insert points with labels, labels should be present for filtered index
|
||||
template <typename data_type, typename tag_type, typename label_type>
|
||||
int insert_point(const data_type *point, const tag_type tag, const std::vector<label_type> &labels);
|
||||
|
||||
// insert point for unfiltered index build. do not use with filtered index
|
||||
template <typename data_type, typename tag_type> int insert_point(const data_type *point, const tag_type tag);
|
||||
|
||||
// delete point with tag, or return -1 if point can not be deleted
|
||||
template <typename tag_type> int lazy_delete(const tag_type &tag);
|
||||
|
||||
// batch delete tags and populates failed tags if unabke to delete given tags.
|
||||
template <typename tag_type>
|
||||
void lazy_delete(const std::vector<tag_type> &tags, std::vector<tag_type> &failed_tags);
|
||||
|
||||
template <typename tag_type> void get_active_tags(tsl::robin_set<tag_type> &active_tags);
|
||||
|
||||
template <typename data_type> void set_start_points_at_random(data_type radius, uint32_t random_seed = 0);
|
||||
|
||||
virtual consolidation_report consolidate_deletes(const IndexWriteParameters ¶meters) = 0;
|
||||
|
||||
virtual void optimize_index_layout() = 0;
|
||||
|
||||
// memory should be allocated for vec before calling this function
|
||||
template <typename tag_type, typename data_type> int get_vector_by_tag(tag_type &tag, data_type *vec);
|
||||
|
||||
template <typename label_type> void set_universal_label(const label_type universal_label);
|
||||
|
||||
private:
|
||||
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
|
||||
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
|
||||
std::any &indices, float *distances = nullptr) = 0;
|
||||
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
|
||||
const size_t K, const uint32_t L, std::any &indices,
|
||||
float *distances) = 0;
|
||||
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
|
||||
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
|
||||
virtual int _lazy_delete(const TagType &tag) = 0;
|
||||
virtual void _lazy_delete(TagVector &tags, TagVector &failed_tags) = 0;
|
||||
virtual void _get_active_tags(TagRobinSet &active_tags) = 0;
|
||||
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
|
||||
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
|
||||
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
|
||||
float *distances, DataVector &res_vectors, bool use_filters = false,
|
||||
const std::string filter_label = "") = 0;
|
||||
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
|
||||
virtual void _set_universal_label(const LabelType universal_label) = 0;
|
||||
};
|
||||
} // namespace diskann
|
||||
@@ -1,35 +0,0 @@
|
||||
#pragma once
|
||||
namespace diskann
|
||||
{
|
||||
|
||||
template <typename data_t> class PQScratch;
|
||||
|
||||
// By somewhat more than a coincidence, it seems that both InMemQueryScratch
|
||||
// and SSDQueryScratch have the aligned query and PQScratch objects. So we
|
||||
// can put them in a neat hierarchy and keep PQScratch as a standalone class.
|
||||
template <typename data_t> class AbstractScratch
|
||||
{
|
||||
public:
|
||||
AbstractScratch() = default;
|
||||
// This class does not take any responsibilty for memory management of
|
||||
// its members. It is the responsibility of the derived classes to do so.
|
||||
virtual ~AbstractScratch() = default;
|
||||
|
||||
// Scratch objects should not be copied
|
||||
AbstractScratch(const AbstractScratch &) = delete;
|
||||
AbstractScratch &operator=(const AbstractScratch &) = delete;
|
||||
|
||||
data_t *aligned_query_T()
|
||||
{
|
||||
return _aligned_query_T;
|
||||
}
|
||||
PQScratch<data_t> *pq_scratch()
|
||||
{
|
||||
return _pq_scratch;
|
||||
}
|
||||
|
||||
protected:
|
||||
data_t *_aligned_query_T = nullptr;
|
||||
PQScratch<data_t> *_pq_scratch = nullptr;
|
||||
};
|
||||
} // namespace diskann
|
||||
@@ -1,138 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define MAX_IO_DEPTH 128
|
||||
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <fcntl.h>
|
||||
#include <libaio.h>
|
||||
#include <unistd.h>
|
||||
#include <malloc.h>
|
||||
typedef io_context_t IOContext;
|
||||
#elif __APPLE__
|
||||
#include <dispatch/dispatch.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
|
||||
struct IOContext
|
||||
{
|
||||
int fd;
|
||||
dispatch_io_t channel;
|
||||
dispatch_queue_t queue;
|
||||
dispatch_group_t grp;
|
||||
};
|
||||
#elif _WINDOWS
|
||||
#include <Windows.h>
|
||||
#include <minwinbase.h>
|
||||
#include <malloc.h>
|
||||
|
||||
#ifndef USE_BING_INFRA
|
||||
struct IOContext
|
||||
{
|
||||
HANDLE fhandle = NULL;
|
||||
HANDLE iocp = NULL;
|
||||
std::vector<OVERLAPPED> reqs;
|
||||
};
|
||||
#else
|
||||
#include "IDiskPriorityIO.h"
|
||||
#include <atomic>
|
||||
// TODO: Caller code is very callous about copying IOContext objects
|
||||
// all over the place. MUST verify that it won't cause leaks/logical
|
||||
// errors.
|
||||
// Because of such callous copying, we have to use ptr->atomic instead
|
||||
// of atomic, as atomic is not copyable.
|
||||
struct IOContext
|
||||
{
|
||||
enum Status
|
||||
{
|
||||
READ_WAIT = 0,
|
||||
READ_SUCCESS,
|
||||
READ_FAILED,
|
||||
PROCESS_COMPLETE
|
||||
};
|
||||
|
||||
std::shared_ptr<ANNIndex::IDiskPriorityIO> m_pDiskIO = nullptr;
|
||||
std::shared_ptr<std::vector<ANNIndex::AsyncReadRequest>> m_pRequests;
|
||||
std::shared_ptr<std::vector<Status>> m_pRequestsStatus;
|
||||
|
||||
// waitonaddress on this memory to wait for IO completion signal
|
||||
// reader should signal this memory after IO completion
|
||||
// TODO: WindowsAlignedFileReader can be modified to take advantage of this
|
||||
// and can largely share code with the file reader for Bing.
|
||||
mutable volatile long m_completeCount = 0;
|
||||
|
||||
IOContext()
|
||||
: m_pRequestsStatus(new std::vector<Status>()), m_pRequests(new std::vector<ANNIndex::AsyncReadRequest>())
|
||||
{
|
||||
(*m_pRequestsStatus).reserve(MAX_IO_DEPTH);
|
||||
(*m_pRequests).reserve(MAX_IO_DEPTH);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#include <cstdio>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include "tsl/robin_map.h"
|
||||
#include "utils.h"
|
||||
|
||||
// NOTE :: all 3 fields must be 512-aligned
|
||||
struct AlignedRead
|
||||
{
|
||||
uint64_t offset; // where to read from
|
||||
uint64_t len; // how much to read
|
||||
void *buf; // where to read into
|
||||
|
||||
AlignedRead() : offset(0), len(0), buf(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
AlignedRead(uint64_t offset, uint64_t len, void *buf) : offset(offset), len(len), buf(buf)
|
||||
{
|
||||
assert(IS_512_ALIGNED(offset));
|
||||
assert(IS_512_ALIGNED(len));
|
||||
assert(IS_512_ALIGNED(buf));
|
||||
// assert(malloc_usable_size(buf) >= len);
|
||||
}
|
||||
};
|
||||
|
||||
class AlignedFileReader
|
||||
{
|
||||
protected:
|
||||
tsl::robin_map<std::thread::id, IOContext> ctx_map;
|
||||
std::mutex ctx_mut;
|
||||
|
||||
public:
|
||||
// returns the thread-specific context
|
||||
// returns (io_context_t)(-1) if thread is not registered
|
||||
virtual IOContext &get_ctx() = 0;
|
||||
|
||||
virtual ~AlignedFileReader(){};
|
||||
|
||||
// register thread-id for a context
|
||||
virtual void register_thread() = 0;
|
||||
// de-register thread-id for a context
|
||||
virtual void deregister_thread() = 0;
|
||||
virtual void deregister_all_threads() = 0;
|
||||
|
||||
// Open & close ops
|
||||
// Blocking calls
|
||||
virtual void open(const std::string &fname) = 0;
|
||||
virtual void close() = 0;
|
||||
|
||||
// process batch of aligned requests in parallel
|
||||
// NOTE :: blocking call
|
||||
virtual void read(std::vector<AlignedRead> &read_reqs, IOContext &ctx, bool async = false) = 0;
|
||||
|
||||
#ifdef USE_BING_INFRA
|
||||
// wait for completion of one request in a batch of requests
|
||||
virtual void wait(IOContext &ctx, int &completedIndex) = 0;
|
||||
#endif
|
||||
};
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <system_error>
|
||||
#include "windows_customizations.h"
|
||||
#include <cstdint>
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#define __FUNCSIG__ __PRETTY_FUNCTION__
|
||||
#endif
|
||||
|
||||
namespace diskann
|
||||
{
|
||||
|
||||
class ANNException : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode);
|
||||
DISKANN_DLLEXPORT ANNException(const std::string &message, int errorCode, const std::string &funcSig,
|
||||
const std::string &fileName, uint32_t lineNum);
|
||||
|
||||
private:
|
||||
int _errorCode;
|
||||
};
|
||||
|
||||
class FileException : public ANNException
|
||||
{
|
||||
public:
|
||||
DISKANN_DLLEXPORT FileException(const std::string &filename, std::system_error &e, const std::string &funcSig,
|
||||
const std::string &fileName, uint32_t lineNum);
|
||||
};
|
||||
} // namespace diskann
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user