Initial commit
This commit is contained in:
45
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/Cargo.toml
vendored
Normal file
45
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/Cargo.toml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "diskann"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bincode = "1.3.3"
|
||||
bit-vec = "0.6.3"
|
||||
byteorder = "1.4.3"
|
||||
cblas = "0.4.0"
|
||||
crossbeam = "0.8.2"
|
||||
half = "2.2.1"
|
||||
hashbrown = "0.13.2"
|
||||
num-traits = "0.2.15"
|
||||
once_cell = "1.17.1"
|
||||
openblas-src = { version = "0.10.8", features = ["system"] }
|
||||
rand = { version = "0.8.5", features = [ "small_rng" ] }
|
||||
rayon = "1.7.0"
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
thiserror = "1.0.40"
|
||||
winapi = { version = "0.3.9", features = ["errhandlingapi", "fileapi", "ioapiset", "handleapi", "winnt", "minwindef", "basetsd", "winerror", "winbase"] }
|
||||
|
||||
logger = { path = "../logger" }
|
||||
platform = { path = "../platform" }
|
||||
vector = { path = "../vector" }
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1.0.79"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5.1"
|
||||
criterion = "0.5.1"
|
||||
|
||||
|
||||
[[bench]]
|
||||
name = "distance_bench"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "neighbor_bench"
|
||||
harness = false
|
||||
47
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/benches/distance_bench.rs
vendored
Normal file
47
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/benches/distance_bench.rs
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
use rand::{thread_rng, Rng};
|
||||
use vector::{FullPrecisionDistance, Metric};
|
||||
|
||||
// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps
|
||||
#[repr(C, align(32))]
|
||||
struct Vector32ByteAligned {
|
||||
v: [f32; 256],
|
||||
}
|
||||
|
||||
fn benchmark_l2_distance_float_rust(c: &mut Criterion) {
|
||||
let (a, b) = prepare_random_aligned_vectors();
|
||||
let mut group = c.benchmark_group("avx-computation");
|
||||
group.sample_size(5000);
|
||||
|
||||
group.bench_function("AVX Rust run", |f| {
|
||||
f.iter(|| {
|
||||
black_box(<[f32; 256]>::distance_compare(
|
||||
black_box(&a.v),
|
||||
black_box(&b.v),
|
||||
Metric::L2,
|
||||
))
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps
|
||||
fn prepare_random_aligned_vectors() -> (Box<Vector32ByteAligned>, Box<Vector32ByteAligned>) {
|
||||
let a = Box::new(Vector32ByteAligned {
|
||||
v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)),
|
||||
});
|
||||
|
||||
let b = Box::new(Vector32ByteAligned {
|
||||
v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)),
|
||||
});
|
||||
|
||||
(a, b)
|
||||
}
|
||||
|
||||
criterion_group!(benches, benchmark_l2_distance_float_rust,);
|
||||
criterion_main!(benches);
|
||||
|
||||
70
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/benches/kmeans_bench.rs
vendored
Normal file
70
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/benches/kmeans_bench.rs
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use diskann::utils::k_means_clustering;
|
||||
use rand::Rng;
|
||||
|
||||
const NUM_POINTS: usize = 10000;
|
||||
const DIM: usize = 100;
|
||||
const NUM_CENTERS: usize = 256;
|
||||
const MAX_KMEANS_REPS: usize = 12;
|
||||
|
||||
fn benchmark_kmeans_rust(c: &mut Criterion) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let data: Vec<f32> = (0..NUM_POINTS * DIM)
|
||||
.map(|_| rng.gen_range(-1.0..1.0))
|
||||
.collect();
|
||||
let centers: Vec<f32> = vec![0.0; NUM_CENTERS * DIM];
|
||||
|
||||
let mut group = c.benchmark_group("kmeans-computation");
|
||||
group.sample_size(500);
|
||||
|
||||
group.bench_function("K-Means Rust run", |f| {
|
||||
f.iter(|| {
|
||||
// let mut centers_copy = centers.clone();
|
||||
let data_copy = data.clone();
|
||||
let mut centers_copy = centers.clone();
|
||||
k_means_clustering(
|
||||
&data_copy,
|
||||
NUM_POINTS,
|
||||
DIM,
|
||||
&mut centers_copy,
|
||||
NUM_CENTERS,
|
||||
MAX_KMEANS_REPS,
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn benchmark_kmeans_c(c: &mut Criterion) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let data: Vec<f32> = (0..NUM_POINTS * DIM)
|
||||
.map(|_| rng.gen_range(-1.0..1.0))
|
||||
.collect();
|
||||
let centers: Vec<f32> = vec![0.0; NUM_CENTERS * DIM];
|
||||
|
||||
let mut group = c.benchmark_group("kmeans-computation");
|
||||
group.sample_size(500);
|
||||
|
||||
group.bench_function("K-Means C++ Run", |f| {
|
||||
f.iter(|| {
|
||||
let data_copy = data.clone();
|
||||
let mut centers_copy = centers.clone();
|
||||
let _ = k_means_clustering(
|
||||
data_copy.as_slice(),
|
||||
NUM_POINTS,
|
||||
DIM,
|
||||
centers_copy.as_mut_slice(),
|
||||
NUM_CENTERS,
|
||||
MAX_KMEANS_REPS,
|
||||
);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, benchmark_kmeans_rust, benchmark_kmeans_c);
|
||||
|
||||
criterion_main!(benches);
|
||||
|
||||
49
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/benches/neighbor_bench.rs
vendored
Normal file
49
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/benches/neighbor_bench.rs
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::time::Duration;
|
||||
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
use diskann::model::{Neighbor, NeighborPriorityQueue};
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
|
||||
fn benchmark_priority_queue_insert(c: &mut Criterion) {
|
||||
let vec = generate_random_floats();
|
||||
let mut group = c.benchmark_group("neighborqueue-insert");
|
||||
group.measurement_time(Duration::from_secs(3)).sample_size(500);
|
||||
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(64_usize);
|
||||
group.bench_function("Neighbor Priority Queue Insert", |f| {
|
||||
f.iter(|| {
|
||||
queue.clear();
|
||||
for n in vec.iter() {
|
||||
queue.insert(*n);
|
||||
}
|
||||
|
||||
black_box(&1)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn generate_random_floats() -> Vec<Neighbor> {
|
||||
let seed: [u8; 32] = [73; 32];
|
||||
let mut rng: StdRng = SeedableRng::from_seed(seed);
|
||||
let range = Uniform::new(0.0, 1.0);
|
||||
let mut random_floats = Vec::with_capacity(100);
|
||||
|
||||
for i in 0..100 {
|
||||
let random_float = range.sample(&mut rng) as f32;
|
||||
let n = Neighbor::new(i, random_float);
|
||||
random_floats.push(n);
|
||||
}
|
||||
|
||||
random_floats
|
||||
}
|
||||
|
||||
criterion_group!(benches, benchmark_priority_queue_insert);
|
||||
criterion_main!(benches);
|
||||
|
||||
7
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/mod.rs
vendored
Normal file
7
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/mod.rs
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub mod search;
|
||||
|
||||
pub mod prune;
|
||||
6
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/prune/mod.rs
vendored
Normal file
6
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/prune/mod.rs
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod prune;
|
||||
288
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/prune/prune.rs
vendored
Normal file
288
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/prune/prune.rs
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use hashbrown::HashSet;
|
||||
use vector::{FullPrecisionDistance, Metric};
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::index::InmemIndex;
|
||||
use crate::model::graph::AdjacencyList;
|
||||
use crate::model::neighbor::SortedNeighborVector;
|
||||
use crate::model::scratch::InMemQueryScratch;
|
||||
use crate::model::Neighbor;
|
||||
|
||||
impl<T, const N: usize> InmemIndex<T, N>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// A method that occludes a list of neighbors based on some criteria
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn occlude_list(
|
||||
&self,
|
||||
location: u32,
|
||||
pool: &mut SortedNeighborVector,
|
||||
alpha: f32,
|
||||
degree: u32,
|
||||
max_candidate_size: usize,
|
||||
result: &mut AdjacencyList,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
delete_set_ptr: Option<&HashSet<u32>>,
|
||||
) -> ANNResult<()> {
|
||||
if pool.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !result.is_empty() {
|
||||
return Err(ANNError::log_index_error(
|
||||
"result is not empty.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Truncate pool at max_candidate_size and initialize scratch spaces
|
||||
if pool.len() > max_candidate_size {
|
||||
pool.truncate(max_candidate_size);
|
||||
}
|
||||
|
||||
let occlude_factor = &mut scratch.occlude_factor;
|
||||
|
||||
// occlude_list can be called with the same scratch more than once by
|
||||
// search_for_point_and_add_link through inter_insert.
|
||||
occlude_factor.clear();
|
||||
|
||||
// Initialize occlude_factor to pool.len() many 0.0 values for correctness
|
||||
occlude_factor.resize(pool.len(), 0.0);
|
||||
|
||||
let mut cur_alpha = 1.0;
|
||||
while cur_alpha <= alpha && result.len() < degree as usize {
|
||||
for (i, neighbor) in pool.iter().enumerate() {
|
||||
if result.len() >= degree as usize {
|
||||
break;
|
||||
}
|
||||
if occlude_factor[i] > cur_alpha {
|
||||
continue;
|
||||
}
|
||||
// Set the entry to f32::MAX so that is not considered again
|
||||
occlude_factor[i] = f32::MAX;
|
||||
|
||||
// Add the entry to the result if its not been deleted, and doesn't
|
||||
// add a self loop
|
||||
if delete_set_ptr.map_or(true, |delete_set| !delete_set.contains(&neighbor.id))
|
||||
&& neighbor.id != location
|
||||
{
|
||||
result.push(neighbor.id);
|
||||
}
|
||||
|
||||
// Update occlude factor for points from i+1 to pool.len()
|
||||
for (j, neighbor2) in pool.iter().enumerate().skip(i + 1) {
|
||||
if occlude_factor[j] > alpha {
|
||||
continue;
|
||||
}
|
||||
|
||||
// todo - self.filtered_index
|
||||
let djk = self.get_distance(neighbor2.id, neighbor.id)?;
|
||||
match self.configuration.dist_metric {
|
||||
Metric::L2 | Metric::Cosine => {
|
||||
occlude_factor[j] = if djk == 0.0 {
|
||||
f32::MAX
|
||||
} else {
|
||||
occlude_factor[j].max(neighbor2.distance / djk)
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur_alpha *= 1.2;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `location` - The id of the data point whose neighbors are to be pruned.
|
||||
/// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point.
|
||||
/// * `pruned_list` - A vector to store the ids of the pruned neighbors.
|
||||
/// * `scratch` - A mutable reference to a scratch space for in-memory queries.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `pruned_list` contains more than `range` elements after pruning.
|
||||
pub fn prune_neighbors(
|
||||
&self,
|
||||
location: u32,
|
||||
pool: &mut Vec<Neighbor>,
|
||||
pruned_list: &mut AdjacencyList,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
) -> ANNResult<()> {
|
||||
self.robust_prune(
|
||||
location,
|
||||
pool,
|
||||
self.configuration.index_write_parameter.max_degree,
|
||||
self.configuration.index_write_parameter.max_occlusion_size,
|
||||
self.configuration.index_write_parameter.alpha,
|
||||
pruned_list,
|
||||
scratch,
|
||||
)
|
||||
}
|
||||
|
||||
/// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `location` - The id of the data point whose neighbors are to be pruned.
|
||||
/// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point.
|
||||
/// * `range` - The maximum number of neighbors to keep after pruning.
|
||||
/// * `max_candidate_size` - The maximum number of candidates to consider for pruning.
|
||||
/// * `alpha` - A parameter that controls the occlusion pruning strategy.
|
||||
/// * `pruned_list` - A vector to store the ids of the pruned neighbors.
|
||||
/// * `scratch` - A mutable reference to a scratch space for in-memory queries.
|
||||
///
|
||||
/// # Error
|
||||
///
|
||||
/// Return error if `pruned_list` contains more than `range` elements after pruning.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn robust_prune(
|
||||
&self,
|
||||
location: u32,
|
||||
pool: &mut Vec<Neighbor>,
|
||||
range: u32,
|
||||
max_candidate_size: u32,
|
||||
alpha: f32,
|
||||
pruned_list: &mut AdjacencyList,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
) -> ANNResult<()> {
|
||||
if pool.is_empty() {
|
||||
// if the pool is empty, behave like a noop
|
||||
pruned_list.clear();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If using _pq_build, over-write the PQ distances with actual distances
|
||||
// todo : pq_dist
|
||||
|
||||
// sort the pool based on distance to query and prune it with occlude_list
|
||||
let mut pool = SortedNeighborVector::new(pool);
|
||||
pruned_list.clear();
|
||||
|
||||
self.occlude_list(
|
||||
location,
|
||||
&mut pool,
|
||||
alpha,
|
||||
range,
|
||||
max_candidate_size as usize,
|
||||
pruned_list,
|
||||
scratch,
|
||||
Option::None,
|
||||
)?;
|
||||
|
||||
if pruned_list.len() > range as usize {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"pruned_list's len {} is over range {}.",
|
||||
pruned_list.len(),
|
||||
range
|
||||
)));
|
||||
}
|
||||
|
||||
if self.configuration.index_write_parameter.saturate_graph && alpha > 1.0f32 {
|
||||
for neighbor in pool.iter() {
|
||||
if pruned_list.len() >= (range as usize) {
|
||||
break;
|
||||
}
|
||||
if !pruned_list.contains(&neighbor.id) && neighbor.id != location {
|
||||
pruned_list.push(neighbor.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A method that inserts a point n into the graph of its neighbors and their neighbors,
|
||||
/// pruning the graph if necessary to keep it within the specified range
|
||||
/// * `n` - The index of the new point
|
||||
/// * `pruned_list` is a vector of the neighbors of n that have been pruned by a previous step
|
||||
/// * `range` is the target number of neighbors for each point
|
||||
/// * `scratch` is a mutable reference to a scratch space that can be reused for intermediate computations
|
||||
pub fn inter_insert(
|
||||
&self,
|
||||
n: u32,
|
||||
pruned_list: &Vec<u32>,
|
||||
range: u32,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
) -> ANNResult<()> {
|
||||
// Borrow the pruned_list as a source pool of neighbors
|
||||
let src_pool = pruned_list;
|
||||
|
||||
if src_pool.is_empty() {
|
||||
return Err(ANNError::log_index_error("src_pool is empty.".to_string()));
|
||||
}
|
||||
|
||||
for &vertex_id in src_pool {
|
||||
// vertex is the index of a neighbor of n
|
||||
// Assert that vertex is within the valid range of points
|
||||
if (vertex_id as usize)
|
||||
>= self.configuration.max_points + self.configuration.num_frozen_pts
|
||||
{
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"vertex_id {} is out of valid range of points {}",
|
||||
vertex_id,
|
||||
self.configuration.max_points + self.configuration.num_frozen_pts,
|
||||
)));
|
||||
}
|
||||
|
||||
let neighbors = self.add_to_neighbors(vertex_id, n, range)?;
|
||||
|
||||
if let Some(copy_of_neighbors) = neighbors {
|
||||
// Pruning is needed, create a dummy set and a dummy vector to store the unique neighbors of vertex_id
|
||||
let mut dummy_pool = self.get_unique_neighbors(©_of_neighbors, vertex_id)?;
|
||||
|
||||
// Create a new vector to store the pruned neighbors of vertex_id
|
||||
let mut new_out_neighbors =
|
||||
AdjacencyList::for_range(self.configuration.write_range());
|
||||
// Prune the neighbors of vertex_id using a helper method
|
||||
self.prune_neighbors(vertex_id, &mut dummy_pool, &mut new_out_neighbors, scratch)?;
|
||||
|
||||
self.set_neighbors(vertex_id, new_out_neighbors)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a node to the list of neighbors for the given node.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vertex_id` - The ID of the node to add the neighbor to.
|
||||
/// * `node_id` - The ID of the node to add.
|
||||
/// * `range` - The range of the graph.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full.
|
||||
fn add_to_neighbors(
|
||||
&self,
|
||||
vertex_id: u32,
|
||||
node_id: u32,
|
||||
range: u32,
|
||||
) -> ANNResult<Option<Vec<u32>>> {
|
||||
// vertex contains a vector of the neighbors of vertex_id
|
||||
let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?;
|
||||
|
||||
Ok(vertex_guard.add_to_neighbors(node_id, range))
|
||||
}
|
||||
|
||||
fn set_neighbors(&self, vertex_id: u32, new_out_neighbors: AdjacencyList) -> ANNResult<()> {
|
||||
// vertex contains a vector of the neighbors of vertex_id
|
||||
let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?;
|
||||
|
||||
vertex_guard.set_neighbors(new_out_neighbors);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
7
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/search/mod.rs
vendored
Normal file
7
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/search/mod.rs
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod search;
|
||||
|
||||
359
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/search/search.rs
vendored
Normal file
359
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/algorithm/search/search.rs
vendored
Normal file
@@ -0,0 +1,359 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Search algorithm for index construction and query
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::index::InmemIndex;
|
||||
use crate::model::{scratch::InMemQueryScratch, Neighbor, Vertex};
|
||||
use hashbrown::hash_set::Entry::*;
|
||||
use vector::FullPrecisionDistance;
|
||||
|
||||
impl<T, const N: usize> InmemIndex<T, N>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// Search for query using given L value, for benchmarking purposes
|
||||
/// # Arguments
|
||||
/// * `query` - query vertex
|
||||
/// * `scratch` - in-memory query scratch
|
||||
/// * `search_list_size` - search list size to use for the benchmark
|
||||
pub fn search_with_l_override(
|
||||
&self,
|
||||
query: &Vertex<T, N>,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
search_list_size: usize,
|
||||
) -> ANNResult<u32> {
|
||||
let init_ids = self.get_init_ids()?;
|
||||
self.init_graph_for_point(query, init_ids, scratch)?;
|
||||
// Scratch is created using largest L val from search_memory_index, so we artifically make it smaller here
|
||||
// This allows us to use the same scratch for all L values without having to rebuild the query scratch
|
||||
scratch.best_candidates.set_capacity(search_list_size);
|
||||
let (_, cmp) = self.greedy_search(query, scratch)?;
|
||||
|
||||
Ok(cmp)
|
||||
}
|
||||
|
||||
/// search for point
|
||||
/// # Arguments
|
||||
/// * `query` - query vertex
|
||||
/// * `scratch` - in-memory query scratch
|
||||
/// TODO: use_filter, filteredLindex
|
||||
pub fn search_for_point(
|
||||
&self,
|
||||
query: &Vertex<T, N>,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
) -> ANNResult<Vec<Neighbor>> {
|
||||
let init_ids = self.get_init_ids()?;
|
||||
self.init_graph_for_point(query, init_ids, scratch)?;
|
||||
let (mut visited_nodes, _) = self.greedy_search(query, scratch)?;
|
||||
|
||||
visited_nodes.retain(|&element| element.id != query.vertex_id());
|
||||
Ok(visited_nodes)
|
||||
}
|
||||
|
||||
/// Returns the locations of start point and frozen points suitable for use with iterate_to_fixed_point.
|
||||
fn get_init_ids(&self) -> ANNResult<Vec<u32>> {
|
||||
let mut init_ids = Vec::with_capacity(1 + self.configuration.num_frozen_pts);
|
||||
init_ids.push(self.start);
|
||||
|
||||
for frozen in self.configuration.max_points
|
||||
..(self.configuration.max_points + self.configuration.num_frozen_pts)
|
||||
{
|
||||
let frozen_u32 = frozen.try_into()?;
|
||||
if frozen_u32 != self.start {
|
||||
init_ids.push(frozen_u32);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(init_ids)
|
||||
}
|
||||
|
||||
/// Initialize graph for point
|
||||
/// # Arguments
|
||||
/// * `query` - query vertex
|
||||
/// * `init_ids` - initial nodes from which search starts
|
||||
/// * `scratch` - in-memory query scratch
|
||||
/// * `search_list_size_override` - override for search list size in index config
|
||||
fn init_graph_for_point(
|
||||
&self,
|
||||
query: &Vertex<T, N>,
|
||||
init_ids: Vec<u32>,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
) -> ANNResult<()> {
|
||||
scratch
|
||||
.best_candidates
|
||||
.reserve(self.configuration.index_write_parameter.search_list_size as usize);
|
||||
scratch.query.memcpy(query.vector())?;
|
||||
|
||||
if !scratch.id_scratch.is_empty() {
|
||||
return Err(ANNError::log_index_error(
|
||||
"id_scratch is not empty.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let query_vertex = Vertex::<T, N>::try_from((&scratch.query[..], query.vertex_id()))
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_error(format!(
|
||||
"TryFromSliceError: failed to get Vertex for query, err={}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
|
||||
for id in init_ids {
|
||||
if (id as usize) >= self.configuration.max_points + self.configuration.num_frozen_pts {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"vertex_id {} is out of valid range of points {}",
|
||||
id,
|
||||
self.configuration.max_points + self.configuration.num_frozen_pts
|
||||
)));
|
||||
}
|
||||
|
||||
if let Vacant(entry) = scratch.node_visited_robinset.entry(id) {
|
||||
entry.insert();
|
||||
|
||||
let vertex = self.dataset.get_vertex(id)?;
|
||||
|
||||
let distance = vertex.compare(&query_vertex, self.configuration.dist_metric);
|
||||
let neighbor = Neighbor::new(id, distance);
|
||||
scratch.best_candidates.insert(neighbor);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// GreedySearch against query node
|
||||
/// Returns visited nodes
|
||||
/// # Arguments
|
||||
/// * `query` - query vertex
|
||||
/// * `scratch` - in-memory query scratch
|
||||
/// TODO: use_filter, filter_label, search_invocation
|
||||
fn greedy_search(
|
||||
&self,
|
||||
query: &Vertex<T, N>,
|
||||
scratch: &mut InMemQueryScratch<T, N>,
|
||||
) -> ANNResult<(Vec<Neighbor>, u32)> {
|
||||
let mut visited_nodes =
|
||||
Vec::with_capacity((3 * scratch.candidate_size + scratch.max_degree) as usize);
|
||||
|
||||
// TODO: uncomment hops?
|
||||
// let mut hops: u32 = 0;
|
||||
let mut cmps: u32 = 0;
|
||||
|
||||
let query_vertex = Vertex::<T, N>::try_from((&scratch.query[..], query.vertex_id()))
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_error(format!(
|
||||
"TryFromSliceError: failed to get Vertex for query, err={}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
|
||||
while scratch.best_candidates.has_notvisited_node() {
|
||||
let closest_node = scratch.best_candidates.closest_notvisited();
|
||||
|
||||
// Add node to visited nodes to create pool for prune later
|
||||
// TODO: search_invocation and use_filter
|
||||
visited_nodes.push(closest_node);
|
||||
|
||||
// Find which of the nodes in des have not been visited before
|
||||
scratch.id_scratch.clear();
|
||||
|
||||
let max_vertex_id = self.configuration.max_points + self.configuration.num_frozen_pts;
|
||||
|
||||
for id in self
|
||||
.final_graph
|
||||
.read_vertex_and_neighbors(closest_node.id)?
|
||||
.get_neighbors()
|
||||
{
|
||||
let current_vertex_id = *id;
|
||||
debug_assert!(
|
||||
(current_vertex_id as usize) < max_vertex_id,
|
||||
"current_vertex_id {} is out of valid range of points {}",
|
||||
current_vertex_id,
|
||||
max_vertex_id
|
||||
);
|
||||
if current_vertex_id as usize >= max_vertex_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
// quickly de-dup. Remember, we are in a read lock
|
||||
// we want to exit out of it quickly
|
||||
if scratch.node_visited_robinset.insert(current_vertex_id) {
|
||||
scratch.id_scratch.push(current_vertex_id);
|
||||
}
|
||||
}
|
||||
|
||||
let len = scratch.id_scratch.len();
|
||||
for (m, &id) in scratch.id_scratch.iter().enumerate() {
|
||||
if m + 1 < len {
|
||||
let next_node = unsafe { *scratch.id_scratch.get_unchecked(m + 1) };
|
||||
self.dataset.prefetch_vector(next_node);
|
||||
}
|
||||
|
||||
let vertex = self.dataset.get_vertex(id)?;
|
||||
let distance = query_vertex.compare(&vertex, self.configuration.dist_metric);
|
||||
|
||||
// Insert <id, dist> pairs into the pool of candidates
|
||||
scratch.best_candidates.insert(Neighbor::new(id, distance));
|
||||
}
|
||||
|
||||
cmps += len as u32;
|
||||
}
|
||||
|
||||
Ok((visited_nodes, cmps))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod search_test {
|
||||
use vector::Metric;
|
||||
|
||||
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
|
||||
use crate::model::graph::AdjacencyList;
|
||||
use crate::model::IndexConfiguration;
|
||||
use crate::test_utils::inmem_index_initialization::create_index_with_test_data;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn get_init_ids_no_forzen_pts() {
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4)
|
||||
.with_alpha(1.2)
|
||||
.build();
|
||||
let config = IndexConfiguration::new(
|
||||
Metric::L2,
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
0,
|
||||
1f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
|
||||
let index = InmemIndex::<f32, 256>::new(config).unwrap();
|
||||
let init_ids = index.get_init_ids().unwrap();
|
||||
assert_eq!(init_ids.len(), 1);
|
||||
assert_eq!(init_ids[0], 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_init_ids_with_forzen_pts() {
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4)
|
||||
.with_alpha(1.2)
|
||||
.build();
|
||||
let config = IndexConfiguration::new(
|
||||
Metric::L2,
|
||||
256,
|
||||
256,
|
||||
256,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
2,
|
||||
1f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
|
||||
let index = InmemIndex::<f32, 256>::new(config).unwrap();
|
||||
let init_ids = index.get_init_ids().unwrap();
|
||||
assert_eq!(init_ids.len(), 2);
|
||||
assert_eq!(init_ids[0], 256);
|
||||
assert_eq!(init_ids[1], 257);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_for_point_initial_call() {
|
||||
let index = create_index_with_test_data();
|
||||
let query = index.dataset.get_vertex(0).unwrap();
|
||||
|
||||
let mut scratch = InMemQueryScratch::new(
|
||||
index.configuration.index_write_parameter.search_list_size,
|
||||
&index.configuration.index_write_parameter,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap();
|
||||
assert_eq!(visited_nodes.len(), 1);
|
||||
assert_eq!(scratch.best_candidates.size(), 1);
|
||||
assert_eq!(scratch.best_candidates[0].id, 72);
|
||||
assert_eq!(scratch.best_candidates[0].distance, 125678.0_f32);
|
||||
assert!(scratch.best_candidates[0].visited);
|
||||
}
|
||||
|
||||
fn set_neighbors(index: &InmemIndex<f32, 128>, vertex_id: u32, neighbors: Vec<u32>) {
|
||||
index
|
||||
.final_graph
|
||||
.write_vertex_and_neighbors(vertex_id)
|
||||
.unwrap()
|
||||
.set_neighbors(AdjacencyList::from(neighbors));
|
||||
}
|
||||
#[test]
|
||||
fn search_for_point_works_with_edges() {
|
||||
let index = create_index_with_test_data();
|
||||
let query = index.dataset.get_vertex(14).unwrap();
|
||||
|
||||
set_neighbors(&index, 0, vec![12, 72, 5, 9]);
|
||||
set_neighbors(&index, 1, vec![2, 12, 10, 4]);
|
||||
set_neighbors(&index, 2, vec![1, 72, 9]);
|
||||
set_neighbors(&index, 3, vec![13, 6, 5, 11]);
|
||||
set_neighbors(&index, 4, vec![1, 3, 7, 9]);
|
||||
set_neighbors(&index, 5, vec![3, 0, 8, 11, 13]);
|
||||
set_neighbors(&index, 6, vec![3, 72, 7, 10, 13]);
|
||||
set_neighbors(&index, 7, vec![72, 4, 6]);
|
||||
set_neighbors(&index, 8, vec![72, 5, 9, 12]);
|
||||
set_neighbors(&index, 9, vec![8, 4, 0, 2]);
|
||||
set_neighbors(&index, 10, vec![72, 1, 9, 6]);
|
||||
set_neighbors(&index, 11, vec![3, 0, 5]);
|
||||
set_neighbors(&index, 12, vec![1, 0, 8, 9]);
|
||||
set_neighbors(&index, 13, vec![3, 72, 5, 6]);
|
||||
set_neighbors(&index, 72, vec![7, 2, 10, 8, 13]);
|
||||
|
||||
let mut scratch = InMemQueryScratch::new(
|
||||
index.configuration.index_write_parameter.search_list_size,
|
||||
&index.configuration.index_write_parameter,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap();
|
||||
assert_eq!(visited_nodes.len(), 15);
|
||||
assert_eq!(scratch.best_candidates.size(), 15);
|
||||
assert_eq!(scratch.best_candidates[0].id, 2);
|
||||
assert_eq!(scratch.best_candidates[0].distance, 120899.0_f32);
|
||||
assert_eq!(scratch.best_candidates[1].id, 8);
|
||||
assert_eq!(scratch.best_candidates[1].distance, 145538.0_f32);
|
||||
assert_eq!(scratch.best_candidates[2].id, 72);
|
||||
assert_eq!(scratch.best_candidates[2].distance, 146046.0_f32);
|
||||
assert_eq!(scratch.best_candidates[3].id, 4);
|
||||
assert_eq!(scratch.best_candidates[3].distance, 148462.0_f32);
|
||||
assert_eq!(scratch.best_candidates[4].id, 7);
|
||||
assert_eq!(scratch.best_candidates[4].distance, 148912.0_f32);
|
||||
assert_eq!(scratch.best_candidates[5].id, 10);
|
||||
assert_eq!(scratch.best_candidates[5].distance, 154570.0_f32);
|
||||
assert_eq!(scratch.best_candidates[6].id, 1);
|
||||
assert_eq!(scratch.best_candidates[6].distance, 159448.0_f32);
|
||||
assert_eq!(scratch.best_candidates[7].id, 12);
|
||||
assert_eq!(scratch.best_candidates[7].distance, 170698.0_f32);
|
||||
assert_eq!(scratch.best_candidates[8].id, 9);
|
||||
assert_eq!(scratch.best_candidates[8].distance, 177205.0_f32);
|
||||
assert_eq!(scratch.best_candidates[9].id, 0);
|
||||
assert_eq!(scratch.best_candidates[9].distance, 259996.0_f32);
|
||||
assert_eq!(scratch.best_candidates[10].id, 6);
|
||||
assert_eq!(scratch.best_candidates[10].distance, 371819.0_f32);
|
||||
assert_eq!(scratch.best_candidates[11].id, 5);
|
||||
assert_eq!(scratch.best_candidates[11].distance, 385240.0_f32);
|
||||
assert_eq!(scratch.best_candidates[12].id, 3);
|
||||
assert_eq!(scratch.best_candidates[12].distance, 413899.0_f32);
|
||||
assert_eq!(scratch.best_candidates[13].id, 13);
|
||||
assert_eq!(scratch.best_candidates[13].distance, 416386.0_f32);
|
||||
assert_eq!(scratch.best_candidates[14].id, 11);
|
||||
assert_eq!(scratch.best_candidates[14].distance, 449266.0_f32);
|
||||
}
|
||||
}
|
||||
281
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/common/aligned_allocator.rs
vendored
Normal file
281
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/common/aligned_allocator.rs
vendored
Normal file
@@ -0,0 +1,281 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Aligned allocator
|
||||
|
||||
use std::alloc::Layout;
|
||||
use std::ops::{Deref, DerefMut, Range};
|
||||
use std::ptr::copy_nonoverlapping;
|
||||
|
||||
use super::{ANNResult, ANNError};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A box that holds a slice but is aligned to the specified layout.
|
||||
///
|
||||
/// This type is useful for working with types that require a certain alignment,
|
||||
/// such as SIMD vectors or FFI structs. It allocates memory using the global allocator
|
||||
/// and frees it when dropped. It also implements Deref and DerefMut to allow access
|
||||
/// to the underlying slice.
|
||||
pub struct AlignedBoxWithSlice<T> {
|
||||
/// The layout of the allocated memory.
|
||||
layout: Layout,
|
||||
|
||||
/// The slice that points to the allocated memory.
|
||||
val: Box<[T]>,
|
||||
}
|
||||
|
||||
impl<T> AlignedBoxWithSlice<T> {
|
||||
/// Creates a new `AlignedBoxWithSlice` with the given capacity and alignment.
|
||||
/// The allocated memory are set to 0.
|
||||
///
|
||||
/// # Error
|
||||
///
|
||||
/// Return IndexError if the alignment is not a power of two or if the layout is invalid.
|
||||
///
|
||||
/// This function is unsafe because it allocates uninitialized memory and casts it to
|
||||
/// a slice of `T`. The caller must ensure that the capacity and alignment are valid
|
||||
/// for the type `T` and that the memory is initialized before accessing the elements
|
||||
/// of the slice.
|
||||
pub fn new(capacity: usize, alignment: usize) -> ANNResult<Self> {
|
||||
let allocsize = capacity.checked_mul(std::mem::size_of::<T>())
|
||||
.ok_or_else(|| ANNError::log_index_error("capacity overflow".to_string()))?;
|
||||
let layout = Layout::from_size_align(allocsize, alignment)
|
||||
.map_err(ANNError::log_mem_alloc_layout_error)?;
|
||||
|
||||
let val = unsafe {
|
||||
let mem = std::alloc::alloc_zeroed(layout);
|
||||
let ptr = mem as *mut T;
|
||||
let slice = std::slice::from_raw_parts_mut(ptr, capacity);
|
||||
std::boxed::Box::from_raw(slice)
|
||||
};
|
||||
|
||||
Ok(Self { layout, val })
|
||||
}
|
||||
|
||||
/// Returns a reference to the slice.
|
||||
pub fn as_slice(&self) -> &[T] {
|
||||
&self.val
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the slice.
|
||||
pub fn as_mut_slice(&mut self) -> &mut [T] {
|
||||
&mut self.val
|
||||
}
|
||||
|
||||
/// Copies data from the source slice to the destination box.
|
||||
pub fn memcpy(&mut self, src: &[T]) -> ANNResult<()> {
|
||||
if src.len() > self.val.len() {
|
||||
return Err(ANNError::log_index_error(format!("source slice is too large (src:{}, dst:{})", src.len(), self.val.len())));
|
||||
}
|
||||
|
||||
// Check that they don't overlap
|
||||
let src_ptr = src.as_ptr();
|
||||
let src_end = unsafe { src_ptr.add(src.len()) };
|
||||
let dst_ptr = self.val.as_mut_ptr();
|
||||
let dst_end = unsafe { dst_ptr.add(self.val.len()) };
|
||||
|
||||
if src_ptr < dst_end && src_end > dst_ptr {
|
||||
return Err(ANNError::log_index_error("Source and destination overlap".to_string()));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
copy_nonoverlapping(src.as_ptr(), self.val.as_mut_ptr(), src.len());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Split the range of memory into nonoverlapping mutable slices.
|
||||
/// The number of returned slices is (range length / slice_len) and each has a length of slice_len.
|
||||
pub fn split_into_nonoverlapping_mut_slices(&mut self, range: Range<usize>, slice_len: usize) -> ANNResult<Vec<&mut [T]>> {
|
||||
if range.len() % slice_len != 0 || range.end > self.len() {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}",
|
||||
range,
|
||||
self.len(),
|
||||
slice_len,
|
||||
)));
|
||||
}
|
||||
|
||||
let mut slices = Vec::with_capacity(range.len() / slice_len);
|
||||
let mut remaining_slice = &mut self.val[range];
|
||||
|
||||
while remaining_slice.len() >= slice_len {
|
||||
let (left, right) = remaining_slice.split_at_mut(slice_len);
|
||||
slices.push(left);
|
||||
remaining_slice = right;
|
||||
}
|
||||
|
||||
Ok(slices)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<T> Drop for AlignedBoxWithSlice<T> {
|
||||
/// Frees the memory allocated for the slice using the global allocator.
|
||||
fn drop(&mut self) {
|
||||
let val = std::mem::take(&mut self.val);
|
||||
let mut val2 = std::mem::ManuallyDrop::new(val);
|
||||
let ptr = val2.as_mut_ptr();
|
||||
|
||||
unsafe {
|
||||
// let nonNull = NonNull::new_unchecked(ptr as *mut u8);
|
||||
std::alloc::dealloc(ptr as *mut u8, self.layout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for AlignedBoxWithSlice<T> {
|
||||
type Target = [T];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for AlignedBoxWithSlice<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.val
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::Rng;
|
||||
|
||||
use crate::utils::is_aligned;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn create_alignedvec_works_32() {
|
||||
(0..100).for_each(|_| {
|
||||
let size = 1_000_000;
|
||||
println!("Attempting {}", size);
|
||||
let data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
assert_eq!(data.len(), size, "Capacity should match");
|
||||
|
||||
let ptr = data.as_ptr() as usize;
|
||||
assert_eq!(ptr % 32, 0, "Ptr should be aligned to 32");
|
||||
|
||||
// assert that the slice is initialized.
|
||||
(0..size).for_each(|i| {
|
||||
assert_eq!(data[i], f32::default());
|
||||
});
|
||||
|
||||
drop(data);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_alignedvec_works_256() {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
(0..100).for_each(|_| {
|
||||
let n = rng.gen::<u8>();
|
||||
let size = usize::from(n) + 1;
|
||||
println!("Attempting {}", size);
|
||||
let data = AlignedBoxWithSlice::<u8>::new(size, 256).unwrap();
|
||||
assert_eq!(data.len(), size, "Capacity should match");
|
||||
|
||||
let ptr = data.as_ptr() as usize;
|
||||
assert_eq!(ptr % 256, 0, "Ptr should be aligned to 32");
|
||||
|
||||
// assert that the slice is initialized.
|
||||
(0..size).for_each(|i| {
|
||||
assert_eq!(data[i], u8::default());
|
||||
});
|
||||
|
||||
drop(data);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn as_slice_test() {
|
||||
let size = 1_000_000;
|
||||
let data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
// assert that the slice is initialized.
|
||||
(0..size).for_each(|i| {
|
||||
assert_eq!(data[i], f32::default());
|
||||
});
|
||||
|
||||
let slice = data.as_slice();
|
||||
(0..size).for_each(|i| {
|
||||
assert_eq!(slice[i], f32::default());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn as_mut_slice_test() {
|
||||
let size = 1_000_000;
|
||||
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
let mut_slice = data.as_mut_slice();
|
||||
(0..size).for_each(|i| {
|
||||
assert_eq!(mut_slice[i], f32::default());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn memcpy_test() {
|
||||
let size = 1_000_000;
|
||||
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
let mut destination = AlignedBoxWithSlice::<f32>::new(size-2, 32).unwrap();
|
||||
let mut_destination = destination.as_mut_slice();
|
||||
data.memcpy(mut_destination).unwrap();
|
||||
(0..size-2).for_each(|i| {
|
||||
assert_eq!(data[i], mut_destination[i]);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "source slice is too large (src:1000000, dst:999998)")]
|
||||
fn memcpy_panic_test() {
|
||||
let size = 1_000_000;
|
||||
let mut data = AlignedBoxWithSlice::<f32>::new(size-2, 32).unwrap();
|
||||
let mut destination = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
let mut_destination = destination.as_mut_slice();
|
||||
data.memcpy(mut_destination).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_aligned_test() {
|
||||
assert!(is_aligned(256,256));
|
||||
assert!(!is_aligned(255,256));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_into_nonoverlapping_mut_slices_test() {
|
||||
let size = 10;
|
||||
let slice_len = 2;
|
||||
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
let slices = data.split_into_nonoverlapping_mut_slices(2..8, slice_len).unwrap();
|
||||
assert_eq!(slices.len(), 3);
|
||||
for (i, slice) in slices.into_iter().enumerate() {
|
||||
assert_eq!(slice.len(), slice_len);
|
||||
slice[0] = i as f32 + 1.0;
|
||||
slice[1] = i as f32 + 1.0;
|
||||
}
|
||||
let expected_arr = [0.0f32, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0, 0.0];
|
||||
assert_eq!(data.as_ref(), &expected_arr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_into_nonoverlapping_mut_slices_error_when_indivisible() {
|
||||
let size = 10;
|
||||
let slice_len = 2;
|
||||
let range = 2..7;
|
||||
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
|
||||
let result = data.split_into_nonoverlapping_mut_slices(range.clone(), slice_len);
|
||||
let expected_err_str = format!(
|
||||
"IndexError: Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}",
|
||||
range,
|
||||
size,
|
||||
slice_len,
|
||||
);
|
||||
assert!(result.is_err_and(|e| e.to_string() == expected_err_str));
|
||||
}
|
||||
}
|
||||
|
||||
179
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/common/ann_result.rs
vendored
Normal file
179
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/common/ann_result.rs
vendored
Normal file
@@ -0,0 +1,179 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::alloc::LayoutError;
|
||||
use std::array::TryFromSliceError;
|
||||
use std::io;
|
||||
use std::num::TryFromIntError;
|
||||
|
||||
use logger::error_logger::log_error;
|
||||
use logger::log_error::LogError;
|
||||
|
||||
/// Result
|
||||
pub type ANNResult<T> = Result<T, ANNError>;
|
||||
|
||||
/// DiskANN Error
|
||||
/// ANNError is `Send` (i.e., safe to send across threads)
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ANNError {
|
||||
/// Index construction and search error
|
||||
#[error("IndexError: {err}")]
|
||||
IndexError { err: String },
|
||||
|
||||
/// Index configuration error
|
||||
#[error("IndexConfigError: {parameter} is invalid, err={err}")]
|
||||
IndexConfigError { parameter: String, err: String },
|
||||
|
||||
/// Integer conversion error
|
||||
#[error("TryFromIntError: {err}")]
|
||||
TryFromIntError {
|
||||
#[from]
|
||||
err: TryFromIntError,
|
||||
},
|
||||
|
||||
/// IO error
|
||||
#[error("IOError: {err}")]
|
||||
IOError {
|
||||
#[from]
|
||||
err: io::Error,
|
||||
},
|
||||
|
||||
/// Layout error in memory allocation
|
||||
#[error("MemoryAllocLayoutError: {err}")]
|
||||
MemoryAllocLayoutError {
|
||||
#[from]
|
||||
err: LayoutError,
|
||||
},
|
||||
|
||||
/// PoisonError which can be returned whenever a lock is acquired
|
||||
/// Both Mutexes and RwLocks are poisoned whenever a thread fails while the lock is held
|
||||
#[error("LockPoisonError: {err}")]
|
||||
LockPoisonError { err: String },
|
||||
|
||||
/// DiskIOAlignmentError which can be returned when calling windows API CreateFileA for the disk index file fails.
|
||||
#[error("DiskIOAlignmentError: {err}")]
|
||||
DiskIOAlignmentError { err: String },
|
||||
|
||||
/// Logging error
|
||||
#[error("LogError: {err}")]
|
||||
LogError {
|
||||
#[from]
|
||||
err: LogError,
|
||||
},
|
||||
|
||||
// PQ construction error
|
||||
// Error happened when we construct PQ pivot or PQ compressed table
|
||||
#[error("PQError: {err}")]
|
||||
PQError { err: String },
|
||||
|
||||
/// Array conversion error
|
||||
#[error("Error try creating array from slice: {err}")]
|
||||
TryFromSliceError {
|
||||
#[from]
|
||||
err: TryFromSliceError,
|
||||
},
|
||||
}
|
||||
|
||||
impl ANNError {
|
||||
/// Create, log and return IndexError
|
||||
#[inline]
|
||||
pub fn log_index_error(err: String) -> Self {
|
||||
let ann_err = ANNError::IndexError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return IndexConfigError
|
||||
#[inline]
|
||||
pub fn log_index_config_error(parameter: String, err: String) -> Self {
|
||||
let ann_err = ANNError::IndexConfigError { parameter, err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return TryFromIntError
|
||||
#[inline]
|
||||
pub fn log_try_from_int_error(err: TryFromIntError) -> Self {
|
||||
let ann_err = ANNError::TryFromIntError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return IOError
|
||||
#[inline]
|
||||
pub fn log_io_error(err: io::Error) -> Self {
|
||||
let ann_err = ANNError::IOError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return DiskIOAlignmentError
|
||||
/// #[inline]
|
||||
pub fn log_disk_io_request_alignment_error(err: String) -> Self {
|
||||
let ann_err: ANNError = ANNError::DiskIOAlignmentError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return IOError
|
||||
#[inline]
|
||||
pub fn log_mem_alloc_layout_error(err: LayoutError) -> Self {
|
||||
let ann_err = ANNError::MemoryAllocLayoutError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return LockPoisonError
|
||||
#[inline]
|
||||
pub fn log_lock_poison_error(err: String) -> Self {
|
||||
let ann_err = ANNError::LockPoisonError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return PQError
|
||||
#[inline]
|
||||
pub fn log_pq_error(err: String) -> Self {
|
||||
let ann_err = ANNError::PQError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create, log and return TryFromSliceError
|
||||
#[inline]
|
||||
pub fn log_try_from_slice_error(err: TryFromSliceError) -> Self {
|
||||
let ann_err = ANNError::TryFromSliceError { err };
|
||||
match log_error(ann_err.to_string()) {
|
||||
Ok(()) => ann_err,
|
||||
Err(log_err) => ANNError::LogError { err: log_err },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod ann_result_test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ann_err_is_send() {
|
||||
fn assert_send<T: Send>() {}
|
||||
assert_send::<ANNError>();
|
||||
}
|
||||
}
|
||||
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/common/mod.rs
vendored
Normal file
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/common/mod.rs
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
mod aligned_allocator;
|
||||
pub use aligned_allocator::AlignedBoxWithSlice;
|
||||
|
||||
mod ann_result;
|
||||
pub use ann_result::*;
|
||||
@@ -0,0 +1,54 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! ANN disk index abstraction
|
||||
|
||||
use vector::FullPrecisionDistance;
|
||||
|
||||
use crate::model::{IndexConfiguration, DiskIndexBuildParameters};
|
||||
use crate::storage::DiskIndexStorage;
|
||||
use crate::model::vertex::{DIM_128, DIM_256, DIM_104};
|
||||
|
||||
use crate::common::{ANNResult, ANNError};
|
||||
|
||||
use super::DiskIndex;
|
||||
|
||||
/// ANN disk index abstraction for custom <T, N>
|
||||
pub trait ANNDiskIndex<T> : Sync + Send
|
||||
where T : Default + Copy + Sync + Send + Into<f32>
|
||||
{
|
||||
/// Build index
|
||||
fn build(&mut self, codebook_prefix: &str) -> ANNResult<()>;
|
||||
}
|
||||
|
||||
/// Create Index<T, N> based on configuration
|
||||
pub fn create_disk_index<'a, T>(
|
||||
disk_build_param: Option<DiskIndexBuildParameters>,
|
||||
config: IndexConfiguration,
|
||||
storage: DiskIndexStorage<T>,
|
||||
) -> ANNResult<Box<dyn ANNDiskIndex<T> + 'a>>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32> + 'a,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
|
||||
{
|
||||
match config.aligned_dim {
|
||||
DIM_104 => {
|
||||
let index = Box::new(DiskIndex::<T, DIM_104>::new(disk_build_param, config, storage));
|
||||
Ok(index as Box<dyn ANNDiskIndex<T>>)
|
||||
},
|
||||
DIM_128 => {
|
||||
let index = Box::new(DiskIndex::<T, DIM_128>::new(disk_build_param, config, storage));
|
||||
Ok(index as Box<dyn ANNDiskIndex<T>>)
|
||||
},
|
||||
DIM_256 => {
|
||||
let index = Box::new(DiskIndex::<T, DIM_256>::new(disk_build_param, config, storage));
|
||||
Ok(index as Box<dyn ANNDiskIndex<T>>)
|
||||
},
|
||||
_ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))),
|
||||
}
|
||||
}
|
||||
161
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/disk_index/disk_index.rs
vendored
Normal file
161
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/disk_index/disk_index.rs
vendored
Normal file
@@ -0,0 +1,161 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::mem;
|
||||
|
||||
use logger::logger::indexlog::DiskIndexConstructionCheckpoint;
|
||||
use vector::FullPrecisionDistance;
|
||||
|
||||
use crate::common::{ANNResult, ANNError};
|
||||
use crate::index::{InmemIndex, ANNInmemIndex};
|
||||
use crate::instrumentation::DiskIndexBuildLogger;
|
||||
use crate::model::configuration::DiskIndexBuildParameters;
|
||||
use crate::model::{IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE, MAX_PQ_CHUNKS, generate_quantized_data, GRAPH_SLACK_FACTOR};
|
||||
use crate::storage::DiskIndexStorage;
|
||||
use crate::utils::set_rayon_num_threads;
|
||||
|
||||
use super::ann_disk_index::ANNDiskIndex;
|
||||
|
||||
pub const OVERHEAD_FACTOR: f64 = 1.1f64;
|
||||
|
||||
pub const MAX_SAMPLE_POINTS_FOR_WARMUP: usize = 100_000;
|
||||
|
||||
pub struct DiskIndex<T, const N: usize>
|
||||
where
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// Parameters for index construction
|
||||
/// None for query path
|
||||
disk_build_param: Option<DiskIndexBuildParameters>,
|
||||
|
||||
configuration: IndexConfiguration,
|
||||
|
||||
pub storage: DiskIndexStorage<T>,
|
||||
}
|
||||
|
||||
impl<T, const N: usize> DiskIndex<T, N>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
pub fn new(
|
||||
disk_build_param: Option<DiskIndexBuildParameters>,
|
||||
configuration: IndexConfiguration,
|
||||
storage: DiskIndexStorage<T>,
|
||||
) -> Self {
|
||||
Self {
|
||||
disk_build_param,
|
||||
configuration,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn disk_build_param(&self) -> &Option<DiskIndexBuildParameters> {
|
||||
&self.disk_build_param
|
||||
}
|
||||
|
||||
pub fn index_configuration(&self) -> &IndexConfiguration {
|
||||
&self.configuration
|
||||
}
|
||||
|
||||
fn build_inmem_index(&self, num_points: usize, data_path: &str, inmem_index_path: &str) -> ANNResult<()> {
|
||||
let estimated_index_ram = self.estimate_ram_usage(num_points);
|
||||
if estimated_index_ram >= self.fetch_disk_build_param()?.index_build_ram_limit() * 1024_f64 * 1024_f64 * 1024_f64 {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Insufficient memory budget for index build, index_build_ram_limit={}GB estimated_index_ram={}GB",
|
||||
self.fetch_disk_build_param()?.index_build_ram_limit(),
|
||||
estimated_index_ram / (1024_f64 * 1024_f64 * 1024_f64),
|
||||
)));
|
||||
}
|
||||
|
||||
let mut index = InmemIndex::<T, N>::new(self.configuration.clone())?;
|
||||
index.build(data_path, num_points)?;
|
||||
index.save(inmem_index_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn estimate_ram_usage(&self, size: usize) -> f64 {
|
||||
let degree = self.configuration.index_write_parameter.max_degree as usize;
|
||||
let datasize = mem::size_of::<T>();
|
||||
|
||||
let dataset_size = (size * N * datasize) as f64;
|
||||
let graph_size = (size * degree * mem::size_of::<u32>()) as f64 * GRAPH_SLACK_FACTOR;
|
||||
|
||||
OVERHEAD_FACTOR * (dataset_size + graph_size)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fetch_disk_build_param(&self) -> ANNResult<&DiskIndexBuildParameters> {
|
||||
self.disk_build_param
|
||||
.as_ref()
|
||||
.ok_or_else(|| ANNError::log_index_config_error(
|
||||
"disk_build_param".to_string(),
|
||||
"disk_build_param is None".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const N: usize> ANNDiskIndex<T> for DiskIndex<T, N>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
fn build(&mut self, codebook_prefix: &str) -> ANNResult<()> {
|
||||
if self.configuration.index_write_parameter.num_threads > 0 {
|
||||
set_rayon_num_threads(self.configuration.index_write_parameter.num_threads);
|
||||
}
|
||||
|
||||
println!("Starting index build: R={} L={} Query RAM budget={} Indexing RAM budget={} T={}",
|
||||
self.configuration.index_write_parameter.max_degree,
|
||||
self.configuration.index_write_parameter.search_list_size,
|
||||
self.fetch_disk_build_param()?.search_ram_limit(),
|
||||
self.fetch_disk_build_param()?.index_build_ram_limit(),
|
||||
self.configuration.index_write_parameter.num_threads
|
||||
);
|
||||
|
||||
let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction);
|
||||
|
||||
// PQ memory consumption = PQ pivots + PQ compressed table
|
||||
// PQ pivots: dim * num_centroids * sizeof::<T>()
|
||||
// PQ compressed table: num_pts * num_pq_chunks * (dim / num_pq_chunks) * sizeof::<u8>()
|
||||
// * Because num_centroids is 256, centroid id can be represented by u8
|
||||
let num_points = self.configuration.max_points;
|
||||
let dim = self.configuration.dim;
|
||||
let p_val = MAX_PQ_TRAINING_SET_SIZE / (num_points as f64);
|
||||
let mut num_pq_chunks = ((self.fetch_disk_build_param()?.search_ram_limit() / (num_points as f64)).floor()) as usize;
|
||||
num_pq_chunks = if num_pq_chunks == 0 { 1 } else { num_pq_chunks };
|
||||
num_pq_chunks = if num_pq_chunks > dim { dim } else { num_pq_chunks };
|
||||
num_pq_chunks = if num_pq_chunks > MAX_PQ_CHUNKS { MAX_PQ_CHUNKS } else { num_pq_chunks };
|
||||
|
||||
println!("Compressing {}-dimensional data into {} bytes per vector.", dim, num_pq_chunks);
|
||||
|
||||
// TODO: Decouple PQ from file access
|
||||
generate_quantized_data::<T>(
|
||||
p_val,
|
||||
num_pq_chunks,
|
||||
codebook_prefix,
|
||||
self.storage.get_pq_storage(),
|
||||
)?;
|
||||
logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild)?;
|
||||
|
||||
// TODO: Decouple index from file access
|
||||
let inmem_index_path = self.storage.index_path_prefix().clone() + "_mem.index";
|
||||
self.build_inmem_index(num_points, self.storage.dataset_file(), inmem_index_path.as_str())?;
|
||||
logger.log_checkpoint(DiskIndexConstructionCheckpoint::DiskLayout)?;
|
||||
|
||||
self.storage.create_disk_layout()?;
|
||||
logger.log_checkpoint(DiskIndexConstructionCheckpoint::None)?;
|
||||
|
||||
let ten_percent_points = ((num_points as f64) * 0.1_f64).ceil();
|
||||
let num_sample_points = if ten_percent_points > (MAX_SAMPLE_POINTS_FOR_WARMUP as f64) { MAX_SAMPLE_POINTS_FOR_WARMUP as f64 } else { ten_percent_points };
|
||||
let sample_sampling_rate = num_sample_points / (num_points as f64);
|
||||
self.storage.gen_query_warmup_data(sample_sampling_rate)?;
|
||||
|
||||
self.storage.index_build_cleanup()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/disk_index/mod.rs
vendored
Normal file
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/disk_index/mod.rs
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod disk_index;
|
||||
pub use disk_index::DiskIndex;
|
||||
|
||||
pub mod ann_disk_index;
|
||||
@@ -0,0 +1,97 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! ANN in-memory index abstraction
|
||||
|
||||
use vector::FullPrecisionDistance;
|
||||
|
||||
use crate::model::{vertex::{DIM_128, DIM_256, DIM_104}, IndexConfiguration};
|
||||
use crate::common::{ANNResult, ANNError};
|
||||
|
||||
use super::InmemIndex;
|
||||
|
||||
/// ANN inmem-index abstraction for custom <T, N>
|
||||
pub trait ANNInmemIndex<T> : Sync + Send
|
||||
where T : Default + Copy + Sync + Send + Into<f32>
|
||||
{
|
||||
/// Build index
|
||||
fn build(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()>;
|
||||
|
||||
/// Save index
|
||||
fn save(&mut self, filename: &str) -> ANNResult<()>;
|
||||
|
||||
/// Load index
|
||||
fn load(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<()>;
|
||||
|
||||
/// insert index
|
||||
fn insert(&mut self, filename: &str, num_points_to_insert: usize) -> ANNResult<()>;
|
||||
|
||||
/// Search the index for K nearest neighbors of query using given L value, for benchmarking purposes
|
||||
fn search(&self, query : &[T], k_value : usize, l_value : u32, indices : &mut[u32]) -> ANNResult<u32>;
|
||||
|
||||
/// Soft deletes the nodes with the ids in the given array.
|
||||
fn soft_delete(&mut self, vertex_ids_to_delete: Vec<u32>, num_points_to_delete: usize) -> ANNResult<()>;
|
||||
}
|
||||
|
||||
/// Create Index<T, N> based on configuration
|
||||
pub fn create_inmem_index<'a, T>(config: IndexConfiguration) -> ANNResult<Box<dyn ANNInmemIndex<T> + 'a>>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32> + 'a,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
|
||||
{
|
||||
match config.aligned_dim {
|
||||
DIM_104 => {
|
||||
let index = Box::new(InmemIndex::<T, DIM_104>::new(config)?);
|
||||
Ok(index as Box<dyn ANNInmemIndex<T>>)
|
||||
},
|
||||
DIM_128 => {
|
||||
let index = Box::new(InmemIndex::<T, DIM_128>::new(config)?);
|
||||
Ok(index as Box<dyn ANNInmemIndex<T>>)
|
||||
},
|
||||
DIM_256 => {
|
||||
let index = Box::new(InmemIndex::<T, DIM_256>::new(config)?);
|
||||
Ok(index as Box<dyn ANNInmemIndex<T>>)
|
||||
},
|
||||
_ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dataset_test {
|
||||
use vector::Metric;
|
||||
|
||||
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ERROR: Data file fake_file does not exist.")]
|
||||
fn create_index_test() {
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4)
|
||||
.with_alpha(1.2)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(1)
|
||||
.build();
|
||||
|
||||
let config = IndexConfiguration::new(
|
||||
Metric::L2,
|
||||
128,
|
||||
256,
|
||||
1_000_000,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
0,
|
||||
1f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let mut index = create_inmem_index::<f32>(config).unwrap();
|
||||
index.build("fake_file", 100).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
1033
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/inmem_index/inmem_index.rs
vendored
Normal file
1033
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/inmem_index/inmem_index.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,304 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter, Seek, SeekFrom, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use vector::FullPrecisionDistance;
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::model::graph::AdjacencyList;
|
||||
use crate::model::InMemoryGraph;
|
||||
use crate::utils::{file_exists, save_data_in_base_dimensions};
|
||||
|
||||
use super::InmemIndex;
|
||||
|
||||
impl<T, const N: usize> InmemIndex<T, N>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
pub fn load_graph(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<usize> {
|
||||
// let file_offset = 0; // will need this for single file format support
|
||||
|
||||
let mut in_file = BufReader::new(File::open(Path::new(filename))?);
|
||||
// in_file.seek(SeekFrom::Start(file_offset as u64))?;
|
||||
|
||||
let expected_file_size: usize = in_file.read_u64::<LittleEndian>()? as usize;
|
||||
self.max_observed_degree = in_file.read_u32::<LittleEndian>()?;
|
||||
self.start = in_file.read_u32::<LittleEndian>()?;
|
||||
let file_frozen_pts: usize = in_file.read_u64::<LittleEndian>()? as usize;
|
||||
|
||||
let vamana_metadata_size = 24;
|
||||
|
||||
println!("From graph header, expected_file_size: {}, max_observed_degree: {}, start: {}, file_frozen_pts: {}",
|
||||
expected_file_size, self.max_observed_degree, self.start, file_frozen_pts);
|
||||
|
||||
if file_frozen_pts != self.configuration.num_frozen_pts {
|
||||
if file_frozen_pts == 1 {
|
||||
return Err(ANNError::log_index_config_error(
|
||||
"num_frozen_pts".to_string(),
|
||||
"ERROR: When loading index, detected dynamic index, but constructor asks for static index. Exitting.".to_string())
|
||||
);
|
||||
} else {
|
||||
return Err(ANNError::log_index_config_error(
|
||||
"num_frozen_pts".to_string(),
|
||||
"ERROR: When loading index, detected static index, but constructor asks for dynamic index. Exitting.".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Loading vamana graph {}...", filename);
|
||||
|
||||
let expected_max_points = expected_num_points - file_frozen_pts;
|
||||
|
||||
// If user provides more points than max_points
|
||||
// resize the _final_graph to the larger size.
|
||||
if self.configuration.max_points < expected_max_points {
|
||||
println!("Number of points in data: {} is greater than max_points: {} Setting max points to: {}", expected_max_points, self.configuration.max_points, expected_max_points);
|
||||
|
||||
self.configuration.max_points = expected_max_points;
|
||||
self.final_graph = InMemoryGraph::new(
|
||||
self.configuration.max_points + self.configuration.num_frozen_pts,
|
||||
self.configuration.index_write_parameter.max_degree,
|
||||
);
|
||||
}
|
||||
|
||||
let mut bytes_read = vamana_metadata_size;
|
||||
let mut num_edges = 0;
|
||||
let mut nodes_read = 0;
|
||||
let mut max_observed_degree = 0;
|
||||
|
||||
while bytes_read != expected_file_size {
|
||||
let num_nbrs = in_file.read_u32::<LittleEndian>()?;
|
||||
max_observed_degree = if num_nbrs > max_observed_degree {
|
||||
num_nbrs
|
||||
} else {
|
||||
max_observed_degree
|
||||
};
|
||||
|
||||
if num_nbrs == 0 {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"ERROR: Point found with no out-neighbors, point# {}",
|
||||
nodes_read
|
||||
)));
|
||||
}
|
||||
|
||||
num_edges += num_nbrs;
|
||||
nodes_read += 1;
|
||||
let mut tmp: Vec<u32> = Vec::with_capacity(num_nbrs as usize);
|
||||
for _ in 0..num_nbrs {
|
||||
tmp.push(in_file.read_u32::<LittleEndian>()?);
|
||||
}
|
||||
|
||||
self.final_graph
|
||||
.write_vertex_and_neighbors(nodes_read - 1)?
|
||||
.set_neighbors(AdjacencyList::from(tmp));
|
||||
bytes_read += 4 * (num_nbrs as usize + 1);
|
||||
}
|
||||
|
||||
println!(
|
||||
"Done. Index has {} nodes and {} out-edges, _start is set to {}",
|
||||
nodes_read, num_edges, self.start
|
||||
);
|
||||
|
||||
self.max_observed_degree = max_observed_degree;
|
||||
Ok(nodes_read as usize)
|
||||
}
|
||||
|
||||
/// Save the graph index on a file as an adjacency list.
|
||||
/// For each point, first store the number of neighbors,
|
||||
/// and then the neighbor list (each as 4 byte u32)
|
||||
pub fn save_graph(&mut self, graph_file: &str) -> ANNResult<u64> {
|
||||
let file: File = File::create(graph_file)?;
|
||||
let mut out = BufWriter::new(file);
|
||||
|
||||
let file_offset: u64 = 0;
|
||||
out.seek(SeekFrom::Start(file_offset))?;
|
||||
let mut index_size: u64 = 24;
|
||||
let mut max_degree: u32 = 0;
|
||||
out.write_all(&index_size.to_le_bytes())?;
|
||||
out.write_all(&self.max_observed_degree.to_le_bytes())?;
|
||||
out.write_all(&self.start.to_le_bytes())?;
|
||||
out.write_all(&(self.configuration.num_frozen_pts as u64).to_le_bytes())?;
|
||||
|
||||
// At this point, either nd == max_points or any frozen points have
|
||||
// been temporarily moved to nd, so nd + num_frozen_points is the valid
|
||||
// location limit
|
||||
for i in 0..self.num_active_pts + self.configuration.num_frozen_pts {
|
||||
let idx = i as u32;
|
||||
let gk: u32 = self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32;
|
||||
out.write_all(&gk.to_le_bytes())?;
|
||||
for neighbor in self
|
||||
.final_graph
|
||||
.read_vertex_and_neighbors(idx)?
|
||||
.get_neighbors()
|
||||
.iter()
|
||||
{
|
||||
out.write_all(&neighbor.to_le_bytes())?;
|
||||
}
|
||||
max_degree =
|
||||
if self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32 > max_degree {
|
||||
self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32
|
||||
} else {
|
||||
max_degree
|
||||
};
|
||||
index_size += (std::mem::size_of::<u32>() * (gk as usize + 1)) as u64;
|
||||
}
|
||||
out.seek(SeekFrom::Start(file_offset))?;
|
||||
out.write_all(&index_size.to_le_bytes())?;
|
||||
out.write_all(&max_degree.to_le_bytes())?;
|
||||
out.flush()?;
|
||||
Ok(index_size)
|
||||
}
|
||||
|
||||
/// Save the data on a file.
|
||||
pub fn save_data(&mut self, data_file: &str) -> ANNResult<usize> {
|
||||
// Note: at this point, either _nd == _max_points or any frozen points have
|
||||
// been temporarily moved to _nd, so _nd + _num_frozen_points is the valid
|
||||
// location limit.
|
||||
Ok(save_data_in_base_dimensions(
|
||||
data_file,
|
||||
&mut self.dataset.data,
|
||||
self.num_active_pts + self.configuration.num_frozen_pts,
|
||||
self.configuration.dim,
|
||||
self.configuration.aligned_dim,
|
||||
0,
|
||||
)?)
|
||||
}
|
||||
|
||||
/// Save the delete list to a file only if the delete list length is not zero.
|
||||
pub fn save_delete_list(&mut self, delete_list_file: &str) -> ANNResult<usize> {
|
||||
let mut delete_file_size = 0;
|
||||
if let Ok(delete_set) = self.delete_set.read() {
|
||||
let delete_set_len = delete_set.len() as u32;
|
||||
|
||||
if delete_set_len != 0 {
|
||||
let file: File = File::create(delete_list_file)?;
|
||||
let mut writer = BufWriter::new(file);
|
||||
|
||||
// Write the length of the set.
|
||||
writer.write_all(&delete_set_len.to_le_bytes())?;
|
||||
delete_file_size += std::mem::size_of::<u32>();
|
||||
|
||||
// Write the elements of the set.
|
||||
for &item in delete_set.iter() {
|
||||
writer.write_all(&item.to_be_bytes())?;
|
||||
delete_file_size += std::mem::size_of::<u32>();
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
}
|
||||
} else {
|
||||
return Err(ANNError::log_lock_poison_error(
|
||||
"Poisoned lock on delete set. Can't save deleted list.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(delete_file_size)
|
||||
}
|
||||
|
||||
// load the deleted list from the delete file if it exists.
|
||||
pub fn load_delete_list(&mut self, delete_list_file: &str) -> ANNResult<usize> {
|
||||
let mut len = 0;
|
||||
|
||||
if file_exists(delete_list_file) {
|
||||
let file = File::open(delete_list_file)?;
|
||||
let mut reader = BufReader::new(file);
|
||||
|
||||
len = reader.read_u32::<LittleEndian>()? as usize;
|
||||
|
||||
if let Ok(mut delete_set) = self.delete_set.write() {
|
||||
for _ in 0..len {
|
||||
let item = reader.read_u32::<LittleEndian>()?;
|
||||
delete_set.insert(item);
|
||||
}
|
||||
} else {
|
||||
return Err(ANNError::log_lock_poison_error(
|
||||
"Poisoned lock on delete set. Can't load deleted list.".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod index_test {
|
||||
use std::fs;
|
||||
|
||||
use vector::Metric;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
index::ANNInmemIndex,
|
||||
model::{
|
||||
configuration::index_write_parameters::IndexWriteParametersBuilder, vertex::DIM_128,
|
||||
IndexConfiguration,
|
||||
},
|
||||
utils::{load_metadata_from_file, round_up},
|
||||
};
|
||||
|
||||
const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin";
|
||||
const R: u32 = 4;
|
||||
const L: u32 = 50;
|
||||
const ALPHA: f32 = 1.2;
|
||||
|
||||
#[cfg_attr(not(coverage), test)]
|
||||
fn save_graph_test() {
|
||||
let parameters = IndexWriteParametersBuilder::new(50, 4)
|
||||
.with_alpha(1.2)
|
||||
.build();
|
||||
let config =
|
||||
IndexConfiguration::new(Metric::L2, 10, 16, 16, false, 0, false, 8, 1f32, parameters);
|
||||
let mut index = InmemIndex::<f32, 3>::new(config).unwrap();
|
||||
let final_graph = InMemoryGraph::new(10, 3);
|
||||
let num_active_pts = 2_usize;
|
||||
index.final_graph = final_graph;
|
||||
index.num_active_pts = num_active_pts;
|
||||
let graph_file = "test_save_graph_data.bin";
|
||||
let result = index.save_graph(graph_file);
|
||||
assert!(result.is_ok());
|
||||
|
||||
fs::remove_file(graph_file).expect("Failed to delete file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_data_test() {
|
||||
let (data_num, dim) = load_metadata_from_file(TEST_DATA_FILE).unwrap();
|
||||
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(L, R)
|
||||
.with_alpha(ALPHA)
|
||||
.build();
|
||||
let config = IndexConfiguration::new(
|
||||
Metric::L2,
|
||||
dim,
|
||||
round_up(dim as u64, 16_u64) as usize,
|
||||
data_num,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
0,
|
||||
1f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let mut index: InmemIndex<f32, DIM_128> = InmemIndex::new(config).unwrap();
|
||||
|
||||
index.build(TEST_DATA_FILE, data_num).unwrap();
|
||||
|
||||
let data_file = "test.data";
|
||||
let result = index.save_data(data_file);
|
||||
assert_eq!(
|
||||
result.unwrap(),
|
||||
2 * std::mem::size_of::<u32>()
|
||||
+ (index.num_active_pts + index.configuration.num_frozen_pts)
|
||||
* index.configuration.dim
|
||||
* (std::mem::size_of::<f32>())
|
||||
);
|
||||
fs::remove_file(data_file).expect("Failed to delete file");
|
||||
}
|
||||
}
|
||||
12
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/inmem_index/mod.rs
vendored
Normal file
12
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/inmem_index/mod.rs
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod inmem_index;
|
||||
pub use inmem_index::InmemIndex;
|
||||
|
||||
mod inmem_index_storage;
|
||||
|
||||
pub mod ann_inmem_index;
|
||||
|
||||
11
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/mod.rs
vendored
Normal file
11
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/index/mod.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
mod inmem_index;
|
||||
pub use inmem_index::ann_inmem_index::*;
|
||||
pub use inmem_index::InmemIndex;
|
||||
|
||||
mod disk_index;
|
||||
pub use disk_index::*;
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use logger::logger::indexlog::DiskIndexConstructionCheckpoint;
|
||||
use logger::logger::indexlog::DiskIndexConstructionLog;
|
||||
use logger::logger::indexlog::Log;
|
||||
use logger::logger::indexlog::LogLevel;
|
||||
use logger::message_handler::send_log;
|
||||
|
||||
use crate::{utils::Timer, common::ANNResult};
|
||||
|
||||
pub struct DiskIndexBuildLogger {
|
||||
timer: Timer,
|
||||
checkpoint: DiskIndexConstructionCheckpoint,
|
||||
}
|
||||
|
||||
impl DiskIndexBuildLogger {
|
||||
pub fn new(checkpoint: DiskIndexConstructionCheckpoint) -> Self {
|
||||
Self {
|
||||
timer: Timer::new(),
|
||||
checkpoint,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_checkpoint(&mut self, next_checkpoint: DiskIndexConstructionCheckpoint) -> ANNResult<()> {
|
||||
if self.checkpoint == DiskIndexConstructionCheckpoint::None {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut log = Log::default();
|
||||
let disk_index_construction_log = DiskIndexConstructionLog {
|
||||
checkpoint: self.checkpoint as i32,
|
||||
time_spent_in_seconds: self.timer.elapsed().as_secs_f32(),
|
||||
g_cycles_spent: self.timer.elapsed_gcycles(),
|
||||
log_level: LogLevel::Info as i32,
|
||||
};
|
||||
log.disk_index_construction_log = Some(disk_index_construction_log);
|
||||
|
||||
send_log(log)?;
|
||||
self.checkpoint = next_checkpoint;
|
||||
self.timer.reset();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dataset_test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_log() {
|
||||
let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction);
|
||||
logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild).unwrap();logger.log_checkpoint(logger::logger::indexlog::DiskIndexConstructionCheckpoint::DiskLayout).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use logger::logger::indexlog::IndexConstructionLog;
|
||||
use logger::logger::indexlog::Log;
|
||||
use logger::logger::indexlog::LogLevel;
|
||||
use logger::message_handler::send_log;
|
||||
|
||||
use crate::common::ANNResult;
|
||||
use crate::utils::Timer;
|
||||
|
||||
pub struct IndexLogger {
|
||||
items_processed: AtomicUsize,
|
||||
timer: Timer,
|
||||
range: usize,
|
||||
}
|
||||
|
||||
impl IndexLogger {
|
||||
pub fn new(range: usize) -> Self {
|
||||
Self {
|
||||
items_processed: AtomicUsize::new(0),
|
||||
timer: Timer::new(),
|
||||
range,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vertex_processed(&self) -> ANNResult<()> {
|
||||
let count = self.items_processed.fetch_add(1, Ordering::Relaxed);
|
||||
if count % 100_000 == 0 {
|
||||
let mut log = Log::default();
|
||||
let index_construction_log = IndexConstructionLog {
|
||||
percentage_complete: (100_f32 * count as f32) / (self.range as f32),
|
||||
time_spent_in_seconds: self.timer.elapsed().as_secs_f32(),
|
||||
g_cycles_spent: self.timer.elapsed_gcycles(),
|
||||
log_level: LogLevel::Info as i32,
|
||||
};
|
||||
log.index_construction_log = Some(index_construction_log);
|
||||
|
||||
send_log(log)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/instrumentation/mod.rs
vendored
Normal file
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/instrumentation/mod.rs
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
mod index_logger;
|
||||
pub use index_logger::IndexLogger;
|
||||
|
||||
mod disk_index_build_logger;
|
||||
pub use disk_index_build_logger::DiskIndexBuildLogger;
|
||||
26
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/lib.rs
vendored
Normal file
26
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/lib.rs
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![cfg_attr(
|
||||
not(test),
|
||||
warn(clippy::panic, clippy::unwrap_used, clippy::expect_used)
|
||||
)]
|
||||
#![cfg_attr(test, allow(clippy::unused_io_amount))]
|
||||
|
||||
pub mod utils;
|
||||
|
||||
pub mod algorithm;
|
||||
|
||||
pub mod model;
|
||||
|
||||
pub mod common;
|
||||
|
||||
pub mod index;
|
||||
|
||||
pub mod storage;
|
||||
|
||||
pub mod instrumentation;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils;
|
||||
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Parameters for disk index construction.
|
||||
|
||||
use crate::common::{ANNResult, ANNError};
|
||||
|
||||
/// Cached nodes size in GB
|
||||
const SPACE_FOR_CACHED_NODES_IN_GB: f64 = 0.25;
|
||||
|
||||
/// Threshold for caching in GB
|
||||
const THRESHOLD_FOR_CACHING_IN_GB: f64 = 1.0;
|
||||
|
||||
/// Parameters specific for disk index construction.
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub struct DiskIndexBuildParameters {
|
||||
/// Bound on the memory footprint of the index at search time in bytes.
|
||||
/// Once built, the index will use up only the specified RAM limit, the rest will reside on disk.
|
||||
/// This will dictate how aggressively we compress the data vectors to store in memory.
|
||||
/// Larger will yield better performance at search time.
|
||||
search_ram_limit: f64,
|
||||
|
||||
/// Limit on the memory allowed for building the index in bytes.
|
||||
index_build_ram_limit: f64,
|
||||
}
|
||||
|
||||
impl DiskIndexBuildParameters {
|
||||
/// Create DiskIndexBuildParameters instance
|
||||
pub fn new(search_ram_limit_gb: f64, index_build_ram_limit_gb: f64) -> ANNResult<Self> {
|
||||
let param = Self {
|
||||
search_ram_limit: Self::get_memory_budget(search_ram_limit_gb),
|
||||
index_build_ram_limit: index_build_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64,
|
||||
};
|
||||
|
||||
if param.search_ram_limit <= 0f64 {
|
||||
return Err(ANNError::log_index_config_error("search_ram_limit".to_string(), "RAM budget should be > 0".to_string()))
|
||||
}
|
||||
|
||||
if param.index_build_ram_limit <= 0f64 {
|
||||
return Err(ANNError::log_index_config_error("index_build_ram_limit".to_string(), "RAM budget should be > 0".to_string()))
|
||||
}
|
||||
|
||||
Ok(param)
|
||||
}
|
||||
|
||||
/// Get search_ram_limit
|
||||
pub fn search_ram_limit(&self) -> f64 {
|
||||
self.search_ram_limit
|
||||
}
|
||||
|
||||
/// Get index_build_ram_limit
|
||||
pub fn index_build_ram_limit(&self) -> f64 {
|
||||
self.index_build_ram_limit
|
||||
}
|
||||
|
||||
fn get_memory_budget(mut index_ram_limit_gb: f64) -> f64 {
|
||||
if index_ram_limit_gb - SPACE_FOR_CACHED_NODES_IN_GB > THRESHOLD_FOR_CACHING_IN_GB {
|
||||
// slack for space used by cached nodes
|
||||
index_ram_limit_gb -= SPACE_FOR_CACHED_NODES_IN_GB;
|
||||
}
|
||||
|
||||
index_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dataset_test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn sufficient_ram_for_caching() {
|
||||
let param = DiskIndexBuildParameters::new(1.26_f64, 1.0_f64).unwrap();
|
||||
assert_eq!(param.search_ram_limit, 1.01_f64 * 1024_f64 * 1024_f64 * 1024_f64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn insufficient_ram_for_caching() {
|
||||
let param = DiskIndexBuildParameters::new(0.03_f64, 1.0_f64).unwrap();
|
||||
assert_eq!(param.search_ram_limit, 0.03_f64 * 1024_f64 * 1024_f64 * 1024_f64);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Index configuration.
|
||||
|
||||
use vector::Metric;
|
||||
|
||||
use super::index_write_parameters::IndexWriteParameters;
|
||||
|
||||
/// The index configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndexConfiguration {
|
||||
/// Index write parameter
|
||||
pub index_write_parameter: IndexWriteParameters,
|
||||
|
||||
/// Distance metric
|
||||
pub dist_metric: Metric,
|
||||
|
||||
/// Dimension of the raw data
|
||||
pub dim: usize,
|
||||
|
||||
/// Aligned dimension - round up dim to the nearest multiple of 8
|
||||
pub aligned_dim: usize,
|
||||
|
||||
/// Total number of points in given data set
|
||||
pub max_points: usize,
|
||||
|
||||
/// Number of points which are used as initial candidates when iterating to
|
||||
/// closest point(s). These are not visible externally and won't be returned
|
||||
/// by search. DiskANN forces at least 1 frozen point for dynamic index.
|
||||
/// The frozen points have consecutive locations.
|
||||
pub num_frozen_pts: usize,
|
||||
|
||||
/// Calculate distance by PQ or not
|
||||
pub use_pq_dist: bool,
|
||||
|
||||
/// Number of PQ chunks
|
||||
pub num_pq_chunks: usize,
|
||||
|
||||
/// Use optimized product quantization
|
||||
/// Currently not supported
|
||||
pub use_opq: bool,
|
||||
|
||||
/// potential for growth. 1.2 means the index can grow by up to 20%.
|
||||
pub growth_potential: f32,
|
||||
|
||||
// TODO: below settings are not supported in current iteration
|
||||
// pub concurrent_consolidate: bool,
|
||||
// pub has_built: bool,
|
||||
// pub save_as_one_file: bool,
|
||||
// pub dynamic_index: bool,
|
||||
// pub enable_tags: bool,
|
||||
// pub normalize_vecs: bool,
|
||||
}
|
||||
|
||||
impl IndexConfiguration {
|
||||
/// Create IndexConfiguration instance
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
dist_metric: Metric,
|
||||
dim: usize,
|
||||
aligned_dim: usize,
|
||||
max_points: usize,
|
||||
use_pq_dist: bool,
|
||||
num_pq_chunks: usize,
|
||||
use_opq: bool,
|
||||
num_frozen_pts: usize,
|
||||
growth_potential: f32,
|
||||
index_write_parameter: IndexWriteParameters
|
||||
) -> Self {
|
||||
Self {
|
||||
index_write_parameter,
|
||||
dist_metric,
|
||||
dim,
|
||||
aligned_dim,
|
||||
max_points,
|
||||
num_frozen_pts,
|
||||
use_pq_dist,
|
||||
num_pq_chunks,
|
||||
use_opq,
|
||||
growth_potential,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the size of adjacency list that we build out.
|
||||
pub fn write_range(&self) -> usize {
|
||||
self.index_write_parameter.max_degree as usize
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Index write parameters.
|
||||
|
||||
/// Default parameter values.
|
||||
pub mod default_param_vals {
|
||||
/// Default value of alpha.
|
||||
pub const ALPHA: f32 = 1.2;
|
||||
|
||||
/// Default value of number of threads.
|
||||
pub const NUM_THREADS: u32 = 0;
|
||||
|
||||
/// Default value of number of rounds.
|
||||
pub const NUM_ROUNDS: u32 = 2;
|
||||
|
||||
/// Default value of max occlusion size.
|
||||
pub const MAX_OCCLUSION_SIZE: u32 = 750;
|
||||
|
||||
/// Default value of filter list size.
|
||||
pub const FILTER_LIST_SIZE: u32 = 0;
|
||||
|
||||
/// Default value of number of frozen points.
|
||||
pub const NUM_FROZEN_POINTS: u32 = 0;
|
||||
|
||||
/// Default value of max degree.
|
||||
pub const MAX_DEGREE: u32 = 64;
|
||||
|
||||
/// Default value of build list size.
|
||||
pub const BUILD_LIST_SIZE: u32 = 100;
|
||||
|
||||
/// Default value of saturate graph.
|
||||
pub const SATURATE_GRAPH: bool = false;
|
||||
|
||||
/// Default value of search list size.
|
||||
pub const SEARCH_LIST_SIZE: u32 = 100;
|
||||
}
|
||||
|
||||
/// Index write parameters.
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub struct IndexWriteParameters {
|
||||
/// Search list size - L.
|
||||
pub search_list_size: u32,
|
||||
|
||||
/// Max degree - R.
|
||||
pub max_degree: u32,
|
||||
|
||||
/// Saturate graph.
|
||||
pub saturate_graph: bool,
|
||||
|
||||
/// Max occlusion size - C.
|
||||
pub max_occlusion_size: u32,
|
||||
|
||||
/// Alpha.
|
||||
pub alpha: f32,
|
||||
|
||||
/// Number of rounds.
|
||||
pub num_rounds: u32,
|
||||
|
||||
/// Number of threads.
|
||||
pub num_threads: u32,
|
||||
|
||||
/// Number of frozen points.
|
||||
pub num_frozen_points: u32,
|
||||
}
|
||||
|
||||
impl Default for IndexWriteParameters {
|
||||
/// Create IndexWriteParameters with default values
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
search_list_size: default_param_vals::SEARCH_LIST_SIZE,
|
||||
max_degree: default_param_vals::MAX_DEGREE,
|
||||
saturate_graph: default_param_vals::SATURATE_GRAPH,
|
||||
max_occlusion_size: default_param_vals::MAX_OCCLUSION_SIZE,
|
||||
alpha: default_param_vals::ALPHA,
|
||||
num_rounds: default_param_vals::NUM_ROUNDS,
|
||||
num_threads: default_param_vals::NUM_THREADS,
|
||||
num_frozen_points: default_param_vals::NUM_FROZEN_POINTS
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The builder for IndexWriteParameters.
|
||||
#[derive(Debug)]
|
||||
pub struct IndexWriteParametersBuilder {
|
||||
search_list_size: u32,
|
||||
max_degree: u32,
|
||||
max_occlusion_size: Option<u32>,
|
||||
saturate_graph: Option<bool>,
|
||||
alpha: Option<f32>,
|
||||
num_rounds: Option<u32>,
|
||||
num_threads: Option<u32>,
|
||||
// filter_list_size: Option<u32>,
|
||||
num_frozen_points: Option<u32>,
|
||||
}
|
||||
|
||||
impl IndexWriteParametersBuilder {
|
||||
/// Initialize IndexWriteParametersBuilder
|
||||
pub fn new(search_list_size: u32, max_degree: u32) -> Self {
|
||||
Self {
|
||||
search_list_size,
|
||||
max_degree,
|
||||
max_occlusion_size: None,
|
||||
saturate_graph: None,
|
||||
alpha: None,
|
||||
num_rounds: None,
|
||||
num_threads: None,
|
||||
// filter_list_size: None,
|
||||
num_frozen_points: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set max occlusion size.
|
||||
pub fn with_max_occlusion_size(mut self, max_occlusion_size: u32) -> Self {
|
||||
self.max_occlusion_size = Some(max_occlusion_size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set saturate graph.
|
||||
pub fn with_saturate_graph(mut self, saturate_graph: bool) -> Self {
|
||||
self.saturate_graph = Some(saturate_graph);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set alpha.
|
||||
pub fn with_alpha(mut self, alpha: f32) -> Self {
|
||||
self.alpha = Some(alpha);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set number of rounds.
|
||||
pub fn with_num_rounds(mut self, num_rounds: u32) -> Self {
|
||||
self.num_rounds = Some(num_rounds);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set number of threads.
|
||||
pub fn with_num_threads(mut self, num_threads: u32) -> Self {
|
||||
self.num_threads = Some(num_threads);
|
||||
self
|
||||
}
|
||||
|
||||
/*
|
||||
pub fn with_filter_list_size(mut self, filter_list_size: u32) -> Self {
|
||||
self.filter_list_size = Some(filter_list_size);
|
||||
self
|
||||
}
|
||||
*/
|
||||
|
||||
/// Set number of frozen points.
|
||||
pub fn with_num_frozen_points(mut self, num_frozen_points: u32) -> Self {
|
||||
self.num_frozen_points = Some(num_frozen_points);
|
||||
self
|
||||
}
|
||||
|
||||
/// Build IndexWriteParameters from IndexWriteParametersBuilder.
|
||||
pub fn build(self) -> IndexWriteParameters {
|
||||
IndexWriteParameters {
|
||||
search_list_size: self.search_list_size,
|
||||
max_degree: self.max_degree,
|
||||
saturate_graph: self.saturate_graph.unwrap_or(default_param_vals::SATURATE_GRAPH),
|
||||
max_occlusion_size: self.max_occlusion_size.unwrap_or(default_param_vals::MAX_OCCLUSION_SIZE),
|
||||
alpha: self.alpha.unwrap_or(default_param_vals::ALPHA),
|
||||
num_rounds: self.num_rounds.unwrap_or(default_param_vals::NUM_ROUNDS),
|
||||
num_threads: self.num_threads.unwrap_or(default_param_vals::NUM_THREADS),
|
||||
// filter_list_size: self.filter_list_size.unwrap_or(default_param_vals::FILTER_LIST_SIZE),
|
||||
num_frozen_points: self.num_frozen_points.unwrap_or(default_param_vals::NUM_FROZEN_POINTS),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct IndexWriteParametersBuilder from IndexWriteParameters.
|
||||
impl From<IndexWriteParameters> for IndexWriteParametersBuilder {
|
||||
fn from(param: IndexWriteParameters) -> Self {
|
||||
Self {
|
||||
search_list_size: param.search_list_size,
|
||||
max_degree: param.max_degree,
|
||||
max_occlusion_size: Some(param.max_occlusion_size),
|
||||
saturate_graph: Some(param.saturate_graph),
|
||||
alpha: Some(param.alpha),
|
||||
num_rounds: Some(param.num_rounds),
|
||||
num_threads: Some(param.num_threads),
|
||||
// filter_list_size: Some(param.filter_list_size),
|
||||
num_frozen_points: Some(param.num_frozen_points),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod parameters_test {
|
||||
use crate::model::configuration::index_write_parameters::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_index_params() {
|
||||
let wp1 = IndexWriteParameters::default();
|
||||
assert_eq!(wp1.search_list_size, default_param_vals::SEARCH_LIST_SIZE);
|
||||
assert_eq!(wp1.max_degree, default_param_vals::MAX_DEGREE);
|
||||
assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH);
|
||||
assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE);
|
||||
assert_eq!(wp1.alpha, default_param_vals::ALPHA);
|
||||
assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS);
|
||||
assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS);
|
||||
assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_write_parameters_builder() {
|
||||
// default value
|
||||
let wp1 = IndexWriteParametersBuilder::new(10, 20).build();
|
||||
assert_eq!(wp1.search_list_size, 10);
|
||||
assert_eq!(wp1.max_degree, 20);
|
||||
assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH);
|
||||
assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE);
|
||||
assert_eq!(wp1.alpha, default_param_vals::ALPHA);
|
||||
assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS);
|
||||
assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS);
|
||||
assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS);
|
||||
|
||||
// build with custom values
|
||||
let wp2 = IndexWriteParametersBuilder::new(10, 20)
|
||||
.with_max_occlusion_size(30)
|
||||
.with_saturate_graph(true)
|
||||
.with_alpha(0.5)
|
||||
.with_num_rounds(40)
|
||||
.with_num_threads(50)
|
||||
.with_num_frozen_points(60)
|
||||
.build();
|
||||
assert_eq!(wp2.search_list_size, 10);
|
||||
assert_eq!(wp2.max_degree, 20);
|
||||
assert!(wp2.saturate_graph);
|
||||
assert_eq!(wp2.max_occlusion_size, 30);
|
||||
assert_eq!(wp2.alpha, 0.5);
|
||||
assert_eq!(wp2.num_rounds, 40);
|
||||
assert_eq!(wp2.num_threads, 50);
|
||||
assert_eq!(wp2.num_frozen_points, 60);
|
||||
|
||||
// test from
|
||||
let wp3 = IndexWriteParametersBuilder::from(wp2).build();
|
||||
assert_eq!(wp3, wp2);
|
||||
}
|
||||
}
|
||||
|
||||
12
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/configuration/mod.rs
vendored
Normal file
12
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/configuration/mod.rs
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub mod index_configuration;
|
||||
pub use index_configuration::IndexConfiguration;
|
||||
|
||||
pub mod index_write_parameters;
|
||||
pub use index_write_parameters::*;
|
||||
|
||||
pub mod disk_index_build_parameter;
|
||||
pub use disk_index_build_parameter::DiskIndexBuildParameters;
|
||||
@@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Disk scratch dataset
|
||||
|
||||
use std::mem::{size_of, size_of_val};
|
||||
use std::ptr;
|
||||
|
||||
use crate::common::{AlignedBoxWithSlice, ANNResult};
|
||||
use crate::model::MAX_N_CMPS;
|
||||
use crate::utils::round_up;
|
||||
|
||||
/// DiskScratchDataset alignment
|
||||
pub const DISK_SCRATCH_DATASET_ALIGN: usize = 256;
|
||||
|
||||
/// Disk scratch dataset storing fp vectors with aligned dim
|
||||
#[derive(Debug)]
|
||||
pub struct DiskScratchDataset<T, const N: usize>
|
||||
{
|
||||
/// fp vectors with aligned dim
|
||||
pub data: AlignedBoxWithSlice<T>,
|
||||
|
||||
/// current index to store the next fp vector
|
||||
pub cur_index: usize,
|
||||
}
|
||||
|
||||
impl<T, const N: usize> DiskScratchDataset<T, N>
|
||||
{
|
||||
/// Create DiskScratchDataset instance
|
||||
pub fn new() -> ANNResult<Self> {
|
||||
Ok(Self {
|
||||
// C++ code allocates round_up(MAX_N_CMPS * N, 256) bytes, shouldn't it be round_up(MAX_N_CMPS * N, 256) * size_of::<T> bytes?
|
||||
data: AlignedBoxWithSlice::new(
|
||||
round_up(MAX_N_CMPS * N, DISK_SCRATCH_DATASET_ALIGN),
|
||||
DISK_SCRATCH_DATASET_ALIGN)?,
|
||||
cur_index: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// memcpy from fp vector bytes (its len should be `dim * size_of::<T>()`) to self.data
|
||||
/// The dest slice is a fp vector with aligned dim
|
||||
/// * fp_vector_buf's dim might not be aligned dim (N)
|
||||
/// # Safety
|
||||
/// Behavior is undefined if any of the following conditions are violated:
|
||||
///
|
||||
/// * `fp_vector_buf`'s len must be `dim * size_of::<T>()` bytes
|
||||
///
|
||||
/// * `fp_vector_buf` must be smaller than or equal to `N * size_of::<T>()` bytes.
|
||||
///
|
||||
/// * `fp_vector_buf` and `self.data` must be nonoverlapping.
|
||||
pub unsafe fn memcpy_from_fp_vector_buf(&mut self, fp_vector_buf: &[u8]) -> &[T] {
|
||||
if self.cur_index == MAX_N_CMPS {
|
||||
self.cur_index = 0;
|
||||
}
|
||||
|
||||
let aligned_dim_vector = &mut self.data[self.cur_index * N..(self.cur_index + 1) * N];
|
||||
|
||||
assert!(fp_vector_buf.len() % size_of::<T>() == 0);
|
||||
assert!(fp_vector_buf.len() <= size_of_val(aligned_dim_vector));
|
||||
|
||||
// memcpy from fp_vector_buf to aligned_dim_vector
|
||||
unsafe {
|
||||
ptr::copy_nonoverlapping(
|
||||
fp_vector_buf.as_ptr(),
|
||||
aligned_dim_vector.as_mut_ptr() as *mut u8,
|
||||
fp_vector_buf.len(),
|
||||
);
|
||||
}
|
||||
|
||||
self.cur_index += 1;
|
||||
aligned_dim_vector
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! In-memory Dataset
|
||||
|
||||
use rayon::prelude::*;
|
||||
use std::mem;
|
||||
use vector::{FullPrecisionDistance, Metric};
|
||||
|
||||
use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice};
|
||||
use crate::model::Vertex;
|
||||
use crate::utils::copy_aligned_data_from_file;
|
||||
|
||||
/// Dataset of all in-memory FP points
|
||||
#[derive(Debug)]
|
||||
pub struct InmemDataset<T, const N: usize>
|
||||
where
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// All in-memory points
|
||||
pub data: AlignedBoxWithSlice<T>,
|
||||
|
||||
/// Number of points we anticipate to have
|
||||
pub num_points: usize,
|
||||
|
||||
/// Number of active points i.e. existing in the graph
|
||||
pub num_active_pts: usize,
|
||||
|
||||
/// Capacity of the dataset
|
||||
pub capacity: usize,
|
||||
}
|
||||
|
||||
impl<'a, T, const N: usize> InmemDataset<T, N>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// Create the dataset with size num_points and growth factor.
|
||||
/// growth factor=1 means no growth (provision 100% space of num_points)
|
||||
/// growth factor=1.2 means provision 120% space of num_points (20% extra space)
|
||||
pub fn new(num_points: usize, index_growth_factor: f32) -> ANNResult<Self> {
|
||||
let capacity = (((num_points * N) as f32) * index_growth_factor) as usize;
|
||||
|
||||
Ok(Self {
|
||||
data: AlignedBoxWithSlice::new(capacity, mem::size_of::<T>() * 16)?,
|
||||
num_points,
|
||||
num_active_pts: num_points,
|
||||
capacity,
|
||||
})
|
||||
}
|
||||
|
||||
/// get immutable data slice
|
||||
pub fn get_data(&self) -> &[T] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Build the dataset from file
|
||||
pub fn build_from_file(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()> {
|
||||
println!(
|
||||
"Loading {} vectors from file {} into dataset...",
|
||||
num_points_to_load, filename
|
||||
);
|
||||
self.num_active_pts = num_points_to_load;
|
||||
|
||||
copy_aligned_data_from_file(filename, self.into_dto(), 0)?;
|
||||
|
||||
println!("Dataset loaded.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append the dataset from file
|
||||
pub fn append_from_file(
|
||||
&mut self,
|
||||
filename: &str,
|
||||
num_points_to_append: usize,
|
||||
) -> ANNResult<()> {
|
||||
println!(
|
||||
"Appending {} vectors from file {} into dataset...",
|
||||
num_points_to_append, filename
|
||||
);
|
||||
if self.num_points + num_points_to_append > self.capacity {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Cannot append {} points to dataset of capacity {}",
|
||||
num_points_to_append, self.capacity
|
||||
)));
|
||||
}
|
||||
|
||||
let pts_offset = self.num_active_pts;
|
||||
copy_aligned_data_from_file(filename, self.into_dto(), pts_offset)?;
|
||||
|
||||
self.num_active_pts += num_points_to_append;
|
||||
self.num_points += num_points_to_append;
|
||||
|
||||
println!("Dataset appended.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get vertex by id
|
||||
pub fn get_vertex(&'a self, id: u32) -> ANNResult<Vertex<'a, T, N>> {
|
||||
let start = id as usize * N;
|
||||
let end = start + N;
|
||||
|
||||
if end <= self.data.len() {
|
||||
let val = <&[T; N]>::try_from(&self.data[start..end]).map_err(|err| {
|
||||
ANNError::log_index_error(format!("Failed to get vertex {}, err={}", id, err))
|
||||
})?;
|
||||
Ok(Vertex::new(val, id))
|
||||
} else {
|
||||
Err(ANNError::log_index_error(format!(
|
||||
"Invalid vertex id {}.",
|
||||
id
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get full precision distance between two nodes
|
||||
pub fn get_distance(&self, id1: u32, id2: u32, metric: Metric) -> ANNResult<f32> {
|
||||
let vertex1 = self.get_vertex(id1)?;
|
||||
let vertex2 = self.get_vertex(id2)?;
|
||||
|
||||
Ok(vertex1.compare(&vertex2, metric))
|
||||
}
|
||||
|
||||
/// find out the medoid, the vertex in the dataset that is closest to the centroid
|
||||
pub fn calculate_medoid_point_id(&self) -> ANNResult<u32> {
|
||||
Ok(self.find_nearest_point_id(self.calculate_centroid_point()?))
|
||||
}
|
||||
|
||||
/// calculate centroid, average of all vertices in the dataset
|
||||
fn calculate_centroid_point(&self) -> ANNResult<[f32; N]> {
|
||||
// Allocate and initialize the centroid vector
|
||||
let mut center: [f32; N] = [0.0; N];
|
||||
|
||||
// Sum the data points' components
|
||||
for i in 0..self.num_active_pts {
|
||||
let vertex = self.get_vertex(i as u32)?;
|
||||
let vertex_slice = vertex.vector();
|
||||
for j in 0..N {
|
||||
center[j] += vertex_slice[j].into();
|
||||
}
|
||||
}
|
||||
|
||||
// Divide by the number of points to calculate the centroid
|
||||
let capacity = self.num_active_pts as f32;
|
||||
for item in center.iter_mut().take(N) {
|
||||
*item /= capacity;
|
||||
}
|
||||
|
||||
Ok(center)
|
||||
}
|
||||
|
||||
/// find out the vertex closest to the given point
|
||||
fn find_nearest_point_id(&self, point: [f32; N]) -> u32 {
|
||||
// compute all to one distance
|
||||
let mut distances = vec![0f32; self.num_active_pts];
|
||||
let slice = &self.data[..];
|
||||
distances.par_iter_mut().enumerate().for_each(|(i, dist)| {
|
||||
let start = i * N;
|
||||
for j in 0..N {
|
||||
let diff: f32 = (point.as_slice()[j] - slice[start + j].into())
|
||||
* (point.as_slice()[j] - slice[start + j].into());
|
||||
*dist += diff;
|
||||
}
|
||||
});
|
||||
|
||||
let mut min_idx = 0;
|
||||
let mut min_dist = f32::MAX;
|
||||
for (i, distance) in distances.iter().enumerate().take(self.num_active_pts) {
|
||||
if *distance < min_dist {
|
||||
min_idx = i;
|
||||
min_dist = *distance;
|
||||
}
|
||||
}
|
||||
min_idx as u32
|
||||
}
|
||||
|
||||
/// Prefetch vertex data in the memory hierarchy
|
||||
/// NOTE: good efficiency when total_vec_size is integral multiple of 64
|
||||
#[inline]
|
||||
pub fn prefetch_vector(&self, id: u32) {
|
||||
let start = id as usize * N;
|
||||
let end = start + N;
|
||||
|
||||
if end <= self.data.len() {
|
||||
let vec = &self.data[start..end];
|
||||
vector::prefetch_vector(vec);
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert into dto object
|
||||
pub fn into_dto(&mut self) -> DatasetDto<T> {
|
||||
DatasetDto {
|
||||
data: &mut self.data,
|
||||
rounded_dim: N,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dataset dto used for other layer, such as storage
|
||||
/// N is the aligned dimension
|
||||
#[derive(Debug)]
|
||||
pub struct DatasetDto<'a, T> {
|
||||
/// data slice borrow from dataset
|
||||
pub data: &'a mut [T],
|
||||
|
||||
/// rounded dimension
|
||||
pub rounded_dim: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dataset_test {
|
||||
use std::fs;
|
||||
|
||||
use super::*;
|
||||
use crate::model::vertex::DIM_128;
|
||||
|
||||
#[test]
|
||||
fn get_vertex_within_range() {
|
||||
let num_points = 1_000_000;
|
||||
let id = 999_999;
|
||||
let dataset = InmemDataset::<f32, DIM_128>::new(num_points, 1f32).unwrap();
|
||||
|
||||
let vertex = dataset.get_vertex(999_999).unwrap();
|
||||
|
||||
assert_eq!(vertex.vertex_id(), id);
|
||||
assert_eq!(vertex.vector().len(), DIM_128);
|
||||
assert_eq!(vertex.vector().as_ptr(), unsafe {
|
||||
dataset.data.as_ptr().add((id as usize) * DIM_128)
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_vertex_out_of_range() {
|
||||
let num_points = 1_000_000;
|
||||
let invalid_id = 1_000_000;
|
||||
let dataset = InmemDataset::<f32, DIM_128>::new(num_points, 1f32).unwrap();
|
||||
|
||||
if dataset.get_vertex(invalid_id).is_ok() {
|
||||
panic!("id ({}) should be out of range", invalid_id)
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_data_test() {
|
||||
let file_name = "dataset_test_load_data_test.bin";
|
||||
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
|
||||
let data: [u8; 72] = [
|
||||
2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
|
||||
0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
|
||||
0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
|
||||
0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
|
||||
0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
|
||||
];
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
|
||||
let mut dataset = InmemDataset::<f32, 8>::new(2, 1f32).unwrap();
|
||||
|
||||
match copy_aligned_data_from_file(
|
||||
file_name,
|
||||
dataset.into_dto(),
|
||||
0,
|
||||
) {
|
||||
Ok((npts, dim)) => {
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
assert!(npts == 2);
|
||||
assert!(dim == 8);
|
||||
assert!(dataset.data.len() == 16);
|
||||
|
||||
let first_vertex = dataset.get_vertex(0).unwrap();
|
||||
let second_vertex = dataset.get_vertex(1).unwrap();
|
||||
|
||||
assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
|
||||
assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
|
||||
}
|
||||
Err(e) => {
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
panic!("{}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
11
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/data_store/mod.rs
vendored
Normal file
11
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/data_store/mod.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod inmem_dataset;
|
||||
pub use inmem_dataset::InmemDataset;
|
||||
pub use inmem_dataset::DatasetDto;
|
||||
|
||||
mod disk_scratch_dataset;
|
||||
pub use disk_scratch_dataset::*;
|
||||
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Adjacency List
|
||||
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
#[derive(Debug, Eq, PartialEq)]
|
||||
/// Represents the out neighbors of a vertex
|
||||
pub struct AdjacencyList {
|
||||
edges: Vec<u32>,
|
||||
}
|
||||
|
||||
/// In-mem index related limits
|
||||
const GRAPH_SLACK_FACTOR: f32 = 1.3_f32;
|
||||
|
||||
impl AdjacencyList {
|
||||
/// Create AdjacencyList with capacity slack for a range.
|
||||
pub fn for_range(range: usize) -> Self {
|
||||
let capacity = (range as f32 * GRAPH_SLACK_FACTOR).ceil() as usize;
|
||||
Self {
|
||||
edges: Vec::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a node to the list of neighbors for the given node.
|
||||
pub fn push(&mut self, node_id: u32) {
|
||||
debug_assert!(self.edges.len() < self.edges.capacity());
|
||||
self.edges.push(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<u32>> for AdjacencyList {
|
||||
fn from(edges: Vec<u32>) -> Self {
|
||||
Self { edges }
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for AdjacencyList {
|
||||
type Target = Vec<u32>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.edges
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for AdjacencyList {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.edges
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a AdjacencyList {
|
||||
type Item = &'a u32;
|
||||
type IntoIter = std::slice::Iter<'a, u32>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.edges.iter()
|
||||
}
|
||||
}
|
||||
|
||||
179
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/disk_graph.rs
vendored
Normal file
179
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/disk_graph.rs
vendored
Normal file
@@ -0,0 +1,179 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Disk graph
|
||||
|
||||
use byteorder::{LittleEndian, ByteOrder};
|
||||
use vector::FullPrecisionDistance;
|
||||
|
||||
use crate::common::{ANNResult, ANNError};
|
||||
use crate::model::data_store::DiskScratchDataset;
|
||||
use crate::model::Vertex;
|
||||
use crate::storage::DiskGraphStorage;
|
||||
|
||||
use super::{VertexAndNeighbors, SectorGraph, AdjacencyList};
|
||||
|
||||
/// Disk graph
|
||||
pub struct DiskGraph {
|
||||
/// dim of fp vector in disk sector
|
||||
dim: usize,
|
||||
|
||||
/// number of nodes per sector
|
||||
num_nodes_per_sector: u64,
|
||||
|
||||
/// max node length in bytes
|
||||
max_node_len: u64,
|
||||
|
||||
/// the len of fp vector
|
||||
fp_vector_len: u64,
|
||||
|
||||
/// list of nodes (vertex_id) to fetch from disk
|
||||
nodes_to_fetch: Vec<u32>,
|
||||
|
||||
/// Sector graph
|
||||
sector_graph: SectorGraph,
|
||||
}
|
||||
|
||||
impl<'a> DiskGraph {
|
||||
/// Create DiskGraph instance
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
num_nodes_per_sector: u64,
|
||||
max_node_len: u64,
|
||||
fp_vector_len: u64,
|
||||
beam_width: usize,
|
||||
graph_storage: DiskGraphStorage,
|
||||
) -> ANNResult<Self> {
|
||||
let graph = Self {
|
||||
dim,
|
||||
num_nodes_per_sector,
|
||||
max_node_len,
|
||||
fp_vector_len,
|
||||
nodes_to_fetch: Vec::with_capacity(2 * beam_width),
|
||||
sector_graph: SectorGraph::new(graph_storage)?,
|
||||
};
|
||||
|
||||
Ok(graph)
|
||||
}
|
||||
|
||||
/// Add vertex_id into the list to fetch from disk
|
||||
pub fn add_vertex(&mut self, id: u32) {
|
||||
self.nodes_to_fetch.push(id);
|
||||
}
|
||||
|
||||
/// Fetch nodes from disk index
|
||||
pub fn fetch_nodes(&mut self) -> ANNResult<()> {
|
||||
let sectors_to_fetch: Vec<u64> = self.nodes_to_fetch.iter().map(|&id| self.node_sector_index(id)).collect();
|
||||
self.sector_graph.read_graph(§ors_to_fetch)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Copy disk fp vector to DiskScratchDataset
|
||||
/// Return the fp vector with aligned dim from DiskScratchDataset
|
||||
pub fn copy_fp_vector_to_disk_scratch_dataset<T, const N: usize>(
|
||||
&self,
|
||||
node_index: usize,
|
||||
disk_scratch_dataset: &'a mut DiskScratchDataset<T, N>
|
||||
) -> ANNResult<Vertex<'a, T, N>>
|
||||
where
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
if self.dim > N {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"copy_sector_fp_to_aligned_dataset: dim {} is greater than aligned dim {}",
|
||||
self.dim, N)));
|
||||
}
|
||||
|
||||
let fp_vector_buf = self.node_fp_vector_buf(node_index);
|
||||
|
||||
// Safety condition is met here
|
||||
let aligned_dim_vector = unsafe { disk_scratch_dataset.memcpy_from_fp_vector_buf(fp_vector_buf) };
|
||||
|
||||
Vertex::<'a, T, N>::try_from((aligned_dim_vector, self.nodes_to_fetch[node_index]))
|
||||
.map_err(|err| ANNError::log_index_error(format!("TryFromSliceError: failed to get Vertex for disk index node, err={}", err)))
|
||||
}
|
||||
|
||||
/// Reset graph
|
||||
pub fn reset(&mut self) {
|
||||
self.nodes_to_fetch.clear();
|
||||
self.sector_graph.reset();
|
||||
}
|
||||
|
||||
fn get_vertex_and_neighbors(&self, node_index: usize) -> VertexAndNeighbors {
|
||||
let node_disk_buf = self.node_disk_buf(node_index);
|
||||
let buf = &node_disk_buf[self.fp_vector_len as usize..];
|
||||
let num_neighbors = LittleEndian::read_u32(&buf[0..4]) as usize;
|
||||
let neighbors_buf = &buf[4..4 + num_neighbors * 4];
|
||||
|
||||
let mut adjacency_list = AdjacencyList::for_range(num_neighbors);
|
||||
for chunk in neighbors_buf.chunks(4) {
|
||||
let neighbor_id = LittleEndian::read_u32(chunk);
|
||||
adjacency_list.push(neighbor_id);
|
||||
}
|
||||
|
||||
VertexAndNeighbors::new(self.nodes_to_fetch[node_index], adjacency_list)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn node_sector_index(&self, vertex_id: u32) -> u64 {
|
||||
vertex_id as u64 / self.num_nodes_per_sector + 1
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn node_disk_buf(&self, node_index: usize) -> &[u8] {
|
||||
let vertex_id = self.nodes_to_fetch[node_index];
|
||||
|
||||
// get sector_buf where this node is located
|
||||
let sector_buf = self.sector_graph.get_sector_buf(node_index);
|
||||
let node_offset = (vertex_id as u64 % self.num_nodes_per_sector * self.max_node_len) as usize;
|
||||
§or_buf[node_offset..node_offset + self.max_node_len as usize]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn node_fp_vector_buf(&self, node_index: usize) -> &[u8] {
|
||||
let node_disk_buf = self.node_disk_buf(node_index);
|
||||
&node_disk_buf[..self.fp_vector_len as usize]
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator for DiskGraph
|
||||
pub struct DiskGraphIntoIterator<'a> {
|
||||
graph: &'a DiskGraph,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a DiskGraph
|
||||
{
|
||||
type IntoIter = DiskGraphIntoIterator<'a>;
|
||||
type Item = ANNResult<(usize, VertexAndNeighbors)>;
|
||||
|
||||
#[inline]
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DiskGraphIntoIterator {
|
||||
graph: self,
|
||||
index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DiskGraphIntoIterator<'a>
|
||||
{
|
||||
type Item = ANNResult<(usize, VertexAndNeighbors)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index >= self.graph.nodes_to_fetch.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let node_index = self.index;
|
||||
let vertex_and_neighbors = self.graph.get_vertex_and_neighbors(self.index);
|
||||
|
||||
self.index += 1;
|
||||
Some(Ok((node_index, vertex_and_neighbors)))
|
||||
}
|
||||
}
|
||||
|
||||
141
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/inmem_graph.rs
vendored
Normal file
141
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/inmem_graph.rs
vendored
Normal file
@@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! In-memory graph
|
||||
|
||||
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
|
||||
use crate::common::ANNError;
|
||||
|
||||
use super::VertexAndNeighbors;
|
||||
|
||||
/// The entire graph of in-memory index
|
||||
#[derive(Debug)]
|
||||
pub struct InMemoryGraph {
|
||||
/// The entire graph
|
||||
pub final_graph: Vec<RwLock<VertexAndNeighbors>>,
|
||||
}
|
||||
|
||||
impl InMemoryGraph {
|
||||
/// Create InMemoryGraph instance
|
||||
pub fn new(size: usize, max_degree: u32) -> Self {
|
||||
let mut graph = Vec::with_capacity(size);
|
||||
for id in 0..size {
|
||||
graph.push(RwLock::new(VertexAndNeighbors::for_range(
|
||||
id as u32,
|
||||
max_degree as usize,
|
||||
)));
|
||||
}
|
||||
Self { final_graph: graph }
|
||||
}
|
||||
|
||||
/// Size of graph
|
||||
pub fn size(&self) -> usize {
|
||||
self.final_graph.len()
|
||||
}
|
||||
|
||||
/// Extend the graph by size vectors
|
||||
pub fn extend(&mut self, size: usize, max_degree: u32) {
|
||||
for id in 0..size {
|
||||
self.final_graph
|
||||
.push(RwLock::new(VertexAndNeighbors::for_range(
|
||||
id as u32,
|
||||
max_degree as usize,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get read guard of vertex_id
|
||||
pub fn read_vertex_and_neighbors(
|
||||
&self,
|
||||
vertex_id: u32,
|
||||
) -> Result<RwLockReadGuard<VertexAndNeighbors>, ANNError> {
|
||||
self.final_graph[vertex_id as usize].read().map_err(|err| {
|
||||
ANNError::log_lock_poison_error(format!(
|
||||
"PoisonError: Lock poisoned when reading final_graph for vertex_id {}, err={}",
|
||||
vertex_id, err
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Get write guard of vertex_id
|
||||
pub fn write_vertex_and_neighbors(
|
||||
&self,
|
||||
vertex_id: u32,
|
||||
) -> Result<RwLockWriteGuard<VertexAndNeighbors>, ANNError> {
|
||||
self.final_graph[vertex_id as usize].write().map_err(|err| {
|
||||
ANNError::log_lock_poison_error(format!(
|
||||
"PoisonError: Lock poisoned when writing final_graph for vertex_id {}, err={}",
|
||||
vertex_id, err
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod graph_tests {
|
||||
use crate::model::{graph::AdjacencyList, GRAPH_SLACK_FACTOR};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let graph = InMemoryGraph::new(10, 10);
|
||||
let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize;
|
||||
|
||||
assert_eq!(graph.final_graph.len(), 10);
|
||||
for i in 0..10 {
|
||||
let neighbor = graph.final_graph[i].read().unwrap();
|
||||
assert_eq!(neighbor.vertex_id, i as u32);
|
||||
assert_eq!(neighbor.get_neighbors().capacity(), capacity);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_size() {
|
||||
let graph = InMemoryGraph::new(10, 10);
|
||||
assert_eq!(graph.size(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extend() {
|
||||
let mut graph = InMemoryGraph::new(10, 10);
|
||||
graph.extend(10, 10);
|
||||
|
||||
assert_eq!(graph.size(), 20);
|
||||
|
||||
let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize;
|
||||
let mut id: u32 = 0;
|
||||
|
||||
for i in 10..20 {
|
||||
let neighbor = graph.final_graph[i].read().unwrap();
|
||||
assert_eq!(neighbor.vertex_id, id);
|
||||
assert_eq!(neighbor.get_neighbors().capacity(), capacity);
|
||||
id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_vertex_and_neighbors() {
|
||||
let graph = InMemoryGraph::new(10, 10);
|
||||
let neighbor = graph.read_vertex_and_neighbors(0);
|
||||
assert!(neighbor.is_ok());
|
||||
assert_eq!(neighbor.unwrap().vertex_id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_vertex_and_neighbors() {
|
||||
let graph = InMemoryGraph::new(10, 10);
|
||||
{
|
||||
let neighbor = graph.write_vertex_and_neighbors(0);
|
||||
assert!(neighbor.is_ok());
|
||||
neighbor.unwrap().add_to_neighbors(10, 10);
|
||||
}
|
||||
|
||||
let neighbor = graph.read_vertex_and_neighbors(0).unwrap();
|
||||
assert_eq!(neighbor.get_neighbors(), &AdjacencyList::from(vec![10_u32]));
|
||||
}
|
||||
}
|
||||
20
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/mod.rs
vendored
Normal file
20
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/mod.rs
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod inmem_graph;
|
||||
pub use inmem_graph::InMemoryGraph;
|
||||
|
||||
pub mod vertex_and_neighbors;
|
||||
pub use vertex_and_neighbors::VertexAndNeighbors;
|
||||
|
||||
mod adjacency_list;
|
||||
pub use adjacency_list::AdjacencyList;
|
||||
|
||||
mod sector_graph;
|
||||
pub use sector_graph::*;
|
||||
|
||||
mod disk_graph;
|
||||
pub use disk_graph::*;
|
||||
|
||||
87
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/sector_graph.rs
vendored
Normal file
87
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/graph/sector_graph.rs
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Sector graph
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::common::{AlignedBoxWithSlice, ANNResult, ANNError};
|
||||
use crate::model::{MAX_N_SECTOR_READS, SECTOR_LEN, AlignedRead};
|
||||
use crate::storage::DiskGraphStorage;
|
||||
|
||||
/// Sector graph read from disk index
|
||||
pub struct SectorGraph {
|
||||
/// Sector bytes from disk
|
||||
/// One sector has num_nodes_per_sector nodes
|
||||
/// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]}
|
||||
/// The fp vector is not aligned
|
||||
sectors_data: AlignedBoxWithSlice<u8>,
|
||||
|
||||
/// Graph storage to read sectors
|
||||
graph_storage: DiskGraphStorage,
|
||||
|
||||
/// Current sector index into which the next read reads data
|
||||
cur_sector_idx: u64,
|
||||
}
|
||||
|
||||
impl SectorGraph {
|
||||
/// Create SectorGraph instance
|
||||
pub fn new(graph_storage: DiskGraphStorage) -> ANNResult<Self> {
|
||||
Ok(Self {
|
||||
sectors_data: AlignedBoxWithSlice::new(MAX_N_SECTOR_READS * SECTOR_LEN, SECTOR_LEN)?,
|
||||
graph_storage,
|
||||
cur_sector_idx: 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reset SectorGraph
|
||||
pub fn reset(&mut self) {
|
||||
self.cur_sector_idx = 0;
|
||||
}
|
||||
|
||||
/// Read sectors into sectors_data
|
||||
/// They are in the same order as sectors_to_fetch
|
||||
pub fn read_graph(&mut self, sectors_to_fetch: &[u64]) -> ANNResult<()> {
|
||||
let cur_sector_idx_usize: usize = self.cur_sector_idx.try_into()?;
|
||||
if sectors_to_fetch.len() > MAX_N_SECTOR_READS - cur_sector_idx_usize {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}",
|
||||
sectors_to_fetch.len(),
|
||||
MAX_N_SECTOR_READS - cur_sector_idx_usize,
|
||||
)));
|
||||
}
|
||||
|
||||
let mut sector_slices = self.sectors_data.split_into_nonoverlapping_mut_slices(
|
||||
cur_sector_idx_usize * SECTOR_LEN..(cur_sector_idx_usize + sectors_to_fetch.len()) * SECTOR_LEN,
|
||||
SECTOR_LEN)?;
|
||||
|
||||
let mut read_requests = Vec::with_capacity(sector_slices.len());
|
||||
for (local_sector_idx, slice) in sector_slices.iter_mut().enumerate() {
|
||||
let sector_id = sectors_to_fetch[local_sector_idx];
|
||||
read_requests.push(AlignedRead::new(sector_id * SECTOR_LEN as u64, slice)?);
|
||||
}
|
||||
|
||||
self.graph_storage.read(&mut read_requests)?;
|
||||
self.cur_sector_idx += sectors_to_fetch.len() as u64;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get sector data by local index
|
||||
#[inline]
|
||||
pub fn get_sector_buf(&self, local_sector_idx: usize) -> &[u8] {
|
||||
&self.sectors_data[local_sector_idx * SECTOR_LEN..(local_sector_idx + 1) * SECTOR_LEN]
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for SectorGraph {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.sectors_data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Vertex and its Adjacency List
|
||||
|
||||
use crate::model::GRAPH_SLACK_FACTOR;
|
||||
|
||||
use super::AdjacencyList;
|
||||
|
||||
/// The out neighbors of vertex_id
|
||||
#[derive(Debug)]
|
||||
pub struct VertexAndNeighbors {
|
||||
/// The id of the vertex
|
||||
pub vertex_id: u32,
|
||||
|
||||
/// All out neighbors (id) of vertex_id
|
||||
neighbors: AdjacencyList,
|
||||
}
|
||||
|
||||
impl VertexAndNeighbors {
|
||||
/// Create VertexAndNeighbors with id and capacity
|
||||
pub fn for_range(id: u32, range: usize) -> Self {
|
||||
Self {
|
||||
vertex_id: id,
|
||||
neighbors: AdjacencyList::for_range(range),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create VertexAndNeighbors with id and neighbors
|
||||
pub fn new(vertex_id: u32, neighbors: AdjacencyList) -> Self {
|
||||
Self {
|
||||
vertex_id,
|
||||
neighbors,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get size of neighbors
|
||||
#[inline(always)]
|
||||
pub fn size(&self) -> usize {
|
||||
self.neighbors.len()
|
||||
}
|
||||
|
||||
/// Update the neighbors vector (post a pruning exercise)
|
||||
#[inline(always)]
|
||||
pub fn set_neighbors(&mut self, new_neighbors: AdjacencyList) {
|
||||
// Replace the graph entry with the pruned neighbors
|
||||
self.neighbors = new_neighbors;
|
||||
}
|
||||
|
||||
/// Get the neighbors
|
||||
#[inline(always)]
|
||||
pub fn get_neighbors(&self) -> &AdjacencyList {
|
||||
&self.neighbors
|
||||
}
|
||||
|
||||
/// Adds a node to the list of neighbors for the given node.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `node_id` - The ID of the node to add.
|
||||
/// * `range` - The range of the graph.
|
||||
///
|
||||
/// # Return
|
||||
///
|
||||
/// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full.
|
||||
pub fn add_to_neighbors(&mut self, node_id: u32, range: u32) -> Option<Vec<u32>> {
|
||||
// Check if n is already in the graph entry
|
||||
if self.neighbors.contains(&node_id) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let neighbor_len = self.neighbors.len();
|
||||
|
||||
// If not, check if the graph entry has enough space
|
||||
if neighbor_len < (GRAPH_SLACK_FACTOR * range as f64) as usize {
|
||||
// If yes, add n to the graph entry
|
||||
self.neighbors.push(node_id);
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut copy_of_neighbors = Vec::with_capacity(neighbor_len + 1);
|
||||
unsafe {
|
||||
let dst = copy_of_neighbors.as_mut_ptr();
|
||||
std::ptr::copy_nonoverlapping(self.neighbors.as_ptr(), dst, neighbor_len);
|
||||
dst.add(neighbor_len).write(node_id);
|
||||
copy_of_neighbors.set_len(neighbor_len + 1);
|
||||
}
|
||||
|
||||
Some(copy_of_neighbors)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod vertex_and_neighbors_tests {
|
||||
use crate::model::GRAPH_SLACK_FACTOR;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_set_with_capacity() {
|
||||
let neighbors = VertexAndNeighbors::for_range(20, 10);
|
||||
assert_eq!(neighbors.vertex_id, 20);
|
||||
assert_eq!(
|
||||
neighbors.neighbors.capacity(),
|
||||
(10_f32 * GRAPH_SLACK_FACTOR as f32).ceil() as usize
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_size() {
|
||||
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
|
||||
|
||||
for i in 0..5 {
|
||||
neighbors.neighbors.push(i);
|
||||
}
|
||||
|
||||
assert_eq!(neighbors.size(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_neighbors() {
|
||||
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
|
||||
let new_vec = AdjacencyList::from(vec![1, 2, 3, 4, 5]);
|
||||
neighbors.set_neighbors(AdjacencyList::from(new_vec.clone()));
|
||||
|
||||
assert_eq!(neighbors.neighbors, new_vec);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_neighbors() {
|
||||
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
|
||||
neighbors.set_neighbors(AdjacencyList::from(vec![1, 2, 3, 4, 5]));
|
||||
let neighbor_ref = neighbors.get_neighbors();
|
||||
|
||||
assert!(std::ptr::eq(&neighbors.neighbors, neighbor_ref))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_to_neighbors() {
|
||||
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
|
||||
|
||||
assert_eq!(neighbors.add_to_neighbors(1, 1), None);
|
||||
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1]));
|
||||
|
||||
assert_eq!(neighbors.add_to_neighbors(1, 1), None);
|
||||
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1]));
|
||||
|
||||
let ret = neighbors.add_to_neighbors(2, 1);
|
||||
assert!(ret.is_some());
|
||||
assert_eq!(ret.unwrap(), vec![1, 2]);
|
||||
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1]));
|
||||
|
||||
assert_eq!(neighbors.add_to_neighbors(2, 2), None);
|
||||
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1, 2]));
|
||||
}
|
||||
}
|
||||
29
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/mod.rs
vendored
Normal file
29
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/mod.rs
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub mod neighbor;
|
||||
pub use neighbor::Neighbor;
|
||||
pub use neighbor::NeighborPriorityQueue;
|
||||
|
||||
pub mod data_store;
|
||||
pub use data_store::InmemDataset;
|
||||
|
||||
pub mod graph;
|
||||
pub use graph::InMemoryGraph;
|
||||
pub use graph::VertexAndNeighbors;
|
||||
|
||||
pub mod configuration;
|
||||
pub use configuration::*;
|
||||
|
||||
pub mod scratch;
|
||||
pub use scratch::*;
|
||||
|
||||
pub mod vertex;
|
||||
pub use vertex::Vertex;
|
||||
|
||||
pub mod pq;
|
||||
pub use pq::*;
|
||||
|
||||
pub mod windows_aligned_file_reader;
|
||||
pub use windows_aligned_file_reader::*;
|
||||
13
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/neighbor/mod.rs
vendored
Normal file
13
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/neighbor/mod.rs
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod neighbor;
|
||||
pub use neighbor::*;
|
||||
|
||||
mod neighbor_priority_queue;
|
||||
pub use neighbor_priority_queue::*;
|
||||
|
||||
mod sorted_neighbor_vector;
|
||||
pub use sorted_neighbor_vector::SortedNeighborVector;
|
||||
104
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/neighbor/neighbor.rs
vendored
Normal file
104
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/neighbor/neighbor.rs
vendored
Normal file
@@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::cmp::Ordering;
|
||||
|
||||
/// Neighbor node
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Neighbor {
|
||||
/// The id of the node
|
||||
pub id: u32,
|
||||
|
||||
/// The distance from the query node to current node
|
||||
pub distance: f32,
|
||||
|
||||
/// Whether the current is visited or not
|
||||
pub visited: bool,
|
||||
}
|
||||
|
||||
impl Neighbor {
|
||||
/// Create the neighbor node and it has not been visited
|
||||
pub fn new (id: u32, distance: f32) -> Self {
|
||||
Self {
|
||||
id,
|
||||
distance,
|
||||
visited: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Neighbor {
|
||||
fn default() -> Self {
|
||||
Self { id: 0, distance: 0.0_f32, visited: false }
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Neighbor {
|
||||
#[inline]
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.id == other.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Neighbor {}
|
||||
|
||||
impl Ord for Neighbor {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let ord = self.distance.partial_cmp(&other.distance).unwrap_or(std::cmp::Ordering::Equal);
|
||||
|
||||
if ord == Ordering::Equal {
|
||||
return self.id.cmp(&other.id);
|
||||
}
|
||||
|
||||
ord
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Neighbor {
|
||||
#[inline]
|
||||
fn lt(&self, other: &Self) -> bool {
|
||||
self.distance < other.distance || (self.distance == other.distance && self.id < other.id)
|
||||
}
|
||||
|
||||
// Reason for allowing panic = "Does not support comparing Neighbor with partial_cmp"
|
||||
#[allow(clippy::panic)]
|
||||
fn partial_cmp(&self, _: &Self) -> Option<std::cmp::Ordering> {
|
||||
panic!("Neighbor only allows eq and lt")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod neighbor_test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn eq_lt_works() {
|
||||
let n1 = Neighbor::new(1, 1.1);
|
||||
let n2 = Neighbor::new(2, 2.0);
|
||||
let n3 = Neighbor::new(1, 1.1);
|
||||
|
||||
assert!(n1 != n2);
|
||||
assert!(n1 < n2);
|
||||
assert!(n1 == n3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn gt_should_panic() {
|
||||
let n1 = Neighbor::new(1, 1.1);
|
||||
let n2 = Neighbor::new(2, 2.0);
|
||||
|
||||
assert!(n2 > n1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn le_should_panic() {
|
||||
let n1 = Neighbor::new(1, 1.1);
|
||||
let n2 = Neighbor::new(2, 2.0);
|
||||
|
||||
assert!(n1 <= n2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use crate::model::Neighbor;
|
||||
|
||||
/// Neighbor priority Queue based on the distance to the query node
|
||||
#[derive(Debug)]
|
||||
pub struct NeighborPriorityQueue {
|
||||
/// The size of the priority queue
|
||||
size: usize,
|
||||
|
||||
/// The capacity of the priority queue
|
||||
capacity: usize,
|
||||
|
||||
/// The current notvisited neighbor whose distance is smallest among all notvisited neighbor
|
||||
cur: usize,
|
||||
|
||||
/// The neighbor collection
|
||||
data: Vec<Neighbor>,
|
||||
}
|
||||
|
||||
impl Default for NeighborPriorityQueue {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl NeighborPriorityQueue {
|
||||
/// Create NeighborPriorityQueue without capacity
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
capacity: 0,
|
||||
cur: 0,
|
||||
data: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create NeighborPriorityQueue with capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
capacity,
|
||||
cur: 0,
|
||||
data: vec![Neighbor::default(); capacity + 1],
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts item with order.
|
||||
/// The item will be dropped if queue is full / already exist in queue / it has a greater distance than the last item.
|
||||
/// The set cursor that is used to pop() the next item will be set to the lowest index of an uncheck item.
|
||||
pub fn insert(&mut self, nbr: Neighbor) {
|
||||
if self.size == self.capacity && self.get_at(self.size - 1) < &nbr {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut lo = 0;
|
||||
let mut hi = self.size;
|
||||
while lo < hi {
|
||||
let mid = (lo + hi) >> 1;
|
||||
if &nbr < self.get_at(mid) {
|
||||
hi = mid;
|
||||
} else if self.get_at(mid).id == nbr.id {
|
||||
// Make sure the same neighbor isn't inserted into the set
|
||||
return;
|
||||
} else {
|
||||
lo = mid + 1;
|
||||
}
|
||||
}
|
||||
|
||||
if lo < self.capacity {
|
||||
self.data.copy_within(lo..self.size, lo + 1);
|
||||
}
|
||||
self.data[lo] = Neighbor::new(nbr.id, nbr.distance);
|
||||
if self.size < self.capacity {
|
||||
self.size += 1;
|
||||
}
|
||||
if lo < self.cur {
|
||||
self.cur = lo;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the neighbor at index - SAFETY: index must be less than size
|
||||
fn get_at(&self, index: usize) -> &Neighbor {
|
||||
unsafe { self.data.get_unchecked(index) }
|
||||
}
|
||||
|
||||
/// Get the closest and notvisited neighbor
|
||||
pub fn closest_notvisited(&mut self) -> Neighbor {
|
||||
self.data[self.cur].visited = true;
|
||||
let pre = self.cur;
|
||||
while self.cur < self.size && self.get_at(self.cur).visited {
|
||||
self.cur += 1;
|
||||
}
|
||||
self.data[pre]
|
||||
}
|
||||
|
||||
/// Whether there is notvisited node or not
|
||||
pub fn has_notvisited_node(&self) -> bool {
|
||||
self.cur < self.size
|
||||
}
|
||||
|
||||
/// Get the size of the NeighborPriorityQueue
|
||||
pub fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
/// Get the capacity of the NeighborPriorityQueue
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
/// Sets an artificial capacity of the NeighborPriorityQueue. For benchmarking purposes only.
|
||||
pub fn set_capacity(&mut self, capacity: usize) {
|
||||
if capacity < self.data.len() {
|
||||
self.capacity = capacity;
|
||||
}
|
||||
}
|
||||
|
||||
/// Reserve capacity
|
||||
pub fn reserve(&mut self, capacity: usize) {
|
||||
if capacity > self.capacity {
|
||||
self.data.resize(capacity + 1, Neighbor::default());
|
||||
self.capacity = capacity;
|
||||
}
|
||||
}
|
||||
|
||||
/// Set size and cur to 0
|
||||
pub fn clear(&mut self) {
|
||||
self.size = 0;
|
||||
self.cur = 0;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Index<usize> for NeighborPriorityQueue {
|
||||
type Output = Neighbor;
|
||||
|
||||
fn index(&self, i: usize) -> &Self::Output {
|
||||
&self.data[i]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod neighbor_priority_queue_test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_reserve_capacity() {
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(10);
|
||||
assert_eq!(queue.capacity(), 10);
|
||||
queue.reserve(20);
|
||||
assert_eq!(queue.capacity(), 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert() {
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(3);
|
||||
assert_eq!(queue.size(), 0);
|
||||
queue.insert(Neighbor::new(1, 1.0));
|
||||
queue.insert(Neighbor::new(2, 0.5));
|
||||
assert_eq!(queue.size(), 2);
|
||||
queue.insert(Neighbor::new(2, 0.5)); // should be ignored as the same neighbor
|
||||
assert_eq!(queue.size(), 2);
|
||||
queue.insert(Neighbor::new(3, 0.9));
|
||||
assert_eq!(queue.size(), 3);
|
||||
assert_eq!(queue[2].id, 1);
|
||||
queue.insert(Neighbor::new(4, 2.0)); // should be dropped as queue is full and distance is greater than last item
|
||||
assert_eq!(queue.size(), 3);
|
||||
assert_eq!(queue[0].id, 2); // node id in queue should be [2,3,1]
|
||||
assert_eq!(queue[1].id, 3);
|
||||
assert_eq!(queue[2].id, 1);
|
||||
println!("{:?}", queue);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index() {
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(3);
|
||||
queue.insert(Neighbor::new(1, 1.0));
|
||||
queue.insert(Neighbor::new(2, 0.5));
|
||||
queue.insert(Neighbor::new(3, 1.5));
|
||||
assert_eq!(queue[0].id, 2);
|
||||
assert_eq!(queue[0].distance, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visit() {
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(3);
|
||||
queue.insert(Neighbor::new(1, 1.0));
|
||||
queue.insert(Neighbor::new(2, 0.5));
|
||||
queue.insert(Neighbor::new(3, 1.5)); // node id in queue should be [2,1,3]
|
||||
assert!(queue.has_notvisited_node());
|
||||
let nbr = queue.closest_notvisited();
|
||||
assert_eq!(nbr.id, 2);
|
||||
assert_eq!(nbr.distance, 0.5);
|
||||
assert!(nbr.visited);
|
||||
assert!(queue.has_notvisited_node());
|
||||
let nbr = queue.closest_notvisited();
|
||||
assert_eq!(nbr.id, 1);
|
||||
assert_eq!(nbr.distance, 1.0);
|
||||
assert!(nbr.visited);
|
||||
assert!(queue.has_notvisited_node());
|
||||
let nbr = queue.closest_notvisited();
|
||||
assert_eq!(nbr.id, 3);
|
||||
assert_eq!(nbr.distance, 1.5);
|
||||
assert!(nbr.visited);
|
||||
assert!(!queue.has_notvisited_node());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_queue() {
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(3);
|
||||
queue.insert(Neighbor::new(1, 1.0));
|
||||
queue.insert(Neighbor::new(2, 0.5));
|
||||
assert_eq!(queue.size(), 2);
|
||||
assert!(queue.has_notvisited_node());
|
||||
queue.clear();
|
||||
assert_eq!(queue.size(), 0);
|
||||
assert!(!queue.has_notvisited_node());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reserve() {
|
||||
let mut queue = NeighborPriorityQueue::new();
|
||||
queue.reserve(10);
|
||||
assert_eq!(queue.data.len(), 11);
|
||||
assert_eq!(queue.capacity, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_capacity() {
|
||||
let mut queue = NeighborPriorityQueue::with_capacity(10);
|
||||
queue.set_capacity(5);
|
||||
assert_eq!(queue.capacity, 5);
|
||||
assert_eq!(queue.data.len(), 11);
|
||||
|
||||
queue.set_capacity(11);
|
||||
assert_eq!(queue.capacity, 5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Sorted Neighbor Vector
|
||||
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use super::Neighbor;
|
||||
|
||||
/// A newtype on top of vector of neighbors, is sorted by distance
|
||||
#[derive(Debug)]
|
||||
pub struct SortedNeighborVector<'a>(&'a mut Vec<Neighbor>);
|
||||
|
||||
impl<'a> SortedNeighborVector<'a> {
|
||||
/// Create a new SortedNeighborVector
|
||||
pub fn new(vec: &'a mut Vec<Neighbor>) -> Self {
|
||||
vec.sort_unstable();
|
||||
Self(vec)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Deref for SortedNeighborVector<'a> {
|
||||
type Target = Vec<Neighbor>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> DerefMut for SortedNeighborVector<'a> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
483
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs
vendored
Normal file
483
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs
vendored
Normal file
@@ -0,0 +1,483 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations)]
|
||||
|
||||
use hashbrown::HashMap;
|
||||
use rayon::prelude::{
|
||||
IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut,
|
||||
};
|
||||
use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
|
||||
|
||||
use crate::{
|
||||
common::{ANNError, ANNResult},
|
||||
model::NUM_PQ_CENTROIDS,
|
||||
};
|
||||
|
||||
/// PQ Pivot table loading and calculate distance
|
||||
#[derive(Debug)]
|
||||
pub struct FixedChunkPQTable {
|
||||
/// pq_tables = float array of size [256 * ndims]
|
||||
pq_table: Vec<f32>,
|
||||
|
||||
/// ndims = true dimension of vectors
|
||||
dim: usize,
|
||||
|
||||
/// num_pq_chunks = the pq chunk number
|
||||
num_pq_chunks: usize,
|
||||
|
||||
/// chunk_offsets = the offset of each chunk, start from 0
|
||||
chunk_offsets: Vec<usize>,
|
||||
|
||||
/// centroid of each dimension
|
||||
centroids: Vec<f32>,
|
||||
|
||||
/// Becasue we're using L2 distance, this is no needed now.
|
||||
/// Transport of pq_table. transport_pq_table = float array of size [ndims * 256].
|
||||
/// e.g. if pa_table is 2 centroids * 3 dims
|
||||
/// [ 1, 2, 3,
|
||||
/// 4, 5, 6]
|
||||
/// then transport_pq_table would be 3 dims * 2 centroids
|
||||
/// [ 1, 4,
|
||||
/// 2, 5,
|
||||
/// 3, 6]
|
||||
/// transport_pq_table: Vec<f32>,
|
||||
|
||||
/// Map dim offset to chunk index e.g., 8 dims in to 2 chunks
|
||||
/// then would be [(0,0), (1,0), (2,0), (3,0), (4,1), (5,1), (6,1), (7,1)]
|
||||
dimoffset_chunk_mapping: HashMap<usize, usize>,
|
||||
}
|
||||
|
||||
impl FixedChunkPQTable {
|
||||
/// Create the FixedChunkPQTable with dim and chunk numbers and pivot file data (pivot table + cenroids + chunk offsets)
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
num_pq_chunks: usize,
|
||||
pq_table: Vec<f32>,
|
||||
centroids: Vec<f32>,
|
||||
chunk_offsets: Vec<usize>,
|
||||
) -> Self {
|
||||
let mut dimoffset_chunk_mapping = HashMap::new();
|
||||
for chunk_index in 0..num_pq_chunks {
|
||||
for dim_offset in chunk_offsets[chunk_index]..chunk_offsets[chunk_index + 1] {
|
||||
dimoffset_chunk_mapping.insert(dim_offset, chunk_index);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
pq_table,
|
||||
dim,
|
||||
num_pq_chunks,
|
||||
chunk_offsets,
|
||||
centroids,
|
||||
dimoffset_chunk_mapping,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get chunk number
|
||||
pub fn get_num_chunks(&self) -> usize {
|
||||
self.num_pq_chunks
|
||||
}
|
||||
|
||||
/// Shifting the query according to mean or the whole corpus
|
||||
pub fn preprocess_query(&self, query_vec: &mut [f32]) {
|
||||
for (query, ¢roid) in query_vec.iter_mut().zip(self.centroids.iter()) {
|
||||
*query -= centroid;
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-calculated the distance between query and each centroid by l2 distance
|
||||
/// * `query_vec` - query vector: 1 * dim
|
||||
/// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn populate_chunk_distances(&self, query_vec: &[f32]) -> Vec<f32> {
|
||||
let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS];
|
||||
for centroid_index in 0..NUM_PQ_CENTROIDS {
|
||||
for chunk_index in 0..self.num_pq_chunks {
|
||||
for dim_offset in
|
||||
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
|
||||
{
|
||||
let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset]
|
||||
- query_vec[dim_offset];
|
||||
dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] += diff * diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
dist_vec
|
||||
}
|
||||
|
||||
/// Pre-calculated the distance between query and each centroid by inner product
|
||||
/// * `query_vec` - query vector: 1 * dim
|
||||
/// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids
|
||||
///
|
||||
/// Reason to allow clippy::needless_range_loop:
|
||||
/// The inner loop is operating over a range that is different for each iteration of the outer loop.
|
||||
/// This isn't a scenario where using iter().enumerate() would be easily applicable,
|
||||
/// because the inner loop isn't iterating directly over the contents of a slice or array.
|
||||
/// Thus, using indexing might be the most straightforward way to express this logic.
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn populate_chunk_inner_products(&self, query_vec: &[f32]) -> Vec<f32> {
|
||||
let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS];
|
||||
for centroid_index in 0..NUM_PQ_CENTROIDS {
|
||||
for chunk_index in 0..self.num_pq_chunks {
|
||||
for dim_offset in
|
||||
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
|
||||
{
|
||||
// assumes that we are not shifting the vectors to mean zero, i.e., centroid
|
||||
// array should be all zeros returning negative to keep the search code
|
||||
// clean (max inner product vs min distance)
|
||||
let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset]
|
||||
* query_vec[dim_offset];
|
||||
dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] -= diff;
|
||||
}
|
||||
}
|
||||
}
|
||||
dist_vec
|
||||
}
|
||||
|
||||
/// Calculate the distance between query and given centroid by l2 distance
|
||||
/// * `query_vec` - query vector: 1 * dim
|
||||
/// * `base_vec` - given centroid array: 1 * num_pq_chunks
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn l2_distance(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 {
|
||||
let mut res_vec: Vec<f32> = vec![0.0; self.num_pq_chunks];
|
||||
res_vec
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(chunk_index, chunk_diff)| {
|
||||
for dim_offset in
|
||||
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
|
||||
{
|
||||
let diff = self.pq_table
|
||||
[self.dim * base_vec[chunk_index] as usize + dim_offset]
|
||||
- query_vec[dim_offset];
|
||||
*chunk_diff += diff * diff;
|
||||
}
|
||||
});
|
||||
|
||||
let res: f32 = res_vec.iter().sum::<f32>();
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// Calculate the distance between query and given centroid by inner product
|
||||
/// * `query_vec` - query vector: 1 * dim
|
||||
/// * `base_vec` - given centroid array: 1 * num_pq_chunks
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
pub fn inner_product(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 {
|
||||
let mut res_vec: Vec<f32> = vec![0.0; self.num_pq_chunks];
|
||||
res_vec
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(chunk_index, chunk_diff)| {
|
||||
for dim_offset in
|
||||
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
|
||||
{
|
||||
*chunk_diff += self.pq_table
|
||||
[self.dim * base_vec[chunk_index] as usize + dim_offset]
|
||||
* query_vec[dim_offset];
|
||||
}
|
||||
});
|
||||
|
||||
let res: f32 = res_vec.iter().sum::<f32>();
|
||||
|
||||
// returns negative value to simulate distances (max -> min conversion)
|
||||
-res
|
||||
}
|
||||
|
||||
/// Revert vector by adding centroid
|
||||
/// * `base_vec` - given centroid array: 1 * num_pq_chunks
|
||||
/// * `out_vec` - reverted vector
|
||||
pub fn inflate_vector(&self, base_vec: &[u8]) -> ANNResult<Vec<f32>> {
|
||||
let mut out_vec: Vec<f32> = vec![0.0; self.dim];
|
||||
for (dim_offset, value) in out_vec.iter_mut().enumerate() {
|
||||
let chunk_index =
|
||||
self.dimoffset_chunk_mapping
|
||||
.get(&dim_offset)
|
||||
.ok_or(ANNError::log_pq_error(
|
||||
"ERROR: dim_offset not found in dimoffset_chunk_mapping".to_string(),
|
||||
))?;
|
||||
*value = self.pq_table[self.dim * base_vec[*chunk_index] as usize + dim_offset]
|
||||
+ self.centroids[dim_offset];
|
||||
}
|
||||
|
||||
Ok(out_vec)
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a batch input nodes, return a batch of PQ distance
|
||||
/// * `pq_ids` - batch nodes: n_pts * pq_nchunks
|
||||
/// * `n_pts` - batch number
|
||||
/// * `pq_nchunks` - pq chunk number number
|
||||
/// * `pq_dists` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids
|
||||
/// * `dists_out` - n_pts * 1
|
||||
pub fn pq_dist_lookup(
|
||||
pq_ids: &[u8],
|
||||
n_pts: usize,
|
||||
pq_nchunks: usize,
|
||||
pq_dists: &[f32],
|
||||
) -> Vec<f32> {
|
||||
let mut dists_out: Vec<f32> = vec![0.0; n_pts];
|
||||
unsafe {
|
||||
_mm_prefetch(dists_out.as_ptr() as *const i8, _MM_HINT_T0);
|
||||
_mm_prefetch(pq_ids.as_ptr() as *const i8, _MM_HINT_T0);
|
||||
_mm_prefetch(pq_ids.as_ptr().add(64) as *const i8, _MM_HINT_T0);
|
||||
_mm_prefetch(pq_ids.as_ptr().add(128) as *const i8, _MM_HINT_T0);
|
||||
}
|
||||
for chunk in 0..pq_nchunks {
|
||||
let chunk_dists = &pq_dists[256 * chunk..];
|
||||
if chunk < pq_nchunks - 1 {
|
||||
unsafe {
|
||||
_mm_prefetch(
|
||||
chunk_dists.as_ptr().offset(256 * chunk as isize).add(256) as *const i8,
|
||||
_MM_HINT_T0,
|
||||
);
|
||||
}
|
||||
}
|
||||
dists_out
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(n_iter, dist)| {
|
||||
let pq_centerid = pq_ids[pq_nchunks * n_iter + chunk];
|
||||
*dist += chunk_dists[pq_centerid as usize];
|
||||
});
|
||||
}
|
||||
dists_out
|
||||
}
|
||||
|
||||
pub fn aggregate_coords(ids: &[u32], all_coords: &[u8], ndims: usize) -> Vec<u8> {
|
||||
let mut out: Vec<u8> = vec![0u8; ids.len() * ndims];
|
||||
let ndim_u32 = ndims as u32;
|
||||
out.par_chunks_mut(ndims)
|
||||
.enumerate()
|
||||
.for_each(|(index, chunk)| {
|
||||
let id_compressed_pivot = &all_coords
|
||||
[(ids[index] * ndim_u32) as usize..(ids[index] * ndim_u32 + ndim_u32) as usize];
|
||||
let temp_slice =
|
||||
unsafe { std::slice::from_raw_parts(id_compressed_pivot.as_ptr(), ndims) };
|
||||
chunk.copy_from_slice(temp_slice);
|
||||
});
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod fixed_chunk_pq_table_test {
|
||||
|
||||
use super::*;
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, file_exists, load_bin};
|
||||
|
||||
const DIM: usize = 128;
|
||||
|
||||
#[test]
|
||||
fn load_pivot_test() {
|
||||
let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
|
||||
let (dim, pq_table, centroids, chunk_offsets) =
|
||||
load_pq_pivots_bin(pq_pivots_path, &1).unwrap();
|
||||
let fixed_chunk_pq_table =
|
||||
FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets);
|
||||
|
||||
assert_eq!(dim, DIM);
|
||||
assert_eq!(fixed_chunk_pq_table.pq_table.len(), DIM * NUM_PQ_CENTROIDS);
|
||||
assert_eq!(fixed_chunk_pq_table.centroids.len(), DIM);
|
||||
|
||||
assert_eq!(fixed_chunk_pq_table.chunk_offsets[0], 0);
|
||||
assert_eq!(fixed_chunk_pq_table.chunk_offsets[1], DIM);
|
||||
assert_eq!(fixed_chunk_pq_table.chunk_offsets.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_num_chunks_test() {
|
||||
let num_chunks = 7;
|
||||
let pa_table = vec![0.0; DIM * NUM_PQ_CENTROIDS];
|
||||
let centroids = vec![0.0; DIM];
|
||||
let chunk_offsets = vec![0, 7, 9, 11, 22, 34, 78, 127];
|
||||
let fixed_chunk_pq_table =
|
||||
FixedChunkPQTable::new(DIM, num_chunks, pa_table, centroids, chunk_offsets);
|
||||
let chunk: usize = fixed_chunk_pq_table.get_num_chunks();
|
||||
assert_eq!(chunk, num_chunks);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preprocess_query_test() {
|
||||
let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
|
||||
let (dim, pq_table, centroids, chunk_offsets) =
|
||||
load_pq_pivots_bin(pq_pivots_path, &1).unwrap();
|
||||
let fixed_chunk_pq_table =
|
||||
FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets);
|
||||
|
||||
let mut query_vec: Vec<f32> = vec![
|
||||
32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32,
|
||||
68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32,
|
||||
7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32,
|
||||
65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32,
|
||||
11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32,
|
||||
59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32,
|
||||
7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32,
|
||||
48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32,
|
||||
47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32,
|
||||
74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32,
|
||||
65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32,
|
||||
41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32,
|
||||
84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32,
|
||||
74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32,
|
||||
37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32,
|
||||
];
|
||||
fixed_chunk_pq_table.preprocess_query(&mut query_vec);
|
||||
assert_eq!(query_vec[0], 32.39f32 - fixed_chunk_pq_table.centroids[0]);
|
||||
assert_eq!(
|
||||
query_vec[127],
|
||||
35.6f32 - fixed_chunk_pq_table.centroids[127]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculate_distances_tests() {
|
||||
let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
|
||||
|
||||
let (dim, pq_table, centroids, chunk_offsets) =
|
||||
load_pq_pivots_bin(pq_pivots_path, &1).unwrap();
|
||||
let fixed_chunk_pq_table =
|
||||
FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets);
|
||||
|
||||
let query_vec: Vec<f32> = vec![
|
||||
32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32,
|
||||
68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32,
|
||||
7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32,
|
||||
65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32,
|
||||
11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32,
|
||||
59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32,
|
||||
7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32,
|
||||
48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32,
|
||||
47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32,
|
||||
74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32,
|
||||
65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32,
|
||||
41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32,
|
||||
84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32,
|
||||
74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32,
|
||||
37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32,
|
||||
];
|
||||
|
||||
let dist_vec = fixed_chunk_pq_table.populate_chunk_distances(&query_vec);
|
||||
assert_eq!(dist_vec.len(), 256);
|
||||
|
||||
// populate_chunk_distances_test
|
||||
let mut sampled_output = 0.0;
|
||||
(0..DIM).for_each(|dim_offset| {
|
||||
let diff = fixed_chunk_pq_table.pq_table[dim_offset] - query_vec[dim_offset];
|
||||
sampled_output += diff * diff;
|
||||
});
|
||||
assert_eq!(sampled_output, dist_vec[0]);
|
||||
|
||||
// populate_chunk_inner_products_test
|
||||
let dist_vec = fixed_chunk_pq_table.populate_chunk_inner_products(&query_vec);
|
||||
assert_eq!(dist_vec.len(), 256);
|
||||
|
||||
let mut sampled_output = 0.0;
|
||||
(0..DIM).for_each(|dim_offset| {
|
||||
sampled_output -= fixed_chunk_pq_table.pq_table[dim_offset] * query_vec[dim_offset];
|
||||
});
|
||||
assert_eq!(sampled_output, dist_vec[0]);
|
||||
|
||||
// l2_distance_test
|
||||
let base_vec: Vec<u8> = vec![3u8];
|
||||
let dist = fixed_chunk_pq_table.l2_distance(&query_vec, &base_vec);
|
||||
let mut l2_output = 0.0;
|
||||
(0..DIM).for_each(|dim_offset| {
|
||||
let diff = fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] - query_vec[dim_offset];
|
||||
l2_output += diff * diff;
|
||||
});
|
||||
assert_eq!(l2_output, dist);
|
||||
|
||||
// inner_product_test
|
||||
let dist = fixed_chunk_pq_table.inner_product(&query_vec, &base_vec);
|
||||
let mut l2_output = 0.0;
|
||||
(0..DIM).for_each(|dim_offset| {
|
||||
l2_output -=
|
||||
fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] * query_vec[dim_offset];
|
||||
});
|
||||
assert_eq!(l2_output, dist);
|
||||
|
||||
// inflate_vector_test
|
||||
let inflate_vector = fixed_chunk_pq_table.inflate_vector(&base_vec).unwrap();
|
||||
assert_eq!(inflate_vector.len(), DIM);
|
||||
assert_eq!(
|
||||
inflate_vector[0],
|
||||
fixed_chunk_pq_table.pq_table[3 * DIM] + fixed_chunk_pq_table.centroids[0]
|
||||
);
|
||||
assert_eq!(
|
||||
inflate_vector[1],
|
||||
fixed_chunk_pq_table.pq_table[3 * DIM + 1] + fixed_chunk_pq_table.centroids[1]
|
||||
);
|
||||
assert_eq!(
|
||||
inflate_vector[127],
|
||||
fixed_chunk_pq_table.pq_table[3 * DIM + 127] + fixed_chunk_pq_table.centroids[127]
|
||||
);
|
||||
}
|
||||
|
||||
fn load_pq_pivots_bin(
|
||||
pq_pivots_path: &str,
|
||||
num_pq_chunks: &usize,
|
||||
) -> ANNResult<(usize, Vec<f32>, Vec<f32>, Vec<usize>)> {
|
||||
if !file_exists(pq_pivots_path) {
|
||||
return Err(ANNError::log_pq_error(
|
||||
"ERROR: PQ k-means pivot file not found.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (data, offset_num, offset_dim) = load_bin::<u64>(pq_pivots_path, 0)?;
|
||||
let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim);
|
||||
if offset_num != 4 {
|
||||
let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, pq_center_num, dim) = load_bin::<f32>(pq_pivots_path, file_offset_data[0])?;
|
||||
let pq_table = data.to_vec();
|
||||
if pq_center_num != NUM_PQ_CENTROIDS {
|
||||
let error_message = format!(
|
||||
"Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.",
|
||||
pq_pivots_path, pq_center_num, NUM_PQ_CENTROIDS
|
||||
);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, centroid_dim, nc) = load_bin::<f32>(pq_pivots_path, file_offset_data[1])?;
|
||||
let centroids = data.to_vec();
|
||||
if centroid_dim != dim || nc != 1 {
|
||||
let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, chunk_offset_num, nc) = load_bin::<u32>(pq_pivots_path, file_offset_data[2])?;
|
||||
let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc);
|
||||
if chunk_offset_num != num_pq_chunks + 1 || nc != 1 {
|
||||
let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
Ok((dim, pq_table, centroids, chunk_offsets))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pq_index_prune_query_test {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn pq_dist_lookup_test() {
|
||||
let pq_ids: Vec<u8> = vec![1u8, 3u8, 2u8, 2u8];
|
||||
let mut pq_dists: Vec<f32> = Vec::with_capacity(256 * 2);
|
||||
for _ in 0..pq_dists.capacity() {
|
||||
pq_dists.push(rand::random());
|
||||
}
|
||||
|
||||
let dists_out = pq_dist_lookup(&pq_ids, 2, 2, &pq_dists);
|
||||
assert_eq!(dists_out.len(), 2);
|
||||
assert_eq!(dists_out[0], pq_dists[0 + 1] + pq_dists[256 + 3]);
|
||||
assert_eq!(dists_out[1], pq_dists[0 + 2] + pq_dists[256 + 2]);
|
||||
}
|
||||
}
|
||||
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/pq/mod.rs
vendored
Normal file
9
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/pq/mod.rs
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
mod fixed_chunk_pq_table;
|
||||
pub use fixed_chunk_pq_table::*;
|
||||
|
||||
mod pq_construction;
|
||||
pub use pq_construction::*;
|
||||
398
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/pq/pq_construction.rs
vendored
Normal file
398
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/pq/pq_construction.rs
vendored
Normal file
@@ -0,0 +1,398 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations)]
|
||||
|
||||
use rayon::prelude::{IndexedParallelIterator, ParallelIterator};
|
||||
use rayon::slice::ParallelSliceMut;
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::storage::PQStorage;
|
||||
use crate::utils::{compute_closest_centers, file_exists, k_means_clustering};
|
||||
|
||||
/// Max size of PQ training set
|
||||
pub const MAX_PQ_TRAINING_SET_SIZE: f64 = 256_000f64;
|
||||
|
||||
/// Max number of PQ chunks
|
||||
pub const MAX_PQ_CHUNKS: usize = 512;
|
||||
|
||||
pub const NUM_PQ_CENTROIDS: usize = 256;
|
||||
/// block size for reading/processing large files and matrices in blocks
|
||||
const BLOCK_SIZE: usize = 5000000;
|
||||
const NUM_KMEANS_REPS_PQ: usize = 12;
|
||||
|
||||
/// given training data in train_data of dimensions num_train * dim, generate
|
||||
/// PQ pivots using k-means algorithm to partition the co-ordinates into
|
||||
/// num_pq_chunks (if it divides dimension, else rounded) chunks, and runs
|
||||
/// k-means in each chunk to compute the PQ pivots and stores in bin format in
|
||||
/// file pq_pivots_path as a s num_centers*dim floating point binary file
|
||||
/// PQ pivot table layout: {pivot offsets data: METADATA_SIZE}{pivot vector:[dim; num_centroid]}{centroid vector:[dim; 1]}{chunk offsets:[chunk_num+1; 1]}
|
||||
fn generate_pq_pivots(
|
||||
train_data: &mut [f32],
|
||||
num_train: usize,
|
||||
dim: usize,
|
||||
num_centers: usize,
|
||||
num_pq_chunks: usize,
|
||||
max_k_means_reps: usize,
|
||||
pq_storage: &mut PQStorage,
|
||||
) -> ANNResult<()> {
|
||||
if num_pq_chunks > dim {
|
||||
return Err(ANNError::log_pq_error(
|
||||
"Error: number of chunks more than dimension.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if pq_storage.pivot_data_exist() {
|
||||
let (file_num_centers, file_dim) = pq_storage.read_pivot_metadata()?;
|
||||
if file_dim == dim && file_num_centers == num_centers {
|
||||
// PQ pivot file exists. Not generating again.
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate centroid and center the training data
|
||||
// If we use L2 distance, there is an option to
|
||||
// translate all vectors to make them centered and
|
||||
// then compute PQ. This needs to be set to false
|
||||
// when using PQ for MIPS as such translations dont
|
||||
// preserve inner products.
|
||||
// Now, we're using L2 as default.
|
||||
let mut centroid: Vec<f32> = vec![0.0; dim];
|
||||
for dim_index in 0..dim {
|
||||
for train_data_index in 0..num_train {
|
||||
centroid[dim_index] += train_data[train_data_index * dim + dim_index];
|
||||
}
|
||||
centroid[dim_index] /= num_train as f32;
|
||||
}
|
||||
for dim_index in 0..dim {
|
||||
for train_data_index in 0..num_train {
|
||||
train_data[train_data_index * dim + dim_index] -= centroid[dim_index];
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate each chunk's offset
|
||||
// If we have 8 dimension and 3 chunk then offsets would be [0,3,6,8]
|
||||
let mut chunk_offsets: Vec<usize> = vec![0; num_pq_chunks + 1];
|
||||
let mut chunk_offset: usize = 0;
|
||||
for chunk_index in 0..num_pq_chunks {
|
||||
chunk_offset += dim / num_pq_chunks;
|
||||
if chunk_index < (dim % num_pq_chunks) {
|
||||
chunk_offset += 1;
|
||||
}
|
||||
chunk_offsets[chunk_index + 1] = chunk_offset;
|
||||
}
|
||||
|
||||
let mut full_pivot_data: Vec<f32> = vec![0.0; num_centers * dim];
|
||||
for chunk_index in 0..num_pq_chunks {
|
||||
let chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index];
|
||||
|
||||
let mut cur_train_data: Vec<f32> = vec![0.0; num_train * chunk_size];
|
||||
let mut cur_pivot_data: Vec<f32> = vec![0.0; num_centers * chunk_size];
|
||||
|
||||
cur_train_data
|
||||
.par_chunks_mut(chunk_size)
|
||||
.enumerate()
|
||||
.for_each(|(train_data_index, chunk)| {
|
||||
for (dim_offset, item) in chunk.iter_mut().enumerate() {
|
||||
*item = train_data
|
||||
[train_data_index * dim + chunk_offsets[chunk_index] + dim_offset];
|
||||
}
|
||||
});
|
||||
|
||||
// Run kmeans to get the centroids of this chunk.
|
||||
let (_closest_docs, _closest_center, _residual) = k_means_clustering(
|
||||
&cur_train_data,
|
||||
num_train,
|
||||
chunk_size,
|
||||
&mut cur_pivot_data,
|
||||
num_centers,
|
||||
max_k_means_reps,
|
||||
)?;
|
||||
|
||||
// Copy centroids from this chunk table to full table
|
||||
for center_index in 0..num_centers {
|
||||
full_pivot_data[center_index * dim + chunk_offsets[chunk_index]
|
||||
..center_index * dim + chunk_offsets[chunk_index + 1]]
|
||||
.copy_from_slice(
|
||||
&cur_pivot_data[center_index * chunk_size..(center_index + 1) * chunk_size],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pq_storage.write_pivot_data(
|
||||
&full_pivot_data,
|
||||
¢roid,
|
||||
&chunk_offsets,
|
||||
num_centers,
|
||||
dim,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// streams the base file (data_file), and computes the closest centers in each
|
||||
/// chunk to generate the compressed data_file and stores it in
|
||||
/// pq_compressed_vectors_path.
|
||||
/// If the numbber of centers is < 256, it stores as byte vector, else as
|
||||
/// 4-byte vector in binary format.
|
||||
/// Compressed PQ table layout: {num_points: usize}{num_chunks: usize}{compressed pq table: [num_points; num_chunks]}
|
||||
fn generate_pq_data_from_pivots<T: Copy + Into<f32>>(
|
||||
num_centers: usize,
|
||||
num_pq_chunks: usize,
|
||||
pq_storage: &mut PQStorage,
|
||||
) -> ANNResult<()> {
|
||||
let (num_points, dim) = pq_storage.read_pq_data_metadata()?;
|
||||
|
||||
let full_pivot_data: Vec<f32>;
|
||||
let centroid: Vec<f32>;
|
||||
let chunk_offsets: Vec<usize>;
|
||||
|
||||
if !pq_storage.pivot_data_exist() {
|
||||
return Err(ANNError::log_pq_error(
|
||||
"ERROR: PQ k-means pivot file not found.".to_string(),
|
||||
));
|
||||
} else {
|
||||
(full_pivot_data, centroid, chunk_offsets) =
|
||||
pq_storage.load_pivot_data(&num_pq_chunks, &num_centers, &dim)?;
|
||||
}
|
||||
|
||||
pq_storage.write_compressed_pivot_metadata(num_points as i32, num_pq_chunks as i32)?;
|
||||
|
||||
let block_size = if num_points <= BLOCK_SIZE {
|
||||
num_points
|
||||
} else {
|
||||
BLOCK_SIZE
|
||||
};
|
||||
let num_blocks = (num_points / block_size) + (num_points % block_size != 0) as usize;
|
||||
|
||||
for block_index in 0..num_blocks {
|
||||
let start_index: usize = block_index * block_size;
|
||||
let end_index: usize = std::cmp::min((block_index + 1) * block_size, num_points);
|
||||
let cur_block_size: usize = end_index - start_index;
|
||||
|
||||
let mut block_compressed_base: Vec<usize> = vec![0; cur_block_size * num_pq_chunks];
|
||||
|
||||
let block_data: Vec<T> = pq_storage.read_pq_block_data(cur_block_size, dim)?;
|
||||
|
||||
let mut adjusted_block_data: Vec<f32> = vec![0.0; cur_block_size * dim];
|
||||
|
||||
for block_data_index in 0..cur_block_size {
|
||||
for dim_index in 0..dim {
|
||||
adjusted_block_data[block_data_index * dim + dim_index] =
|
||||
block_data[block_data_index * dim + dim_index].into() - centroid[dim_index];
|
||||
}
|
||||
}
|
||||
|
||||
for chunk_index in 0..num_pq_chunks {
|
||||
let cur_chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index];
|
||||
if cur_chunk_size == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut cur_pivot_data: Vec<f32> = vec![0.0; num_centers * cur_chunk_size];
|
||||
let mut cur_data: Vec<f32> = vec![0.0; cur_block_size * cur_chunk_size];
|
||||
let mut closest_center: Vec<u32> = vec![0; cur_block_size];
|
||||
|
||||
// Divide the data into chunks and process each chunk in parallel.
|
||||
cur_data
|
||||
.par_chunks_mut(cur_chunk_size)
|
||||
.enumerate()
|
||||
.for_each(|(block_data_index, chunk)| {
|
||||
for (dim_offset, item) in chunk.iter_mut().enumerate() {
|
||||
*item = adjusted_block_data
|
||||
[block_data_index * dim + chunk_offsets[chunk_index] + dim_offset];
|
||||
}
|
||||
});
|
||||
|
||||
cur_pivot_data
|
||||
.par_chunks_mut(cur_chunk_size)
|
||||
.enumerate()
|
||||
.for_each(|(center_index, chunk)| {
|
||||
for (din_offset, item) in chunk.iter_mut().enumerate() {
|
||||
*item = full_pivot_data
|
||||
[center_index * dim + chunk_offsets[chunk_index] + din_offset];
|
||||
}
|
||||
});
|
||||
|
||||
// Compute the closet centers
|
||||
compute_closest_centers(
|
||||
&cur_data,
|
||||
cur_block_size,
|
||||
cur_chunk_size,
|
||||
&cur_pivot_data,
|
||||
num_centers,
|
||||
1,
|
||||
&mut closest_center,
|
||||
None,
|
||||
None,
|
||||
)?;
|
||||
|
||||
block_compressed_base
|
||||
.par_chunks_mut(num_pq_chunks)
|
||||
.enumerate()
|
||||
.for_each(|(block_data_index, slice)| {
|
||||
slice[chunk_index] = closest_center[block_data_index] as usize;
|
||||
});
|
||||
}
|
||||
|
||||
_ = pq_storage.write_compressed_pivot_data(
|
||||
&block_compressed_base,
|
||||
num_centers,
|
||||
cur_block_size,
|
||||
num_pq_chunks,
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Save the data on a file.
|
||||
/// # Arguments
|
||||
/// * `p_val` - choose how many ratio sample data as trained data to get pivot
|
||||
/// * `num_pq_chunks` - pq chunk number
|
||||
/// * `codebook_prefix` - predefined pivots file named
|
||||
/// * `pq_storage` - pq file access
|
||||
pub fn generate_quantized_data<T: Default + Copy + Into<f32>>(
|
||||
p_val: f64,
|
||||
num_pq_chunks: usize,
|
||||
codebook_prefix: &str,
|
||||
pq_storage: &mut PQStorage,
|
||||
) -> ANNResult<()> {
|
||||
// If predefined pivots already exists, skip training.
|
||||
if !file_exists(codebook_prefix) {
|
||||
// Instantiates train data with random sample updates train_data_vector
|
||||
// Training data with train_size samples loaded.
|
||||
// Each sampled file has train_dim.
|
||||
let (mut train_data_vector, train_size, train_dim) =
|
||||
pq_storage.gen_random_slice::<T>(p_val)?;
|
||||
|
||||
generate_pq_pivots(
|
||||
&mut train_data_vector,
|
||||
train_size,
|
||||
train_dim,
|
||||
NUM_PQ_CENTROIDS,
|
||||
num_pq_chunks,
|
||||
NUM_KMEANS_REPS_PQ,
|
||||
pq_storage,
|
||||
)?;
|
||||
}
|
||||
generate_pq_data_from_pivots::<T>(NUM_PQ_CENTROIDS, num_pq_chunks, pq_storage)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pq_test {
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
|
||||
use super::*;
|
||||
use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, METADATA_SIZE};
|
||||
|
||||
#[test]
|
||||
fn generate_pq_pivots_test() {
|
||||
let pivot_file_name = "generate_pq_pivots_test.bin";
|
||||
let compressed_file_name = "compressed.bin";
|
||||
let pq_training_file_name = "tests/data/siftsmall_learn.bin";
|
||||
let mut pq_storage =
|
||||
PQStorage::new(pivot_file_name, compressed_file_name, pq_training_file_name).unwrap();
|
||||
let mut train_data: Vec<f32> = vec![
|
||||
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
|
||||
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
|
||||
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
|
||||
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
|
||||
];
|
||||
generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap();
|
||||
|
||||
let (data, nr, nc) = load_bin::<u64>(pivot_file_name, 0).unwrap();
|
||||
let file_offset_data = convert_types_u64_usize(&data, nr, nc);
|
||||
assert_eq!(file_offset_data[0], METADATA_SIZE);
|
||||
assert_eq!(nr, 4);
|
||||
assert_eq!(nc, 1);
|
||||
|
||||
let (data, nr, nc) = load_bin::<f32>(pivot_file_name, file_offset_data[0]).unwrap();
|
||||
let full_pivot_data = data.to_vec();
|
||||
assert_eq!(full_pivot_data.len(), 16);
|
||||
assert_eq!(nr, 2);
|
||||
assert_eq!(nc, 8);
|
||||
|
||||
let (data, nr, nc) = load_bin::<f32>(pivot_file_name, file_offset_data[1]).unwrap();
|
||||
let centroid = data.to_vec();
|
||||
assert_eq!(
|
||||
centroid[0],
|
||||
(1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32
|
||||
);
|
||||
assert_eq!(nr, 8);
|
||||
assert_eq!(nc, 1);
|
||||
|
||||
let (data, nr, nc) = load_bin::<u32>(pivot_file_name, file_offset_data[2]).unwrap();
|
||||
let chunk_offsets = convert_types_u32_usize(&data, nr, nc);
|
||||
assert_eq!(chunk_offsets[0], 0);
|
||||
assert_eq!(chunk_offsets[1], 4);
|
||||
assert_eq!(chunk_offsets[2], 8);
|
||||
assert_eq!(nr, 3);
|
||||
assert_eq!(nc, 1);
|
||||
std::fs::remove_file(pivot_file_name).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_pq_data_from_pivots_test() {
|
||||
let data_file = "generate_pq_data_from_pivots_test_data.bin";
|
||||
//npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8]
|
||||
let mut train_data: Vec<f32> = vec![
|
||||
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
|
||||
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
|
||||
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
|
||||
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
|
||||
];
|
||||
let my_nums_unstructured: &[u8] = unsafe {
|
||||
std::slice::from_raw_parts(train_data.as_ptr() as *const u8, train_data.len() * 4)
|
||||
};
|
||||
let meta: Vec<i32> = vec![5, 8];
|
||||
let meta_unstructured: &[u8] =
|
||||
unsafe { std::slice::from_raw_parts(meta.as_ptr() as *const u8, meta.len() * 4) };
|
||||
let mut data_file_writer = File::create(data_file).unwrap();
|
||||
data_file_writer
|
||||
.write_all(meta_unstructured)
|
||||
.expect("Failed to write sample file");
|
||||
data_file_writer
|
||||
.write_all(my_nums_unstructured)
|
||||
.expect("Failed to write sample file");
|
||||
|
||||
let pq_pivots_path = "generate_pq_data_from_pivots_test_pivot.bin";
|
||||
let pq_compressed_vectors_path = "generate_pq_data_from_pivots_test.bin";
|
||||
let mut pq_storage =
|
||||
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap();
|
||||
generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap();
|
||||
generate_pq_data_from_pivots::<f32>(2, 2, &mut pq_storage).unwrap();
|
||||
let (data, nr, nc) = load_bin::<u8>(pq_compressed_vectors_path, 0).unwrap();
|
||||
assert_eq!(nr, 5);
|
||||
assert_eq!(nc, 2);
|
||||
assert_eq!(data[0], data[2]);
|
||||
assert_ne!(data[0], data[8]);
|
||||
|
||||
std::fs::remove_file(data_file).unwrap();
|
||||
std::fs::remove_file(pq_pivots_path).unwrap();
|
||||
std::fs::remove_file(pq_compressed_vectors_path).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pq_end_to_end_validation_with_codebook_test() {
|
||||
let data_file = "tests/data/siftsmall_learn.bin";
|
||||
let pq_pivots_path = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
|
||||
let gound_truth_path = "tests/data/siftsmall_learn.bin_pq_compressed.bin";
|
||||
let pq_compressed_vectors_path = "validation.bin";
|
||||
let mut pq_storage =
|
||||
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap();
|
||||
generate_quantized_data::<f32>(0.5, 1, pq_pivots_path, &mut pq_storage).unwrap();
|
||||
|
||||
let (data, nr, nc) = load_bin::<u8>(pq_compressed_vectors_path, 0).unwrap();
|
||||
let (gt_data, gt_nr, gt_nc) = load_bin::<u8>(gound_truth_path, 0).unwrap();
|
||||
assert_eq!(nr, gt_nr);
|
||||
assert_eq!(nc, gt_nc);
|
||||
for i in 0..data.len() {
|
||||
assert_eq!(data[i], gt_data[i]);
|
||||
}
|
||||
std::fs::remove_file(pq_compressed_vectors_path).unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Aligned allocator
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::ops::Deref;
|
||||
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Query scratch data structures
|
||||
pub struct ConcurrentQueue<T> {
|
||||
q: Mutex<VecDeque<T>>,
|
||||
c: Mutex<bool>,
|
||||
push_cv: Condvar,
|
||||
}
|
||||
|
||||
impl Default for ConcurrentQueue<usize> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ConcurrentQueue<T> {
|
||||
/// Create a concurrent queue
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
q: Mutex::new(VecDeque::new()),
|
||||
c: Mutex::new(false),
|
||||
push_cv: Condvar::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Block the current thread until it is able to acquire the mutex
|
||||
pub fn reserve(&self, size: usize) -> ANNResult<()> {
|
||||
let mut guard = lock(&self.q)?;
|
||||
guard.reserve(size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// queue stats
|
||||
pub fn size(&self) -> ANNResult<usize> {
|
||||
let guard = lock(&self.q)?;
|
||||
|
||||
Ok(guard.len())
|
||||
}
|
||||
|
||||
/// empty the queue
|
||||
pub fn is_empty(&self) -> ANNResult<bool> {
|
||||
Ok(self.size()? == 0)
|
||||
}
|
||||
|
||||
/// push back
|
||||
pub fn push(&self, new_val: T) -> ANNResult<()> {
|
||||
let mut guard = lock(&self.q)?;
|
||||
self.push_internal(&mut guard, new_val);
|
||||
self.push_cv.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// push back
|
||||
fn push_internal(&self, guard: &mut MutexGuard<VecDeque<T>>, new_val: T) {
|
||||
guard.push_back(new_val);
|
||||
}
|
||||
|
||||
/// insert into queue
|
||||
pub fn insert<I>(&self, iter: I) -> ANNResult<()>
|
||||
where
|
||||
I: IntoIterator<Item = T>,
|
||||
{
|
||||
let mut guard = lock(&self.q)?;
|
||||
for item in iter {
|
||||
self.push_internal(&mut guard, item);
|
||||
}
|
||||
|
||||
self.push_cv.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// pop front
|
||||
pub fn pop(&self) -> ANNResult<Option<T>> {
|
||||
let mut guard = lock(&self.q)?;
|
||||
Ok(guard.pop_front())
|
||||
}
|
||||
|
||||
/// Empty - is this necessary?
|
||||
pub fn empty_queue(&self) -> ANNResult<()> {
|
||||
let mut guard = lock(&self.q)?;
|
||||
while !guard.is_empty() {
|
||||
let _ = guard.pop_front();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// register for push notifications
|
||||
pub fn wait_for_push_notify(&self, wait_time: Duration) -> ANNResult<()> {
|
||||
let guard_lock = lock(&self.c)?;
|
||||
let _ = self
|
||||
.push_cv
|
||||
.wait_timeout(guard_lock, wait_time)
|
||||
.map_err(|err| {
|
||||
ANNError::log_lock_poison_error(format!(
|
||||
"ConcurrentQueue Lock is poisoned, err={}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn lock<T>(mutex: &Mutex<T>) -> ANNResult<MutexGuard<T>> {
|
||||
let guard = mutex.lock().map_err(|err| {
|
||||
ANNError::log_lock_poison_error(format!("ConcurrentQueue lock is poisoned, err={}", err))
|
||||
})?;
|
||||
Ok(guard)
|
||||
}
|
||||
|
||||
/// A thread-safe queue that holds instances of `T`.
|
||||
/// Each instance is stored in a `Box` to keep the size of the queue node constant.
|
||||
#[derive(Debug)]
|
||||
pub struct ArcConcurrentBoxedQueue<T> {
|
||||
internal_queue: Arc<ConcurrentQueue<Box<T>>>,
|
||||
}
|
||||
|
||||
impl<T> ArcConcurrentBoxedQueue<T> {
|
||||
/// Create a new `ArcConcurrentBoxedQueue`.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
internal_queue: Arc::new(ConcurrentQueue::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for ArcConcurrentBoxedQueue<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Clone for ArcConcurrentBoxedQueue<T> {
|
||||
/// Create a new `ArcConcurrentBoxedQueue` that shares the same internal queue
|
||||
/// with the existing one. This allows multiple `ArcConcurrentBoxedQueue` to
|
||||
/// operate on the same underlying queue.
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
internal_queue: Arc::clone(&self.internal_queue),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Deref to the ConcurrentQueue.
|
||||
impl<T> Deref for ArcConcurrentBoxedQueue<T> {
|
||||
type Target = ConcurrentQueue<Box<T>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.internal_queue
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model::ConcurrentQueue;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[test]
|
||||
fn test_push_pop() {
|
||||
let queue = ConcurrentQueue::<i32>::new();
|
||||
|
||||
queue.push(1).unwrap();
|
||||
queue.push(2).unwrap();
|
||||
queue.push(3).unwrap();
|
||||
|
||||
assert_eq!(queue.pop().unwrap(), Some(1));
|
||||
assert_eq!(queue.pop().unwrap(), Some(2));
|
||||
assert_eq!(queue.pop().unwrap(), Some(3));
|
||||
assert_eq!(queue.pop().unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_size_empty() {
|
||||
let queue = ConcurrentQueue::new();
|
||||
|
||||
assert_eq!(queue.size().unwrap(), 0);
|
||||
assert!(queue.is_empty().unwrap());
|
||||
|
||||
queue.push(1).unwrap();
|
||||
queue.push(2).unwrap();
|
||||
|
||||
assert_eq!(queue.size().unwrap(), 2);
|
||||
assert!(!queue.is_empty().unwrap());
|
||||
|
||||
queue.pop().unwrap();
|
||||
queue.pop().unwrap();
|
||||
|
||||
assert_eq!(queue.size().unwrap(), 0);
|
||||
assert!(queue.is_empty().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert() {
|
||||
let queue = ConcurrentQueue::new();
|
||||
|
||||
let data = vec![1, 2, 3];
|
||||
queue.insert(data.into_iter()).unwrap();
|
||||
|
||||
assert_eq!(queue.pop().unwrap(), Some(1));
|
||||
assert_eq!(queue.pop().unwrap(), Some(2));
|
||||
assert_eq!(queue.pop().unwrap(), Some(3));
|
||||
assert_eq!(queue.pop().unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_notifications() {
|
||||
let queue = Arc::new(ConcurrentQueue::new());
|
||||
let queue_clone = Arc::clone(&queue);
|
||||
|
||||
let producer = thread::spawn(move || {
|
||||
for i in 0..3 {
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
queue_clone.push(i).unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let consumer = thread::spawn(move || {
|
||||
let mut values = vec![];
|
||||
|
||||
for _ in 0..3 {
|
||||
let mut val = -1;
|
||||
while val == -1 {
|
||||
queue
|
||||
.wait_for_push_notify(Duration::from_millis(10))
|
||||
.unwrap();
|
||||
val = queue.pop().unwrap().unwrap_or(-1);
|
||||
}
|
||||
|
||||
values.push(val);
|
||||
}
|
||||
|
||||
values
|
||||
});
|
||||
|
||||
producer.join().unwrap();
|
||||
let consumer_results = consumer.join().unwrap();
|
||||
|
||||
assert_eq!(consumer_results, vec![0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multithreaded_push_pop() {
|
||||
let queue = Arc::new(ConcurrentQueue::new());
|
||||
let queue_clone = Arc::clone(&queue);
|
||||
|
||||
let producer = thread::spawn(move || {
|
||||
for i in 0..10 {
|
||||
queue_clone.push(i).unwrap();
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
});
|
||||
|
||||
let consumer = thread::spawn(move || {
|
||||
let mut values = vec![];
|
||||
|
||||
for _ in 0..10 {
|
||||
let mut val = -1;
|
||||
while val == -1 {
|
||||
val = queue.pop().unwrap().unwrap_or(-1);
|
||||
thread::sleep(Duration::from_millis(10));
|
||||
}
|
||||
|
||||
values.push(val);
|
||||
}
|
||||
|
||||
values
|
||||
});
|
||||
|
||||
producer.join().unwrap();
|
||||
let consumer_results = consumer.join().unwrap();
|
||||
|
||||
assert_eq!(consumer_results, (0..10).collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
/// This is a single value test. It avoids the unlimited wait until the collectin got empty on the previous test.
|
||||
/// It will make sure the signal mutex is matching the waiting mutex.
|
||||
#[test]
|
||||
fn test_wait_for_push_notify() {
|
||||
let queue = Arc::new(ConcurrentQueue::<usize>::new());
|
||||
let queue_clone = Arc::clone(&queue);
|
||||
|
||||
let producer = thread::spawn(move || {
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
queue_clone.push(1).unwrap();
|
||||
});
|
||||
|
||||
let consumer = thread::spawn(move || {
|
||||
queue
|
||||
.wait_for_push_notify(Duration::from_millis(200))
|
||||
.unwrap();
|
||||
assert_eq!(queue.pop().unwrap(), Some(1));
|
||||
});
|
||||
|
||||
producer.join().unwrap();
|
||||
consumer.join().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Scratch space for in-memory index based search
|
||||
|
||||
use std::cmp::max;
|
||||
use std::mem;
|
||||
|
||||
use hashbrown::HashSet;
|
||||
|
||||
use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice};
|
||||
use crate::model::configuration::index_write_parameters::IndexWriteParameters;
|
||||
use crate::model::{Neighbor, NeighborPriorityQueue, PQScratch};
|
||||
|
||||
use super::Scratch;
|
||||
|
||||
/// In-mem index related limits
|
||||
pub const GRAPH_SLACK_FACTOR: f64 = 1.3_f64;
|
||||
|
||||
/// Max number of points for using bitset
|
||||
pub const MAX_POINTS_FOR_USING_BITSET: usize = 100000;
|
||||
|
||||
/// TODO: SSD Index related limits
|
||||
pub const MAX_GRAPH_DEGREE: usize = 512;
|
||||
|
||||
/// TODO: SSD Index related limits
|
||||
pub const MAX_N_CMPS: usize = 16384;
|
||||
|
||||
/// TODO: SSD Index related limits
|
||||
pub const SECTOR_LEN: usize = 4096;
|
||||
|
||||
/// TODO: SSD Index related limits
|
||||
pub const MAX_N_SECTOR_READS: usize = 128;
|
||||
|
||||
/// The alignment required for memory access. This will be multiplied with size of T to get the actual alignment
|
||||
pub const QUERY_ALIGNMENT_OF_T_SIZE: usize = 16;
|
||||
|
||||
/// Scratch space for in-memory index based search
|
||||
#[derive(Debug)]
|
||||
pub struct InMemQueryScratch<T, const N: usize> {
|
||||
/// Size of the candidate queue
|
||||
pub candidate_size: u32,
|
||||
|
||||
/// Max degree for each vertex
|
||||
pub max_degree: u32,
|
||||
|
||||
/// Max occlusion size
|
||||
pub max_occlusion_size: u32,
|
||||
|
||||
/// Query node
|
||||
pub query: AlignedBoxWithSlice<T>,
|
||||
|
||||
/// Best candidates, whose size is candidate_queue_size
|
||||
pub best_candidates: NeighborPriorityQueue,
|
||||
|
||||
/// Occlude factor
|
||||
pub occlude_factor: Vec<f32>,
|
||||
|
||||
/// Visited neighbor id
|
||||
pub id_scratch: Vec<u32>,
|
||||
|
||||
/// The distance between visited neighbor and query node
|
||||
pub dist_scratch: Vec<f32>,
|
||||
|
||||
/// The PQ Scratch, keey it private since this class use the Box to own the memory. Use the function pq_scratch to get its reference
|
||||
pub pq_scratch: Option<Box<PQScratch>>,
|
||||
|
||||
/// Buffers used in process delete, capacity increases as needed
|
||||
pub expanded_nodes_set: HashSet<u32>,
|
||||
|
||||
/// Expanded neighbors
|
||||
pub expanded_neighbors_vector: Vec<Neighbor>,
|
||||
|
||||
/// Occlude list
|
||||
pub occlude_list_output: Vec<u32>,
|
||||
|
||||
/// RobinSet for larger dataset
|
||||
pub node_visited_robinset: HashSet<u32>,
|
||||
}
|
||||
|
||||
impl<T: Default + Copy, const N: usize> InMemQueryScratch<T, N> {
|
||||
/// Create InMemQueryScratch instance
|
||||
pub fn new(
|
||||
search_candidate_size: u32,
|
||||
index_write_parameter: &IndexWriteParameters,
|
||||
init_pq_scratch: bool,
|
||||
) -> ANNResult<Self> {
|
||||
let indexing_candidate_size = index_write_parameter.search_list_size;
|
||||
let max_degree = index_write_parameter.max_degree;
|
||||
let max_occlusion_size = index_write_parameter.max_occlusion_size;
|
||||
|
||||
if search_candidate_size == 0 || indexing_candidate_size == 0 || max_degree == 0 || N == 0 {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"In InMemQueryScratch, one of search_candidate_size = {}, indexing_candidate_size = {}, dim = {} or max_degree = {} is zero.",
|
||||
search_candidate_size, indexing_candidate_size, N, max_degree)));
|
||||
}
|
||||
|
||||
let query = AlignedBoxWithSlice::new(N, mem::size_of::<T>() * QUERY_ALIGNMENT_OF_T_SIZE)?;
|
||||
let pq_scratch = if init_pq_scratch {
|
||||
Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let occlude_factor = Vec::with_capacity(max_occlusion_size as usize);
|
||||
|
||||
let capacity = (1.5 * GRAPH_SLACK_FACTOR * (max_degree as f64)).ceil() as usize;
|
||||
let id_scratch = Vec::with_capacity(capacity);
|
||||
let dist_scratch = Vec::with_capacity(capacity);
|
||||
|
||||
let expanded_nodes_set = HashSet::<u32>::new();
|
||||
let expanded_neighbors_vector = Vec::<Neighbor>::new();
|
||||
let occlude_list_output = Vec::<u32>::new();
|
||||
|
||||
let candidate_size = max(search_candidate_size, indexing_candidate_size);
|
||||
let node_visited_robinset = HashSet::<u32>::with_capacity(20 * candidate_size as usize);
|
||||
let scratch = Self {
|
||||
candidate_size,
|
||||
max_degree,
|
||||
max_occlusion_size,
|
||||
query,
|
||||
best_candidates: NeighborPriorityQueue::with_capacity(candidate_size as usize),
|
||||
occlude_factor,
|
||||
id_scratch,
|
||||
dist_scratch,
|
||||
pq_scratch,
|
||||
expanded_nodes_set,
|
||||
expanded_neighbors_vector,
|
||||
occlude_list_output,
|
||||
node_visited_robinset,
|
||||
};
|
||||
|
||||
Ok(scratch)
|
||||
}
|
||||
|
||||
/// Resize the scratch with new candidate size
|
||||
pub fn resize_for_new_candidate_size(&mut self, new_candidate_size: u32) {
|
||||
if new_candidate_size > self.candidate_size {
|
||||
let delta = new_candidate_size - self.candidate_size;
|
||||
self.candidate_size = new_candidate_size;
|
||||
self.best_candidates.reserve(delta as usize);
|
||||
self.node_visited_robinset.reserve((20 * delta) as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Default + Copy, const N: usize> Scratch for InMemQueryScratch<T, N> {
|
||||
fn clear(&mut self) {
|
||||
self.best_candidates.clear();
|
||||
self.occlude_factor.clear();
|
||||
|
||||
self.node_visited_robinset.clear();
|
||||
|
||||
self.id_scratch.clear();
|
||||
self.dist_scratch.clear();
|
||||
|
||||
self.expanded_nodes_set.clear();
|
||||
self.expanded_neighbors_vector.clear();
|
||||
self.occlude_list_output.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod inmemory_query_scratch_test {
|
||||
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn node_visited_robinset_test() {
|
||||
let index_write_parameter = IndexWriteParametersBuilder::new(10, 10)
|
||||
.with_max_occlusion_size(5)
|
||||
.build();
|
||||
|
||||
let mut scratch =
|
||||
InMemQueryScratch::<f32, 32>::new(100, &index_write_parameter, false).unwrap();
|
||||
|
||||
assert_eq!(scratch.node_visited_robinset.len(), 0);
|
||||
|
||||
scratch.clear();
|
||||
assert_eq!(scratch.node_visited_robinset.len(), 0);
|
||||
}
|
||||
}
|
||||
28
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/scratch/mod.rs
vendored
Normal file
28
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/scratch/mod.rs
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub mod scratch_traits;
|
||||
pub use scratch_traits::*;
|
||||
|
||||
pub mod concurrent_queue;
|
||||
pub use concurrent_queue::*;
|
||||
|
||||
pub mod pq_scratch;
|
||||
pub use pq_scratch::*;
|
||||
|
||||
|
||||
pub mod inmem_query_scratch;
|
||||
pub use inmem_query_scratch::*;
|
||||
|
||||
pub mod scratch_store_manager;
|
||||
pub use scratch_store_manager::*;
|
||||
|
||||
pub mod ssd_query_scratch;
|
||||
pub use ssd_query_scratch::*;
|
||||
|
||||
pub mod ssd_thread_data;
|
||||
pub use ssd_thread_data::*;
|
||||
|
||||
pub mod ssd_io_context;
|
||||
pub use ssd_io_context::*;
|
||||
105
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/scratch/pq_scratch.rs
vendored
Normal file
105
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/scratch/pq_scratch.rs
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Aligned allocator
|
||||
|
||||
use std::mem::size_of;
|
||||
|
||||
use crate::common::{ANNResult, AlignedBoxWithSlice};
|
||||
|
||||
const MAX_PQ_CHUNKS: usize = 512;
|
||||
|
||||
#[derive(Debug)]
|
||||
/// PQ scratch
|
||||
pub struct PQScratch {
|
||||
/// Aligned pq table dist scratch, must be at least [256 * NCHUNKS]
|
||||
pub aligned_pqtable_dist_scratch: AlignedBoxWithSlice<f32>,
|
||||
/// Aligned dist scratch, must be at least diskann MAX_DEGREE
|
||||
pub aligned_dist_scratch: AlignedBoxWithSlice<f32>,
|
||||
/// Aligned pq coord scratch, must be at least [N_CHUNKS * MAX_DEGREE]
|
||||
pub aligned_pq_coord_scratch: AlignedBoxWithSlice<u8>,
|
||||
/// Rotated query
|
||||
pub rotated_query: AlignedBoxWithSlice<f32>,
|
||||
/// Aligned query float
|
||||
pub aligned_query_float: AlignedBoxWithSlice<f32>,
|
||||
}
|
||||
|
||||
impl PQScratch {
|
||||
const ALIGNED_ALLOC_256: usize = 256;
|
||||
|
||||
/// Create a new pq scratch
|
||||
pub fn new(graph_degree: usize, aligned_dim: usize) -> ANNResult<Self> {
|
||||
let aligned_pq_coord_scratch =
|
||||
AlignedBoxWithSlice::new(graph_degree * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?;
|
||||
let aligned_pqtable_dist_scratch =
|
||||
AlignedBoxWithSlice::new(256 * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?;
|
||||
let aligned_dist_scratch =
|
||||
AlignedBoxWithSlice::new(graph_degree, PQScratch::ALIGNED_ALLOC_256)?;
|
||||
let aligned_query_float = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::<f32>())?;
|
||||
let rotated_query = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::<f32>())?;
|
||||
|
||||
Ok(Self {
|
||||
aligned_pqtable_dist_scratch,
|
||||
aligned_dist_scratch,
|
||||
aligned_pq_coord_scratch,
|
||||
rotated_query,
|
||||
aligned_query_float,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set rotated_query and aligned_query_float values
|
||||
pub fn set<T>(&mut self, dim: usize, query: &[T], norm: f32)
|
||||
where
|
||||
T: Into<f32> + Copy,
|
||||
{
|
||||
for (d, item) in query.iter().enumerate().take(dim) {
|
||||
let query_val: f32 = (*item).into();
|
||||
if (norm - 1.0).abs() > f32::EPSILON {
|
||||
self.rotated_query[d] = query_val / norm;
|
||||
self.aligned_query_float[d] = query_val / norm;
|
||||
} else {
|
||||
self.rotated_query[d] = query_val;
|
||||
self.aligned_query_float[d] = query_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model::PQScratch;
|
||||
|
||||
#[test]
|
||||
fn test_pq_scratch() {
|
||||
let graph_degree = 512;
|
||||
let aligned_dim = 8;
|
||||
|
||||
let mut pq_scratch: PQScratch = PQScratch::new(graph_degree, aligned_dim).unwrap();
|
||||
|
||||
// Check alignment
|
||||
assert_eq!(
|
||||
(pq_scratch.aligned_pqtable_dist_scratch.as_ptr() as usize) % 256,
|
||||
0
|
||||
);
|
||||
assert_eq!((pq_scratch.aligned_dist_scratch.as_ptr() as usize) % 256, 0);
|
||||
assert_eq!(
|
||||
(pq_scratch.aligned_pq_coord_scratch.as_ptr() as usize) % 256,
|
||||
0
|
||||
);
|
||||
assert_eq!((pq_scratch.rotated_query.as_ptr() as usize) % 32, 0);
|
||||
assert_eq!((pq_scratch.aligned_query_float.as_ptr() as usize) % 32, 0);
|
||||
|
||||
// Test set() method
|
||||
let query = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
let norm = 2.0f32;
|
||||
pq_scratch.set::<u8>(query.len(), &query, norm);
|
||||
|
||||
(0..query.len()).for_each(|i| {
|
||||
assert_eq!(pq_scratch.rotated_query[i], query[i] as f32 / norm);
|
||||
assert_eq!(pq_scratch.aligned_query_float[i], query[i] as f32 / norm);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use crate::common::ANNResult;
|
||||
|
||||
use super::ArcConcurrentBoxedQueue;
|
||||
use super::{scratch_traits::Scratch};
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct ScratchStoreManager<T: Scratch> {
|
||||
scratch: Option<Box<T>>,
|
||||
scratch_pool: ArcConcurrentBoxedQueue<T>,
|
||||
}
|
||||
|
||||
impl<T: Scratch> ScratchStoreManager<T> {
|
||||
pub fn new(scratch_pool: ArcConcurrentBoxedQueue<T>, wait_time: Duration) -> ANNResult<Self> {
|
||||
let mut scratch = scratch_pool.pop()?;
|
||||
while scratch.is_none() {
|
||||
scratch_pool.wait_for_push_notify(wait_time)?;
|
||||
scratch = scratch_pool.pop()?;
|
||||
}
|
||||
|
||||
Ok(ScratchStoreManager {
|
||||
scratch,
|
||||
scratch_pool,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn scratch_space(&mut self) -> Option<&mut T> {
|
||||
self.scratch.as_deref_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scratch> Drop for ScratchStoreManager<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(mut scratch) = self.scratch.take() {
|
||||
scratch.clear();
|
||||
let _ = self.scratch_pool.push(scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MyScratch {
|
||||
data: Vec<i32>,
|
||||
}
|
||||
|
||||
impl Scratch for MyScratch {
|
||||
fn clear(&mut self) {
|
||||
self.data.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scratch_store_manager() {
|
||||
let wait_time = Duration::from_millis(100);
|
||||
|
||||
let scratch_pool = ArcConcurrentBoxedQueue::new();
|
||||
for i in 1..3 {
|
||||
scratch_pool.push(Box::new(MyScratch {
|
||||
data: vec![i, 2 * i, 3 * i],
|
||||
})).unwrap();
|
||||
}
|
||||
|
||||
let mut manager = ScratchStoreManager::new(scratch_pool.clone(), wait_time).unwrap();
|
||||
let scratch_space = manager.scratch_space().unwrap();
|
||||
|
||||
assert_eq!(scratch_space.data, vec![1, 2, 3]);
|
||||
|
||||
// At this point, the ScratchStoreManager will go out of scope,
|
||||
// causing the Drop implementation to be called, which should
|
||||
// call the clear method on MyScratch.
|
||||
drop(manager);
|
||||
|
||||
let current_scratch = scratch_pool.pop().unwrap().unwrap();
|
||||
assert_eq!(current_scratch.data, vec![2, 4, 6]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub trait Scratch {
|
||||
fn clear(&mut self);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete.
|
||||
use crate::common::ANNError;
|
||||
|
||||
use platform::{FileHandle, IOCompletionPort};
|
||||
|
||||
// The IOContext struct for disk I/O. One for each thread.
|
||||
pub struct IOContext {
|
||||
pub status: Status,
|
||||
pub file_handle: FileHandle,
|
||||
pub io_completion_port: IOCompletionPort,
|
||||
}
|
||||
|
||||
impl Default for IOContext {
|
||||
fn default() -> Self {
|
||||
IOContext {
|
||||
status: Status::ReadWait,
|
||||
file_handle: FileHandle::default(),
|
||||
io_completion_port: IOCompletionPort::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IOContext {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Status {
|
||||
ReadWait,
|
||||
ReadSuccess,
|
||||
ReadFailed(ANNError),
|
||||
ProcessComplete,
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete.
|
||||
use std::mem;
|
||||
use std::vec::Vec;
|
||||
|
||||
use hashbrown::HashSet;
|
||||
|
||||
use crate::{
|
||||
common::{ANNResult, AlignedBoxWithSlice},
|
||||
model::{Neighbor, NeighborPriorityQueue},
|
||||
model::data_store::DiskScratchDataset,
|
||||
};
|
||||
|
||||
use super::{PQScratch, Scratch, MAX_GRAPH_DEGREE, QUERY_ALIGNMENT_OF_T_SIZE};
|
||||
|
||||
// Scratch space for disk index based search.
|
||||
pub struct SSDQueryScratch<T: Default + Copy, const N: usize>
|
||||
{
|
||||
// Disk scratch dataset storing fp vectors with aligned dim (N)
|
||||
pub scratch_dataset: DiskScratchDataset<T, N>,
|
||||
|
||||
// The query scratch.
|
||||
pub query: AlignedBoxWithSlice<T>,
|
||||
|
||||
/// The PQ Scratch.
|
||||
pub pq_scratch: Option<Box<PQScratch>>,
|
||||
|
||||
// The visited set.
|
||||
pub id_scratch: HashSet<u32>,
|
||||
|
||||
/// Best candidates, whose size is candidate_queue_size
|
||||
pub best_candidates: NeighborPriorityQueue,
|
||||
|
||||
// Full return set.
|
||||
pub full_return_set: Vec<Neighbor>,
|
||||
}
|
||||
|
||||
//
|
||||
impl<T: Copy + Default, const N: usize> SSDQueryScratch<T, N>
|
||||
{
|
||||
pub fn new(
|
||||
visited_reserve: usize,
|
||||
candidate_queue_size: usize,
|
||||
init_pq_scratch: bool,
|
||||
) -> ANNResult<Self> {
|
||||
let scratch_dataset = DiskScratchDataset::<T, N>::new()?;
|
||||
|
||||
let query = AlignedBoxWithSlice::<T>::new(N, mem::size_of::<T>() * QUERY_ALIGNMENT_OF_T_SIZE)?;
|
||||
|
||||
let id_scratch = HashSet::<u32>::with_capacity(visited_reserve);
|
||||
let full_return_set = Vec::<Neighbor>::with_capacity(visited_reserve);
|
||||
let best_candidates = NeighborPriorityQueue::with_capacity(candidate_queue_size);
|
||||
|
||||
let pq_scratch = if init_pq_scratch {
|
||||
Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
scratch_dataset,
|
||||
query,
|
||||
pq_scratch,
|
||||
id_scratch,
|
||||
best_candidates,
|
||||
full_return_set,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pq_scratch(&mut self) -> &Option<Box<PQScratch>> {
|
||||
&self.pq_scratch
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Default + Copy, const N: usize> Scratch for SSDQueryScratch<T, N>
|
||||
{
|
||||
fn clear(&mut self) {
|
||||
self.id_scratch.clear();
|
||||
self.best_candidates.clear();
|
||||
self.full_return_set.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
// Arrange
|
||||
let visited_reserve = 100;
|
||||
let candidate_queue_size = 10;
|
||||
let init_pq_scratch = true;
|
||||
|
||||
// Act
|
||||
let result =
|
||||
SSDQueryScratch::<u32, 3>::new(visited_reserve, candidate_queue_size, init_pq_scratch);
|
||||
|
||||
// Assert
|
||||
assert!(result.is_ok());
|
||||
|
||||
let scratch = result.unwrap();
|
||||
|
||||
// Assert the properties of the scratch instance
|
||||
assert!(scratch.pq_scratch.is_some());
|
||||
assert!(scratch.id_scratch.is_empty());
|
||||
assert!(scratch.best_candidates.size() == 0);
|
||||
assert!(scratch.full_return_set.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
// Arrange
|
||||
let mut scratch = SSDQueryScratch::<u32, 3>::new(100, 10, true).unwrap();
|
||||
|
||||
// Add some data to scratch fields
|
||||
scratch.id_scratch.insert(1);
|
||||
scratch.best_candidates.insert(Neighbor::new(2, 0.5));
|
||||
scratch.full_return_set.push(Neighbor::new(3, 0.8));
|
||||
|
||||
// Act
|
||||
scratch.clear();
|
||||
|
||||
// Assert
|
||||
assert!(scratch.id_scratch.is_empty());
|
||||
assert!(scratch.best_candidates.size() == 0);
|
||||
assert!(scratch.full_return_set.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete.
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{scratch_traits::Scratch, IOContext, SSDQueryScratch};
|
||||
use crate::common::ANNResult;
|
||||
|
||||
// The thread data struct for SSD I/O. One for each thread, contains the ScratchSpace and the IOContext.
|
||||
pub struct SSDThreadData<T: Default + Copy, const N: usize> {
|
||||
pub scratch: SSDQueryScratch<T, N>,
|
||||
pub io_context: Option<Arc<IOContext>>,
|
||||
}
|
||||
|
||||
impl<T: Default + Copy, const N: usize> SSDThreadData<T, N> {
|
||||
pub fn new(
|
||||
aligned_dim: usize,
|
||||
visited_reserve: usize,
|
||||
init_pq_scratch: bool,
|
||||
) -> ANNResult<Self> {
|
||||
let scratch = SSDQueryScratch::new(aligned_dim, visited_reserve, init_pq_scratch)?;
|
||||
Ok(SSDThreadData {
|
||||
scratch,
|
||||
io_context: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.scratch.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model::Neighbor;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
// Arrange
|
||||
let aligned_dim = 10;
|
||||
let visited_reserve = 100;
|
||||
let init_pq_scratch = true;
|
||||
|
||||
// Act
|
||||
let result = SSDThreadData::<u32, 3>::new(aligned_dim, visited_reserve, init_pq_scratch);
|
||||
|
||||
// Assert
|
||||
assert!(result.is_ok());
|
||||
|
||||
let thread_data = result.unwrap();
|
||||
|
||||
// Assert the properties of the thread data instance
|
||||
assert!(thread_data.io_context.is_none());
|
||||
|
||||
let scratch = &thread_data.scratch;
|
||||
// Assert the properties of the scratch instance
|
||||
assert!(scratch.pq_scratch.is_some());
|
||||
assert!(scratch.id_scratch.is_empty());
|
||||
assert!(scratch.best_candidates.size() == 0);
|
||||
assert!(scratch.full_return_set.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
// Arrange
|
||||
let mut thread_data = SSDThreadData::<u32, 3>::new(10, 100, true).unwrap();
|
||||
|
||||
// Add some data to scratch fields
|
||||
thread_data.scratch.id_scratch.insert(1);
|
||||
thread_data
|
||||
.scratch
|
||||
.best_candidates
|
||||
.insert(Neighbor::new(2, 0.5));
|
||||
thread_data
|
||||
.scratch
|
||||
.full_return_set
|
||||
.push(Neighbor::new(3, 0.8));
|
||||
|
||||
// Act
|
||||
thread_data.clear();
|
||||
|
||||
// Assert
|
||||
assert!(thread_data.scratch.id_scratch.is_empty());
|
||||
assert!(thread_data.scratch.best_candidates.size() == 0);
|
||||
assert!(thread_data.scratch.full_return_set.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
22
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/vertex/dimension.rs
vendored
Normal file
22
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/vertex/dimension.rs
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Vertex dimension
|
||||
|
||||
/// 32 vertex dimension
|
||||
pub const DIM_32: usize = 32;
|
||||
|
||||
/// 64 vertex dimension
|
||||
pub const DIM_64: usize = 64;
|
||||
|
||||
/// 104 vertex dimension
|
||||
pub const DIM_104: usize = 104;
|
||||
|
||||
/// 128 vertex dimension
|
||||
pub const DIM_128: usize = 128;
|
||||
|
||||
/// 256 vertex dimension
|
||||
pub const DIM_256: usize = 256;
|
||||
10
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/vertex/mod.rs
vendored
Normal file
10
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/vertex/mod.rs
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod vertex;
|
||||
pub use vertex::Vertex;
|
||||
|
||||
mod dimension;
|
||||
pub use dimension::*;
|
||||
68
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/vertex/vertex.rs
vendored
Normal file
68
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/model/vertex/vertex.rs
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Vertex
|
||||
|
||||
use std::array::TryFromSliceError;
|
||||
|
||||
use vector::{FullPrecisionDistance, Metric};
|
||||
|
||||
/// Vertex with data type T and dimension N
|
||||
#[derive(Debug)]
|
||||
pub struct Vertex<'a, T, const N: usize>
|
||||
where
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// Vertex value
|
||||
val: &'a [T; N],
|
||||
|
||||
/// Vertex Id
|
||||
id: u32,
|
||||
}
|
||||
|
||||
impl<'a, T, const N: usize> Vertex<'a, T, N>
|
||||
where
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
/// Create the vertex with data
|
||||
pub fn new(val: &'a [T; N], id: u32) -> Self {
|
||||
Self {
|
||||
val,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare the vertex with another.
|
||||
#[inline(always)]
|
||||
pub fn compare(&self, other: &Vertex<'a, T, N>, metric: Metric) -> f32 {
|
||||
<[T; N]>::distance_compare(self.val, other.val, metric)
|
||||
}
|
||||
|
||||
/// Get the vector associated with the vertex.
|
||||
#[inline]
|
||||
pub fn vector(&self) -> &[T; N] {
|
||||
self.val
|
||||
}
|
||||
|
||||
/// Get the vertex id.
|
||||
#[inline]
|
||||
pub fn vertex_id(&self) -> u32 {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, const N: usize> TryFrom<(&'a [T], u32)> for Vertex<'a, T, N>
|
||||
where
|
||||
[T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
type Error = TryFromSliceError;
|
||||
|
||||
fn try_from((mem_slice, id): (&'a [T], u32)) -> Result<Self, Self::Error> {
|
||||
let array: &[T; N] = mem_slice.try_into()?;
|
||||
Ok(Vertex::new(array, id))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[allow(clippy::module_inception)]
|
||||
mod windows_aligned_file_reader;
|
||||
pub use windows_aligned_file_reader::*;
|
||||
@@ -0,0 +1,414 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{ptr, thread};
|
||||
|
||||
use crossbeam::sync::ShardedLock;
|
||||
use hashbrown::HashMap;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use platform::file_handle::{AccessMode, ShareMode};
|
||||
use platform::{
|
||||
file_handle::FileHandle,
|
||||
file_io::{get_queued_completion_status, read_file_to_slice},
|
||||
io_completion_port::IOCompletionPort,
|
||||
};
|
||||
|
||||
use winapi::{
|
||||
shared::{basetsd::ULONG_PTR, minwindef::DWORD},
|
||||
um::minwinbase::OVERLAPPED,
|
||||
};
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::model::IOContext;
|
||||
|
||||
pub const MAX_IO_CONCURRENCY: usize = 128; // To do: explore the optimal value for this. The current value is taken from C++ code.
|
||||
pub const FILE_ATTRIBUTE_READONLY: DWORD = 0x00000001;
|
||||
pub const IO_COMPLETION_TIMEOUT: DWORD = u32::MAX; // Infinite timeout.
|
||||
pub const DISK_IO_ALIGNMENT: usize = 512;
|
||||
pub const ASYNC_IO_COMPLETION_CHECK_INTERVAL: Duration = Duration::from_micros(5);
|
||||
|
||||
/// Aligned read struct for disk IO, it takes the ownership of the AlignedBoxedSlice and returns the AlignedBoxWithSlice data immutably.
|
||||
pub struct AlignedRead<'a, T> {
|
||||
/// where to read from
|
||||
/// offset needs to be aligned with DISK_IO_ALIGNMENT
|
||||
offset: u64,
|
||||
|
||||
/// where to read into
|
||||
/// aligned_buf and its len need to be aligned with DISK_IO_ALIGNMENT
|
||||
aligned_buf: &'a mut [T],
|
||||
}
|
||||
|
||||
impl<'a, T> AlignedRead<'a, T> {
|
||||
pub fn new(offset: u64, aligned_buf: &'a mut [T]) -> ANNResult<Self> {
|
||||
Self::assert_is_aligned(offset as usize)?;
|
||||
Self::assert_is_aligned(std::mem::size_of_val(aligned_buf))?;
|
||||
|
||||
Ok(Self {
|
||||
offset,
|
||||
aligned_buf,
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_is_aligned(val: usize) -> ANNResult<()> {
|
||||
match val % DISK_IO_ALIGNMENT {
|
||||
0 => Ok(()),
|
||||
_ => Err(ANNError::log_disk_io_request_alignment_error(format!(
|
||||
"The offset or length of AlignedRead request is not {} bytes aligned",
|
||||
DISK_IO_ALIGNMENT
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aligned_buf(&self) -> &[T] {
|
||||
self.aligned_buf
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WindowsAlignedFileReader {
|
||||
file_name: String,
|
||||
|
||||
// ctx_map is the mapping from thread id to io context. It is hashmap behind a sharded lock to allow concurrent access from multiple threads.
|
||||
// ShardedLock: shardedlock provides an implementation of a reader-writer lock that offers concurrent read access to the shared data while allowing exclusive write access.
|
||||
// It achieves better scalability by dividing the shared data into multiple shards, and each with its own internal lock.
|
||||
// Multiple threads can read from different shards simultaneously, reducing contention.
|
||||
// https://docs.rs/crossbeam/0.8.2/crossbeam/sync/struct.ShardedLock.html
|
||||
// Comparing to RwLock, ShardedLock provides higher concurrency for read operations and is suitable for read heavy workloads.
|
||||
// The value of the hashmap is an Arc<IOContext> to allow immutable access to IOContext with automatic reference counting.
|
||||
ctx_map: Lazy<ShardedLock<HashMap<thread::ThreadId, Arc<IOContext>>>>,
|
||||
}
|
||||
|
||||
impl WindowsAlignedFileReader {
|
||||
pub fn new(fname: &str) -> ANNResult<Self> {
|
||||
let reader: WindowsAlignedFileReader = WindowsAlignedFileReader {
|
||||
file_name: fname.to_string(),
|
||||
ctx_map: Lazy::new(|| ShardedLock::new(HashMap::new())),
|
||||
};
|
||||
|
||||
reader.register_thread()?;
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
// Register the io context for a thread if it hasn't been registered.
|
||||
pub fn register_thread(&self) -> ANNResult<()> {
|
||||
let mut ctx_map = self.ctx_map.write().map_err(|_| {
|
||||
ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string())
|
||||
})?;
|
||||
|
||||
let id = thread::current().id();
|
||||
if ctx_map.contains_key(&id) {
|
||||
println!(
|
||||
"Warning:: Duplicate registration for thread_id : {:?}. Directly call get_ctx to get the thread context data.",
|
||||
id);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut ctx = IOContext::new();
|
||||
|
||||
match unsafe { FileHandle::new(&self.file_name, AccessMode::Read, ShareMode::Read) } {
|
||||
Ok(file_handle) => ctx.file_handle = file_handle,
|
||||
Err(err) => {
|
||||
return Err(ANNError::log_io_error(err));
|
||||
}
|
||||
}
|
||||
|
||||
// Create a io completion port for the file handle, later it will be used to get the completion status.
|
||||
match IOCompletionPort::new(&ctx.file_handle, None, 0, 0) {
|
||||
Ok(io_completion_port) => ctx.io_completion_port = io_completion_port,
|
||||
Err(err) => {
|
||||
return Err(ANNError::log_io_error(err));
|
||||
}
|
||||
}
|
||||
|
||||
ctx_map.insert(id, Arc::new(ctx));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Get the reference counted io context for the current thread.
|
||||
pub fn get_ctx(&self) -> ANNResult<Arc<IOContext>> {
|
||||
let ctx_map = self.ctx_map.read().map_err(|_| {
|
||||
ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string())
|
||||
})?;
|
||||
|
||||
let id = thread::current().id();
|
||||
match ctx_map.get(&id) {
|
||||
Some(ctx) => Ok(Arc::clone(ctx)),
|
||||
None => Err(ANNError::log_index_error(format!(
|
||||
"unable to find IOContext for thread_id {:?}",
|
||||
id
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
// Read the data from the file by sending concurrent io requests in batches.
|
||||
pub fn read<T>(&self, read_requests: &mut [AlignedRead<T>], ctx: &IOContext) -> ANNResult<()> {
|
||||
let n_requests = read_requests.len();
|
||||
let n_batches = (n_requests + MAX_IO_CONCURRENCY - 1) / MAX_IO_CONCURRENCY;
|
||||
|
||||
let mut overlapped_in_out =
|
||||
vec![unsafe { std::mem::zeroed::<OVERLAPPED>() }; MAX_IO_CONCURRENCY];
|
||||
|
||||
for batch_idx in 0..n_batches {
|
||||
let batch_start = MAX_IO_CONCURRENCY * batch_idx;
|
||||
let batch_size = std::cmp::min(n_requests - batch_start, MAX_IO_CONCURRENCY);
|
||||
|
||||
for j in 0..batch_size {
|
||||
let req = &mut read_requests[batch_start + j];
|
||||
let os = &mut overlapped_in_out[j];
|
||||
|
||||
match unsafe {
|
||||
read_file_to_slice(&ctx.file_handle, req.aligned_buf, os, req.offset)
|
||||
} {
|
||||
Ok(_) => {}
|
||||
Err(error) => {
|
||||
return Err(ANNError::IOError { err: (error) });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut n_read: DWORD = 0;
|
||||
let mut n_complete: u64 = 0;
|
||||
let mut completion_key: ULONG_PTR = 0;
|
||||
let mut lp_os: *mut OVERLAPPED = ptr::null_mut();
|
||||
while n_complete < batch_size as u64 {
|
||||
match unsafe {
|
||||
get_queued_completion_status(
|
||||
&ctx.io_completion_port,
|
||||
&mut n_read,
|
||||
&mut completion_key,
|
||||
&mut lp_os,
|
||||
IO_COMPLETION_TIMEOUT,
|
||||
)
|
||||
} {
|
||||
// An IO request completed.
|
||||
Ok(true) => n_complete += 1,
|
||||
// No IO request completed, continue to wait.
|
||||
Ok(false) => {
|
||||
thread::sleep(ASYNC_IO_COMPLETION_CHECK_INTERVAL);
|
||||
}
|
||||
// An error ocurred.
|
||||
Err(error) => return Err(ANNError::IOError { err: (error) }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{fs::File, io::BufReader};
|
||||
|
||||
use bincode::deserialize_from;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{common::AlignedBoxWithSlice, model::SECTOR_LEN};
|
||||
|
||||
use super::*;
|
||||
pub const TEST_INDEX_PATH: &str =
|
||||
"./tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index";
|
||||
pub const TRUTH_NODE_DATA_PATH: &str =
|
||||
"./tests/data/disk_index_node_data_aligned_reader_truth.bin";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NodeData {
|
||||
num_neighbors: u32,
|
||||
coordinates: Vec<f32>,
|
||||
neighbors: Vec<u32>,
|
||||
}
|
||||
|
||||
impl PartialEq for NodeData {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.num_neighbors == other.num_neighbors
|
||||
&& self.coordinates == other.coordinates
|
||||
&& self.neighbors == other.neighbors
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_aligned_file_reader() {
|
||||
// Replace "test_file_path" with actual file path
|
||||
let result = WindowsAlignedFileReader::new(TEST_INDEX_PATH);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let reader = result.unwrap();
|
||||
assert_eq!(reader.file_name, TEST_INDEX_PATH);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read() {
|
||||
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
|
||||
let ctx = reader.get_ctx().unwrap();
|
||||
|
||||
let read_length = 512; // adjust according to your logic
|
||||
let num_read = 10;
|
||||
let mut aligned_mem = AlignedBoxWithSlice::<u8>::new(read_length * num_read, 512).unwrap();
|
||||
|
||||
// create and add AlignedReads to the vector
|
||||
let mut mem_slices = aligned_mem
|
||||
.split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
|
||||
.unwrap();
|
||||
|
||||
let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(i, slice)| {
|
||||
let offset = (i * read_length) as u64;
|
||||
AlignedRead::new(offset, slice).unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let result = reader.read(&mut aligned_reads, &ctx);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_disk_index_by_sector() {
|
||||
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
|
||||
let ctx = reader.get_ctx().unwrap();
|
||||
|
||||
let read_length = SECTOR_LEN; // adjust according to your logic
|
||||
let num_sector = 10;
|
||||
let mut aligned_mem =
|
||||
AlignedBoxWithSlice::<u8>::new(read_length * num_sector, 512).unwrap();
|
||||
|
||||
// Each slice will be used as the buffer for a read request of a sector.
|
||||
let mut mem_slices = aligned_mem
|
||||
.split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
|
||||
.unwrap();
|
||||
|
||||
let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.map(|(sector_id, slice)| {
|
||||
let offset = (sector_id * read_length) as u64;
|
||||
AlignedRead::new(offset, slice).unwrap()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let result = reader.read(&mut aligned_reads, &ctx);
|
||||
assert!(result.is_ok());
|
||||
|
||||
aligned_reads.iter().for_each(|read| {
|
||||
assert_eq!(read.aligned_buf.len(), SECTOR_LEN);
|
||||
});
|
||||
|
||||
let disk_layout_meta = reconstruct_disk_meta(aligned_reads[0].aligned_buf);
|
||||
assert!(disk_layout_meta.len() > 9);
|
||||
|
||||
let dims = disk_layout_meta[1];
|
||||
let num_pts = disk_layout_meta[0];
|
||||
let max_node_len = disk_layout_meta[3];
|
||||
let max_num_nodes_per_sector = disk_layout_meta[4];
|
||||
|
||||
assert!(max_node_len * max_num_nodes_per_sector < SECTOR_LEN as u64);
|
||||
|
||||
let num_nbrs_start = (dims as usize) * std::mem::size_of::<f32>();
|
||||
let nbrs_buf_start = num_nbrs_start + std::mem::size_of::<u32>();
|
||||
|
||||
let mut node_data_array = Vec::with_capacity(max_num_nodes_per_sector as usize * 9);
|
||||
|
||||
// Only validate the first 9 sectors with graph nodes.
|
||||
(1..9).for_each(|sector_id| {
|
||||
let sector_data = &mem_slices[sector_id];
|
||||
for node_data in sector_data.chunks_exact(max_node_len as usize) {
|
||||
// Extract coordinates data from the start of the node_data
|
||||
let coordinates_end = (dims as usize) * std::mem::size_of::<f32>();
|
||||
let coordinates = node_data[0..coordinates_end]
|
||||
.chunks_exact(std::mem::size_of::<f32>())
|
||||
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
|
||||
.collect();
|
||||
|
||||
// Extract number of neighbors from the node_data
|
||||
let neighbors_num = u32::from_le_bytes(
|
||||
node_data[num_nbrs_start..nbrs_buf_start]
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let nbors_buf_end =
|
||||
nbrs_buf_start + (neighbors_num as usize) * std::mem::size_of::<u32>();
|
||||
|
||||
// Extract neighbors from the node data.
|
||||
let mut neighbors = Vec::new();
|
||||
for nbors_data in node_data[nbrs_buf_start..nbors_buf_end]
|
||||
.chunks_exact(std::mem::size_of::<u32>())
|
||||
{
|
||||
let nbors_id = u32::from_le_bytes(nbors_data.try_into().unwrap());
|
||||
assert!(nbors_id < num_pts as u32);
|
||||
neighbors.push(nbors_id);
|
||||
}
|
||||
|
||||
// Create NodeData struct and push it to the node_data_array
|
||||
node_data_array.push(NodeData {
|
||||
num_neighbors: neighbors_num,
|
||||
coordinates,
|
||||
neighbors,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Compare that each node read from the disk index are expected.
|
||||
let node_data_truth_file = File::open(TRUTH_NODE_DATA_PATH).unwrap();
|
||||
let reader = BufReader::new(node_data_truth_file);
|
||||
|
||||
let node_data_vec: Vec<NodeData> = deserialize_from(reader).unwrap();
|
||||
for (node_from_node_data_file, node_from_disk_index) in
|
||||
node_data_vec.iter().zip(node_data_array.iter())
|
||||
{
|
||||
// Verify that the NodeData from the file is equal to the NodeData in node_data_array
|
||||
assert_eq!(node_from_node_data_file, node_from_disk_index);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_fail_invalid_file() {
|
||||
let reader = WindowsAlignedFileReader::new("/invalid_path");
|
||||
assert!(reader.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_no_requests() {
|
||||
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
|
||||
let ctx = reader.get_ctx().unwrap();
|
||||
|
||||
let mut read_requests = Vec::<AlignedRead<u8>>::new();
|
||||
let result = reader.read(&mut read_requests, &ctx);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_ctx() {
|
||||
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
|
||||
let result = reader.get_ctx();
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_thread() {
|
||||
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
|
||||
let result = reader.register_thread();
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
fn reconstruct_disk_meta(buffer: &[u8]) -> Vec<u64> {
|
||||
let size_of_u64 = std::mem::size_of::<u64>();
|
||||
|
||||
let num_values = buffer.len() / size_of_u64;
|
||||
let mut disk_layout_meta = Vec::with_capacity(num_values);
|
||||
let meta_data = &buffer[8..];
|
||||
|
||||
for chunk in meta_data.chunks_exact(size_of_u64) {
|
||||
let value = u64::from_le_bytes(chunk.try_into().unwrap());
|
||||
disk_layout_meta.push(value);
|
||||
}
|
||||
|
||||
disk_layout_meta
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Disk graph storage
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{model::{WindowsAlignedFileReader, IOContext, AlignedRead}, common::ANNResult};
|
||||
|
||||
/// Graph storage for disk index
|
||||
/// One thread has one storage instance
|
||||
pub struct DiskGraphStorage {
|
||||
/// Disk graph reader
|
||||
disk_graph_reader: Arc<WindowsAlignedFileReader>,
|
||||
|
||||
/// IOContext of current thread
|
||||
ctx: Arc<IOContext>,
|
||||
}
|
||||
|
||||
impl DiskGraphStorage {
|
||||
/// Create a new DiskGraphStorage instance
|
||||
pub fn new(disk_graph_reader: Arc<WindowsAlignedFileReader>) -> ANNResult<Self> {
|
||||
let ctx = disk_graph_reader.get_ctx()?;
|
||||
Ok(Self {
|
||||
disk_graph_reader,
|
||||
ctx,
|
||||
})
|
||||
}
|
||||
|
||||
/// Read disk graph data
|
||||
pub fn read<T>(&self, read_requests: &mut [AlignedRead<T>]) -> ANNResult<()> {
|
||||
self.disk_graph_reader.read(read_requests, &self.ctx)
|
||||
}
|
||||
}
|
||||
363
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/storage/disk_index_storage.rs
vendored
Normal file
363
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/storage/disk_index_storage.rs
vendored
Normal file
@@ -0,0 +1,363 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::marker::PhantomData;
|
||||
use std::{fs, mem};
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::model::NUM_PQ_CENTROIDS;
|
||||
use crate::storage::PQStorage;
|
||||
use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, save_bin_u64};
|
||||
use crate::utils::{
|
||||
file_exists, gen_sample_data, get_file_size, round_up, CachedReader, CachedWriter,
|
||||
};
|
||||
|
||||
const SECTOR_LEN: usize = 4096;
|
||||
|
||||
/// Todo: Remove the allow(dead_code) when the disk search code is complete
|
||||
#[allow(dead_code)]
|
||||
pub struct PQPivotData {
|
||||
dim: usize,
|
||||
pq_table: Vec<f32>,
|
||||
centroids: Vec<f32>,
|
||||
chunk_offsets: Vec<usize>,
|
||||
}
|
||||
|
||||
pub struct DiskIndexStorage<T> {
|
||||
/// Dataset file
|
||||
dataset_file: String,
|
||||
|
||||
/// Index file path prefix
|
||||
index_path_prefix: String,
|
||||
|
||||
// TODO: Only a placeholder for T, will be removed later
|
||||
_marker: PhantomData<T>,
|
||||
|
||||
pq_storage: PQStorage,
|
||||
}
|
||||
|
||||
impl<T> DiskIndexStorage<T> {
|
||||
/// Create DiskIndexStorage instance
|
||||
pub fn new(dataset_file: String, index_path_prefix: String) -> ANNResult<Self> {
|
||||
let pq_storage: PQStorage = PQStorage::new(
|
||||
&(index_path_prefix.clone() + ".bin_pq_pivots.bin"),
|
||||
&(index_path_prefix.clone() + ".bin_pq_compressed.bin"),
|
||||
&dataset_file,
|
||||
)?;
|
||||
|
||||
Ok(DiskIndexStorage {
|
||||
dataset_file,
|
||||
index_path_prefix,
|
||||
_marker: PhantomData,
|
||||
pq_storage,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_pq_storage(&mut self) -> &mut PQStorage {
|
||||
&mut self.pq_storage
|
||||
}
|
||||
|
||||
pub fn dataset_file(&self) -> &String {
|
||||
&self.dataset_file
|
||||
}
|
||||
|
||||
pub fn index_path_prefix(&self) -> &String {
|
||||
&self.index_path_prefix
|
||||
}
|
||||
|
||||
/// Create disk layout
|
||||
/// Sector #1: disk_layout_meta
|
||||
/// Sector #n: num_nodes_per_sector nodes
|
||||
/// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]}
|
||||
/// # Arguments
|
||||
/// * `dataset_file` - dataset file containing full precision vectors
|
||||
/// * `mem_index_file` - in-memory index graph file
|
||||
/// * `disk_layout_file` - output disk layout file
|
||||
pub fn create_disk_layout(&self) -> ANNResult<()> {
|
||||
let mem_index_file = self.mem_index_file();
|
||||
let disk_layout_file = self.disk_index_file();
|
||||
|
||||
// amount to read or write in one shot
|
||||
let read_blk_size = 64 * 1024 * 1024;
|
||||
let write_blk_size = read_blk_size;
|
||||
let mut dataset_reader = CachedReader::new(self.dataset_file.as_str(), read_blk_size)?;
|
||||
|
||||
let num_pts = dataset_reader.read_u32()? as u64;
|
||||
let dims = dataset_reader.read_u32()? as u64;
|
||||
|
||||
// Create cached reader + writer
|
||||
let actual_file_size = get_file_size(mem_index_file.as_str())?;
|
||||
println!("Vamana index file size={}", actual_file_size);
|
||||
|
||||
let mut vamana_reader = File::open(mem_index_file)?;
|
||||
let mut diskann_writer = CachedWriter::new(disk_layout_file.as_str(), write_blk_size)?;
|
||||
|
||||
let index_file_size = vamana_reader.read_u64::<LittleEndian>()?;
|
||||
if index_file_size != actual_file_size {
|
||||
println!(
|
||||
"Vamana Index file size does not match expected size per meta-data. file size from file: {}, actual file size: {}",
|
||||
index_file_size, actual_file_size
|
||||
);
|
||||
}
|
||||
|
||||
let max_degree = vamana_reader.read_u32::<LittleEndian>()?;
|
||||
let medoid = vamana_reader.read_u32::<LittleEndian>()?;
|
||||
let vamana_frozen_num = vamana_reader.read_u64::<LittleEndian>()?;
|
||||
|
||||
let mut vamana_frozen_loc = 0;
|
||||
if vamana_frozen_num == 1 {
|
||||
vamana_frozen_loc = medoid;
|
||||
}
|
||||
|
||||
let max_node_len = ((max_degree as u64 + 1) * (mem::size_of::<u32>() as u64))
|
||||
+ (dims * (mem::size_of::<T>() as u64));
|
||||
let num_nodes_per_sector = (SECTOR_LEN as u64) / max_node_len;
|
||||
|
||||
println!("medoid: {}B", medoid);
|
||||
println!("max_node_len: {}B", max_node_len);
|
||||
println!("num_nodes_per_sector: {}B", num_nodes_per_sector);
|
||||
|
||||
// SECTOR_LEN buffer for each sector
|
||||
let mut sector_buf = vec![0u8; SECTOR_LEN];
|
||||
let mut node_buf = vec![0u8; max_node_len as usize];
|
||||
|
||||
let num_nbrs_start = (dims as usize) * mem::size_of::<T>();
|
||||
let nbrs_buf_start = num_nbrs_start + mem::size_of::<u32>();
|
||||
|
||||
// number of sectors (1 for meta data)
|
||||
let num_sectors = round_up(num_pts, num_nodes_per_sector) / num_nodes_per_sector;
|
||||
let disk_index_file_size = (num_sectors + 1) * (SECTOR_LEN as u64);
|
||||
|
||||
let disk_layout_meta = vec![
|
||||
num_pts,
|
||||
dims,
|
||||
medoid as u64,
|
||||
max_node_len,
|
||||
num_nodes_per_sector,
|
||||
vamana_frozen_num,
|
||||
vamana_frozen_loc as u64,
|
||||
// append_reorder_data
|
||||
// We are not supporting this. Temporarily write it into the layout so that
|
||||
// we can leverage C++ query driver to test the disk index
|
||||
false as u64,
|
||||
disk_index_file_size,
|
||||
];
|
||||
|
||||
diskann_writer.write(§or_buf)?;
|
||||
|
||||
let mut cur_node_coords = vec![0u8; (dims as usize) * mem::size_of::<T>()];
|
||||
let mut cur_node_id = 0u64;
|
||||
|
||||
for sector in 0..num_sectors {
|
||||
if sector % 100_000 == 0 {
|
||||
println!("Sector #{} written", sector);
|
||||
}
|
||||
sector_buf.fill(0);
|
||||
|
||||
for sector_node_id in 0..num_nodes_per_sector {
|
||||
if cur_node_id >= num_pts {
|
||||
break;
|
||||
}
|
||||
|
||||
node_buf.fill(0);
|
||||
|
||||
// read cur node's num_nbrs
|
||||
let num_nbrs = vamana_reader.read_u32::<LittleEndian>()?;
|
||||
|
||||
// sanity checks on num_nbrs
|
||||
debug_assert!(num_nbrs > 0);
|
||||
debug_assert!(num_nbrs <= max_degree);
|
||||
|
||||
// write coords of node first
|
||||
dataset_reader.read(&mut cur_node_coords)?;
|
||||
node_buf[..cur_node_coords.len()].copy_from_slice(&cur_node_coords);
|
||||
|
||||
// write num_nbrs
|
||||
LittleEndian::write_u32(
|
||||
&mut node_buf[num_nbrs_start..(num_nbrs_start + mem::size_of::<u32>())],
|
||||
num_nbrs,
|
||||
);
|
||||
|
||||
// write neighbors
|
||||
let nbrs_buf = &mut node_buf[nbrs_buf_start
|
||||
..(nbrs_buf_start + (num_nbrs as usize) * mem::size_of::<u32>())];
|
||||
vamana_reader.read_exact(nbrs_buf)?;
|
||||
|
||||
// get offset into sector_buf
|
||||
let sector_node_buf_start = (sector_node_id * max_node_len) as usize;
|
||||
let sector_node_buf = &mut sector_buf
|
||||
[sector_node_buf_start..(sector_node_buf_start + max_node_len as usize)];
|
||||
sector_node_buf.copy_from_slice(&node_buf[..(max_node_len as usize)]);
|
||||
|
||||
cur_node_id += 1;
|
||||
}
|
||||
|
||||
// flush sector to disk
|
||||
diskann_writer.write(§or_buf)?;
|
||||
}
|
||||
|
||||
diskann_writer.flush()?;
|
||||
save_bin_u64(
|
||||
disk_layout_file.as_str(),
|
||||
&disk_layout_meta,
|
||||
disk_layout_meta.len(),
|
||||
1,
|
||||
0,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn index_build_cleanup(&self) -> ANNResult<()> {
|
||||
fs::remove_file(self.mem_index_file())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn gen_query_warmup_data(&self, sampling_rate: f64) -> ANNResult<()> {
|
||||
gen_sample_data::<T>(
|
||||
&self.dataset_file,
|
||||
&self.warmup_query_prefix(),
|
||||
sampling_rate,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load pre-trained pivot table
|
||||
pub fn load_pq_pivots_bin(
|
||||
&self,
|
||||
num_pq_chunks: &usize,
|
||||
) -> ANNResult<PQPivotData> {
|
||||
let pq_pivots_path = &self.pq_pivot_file();
|
||||
if !file_exists(pq_pivots_path) {
|
||||
return Err(ANNError::log_pq_error(
|
||||
"ERROR: PQ k-means pivot file not found.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (data, offset_num, offset_dim) = load_bin::<u64>(pq_pivots_path, 0)?;
|
||||
let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim);
|
||||
if offset_num != 4 {
|
||||
let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, pivot_num, dim) = load_bin::<f32>(pq_pivots_path, file_offset_data[0])?;
|
||||
let pq_table = data.to_vec();
|
||||
if pivot_num != NUM_PQ_CENTROIDS {
|
||||
let error_message = format!(
|
||||
"Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.",
|
||||
pq_pivots_path, pivot_num, NUM_PQ_CENTROIDS
|
||||
);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, centroid_dim, nc) = load_bin::<f32>(pq_pivots_path, file_offset_data[1])?;
|
||||
let centroids = data.to_vec();
|
||||
if centroid_dim != dim || nc != 1 {
|
||||
let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, chunk_offset_num, nc) = load_bin::<u32>(pq_pivots_path, file_offset_data[2])?;
|
||||
let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc);
|
||||
if chunk_offset_num != num_pq_chunks + 1 || nc != 1 {
|
||||
let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
Ok(PQPivotData {
|
||||
dim,
|
||||
pq_table,
|
||||
centroids,
|
||||
chunk_offsets
|
||||
})
|
||||
}
|
||||
|
||||
fn mem_index_file(&self) -> String {
|
||||
self.index_path_prefix.clone() + "_mem.index"
|
||||
}
|
||||
|
||||
fn disk_index_file(&self) -> String {
|
||||
self.index_path_prefix.clone() + "_disk.index"
|
||||
}
|
||||
|
||||
fn warmup_query_prefix(&self) -> String {
|
||||
self.index_path_prefix.clone() + "_sample"
|
||||
}
|
||||
|
||||
pub fn pq_pivot_file(&self) -> String {
|
||||
self.index_path_prefix.clone() + ".bin_pq_pivots.bin"
|
||||
}
|
||||
|
||||
pub fn compressed_pq_pivot_file(&self) -> String {
|
||||
self.index_path_prefix.clone() + ".bin_pq_compressed.bin"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod disk_index_storage_test {
|
||||
use std::fs;
|
||||
|
||||
use crate::test_utils::get_test_file_path;
|
||||
|
||||
use super::*;
|
||||
|
||||
const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin";
|
||||
const DISK_INDEX_PATH_PREFIX: &str = "tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2";
|
||||
const TRUTH_DISK_LAYOUT: &str =
|
||||
"tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index";
|
||||
|
||||
#[test]
|
||||
fn create_disk_layout_test() {
|
||||
let storage = DiskIndexStorage::<f32>::new(
|
||||
get_test_file_path(TEST_DATA_FILE),
|
||||
get_test_file_path(DISK_INDEX_PATH_PREFIX),
|
||||
).unwrap();
|
||||
storage.create_disk_layout().unwrap();
|
||||
|
||||
let disk_layout_file = storage.disk_index_file();
|
||||
let rust_disk_layout = fs::read(disk_layout_file.as_str()).unwrap();
|
||||
let truth_disk_layout = fs::read(get_test_file_path(TRUTH_DISK_LAYOUT).as_str()).unwrap();
|
||||
|
||||
assert!(rust_disk_layout == truth_disk_layout);
|
||||
|
||||
fs::remove_file(disk_layout_file.as_str()).expect("Failed to delete file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_pivot_test() {
|
||||
let dim: usize = 128;
|
||||
let num_pq_chunk: usize = 1;
|
||||
let pivot_file_prefix: &str = "tests/data/siftsmall_learn";
|
||||
let storage = DiskIndexStorage::<f32>::new(
|
||||
get_test_file_path(TEST_DATA_FILE),
|
||||
pivot_file_prefix.to_string(),
|
||||
).unwrap();
|
||||
|
||||
let pq_pivot_data =
|
||||
storage.load_pq_pivots_bin(&num_pq_chunk).unwrap();
|
||||
|
||||
assert_eq!(pq_pivot_data.pq_table.len(), NUM_PQ_CENTROIDS * dim);
|
||||
assert_eq!(pq_pivot_data.centroids.len(), dim);
|
||||
|
||||
assert_eq!(pq_pivot_data.chunk_offsets[0], 0);
|
||||
assert_eq!(pq_pivot_data.chunk_offsets[1], dim);
|
||||
assert_eq!(pq_pivot_data.chunk_offsets.len(), num_pq_chunk + 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "ERROR: PQ k-means pivot file not found.")]
|
||||
fn load_pivot_file_not_exist_test() {
|
||||
let num_pq_chunk: usize = 1;
|
||||
let pivot_file_prefix: &str = "tests/data/siftsmall_learn_file_not_exist";
|
||||
let storage = DiskIndexStorage::<f32>::new(
|
||||
get_test_file_path(TEST_DATA_FILE),
|
||||
pivot_file_prefix.to_string(),
|
||||
).unwrap();
|
||||
let _ = storage.load_pq_pivots_bin(&num_pq_chunk).unwrap();
|
||||
}
|
||||
}
|
||||
12
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/storage/mod.rs
vendored
Normal file
12
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/storage/mod.rs
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
mod disk_index_storage;
|
||||
pub use disk_index_storage::*;
|
||||
|
||||
mod disk_graph_storage;
|
||||
pub use disk_graph_storage::*;
|
||||
|
||||
mod pq_storage;
|
||||
pub use pq_storage::*;
|
||||
367
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/storage/pq_storage.rs
vendored
Normal file
367
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/storage/pq_storage.rs
vendored
Normal file
@@ -0,0 +1,367 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
use std::fs::File;
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::mem;
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
use crate::utils::CachedReader;
|
||||
use crate::utils::{
|
||||
convert_types_u32_usize, convert_types_u64_usize, convert_types_usize_u32,
|
||||
convert_types_usize_u64, convert_types_usize_u8, save_bin_f32, save_bin_u32, save_bin_u64,
|
||||
};
|
||||
use crate::utils::{file_exists, load_bin, open_file_to_write, METADATA_SIZE};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PQStorage {
|
||||
/// Pivot table path
|
||||
pivot_file: String,
|
||||
|
||||
/// Compressed pivot path
|
||||
compressed_pivot_file: String,
|
||||
|
||||
/// Data used to construct PQ table and PQ compressed table
|
||||
pq_data_file: String,
|
||||
|
||||
/// PQ data reader
|
||||
pq_data_file_reader: File,
|
||||
}
|
||||
|
||||
impl PQStorage {
|
||||
pub fn new(
|
||||
pivot_file: &str,
|
||||
compressed_pivot_file: &str,
|
||||
pq_data_file: &str,
|
||||
) -> std::io::Result<Self> {
|
||||
let pq_data_file_reader = File::open(pq_data_file)?;
|
||||
Ok(Self {
|
||||
pivot_file: pivot_file.to_string(),
|
||||
compressed_pivot_file: compressed_pivot_file.to_string(),
|
||||
pq_data_file: pq_data_file.to_string(),
|
||||
pq_data_file_reader,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn write_compressed_pivot_metadata(&self, npts: i32, pq_chunk: i32) -> std::io::Result<()> {
|
||||
let mut writer = open_file_to_write(&self.compressed_pivot_file)?;
|
||||
writer.write_all(&npts.to_le_bytes())?;
|
||||
writer.write_all(&pq_chunk.to_le_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_compressed_pivot_data(
|
||||
&self,
|
||||
compressed_base: &[usize],
|
||||
num_centers: usize,
|
||||
block_size: usize,
|
||||
num_pq_chunks: usize,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = open_file_to_write(&self.compressed_pivot_file)?;
|
||||
writer.seek(SeekFrom::Start((std::mem::size_of::<i32>() * 2) as u64))?;
|
||||
if num_centers > 256 {
|
||||
writer.write_all(unsafe {
|
||||
std::slice::from_raw_parts(
|
||||
compressed_base.as_ptr() as *const u8,
|
||||
block_size * num_pq_chunks * std::mem::size_of::<usize>(),
|
||||
)
|
||||
})?;
|
||||
} else {
|
||||
let compressed_base_u8 =
|
||||
convert_types_usize_u8(compressed_base, block_size, num_pq_chunks);
|
||||
writer.write_all(&compressed_base_u8)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write_pivot_data(
|
||||
&self,
|
||||
full_pivot_data: &[f32],
|
||||
centroid: &[f32],
|
||||
chunk_offsets: &[usize],
|
||||
num_centers: usize,
|
||||
dim: usize,
|
||||
) -> std::io::Result<()> {
|
||||
let mut cumul_bytes: Vec<usize> = vec![0; 4];
|
||||
cumul_bytes[0] = METADATA_SIZE;
|
||||
cumul_bytes[1] = cumul_bytes[0]
|
||||
+ save_bin_f32(
|
||||
&self.pivot_file,
|
||||
full_pivot_data,
|
||||
num_centers,
|
||||
dim,
|
||||
cumul_bytes[0],
|
||||
)?;
|
||||
cumul_bytes[2] =
|
||||
cumul_bytes[1] + save_bin_f32(&self.pivot_file, centroid, dim, 1, cumul_bytes[1])?;
|
||||
|
||||
// Because the writer only can write u32, u64 but not usize, so we need to convert the type first.
|
||||
let chunk_offsets_u64 = convert_types_usize_u32(chunk_offsets, chunk_offsets.len(), 1);
|
||||
cumul_bytes[3] = cumul_bytes[2]
|
||||
+ save_bin_u32(
|
||||
&self.pivot_file,
|
||||
&chunk_offsets_u64,
|
||||
chunk_offsets.len(),
|
||||
1,
|
||||
cumul_bytes[2],
|
||||
)?;
|
||||
|
||||
let cumul_bytes_u64 = convert_types_usize_u64(&cumul_bytes, 4, 1);
|
||||
save_bin_u64(&self.pivot_file, &cumul_bytes_u64, cumul_bytes.len(), 1, 0)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn pivot_data_exist(&self) -> bool {
|
||||
file_exists(&self.pivot_file)
|
||||
}
|
||||
|
||||
pub fn read_pivot_metadata(&self) -> std::io::Result<(usize, usize)> {
|
||||
let (_, file_num_centers, file_dim) = load_bin::<f32>(&self.pivot_file, METADATA_SIZE)?;
|
||||
Ok((file_num_centers, file_dim))
|
||||
}
|
||||
|
||||
pub fn load_pivot_data(
|
||||
&self,
|
||||
num_pq_chunks: &usize,
|
||||
num_centers: &usize,
|
||||
dim: &usize,
|
||||
) -> ANNResult<(Vec<f32>, Vec<f32>, Vec<usize>)> {
|
||||
// Load file offset data. File saved as offset data(4*1) -> pivot data(centroid num*dim) -> centroid of dim data(dim*1) -> chunk offset data(chunksize+1*1)
|
||||
// Because we only can write u64 rather than usize, so the file stored as u64 type. Need to convert to usize when use.
|
||||
let (data, offset_num, nc) = load_bin::<u64>(&self.pivot_file, 0)?;
|
||||
let file_offset_data = convert_types_u64_usize(&data, offset_num, nc);
|
||||
if offset_num != 4 {
|
||||
let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", &self.pivot_file, offset_num);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, pivot_num, pivot_dim) = load_bin::<f32>(&self.pivot_file, file_offset_data[0])?;
|
||||
let full_pivot_data = data;
|
||||
if pivot_num != *num_centers || pivot_dim != *dim {
|
||||
let error_message = format!("Error reading pq_pivots file {}. file_num_centers = {}, file_dim = {} but expecting {} centers in {} dimensions.", &self.pivot_file, pivot_num, pivot_dim, num_centers, dim);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, centroid_dim, nc) = load_bin::<f32>(&self.pivot_file, file_offset_data[1])?;
|
||||
let centroid = data;
|
||||
if centroid_dim != *dim || nc != 1 {
|
||||
let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", &self.pivot_file, centroid_dim, nc, dim);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
|
||||
let (data, chunk_offset_number, nc) =
|
||||
load_bin::<u32>(&self.pivot_file, file_offset_data[2])?;
|
||||
let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_number, nc);
|
||||
if chunk_offset_number != *num_pq_chunks + 1 || nc != 1 {
|
||||
let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_number, nc, num_pq_chunks + 1);
|
||||
return Err(ANNError::log_pq_error(error_message));
|
||||
}
|
||||
Ok((full_pivot_data, centroid, chunk_offsets))
|
||||
}
|
||||
|
||||
pub fn read_pq_data_metadata(&mut self) -> std::io::Result<(usize, usize)> {
|
||||
let npts_i32 = self.pq_data_file_reader.read_i32::<LittleEndian>()?;
|
||||
let dim_i32 = self.pq_data_file_reader.read_i32::<LittleEndian>()?;
|
||||
let num_points = npts_i32 as usize;
|
||||
let dim = dim_i32 as usize;
|
||||
Ok((num_points, dim))
|
||||
}
|
||||
|
||||
pub fn read_pq_block_data<T: Copy>(
|
||||
&mut self,
|
||||
cur_block_size: usize,
|
||||
dim: usize,
|
||||
) -> std::io::Result<Vec<T>> {
|
||||
let mut buf = vec![0u8; cur_block_size * dim * std::mem::size_of::<T>()];
|
||||
self.pq_data_file_reader.read_exact(&mut buf)?;
|
||||
|
||||
let ptr = buf.as_ptr() as *const T;
|
||||
let block_data = unsafe { std::slice::from_raw_parts(ptr, cur_block_size * dim) };
|
||||
Ok(block_data.to_vec())
|
||||
}
|
||||
|
||||
/// streams data from the file, and samples each vector with probability p_val
|
||||
/// and returns a matrix of size slice_size* ndims as floating point type.
|
||||
/// the slice_size and ndims are set inside the function.
|
||||
/// # Arguments
|
||||
/// * `file_name` - filename where the data is
|
||||
/// * `p_val` - possibility to sample data
|
||||
/// * `sampled_vectors` - sampled vector chose by p_val possibility
|
||||
/// * `slice_size` - how many sampled data return
|
||||
/// * `dim` - each sample data dimension
|
||||
pub fn gen_random_slice<T: Default + Copy + Into<f32>>(
|
||||
&self,
|
||||
mut p_val: f64,
|
||||
) -> ANNResult<(Vec<f32>, usize, usize)> {
|
||||
let read_blk_size = 64 * 1024 * 1024;
|
||||
let mut reader = CachedReader::new(&self.pq_data_file, read_blk_size)?;
|
||||
|
||||
let npts = reader.read_u32()? as usize;
|
||||
let dim = reader.read_u32()? as usize;
|
||||
let mut sampled_vectors: Vec<f32> = Vec::new();
|
||||
let mut slice_size = 0;
|
||||
p_val = if p_val < 1f64 { p_val } else { 1f64 };
|
||||
|
||||
let mut generator = rand::thread_rng();
|
||||
let distribution = Uniform::from(0.0..1.0);
|
||||
|
||||
for _ in 0..npts {
|
||||
let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::<T>()];
|
||||
reader.read(&mut cur_vector_bytes)?;
|
||||
let random_value = distribution.sample(&mut generator);
|
||||
if random_value < p_val {
|
||||
let ptr = cur_vector_bytes.as_ptr() as *const T;
|
||||
let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) };
|
||||
sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into()));
|
||||
slice_size += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((sampled_vectors, slice_size, dim))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod pq_storage_tests {
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
use crate::utils::gen_random_slice;
|
||||
|
||||
const DATA_FILE: &str = "tests/data/siftsmall_learn.bin";
|
||||
const PQ_PIVOT_PATH: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
|
||||
const PQ_COMPRESSED_PATH: &str = "tests/data/empty_pq_compressed.bin";
|
||||
|
||||
#[test]
|
||||
fn new_test() {
|
||||
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_compressed_pivot_metadata_test() {
|
||||
let compress_pivot_path = "write_compressed_pivot_metadata_test.bin";
|
||||
let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap();
|
||||
|
||||
_ = result.write_compressed_pivot_metadata(100, 20);
|
||||
let mut result_reader = File::open(compress_pivot_path).unwrap();
|
||||
let npts_i32 = result_reader.read_i32::<LittleEndian>().unwrap();
|
||||
let dim_i32 = result_reader.read_i32::<LittleEndian>().unwrap();
|
||||
|
||||
assert_eq!(npts_i32, 100);
|
||||
assert_eq!(dim_i32, 20);
|
||||
|
||||
std::fs::remove_file(compress_pivot_path).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_compressed_pivot_data_test() {
|
||||
let compress_pivot_path = "write_compressed_pivot_data_test.bin";
|
||||
let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap();
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let num_centers = 256;
|
||||
let block_size = 4;
|
||||
let num_pq_chunks = 2;
|
||||
let compressed_base: Vec<usize> = (0..block_size * num_pq_chunks)
|
||||
.map(|_| rng.gen_range(0..num_centers))
|
||||
.collect();
|
||||
_ = result.write_compressed_pivot_data(
|
||||
&compressed_base,
|
||||
num_centers,
|
||||
block_size,
|
||||
num_pq_chunks,
|
||||
);
|
||||
|
||||
let mut result_reader = File::open(compress_pivot_path).unwrap();
|
||||
_ = result_reader.read_i32::<LittleEndian>().unwrap();
|
||||
_ = result_reader.read_i32::<LittleEndian>().unwrap();
|
||||
let mut buf = vec![0u8; block_size * num_pq_chunks * std::mem::size_of::<u8>()];
|
||||
result_reader.read_exact(&mut buf).unwrap();
|
||||
|
||||
let ptr = buf.as_ptr() as *const u8;
|
||||
let block_data = unsafe { std::slice::from_raw_parts(ptr, block_size * num_pq_chunks) };
|
||||
|
||||
for index in 0..block_data.len() {
|
||||
assert_eq!(compressed_base[index], block_data[index] as usize);
|
||||
}
|
||||
std::fs::remove_file(compress_pivot_path).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pivot_data_exist_test() {
|
||||
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
|
||||
assert!(result.pivot_data_exist());
|
||||
|
||||
let pivot_path = "not_exist_pivot_path.bin";
|
||||
let result = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
|
||||
assert!(!result.pivot_data_exist());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pivot_metadata_test() {
|
||||
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
|
||||
let (npt, dim) = result.read_pivot_metadata().unwrap();
|
||||
|
||||
assert_eq!(npt, 256);
|
||||
assert_eq!(dim, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_pivot_data_test() {
|
||||
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
|
||||
let (pq_pivot_data, centroids, chunk_offsets) =
|
||||
result.load_pivot_data(&1, &256, &128).unwrap();
|
||||
|
||||
assert_eq!(pq_pivot_data.len(), 256 * 128);
|
||||
assert_eq!(centroids.len(), 128);
|
||||
assert_eq!(chunk_offsets.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_pq_data_metadata_test() {
|
||||
let mut result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
|
||||
let (npt, dim) = result.read_pq_data_metadata().unwrap();
|
||||
|
||||
assert_eq!(npt, 25000);
|
||||
assert_eq!(dim, 128);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gen_random_slice_test() {
|
||||
let file_name = "gen_random_slice_test.bin";
|
||||
//npoints=2, dim=8
|
||||
let data: [u8; 72] = [
|
||||
2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
|
||||
0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
|
||||
0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
|
||||
0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
|
||||
0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
|
||||
];
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
|
||||
let (sampled_vectors, slice_size, ndims) =
|
||||
gen_random_slice::<f32>(file_name, 1f64).unwrap();
|
||||
let mut start = 8;
|
||||
(0..sampled_vectors.len()).for_each(|i| {
|
||||
assert_eq!(sampled_vectors[i].to_le_bytes(), data[start..start + 4]);
|
||||
start += 4;
|
||||
});
|
||||
assert_eq!(sampled_vectors.len(), 16);
|
||||
assert_eq!(slice_size, 2);
|
||||
assert_eq!(ndims, 8);
|
||||
|
||||
let (sampled_vectors, slice_size, ndims) =
|
||||
gen_random_slice::<f32>(file_name, 0f64).unwrap();
|
||||
assert_eq!(sampled_vectors.len(), 0);
|
||||
assert_eq!(slice_size, 0);
|
||||
assert_eq!(ndims, 8);
|
||||
|
||||
std::fs::remove_file(file_name).expect("Failed to delete file");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use vector::Metric;
|
||||
|
||||
use crate::index::InmemIndex;
|
||||
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
|
||||
use crate::model::{IndexConfiguration};
|
||||
use crate::model::vertex::DIM_128;
|
||||
use crate::utils::{file_exists, load_metadata_from_file};
|
||||
|
||||
use super::get_test_file_path;
|
||||
|
||||
// f32, 128 DIM and 256 points source data
|
||||
const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin";
|
||||
const NUM_POINTS_TO_LOAD: usize = 256;
|
||||
|
||||
pub fn create_index_with_test_data() -> InmemIndex<f32, DIM_128> {
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4).with_alpha(1.2).build();
|
||||
let config = IndexConfiguration::new(
|
||||
Metric::L2,
|
||||
128,
|
||||
128,
|
||||
256,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
0,
|
||||
1.0f32,
|
||||
index_write_parameters);
|
||||
let mut index: InmemIndex<f32, DIM_128> = InmemIndex::new(config).unwrap();
|
||||
|
||||
build_test_index(&mut index, get_test_file_path(TEST_DATA_FILE).as_str(), NUM_POINTS_TO_LOAD);
|
||||
|
||||
index.start = index.dataset.calculate_medoid_point_id().unwrap();
|
||||
|
||||
index
|
||||
}
|
||||
|
||||
fn build_test_index(index: &mut InmemIndex<f32, DIM_128>, filename: &str, num_points_to_load: usize) {
|
||||
if !file_exists(filename) {
|
||||
panic!("ERROR: Data file {} does not exist.", filename);
|
||||
}
|
||||
|
||||
let (file_num_points, file_dim) = load_metadata_from_file(filename).unwrap();
|
||||
if file_num_points > index.configuration.max_points {
|
||||
panic!(
|
||||
"ERROR: Driver requests loading {} points and file has {} points,
|
||||
but index can support only {} points as specified in configuration.",
|
||||
num_points_to_load, file_num_points, index.configuration.max_points
|
||||
);
|
||||
}
|
||||
|
||||
if num_points_to_load > file_num_points {
|
||||
panic!(
|
||||
"ERROR: Driver requests loading {} points and file has only {} points.",
|
||||
num_points_to_load, file_num_points
|
||||
);
|
||||
}
|
||||
|
||||
if file_dim != index.configuration.dim {
|
||||
panic!(
|
||||
"ERROR: Driver requests loading {} dimension, but file has {} dimension.",
|
||||
index.configuration.dim, file_dim
|
||||
);
|
||||
}
|
||||
|
||||
index.dataset.build_from_file(filename, num_points_to_load).unwrap();
|
||||
|
||||
println!("Using only first {} from file.", num_points_to_load);
|
||||
|
||||
index.num_active_pts = num_points_to_load;
|
||||
}
|
||||
11
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/test_utils/mod.rs
vendored
Normal file
11
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/test_utils/mod.rs
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub mod inmem_index_initialization;
|
||||
|
||||
/// test files should be placed under tests folder
|
||||
pub fn get_test_file_path(relative_path: &str) -> String {
|
||||
format!("{}/{}", env!("CARGO_MANIFEST_DIR"), relative_path)
|
||||
}
|
||||
|
||||
45
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/bit_vec_extension.rs
vendored
Normal file
45
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/bit_vec_extension.rs
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use bit_vec::BitVec;
|
||||
|
||||
pub trait BitVecExtension {
|
||||
fn resize(&mut self, new_len: usize, value: bool);
|
||||
}
|
||||
|
||||
impl BitVecExtension for BitVec {
|
||||
fn resize(&mut self, new_len: usize, value: bool) {
|
||||
let old_len = self.len();
|
||||
match new_len.cmp(&old_len) {
|
||||
Ordering::Less => self.truncate(new_len),
|
||||
Ordering::Greater => self.grow(new_len - old_len, value),
|
||||
Ordering::Equal => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod bit_vec_extension_test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resize_test() {
|
||||
let mut bitset = BitVec::new();
|
||||
|
||||
bitset.resize(10, false);
|
||||
assert_eq!(bitset.len(), 10);
|
||||
assert!(bitset.none());
|
||||
|
||||
bitset.resize(11, true);
|
||||
assert_eq!(bitset.len(), 11);
|
||||
assert!(bitset[10]);
|
||||
|
||||
bitset.resize(5, false);
|
||||
assert_eq!(bitset.len(), 5);
|
||||
assert!(bitset.none());
|
||||
}
|
||||
}
|
||||
|
||||
160
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/cached_reader.rs
vendored
Normal file
160
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/cached_reader.rs
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::fs::File;
|
||||
use std::io::{Seek, Read};
|
||||
|
||||
use crate::common::{ANNResult, ANNError};
|
||||
|
||||
/// Sequential cached reads
|
||||
pub struct CachedReader {
|
||||
/// File reader
|
||||
reader: File,
|
||||
|
||||
/// # bytes to cache in one shot read
|
||||
cache_size: u64,
|
||||
|
||||
/// Underlying buf for cache
|
||||
cache_buf: Vec<u8>,
|
||||
|
||||
/// Offset into cache_buf for cur_pos
|
||||
cur_off: u64,
|
||||
|
||||
/// File size
|
||||
fsize: u64,
|
||||
}
|
||||
|
||||
impl CachedReader {
|
||||
pub fn new(filename: &str, cache_size: u64) -> std::io::Result<Self> {
|
||||
let mut reader = File::open(filename)?;
|
||||
let metadata = reader.metadata()?;
|
||||
let fsize = metadata.len();
|
||||
|
||||
let cache_size = cache_size.min(fsize);
|
||||
let mut cache_buf = vec![0; cache_size as usize];
|
||||
reader.read_exact(&mut cache_buf)?;
|
||||
println!("Opened: {}, size: {}, cache_size: {}", filename, fsize, cache_size);
|
||||
|
||||
Ok(Self {
|
||||
reader,
|
||||
cache_size,
|
||||
cache_buf,
|
||||
cur_off: 0,
|
||||
fsize,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_file_size(&self) -> u64 {
|
||||
self.fsize
|
||||
}
|
||||
|
||||
pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> {
|
||||
let n_bytes = read_buf.len() as u64;
|
||||
if n_bytes <= (self.cache_size - self.cur_off) {
|
||||
// case 1: cache contains all data
|
||||
read_buf.copy_from_slice(&self.cache_buf[(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)]);
|
||||
self.cur_off += n_bytes;
|
||||
} else {
|
||||
// case 2: cache contains some data
|
||||
let cached_bytes = self.cache_size - self.cur_off;
|
||||
if n_bytes - cached_bytes > self.fsize - self.reader.stream_position()? {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}",
|
||||
n_bytes, cached_bytes, self.fsize, self.reader.stream_position()?))
|
||||
);
|
||||
}
|
||||
|
||||
read_buf[..cached_bytes as usize].copy_from_slice(&self.cache_buf[self.cur_off as usize..]);
|
||||
// go to disk and fetch more data
|
||||
self.reader.read_exact(&mut read_buf[cached_bytes as usize..])?;
|
||||
// reset cur off
|
||||
self.cur_off = self.cache_size;
|
||||
|
||||
let size_left = self.fsize - self.reader.stream_position()?;
|
||||
if size_left >= self.cache_size {
|
||||
self.reader.read_exact(&mut self.cache_buf)?;
|
||||
self.cur_off = 0;
|
||||
}
|
||||
// note that if size_left < cache_size, then cur_off = cache_size,
|
||||
// so subsequent reads will all be directly from file
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read_u32(&mut self) -> ANNResult<u32> {
|
||||
let mut bytes = [0u8; 4];
|
||||
self.read(&mut bytes)?;
|
||||
Ok(u32::from_le_bytes(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod cached_reader_test {
|
||||
use std::fs;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cached_reader_works() {
|
||||
let file_name = "cached_reader_works_test.bin";
|
||||
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
|
||||
let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3,
|
||||
0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41];
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
|
||||
let mut reader = CachedReader::new(file_name, 8).unwrap();
|
||||
assert_eq!(reader.get_file_size(), 72);
|
||||
assert_eq!(reader.cache_size, 8);
|
||||
|
||||
let mut all_from_cache_buf = vec![0; 4];
|
||||
reader.read(all_from_cache_buf.as_mut_slice()).unwrap();
|
||||
assert_eq!(all_from_cache_buf, [2, 0, 1, 2]);
|
||||
assert_eq!(reader.cur_off, 4);
|
||||
|
||||
let mut partial_from_cache_buf = vec![0; 6];
|
||||
reader.read(partial_from_cache_buf.as_mut_slice()).unwrap();
|
||||
assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]);
|
||||
assert_eq!(reader.cur_off, 0);
|
||||
|
||||
let mut over_cache_size_buf = vec![0; 60];
|
||||
reader.read(over_cache_size_buf.as_mut_slice()).unwrap();
|
||||
assert_eq!(
|
||||
over_cache_size_buf,
|
||||
[0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11]
|
||||
);
|
||||
|
||||
let mut remaining_less_than_cache_size_buf = vec![0; 2];
|
||||
reader.read(remaining_less_than_cache_size_buf.as_mut_slice()).unwrap();
|
||||
assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]);
|
||||
assert_eq!(reader.cur_off, reader.cache_size);
|
||||
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "n_bytes: 73 cached_bytes: 8 fsize: 72 current pos: 8")]
|
||||
fn failed_for_reading_beyond_end_of_file() {
|
||||
let file_name = "failed_for_reading_beyond_end_of_file_test.bin";
|
||||
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
|
||||
let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3,
|
||||
0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41];
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
|
||||
let mut reader = CachedReader::new(file_name, 8).unwrap();
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
|
||||
let mut over_size_buf = vec![0; 73];
|
||||
reader.read(over_size_buf.as_mut_slice()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
142
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/cached_writer.rs
vendored
Normal file
142
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/cached_writer.rs
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::io::{Write, Seek, SeekFrom};
|
||||
use std::fs::{OpenOptions, File};
|
||||
use std::path::Path;
|
||||
|
||||
pub struct CachedWriter {
|
||||
/// File writer
|
||||
writer: File,
|
||||
|
||||
/// # bytes to cache for one shot write
|
||||
cache_size: u64,
|
||||
|
||||
/// Underlying buf for cache
|
||||
cache_buf: Vec<u8>,
|
||||
|
||||
/// Offset into cache_buf for cur_pos
|
||||
cur_off: u64,
|
||||
|
||||
/// File size
|
||||
fsize: u64,
|
||||
}
|
||||
|
||||
impl CachedWriter {
|
||||
pub fn new(filename: &str, cache_size: u64) -> std::io::Result<Self> {
|
||||
let writer = OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(Path::new(filename))?;
|
||||
|
||||
if cache_size == 0 {
|
||||
return Err(std::io::Error::new(std::io::ErrorKind::Other, "Cache size must be greater than 0"));
|
||||
}
|
||||
|
||||
println!("Opened: {}, cache_size: {}", filename, cache_size);
|
||||
Ok(Self {
|
||||
writer,
|
||||
cache_size,
|
||||
cache_buf: vec![0; cache_size as usize],
|
||||
cur_off: 0,
|
||||
fsize: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) -> std::io::Result<()> {
|
||||
// dump any remaining data in memory
|
||||
if self.cur_off > 0 {
|
||||
self.flush_cache()?;
|
||||
}
|
||||
|
||||
self.writer.flush()?;
|
||||
println!("Finished writing {}B", self.fsize);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_file_size(&self) -> u64 {
|
||||
self.fsize
|
||||
}
|
||||
|
||||
/// Writes n_bytes from write_buf to the underlying cache
|
||||
pub fn write(&mut self, write_buf: &[u8]) -> std::io::Result<()> {
|
||||
let n_bytes = write_buf.len() as u64;
|
||||
if n_bytes <= (self.cache_size - self.cur_off) {
|
||||
// case 1: cache can take all data
|
||||
self.cache_buf[(self.cur_off as usize)..((self.cur_off + n_bytes) as usize)].copy_from_slice(&write_buf[..n_bytes as usize]);
|
||||
self.cur_off += n_bytes;
|
||||
} else {
|
||||
// case 2: cache cant take all data
|
||||
// go to disk and write existing cache data
|
||||
self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?;
|
||||
self.fsize += self.cur_off;
|
||||
// write the new data to disk
|
||||
self.writer.write_all(write_buf)?;
|
||||
self.fsize += n_bytes;
|
||||
// clear cache data and reset cur_off
|
||||
self.cache_buf.fill(0);
|
||||
self.cur_off = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) -> std::io::Result<()> {
|
||||
self.flush_cache()?;
|
||||
self.writer.seek(SeekFrom::Start(0))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn flush_cache(&mut self) -> std::io::Result<()> {
|
||||
self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?;
|
||||
self.fsize += self.cur_off;
|
||||
self.cache_buf.fill(0);
|
||||
self.cur_off = 0;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CachedWriter {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod cached_writer_test {
|
||||
use std::fs;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cached_writer_works() {
|
||||
let file_name = "cached_writer_works_test.bin";
|
||||
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
|
||||
let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3,
|
||||
0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41];
|
||||
|
||||
let mut writer = CachedWriter::new(file_name, 8).unwrap();
|
||||
assert_eq!(writer.get_file_size(), 0);
|
||||
assert_eq!(writer.cache_size, 8);
|
||||
assert_eq!(writer.get_file_size(), 0);
|
||||
|
||||
let cache_all_buf = &data[0..4];
|
||||
writer.write(cache_all_buf).unwrap();
|
||||
assert_eq!(&writer.cache_buf[..4], cache_all_buf);
|
||||
assert_eq!(&writer.cache_buf[4..], vec![0; 4]);
|
||||
assert_eq!(writer.cur_off, 4);
|
||||
assert_eq!(writer.get_file_size(), 0);
|
||||
|
||||
let write_all_buf = &data[4..10];
|
||||
writer.write(write_all_buf).unwrap();
|
||||
assert_eq!(writer.cache_buf, vec![0; 8]);
|
||||
assert_eq!(writer.cur_off, 0);
|
||||
assert_eq!(writer.get_file_size(), 10);
|
||||
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
}
|
||||
}
|
||||
|
||||
377
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/file_util.rs
vendored
Normal file
377
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/file_util.rs
vendored
Normal file
@@ -0,0 +1,377 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! File operations
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use std::{mem, io};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, BufReader, Write, Seek, SeekFrom};
|
||||
use std::path::Path;
|
||||
|
||||
use crate::model::data_store::DatasetDto;
|
||||
|
||||
/// Read metadata of data file.
|
||||
pub fn load_metadata_from_file(file_name: &str) -> std::io::Result<(usize, usize)> {
|
||||
let file = File::open(file_name)?;
|
||||
let mut reader = BufReader::new(file);
|
||||
|
||||
let npoints = reader.read_i32::<LittleEndian>()? as usize;
|
||||
let ndims = reader.read_i32::<LittleEndian>()? as usize;
|
||||
|
||||
Ok((npoints, ndims))
|
||||
}
|
||||
|
||||
/// Read the deleted vertex ids from file.
|
||||
pub fn load_ids_to_delete_from_file(file_name: &str) -> std::io::Result<(usize, Vec<u32>)> {
|
||||
// The first 4 bytes are the number of vector ids.
|
||||
// The rest of the file are the vector ids in the format of usize.
|
||||
// The vector ids are sorted in ascending order.
|
||||
let mut file = File::open(file_name)?;
|
||||
let num_ids = file.read_u32::<LittleEndian>()? as usize;
|
||||
|
||||
let mut ids = Vec::with_capacity(num_ids);
|
||||
for _ in 0..num_ids {
|
||||
let id = file.read_u32::<LittleEndian>()?;
|
||||
ids.push(id);
|
||||
}
|
||||
|
||||
Ok((num_ids, ids))
|
||||
}
|
||||
|
||||
/// Copy data from file
|
||||
/// # Arguments
|
||||
/// * `bin_file` - filename where the data is
|
||||
/// * `data` - destination dataset dto to which the data is copied
|
||||
/// * `pts_offset` - offset of points. data will be loaded after this point in dataset
|
||||
/// * `npts` - number of points read from bin_file
|
||||
/// * `dim` - point dimension read from bin_file
|
||||
/// * `rounded_dim` - rounded dimension (padding zero if it's > dim)
|
||||
/// # Return
|
||||
/// * `npts` - number of points read from bin_file
|
||||
/// * `dim` - point dimension read from bin_file
|
||||
pub fn copy_aligned_data_from_file<T: Default + Copy>(
|
||||
bin_file: &str,
|
||||
dataset_dto: DatasetDto<T>,
|
||||
pts_offset: usize,
|
||||
) -> std::io::Result<(usize, usize)> {
|
||||
let mut reader = File::open(bin_file)?;
|
||||
|
||||
let npts = reader.read_i32::<LittleEndian>()? as usize;
|
||||
let dim = reader.read_i32::<LittleEndian>()? as usize;
|
||||
let rounded_dim = dataset_dto.rounded_dim;
|
||||
let offset = pts_offset * rounded_dim;
|
||||
|
||||
for i in 0..npts {
|
||||
let data_slice = &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim];
|
||||
let mut buf = vec![0u8; dim * mem::size_of::<T>()];
|
||||
reader.read_exact(&mut buf)?;
|
||||
|
||||
let ptr = buf.as_ptr() as *const T;
|
||||
let temp_slice = unsafe { std::slice::from_raw_parts(ptr, dim) };
|
||||
data_slice.copy_from_slice(temp_slice);
|
||||
|
||||
(i * rounded_dim + dim..i * rounded_dim + rounded_dim).for_each(|j| {
|
||||
dataset_dto.data[j] = T::default();
|
||||
});
|
||||
}
|
||||
|
||||
Ok((npts, dim))
|
||||
}
|
||||
|
||||
/// Open a file to write
|
||||
/// # Arguments
|
||||
/// * `writer` - mutable File reference
|
||||
/// * `file_name` - file name
|
||||
#[inline]
|
||||
pub fn open_file_to_write(file_name: &str) -> std::io::Result<File> {
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(Path::new(file_name))
|
||||
}
|
||||
|
||||
/// Delete a file
|
||||
/// # Arguments
|
||||
/// * `file_name` - file name
|
||||
pub fn delete_file(file_name: &str) -> std::io::Result<()> {
|
||||
if file_exists(file_name) {
|
||||
fs::remove_file(file_name)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether file exists or not
|
||||
pub fn file_exists(filename: &str) -> bool {
|
||||
std::path::Path::new(filename).exists()
|
||||
}
|
||||
|
||||
/// Save data to file
|
||||
/// # Arguments
|
||||
/// * `filename` - filename where the data is
|
||||
/// * `data` - information data
|
||||
/// * `npts` - number of points
|
||||
/// * `ndims` - point dimension
|
||||
/// * `aligned_dim` - aligned dimension
|
||||
/// * `offset` - data offset in file
|
||||
pub fn save_data_in_base_dimensions<T: Default + Copy>(
|
||||
filename: &str,
|
||||
data: &mut [T],
|
||||
npts: usize,
|
||||
ndims: usize,
|
||||
aligned_dim: usize,
|
||||
offset: usize,
|
||||
) -> std::io::Result<usize> {
|
||||
let mut writer = open_file_to_write(filename)?;
|
||||
let npts_i32 = npts as i32;
|
||||
let ndims_i32 = ndims as i32;
|
||||
let bytes_written = 2 * std::mem::size_of::<u32>() + npts * ndims * (std::mem::size_of::<T>());
|
||||
|
||||
writer.seek(std::io::SeekFrom::Start(offset as u64))?;
|
||||
writer.write_all(&npts_i32.to_le_bytes())?;
|
||||
writer.write_all(&ndims_i32.to_le_bytes())?;
|
||||
let data_ptr = data.as_ptr() as *const u8;
|
||||
for i in 0..npts {
|
||||
let middle_offset = i * aligned_dim * std::mem::size_of::<T>();
|
||||
let middle_slice = unsafe { std::slice::from_raw_parts(data_ptr.add(middle_offset), ndims * std::mem::size_of::<T>()) };
|
||||
writer.write_all(middle_slice)?;
|
||||
}
|
||||
writer.flush()?;
|
||||
Ok(bytes_written)
|
||||
}
|
||||
|
||||
/// Read data file
|
||||
/// # Arguments
|
||||
/// * `bin_file` - filename where the data is
|
||||
/// * `file_offset` - data offset in file
|
||||
/// * `data` - information data
|
||||
/// * `npts` - number of points
|
||||
/// * `ndims` - point dimension
|
||||
pub fn load_bin<T: Copy>(
|
||||
bin_file: &str,
|
||||
file_offset: usize) -> std::io::Result<(Vec<T>, usize, usize)>
|
||||
{
|
||||
let mut reader = File::open(bin_file)?;
|
||||
reader.seek(std::io::SeekFrom::Start(file_offset as u64))?;
|
||||
let npts = reader.read_i32::<LittleEndian>()? as usize;
|
||||
let dim = reader.read_i32::<LittleEndian>()? as usize;
|
||||
|
||||
let size = npts * dim * std::mem::size_of::<T>();
|
||||
let mut buf = vec![0u8; size];
|
||||
reader.read_exact(&mut buf)?;
|
||||
|
||||
let ptr = buf.as_ptr() as *const T;
|
||||
let data = unsafe { std::slice::from_raw_parts(ptr, npts * dim)};
|
||||
|
||||
Ok((data.to_vec(), npts, dim))
|
||||
}
|
||||
|
||||
/// Get file size
|
||||
pub fn get_file_size(filename: &str) -> io::Result<u64> {
|
||||
let reader = File::open(filename)?;
|
||||
let metadata = reader.metadata()?;
|
||||
Ok(metadata.len())
|
||||
}
|
||||
|
||||
macro_rules! save_bin {
|
||||
($name:ident, $t:ty, $write_func:ident) => {
|
||||
/// Write data into file
|
||||
pub fn $name(filename: &str, data: &[$t], num_pts: usize, dims: usize, offset: usize) -> std::io::Result<usize> {
|
||||
let mut writer = open_file_to_write(filename)?;
|
||||
|
||||
println!("Writing bin: {}", filename);
|
||||
writer.seek(SeekFrom::Start(offset as u64))?;
|
||||
let num_pts_i32 = num_pts as i32;
|
||||
let dims_i32 = dims as i32;
|
||||
let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::<u32>();
|
||||
|
||||
writer.write_i32::<LittleEndian>(num_pts_i32)?;
|
||||
writer.write_i32::<LittleEndian>(dims_i32)?;
|
||||
println!("bin: #pts = {}, #dims = {}, size = {}B", num_pts, dims, bytes_written);
|
||||
|
||||
for item in data.iter() {
|
||||
writer.$write_func::<LittleEndian>(*item)?;
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
|
||||
println!("Finished writing bin.");
|
||||
Ok(bytes_written)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
save_bin!(save_bin_f32, f32, write_f32);
|
||||
save_bin!(save_bin_u64, u64, write_u64);
|
||||
save_bin!(save_bin_u32, u32, write_u32);
|
||||
|
||||
#[cfg(test)]
|
||||
mod file_util_test {
|
||||
use crate::model::data_store::InmemDataset;
|
||||
use std::fs;
|
||||
use super::*;
|
||||
|
||||
pub const DIM_8: usize = 8;
|
||||
|
||||
#[test]
|
||||
fn load_metadata_test() {
|
||||
let file_name = "test_load_metadata_test.bin";
|
||||
let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
match load_metadata_from_file(file_name) {
|
||||
Ok((npoints, ndims)) => {
|
||||
assert!(npoints == 200);
|
||||
assert!(ndims == 128);
|
||||
},
|
||||
Err(_e) => {},
|
||||
}
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_data_test() {
|
||||
let file_name = "test_load_data_test.bin";
|
||||
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
|
||||
let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0,
|
||||
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41];
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
|
||||
let mut dataset = InmemDataset::<f32, DIM_8>::new(2, 1f32).unwrap();
|
||||
|
||||
match copy_aligned_data_from_file(file_name, dataset.into_dto(), 0) {
|
||||
Ok((num_points, dim)) => {
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
assert!(num_points == 2);
|
||||
assert!(dim == 8);
|
||||
assert!(dataset.data.len() == 16);
|
||||
|
||||
let first_vertex = dataset.get_vertex(0).unwrap();
|
||||
let second_vertex = dataset.get_vertex(1).unwrap();
|
||||
|
||||
assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
|
||||
assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
|
||||
},
|
||||
Err(e) => {
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
panic!("{}", e)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn open_file_to_write_test() {
|
||||
let file_name = "test_open_file_to_write_test.bin";
|
||||
let mut writer = File::create(file_name).unwrap();
|
||||
let data = [200, 0, 0, 0, 128, 0, 0, 0];
|
||||
writer.write(&data).expect("Failed to write sample file");
|
||||
|
||||
let _ = open_file_to_write(file_name);
|
||||
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_file_test() {
|
||||
let file_name = "test_delete_file_test.bin";
|
||||
let mut file = File::create(file_name).unwrap();
|
||||
writeln!(file, "test delete file").unwrap();
|
||||
|
||||
let result = delete_file(file_name);
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert!(fs::metadata(file_name).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_data_in_base_dimensions_test() {
|
||||
//npoints=2, dim=8
|
||||
let mut data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0,
|
||||
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41];
|
||||
let num_points = 2;
|
||||
let dim = DIM_8;
|
||||
let data_file = "save_data_in_base_dimensions_test.data";
|
||||
match save_data_in_base_dimensions(data_file, &mut data, num_points, dim, DIM_8, 0) {
|
||||
Ok(num) => {
|
||||
assert!(file_exists(data_file));
|
||||
assert_eq!(num, 2 * std::mem::size_of::<u32>() + num_points * dim * std::mem::size_of::<u8>());
|
||||
fs::remove_file(data_file).expect("Failed to delete file");
|
||||
},
|
||||
Err(e) => {
|
||||
fs::remove_file(data_file).expect("Failed to delete file");
|
||||
panic!("{}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_bin_test() {
|
||||
let filename = "save_bin_test";
|
||||
let data = vec![0u64, 1u64, 2u64];
|
||||
let num_pts = data.len();
|
||||
let dims = 1;
|
||||
let bytes_written = save_bin_u64(filename, &data, num_pts, dims, 0).unwrap();
|
||||
assert_eq!(bytes_written, 32);
|
||||
|
||||
let mut file = File::open(filename).unwrap();
|
||||
let mut buffer = vec![];
|
||||
|
||||
let npts_read = file.read_i32::<LittleEndian>().unwrap() as usize;
|
||||
let dims_read = file.read_i32::<LittleEndian>().unwrap() as usize;
|
||||
|
||||
file.read_to_end(&mut buffer).unwrap();
|
||||
let data_read: Vec<u64> = buffer
|
||||
.chunks_exact(8)
|
||||
.map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]))
|
||||
.collect();
|
||||
|
||||
std::fs::remove_file(filename).unwrap();
|
||||
|
||||
assert_eq!(num_pts, npts_read);
|
||||
assert_eq!(dims, dims_read);
|
||||
assert_eq!(data, data_read);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_bin_test() {
|
||||
let file_name = "load_bin_test";
|
||||
let data = vec![0u64, 1u64, 2u64];
|
||||
let num_pts = data.len();
|
||||
let dims = 1;
|
||||
let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, 0).unwrap();
|
||||
assert_eq!(bytes_written, 32);
|
||||
|
||||
let (load_data, load_num_pts, load_dims) = load_bin::<u64>(file_name, 0).unwrap();
|
||||
assert_eq!(load_num_pts, num_pts);
|
||||
assert_eq!(load_dims, dims);
|
||||
assert_eq!(load_data, data);
|
||||
std::fs::remove_file(file_name).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_bin_offset_test() {
|
||||
let offset:usize = 32;
|
||||
let file_name = "load_bin_offset_test";
|
||||
let data = vec![0u64, 1u64, 2u64];
|
||||
let num_pts = data.len();
|
||||
let dims = 1;
|
||||
let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, offset).unwrap();
|
||||
assert_eq!(bytes_written, 32);
|
||||
|
||||
let (load_data, load_num_pts, load_dims) = load_bin::<u64>(file_name, offset).unwrap();
|
||||
assert_eq!(load_num_pts, num_pts);
|
||||
assert_eq!(load_dims, dims);
|
||||
assert_eq!(load_data, data);
|
||||
std::fs::remove_file(file_name).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
46
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/hashset_u32.rs
vendored
Normal file
46
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/hashset_u32.rs
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use hashbrown::HashSet;
|
||||
use std::{hash::BuildHasherDefault, ops::{Deref, DerefMut}};
|
||||
use fxhash::FxHasher;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
/// Singleton hasher.
|
||||
static ref HASHER: BuildHasherDefault<FxHasher> = {
|
||||
BuildHasherDefault::<FxHasher>::default()
|
||||
};
|
||||
}
|
||||
|
||||
pub struct HashSetForU32 {
|
||||
hashset: HashSet::<u32, BuildHasherDefault<FxHasher>>,
|
||||
}
|
||||
|
||||
impl HashSetForU32 {
|
||||
pub fn with_capacity(capacity: usize) -> HashSetForU32 {
|
||||
let hashset = HashSet::<u32, BuildHasherDefault<FxHasher>>::with_capacity_and_hasher(capacity, HASHER.clone());
|
||||
HashSetForU32 {
|
||||
hashset
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for HashSetForU32 {
|
||||
type Target = HashSet::<u32, BuildHasherDefault<FxHasher>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.hashset
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for HashSetForU32 {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.hashset
|
||||
}
|
||||
}
|
||||
|
||||
430
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/kmeans.rs
vendored
Normal file
430
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/kmeans.rs
vendored
Normal file
@@ -0,0 +1,430 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Aligned allocator
|
||||
|
||||
use rand::{distributions::Uniform, prelude::Distribution, thread_rng};
|
||||
use rayon::prelude::*;
|
||||
use std::cmp::min;
|
||||
|
||||
use crate::common::ANNResult;
|
||||
use crate::utils::math_util::{calc_distance, compute_closest_centers, compute_vecs_l2sq};
|
||||
|
||||
/// Run Lloyds one iteration
|
||||
/// Given data in row-major num_points * dim, and centers in row-major
|
||||
/// num_centers * dim and squared lengths of ata points, output the closest
|
||||
/// center to each data point, update centers, and also return inverted index.
|
||||
/// If closest_centers == NULL, will allocate memory and return.
|
||||
/// Similarly, if closest_docs == NULL, will allocate memory and return.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn lloyds_iter(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
centers: &mut [f32],
|
||||
num_centers: usize,
|
||||
docs_l2sq: &[f32],
|
||||
mut closest_docs: &mut Vec<Vec<usize>>,
|
||||
closest_center: &mut [u32],
|
||||
) -> ANNResult<f32> {
|
||||
let compute_residual = true;
|
||||
|
||||
closest_docs.iter_mut().for_each(|doc| doc.clear());
|
||||
|
||||
compute_closest_centers(
|
||||
data,
|
||||
num_points,
|
||||
dim,
|
||||
centers,
|
||||
num_centers,
|
||||
1,
|
||||
closest_center,
|
||||
Some(&mut closest_docs),
|
||||
Some(docs_l2sq),
|
||||
)?;
|
||||
|
||||
centers.fill(0.0);
|
||||
|
||||
centers
|
||||
.par_chunks_mut(dim)
|
||||
.enumerate()
|
||||
.for_each(|(c, center)| {
|
||||
let mut cluster_sum = vec![0.0; dim];
|
||||
for &doc_index in &closest_docs[c] {
|
||||
let current = &data[doc_index * dim..(doc_index + 1) * dim];
|
||||
for (j, current_val) in current.iter().enumerate() {
|
||||
cluster_sum[j] += *current_val as f64;
|
||||
}
|
||||
}
|
||||
if !closest_docs[c].is_empty() {
|
||||
for (i, sum_val) in cluster_sum.iter().enumerate() {
|
||||
center[i] = (*sum_val / closest_docs[c].len() as f64) as f32;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut residual = 0.0;
|
||||
if compute_residual {
|
||||
let buf_pad: usize = 32;
|
||||
let chunk_size: usize = 2 * 8192;
|
||||
let nchunks =
|
||||
num_points / chunk_size + (if num_points % chunk_size == 0 { 0 } else { 1 } as usize);
|
||||
|
||||
let mut residuals: Vec<f32> = vec![0.0; nchunks * buf_pad];
|
||||
|
||||
residuals
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(chunk, res)| {
|
||||
for d in (chunk * chunk_size)..min(num_points, (chunk + 1) * chunk_size) {
|
||||
*res += calc_distance(
|
||||
&data[d * dim..(d + 1) * dim],
|
||||
¢ers[closest_center[d] as usize * dim..],
|
||||
dim,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
for chunk in 0..nchunks {
|
||||
residual += residuals[chunk * buf_pad];
|
||||
}
|
||||
}
|
||||
|
||||
Ok(residual)
|
||||
}
|
||||
|
||||
/// Run Lloyds until max_reps or stopping criterion
|
||||
/// If you pass NULL for closest_docs and closest_center, it will NOT return
|
||||
/// the results, else it will assume appropriate allocation as closest_docs =
|
||||
/// new vec<usize> [num_centers], and closest_center = new size_t[num_points]
|
||||
/// Final centers are output in centers as row-major num_centers * dim.
|
||||
fn run_lloyds(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
centers: &mut [f32],
|
||||
num_centers: usize,
|
||||
max_reps: usize,
|
||||
) -> ANNResult<(Vec<Vec<usize>>, Vec<u32>, f32)> {
|
||||
let mut residual = f32::MAX;
|
||||
|
||||
let mut closest_docs = vec![Vec::new(); num_centers];
|
||||
let mut closest_center = vec![0; num_points];
|
||||
|
||||
let mut docs_l2sq = vec![0.0; num_points];
|
||||
compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim);
|
||||
|
||||
let mut old_residual;
|
||||
|
||||
for i in 0..max_reps {
|
||||
old_residual = residual;
|
||||
|
||||
residual = lloyds_iter(
|
||||
data,
|
||||
num_points,
|
||||
dim,
|
||||
centers,
|
||||
num_centers,
|
||||
&docs_l2sq,
|
||||
&mut closest_docs,
|
||||
&mut closest_center,
|
||||
)?;
|
||||
|
||||
if (i != 0 && (old_residual - residual) / residual < 0.00001) || (residual < f32::EPSILON) {
|
||||
println!(
|
||||
"Residuals unchanged: {} becomes {}. Early termination.",
|
||||
old_residual, residual
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((closest_docs, closest_center, residual))
|
||||
}
|
||||
|
||||
/// Assume memory allocated for pivot_data as new float[num_centers * dim]
|
||||
/// and select randomly num_centers points as pivots
|
||||
fn selecting_pivots(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
pivot_data: &mut [f32],
|
||||
num_centers: usize,
|
||||
) {
|
||||
let mut picked = Vec::new();
|
||||
let mut rng = thread_rng();
|
||||
let distribution = Uniform::from(0..num_points);
|
||||
|
||||
for j in 0..num_centers {
|
||||
let mut tmp_pivot = distribution.sample(&mut rng);
|
||||
while picked.contains(&tmp_pivot) {
|
||||
tmp_pivot = distribution.sample(&mut rng);
|
||||
}
|
||||
picked.push(tmp_pivot);
|
||||
let data_offset = tmp_pivot * dim;
|
||||
let pivot_offset = j * dim;
|
||||
pivot_data[pivot_offset..pivot_offset + dim]
|
||||
.copy_from_slice(&data[data_offset..data_offset + dim]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Select pivots in k-means++ algorithm
|
||||
/// Points that are farther away from the already chosen centroids
|
||||
/// have a higher probability of being selected as the next centroid.
|
||||
/// The k-means++ algorithm helps avoid poor initial centroid
|
||||
/// placement that can result in suboptimal clustering.
|
||||
fn k_meanspp_selecting_pivots(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
pivot_data: &mut [f32],
|
||||
num_centers: usize,
|
||||
) {
|
||||
if num_points > (1 << 23) {
|
||||
println!("ERROR: n_pts {} currently not supported for k-means++, maximum is 8388608. Falling back to random pivot selection.", num_points);
|
||||
selecting_pivots(data, num_points, dim, pivot_data, num_centers);
|
||||
return;
|
||||
}
|
||||
|
||||
let mut picked: Vec<usize> = Vec::new();
|
||||
let mut rng = thread_rng();
|
||||
let real_distribution = Uniform::from(0.0..1.0);
|
||||
let int_distribution = Uniform::from(0..num_points);
|
||||
|
||||
let init_id = int_distribution.sample(&mut rng);
|
||||
let mut num_picked = 1;
|
||||
|
||||
picked.push(init_id);
|
||||
let init_data_offset = init_id * dim;
|
||||
pivot_data[0..dim].copy_from_slice(&data[init_data_offset..init_data_offset + dim]);
|
||||
|
||||
let mut dist = vec![0.0; num_points];
|
||||
|
||||
dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| {
|
||||
*dist_i = calc_distance(
|
||||
&data[i * dim..(i + 1) * dim],
|
||||
&data[init_id * dim..(init_id + 1) * dim],
|
||||
dim,
|
||||
);
|
||||
});
|
||||
|
||||
let mut dart_val: f64;
|
||||
let mut tmp_pivot = 0;
|
||||
let mut sum_flag = false;
|
||||
|
||||
while num_picked < num_centers {
|
||||
dart_val = real_distribution.sample(&mut rng);
|
||||
|
||||
let mut sum: f64 = 0.0;
|
||||
for item in dist.iter().take(num_points) {
|
||||
sum += *item as f64;
|
||||
}
|
||||
if sum == 0.0 {
|
||||
sum_flag = true;
|
||||
}
|
||||
|
||||
dart_val *= sum;
|
||||
|
||||
let mut prefix_sum: f64 = 0.0;
|
||||
for (i, pivot) in dist.iter().enumerate().take(num_points) {
|
||||
tmp_pivot = i;
|
||||
if dart_val >= prefix_sum && dart_val < (prefix_sum + *pivot as f64) {
|
||||
break;
|
||||
}
|
||||
|
||||
prefix_sum += *pivot as f64;
|
||||
}
|
||||
|
||||
if picked.contains(&tmp_pivot) && !sum_flag {
|
||||
continue;
|
||||
}
|
||||
|
||||
picked.push(tmp_pivot);
|
||||
let pivot_offset = num_picked * dim;
|
||||
let data_offset = tmp_pivot * dim;
|
||||
pivot_data[pivot_offset..pivot_offset + dim]
|
||||
.copy_from_slice(&data[data_offset..data_offset + dim]);
|
||||
|
||||
dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| {
|
||||
*dist_i = (*dist_i).min(calc_distance(
|
||||
&data[i * dim..(i + 1) * dim],
|
||||
&data[tmp_pivot * dim..(tmp_pivot + 1) * dim],
|
||||
dim,
|
||||
));
|
||||
});
|
||||
|
||||
num_picked += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// k-means algorithm interface
|
||||
pub fn k_means_clustering(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
centers: &mut [f32],
|
||||
num_centers: usize,
|
||||
max_reps: usize,
|
||||
) -> ANNResult<(Vec<Vec<usize>>, Vec<u32>, f32)> {
|
||||
k_meanspp_selecting_pivots(data, num_points, dim, centers, num_centers);
|
||||
let (closest_docs, closest_center, residual) =
|
||||
run_lloyds(data, num_points, dim, centers, num_centers, max_reps)?;
|
||||
Ok((closest_docs, closest_center, residual))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod kmeans_test {
|
||||
use super::*;
|
||||
use approx::assert_relative_eq;
|
||||
use rand::Rng;
|
||||
|
||||
#[test]
|
||||
fn lloyds_iter_test() {
|
||||
let dim = 2;
|
||||
let num_points = 10;
|
||||
let num_centers = 3;
|
||||
|
||||
let data: Vec<f32> = (1..=num_points * dim).map(|x| x as f32).collect();
|
||||
let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0];
|
||||
|
||||
let mut closest_docs: Vec<Vec<usize>> = vec![vec![]; num_centers];
|
||||
let mut closest_center: Vec<u32> = vec![0; num_points];
|
||||
let docs_l2sq: Vec<f32> = data
|
||||
.chunks(dim)
|
||||
.map(|chunk| chunk.iter().map(|val| val.powi(2)).sum())
|
||||
.collect();
|
||||
|
||||
let residual = lloyds_iter(
|
||||
&data,
|
||||
num_points,
|
||||
dim,
|
||||
&mut centers,
|
||||
num_centers,
|
||||
&docs_l2sq,
|
||||
&mut closest_docs,
|
||||
&mut closest_center,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_centers: [f32; 6] = [2.0, 3.0, 9.0, 10.0, 17.0, 18.0];
|
||||
let expected_closest_docs: Vec<Vec<usize>> =
|
||||
vec![vec![0, 1], vec![2, 3, 4, 5, 6], vec![7, 8, 9]];
|
||||
let expected_closest_center: [u32; 10] = [0, 0, 1, 1, 1, 1, 1, 2, 2, 2];
|
||||
let expected_residual: f32 = 100.0;
|
||||
|
||||
// sort data for assert
|
||||
centers.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
for inner_vec in &mut closest_docs {
|
||||
inner_vec.sort();
|
||||
}
|
||||
closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
|
||||
assert_eq!(centers, expected_centers);
|
||||
assert_eq!(closest_docs, expected_closest_docs);
|
||||
assert_eq!(closest_center, expected_closest_center);
|
||||
assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_lloyds_test() {
|
||||
let dim = 2;
|
||||
let num_points = 10;
|
||||
let num_centers = 3;
|
||||
let max_reps = 5;
|
||||
|
||||
let data: Vec<f32> = (1..=num_points * dim).map(|x| x as f32).collect();
|
||||
let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0];
|
||||
|
||||
let (mut closest_docs, mut closest_center, residual) =
|
||||
run_lloyds(&data, num_points, dim, &mut centers, num_centers, max_reps).unwrap();
|
||||
|
||||
let expected_centers: [f32; 6] = [3.0, 4.0, 10.0, 11.0, 17.0, 18.0];
|
||||
let expected_closest_docs: Vec<Vec<usize>> =
|
||||
vec![vec![0, 1, 2], vec![3, 4, 5, 6], vec![7, 8, 9]];
|
||||
let expected_closest_center: [u32; 10] = [0, 0, 0, 1, 1, 1, 1, 2, 2, 2];
|
||||
let expected_residual: f32 = 72.0;
|
||||
|
||||
// sort data for assert
|
||||
centers.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
for inner_vec in &mut closest_docs {
|
||||
inner_vec.sort();
|
||||
}
|
||||
closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
|
||||
assert_eq!(centers, expected_centers);
|
||||
assert_eq!(closest_docs, expected_closest_docs);
|
||||
assert_eq!(closest_center, expected_closest_center);
|
||||
assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selecting_pivots_test() {
|
||||
let dim = 2;
|
||||
let num_points = 10;
|
||||
let num_centers = 3;
|
||||
|
||||
// Generate some random data points
|
||||
let mut rng = rand::thread_rng();
|
||||
let data: Vec<f32> = (0..num_points * dim).map(|_| rng.gen()).collect();
|
||||
|
||||
let mut pivot_data = vec![0.0; num_centers * dim];
|
||||
|
||||
selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers);
|
||||
|
||||
// Verify that each pivot point corresponds to a point in the data
|
||||
for i in 0..num_centers {
|
||||
let pivot_offset = i * dim;
|
||||
let pivot = &pivot_data[pivot_offset..(pivot_offset + dim)];
|
||||
|
||||
// Make sure the pivot is found in the data
|
||||
let mut found = false;
|
||||
for j in 0..num_points {
|
||||
let data_offset = j * dim;
|
||||
let point = &data[data_offset..(data_offset + dim)];
|
||||
|
||||
if pivot == point {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert!(found, "Pivot not found in data");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn k_meanspp_selecting_pivots_test() {
|
||||
let dim = 2;
|
||||
let num_points = 10;
|
||||
let num_centers = 3;
|
||||
|
||||
// Generate some random data points
|
||||
let mut rng = rand::thread_rng();
|
||||
let data: Vec<f32> = (0..num_points * dim).map(|_| rng.gen()).collect();
|
||||
|
||||
let mut pivot_data = vec![0.0; num_centers * dim];
|
||||
|
||||
k_meanspp_selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers);
|
||||
|
||||
// Verify that each pivot point corresponds to a point in the data
|
||||
for i in 0..num_centers {
|
||||
let pivot_offset = i * dim;
|
||||
let pivot = &pivot_data[pivot_offset..pivot_offset + dim];
|
||||
|
||||
// Make sure the pivot is found in the data
|
||||
let mut found = false;
|
||||
for j in 0..num_points {
|
||||
let data_offset = j * dim;
|
||||
let point = &data[data_offset..data_offset + dim];
|
||||
|
||||
if pivot == point {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert!(found, "Pivot not found in data");
|
||||
}
|
||||
}
|
||||
}
|
||||
481
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/math_util.rs
vendored
Normal file
481
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/math_util.rs
vendored
Normal file
@@ -0,0 +1,481 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Aligned allocator
|
||||
|
||||
extern crate cblas;
|
||||
extern crate openblas_src;
|
||||
|
||||
use cblas::{sgemm, snrm2, Layout, Transpose};
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
cmp::{min, Ordering},
|
||||
collections::BinaryHeap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use crate::common::{ANNError, ANNResult};
|
||||
|
||||
struct PivotContainer {
|
||||
piv_id: usize,
|
||||
piv_dist: f32,
|
||||
}
|
||||
|
||||
impl PartialOrd for PivotContainer {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
other.piv_dist.partial_cmp(&self.piv_dist)
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for PivotContainer {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
// Treat NaN as less than all other values.
|
||||
// piv_dist should never be NaN.
|
||||
self.partial_cmp(other).unwrap_or(Ordering::Less)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for PivotContainer {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.piv_dist == other.piv_dist
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for PivotContainer {}
|
||||
|
||||
/// Calculate the Euclidean distance between two vectors
|
||||
pub fn calc_distance(vec_1: &[f32], vec_2: &[f32], dim: usize) -> f32 {
|
||||
let mut dist = 0.0;
|
||||
for j in 0..dim {
|
||||
let diff = vec_1[j] - vec_2[j];
|
||||
dist += diff * diff;
|
||||
}
|
||||
dist
|
||||
}
|
||||
|
||||
/// Compute L2-squared norms of data stored in row-major num_points * dim,
|
||||
/// need to be pre-allocated
|
||||
pub fn compute_vecs_l2sq(vecs_l2sq: &mut [f32], data: &[f32], num_points: usize, dim: usize) {
|
||||
assert_eq!(vecs_l2sq.len(), num_points);
|
||||
|
||||
vecs_l2sq
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(n_iter, vec_l2sq)| {
|
||||
let slice = &data[n_iter * dim..(n_iter + 1) * dim];
|
||||
let norm = unsafe { snrm2(dim as i32, slice, 1) };
|
||||
*vec_l2sq = norm * norm;
|
||||
});
|
||||
}
|
||||
|
||||
/// Calculate k closest centers to data of num_points * dim (row-major)
|
||||
/// Centers is num_centers * dim (row-major)
|
||||
/// data_l2sq has pre-computed squared norms of data
|
||||
/// centers_l2sq has pre-computed squared norms of centers
|
||||
/// Pre-allocated center_index will contain id of nearest center
|
||||
/// Pre-allocated dist_matrix should be num_points * num_centers and contain squared distances
|
||||
/// Default value of k is 1
|
||||
/// Ideally used only by compute_closest_centers
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn compute_closest_centers_in_block(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
centers: &[f32],
|
||||
num_centers: usize,
|
||||
docs_l2sq: &[f32],
|
||||
centers_l2sq: &[f32],
|
||||
center_index: &mut [u32],
|
||||
dist_matrix: &mut [f32],
|
||||
k: usize,
|
||||
) -> ANNResult<()> {
|
||||
if k > num_centers {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"ERROR: k ({}) > num_centers({})",
|
||||
k, num_centers
|
||||
)));
|
||||
}
|
||||
|
||||
let ones_a: Vec<f32> = vec![1.0; num_centers];
|
||||
let ones_b: Vec<f32> = vec![1.0; num_points];
|
||||
|
||||
unsafe {
|
||||
sgemm(
|
||||
Layout::RowMajor,
|
||||
Transpose::None,
|
||||
Transpose::Ordinary,
|
||||
num_points as i32,
|
||||
num_centers as i32,
|
||||
1,
|
||||
1.0,
|
||||
docs_l2sq,
|
||||
1,
|
||||
&ones_a,
|
||||
1,
|
||||
0.0,
|
||||
dist_matrix,
|
||||
num_centers as i32,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
sgemm(
|
||||
Layout::RowMajor,
|
||||
Transpose::None,
|
||||
Transpose::Ordinary,
|
||||
num_points as i32,
|
||||
num_centers as i32,
|
||||
1,
|
||||
1.0,
|
||||
&ones_b,
|
||||
1,
|
||||
centers_l2sq,
|
||||
1,
|
||||
1.0,
|
||||
dist_matrix,
|
||||
num_centers as i32,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
sgemm(
|
||||
Layout::RowMajor,
|
||||
Transpose::None,
|
||||
Transpose::Ordinary,
|
||||
num_points as i32,
|
||||
num_centers as i32,
|
||||
dim as i32,
|
||||
-2.0,
|
||||
data,
|
||||
dim as i32,
|
||||
centers,
|
||||
dim as i32,
|
||||
1.0,
|
||||
dist_matrix,
|
||||
num_centers as i32,
|
||||
);
|
||||
}
|
||||
|
||||
if k == 1 {
|
||||
center_index
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, center_idx)| {
|
||||
let mut min = f32::MAX;
|
||||
let current = &dist_matrix[i * num_centers..(i + 1) * num_centers];
|
||||
let mut min_idx = 0;
|
||||
for (j, &distance) in current.iter().enumerate() {
|
||||
if distance < min {
|
||||
min = distance;
|
||||
min_idx = j;
|
||||
}
|
||||
}
|
||||
*center_idx = min_idx as u32;
|
||||
});
|
||||
} else {
|
||||
center_index
|
||||
.par_chunks_mut(k)
|
||||
.enumerate()
|
||||
.for_each(|(i, center_chunk)| {
|
||||
let current = &dist_matrix[i * num_centers..(i + 1) * num_centers];
|
||||
let mut top_k_queue = BinaryHeap::new();
|
||||
for (j, &distance) in current.iter().enumerate() {
|
||||
let this_piv = PivotContainer {
|
||||
piv_id: j,
|
||||
piv_dist: distance,
|
||||
};
|
||||
if top_k_queue.len() < k {
|
||||
top_k_queue.push(this_piv);
|
||||
} else {
|
||||
// Safe unwrap, top_k_queue is not empty
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let mut top = top_k_queue.peek_mut().unwrap();
|
||||
if this_piv.piv_dist < top.piv_dist {
|
||||
*top = this_piv;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (_j, center_idx) in center_chunk.iter_mut().enumerate() {
|
||||
if let Some(this_piv) = top_k_queue.pop() {
|
||||
*center_idx = this_piv.piv_id as u32;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Given data in num_points * new_dim row major
|
||||
/// Pivots stored in full_pivot_data as num_centers * new_dim row major
|
||||
/// Calculate the k closest pivot for each point and store it in vector
|
||||
/// closest_centers_ivf (row major, num_points*k) (which needs to be allocated
|
||||
/// outside) Additionally, if inverted index is not null (and pre-allocated),
|
||||
/// it will return inverted index for each center, assuming each of the inverted
|
||||
/// indices is an empty vector. Additionally, if pts_norms_squared is not null,
|
||||
/// then it will assume that point norms are pre-computed and use those values
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn compute_closest_centers(
|
||||
data: &[f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
pivot_data: &[f32],
|
||||
num_centers: usize,
|
||||
k: usize,
|
||||
closest_centers_ivf: &mut [u32],
|
||||
mut inverted_index: Option<&mut Vec<Vec<usize>>>,
|
||||
pts_norms_squared: Option<&[f32]>,
|
||||
) -> ANNResult<()> {
|
||||
if k > num_centers {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"ERROR: k ({}) > num_centers({})",
|
||||
k, num_centers
|
||||
)));
|
||||
}
|
||||
|
||||
let _is_norm_given_for_pts = pts_norms_squared.is_some();
|
||||
|
||||
let mut pivs_norms_squared = vec![0.0; num_centers];
|
||||
|
||||
let mut pts_norms_squared = if let Some(pts_norms) = pts_norms_squared {
|
||||
pts_norms.to_vec()
|
||||
} else {
|
||||
let mut norms_squared = vec![0.0; num_points];
|
||||
compute_vecs_l2sq(&mut norms_squared, data, num_points, dim);
|
||||
norms_squared
|
||||
};
|
||||
|
||||
compute_vecs_l2sq(&mut pivs_norms_squared, pivot_data, num_centers, dim);
|
||||
|
||||
let par_block_size = num_points;
|
||||
let n_blocks = if num_points % par_block_size == 0 {
|
||||
num_points / par_block_size
|
||||
} else {
|
||||
num_points / par_block_size + 1
|
||||
};
|
||||
|
||||
let mut closest_centers = vec![0u32; par_block_size * k];
|
||||
let mut distance_matrix = vec![0.0; num_centers * par_block_size];
|
||||
|
||||
for cur_blk in 0..n_blocks {
|
||||
let data_cur_blk = &data[cur_blk * par_block_size * dim..];
|
||||
let num_pts_blk = min(par_block_size, num_points - cur_blk * par_block_size);
|
||||
let pts_norms_blk = &mut pts_norms_squared[cur_blk * par_block_size..];
|
||||
|
||||
compute_closest_centers_in_block(
|
||||
data_cur_blk,
|
||||
num_pts_blk,
|
||||
dim,
|
||||
pivot_data,
|
||||
num_centers,
|
||||
pts_norms_blk,
|
||||
&pivs_norms_squared,
|
||||
&mut closest_centers,
|
||||
&mut distance_matrix,
|
||||
k,
|
||||
)?;
|
||||
|
||||
closest_centers_ivf.clone_from_slice(&closest_centers);
|
||||
|
||||
if let Some(inverted_index_inner) = inverted_index.as_mut() {
|
||||
let inverted_index_arc = Arc::new(Mutex::new(inverted_index_inner));
|
||||
|
||||
(0..num_points)
|
||||
.into_par_iter()
|
||||
.try_for_each(|j| -> ANNResult<()> {
|
||||
let this_center_id = closest_centers[j] as usize;
|
||||
let mut guard = inverted_index_arc.lock().map_err(|err| {
|
||||
ANNError::log_index_error(format!(
|
||||
"PoisonError: Lock poisoned when acquiring inverted_index_arc, err={}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
guard[this_center_id].push(j);
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// If to_subtract is true, will subtract nearest center from each row.
|
||||
/// Else will add.
|
||||
/// Output will be in data_load itself.
|
||||
/// Nearest centers need to be provided in closest_centers.
|
||||
pub fn process_residuals(
|
||||
data_load: &mut [f32],
|
||||
num_points: usize,
|
||||
dim: usize,
|
||||
cur_pivot_data: &[f32],
|
||||
num_centers: usize,
|
||||
closest_centers: &[u32],
|
||||
to_subtract: bool,
|
||||
) {
|
||||
println!(
|
||||
"Processing residuals of {} points in {} dimensions using {} centers",
|
||||
num_points, dim, num_centers
|
||||
);
|
||||
|
||||
data_load
|
||||
.par_chunks_mut(dim)
|
||||
.enumerate()
|
||||
.for_each(|(n_iter, chunk)| {
|
||||
let cur_pivot_index = closest_centers[n_iter] as usize * dim;
|
||||
for d_iter in 0..dim {
|
||||
if to_subtract {
|
||||
chunk[d_iter] -= cur_pivot_data[cur_pivot_index + d_iter];
|
||||
} else {
|
||||
chunk[d_iter] += cur_pivot_data[cur_pivot_index + d_iter];
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod math_util_test {
|
||||
use super::*;
|
||||
use approx::assert_abs_diff_eq;
|
||||
|
||||
#[test]
|
||||
fn calc_distance_test() {
|
||||
let vec1 = vec![1.0, 2.0, 3.0];
|
||||
let vec2 = vec![4.0, 5.0, 6.0];
|
||||
let dim = vec1.len();
|
||||
|
||||
let dist = calc_distance(&vec1, &vec2, dim);
|
||||
|
||||
let expected = 27.0;
|
||||
|
||||
assert_eq!(dist, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_vecs_l2sq_test() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let num_points = 2;
|
||||
let dim = 3;
|
||||
let mut vecs_l2sq = vec![0.0; num_points];
|
||||
|
||||
compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim);
|
||||
|
||||
let expected = vec![14.0, 77.0];
|
||||
|
||||
assert_eq!(vecs_l2sq.len(), num_points);
|
||||
assert_abs_diff_eq!(vecs_l2sq[0], expected[0], epsilon = 1e-6);
|
||||
assert_abs_diff_eq!(vecs_l2sq[1], expected[1], epsilon = 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_closest_centers_in_block_test() {
|
||||
let num_points = 10;
|
||||
let dim = 5;
|
||||
let num_centers = 3;
|
||||
let data = vec![
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
|
||||
31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0,
|
||||
45.0, 46.0, 47.0, 48.0, 49.0, 50.0,
|
||||
];
|
||||
let centers = vec![
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 31.0, 32.0, 33.0, 34.0, 35.0,
|
||||
];
|
||||
let mut docs_l2sq = vec![0.0; num_points];
|
||||
compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim);
|
||||
let mut centers_l2sq = vec![0.0; num_centers];
|
||||
compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim);
|
||||
let mut center_index = vec![0; num_points];
|
||||
let mut dist_matrix = vec![0.0; num_points * num_centers];
|
||||
let k = 1;
|
||||
|
||||
compute_closest_centers_in_block(
|
||||
&data,
|
||||
num_points,
|
||||
dim,
|
||||
¢ers,
|
||||
num_centers,
|
||||
&docs_l2sq,
|
||||
¢ers_l2sq,
|
||||
&mut center_index,
|
||||
&mut dist_matrix,
|
||||
k,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(center_index.len(), num_points);
|
||||
let expected_center_index = vec![0, 0, 0, 1, 1, 1, 2, 2, 2, 2];
|
||||
assert_abs_diff_eq!(*center_index, expected_center_index);
|
||||
|
||||
assert_eq!(dist_matrix.len(), num_points * num_centers);
|
||||
let expected_dist_matrix = vec![
|
||||
0.0, 2000.0, 4500.0, 125.0, 1125.0, 3125.0, 500.0, 500.0, 2000.0, 1125.0, 125.0,
|
||||
1125.0, 2000.0, 0.0, 500.0, 3125.0, 125.0, 125.0, 4500.0, 500.0, 0.0, 6125.0, 1125.0,
|
||||
125.0, 8000.0, 2000.0, 500.0, 10125.0, 3125.0, 1125.0,
|
||||
];
|
||||
assert_abs_diff_eq!(*dist_matrix, expected_dist_matrix, epsilon = 1e-2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_closest_centers() {
|
||||
let num_points = 4;
|
||||
let dim = 3;
|
||||
let num_centers = 2;
|
||||
let mut data = vec![
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
];
|
||||
let pivot_data = vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0];
|
||||
let k = 1;
|
||||
|
||||
let mut closest_centers_ivf = vec![0u32; num_points * k];
|
||||
let mut inverted_index: Vec<Vec<usize>> = vec![vec![], vec![]];
|
||||
|
||||
compute_closest_centers(
|
||||
&data,
|
||||
num_points,
|
||||
dim,
|
||||
&pivot_data,
|
||||
num_centers,
|
||||
k,
|
||||
&mut closest_centers_ivf,
|
||||
Some(&mut inverted_index),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(closest_centers_ivf, vec![0, 0, 1, 1]);
|
||||
|
||||
for vec in inverted_index.iter_mut() {
|
||||
vec.sort_unstable();
|
||||
}
|
||||
assert_eq!(inverted_index, vec![vec![0, 1], vec![2, 3]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_residuals_test() {
|
||||
let mut data_load = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let num_points = 2;
|
||||
let dim = 2;
|
||||
let cur_pivot_data = vec![0.5, 1.5, 2.5, 3.5];
|
||||
let num_centers = 2;
|
||||
let closest_centers = vec![0, 1];
|
||||
let to_subtract = true;
|
||||
|
||||
process_residuals(
|
||||
&mut data_load,
|
||||
num_points,
|
||||
dim,
|
||||
&cur_pivot_data,
|
||||
num_centers,
|
||||
&closest_centers,
|
||||
to_subtract,
|
||||
);
|
||||
|
||||
assert_eq!(data_load, vec![0.5, 0.5, 0.5, 0.5]);
|
||||
}
|
||||
}
|
||||
34
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/mod.rs
vendored
Normal file
34
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/mod.rs
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
pub mod file_util;
|
||||
pub use file_util::*;
|
||||
|
||||
#[allow(clippy::module_inception)]
|
||||
pub mod utils;
|
||||
pub use utils::*;
|
||||
|
||||
pub mod bit_vec_extension;
|
||||
pub use bit_vec_extension::*;
|
||||
|
||||
pub mod rayon_util;
|
||||
pub use rayon_util::*;
|
||||
|
||||
pub mod timer;
|
||||
pub use timer::*;
|
||||
|
||||
pub mod cached_reader;
|
||||
pub use cached_reader::*;
|
||||
|
||||
pub mod cached_writer;
|
||||
pub use cached_writer::*;
|
||||
|
||||
pub mod partition;
|
||||
pub use partition::*;
|
||||
|
||||
pub mod math_util;
|
||||
pub use math_util::*;
|
||||
|
||||
pub mod kmeans;
|
||||
pub use kmeans::*;
|
||||
151
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/partition.rs
vendored
Normal file
151
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/partition.rs
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::mem;
|
||||
use std::{fs::File, path::Path};
|
||||
use std::io::{Write, Seek, SeekFrom};
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
|
||||
use crate::common::ANNResult;
|
||||
|
||||
use super::CachedReader;
|
||||
|
||||
/// streams data from the file, and samples each vector with probability p_val
|
||||
/// and returns a matrix of size slice_size* ndims as floating point type.
|
||||
/// the slice_size and ndims are set inside the function.
|
||||
/// # Arguments
|
||||
/// * `file_name` - filename where the data is
|
||||
/// * `p_val` - possibility to sample data
|
||||
/// * `sampled_vectors` - sampled vector chose by p_val possibility
|
||||
/// * `slice_size` - how many sampled data return
|
||||
/// * `dim` - each sample data dimension
|
||||
pub fn gen_random_slice<T: Default + Copy + Into<f32>>(data_file: &str, mut p_val: f64) -> ANNResult<(Vec<f32>, usize, usize)> {
|
||||
let read_blk_size = 64 * 1024 * 1024;
|
||||
let mut reader = CachedReader::new(data_file, read_blk_size)?;
|
||||
|
||||
let npts = reader.read_u32()? as usize;
|
||||
let dim = reader.read_u32()? as usize;
|
||||
let mut sampled_vectors: Vec<f32> = Vec::new();
|
||||
let mut slice_size = 0;
|
||||
p_val = if p_val < 1f64 { p_val } else { 1f64 };
|
||||
|
||||
let mut generator = rand::thread_rng();
|
||||
let distribution = Uniform::from(0.0..1.0);
|
||||
|
||||
for _ in 0..npts {
|
||||
let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::<T>()];
|
||||
reader.read(&mut cur_vector_bytes)?;
|
||||
let random_value = distribution.sample(&mut generator);
|
||||
if random_value < p_val {
|
||||
let ptr = cur_vector_bytes.as_ptr() as *const T;
|
||||
let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) };
|
||||
sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into()));
|
||||
slice_size += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((sampled_vectors, slice_size, dim))
|
||||
}
|
||||
|
||||
/// Generate random sample data and write into output_file
|
||||
pub fn gen_sample_data<T>(data_file: &str, output_file: &str, sampling_rate: f64) -> ANNResult<()> {
|
||||
let read_blk_size = 64 * 1024 * 1024;
|
||||
let mut reader = CachedReader::new(data_file, read_blk_size)?;
|
||||
|
||||
let sample_data_path = format!("{}_data.bin", output_file);
|
||||
let sample_ids_path = format!("{}_ids.bin", output_file);
|
||||
let mut sample_data_writer = File::create(Path::new(&sample_data_path))?;
|
||||
let mut sample_id_writer = File::create(Path::new(&sample_ids_path))?;
|
||||
|
||||
let mut num_sampled_pts = 0u32;
|
||||
let one_const = 1u32;
|
||||
let mut generator = rand::thread_rng();
|
||||
let distribution = Uniform::from(0.0..1.0);
|
||||
|
||||
let npts_u32 = reader.read_u32()?;
|
||||
let dim_u32 = reader.read_u32()?;
|
||||
let dim = dim_u32 as usize;
|
||||
sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?;
|
||||
sample_data_writer.write_all(&dim_u32.to_le_bytes())?;
|
||||
sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?;
|
||||
sample_id_writer.write_all(&one_const.to_le_bytes())?;
|
||||
|
||||
for id in 0..npts_u32 {
|
||||
let mut cur_row_bytes = vec![0u8; dim * mem::size_of::<T>()];
|
||||
reader.read(&mut cur_row_bytes)?;
|
||||
let random_value = distribution.sample(&mut generator);
|
||||
if random_value < sampling_rate {
|
||||
sample_data_writer.write_all(&cur_row_bytes)?;
|
||||
sample_id_writer.write_all(&id.to_le_bytes())?;
|
||||
num_sampled_pts += 1;
|
||||
}
|
||||
}
|
||||
|
||||
sample_data_writer.seek(SeekFrom::Start(0))?;
|
||||
sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?;
|
||||
sample_id_writer.seek(SeekFrom::Start(0))?;
|
||||
sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?;
|
||||
println!("Wrote {} points to sample file: {}", num_sampled_pts, sample_data_path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod partition_test {
|
||||
use std::{fs, io::Read};
|
||||
use byteorder::{ReadBytesExt, LittleEndian};
|
||||
|
||||
use crate::utils::file_exists;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn gen_sample_data_test() {
|
||||
let file_name = "gen_sample_data_test.bin";
|
||||
//npoints=2, dim=8
|
||||
let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0,
|
||||
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
|
||||
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
|
||||
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
|
||||
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41];
|
||||
std::fs::write(file_name, data).expect("Failed to write sample file");
|
||||
|
||||
let sample_file_prefix = file_name.to_string() + "_sample";
|
||||
gen_sample_data::<f32>(file_name, sample_file_prefix.as_str(), 1f64).unwrap();
|
||||
|
||||
let sample_data_path = format!("{}_data.bin", sample_file_prefix);
|
||||
let sample_ids_path = format!("{}_ids.bin", sample_file_prefix);
|
||||
assert!(file_exists(sample_data_path.as_str()));
|
||||
assert!(file_exists(sample_ids_path.as_str()));
|
||||
|
||||
let mut data_file_reader = File::open(sample_data_path.as_str()).unwrap();
|
||||
let mut ids_file_reader = File::open(sample_ids_path.as_str()).unwrap();
|
||||
|
||||
let mut num_sampled_pts = data_file_reader.read_u32::<LittleEndian>().unwrap();
|
||||
assert_eq!(num_sampled_pts, 2);
|
||||
num_sampled_pts = ids_file_reader.read_u32::<LittleEndian>().unwrap();
|
||||
assert_eq!(num_sampled_pts, 2);
|
||||
|
||||
let dim = data_file_reader.read_u32::<LittleEndian>().unwrap() as usize;
|
||||
assert_eq!(dim, 8);
|
||||
assert_eq!(ids_file_reader.read_u32::<LittleEndian>().unwrap(), 1);
|
||||
|
||||
let mut start = 8;
|
||||
for i in 0..num_sampled_pts {
|
||||
let mut data_bytes = vec![0u8; dim * 4];
|
||||
data_file_reader.read_exact(&mut data_bytes).unwrap();
|
||||
assert_eq!(data_bytes, data[start..start + dim * 4]);
|
||||
|
||||
let id = ids_file_reader.read_u32::<LittleEndian>().unwrap();
|
||||
assert_eq!(id, i);
|
||||
|
||||
start += dim * 4;
|
||||
}
|
||||
|
||||
fs::remove_file(file_name).expect("Failed to delete file");
|
||||
fs::remove_file(sample_data_path.as_str()).expect("Failed to delete file");
|
||||
fs::remove_file(sample_ids_path.as_str()).expect("Failed to delete file");
|
||||
}
|
||||
}
|
||||
|
||||
33
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/rayon_util.rs
vendored
Normal file
33
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/rayon_util.rs
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::ops::Range;
|
||||
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::common::ANNResult;
|
||||
|
||||
/// based on thread_num, execute the task in parallel using Rayon or serial
|
||||
#[inline]
|
||||
pub fn execute_with_rayon<F>(range: Range<usize>, num_threads: u32, f: F) -> ANNResult<()>
|
||||
where F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy
|
||||
{
|
||||
if num_threads == 1 {
|
||||
for i in range {
|
||||
f(i)?;
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
range.into_par_iter().try_for_each(f)
|
||||
}
|
||||
}
|
||||
|
||||
/// set the thread count of Rayon, otherwise it will use threads as many as logical cores.
|
||||
#[inline]
|
||||
pub fn set_rayon_num_threads(num_threads: u32) {
|
||||
std::env::set_var(
|
||||
"RAYON_NUM_THREADS",
|
||||
num_threads.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
101
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/timer.rs
vendored
Normal file
101
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/timer.rs
vendored
Normal file
@@ -0,0 +1,101 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use platform::*;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Timer {
|
||||
check_point: Instant,
|
||||
pid: Option<usize>,
|
||||
cycles: Option<u64>,
|
||||
}
|
||||
|
||||
impl Default for Timer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Timer {
|
||||
pub fn new() -> Timer {
|
||||
let pid = get_process_handle();
|
||||
let cycles = get_process_cycle_time(pid);
|
||||
Timer {
|
||||
check_point: Instant::now(),
|
||||
pid,
|
||||
cycles,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.check_point = Instant::now();
|
||||
self.cycles = get_process_cycle_time(self.pid);
|
||||
}
|
||||
|
||||
pub fn elapsed(&self) -> Duration {
|
||||
Instant::now().duration_since(self.check_point)
|
||||
}
|
||||
|
||||
pub fn elapsed_seconds(&self) -> f64 {
|
||||
self.elapsed().as_secs_f64()
|
||||
}
|
||||
|
||||
pub fn elapsed_gcycles(&self) -> f32 {
|
||||
let cur_cycles = get_process_cycle_time(self.pid);
|
||||
if let (Some(cur_cycles), Some(cycles)) = (cur_cycles, self.cycles) {
|
||||
let spent_cycles =
|
||||
((cur_cycles - cycles) as f64 * 1.0f64) / (1024 * 1024 * 1024) as f64;
|
||||
return spent_cycles as f32;
|
||||
}
|
||||
|
||||
0.0
|
||||
}
|
||||
|
||||
pub fn elapsed_seconds_for_step(&self, step: &str) -> String {
|
||||
format!(
|
||||
"Time for {}: {:.3} seconds, {:.3}B cycles",
|
||||
step,
|
||||
self.elapsed_seconds(),
|
||||
self.elapsed_gcycles()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod timer_tests {
|
||||
use super::*;
|
||||
use std::{thread, time};
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let timer = Timer::new();
|
||||
assert!(timer.check_point.elapsed().as_secs() < 1);
|
||||
if cfg!(windows) {
|
||||
assert!(timer.pid.is_some());
|
||||
assert!(timer.cycles.is_some());
|
||||
}
|
||||
else {
|
||||
assert!(timer.pid.is_none());
|
||||
assert!(timer.cycles.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset() {
|
||||
let mut timer = Timer::new();
|
||||
thread::sleep(time::Duration::from_millis(100));
|
||||
timer.reset();
|
||||
assert!(timer.check_point.elapsed().as_millis() < 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elapsed() {
|
||||
let timer = Timer::new();
|
||||
thread::sleep(time::Duration::from_millis(100));
|
||||
assert!(timer.elapsed().as_millis() > 100);
|
||||
assert!(timer.elapsed_seconds() > 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
154
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/utils.rs
vendored
Normal file
154
packages/leann-backend-diskann/third_party/DiskANN/rust/diskann/src/utils/utils.rs
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::sync::Mutex;
|
||||
use num_traits::Num;
|
||||
|
||||
/// Non recursive mutex
|
||||
pub type NonRecursiveMutex = Mutex<()>;
|
||||
|
||||
/// Round up X to the nearest multiple of Y
|
||||
#[inline]
|
||||
pub fn round_up<T>(x: T, y: T) -> T
|
||||
where T : Num + Copy
|
||||
{
|
||||
div_round_up(x, y) * y
|
||||
}
|
||||
|
||||
/// Rounded-up division
|
||||
#[inline]
|
||||
pub fn div_round_up<T>(x: T, y: T) -> T
|
||||
where T : Num + Copy
|
||||
{
|
||||
(x / y) + if x % y != T::zero() {T::one()} else {T::zero()}
|
||||
}
|
||||
|
||||
/// Round down X to the nearest multiple of Y
|
||||
#[inline]
|
||||
pub fn round_down<T>(x: T, y: T) -> T
|
||||
where T : Num + Copy
|
||||
{
|
||||
(x / y) * y
|
||||
}
|
||||
|
||||
/// Is aligned
|
||||
#[inline]
|
||||
pub fn is_aligned<T>(x: T, y: T) -> bool
|
||||
where T : Num + Copy
|
||||
{
|
||||
x % y == T::zero()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_512_aligned(x: u64) -> bool {
|
||||
is_aligned(x, 512)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn is_4096_aligned(x: u64) -> bool {
|
||||
is_aligned(x, 4096)
|
||||
}
|
||||
|
||||
/// all metadata of individual sub-component files is written in first 4KB for unified files
|
||||
pub const METADATA_SIZE: usize = 4096;
|
||||
|
||||
pub const BUFFER_SIZE_FOR_CACHED_IO: usize = 1024 * 1048576;
|
||||
|
||||
pub const PBSTR: &str = "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||";
|
||||
|
||||
pub const PBWIDTH: usize = 60;
|
||||
|
||||
macro_rules! convert_types {
|
||||
($name:ident, $intput_type:ty, $output_type:ty) => {
|
||||
/// Write data into file
|
||||
pub fn $name(srcmat: &[$intput_type], npts: usize, dim: usize) -> Vec<$output_type> {
|
||||
let mut destmat: Vec<$output_type> = Vec::new();
|
||||
for i in 0..npts {
|
||||
for j in 0..dim {
|
||||
destmat.push(srcmat[i * dim + j] as $output_type);
|
||||
}
|
||||
}
|
||||
destmat
|
||||
}
|
||||
};
|
||||
}
|
||||
convert_types!(convert_types_usize_u8, usize, u8);
|
||||
convert_types!(convert_types_usize_u32, usize, u32);
|
||||
convert_types!(convert_types_usize_u64, usize, u64);
|
||||
convert_types!(convert_types_u64_usize, u64, usize);
|
||||
convert_types!(convert_types_u32_usize, u32, usize);
|
||||
|
||||
#[cfg(test)]
|
||||
mod file_util_test {
|
||||
use super::*;
|
||||
use std::any::type_name;
|
||||
|
||||
#[test]
|
||||
fn round_up_test() {
|
||||
assert_eq!(round_up(252, 8), 256);
|
||||
assert_eq!(round_up(256, 8), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn div_round_up_test() {
|
||||
assert_eq!(div_round_up(252, 8), 32);
|
||||
assert_eq!(div_round_up(256, 8), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_down_test() {
|
||||
assert_eq!(round_down(252, 8), 248);
|
||||
assert_eq!(round_down(256, 8), 256);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_aligned_test() {
|
||||
assert!(!is_aligned(252, 8));
|
||||
assert!(is_aligned(256, 8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_512_aligned_test() {
|
||||
assert!(!is_512_aligned(520));
|
||||
assert!(is_512_aligned(512));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_4096_aligned_test() {
|
||||
assert!(!is_4096_aligned(4090));
|
||||
assert!(is_4096_aligned(4096));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_types_test() {
|
||||
let data = vec![0u64, 1u64, 2u64];
|
||||
let output = convert_types_u64_usize(&data, 3, 1);
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(type_of(output[0]), "usize");
|
||||
assert_eq!(output[0], 0usize);
|
||||
|
||||
let data = vec![0usize, 1usize, 2usize];
|
||||
let output = convert_types_usize_u8(&data, 3, 1);
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(type_of(output[0]), "u8");
|
||||
assert_eq!(output[0], 0u8);
|
||||
|
||||
let data = vec![0usize, 1usize, 2usize];
|
||||
let output = convert_types_usize_u64(&data, 3, 1);
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(type_of(output[0]), "u64");
|
||||
assert_eq!(output[0], 0u64);
|
||||
|
||||
let data = vec![0u32, 1u32, 2u32];
|
||||
let output = convert_types_u32_usize(&data, 3, 1);
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(type_of(output[0]), "usize");
|
||||
assert_eq!(output[0],0usize);
|
||||
}
|
||||
|
||||
fn type_of<T>(_: T) -> &'static str {
|
||||
type_name::<T>()
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user