Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
[package]
name = "diskann"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bincode = "1.3.3"
bit-vec = "0.6.3"
byteorder = "1.4.3"
cblas = "0.4.0"
crossbeam = "0.8.2"
half = "2.2.1"
hashbrown = "0.13.2"
num-traits = "0.2.15"
once_cell = "1.17.1"
openblas-src = { version = "0.10.8", features = ["system"] }
rand = { version = "0.8.5", features = [ "small_rng" ] }
rayon = "1.7.0"
serde = { version = "1.0.130", features = ["derive"] }
thiserror = "1.0.40"
winapi = { version = "0.3.9", features = ["errhandlingapi", "fileapi", "ioapiset", "handleapi", "winnt", "minwindef", "basetsd", "winerror", "winbase"] }
logger = { path = "../logger" }
platform = { path = "../platform" }
vector = { path = "../vector" }
[build-dependencies]
cc = "1.0.79"
[dev-dependencies]
approx = "0.5.1"
criterion = "0.5.1"
[[bench]]
name = "distance_bench"
harness = false
[[bench]]
name = "neighbor_bench"
harness = false

View File

@@ -0,0 +1,47 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::{thread_rng, Rng};
use vector::{FullPrecisionDistance, Metric};
// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps
#[repr(C, align(32))]
struct Vector32ByteAligned {
v: [f32; 256],
}
fn benchmark_l2_distance_float_rust(c: &mut Criterion) {
let (a, b) = prepare_random_aligned_vectors();
let mut group = c.benchmark_group("avx-computation");
group.sample_size(5000);
group.bench_function("AVX Rust run", |f| {
f.iter(|| {
black_box(<[f32; 256]>::distance_compare(
black_box(&a.v),
black_box(&b.v),
Metric::L2,
))
})
});
}
// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps
fn prepare_random_aligned_vectors() -> (Box<Vector32ByteAligned>, Box<Vector32ByteAligned>) {
let a = Box::new(Vector32ByteAligned {
v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)),
});
let b = Box::new(Vector32ByteAligned {
v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)),
});
(a, b)
}
criterion_group!(benches, benchmark_l2_distance_float_rust,);
criterion_main!(benches);

View File

@@ -0,0 +1,70 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use criterion::{criterion_group, criterion_main, Criterion};
use diskann::utils::k_means_clustering;
use rand::Rng;
const NUM_POINTS: usize = 10000;
const DIM: usize = 100;
const NUM_CENTERS: usize = 256;
const MAX_KMEANS_REPS: usize = 12;
fn benchmark_kmeans_rust(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let data: Vec<f32> = (0..NUM_POINTS * DIM)
.map(|_| rng.gen_range(-1.0..1.0))
.collect();
let centers: Vec<f32> = vec![0.0; NUM_CENTERS * DIM];
let mut group = c.benchmark_group("kmeans-computation");
group.sample_size(500);
group.bench_function("K-Means Rust run", |f| {
f.iter(|| {
// let mut centers_copy = centers.clone();
let data_copy = data.clone();
let mut centers_copy = centers.clone();
k_means_clustering(
&data_copy,
NUM_POINTS,
DIM,
&mut centers_copy,
NUM_CENTERS,
MAX_KMEANS_REPS,
)
})
});
}
fn benchmark_kmeans_c(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let data: Vec<f32> = (0..NUM_POINTS * DIM)
.map(|_| rng.gen_range(-1.0..1.0))
.collect();
let centers: Vec<f32> = vec![0.0; NUM_CENTERS * DIM];
let mut group = c.benchmark_group("kmeans-computation");
group.sample_size(500);
group.bench_function("K-Means C++ Run", |f| {
f.iter(|| {
let data_copy = data.clone();
let mut centers_copy = centers.clone();
let _ = k_means_clustering(
data_copy.as_slice(),
NUM_POINTS,
DIM,
centers_copy.as_mut_slice(),
NUM_CENTERS,
MAX_KMEANS_REPS,
);
})
});
}
criterion_group!(benches, benchmark_kmeans_rust, benchmark_kmeans_c);
criterion_main!(benches);

View File

@@ -0,0 +1,49 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::time::Duration;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use diskann::model::{Neighbor, NeighborPriorityQueue};
use rand::distributions::{Distribution, Uniform};
use rand::rngs::StdRng;
use rand::SeedableRng;
fn benchmark_priority_queue_insert(c: &mut Criterion) {
let vec = generate_random_floats();
let mut group = c.benchmark_group("neighborqueue-insert");
group.measurement_time(Duration::from_secs(3)).sample_size(500);
let mut queue = NeighborPriorityQueue::with_capacity(64_usize);
group.bench_function("Neighbor Priority Queue Insert", |f| {
f.iter(|| {
queue.clear();
for n in vec.iter() {
queue.insert(*n);
}
black_box(&1)
});
});
}
fn generate_random_floats() -> Vec<Neighbor> {
let seed: [u8; 32] = [73; 32];
let mut rng: StdRng = SeedableRng::from_seed(seed);
let range = Uniform::new(0.0, 1.0);
let mut random_floats = Vec::with_capacity(100);
for i in 0..100 {
let random_float = range.sample(&mut rng) as f32;
let n = Neighbor::new(i, random_float);
random_floats.push(n);
}
random_floats
}
criterion_group!(benches, benchmark_priority_queue_insert);
criterion_main!(benches);

View File

@@ -0,0 +1,7 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub mod search;
pub mod prune;

View File

@@ -0,0 +1,6 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
pub mod prune;

View File

@@ -0,0 +1,288 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use hashbrown::HashSet;
use vector::{FullPrecisionDistance, Metric};
use crate::common::{ANNError, ANNResult};
use crate::index::InmemIndex;
use crate::model::graph::AdjacencyList;
use crate::model::neighbor::SortedNeighborVector;
use crate::model::scratch::InMemQueryScratch;
use crate::model::Neighbor;
impl<T, const N: usize> InmemIndex<T, N>
where
T: Default + Copy + Sync + Send + Into<f32>,
[T; N]: FullPrecisionDistance<T, N>,
{
/// A method that occludes a list of neighbors based on some criteria
#[allow(clippy::too_many_arguments)]
fn occlude_list(
&self,
location: u32,
pool: &mut SortedNeighborVector,
alpha: f32,
degree: u32,
max_candidate_size: usize,
result: &mut AdjacencyList,
scratch: &mut InMemQueryScratch<T, N>,
delete_set_ptr: Option<&HashSet<u32>>,
) -> ANNResult<()> {
if pool.is_empty() {
return Ok(());
}
if !result.is_empty() {
return Err(ANNError::log_index_error(
"result is not empty.".to_string(),
));
}
// Truncate pool at max_candidate_size and initialize scratch spaces
if pool.len() > max_candidate_size {
pool.truncate(max_candidate_size);
}
let occlude_factor = &mut scratch.occlude_factor;
// occlude_list can be called with the same scratch more than once by
// search_for_point_and_add_link through inter_insert.
occlude_factor.clear();
// Initialize occlude_factor to pool.len() many 0.0 values for correctness
occlude_factor.resize(pool.len(), 0.0);
let mut cur_alpha = 1.0;
while cur_alpha <= alpha && result.len() < degree as usize {
for (i, neighbor) in pool.iter().enumerate() {
if result.len() >= degree as usize {
break;
}
if occlude_factor[i] > cur_alpha {
continue;
}
// Set the entry to f32::MAX so that is not considered again
occlude_factor[i] = f32::MAX;
// Add the entry to the result if its not been deleted, and doesn't
// add a self loop
if delete_set_ptr.map_or(true, |delete_set| !delete_set.contains(&neighbor.id))
&& neighbor.id != location
{
result.push(neighbor.id);
}
// Update occlude factor for points from i+1 to pool.len()
for (j, neighbor2) in pool.iter().enumerate().skip(i + 1) {
if occlude_factor[j] > alpha {
continue;
}
// todo - self.filtered_index
let djk = self.get_distance(neighbor2.id, neighbor.id)?;
match self.configuration.dist_metric {
Metric::L2 | Metric::Cosine => {
occlude_factor[j] = if djk == 0.0 {
f32::MAX
} else {
occlude_factor[j].max(neighbor2.distance / djk)
};
}
}
}
}
cur_alpha *= 1.2;
}
Ok(())
}
/// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids.
///
/// # Arguments
///
/// * `location` - The id of the data point whose neighbors are to be pruned.
/// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point.
/// * `pruned_list` - A vector to store the ids of the pruned neighbors.
/// * `scratch` - A mutable reference to a scratch space for in-memory queries.
///
/// # Panics
///
/// Panics if `pruned_list` contains more than `range` elements after pruning.
pub fn prune_neighbors(
&self,
location: u32,
pool: &mut Vec<Neighbor>,
pruned_list: &mut AdjacencyList,
scratch: &mut InMemQueryScratch<T, N>,
) -> ANNResult<()> {
self.robust_prune(
location,
pool,
self.configuration.index_write_parameter.max_degree,
self.configuration.index_write_parameter.max_occlusion_size,
self.configuration.index_write_parameter.alpha,
pruned_list,
scratch,
)
}
/// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids.
///
/// # Arguments
///
/// * `location` - The id of the data point whose neighbors are to be pruned.
/// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point.
/// * `range` - The maximum number of neighbors to keep after pruning.
/// * `max_candidate_size` - The maximum number of candidates to consider for pruning.
/// * `alpha` - A parameter that controls the occlusion pruning strategy.
/// * `pruned_list` - A vector to store the ids of the pruned neighbors.
/// * `scratch` - A mutable reference to a scratch space for in-memory queries.
///
/// # Error
///
/// Return error if `pruned_list` contains more than `range` elements after pruning.
#[allow(clippy::too_many_arguments)]
fn robust_prune(
&self,
location: u32,
pool: &mut Vec<Neighbor>,
range: u32,
max_candidate_size: u32,
alpha: f32,
pruned_list: &mut AdjacencyList,
scratch: &mut InMemQueryScratch<T, N>,
) -> ANNResult<()> {
if pool.is_empty() {
// if the pool is empty, behave like a noop
pruned_list.clear();
return Ok(());
}
// If using _pq_build, over-write the PQ distances with actual distances
// todo : pq_dist
// sort the pool based on distance to query and prune it with occlude_list
let mut pool = SortedNeighborVector::new(pool);
pruned_list.clear();
self.occlude_list(
location,
&mut pool,
alpha,
range,
max_candidate_size as usize,
pruned_list,
scratch,
Option::None,
)?;
if pruned_list.len() > range as usize {
return Err(ANNError::log_index_error(format!(
"pruned_list's len {} is over range {}.",
pruned_list.len(),
range
)));
}
if self.configuration.index_write_parameter.saturate_graph && alpha > 1.0f32 {
for neighbor in pool.iter() {
if pruned_list.len() >= (range as usize) {
break;
}
if !pruned_list.contains(&neighbor.id) && neighbor.id != location {
pruned_list.push(neighbor.id);
}
}
}
Ok(())
}
/// A method that inserts a point n into the graph of its neighbors and their neighbors,
/// pruning the graph if necessary to keep it within the specified range
/// * `n` - The index of the new point
/// * `pruned_list` is a vector of the neighbors of n that have been pruned by a previous step
/// * `range` is the target number of neighbors for each point
/// * `scratch` is a mutable reference to a scratch space that can be reused for intermediate computations
pub fn inter_insert(
&self,
n: u32,
pruned_list: &Vec<u32>,
range: u32,
scratch: &mut InMemQueryScratch<T, N>,
) -> ANNResult<()> {
// Borrow the pruned_list as a source pool of neighbors
let src_pool = pruned_list;
if src_pool.is_empty() {
return Err(ANNError::log_index_error("src_pool is empty.".to_string()));
}
for &vertex_id in src_pool {
// vertex is the index of a neighbor of n
// Assert that vertex is within the valid range of points
if (vertex_id as usize)
>= self.configuration.max_points + self.configuration.num_frozen_pts
{
return Err(ANNError::log_index_error(format!(
"vertex_id {} is out of valid range of points {}",
vertex_id,
self.configuration.max_points + self.configuration.num_frozen_pts,
)));
}
let neighbors = self.add_to_neighbors(vertex_id, n, range)?;
if let Some(copy_of_neighbors) = neighbors {
// Pruning is needed, create a dummy set and a dummy vector to store the unique neighbors of vertex_id
let mut dummy_pool = self.get_unique_neighbors(&copy_of_neighbors, vertex_id)?;
// Create a new vector to store the pruned neighbors of vertex_id
let mut new_out_neighbors =
AdjacencyList::for_range(self.configuration.write_range());
// Prune the neighbors of vertex_id using a helper method
self.prune_neighbors(vertex_id, &mut dummy_pool, &mut new_out_neighbors, scratch)?;
self.set_neighbors(vertex_id, new_out_neighbors)?;
}
}
Ok(())
}
/// Adds a node to the list of neighbors for the given node.
///
/// # Arguments
///
/// * `vertex_id` - The ID of the node to add the neighbor to.
/// * `node_id` - The ID of the node to add.
/// * `range` - The range of the graph.
///
/// # Return
///
/// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full.
fn add_to_neighbors(
&self,
vertex_id: u32,
node_id: u32,
range: u32,
) -> ANNResult<Option<Vec<u32>>> {
// vertex contains a vector of the neighbors of vertex_id
let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?;
Ok(vertex_guard.add_to_neighbors(node_id, range))
}
fn set_neighbors(&self, vertex_id: u32, new_out_neighbors: AdjacencyList) -> ANNResult<()> {
// vertex contains a vector of the neighbors of vertex_id
let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?;
vertex_guard.set_neighbors(new_out_neighbors);
Ok(())
}
}

View File

@@ -0,0 +1,7 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
pub mod search;

View File

@@ -0,0 +1,359 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Search algorithm for index construction and query
use crate::common::{ANNError, ANNResult};
use crate::index::InmemIndex;
use crate::model::{scratch::InMemQueryScratch, Neighbor, Vertex};
use hashbrown::hash_set::Entry::*;
use vector::FullPrecisionDistance;
impl<T, const N: usize> InmemIndex<T, N>
where
T: Default + Copy + Sync + Send + Into<f32>,
[T; N]: FullPrecisionDistance<T, N>,
{
/// Search for query using given L value, for benchmarking purposes
/// # Arguments
/// * `query` - query vertex
/// * `scratch` - in-memory query scratch
/// * `search_list_size` - search list size to use for the benchmark
pub fn search_with_l_override(
&self,
query: &Vertex<T, N>,
scratch: &mut InMemQueryScratch<T, N>,
search_list_size: usize,
) -> ANNResult<u32> {
let init_ids = self.get_init_ids()?;
self.init_graph_for_point(query, init_ids, scratch)?;
// Scratch is created using largest L val from search_memory_index, so we artifically make it smaller here
// This allows us to use the same scratch for all L values without having to rebuild the query scratch
scratch.best_candidates.set_capacity(search_list_size);
let (_, cmp) = self.greedy_search(query, scratch)?;
Ok(cmp)
}
/// search for point
/// # Arguments
/// * `query` - query vertex
/// * `scratch` - in-memory query scratch
/// TODO: use_filter, filteredLindex
pub fn search_for_point(
&self,
query: &Vertex<T, N>,
scratch: &mut InMemQueryScratch<T, N>,
) -> ANNResult<Vec<Neighbor>> {
let init_ids = self.get_init_ids()?;
self.init_graph_for_point(query, init_ids, scratch)?;
let (mut visited_nodes, _) = self.greedy_search(query, scratch)?;
visited_nodes.retain(|&element| element.id != query.vertex_id());
Ok(visited_nodes)
}
/// Returns the locations of start point and frozen points suitable for use with iterate_to_fixed_point.
fn get_init_ids(&self) -> ANNResult<Vec<u32>> {
let mut init_ids = Vec::with_capacity(1 + self.configuration.num_frozen_pts);
init_ids.push(self.start);
for frozen in self.configuration.max_points
..(self.configuration.max_points + self.configuration.num_frozen_pts)
{
let frozen_u32 = frozen.try_into()?;
if frozen_u32 != self.start {
init_ids.push(frozen_u32);
}
}
Ok(init_ids)
}
/// Initialize graph for point
/// # Arguments
/// * `query` - query vertex
/// * `init_ids` - initial nodes from which search starts
/// * `scratch` - in-memory query scratch
/// * `search_list_size_override` - override for search list size in index config
fn init_graph_for_point(
&self,
query: &Vertex<T, N>,
init_ids: Vec<u32>,
scratch: &mut InMemQueryScratch<T, N>,
) -> ANNResult<()> {
scratch
.best_candidates
.reserve(self.configuration.index_write_parameter.search_list_size as usize);
scratch.query.memcpy(query.vector())?;
if !scratch.id_scratch.is_empty() {
return Err(ANNError::log_index_error(
"id_scratch is not empty.".to_string(),
));
}
let query_vertex = Vertex::<T, N>::try_from((&scratch.query[..], query.vertex_id()))
.map_err(|err| {
ANNError::log_index_error(format!(
"TryFromSliceError: failed to get Vertex for query, err={}",
err
))
})?;
for id in init_ids {
if (id as usize) >= self.configuration.max_points + self.configuration.num_frozen_pts {
return Err(ANNError::log_index_error(format!(
"vertex_id {} is out of valid range of points {}",
id,
self.configuration.max_points + self.configuration.num_frozen_pts
)));
}
if let Vacant(entry) = scratch.node_visited_robinset.entry(id) {
entry.insert();
let vertex = self.dataset.get_vertex(id)?;
let distance = vertex.compare(&query_vertex, self.configuration.dist_metric);
let neighbor = Neighbor::new(id, distance);
scratch.best_candidates.insert(neighbor);
}
}
Ok(())
}
/// GreedySearch against query node
/// Returns visited nodes
/// # Arguments
/// * `query` - query vertex
/// * `scratch` - in-memory query scratch
/// TODO: use_filter, filter_label, search_invocation
fn greedy_search(
&self,
query: &Vertex<T, N>,
scratch: &mut InMemQueryScratch<T, N>,
) -> ANNResult<(Vec<Neighbor>, u32)> {
let mut visited_nodes =
Vec::with_capacity((3 * scratch.candidate_size + scratch.max_degree) as usize);
// TODO: uncomment hops?
// let mut hops: u32 = 0;
let mut cmps: u32 = 0;
let query_vertex = Vertex::<T, N>::try_from((&scratch.query[..], query.vertex_id()))
.map_err(|err| {
ANNError::log_index_error(format!(
"TryFromSliceError: failed to get Vertex for query, err={}",
err
))
})?;
while scratch.best_candidates.has_notvisited_node() {
let closest_node = scratch.best_candidates.closest_notvisited();
// Add node to visited nodes to create pool for prune later
// TODO: search_invocation and use_filter
visited_nodes.push(closest_node);
// Find which of the nodes in des have not been visited before
scratch.id_scratch.clear();
let max_vertex_id = self.configuration.max_points + self.configuration.num_frozen_pts;
for id in self
.final_graph
.read_vertex_and_neighbors(closest_node.id)?
.get_neighbors()
{
let current_vertex_id = *id;
debug_assert!(
(current_vertex_id as usize) < max_vertex_id,
"current_vertex_id {} is out of valid range of points {}",
current_vertex_id,
max_vertex_id
);
if current_vertex_id as usize >= max_vertex_id {
continue;
}
// quickly de-dup. Remember, we are in a read lock
// we want to exit out of it quickly
if scratch.node_visited_robinset.insert(current_vertex_id) {
scratch.id_scratch.push(current_vertex_id);
}
}
let len = scratch.id_scratch.len();
for (m, &id) in scratch.id_scratch.iter().enumerate() {
if m + 1 < len {
let next_node = unsafe { *scratch.id_scratch.get_unchecked(m + 1) };
self.dataset.prefetch_vector(next_node);
}
let vertex = self.dataset.get_vertex(id)?;
let distance = query_vertex.compare(&vertex, self.configuration.dist_metric);
// Insert <id, dist> pairs into the pool of candidates
scratch.best_candidates.insert(Neighbor::new(id, distance));
}
cmps += len as u32;
}
Ok((visited_nodes, cmps))
}
}
#[cfg(test)]
mod search_test {
use vector::Metric;
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
use crate::model::graph::AdjacencyList;
use crate::model::IndexConfiguration;
use crate::test_utils::inmem_index_initialization::create_index_with_test_data;
use super::*;
#[test]
fn get_init_ids_no_forzen_pts() {
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4)
.with_alpha(1.2)
.build();
let config = IndexConfiguration::new(
Metric::L2,
256,
256,
256,
false,
0,
false,
0,
1f32,
index_write_parameters,
);
let index = InmemIndex::<f32, 256>::new(config).unwrap();
let init_ids = index.get_init_ids().unwrap();
assert_eq!(init_ids.len(), 1);
assert_eq!(init_ids[0], 256);
}
#[test]
fn get_init_ids_with_forzen_pts() {
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4)
.with_alpha(1.2)
.build();
let config = IndexConfiguration::new(
Metric::L2,
256,
256,
256,
false,
0,
false,
2,
1f32,
index_write_parameters,
);
let index = InmemIndex::<f32, 256>::new(config).unwrap();
let init_ids = index.get_init_ids().unwrap();
assert_eq!(init_ids.len(), 2);
assert_eq!(init_ids[0], 256);
assert_eq!(init_ids[1], 257);
}
#[test]
fn search_for_point_initial_call() {
let index = create_index_with_test_data();
let query = index.dataset.get_vertex(0).unwrap();
let mut scratch = InMemQueryScratch::new(
index.configuration.index_write_parameter.search_list_size,
&index.configuration.index_write_parameter,
false,
)
.unwrap();
let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap();
assert_eq!(visited_nodes.len(), 1);
assert_eq!(scratch.best_candidates.size(), 1);
assert_eq!(scratch.best_candidates[0].id, 72);
assert_eq!(scratch.best_candidates[0].distance, 125678.0_f32);
assert!(scratch.best_candidates[0].visited);
}
fn set_neighbors(index: &InmemIndex<f32, 128>, vertex_id: u32, neighbors: Vec<u32>) {
index
.final_graph
.write_vertex_and_neighbors(vertex_id)
.unwrap()
.set_neighbors(AdjacencyList::from(neighbors));
}
#[test]
fn search_for_point_works_with_edges() {
let index = create_index_with_test_data();
let query = index.dataset.get_vertex(14).unwrap();
set_neighbors(&index, 0, vec![12, 72, 5, 9]);
set_neighbors(&index, 1, vec![2, 12, 10, 4]);
set_neighbors(&index, 2, vec![1, 72, 9]);
set_neighbors(&index, 3, vec![13, 6, 5, 11]);
set_neighbors(&index, 4, vec![1, 3, 7, 9]);
set_neighbors(&index, 5, vec![3, 0, 8, 11, 13]);
set_neighbors(&index, 6, vec![3, 72, 7, 10, 13]);
set_neighbors(&index, 7, vec![72, 4, 6]);
set_neighbors(&index, 8, vec![72, 5, 9, 12]);
set_neighbors(&index, 9, vec![8, 4, 0, 2]);
set_neighbors(&index, 10, vec![72, 1, 9, 6]);
set_neighbors(&index, 11, vec![3, 0, 5]);
set_neighbors(&index, 12, vec![1, 0, 8, 9]);
set_neighbors(&index, 13, vec![3, 72, 5, 6]);
set_neighbors(&index, 72, vec![7, 2, 10, 8, 13]);
let mut scratch = InMemQueryScratch::new(
index.configuration.index_write_parameter.search_list_size,
&index.configuration.index_write_parameter,
false,
)
.unwrap();
let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap();
assert_eq!(visited_nodes.len(), 15);
assert_eq!(scratch.best_candidates.size(), 15);
assert_eq!(scratch.best_candidates[0].id, 2);
assert_eq!(scratch.best_candidates[0].distance, 120899.0_f32);
assert_eq!(scratch.best_candidates[1].id, 8);
assert_eq!(scratch.best_candidates[1].distance, 145538.0_f32);
assert_eq!(scratch.best_candidates[2].id, 72);
assert_eq!(scratch.best_candidates[2].distance, 146046.0_f32);
assert_eq!(scratch.best_candidates[3].id, 4);
assert_eq!(scratch.best_candidates[3].distance, 148462.0_f32);
assert_eq!(scratch.best_candidates[4].id, 7);
assert_eq!(scratch.best_candidates[4].distance, 148912.0_f32);
assert_eq!(scratch.best_candidates[5].id, 10);
assert_eq!(scratch.best_candidates[5].distance, 154570.0_f32);
assert_eq!(scratch.best_candidates[6].id, 1);
assert_eq!(scratch.best_candidates[6].distance, 159448.0_f32);
assert_eq!(scratch.best_candidates[7].id, 12);
assert_eq!(scratch.best_candidates[7].distance, 170698.0_f32);
assert_eq!(scratch.best_candidates[8].id, 9);
assert_eq!(scratch.best_candidates[8].distance, 177205.0_f32);
assert_eq!(scratch.best_candidates[9].id, 0);
assert_eq!(scratch.best_candidates[9].distance, 259996.0_f32);
assert_eq!(scratch.best_candidates[10].id, 6);
assert_eq!(scratch.best_candidates[10].distance, 371819.0_f32);
assert_eq!(scratch.best_candidates[11].id, 5);
assert_eq!(scratch.best_candidates[11].distance, 385240.0_f32);
assert_eq!(scratch.best_candidates[12].id, 3);
assert_eq!(scratch.best_candidates[12].distance, 413899.0_f32);
assert_eq!(scratch.best_candidates[13].id, 13);
assert_eq!(scratch.best_candidates[13].distance, 416386.0_f32);
assert_eq!(scratch.best_candidates[14].id, 11);
assert_eq!(scratch.best_candidates[14].distance, 449266.0_f32);
}
}

View File

@@ -0,0 +1,281 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Aligned allocator
use std::alloc::Layout;
use std::ops::{Deref, DerefMut, Range};
use std::ptr::copy_nonoverlapping;
use super::{ANNResult, ANNError};
#[derive(Debug)]
/// A box that holds a slice but is aligned to the specified layout.
///
/// This type is useful for working with types that require a certain alignment,
/// such as SIMD vectors or FFI structs. It allocates memory using the global allocator
/// and frees it when dropped. It also implements Deref and DerefMut to allow access
/// to the underlying slice.
pub struct AlignedBoxWithSlice<T> {
/// The layout of the allocated memory.
layout: Layout,
/// The slice that points to the allocated memory.
val: Box<[T]>,
}
impl<T> AlignedBoxWithSlice<T> {
/// Creates a new `AlignedBoxWithSlice` with the given capacity and alignment.
/// The allocated memory are set to 0.
///
/// # Error
///
/// Return IndexError if the alignment is not a power of two or if the layout is invalid.
///
/// This function is unsafe because it allocates uninitialized memory and casts it to
/// a slice of `T`. The caller must ensure that the capacity and alignment are valid
/// for the type `T` and that the memory is initialized before accessing the elements
/// of the slice.
pub fn new(capacity: usize, alignment: usize) -> ANNResult<Self> {
let allocsize = capacity.checked_mul(std::mem::size_of::<T>())
.ok_or_else(|| ANNError::log_index_error("capacity overflow".to_string()))?;
let layout = Layout::from_size_align(allocsize, alignment)
.map_err(ANNError::log_mem_alloc_layout_error)?;
let val = unsafe {
let mem = std::alloc::alloc_zeroed(layout);
let ptr = mem as *mut T;
let slice = std::slice::from_raw_parts_mut(ptr, capacity);
std::boxed::Box::from_raw(slice)
};
Ok(Self { layout, val })
}
/// Returns a reference to the slice.
pub fn as_slice(&self) -> &[T] {
&self.val
}
/// Returns a mutable reference to the slice.
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.val
}
/// Copies data from the source slice to the destination box.
pub fn memcpy(&mut self, src: &[T]) -> ANNResult<()> {
if src.len() > self.val.len() {
return Err(ANNError::log_index_error(format!("source slice is too large (src:{}, dst:{})", src.len(), self.val.len())));
}
// Check that they don't overlap
let src_ptr = src.as_ptr();
let src_end = unsafe { src_ptr.add(src.len()) };
let dst_ptr = self.val.as_mut_ptr();
let dst_end = unsafe { dst_ptr.add(self.val.len()) };
if src_ptr < dst_end && src_end > dst_ptr {
return Err(ANNError::log_index_error("Source and destination overlap".to_string()));
}
unsafe {
copy_nonoverlapping(src.as_ptr(), self.val.as_mut_ptr(), src.len());
}
Ok(())
}
/// Split the range of memory into nonoverlapping mutable slices.
/// The number of returned slices is (range length / slice_len) and each has a length of slice_len.
pub fn split_into_nonoverlapping_mut_slices(&mut self, range: Range<usize>, slice_len: usize) -> ANNResult<Vec<&mut [T]>> {
if range.len() % slice_len != 0 || range.end > self.len() {
return Err(ANNError::log_index_error(format!(
"Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}",
range,
self.len(),
slice_len,
)));
}
let mut slices = Vec::with_capacity(range.len() / slice_len);
let mut remaining_slice = &mut self.val[range];
while remaining_slice.len() >= slice_len {
let (left, right) = remaining_slice.split_at_mut(slice_len);
slices.push(left);
remaining_slice = right;
}
Ok(slices)
}
}
impl<T> Drop for AlignedBoxWithSlice<T> {
/// Frees the memory allocated for the slice using the global allocator.
fn drop(&mut self) {
let val = std::mem::take(&mut self.val);
let mut val2 = std::mem::ManuallyDrop::new(val);
let ptr = val2.as_mut_ptr();
unsafe {
// let nonNull = NonNull::new_unchecked(ptr as *mut u8);
std::alloc::dealloc(ptr as *mut u8, self.layout)
}
}
}
impl<T> Deref for AlignedBoxWithSlice<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
&self.val
}
}
impl<T> DerefMut for AlignedBoxWithSlice<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.val
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use crate::utils::is_aligned;
use super::*;
#[test]
fn create_alignedvec_works_32() {
(0..100).for_each(|_| {
let size = 1_000_000;
println!("Attempting {}", size);
let data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
assert_eq!(data.len(), size, "Capacity should match");
let ptr = data.as_ptr() as usize;
assert_eq!(ptr % 32, 0, "Ptr should be aligned to 32");
// assert that the slice is initialized.
(0..size).for_each(|i| {
assert_eq!(data[i], f32::default());
});
drop(data);
});
}
#[test]
fn create_alignedvec_works_256() {
let mut rng = rand::thread_rng();
(0..100).for_each(|_| {
let n = rng.gen::<u8>();
let size = usize::from(n) + 1;
println!("Attempting {}", size);
let data = AlignedBoxWithSlice::<u8>::new(size, 256).unwrap();
assert_eq!(data.len(), size, "Capacity should match");
let ptr = data.as_ptr() as usize;
assert_eq!(ptr % 256, 0, "Ptr should be aligned to 32");
// assert that the slice is initialized.
(0..size).for_each(|i| {
assert_eq!(data[i], u8::default());
});
drop(data);
});
}
#[test]
fn as_slice_test() {
let size = 1_000_000;
let data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
// assert that the slice is initialized.
(0..size).for_each(|i| {
assert_eq!(data[i], f32::default());
});
let slice = data.as_slice();
(0..size).for_each(|i| {
assert_eq!(slice[i], f32::default());
});
}
#[test]
fn as_mut_slice_test() {
let size = 1_000_000;
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
let mut_slice = data.as_mut_slice();
(0..size).for_each(|i| {
assert_eq!(mut_slice[i], f32::default());
});
}
#[test]
fn memcpy_test() {
let size = 1_000_000;
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
let mut destination = AlignedBoxWithSlice::<f32>::new(size-2, 32).unwrap();
let mut_destination = destination.as_mut_slice();
data.memcpy(mut_destination).unwrap();
(0..size-2).for_each(|i| {
assert_eq!(data[i], mut_destination[i]);
});
}
#[test]
#[should_panic(expected = "source slice is too large (src:1000000, dst:999998)")]
fn memcpy_panic_test() {
let size = 1_000_000;
let mut data = AlignedBoxWithSlice::<f32>::new(size-2, 32).unwrap();
let mut destination = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
let mut_destination = destination.as_mut_slice();
data.memcpy(mut_destination).unwrap();
}
#[test]
fn is_aligned_test() {
assert!(is_aligned(256,256));
assert!(!is_aligned(255,256));
}
#[test]
fn split_into_nonoverlapping_mut_slices_test() {
let size = 10;
let slice_len = 2;
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
let slices = data.split_into_nonoverlapping_mut_slices(2..8, slice_len).unwrap();
assert_eq!(slices.len(), 3);
for (i, slice) in slices.into_iter().enumerate() {
assert_eq!(slice.len(), slice_len);
slice[0] = i as f32 + 1.0;
slice[1] = i as f32 + 1.0;
}
let expected_arr = [0.0f32, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0, 0.0];
assert_eq!(data.as_ref(), &expected_arr);
}
#[test]
fn split_into_nonoverlapping_mut_slices_error_when_indivisible() {
let size = 10;
let slice_len = 2;
let range = 2..7;
let mut data = AlignedBoxWithSlice::<f32>::new(size, 32).unwrap();
let result = data.split_into_nonoverlapping_mut_slices(range.clone(), slice_len);
let expected_err_str = format!(
"IndexError: Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}",
range,
size,
slice_len,
);
assert!(result.is_err_and(|e| e.to_string() == expected_err_str));
}
}

View File

@@ -0,0 +1,179 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::alloc::LayoutError;
use std::array::TryFromSliceError;
use std::io;
use std::num::TryFromIntError;
use logger::error_logger::log_error;
use logger::log_error::LogError;
/// Result
pub type ANNResult<T> = Result<T, ANNError>;
/// DiskANN Error
/// ANNError is `Send` (i.e., safe to send across threads)
#[derive(thiserror::Error, Debug)]
pub enum ANNError {
/// Index construction and search error
#[error("IndexError: {err}")]
IndexError { err: String },
/// Index configuration error
#[error("IndexConfigError: {parameter} is invalid, err={err}")]
IndexConfigError { parameter: String, err: String },
/// Integer conversion error
#[error("TryFromIntError: {err}")]
TryFromIntError {
#[from]
err: TryFromIntError,
},
/// IO error
#[error("IOError: {err}")]
IOError {
#[from]
err: io::Error,
},
/// Layout error in memory allocation
#[error("MemoryAllocLayoutError: {err}")]
MemoryAllocLayoutError {
#[from]
err: LayoutError,
},
/// PoisonError which can be returned whenever a lock is acquired
/// Both Mutexes and RwLocks are poisoned whenever a thread fails while the lock is held
#[error("LockPoisonError: {err}")]
LockPoisonError { err: String },
/// DiskIOAlignmentError which can be returned when calling windows API CreateFileA for the disk index file fails.
#[error("DiskIOAlignmentError: {err}")]
DiskIOAlignmentError { err: String },
/// Logging error
#[error("LogError: {err}")]
LogError {
#[from]
err: LogError,
},
// PQ construction error
// Error happened when we construct PQ pivot or PQ compressed table
#[error("PQError: {err}")]
PQError { err: String },
/// Array conversion error
#[error("Error try creating array from slice: {err}")]
TryFromSliceError {
#[from]
err: TryFromSliceError,
},
}
impl ANNError {
/// Create, log and return IndexError
#[inline]
pub fn log_index_error(err: String) -> Self {
let ann_err = ANNError::IndexError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return IndexConfigError
#[inline]
pub fn log_index_config_error(parameter: String, err: String) -> Self {
let ann_err = ANNError::IndexConfigError { parameter, err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return TryFromIntError
#[inline]
pub fn log_try_from_int_error(err: TryFromIntError) -> Self {
let ann_err = ANNError::TryFromIntError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return IOError
#[inline]
pub fn log_io_error(err: io::Error) -> Self {
let ann_err = ANNError::IOError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return DiskIOAlignmentError
/// #[inline]
pub fn log_disk_io_request_alignment_error(err: String) -> Self {
let ann_err: ANNError = ANNError::DiskIOAlignmentError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return IOError
#[inline]
pub fn log_mem_alloc_layout_error(err: LayoutError) -> Self {
let ann_err = ANNError::MemoryAllocLayoutError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return LockPoisonError
#[inline]
pub fn log_lock_poison_error(err: String) -> Self {
let ann_err = ANNError::LockPoisonError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return PQError
#[inline]
pub fn log_pq_error(err: String) -> Self {
let ann_err = ANNError::PQError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
/// Create, log and return TryFromSliceError
#[inline]
pub fn log_try_from_slice_error(err: TryFromSliceError) -> Self {
let ann_err = ANNError::TryFromSliceError { err };
match log_error(ann_err.to_string()) {
Ok(()) => ann_err,
Err(log_err) => ANNError::LogError { err: log_err },
}
}
}
#[cfg(test)]
mod ann_result_test {
use super::*;
#[test]
fn ann_err_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ANNError>();
}
}

View File

@@ -0,0 +1,9 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
mod aligned_allocator;
pub use aligned_allocator::AlignedBoxWithSlice;
mod ann_result;
pub use ann_result::*;

View File

@@ -0,0 +1,54 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_docs)]
//! ANN disk index abstraction
use vector::FullPrecisionDistance;
use crate::model::{IndexConfiguration, DiskIndexBuildParameters};
use crate::storage::DiskIndexStorage;
use crate::model::vertex::{DIM_128, DIM_256, DIM_104};
use crate::common::{ANNResult, ANNError};
use super::DiskIndex;
/// ANN disk index abstraction for custom <T, N>
pub trait ANNDiskIndex<T> : Sync + Send
where T : Default + Copy + Sync + Send + Into<f32>
{
/// Build index
fn build(&mut self, codebook_prefix: &str) -> ANNResult<()>;
}
/// Create Index<T, N> based on configuration
pub fn create_disk_index<'a, T>(
disk_build_param: Option<DiskIndexBuildParameters>,
config: IndexConfiguration,
storage: DiskIndexStorage<T>,
) -> ANNResult<Box<dyn ANNDiskIndex<T> + 'a>>
where
T: Default + Copy + Sync + Send + Into<f32> + 'a,
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
{
match config.aligned_dim {
DIM_104 => {
let index = Box::new(DiskIndex::<T, DIM_104>::new(disk_build_param, config, storage));
Ok(index as Box<dyn ANNDiskIndex<T>>)
},
DIM_128 => {
let index = Box::new(DiskIndex::<T, DIM_128>::new(disk_build_param, config, storage));
Ok(index as Box<dyn ANNDiskIndex<T>>)
},
DIM_256 => {
let index = Box::new(DiskIndex::<T, DIM_256>::new(disk_build_param, config, storage));
Ok(index as Box<dyn ANNDiskIndex<T>>)
},
_ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))),
}
}

View File

@@ -0,0 +1,161 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::mem;
use logger::logger::indexlog::DiskIndexConstructionCheckpoint;
use vector::FullPrecisionDistance;
use crate::common::{ANNResult, ANNError};
use crate::index::{InmemIndex, ANNInmemIndex};
use crate::instrumentation::DiskIndexBuildLogger;
use crate::model::configuration::DiskIndexBuildParameters;
use crate::model::{IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE, MAX_PQ_CHUNKS, generate_quantized_data, GRAPH_SLACK_FACTOR};
use crate::storage::DiskIndexStorage;
use crate::utils::set_rayon_num_threads;
use super::ann_disk_index::ANNDiskIndex;
pub const OVERHEAD_FACTOR: f64 = 1.1f64;
pub const MAX_SAMPLE_POINTS_FOR_WARMUP: usize = 100_000;
pub struct DiskIndex<T, const N: usize>
where
[T; N]: FullPrecisionDistance<T, N>,
{
/// Parameters for index construction
/// None for query path
disk_build_param: Option<DiskIndexBuildParameters>,
configuration: IndexConfiguration,
pub storage: DiskIndexStorage<T>,
}
impl<T, const N: usize> DiskIndex<T, N>
where
T: Default + Copy + Sync + Send + Into<f32>,
[T; N]: FullPrecisionDistance<T, N>,
{
pub fn new(
disk_build_param: Option<DiskIndexBuildParameters>,
configuration: IndexConfiguration,
storage: DiskIndexStorage<T>,
) -> Self {
Self {
disk_build_param,
configuration,
storage,
}
}
pub fn disk_build_param(&self) -> &Option<DiskIndexBuildParameters> {
&self.disk_build_param
}
pub fn index_configuration(&self) -> &IndexConfiguration {
&self.configuration
}
fn build_inmem_index(&self, num_points: usize, data_path: &str, inmem_index_path: &str) -> ANNResult<()> {
let estimated_index_ram = self.estimate_ram_usage(num_points);
if estimated_index_ram >= self.fetch_disk_build_param()?.index_build_ram_limit() * 1024_f64 * 1024_f64 * 1024_f64 {
return Err(ANNError::log_index_error(format!(
"Insufficient memory budget for index build, index_build_ram_limit={}GB estimated_index_ram={}GB",
self.fetch_disk_build_param()?.index_build_ram_limit(),
estimated_index_ram / (1024_f64 * 1024_f64 * 1024_f64),
)));
}
let mut index = InmemIndex::<T, N>::new(self.configuration.clone())?;
index.build(data_path, num_points)?;
index.save(inmem_index_path)?;
Ok(())
}
#[inline]
fn estimate_ram_usage(&self, size: usize) -> f64 {
let degree = self.configuration.index_write_parameter.max_degree as usize;
let datasize = mem::size_of::<T>();
let dataset_size = (size * N * datasize) as f64;
let graph_size = (size * degree * mem::size_of::<u32>()) as f64 * GRAPH_SLACK_FACTOR;
OVERHEAD_FACTOR * (dataset_size + graph_size)
}
#[inline]
fn fetch_disk_build_param(&self) -> ANNResult<&DiskIndexBuildParameters> {
self.disk_build_param
.as_ref()
.ok_or_else(|| ANNError::log_index_config_error(
"disk_build_param".to_string(),
"disk_build_param is None".to_string()))
}
}
impl<T, const N: usize> ANNDiskIndex<T> for DiskIndex<T, N>
where
T: Default + Copy + Sync + Send + Into<f32>,
[T; N]: FullPrecisionDistance<T, N>,
{
fn build(&mut self, codebook_prefix: &str) -> ANNResult<()> {
if self.configuration.index_write_parameter.num_threads > 0 {
set_rayon_num_threads(self.configuration.index_write_parameter.num_threads);
}
println!("Starting index build: R={} L={} Query RAM budget={} Indexing RAM budget={} T={}",
self.configuration.index_write_parameter.max_degree,
self.configuration.index_write_parameter.search_list_size,
self.fetch_disk_build_param()?.search_ram_limit(),
self.fetch_disk_build_param()?.index_build_ram_limit(),
self.configuration.index_write_parameter.num_threads
);
let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction);
// PQ memory consumption = PQ pivots + PQ compressed table
// PQ pivots: dim * num_centroids * sizeof::<T>()
// PQ compressed table: num_pts * num_pq_chunks * (dim / num_pq_chunks) * sizeof::<u8>()
// * Because num_centroids is 256, centroid id can be represented by u8
let num_points = self.configuration.max_points;
let dim = self.configuration.dim;
let p_val = MAX_PQ_TRAINING_SET_SIZE / (num_points as f64);
let mut num_pq_chunks = ((self.fetch_disk_build_param()?.search_ram_limit() / (num_points as f64)).floor()) as usize;
num_pq_chunks = if num_pq_chunks == 0 { 1 } else { num_pq_chunks };
num_pq_chunks = if num_pq_chunks > dim { dim } else { num_pq_chunks };
num_pq_chunks = if num_pq_chunks > MAX_PQ_CHUNKS { MAX_PQ_CHUNKS } else { num_pq_chunks };
println!("Compressing {}-dimensional data into {} bytes per vector.", dim, num_pq_chunks);
// TODO: Decouple PQ from file access
generate_quantized_data::<T>(
p_val,
num_pq_chunks,
codebook_prefix,
self.storage.get_pq_storage(),
)?;
logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild)?;
// TODO: Decouple index from file access
let inmem_index_path = self.storage.index_path_prefix().clone() + "_mem.index";
self.build_inmem_index(num_points, self.storage.dataset_file(), inmem_index_path.as_str())?;
logger.log_checkpoint(DiskIndexConstructionCheckpoint::DiskLayout)?;
self.storage.create_disk_layout()?;
logger.log_checkpoint(DiskIndexConstructionCheckpoint::None)?;
let ten_percent_points = ((num_points as f64) * 0.1_f64).ceil();
let num_sample_points = if ten_percent_points > (MAX_SAMPLE_POINTS_FOR_WARMUP as f64) { MAX_SAMPLE_POINTS_FOR_WARMUP as f64 } else { ten_percent_points };
let sample_sampling_rate = num_sample_points / (num_points as f64);
self.storage.gen_query_warmup_data(sample_sampling_rate)?;
self.storage.index_build_cleanup()?;
Ok(())
}
}

View File

@@ -0,0 +1,9 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod disk_index;
pub use disk_index::DiskIndex;
pub mod ann_disk_index;

View File

@@ -0,0 +1,97 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_docs)]
//! ANN in-memory index abstraction
use vector::FullPrecisionDistance;
use crate::model::{vertex::{DIM_128, DIM_256, DIM_104}, IndexConfiguration};
use crate::common::{ANNResult, ANNError};
use super::InmemIndex;
/// ANN inmem-index abstraction for custom <T, N>
pub trait ANNInmemIndex<T> : Sync + Send
where T : Default + Copy + Sync + Send + Into<f32>
{
/// Build index
fn build(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()>;
/// Save index
fn save(&mut self, filename: &str) -> ANNResult<()>;
/// Load index
fn load(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<()>;
/// insert index
fn insert(&mut self, filename: &str, num_points_to_insert: usize) -> ANNResult<()>;
/// Search the index for K nearest neighbors of query using given L value, for benchmarking purposes
fn search(&self, query : &[T], k_value : usize, l_value : u32, indices : &mut[u32]) -> ANNResult<u32>;
/// Soft deletes the nodes with the ids in the given array.
fn soft_delete(&mut self, vertex_ids_to_delete: Vec<u32>, num_points_to_delete: usize) -> ANNResult<()>;
}
/// Create Index<T, N> based on configuration
pub fn create_inmem_index<'a, T>(config: IndexConfiguration) -> ANNResult<Box<dyn ANNInmemIndex<T> + 'a>>
where
T: Default + Copy + Sync + Send + Into<f32> + 'a,
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
{
match config.aligned_dim {
DIM_104 => {
let index = Box::new(InmemIndex::<T, DIM_104>::new(config)?);
Ok(index as Box<dyn ANNInmemIndex<T>>)
},
DIM_128 => {
let index = Box::new(InmemIndex::<T, DIM_128>::new(config)?);
Ok(index as Box<dyn ANNInmemIndex<T>>)
},
DIM_256 => {
let index = Box::new(InmemIndex::<T, DIM_256>::new(config)?);
Ok(index as Box<dyn ANNInmemIndex<T>>)
},
_ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))),
}
}
#[cfg(test)]
mod dataset_test {
use vector::Metric;
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
use super::*;
#[test]
#[should_panic(expected = "ERROR: Data file fake_file does not exist.")]
fn create_index_test() {
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4)
.with_alpha(1.2)
.with_saturate_graph(false)
.with_num_threads(1)
.build();
let config = IndexConfiguration::new(
Metric::L2,
128,
256,
1_000_000,
false,
0,
false,
0,
1f32,
index_write_parameters,
);
let mut index = create_inmem_index::<f32>(config).unwrap();
index.build("fake_file", 100).unwrap();
}
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::fs::File;
use std::io::{BufReader, BufWriter, Seek, SeekFrom, Write};
use std::path::Path;
use byteorder::{LittleEndian, ReadBytesExt};
use vector::FullPrecisionDistance;
use crate::common::{ANNError, ANNResult};
use crate::model::graph::AdjacencyList;
use crate::model::InMemoryGraph;
use crate::utils::{file_exists, save_data_in_base_dimensions};
use super::InmemIndex;
impl<T, const N: usize> InmemIndex<T, N>
where
T: Default + Copy + Sync + Send + Into<f32>,
[T; N]: FullPrecisionDistance<T, N>,
{
pub fn load_graph(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<usize> {
// let file_offset = 0; // will need this for single file format support
let mut in_file = BufReader::new(File::open(Path::new(filename))?);
// in_file.seek(SeekFrom::Start(file_offset as u64))?;
let expected_file_size: usize = in_file.read_u64::<LittleEndian>()? as usize;
self.max_observed_degree = in_file.read_u32::<LittleEndian>()?;
self.start = in_file.read_u32::<LittleEndian>()?;
let file_frozen_pts: usize = in_file.read_u64::<LittleEndian>()? as usize;
let vamana_metadata_size = 24;
println!("From graph header, expected_file_size: {}, max_observed_degree: {}, start: {}, file_frozen_pts: {}",
expected_file_size, self.max_observed_degree, self.start, file_frozen_pts);
if file_frozen_pts != self.configuration.num_frozen_pts {
if file_frozen_pts == 1 {
return Err(ANNError::log_index_config_error(
"num_frozen_pts".to_string(),
"ERROR: When loading index, detected dynamic index, but constructor asks for static index. Exitting.".to_string())
);
} else {
return Err(ANNError::log_index_config_error(
"num_frozen_pts".to_string(),
"ERROR: When loading index, detected static index, but constructor asks for dynamic index. Exitting.".to_string())
);
}
}
println!("Loading vamana graph {}...", filename);
let expected_max_points = expected_num_points - file_frozen_pts;
// If user provides more points than max_points
// resize the _final_graph to the larger size.
if self.configuration.max_points < expected_max_points {
println!("Number of points in data: {} is greater than max_points: {} Setting max points to: {}", expected_max_points, self.configuration.max_points, expected_max_points);
self.configuration.max_points = expected_max_points;
self.final_graph = InMemoryGraph::new(
self.configuration.max_points + self.configuration.num_frozen_pts,
self.configuration.index_write_parameter.max_degree,
);
}
let mut bytes_read = vamana_metadata_size;
let mut num_edges = 0;
let mut nodes_read = 0;
let mut max_observed_degree = 0;
while bytes_read != expected_file_size {
let num_nbrs = in_file.read_u32::<LittleEndian>()?;
max_observed_degree = if num_nbrs > max_observed_degree {
num_nbrs
} else {
max_observed_degree
};
if num_nbrs == 0 {
return Err(ANNError::log_index_error(format!(
"ERROR: Point found with no out-neighbors, point# {}",
nodes_read
)));
}
num_edges += num_nbrs;
nodes_read += 1;
let mut tmp: Vec<u32> = Vec::with_capacity(num_nbrs as usize);
for _ in 0..num_nbrs {
tmp.push(in_file.read_u32::<LittleEndian>()?);
}
self.final_graph
.write_vertex_and_neighbors(nodes_read - 1)?
.set_neighbors(AdjacencyList::from(tmp));
bytes_read += 4 * (num_nbrs as usize + 1);
}
println!(
"Done. Index has {} nodes and {} out-edges, _start is set to {}",
nodes_read, num_edges, self.start
);
self.max_observed_degree = max_observed_degree;
Ok(nodes_read as usize)
}
/// Save the graph index on a file as an adjacency list.
/// For each point, first store the number of neighbors,
/// and then the neighbor list (each as 4 byte u32)
pub fn save_graph(&mut self, graph_file: &str) -> ANNResult<u64> {
let file: File = File::create(graph_file)?;
let mut out = BufWriter::new(file);
let file_offset: u64 = 0;
out.seek(SeekFrom::Start(file_offset))?;
let mut index_size: u64 = 24;
let mut max_degree: u32 = 0;
out.write_all(&index_size.to_le_bytes())?;
out.write_all(&self.max_observed_degree.to_le_bytes())?;
out.write_all(&self.start.to_le_bytes())?;
out.write_all(&(self.configuration.num_frozen_pts as u64).to_le_bytes())?;
// At this point, either nd == max_points or any frozen points have
// been temporarily moved to nd, so nd + num_frozen_points is the valid
// location limit
for i in 0..self.num_active_pts + self.configuration.num_frozen_pts {
let idx = i as u32;
let gk: u32 = self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32;
out.write_all(&gk.to_le_bytes())?;
for neighbor in self
.final_graph
.read_vertex_and_neighbors(idx)?
.get_neighbors()
.iter()
{
out.write_all(&neighbor.to_le_bytes())?;
}
max_degree =
if self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32 > max_degree {
self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32
} else {
max_degree
};
index_size += (std::mem::size_of::<u32>() * (gk as usize + 1)) as u64;
}
out.seek(SeekFrom::Start(file_offset))?;
out.write_all(&index_size.to_le_bytes())?;
out.write_all(&max_degree.to_le_bytes())?;
out.flush()?;
Ok(index_size)
}
/// Save the data on a file.
pub fn save_data(&mut self, data_file: &str) -> ANNResult<usize> {
// Note: at this point, either _nd == _max_points or any frozen points have
// been temporarily moved to _nd, so _nd + _num_frozen_points is the valid
// location limit.
Ok(save_data_in_base_dimensions(
data_file,
&mut self.dataset.data,
self.num_active_pts + self.configuration.num_frozen_pts,
self.configuration.dim,
self.configuration.aligned_dim,
0,
)?)
}
/// Save the delete list to a file only if the delete list length is not zero.
pub fn save_delete_list(&mut self, delete_list_file: &str) -> ANNResult<usize> {
let mut delete_file_size = 0;
if let Ok(delete_set) = self.delete_set.read() {
let delete_set_len = delete_set.len() as u32;
if delete_set_len != 0 {
let file: File = File::create(delete_list_file)?;
let mut writer = BufWriter::new(file);
// Write the length of the set.
writer.write_all(&delete_set_len.to_le_bytes())?;
delete_file_size += std::mem::size_of::<u32>();
// Write the elements of the set.
for &item in delete_set.iter() {
writer.write_all(&item.to_be_bytes())?;
delete_file_size += std::mem::size_of::<u32>();
}
writer.flush()?;
}
} else {
return Err(ANNError::log_lock_poison_error(
"Poisoned lock on delete set. Can't save deleted list.".to_string(),
));
}
Ok(delete_file_size)
}
// load the deleted list from the delete file if it exists.
pub fn load_delete_list(&mut self, delete_list_file: &str) -> ANNResult<usize> {
let mut len = 0;
if file_exists(delete_list_file) {
let file = File::open(delete_list_file)?;
let mut reader = BufReader::new(file);
len = reader.read_u32::<LittleEndian>()? as usize;
if let Ok(mut delete_set) = self.delete_set.write() {
for _ in 0..len {
let item = reader.read_u32::<LittleEndian>()?;
delete_set.insert(item);
}
} else {
return Err(ANNError::log_lock_poison_error(
"Poisoned lock on delete set. Can't load deleted list.".to_string(),
));
}
}
Ok(len)
}
}
#[cfg(test)]
mod index_test {
use std::fs;
use vector::Metric;
use super::*;
use crate::{
index::ANNInmemIndex,
model::{
configuration::index_write_parameters::IndexWriteParametersBuilder, vertex::DIM_128,
IndexConfiguration,
},
utils::{load_metadata_from_file, round_up},
};
const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin";
const R: u32 = 4;
const L: u32 = 50;
const ALPHA: f32 = 1.2;
#[cfg_attr(not(coverage), test)]
fn save_graph_test() {
let parameters = IndexWriteParametersBuilder::new(50, 4)
.with_alpha(1.2)
.build();
let config =
IndexConfiguration::new(Metric::L2, 10, 16, 16, false, 0, false, 8, 1f32, parameters);
let mut index = InmemIndex::<f32, 3>::new(config).unwrap();
let final_graph = InMemoryGraph::new(10, 3);
let num_active_pts = 2_usize;
index.final_graph = final_graph;
index.num_active_pts = num_active_pts;
let graph_file = "test_save_graph_data.bin";
let result = index.save_graph(graph_file);
assert!(result.is_ok());
fs::remove_file(graph_file).expect("Failed to delete file");
}
#[test]
fn save_data_test() {
let (data_num, dim) = load_metadata_from_file(TEST_DATA_FILE).unwrap();
let index_write_parameters = IndexWriteParametersBuilder::new(L, R)
.with_alpha(ALPHA)
.build();
let config = IndexConfiguration::new(
Metric::L2,
dim,
round_up(dim as u64, 16_u64) as usize,
data_num,
false,
0,
false,
0,
1f32,
index_write_parameters,
);
let mut index: InmemIndex<f32, DIM_128> = InmemIndex::new(config).unwrap();
index.build(TEST_DATA_FILE, data_num).unwrap();
let data_file = "test.data";
let result = index.save_data(data_file);
assert_eq!(
result.unwrap(),
2 * std::mem::size_of::<u32>()
+ (index.num_active_pts + index.configuration.num_frozen_pts)
* index.configuration.dim
* (std::mem::size_of::<f32>())
);
fs::remove_file(data_file).expect("Failed to delete file");
}
}

View File

@@ -0,0 +1,12 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod inmem_index;
pub use inmem_index::InmemIndex;
mod inmem_index_storage;
pub mod ann_inmem_index;

View File

@@ -0,0 +1,11 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
mod inmem_index;
pub use inmem_index::ann_inmem_index::*;
pub use inmem_index::InmemIndex;
mod disk_index;
pub use disk_index::*;

View File

@@ -0,0 +1,57 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use logger::logger::indexlog::DiskIndexConstructionCheckpoint;
use logger::logger::indexlog::DiskIndexConstructionLog;
use logger::logger::indexlog::Log;
use logger::logger::indexlog::LogLevel;
use logger::message_handler::send_log;
use crate::{utils::Timer, common::ANNResult};
pub struct DiskIndexBuildLogger {
timer: Timer,
checkpoint: DiskIndexConstructionCheckpoint,
}
impl DiskIndexBuildLogger {
pub fn new(checkpoint: DiskIndexConstructionCheckpoint) -> Self {
Self {
timer: Timer::new(),
checkpoint,
}
}
pub fn log_checkpoint(&mut self, next_checkpoint: DiskIndexConstructionCheckpoint) -> ANNResult<()> {
if self.checkpoint == DiskIndexConstructionCheckpoint::None {
return Ok(());
}
let mut log = Log::default();
let disk_index_construction_log = DiskIndexConstructionLog {
checkpoint: self.checkpoint as i32,
time_spent_in_seconds: self.timer.elapsed().as_secs_f32(),
g_cycles_spent: self.timer.elapsed_gcycles(),
log_level: LogLevel::Info as i32,
};
log.disk_index_construction_log = Some(disk_index_construction_log);
send_log(log)?;
self.checkpoint = next_checkpoint;
self.timer.reset();
Ok(())
}
}
#[cfg(test)]
mod dataset_test {
use super::*;
#[test]
fn test_log() {
let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction);
logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild).unwrap();logger.log_checkpoint(logger::logger::indexlog::DiskIndexConstructionCheckpoint::DiskLayout).unwrap();
}
}

View File

@@ -0,0 +1,47 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::sync::atomic::{AtomicUsize, Ordering};
use logger::logger::indexlog::IndexConstructionLog;
use logger::logger::indexlog::Log;
use logger::logger::indexlog::LogLevel;
use logger::message_handler::send_log;
use crate::common::ANNResult;
use crate::utils::Timer;
pub struct IndexLogger {
items_processed: AtomicUsize,
timer: Timer,
range: usize,
}
impl IndexLogger {
pub fn new(range: usize) -> Self {
Self {
items_processed: AtomicUsize::new(0),
timer: Timer::new(),
range,
}
}
pub fn vertex_processed(&self) -> ANNResult<()> {
let count = self.items_processed.fetch_add(1, Ordering::Relaxed);
if count % 100_000 == 0 {
let mut log = Log::default();
let index_construction_log = IndexConstructionLog {
percentage_complete: (100_f32 * count as f32) / (self.range as f32),
time_spent_in_seconds: self.timer.elapsed().as_secs_f32(),
g_cycles_spent: self.timer.elapsed_gcycles(),
log_level: LogLevel::Info as i32,
};
log.index_construction_log = Some(index_construction_log);
send_log(log)?;
}
Ok(())
}
}

View File

@@ -0,0 +1,9 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
mod index_logger;
pub use index_logger::IndexLogger;
mod disk_index_build_logger;
pub use disk_index_build_logger::DiskIndexBuildLogger;

View File

@@ -0,0 +1,26 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![cfg_attr(
not(test),
warn(clippy::panic, clippy::unwrap_used, clippy::expect_used)
)]
#![cfg_attr(test, allow(clippy::unused_io_amount))]
pub mod utils;
pub mod algorithm;
pub mod model;
pub mod common;
pub mod index;
pub mod storage;
pub mod instrumentation;
#[cfg(test)]
pub mod test_utils;

View File

@@ -0,0 +1,85 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Parameters for disk index construction.
use crate::common::{ANNResult, ANNError};
/// Cached nodes size in GB
const SPACE_FOR_CACHED_NODES_IN_GB: f64 = 0.25;
/// Threshold for caching in GB
const THRESHOLD_FOR_CACHING_IN_GB: f64 = 1.0;
/// Parameters specific for disk index construction.
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct DiskIndexBuildParameters {
/// Bound on the memory footprint of the index at search time in bytes.
/// Once built, the index will use up only the specified RAM limit, the rest will reside on disk.
/// This will dictate how aggressively we compress the data vectors to store in memory.
/// Larger will yield better performance at search time.
search_ram_limit: f64,
/// Limit on the memory allowed for building the index in bytes.
index_build_ram_limit: f64,
}
impl DiskIndexBuildParameters {
/// Create DiskIndexBuildParameters instance
pub fn new(search_ram_limit_gb: f64, index_build_ram_limit_gb: f64) -> ANNResult<Self> {
let param = Self {
search_ram_limit: Self::get_memory_budget(search_ram_limit_gb),
index_build_ram_limit: index_build_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64,
};
if param.search_ram_limit <= 0f64 {
return Err(ANNError::log_index_config_error("search_ram_limit".to_string(), "RAM budget should be > 0".to_string()))
}
if param.index_build_ram_limit <= 0f64 {
return Err(ANNError::log_index_config_error("index_build_ram_limit".to_string(), "RAM budget should be > 0".to_string()))
}
Ok(param)
}
/// Get search_ram_limit
pub fn search_ram_limit(&self) -> f64 {
self.search_ram_limit
}
/// Get index_build_ram_limit
pub fn index_build_ram_limit(&self) -> f64 {
self.index_build_ram_limit
}
fn get_memory_budget(mut index_ram_limit_gb: f64) -> f64 {
if index_ram_limit_gb - SPACE_FOR_CACHED_NODES_IN_GB > THRESHOLD_FOR_CACHING_IN_GB {
// slack for space used by cached nodes
index_ram_limit_gb -= SPACE_FOR_CACHED_NODES_IN_GB;
}
index_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64
}
}
#[cfg(test)]
mod dataset_test {
use super::*;
#[test]
fn sufficient_ram_for_caching() {
let param = DiskIndexBuildParameters::new(1.26_f64, 1.0_f64).unwrap();
assert_eq!(param.search_ram_limit, 1.01_f64 * 1024_f64 * 1024_f64 * 1024_f64);
}
#[test]
fn insufficient_ram_for_caching() {
let param = DiskIndexBuildParameters::new(0.03_f64, 1.0_f64).unwrap();
assert_eq!(param.search_ram_limit, 0.03_f64 * 1024_f64 * 1024_f64 * 1024_f64);
}
}

View File

@@ -0,0 +1,92 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Index configuration.
use vector::Metric;
use super::index_write_parameters::IndexWriteParameters;
/// The index configuration
#[derive(Debug, Clone)]
pub struct IndexConfiguration {
/// Index write parameter
pub index_write_parameter: IndexWriteParameters,
/// Distance metric
pub dist_metric: Metric,
/// Dimension of the raw data
pub dim: usize,
/// Aligned dimension - round up dim to the nearest multiple of 8
pub aligned_dim: usize,
/// Total number of points in given data set
pub max_points: usize,
/// Number of points which are used as initial candidates when iterating to
/// closest point(s). These are not visible externally and won't be returned
/// by search. DiskANN forces at least 1 frozen point for dynamic index.
/// The frozen points have consecutive locations.
pub num_frozen_pts: usize,
/// Calculate distance by PQ or not
pub use_pq_dist: bool,
/// Number of PQ chunks
pub num_pq_chunks: usize,
/// Use optimized product quantization
/// Currently not supported
pub use_opq: bool,
/// potential for growth. 1.2 means the index can grow by up to 20%.
pub growth_potential: f32,
// TODO: below settings are not supported in current iteration
// pub concurrent_consolidate: bool,
// pub has_built: bool,
// pub save_as_one_file: bool,
// pub dynamic_index: bool,
// pub enable_tags: bool,
// pub normalize_vecs: bool,
}
impl IndexConfiguration {
/// Create IndexConfiguration instance
#[allow(clippy::too_many_arguments)]
pub fn new(
dist_metric: Metric,
dim: usize,
aligned_dim: usize,
max_points: usize,
use_pq_dist: bool,
num_pq_chunks: usize,
use_opq: bool,
num_frozen_pts: usize,
growth_potential: f32,
index_write_parameter: IndexWriteParameters
) -> Self {
Self {
index_write_parameter,
dist_metric,
dim,
aligned_dim,
max_points,
num_frozen_pts,
use_pq_dist,
num_pq_chunks,
use_opq,
growth_potential,
}
}
/// Get the size of adjacency list that we build out.
pub fn write_range(&self) -> usize {
self.index_write_parameter.max_degree as usize
}
}

View File

@@ -0,0 +1,245 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Index write parameters.
/// Default parameter values.
pub mod default_param_vals {
/// Default value of alpha.
pub const ALPHA: f32 = 1.2;
/// Default value of number of threads.
pub const NUM_THREADS: u32 = 0;
/// Default value of number of rounds.
pub const NUM_ROUNDS: u32 = 2;
/// Default value of max occlusion size.
pub const MAX_OCCLUSION_SIZE: u32 = 750;
/// Default value of filter list size.
pub const FILTER_LIST_SIZE: u32 = 0;
/// Default value of number of frozen points.
pub const NUM_FROZEN_POINTS: u32 = 0;
/// Default value of max degree.
pub const MAX_DEGREE: u32 = 64;
/// Default value of build list size.
pub const BUILD_LIST_SIZE: u32 = 100;
/// Default value of saturate graph.
pub const SATURATE_GRAPH: bool = false;
/// Default value of search list size.
pub const SEARCH_LIST_SIZE: u32 = 100;
}
/// Index write parameters.
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct IndexWriteParameters {
/// Search list size - L.
pub search_list_size: u32,
/// Max degree - R.
pub max_degree: u32,
/// Saturate graph.
pub saturate_graph: bool,
/// Max occlusion size - C.
pub max_occlusion_size: u32,
/// Alpha.
pub alpha: f32,
/// Number of rounds.
pub num_rounds: u32,
/// Number of threads.
pub num_threads: u32,
/// Number of frozen points.
pub num_frozen_points: u32,
}
impl Default for IndexWriteParameters {
/// Create IndexWriteParameters with default values
fn default() -> Self {
Self {
search_list_size: default_param_vals::SEARCH_LIST_SIZE,
max_degree: default_param_vals::MAX_DEGREE,
saturate_graph: default_param_vals::SATURATE_GRAPH,
max_occlusion_size: default_param_vals::MAX_OCCLUSION_SIZE,
alpha: default_param_vals::ALPHA,
num_rounds: default_param_vals::NUM_ROUNDS,
num_threads: default_param_vals::NUM_THREADS,
num_frozen_points: default_param_vals::NUM_FROZEN_POINTS
}
}
}
/// The builder for IndexWriteParameters.
#[derive(Debug)]
pub struct IndexWriteParametersBuilder {
search_list_size: u32,
max_degree: u32,
max_occlusion_size: Option<u32>,
saturate_graph: Option<bool>,
alpha: Option<f32>,
num_rounds: Option<u32>,
num_threads: Option<u32>,
// filter_list_size: Option<u32>,
num_frozen_points: Option<u32>,
}
impl IndexWriteParametersBuilder {
/// Initialize IndexWriteParametersBuilder
pub fn new(search_list_size: u32, max_degree: u32) -> Self {
Self {
search_list_size,
max_degree,
max_occlusion_size: None,
saturate_graph: None,
alpha: None,
num_rounds: None,
num_threads: None,
// filter_list_size: None,
num_frozen_points: None,
}
}
/// Set max occlusion size.
pub fn with_max_occlusion_size(mut self, max_occlusion_size: u32) -> Self {
self.max_occlusion_size = Some(max_occlusion_size);
self
}
/// Set saturate graph.
pub fn with_saturate_graph(mut self, saturate_graph: bool) -> Self {
self.saturate_graph = Some(saturate_graph);
self
}
/// Set alpha.
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = Some(alpha);
self
}
/// Set number of rounds.
pub fn with_num_rounds(mut self, num_rounds: u32) -> Self {
self.num_rounds = Some(num_rounds);
self
}
/// Set number of threads.
pub fn with_num_threads(mut self, num_threads: u32) -> Self {
self.num_threads = Some(num_threads);
self
}
/*
pub fn with_filter_list_size(mut self, filter_list_size: u32) -> Self {
self.filter_list_size = Some(filter_list_size);
self
}
*/
/// Set number of frozen points.
pub fn with_num_frozen_points(mut self, num_frozen_points: u32) -> Self {
self.num_frozen_points = Some(num_frozen_points);
self
}
/// Build IndexWriteParameters from IndexWriteParametersBuilder.
pub fn build(self) -> IndexWriteParameters {
IndexWriteParameters {
search_list_size: self.search_list_size,
max_degree: self.max_degree,
saturate_graph: self.saturate_graph.unwrap_or(default_param_vals::SATURATE_GRAPH),
max_occlusion_size: self.max_occlusion_size.unwrap_or(default_param_vals::MAX_OCCLUSION_SIZE),
alpha: self.alpha.unwrap_or(default_param_vals::ALPHA),
num_rounds: self.num_rounds.unwrap_or(default_param_vals::NUM_ROUNDS),
num_threads: self.num_threads.unwrap_or(default_param_vals::NUM_THREADS),
// filter_list_size: self.filter_list_size.unwrap_or(default_param_vals::FILTER_LIST_SIZE),
num_frozen_points: self.num_frozen_points.unwrap_or(default_param_vals::NUM_FROZEN_POINTS),
}
}
}
/// Construct IndexWriteParametersBuilder from IndexWriteParameters.
impl From<IndexWriteParameters> for IndexWriteParametersBuilder {
fn from(param: IndexWriteParameters) -> Self {
Self {
search_list_size: param.search_list_size,
max_degree: param.max_degree,
max_occlusion_size: Some(param.max_occlusion_size),
saturate_graph: Some(param.saturate_graph),
alpha: Some(param.alpha),
num_rounds: Some(param.num_rounds),
num_threads: Some(param.num_threads),
// filter_list_size: Some(param.filter_list_size),
num_frozen_points: Some(param.num_frozen_points),
}
}
}
#[cfg(test)]
mod parameters_test {
use crate::model::configuration::index_write_parameters::*;
#[test]
fn test_default_index_params() {
let wp1 = IndexWriteParameters::default();
assert_eq!(wp1.search_list_size, default_param_vals::SEARCH_LIST_SIZE);
assert_eq!(wp1.max_degree, default_param_vals::MAX_DEGREE);
assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH);
assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE);
assert_eq!(wp1.alpha, default_param_vals::ALPHA);
assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS);
assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS);
assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS);
}
#[test]
fn test_index_write_parameters_builder() {
// default value
let wp1 = IndexWriteParametersBuilder::new(10, 20).build();
assert_eq!(wp1.search_list_size, 10);
assert_eq!(wp1.max_degree, 20);
assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH);
assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE);
assert_eq!(wp1.alpha, default_param_vals::ALPHA);
assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS);
assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS);
assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS);
// build with custom values
let wp2 = IndexWriteParametersBuilder::new(10, 20)
.with_max_occlusion_size(30)
.with_saturate_graph(true)
.with_alpha(0.5)
.with_num_rounds(40)
.with_num_threads(50)
.with_num_frozen_points(60)
.build();
assert_eq!(wp2.search_list_size, 10);
assert_eq!(wp2.max_degree, 20);
assert!(wp2.saturate_graph);
assert_eq!(wp2.max_occlusion_size, 30);
assert_eq!(wp2.alpha, 0.5);
assert_eq!(wp2.num_rounds, 40);
assert_eq!(wp2.num_threads, 50);
assert_eq!(wp2.num_frozen_points, 60);
// test from
let wp3 = IndexWriteParametersBuilder::from(wp2).build();
assert_eq!(wp3, wp2);
}
}

View File

@@ -0,0 +1,12 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub mod index_configuration;
pub use index_configuration::IndexConfiguration;
pub mod index_write_parameters;
pub use index_write_parameters::*;
pub mod disk_index_build_parameter;
pub use disk_index_build_parameter::DiskIndexBuildParameters;

View File

@@ -0,0 +1,76 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Disk scratch dataset
use std::mem::{size_of, size_of_val};
use std::ptr;
use crate::common::{AlignedBoxWithSlice, ANNResult};
use crate::model::MAX_N_CMPS;
use crate::utils::round_up;
/// DiskScratchDataset alignment
pub const DISK_SCRATCH_DATASET_ALIGN: usize = 256;
/// Disk scratch dataset storing fp vectors with aligned dim
#[derive(Debug)]
pub struct DiskScratchDataset<T, const N: usize>
{
/// fp vectors with aligned dim
pub data: AlignedBoxWithSlice<T>,
/// current index to store the next fp vector
pub cur_index: usize,
}
impl<T, const N: usize> DiskScratchDataset<T, N>
{
/// Create DiskScratchDataset instance
pub fn new() -> ANNResult<Self> {
Ok(Self {
// C++ code allocates round_up(MAX_N_CMPS * N, 256) bytes, shouldn't it be round_up(MAX_N_CMPS * N, 256) * size_of::<T> bytes?
data: AlignedBoxWithSlice::new(
round_up(MAX_N_CMPS * N, DISK_SCRATCH_DATASET_ALIGN),
DISK_SCRATCH_DATASET_ALIGN)?,
cur_index: 0,
})
}
/// memcpy from fp vector bytes (its len should be `dim * size_of::<T>()`) to self.data
/// The dest slice is a fp vector with aligned dim
/// * fp_vector_buf's dim might not be aligned dim (N)
/// # Safety
/// Behavior is undefined if any of the following conditions are violated:
///
/// * `fp_vector_buf`'s len must be `dim * size_of::<T>()` bytes
///
/// * `fp_vector_buf` must be smaller than or equal to `N * size_of::<T>()` bytes.
///
/// * `fp_vector_buf` and `self.data` must be nonoverlapping.
pub unsafe fn memcpy_from_fp_vector_buf(&mut self, fp_vector_buf: &[u8]) -> &[T] {
if self.cur_index == MAX_N_CMPS {
self.cur_index = 0;
}
let aligned_dim_vector = &mut self.data[self.cur_index * N..(self.cur_index + 1) * N];
assert!(fp_vector_buf.len() % size_of::<T>() == 0);
assert!(fp_vector_buf.len() <= size_of_val(aligned_dim_vector));
// memcpy from fp_vector_buf to aligned_dim_vector
unsafe {
ptr::copy_nonoverlapping(
fp_vector_buf.as_ptr(),
aligned_dim_vector.as_mut_ptr() as *mut u8,
fp_vector_buf.len(),
);
}
self.cur_index += 1;
aligned_dim_vector
}
}

View File

@@ -0,0 +1,285 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! In-memory Dataset
use rayon::prelude::*;
use std::mem;
use vector::{FullPrecisionDistance, Metric};
use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice};
use crate::model::Vertex;
use crate::utils::copy_aligned_data_from_file;
/// Dataset of all in-memory FP points
#[derive(Debug)]
pub struct InmemDataset<T, const N: usize>
where
[T; N]: FullPrecisionDistance<T, N>,
{
/// All in-memory points
pub data: AlignedBoxWithSlice<T>,
/// Number of points we anticipate to have
pub num_points: usize,
/// Number of active points i.e. existing in the graph
pub num_active_pts: usize,
/// Capacity of the dataset
pub capacity: usize,
}
impl<'a, T, const N: usize> InmemDataset<T, N>
where
T: Default + Copy + Sync + Send + Into<f32>,
[T; N]: FullPrecisionDistance<T, N>,
{
/// Create the dataset with size num_points and growth factor.
/// growth factor=1 means no growth (provision 100% space of num_points)
/// growth factor=1.2 means provision 120% space of num_points (20% extra space)
pub fn new(num_points: usize, index_growth_factor: f32) -> ANNResult<Self> {
let capacity = (((num_points * N) as f32) * index_growth_factor) as usize;
Ok(Self {
data: AlignedBoxWithSlice::new(capacity, mem::size_of::<T>() * 16)?,
num_points,
num_active_pts: num_points,
capacity,
})
}
/// get immutable data slice
pub fn get_data(&self) -> &[T] {
&self.data
}
/// Build the dataset from file
pub fn build_from_file(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()> {
println!(
"Loading {} vectors from file {} into dataset...",
num_points_to_load, filename
);
self.num_active_pts = num_points_to_load;
copy_aligned_data_from_file(filename, self.into_dto(), 0)?;
println!("Dataset loaded.");
Ok(())
}
/// Append the dataset from file
pub fn append_from_file(
&mut self,
filename: &str,
num_points_to_append: usize,
) -> ANNResult<()> {
println!(
"Appending {} vectors from file {} into dataset...",
num_points_to_append, filename
);
if self.num_points + num_points_to_append > self.capacity {
return Err(ANNError::log_index_error(format!(
"Cannot append {} points to dataset of capacity {}",
num_points_to_append, self.capacity
)));
}
let pts_offset = self.num_active_pts;
copy_aligned_data_from_file(filename, self.into_dto(), pts_offset)?;
self.num_active_pts += num_points_to_append;
self.num_points += num_points_to_append;
println!("Dataset appended.");
Ok(())
}
/// Get vertex by id
pub fn get_vertex(&'a self, id: u32) -> ANNResult<Vertex<'a, T, N>> {
let start = id as usize * N;
let end = start + N;
if end <= self.data.len() {
let val = <&[T; N]>::try_from(&self.data[start..end]).map_err(|err| {
ANNError::log_index_error(format!("Failed to get vertex {}, err={}", id, err))
})?;
Ok(Vertex::new(val, id))
} else {
Err(ANNError::log_index_error(format!(
"Invalid vertex id {}.",
id
)))
}
}
/// Get full precision distance between two nodes
pub fn get_distance(&self, id1: u32, id2: u32, metric: Metric) -> ANNResult<f32> {
let vertex1 = self.get_vertex(id1)?;
let vertex2 = self.get_vertex(id2)?;
Ok(vertex1.compare(&vertex2, metric))
}
/// find out the medoid, the vertex in the dataset that is closest to the centroid
pub fn calculate_medoid_point_id(&self) -> ANNResult<u32> {
Ok(self.find_nearest_point_id(self.calculate_centroid_point()?))
}
/// calculate centroid, average of all vertices in the dataset
fn calculate_centroid_point(&self) -> ANNResult<[f32; N]> {
// Allocate and initialize the centroid vector
let mut center: [f32; N] = [0.0; N];
// Sum the data points' components
for i in 0..self.num_active_pts {
let vertex = self.get_vertex(i as u32)?;
let vertex_slice = vertex.vector();
for j in 0..N {
center[j] += vertex_slice[j].into();
}
}
// Divide by the number of points to calculate the centroid
let capacity = self.num_active_pts as f32;
for item in center.iter_mut().take(N) {
*item /= capacity;
}
Ok(center)
}
/// find out the vertex closest to the given point
fn find_nearest_point_id(&self, point: [f32; N]) -> u32 {
// compute all to one distance
let mut distances = vec![0f32; self.num_active_pts];
let slice = &self.data[..];
distances.par_iter_mut().enumerate().for_each(|(i, dist)| {
let start = i * N;
for j in 0..N {
let diff: f32 = (point.as_slice()[j] - slice[start + j].into())
* (point.as_slice()[j] - slice[start + j].into());
*dist += diff;
}
});
let mut min_idx = 0;
let mut min_dist = f32::MAX;
for (i, distance) in distances.iter().enumerate().take(self.num_active_pts) {
if *distance < min_dist {
min_idx = i;
min_dist = *distance;
}
}
min_idx as u32
}
/// Prefetch vertex data in the memory hierarchy
/// NOTE: good efficiency when total_vec_size is integral multiple of 64
#[inline]
pub fn prefetch_vector(&self, id: u32) {
let start = id as usize * N;
let end = start + N;
if end <= self.data.len() {
let vec = &self.data[start..end];
vector::prefetch_vector(vec);
}
}
/// Convert into dto object
pub fn into_dto(&mut self) -> DatasetDto<T> {
DatasetDto {
data: &mut self.data,
rounded_dim: N,
}
}
}
/// Dataset dto used for other layer, such as storage
/// N is the aligned dimension
#[derive(Debug)]
pub struct DatasetDto<'a, T> {
/// data slice borrow from dataset
pub data: &'a mut [T],
/// rounded dimension
pub rounded_dim: usize,
}
#[cfg(test)]
mod dataset_test {
use std::fs;
use super::*;
use crate::model::vertex::DIM_128;
#[test]
fn get_vertex_within_range() {
let num_points = 1_000_000;
let id = 999_999;
let dataset = InmemDataset::<f32, DIM_128>::new(num_points, 1f32).unwrap();
let vertex = dataset.get_vertex(999_999).unwrap();
assert_eq!(vertex.vertex_id(), id);
assert_eq!(vertex.vector().len(), DIM_128);
assert_eq!(vertex.vector().as_ptr(), unsafe {
dataset.data.as_ptr().add((id as usize) * DIM_128)
});
}
#[test]
fn get_vertex_out_of_range() {
let num_points = 1_000_000;
let invalid_id = 1_000_000;
let dataset = InmemDataset::<f32, DIM_128>::new(num_points, 1f32).unwrap();
if dataset.get_vertex(invalid_id).is_ok() {
panic!("id ({}) should be out of range", invalid_id)
};
}
#[test]
fn load_data_test() {
let file_name = "dataset_test_load_data_test.bin";
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
let data: [u8; 72] = [
2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
];
std::fs::write(file_name, data).expect("Failed to write sample file");
let mut dataset = InmemDataset::<f32, 8>::new(2, 1f32).unwrap();
match copy_aligned_data_from_file(
file_name,
dataset.into_dto(),
0,
) {
Ok((npts, dim)) => {
fs::remove_file(file_name).expect("Failed to delete file");
assert!(npts == 2);
assert!(dim == 8);
assert!(dataset.data.len() == 16);
let first_vertex = dataset.get_vertex(0).unwrap();
let second_vertex = dataset.get_vertex(1).unwrap();
assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
}
Err(e) => {
fs::remove_file(file_name).expect("Failed to delete file");
panic!("{}", e)
}
}
}
}

View File

@@ -0,0 +1,11 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod inmem_dataset;
pub use inmem_dataset::InmemDataset;
pub use inmem_dataset::DatasetDto;
mod disk_scratch_dataset;
pub use disk_scratch_dataset::*;

View File

@@ -0,0 +1,64 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Adjacency List
use std::ops::{Deref, DerefMut};
#[derive(Debug, Eq, PartialEq)]
/// Represents the out neighbors of a vertex
pub struct AdjacencyList {
edges: Vec<u32>,
}
/// In-mem index related limits
const GRAPH_SLACK_FACTOR: f32 = 1.3_f32;
impl AdjacencyList {
/// Create AdjacencyList with capacity slack for a range.
pub fn for_range(range: usize) -> Self {
let capacity = (range as f32 * GRAPH_SLACK_FACTOR).ceil() as usize;
Self {
edges: Vec::with_capacity(capacity),
}
}
/// Push a node to the list of neighbors for the given node.
pub fn push(&mut self, node_id: u32) {
debug_assert!(self.edges.len() < self.edges.capacity());
self.edges.push(node_id);
}
}
impl From<Vec<u32>> for AdjacencyList {
fn from(edges: Vec<u32>) -> Self {
Self { edges }
}
}
impl Deref for AdjacencyList {
type Target = Vec<u32>;
fn deref(&self) -> &Self::Target {
&self.edges
}
}
impl DerefMut for AdjacencyList {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.edges
}
}
impl<'a> IntoIterator for &'a AdjacencyList {
type Item = &'a u32;
type IntoIter = std::slice::Iter<'a, u32>;
fn into_iter(self) -> Self::IntoIter {
self.edges.iter()
}
}

View File

@@ -0,0 +1,179 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_docs)]
//! Disk graph
use byteorder::{LittleEndian, ByteOrder};
use vector::FullPrecisionDistance;
use crate::common::{ANNResult, ANNError};
use crate::model::data_store::DiskScratchDataset;
use crate::model::Vertex;
use crate::storage::DiskGraphStorage;
use super::{VertexAndNeighbors, SectorGraph, AdjacencyList};
/// Disk graph
pub struct DiskGraph {
/// dim of fp vector in disk sector
dim: usize,
/// number of nodes per sector
num_nodes_per_sector: u64,
/// max node length in bytes
max_node_len: u64,
/// the len of fp vector
fp_vector_len: u64,
/// list of nodes (vertex_id) to fetch from disk
nodes_to_fetch: Vec<u32>,
/// Sector graph
sector_graph: SectorGraph,
}
impl<'a> DiskGraph {
/// Create DiskGraph instance
pub fn new(
dim: usize,
num_nodes_per_sector: u64,
max_node_len: u64,
fp_vector_len: u64,
beam_width: usize,
graph_storage: DiskGraphStorage,
) -> ANNResult<Self> {
let graph = Self {
dim,
num_nodes_per_sector,
max_node_len,
fp_vector_len,
nodes_to_fetch: Vec::with_capacity(2 * beam_width),
sector_graph: SectorGraph::new(graph_storage)?,
};
Ok(graph)
}
/// Add vertex_id into the list to fetch from disk
pub fn add_vertex(&mut self, id: u32) {
self.nodes_to_fetch.push(id);
}
/// Fetch nodes from disk index
pub fn fetch_nodes(&mut self) -> ANNResult<()> {
let sectors_to_fetch: Vec<u64> = self.nodes_to_fetch.iter().map(|&id| self.node_sector_index(id)).collect();
self.sector_graph.read_graph(&sectors_to_fetch)?;
Ok(())
}
/// Copy disk fp vector to DiskScratchDataset
/// Return the fp vector with aligned dim from DiskScratchDataset
pub fn copy_fp_vector_to_disk_scratch_dataset<T, const N: usize>(
&self,
node_index: usize,
disk_scratch_dataset: &'a mut DiskScratchDataset<T, N>
) -> ANNResult<Vertex<'a, T, N>>
where
[T; N]: FullPrecisionDistance<T, N>,
{
if self.dim > N {
return Err(ANNError::log_index_error(format!(
"copy_sector_fp_to_aligned_dataset: dim {} is greater than aligned dim {}",
self.dim, N)));
}
let fp_vector_buf = self.node_fp_vector_buf(node_index);
// Safety condition is met here
let aligned_dim_vector = unsafe { disk_scratch_dataset.memcpy_from_fp_vector_buf(fp_vector_buf) };
Vertex::<'a, T, N>::try_from((aligned_dim_vector, self.nodes_to_fetch[node_index]))
.map_err(|err| ANNError::log_index_error(format!("TryFromSliceError: failed to get Vertex for disk index node, err={}", err)))
}
/// Reset graph
pub fn reset(&mut self) {
self.nodes_to_fetch.clear();
self.sector_graph.reset();
}
fn get_vertex_and_neighbors(&self, node_index: usize) -> VertexAndNeighbors {
let node_disk_buf = self.node_disk_buf(node_index);
let buf = &node_disk_buf[self.fp_vector_len as usize..];
let num_neighbors = LittleEndian::read_u32(&buf[0..4]) as usize;
let neighbors_buf = &buf[4..4 + num_neighbors * 4];
let mut adjacency_list = AdjacencyList::for_range(num_neighbors);
for chunk in neighbors_buf.chunks(4) {
let neighbor_id = LittleEndian::read_u32(chunk);
adjacency_list.push(neighbor_id);
}
VertexAndNeighbors::new(self.nodes_to_fetch[node_index], adjacency_list)
}
#[inline]
fn node_sector_index(&self, vertex_id: u32) -> u64 {
vertex_id as u64 / self.num_nodes_per_sector + 1
}
#[inline]
fn node_disk_buf(&self, node_index: usize) -> &[u8] {
let vertex_id = self.nodes_to_fetch[node_index];
// get sector_buf where this node is located
let sector_buf = self.sector_graph.get_sector_buf(node_index);
let node_offset = (vertex_id as u64 % self.num_nodes_per_sector * self.max_node_len) as usize;
&sector_buf[node_offset..node_offset + self.max_node_len as usize]
}
#[inline]
fn node_fp_vector_buf(&self, node_index: usize) -> &[u8] {
let node_disk_buf = self.node_disk_buf(node_index);
&node_disk_buf[..self.fp_vector_len as usize]
}
}
/// Iterator for DiskGraph
pub struct DiskGraphIntoIterator<'a> {
graph: &'a DiskGraph,
index: usize,
}
impl<'a> IntoIterator for &'a DiskGraph
{
type IntoIter = DiskGraphIntoIterator<'a>;
type Item = ANNResult<(usize, VertexAndNeighbors)>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
DiskGraphIntoIterator {
graph: self,
index: 0,
}
}
}
impl<'a> Iterator for DiskGraphIntoIterator<'a>
{
type Item = ANNResult<(usize, VertexAndNeighbors)>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.graph.nodes_to_fetch.len() {
return None;
}
let node_index = self.index;
let vertex_and_neighbors = self.graph.get_vertex_and_neighbors(self.index);
self.index += 1;
Some(Ok((node_index, vertex_and_neighbors)))
}
}

View File

@@ -0,0 +1,141 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! In-memory graph
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use crate::common::ANNError;
use super::VertexAndNeighbors;
/// The entire graph of in-memory index
#[derive(Debug)]
pub struct InMemoryGraph {
/// The entire graph
pub final_graph: Vec<RwLock<VertexAndNeighbors>>,
}
impl InMemoryGraph {
/// Create InMemoryGraph instance
pub fn new(size: usize, max_degree: u32) -> Self {
let mut graph = Vec::with_capacity(size);
for id in 0..size {
graph.push(RwLock::new(VertexAndNeighbors::for_range(
id as u32,
max_degree as usize,
)));
}
Self { final_graph: graph }
}
/// Size of graph
pub fn size(&self) -> usize {
self.final_graph.len()
}
/// Extend the graph by size vectors
pub fn extend(&mut self, size: usize, max_degree: u32) {
for id in 0..size {
self.final_graph
.push(RwLock::new(VertexAndNeighbors::for_range(
id as u32,
max_degree as usize,
)));
}
}
/// Get read guard of vertex_id
pub fn read_vertex_and_neighbors(
&self,
vertex_id: u32,
) -> Result<RwLockReadGuard<VertexAndNeighbors>, ANNError> {
self.final_graph[vertex_id as usize].read().map_err(|err| {
ANNError::log_lock_poison_error(format!(
"PoisonError: Lock poisoned when reading final_graph for vertex_id {}, err={}",
vertex_id, err
))
})
}
/// Get write guard of vertex_id
pub fn write_vertex_and_neighbors(
&self,
vertex_id: u32,
) -> Result<RwLockWriteGuard<VertexAndNeighbors>, ANNError> {
self.final_graph[vertex_id as usize].write().map_err(|err| {
ANNError::log_lock_poison_error(format!(
"PoisonError: Lock poisoned when writing final_graph for vertex_id {}, err={}",
vertex_id, err
))
})
}
}
#[cfg(test)]
mod graph_tests {
use crate::model::{graph::AdjacencyList, GRAPH_SLACK_FACTOR};
use super::*;
#[test]
fn test_new() {
let graph = InMemoryGraph::new(10, 10);
let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize;
assert_eq!(graph.final_graph.len(), 10);
for i in 0..10 {
let neighbor = graph.final_graph[i].read().unwrap();
assert_eq!(neighbor.vertex_id, i as u32);
assert_eq!(neighbor.get_neighbors().capacity(), capacity);
}
}
#[test]
fn test_size() {
let graph = InMemoryGraph::new(10, 10);
assert_eq!(graph.size(), 10);
}
#[test]
fn test_extend() {
let mut graph = InMemoryGraph::new(10, 10);
graph.extend(10, 10);
assert_eq!(graph.size(), 20);
let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize;
let mut id: u32 = 0;
for i in 10..20 {
let neighbor = graph.final_graph[i].read().unwrap();
assert_eq!(neighbor.vertex_id, id);
assert_eq!(neighbor.get_neighbors().capacity(), capacity);
id += 1;
}
}
#[test]
fn test_read_vertex_and_neighbors() {
let graph = InMemoryGraph::new(10, 10);
let neighbor = graph.read_vertex_and_neighbors(0);
assert!(neighbor.is_ok());
assert_eq!(neighbor.unwrap().vertex_id, 0);
}
#[test]
fn test_write_vertex_and_neighbors() {
let graph = InMemoryGraph::new(10, 10);
{
let neighbor = graph.write_vertex_and_neighbors(0);
assert!(neighbor.is_ok());
neighbor.unwrap().add_to_neighbors(10, 10);
}
let neighbor = graph.read_vertex_and_neighbors(0).unwrap();
assert_eq!(neighbor.get_neighbors(), &AdjacencyList::from(vec![10_u32]));
}
}

View File

@@ -0,0 +1,20 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod inmem_graph;
pub use inmem_graph::InMemoryGraph;
pub mod vertex_and_neighbors;
pub use vertex_and_neighbors::VertexAndNeighbors;
mod adjacency_list;
pub use adjacency_list::AdjacencyList;
mod sector_graph;
pub use sector_graph::*;
mod disk_graph;
pub use disk_graph::*;

View File

@@ -0,0 +1,87 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_docs)]
//! Sector graph
use std::ops::Deref;
use crate::common::{AlignedBoxWithSlice, ANNResult, ANNError};
use crate::model::{MAX_N_SECTOR_READS, SECTOR_LEN, AlignedRead};
use crate::storage::DiskGraphStorage;
/// Sector graph read from disk index
pub struct SectorGraph {
/// Sector bytes from disk
/// One sector has num_nodes_per_sector nodes
/// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]}
/// The fp vector is not aligned
sectors_data: AlignedBoxWithSlice<u8>,
/// Graph storage to read sectors
graph_storage: DiskGraphStorage,
/// Current sector index into which the next read reads data
cur_sector_idx: u64,
}
impl SectorGraph {
/// Create SectorGraph instance
pub fn new(graph_storage: DiskGraphStorage) -> ANNResult<Self> {
Ok(Self {
sectors_data: AlignedBoxWithSlice::new(MAX_N_SECTOR_READS * SECTOR_LEN, SECTOR_LEN)?,
graph_storage,
cur_sector_idx: 0,
})
}
/// Reset SectorGraph
pub fn reset(&mut self) {
self.cur_sector_idx = 0;
}
/// Read sectors into sectors_data
/// They are in the same order as sectors_to_fetch
pub fn read_graph(&mut self, sectors_to_fetch: &[u64]) -> ANNResult<()> {
let cur_sector_idx_usize: usize = self.cur_sector_idx.try_into()?;
if sectors_to_fetch.len() > MAX_N_SECTOR_READS - cur_sector_idx_usize {
return Err(ANNError::log_index_error(format!(
"Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}",
sectors_to_fetch.len(),
MAX_N_SECTOR_READS - cur_sector_idx_usize,
)));
}
let mut sector_slices = self.sectors_data.split_into_nonoverlapping_mut_slices(
cur_sector_idx_usize * SECTOR_LEN..(cur_sector_idx_usize + sectors_to_fetch.len()) * SECTOR_LEN,
SECTOR_LEN)?;
let mut read_requests = Vec::with_capacity(sector_slices.len());
for (local_sector_idx, slice) in sector_slices.iter_mut().enumerate() {
let sector_id = sectors_to_fetch[local_sector_idx];
read_requests.push(AlignedRead::new(sector_id * SECTOR_LEN as u64, slice)?);
}
self.graph_storage.read(&mut read_requests)?;
self.cur_sector_idx += sectors_to_fetch.len() as u64;
Ok(())
}
/// Get sector data by local index
#[inline]
pub fn get_sector_buf(&self, local_sector_idx: usize) -> &[u8] {
&self.sectors_data[local_sector_idx * SECTOR_LEN..(local_sector_idx + 1) * SECTOR_LEN]
}
}
impl Deref for SectorGraph {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.sectors_data
}
}

View File

@@ -0,0 +1,159 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Vertex and its Adjacency List
use crate::model::GRAPH_SLACK_FACTOR;
use super::AdjacencyList;
/// The out neighbors of vertex_id
#[derive(Debug)]
pub struct VertexAndNeighbors {
/// The id of the vertex
pub vertex_id: u32,
/// All out neighbors (id) of vertex_id
neighbors: AdjacencyList,
}
impl VertexAndNeighbors {
/// Create VertexAndNeighbors with id and capacity
pub fn for_range(id: u32, range: usize) -> Self {
Self {
vertex_id: id,
neighbors: AdjacencyList::for_range(range),
}
}
/// Create VertexAndNeighbors with id and neighbors
pub fn new(vertex_id: u32, neighbors: AdjacencyList) -> Self {
Self {
vertex_id,
neighbors,
}
}
/// Get size of neighbors
#[inline(always)]
pub fn size(&self) -> usize {
self.neighbors.len()
}
/// Update the neighbors vector (post a pruning exercise)
#[inline(always)]
pub fn set_neighbors(&mut self, new_neighbors: AdjacencyList) {
// Replace the graph entry with the pruned neighbors
self.neighbors = new_neighbors;
}
/// Get the neighbors
#[inline(always)]
pub fn get_neighbors(&self) -> &AdjacencyList {
&self.neighbors
}
/// Adds a node to the list of neighbors for the given node.
///
/// # Arguments
///
/// * `node_id` - The ID of the node to add.
/// * `range` - The range of the graph.
///
/// # Return
///
/// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full.
pub fn add_to_neighbors(&mut self, node_id: u32, range: u32) -> Option<Vec<u32>> {
// Check if n is already in the graph entry
if self.neighbors.contains(&node_id) {
return None;
}
let neighbor_len = self.neighbors.len();
// If not, check if the graph entry has enough space
if neighbor_len < (GRAPH_SLACK_FACTOR * range as f64) as usize {
// If yes, add n to the graph entry
self.neighbors.push(node_id);
return None;
}
let mut copy_of_neighbors = Vec::with_capacity(neighbor_len + 1);
unsafe {
let dst = copy_of_neighbors.as_mut_ptr();
std::ptr::copy_nonoverlapping(self.neighbors.as_ptr(), dst, neighbor_len);
dst.add(neighbor_len).write(node_id);
copy_of_neighbors.set_len(neighbor_len + 1);
}
Some(copy_of_neighbors)
}
}
#[cfg(test)]
mod vertex_and_neighbors_tests {
use crate::model::GRAPH_SLACK_FACTOR;
use super::*;
#[test]
fn test_set_with_capacity() {
let neighbors = VertexAndNeighbors::for_range(20, 10);
assert_eq!(neighbors.vertex_id, 20);
assert_eq!(
neighbors.neighbors.capacity(),
(10_f32 * GRAPH_SLACK_FACTOR as f32).ceil() as usize
);
}
#[test]
fn test_size() {
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
for i in 0..5 {
neighbors.neighbors.push(i);
}
assert_eq!(neighbors.size(), 5);
}
#[test]
fn test_set_neighbors() {
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
let new_vec = AdjacencyList::from(vec![1, 2, 3, 4, 5]);
neighbors.set_neighbors(AdjacencyList::from(new_vec.clone()));
assert_eq!(neighbors.neighbors, new_vec);
}
#[test]
fn test_get_neighbors() {
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
neighbors.set_neighbors(AdjacencyList::from(vec![1, 2, 3, 4, 5]));
let neighbor_ref = neighbors.get_neighbors();
assert!(std::ptr::eq(&neighbors.neighbors, neighbor_ref))
}
#[test]
fn test_add_to_neighbors() {
let mut neighbors = VertexAndNeighbors::for_range(20, 10);
assert_eq!(neighbors.add_to_neighbors(1, 1), None);
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1]));
assert_eq!(neighbors.add_to_neighbors(1, 1), None);
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1]));
let ret = neighbors.add_to_neighbors(2, 1);
assert!(ret.is_some());
assert_eq!(ret.unwrap(), vec![1, 2]);
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1]));
assert_eq!(neighbors.add_to_neighbors(2, 2), None);
assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1, 2]));
}
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub mod neighbor;
pub use neighbor::Neighbor;
pub use neighbor::NeighborPriorityQueue;
pub mod data_store;
pub use data_store::InmemDataset;
pub mod graph;
pub use graph::InMemoryGraph;
pub use graph::VertexAndNeighbors;
pub mod configuration;
pub use configuration::*;
pub mod scratch;
pub use scratch::*;
pub mod vertex;
pub use vertex::Vertex;
pub mod pq;
pub use pq::*;
pub mod windows_aligned_file_reader;
pub use windows_aligned_file_reader::*;

View File

@@ -0,0 +1,13 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod neighbor;
pub use neighbor::*;
mod neighbor_priority_queue;
pub use neighbor_priority_queue::*;
mod sorted_neighbor_vector;
pub use sorted_neighbor_vector::SortedNeighborVector;

View File

@@ -0,0 +1,104 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::cmp::Ordering;
/// Neighbor node
#[derive(Debug, Clone, Copy)]
pub struct Neighbor {
/// The id of the node
pub id: u32,
/// The distance from the query node to current node
pub distance: f32,
/// Whether the current is visited or not
pub visited: bool,
}
impl Neighbor {
/// Create the neighbor node and it has not been visited
pub fn new (id: u32, distance: f32) -> Self {
Self {
id,
distance,
visited: false
}
}
}
impl Default for Neighbor {
fn default() -> Self {
Self { id: 0, distance: 0.0_f32, visited: false }
}
}
impl PartialEq for Neighbor {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for Neighbor {}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> Ordering {
let ord = self.distance.partial_cmp(&other.distance).unwrap_or(std::cmp::Ordering::Equal);
if ord == Ordering::Equal {
return self.id.cmp(&other.id);
}
ord
}
}
impl PartialOrd for Neighbor {
#[inline]
fn lt(&self, other: &Self) -> bool {
self.distance < other.distance || (self.distance == other.distance && self.id < other.id)
}
// Reason for allowing panic = "Does not support comparing Neighbor with partial_cmp"
#[allow(clippy::panic)]
fn partial_cmp(&self, _: &Self) -> Option<std::cmp::Ordering> {
panic!("Neighbor only allows eq and lt")
}
}
#[cfg(test)]
mod neighbor_test {
use super::*;
#[test]
fn eq_lt_works() {
let n1 = Neighbor::new(1, 1.1);
let n2 = Neighbor::new(2, 2.0);
let n3 = Neighbor::new(1, 1.1);
assert!(n1 != n2);
assert!(n1 < n2);
assert!(n1 == n3);
}
#[test]
#[should_panic]
fn gt_should_panic() {
let n1 = Neighbor::new(1, 1.1);
let n2 = Neighbor::new(2, 2.0);
assert!(n2 > n1);
}
#[test]
#[should_panic]
fn le_should_panic() {
let n1 = Neighbor::new(1, 1.1);
let n2 = Neighbor::new(2, 2.0);
assert!(n1 <= n2);
}
}

View File

@@ -0,0 +1,241 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use crate::model::Neighbor;
/// Neighbor priority Queue based on the distance to the query node
#[derive(Debug)]
pub struct NeighborPriorityQueue {
/// The size of the priority queue
size: usize,
/// The capacity of the priority queue
capacity: usize,
/// The current notvisited neighbor whose distance is smallest among all notvisited neighbor
cur: usize,
/// The neighbor collection
data: Vec<Neighbor>,
}
impl Default for NeighborPriorityQueue {
fn default() -> Self {
Self::new()
}
}
impl NeighborPriorityQueue {
/// Create NeighborPriorityQueue without capacity
pub fn new() -> Self {
Self {
size: 0,
capacity: 0,
cur: 0,
data: Vec::new(),
}
}
/// Create NeighborPriorityQueue with capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
size: 0,
capacity,
cur: 0,
data: vec![Neighbor::default(); capacity + 1],
}
}
/// Inserts item with order.
/// The item will be dropped if queue is full / already exist in queue / it has a greater distance than the last item.
/// The set cursor that is used to pop() the next item will be set to the lowest index of an uncheck item.
pub fn insert(&mut self, nbr: Neighbor) {
if self.size == self.capacity && self.get_at(self.size - 1) < &nbr {
return;
}
let mut lo = 0;
let mut hi = self.size;
while lo < hi {
let mid = (lo + hi) >> 1;
if &nbr < self.get_at(mid) {
hi = mid;
} else if self.get_at(mid).id == nbr.id {
// Make sure the same neighbor isn't inserted into the set
return;
} else {
lo = mid + 1;
}
}
if lo < self.capacity {
self.data.copy_within(lo..self.size, lo + 1);
}
self.data[lo] = Neighbor::new(nbr.id, nbr.distance);
if self.size < self.capacity {
self.size += 1;
}
if lo < self.cur {
self.cur = lo;
}
}
/// Get the neighbor at index - SAFETY: index must be less than size
fn get_at(&self, index: usize) -> &Neighbor {
unsafe { self.data.get_unchecked(index) }
}
/// Get the closest and notvisited neighbor
pub fn closest_notvisited(&mut self) -> Neighbor {
self.data[self.cur].visited = true;
let pre = self.cur;
while self.cur < self.size && self.get_at(self.cur).visited {
self.cur += 1;
}
self.data[pre]
}
/// Whether there is notvisited node or not
pub fn has_notvisited_node(&self) -> bool {
self.cur < self.size
}
/// Get the size of the NeighborPriorityQueue
pub fn size(&self) -> usize {
self.size
}
/// Get the capacity of the NeighborPriorityQueue
pub fn capacity(&self) -> usize {
self.capacity
}
/// Sets an artificial capacity of the NeighborPriorityQueue. For benchmarking purposes only.
pub fn set_capacity(&mut self, capacity: usize) {
if capacity < self.data.len() {
self.capacity = capacity;
}
}
/// Reserve capacity
pub fn reserve(&mut self, capacity: usize) {
if capacity > self.capacity {
self.data.resize(capacity + 1, Neighbor::default());
self.capacity = capacity;
}
}
/// Set size and cur to 0
pub fn clear(&mut self) {
self.size = 0;
self.cur = 0;
}
}
impl std::ops::Index<usize> for NeighborPriorityQueue {
type Output = Neighbor;
fn index(&self, i: usize) -> &Self::Output {
&self.data[i]
}
}
#[cfg(test)]
mod neighbor_priority_queue_test {
use super::*;
#[test]
fn test_reserve_capacity() {
let mut queue = NeighborPriorityQueue::with_capacity(10);
assert_eq!(queue.capacity(), 10);
queue.reserve(20);
assert_eq!(queue.capacity(), 20);
}
#[test]
fn test_insert() {
let mut queue = NeighborPriorityQueue::with_capacity(3);
assert_eq!(queue.size(), 0);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_eq!(queue.size(), 2);
queue.insert(Neighbor::new(2, 0.5)); // should be ignored as the same neighbor
assert_eq!(queue.size(), 2);
queue.insert(Neighbor::new(3, 0.9));
assert_eq!(queue.size(), 3);
assert_eq!(queue[2].id, 1);
queue.insert(Neighbor::new(4, 2.0)); // should be dropped as queue is full and distance is greater than last item
assert_eq!(queue.size(), 3);
assert_eq!(queue[0].id, 2); // node id in queue should be [2,3,1]
assert_eq!(queue[1].id, 3);
assert_eq!(queue[2].id, 1);
println!("{:?}", queue);
}
#[test]
fn test_index() {
let mut queue = NeighborPriorityQueue::with_capacity(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5));
assert_eq!(queue[0].id, 2);
assert_eq!(queue[0].distance, 0.5);
}
#[test]
fn test_visit() {
let mut queue = NeighborPriorityQueue::with_capacity(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
queue.insert(Neighbor::new(3, 1.5)); // node id in queue should be [2,1,3]
assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited();
assert_eq!(nbr.id, 2);
assert_eq!(nbr.distance, 0.5);
assert!(nbr.visited);
assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited();
assert_eq!(nbr.id, 1);
assert_eq!(nbr.distance, 1.0);
assert!(nbr.visited);
assert!(queue.has_notvisited_node());
let nbr = queue.closest_notvisited();
assert_eq!(nbr.id, 3);
assert_eq!(nbr.distance, 1.5);
assert!(nbr.visited);
assert!(!queue.has_notvisited_node());
}
#[test]
fn test_clear_queue() {
let mut queue = NeighborPriorityQueue::with_capacity(3);
queue.insert(Neighbor::new(1, 1.0));
queue.insert(Neighbor::new(2, 0.5));
assert_eq!(queue.size(), 2);
assert!(queue.has_notvisited_node());
queue.clear();
assert_eq!(queue.size(), 0);
assert!(!queue.has_notvisited_node());
}
#[test]
fn test_reserve() {
let mut queue = NeighborPriorityQueue::new();
queue.reserve(10);
assert_eq!(queue.data.len(), 11);
assert_eq!(queue.capacity, 10);
}
#[test]
fn test_set_capacity() {
let mut queue = NeighborPriorityQueue::with_capacity(10);
queue.set_capacity(5);
assert_eq!(queue.capacity, 5);
assert_eq!(queue.data.len(), 11);
queue.set_capacity(11);
assert_eq!(queue.capacity, 5);
}
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Sorted Neighbor Vector
use std::ops::{Deref, DerefMut};
use super::Neighbor;
/// A newtype on top of vector of neighbors, is sorted by distance
#[derive(Debug)]
pub struct SortedNeighborVector<'a>(&'a mut Vec<Neighbor>);
impl<'a> SortedNeighborVector<'a> {
/// Create a new SortedNeighborVector
pub fn new(vec: &'a mut Vec<Neighbor>) -> Self {
vec.sort_unstable();
Self(vec)
}
}
impl<'a> Deref for SortedNeighborVector<'a> {
type Target = Vec<Neighbor>;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<'a> DerefMut for SortedNeighborVector<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
}
}

View File

@@ -0,0 +1,483 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations)]
use hashbrown::HashMap;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut,
};
use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
use crate::{
common::{ANNError, ANNResult},
model::NUM_PQ_CENTROIDS,
};
/// PQ Pivot table loading and calculate distance
#[derive(Debug)]
pub struct FixedChunkPQTable {
/// pq_tables = float array of size [256 * ndims]
pq_table: Vec<f32>,
/// ndims = true dimension of vectors
dim: usize,
/// num_pq_chunks = the pq chunk number
num_pq_chunks: usize,
/// chunk_offsets = the offset of each chunk, start from 0
chunk_offsets: Vec<usize>,
/// centroid of each dimension
centroids: Vec<f32>,
/// Becasue we're using L2 distance, this is no needed now.
/// Transport of pq_table. transport_pq_table = float array of size [ndims * 256].
/// e.g. if pa_table is 2 centroids * 3 dims
/// [ 1, 2, 3,
/// 4, 5, 6]
/// then transport_pq_table would be 3 dims * 2 centroids
/// [ 1, 4,
/// 2, 5,
/// 3, 6]
/// transport_pq_table: Vec<f32>,
/// Map dim offset to chunk index e.g., 8 dims in to 2 chunks
/// then would be [(0,0), (1,0), (2,0), (3,0), (4,1), (5,1), (6,1), (7,1)]
dimoffset_chunk_mapping: HashMap<usize, usize>,
}
impl FixedChunkPQTable {
/// Create the FixedChunkPQTable with dim and chunk numbers and pivot file data (pivot table + cenroids + chunk offsets)
pub fn new(
dim: usize,
num_pq_chunks: usize,
pq_table: Vec<f32>,
centroids: Vec<f32>,
chunk_offsets: Vec<usize>,
) -> Self {
let mut dimoffset_chunk_mapping = HashMap::new();
for chunk_index in 0..num_pq_chunks {
for dim_offset in chunk_offsets[chunk_index]..chunk_offsets[chunk_index + 1] {
dimoffset_chunk_mapping.insert(dim_offset, chunk_index);
}
}
Self {
pq_table,
dim,
num_pq_chunks,
chunk_offsets,
centroids,
dimoffset_chunk_mapping,
}
}
/// Get chunk number
pub fn get_num_chunks(&self) -> usize {
self.num_pq_chunks
}
/// Shifting the query according to mean or the whole corpus
pub fn preprocess_query(&self, query_vec: &mut [f32]) {
for (query, &centroid) in query_vec.iter_mut().zip(self.centroids.iter()) {
*query -= centroid;
}
}
/// Pre-calculated the distance between query and each centroid by l2 distance
/// * `query_vec` - query vector: 1 * dim
/// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids
#[allow(clippy::needless_range_loop)]
pub fn populate_chunk_distances(&self, query_vec: &[f32]) -> Vec<f32> {
let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS];
for centroid_index in 0..NUM_PQ_CENTROIDS {
for chunk_index in 0..self.num_pq_chunks {
for dim_offset in
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
{
let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset]
- query_vec[dim_offset];
dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] += diff * diff;
}
}
}
dist_vec
}
/// Pre-calculated the distance between query and each centroid by inner product
/// * `query_vec` - query vector: 1 * dim
/// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids
///
/// Reason to allow clippy::needless_range_loop:
/// The inner loop is operating over a range that is different for each iteration of the outer loop.
/// This isn't a scenario where using iter().enumerate() would be easily applicable,
/// because the inner loop isn't iterating directly over the contents of a slice or array.
/// Thus, using indexing might be the most straightforward way to express this logic.
#[allow(clippy::needless_range_loop)]
pub fn populate_chunk_inner_products(&self, query_vec: &[f32]) -> Vec<f32> {
let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS];
for centroid_index in 0..NUM_PQ_CENTROIDS {
for chunk_index in 0..self.num_pq_chunks {
for dim_offset in
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
{
// assumes that we are not shifting the vectors to mean zero, i.e., centroid
// array should be all zeros returning negative to keep the search code
// clean (max inner product vs min distance)
let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset]
* query_vec[dim_offset];
dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] -= diff;
}
}
}
dist_vec
}
/// Calculate the distance between query and given centroid by l2 distance
/// * `query_vec` - query vector: 1 * dim
/// * `base_vec` - given centroid array: 1 * num_pq_chunks
#[allow(clippy::needless_range_loop)]
pub fn l2_distance(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 {
let mut res_vec: Vec<f32> = vec![0.0; self.num_pq_chunks];
res_vec
.par_iter_mut()
.enumerate()
.for_each(|(chunk_index, chunk_diff)| {
for dim_offset in
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
{
let diff = self.pq_table
[self.dim * base_vec[chunk_index] as usize + dim_offset]
- query_vec[dim_offset];
*chunk_diff += diff * diff;
}
});
let res: f32 = res_vec.iter().sum::<f32>();
res
}
/// Calculate the distance between query and given centroid by inner product
/// * `query_vec` - query vector: 1 * dim
/// * `base_vec` - given centroid array: 1 * num_pq_chunks
#[allow(clippy::needless_range_loop)]
pub fn inner_product(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 {
let mut res_vec: Vec<f32> = vec![0.0; self.num_pq_chunks];
res_vec
.par_iter_mut()
.enumerate()
.for_each(|(chunk_index, chunk_diff)| {
for dim_offset in
self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1]
{
*chunk_diff += self.pq_table
[self.dim * base_vec[chunk_index] as usize + dim_offset]
* query_vec[dim_offset];
}
});
let res: f32 = res_vec.iter().sum::<f32>();
// returns negative value to simulate distances (max -> min conversion)
-res
}
/// Revert vector by adding centroid
/// * `base_vec` - given centroid array: 1 * num_pq_chunks
/// * `out_vec` - reverted vector
pub fn inflate_vector(&self, base_vec: &[u8]) -> ANNResult<Vec<f32>> {
let mut out_vec: Vec<f32> = vec![0.0; self.dim];
for (dim_offset, value) in out_vec.iter_mut().enumerate() {
let chunk_index =
self.dimoffset_chunk_mapping
.get(&dim_offset)
.ok_or(ANNError::log_pq_error(
"ERROR: dim_offset not found in dimoffset_chunk_mapping".to_string(),
))?;
*value = self.pq_table[self.dim * base_vec[*chunk_index] as usize + dim_offset]
+ self.centroids[dim_offset];
}
Ok(out_vec)
}
}
/// Given a batch input nodes, return a batch of PQ distance
/// * `pq_ids` - batch nodes: n_pts * pq_nchunks
/// * `n_pts` - batch number
/// * `pq_nchunks` - pq chunk number number
/// * `pq_dists` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids
/// * `dists_out` - n_pts * 1
pub fn pq_dist_lookup(
pq_ids: &[u8],
n_pts: usize,
pq_nchunks: usize,
pq_dists: &[f32],
) -> Vec<f32> {
let mut dists_out: Vec<f32> = vec![0.0; n_pts];
unsafe {
_mm_prefetch(dists_out.as_ptr() as *const i8, _MM_HINT_T0);
_mm_prefetch(pq_ids.as_ptr() as *const i8, _MM_HINT_T0);
_mm_prefetch(pq_ids.as_ptr().add(64) as *const i8, _MM_HINT_T0);
_mm_prefetch(pq_ids.as_ptr().add(128) as *const i8, _MM_HINT_T0);
}
for chunk in 0..pq_nchunks {
let chunk_dists = &pq_dists[256 * chunk..];
if chunk < pq_nchunks - 1 {
unsafe {
_mm_prefetch(
chunk_dists.as_ptr().offset(256 * chunk as isize).add(256) as *const i8,
_MM_HINT_T0,
);
}
}
dists_out
.par_iter_mut()
.enumerate()
.for_each(|(n_iter, dist)| {
let pq_centerid = pq_ids[pq_nchunks * n_iter + chunk];
*dist += chunk_dists[pq_centerid as usize];
});
}
dists_out
}
pub fn aggregate_coords(ids: &[u32], all_coords: &[u8], ndims: usize) -> Vec<u8> {
let mut out: Vec<u8> = vec![0u8; ids.len() * ndims];
let ndim_u32 = ndims as u32;
out.par_chunks_mut(ndims)
.enumerate()
.for_each(|(index, chunk)| {
let id_compressed_pivot = &all_coords
[(ids[index] * ndim_u32) as usize..(ids[index] * ndim_u32 + ndim_u32) as usize];
let temp_slice =
unsafe { std::slice::from_raw_parts(id_compressed_pivot.as_ptr(), ndims) };
chunk.copy_from_slice(temp_slice);
});
out
}
#[cfg(test)]
mod fixed_chunk_pq_table_test {
use super::*;
use crate::common::{ANNError, ANNResult};
use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, file_exists, load_bin};
const DIM: usize = 128;
#[test]
fn load_pivot_test() {
let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
let (dim, pq_table, centroids, chunk_offsets) =
load_pq_pivots_bin(pq_pivots_path, &1).unwrap();
let fixed_chunk_pq_table =
FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets);
assert_eq!(dim, DIM);
assert_eq!(fixed_chunk_pq_table.pq_table.len(), DIM * NUM_PQ_CENTROIDS);
assert_eq!(fixed_chunk_pq_table.centroids.len(), DIM);
assert_eq!(fixed_chunk_pq_table.chunk_offsets[0], 0);
assert_eq!(fixed_chunk_pq_table.chunk_offsets[1], DIM);
assert_eq!(fixed_chunk_pq_table.chunk_offsets.len(), 2);
}
#[test]
fn get_num_chunks_test() {
let num_chunks = 7;
let pa_table = vec![0.0; DIM * NUM_PQ_CENTROIDS];
let centroids = vec![0.0; DIM];
let chunk_offsets = vec![0, 7, 9, 11, 22, 34, 78, 127];
let fixed_chunk_pq_table =
FixedChunkPQTable::new(DIM, num_chunks, pa_table, centroids, chunk_offsets);
let chunk: usize = fixed_chunk_pq_table.get_num_chunks();
assert_eq!(chunk, num_chunks);
}
#[test]
fn preprocess_query_test() {
let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
let (dim, pq_table, centroids, chunk_offsets) =
load_pq_pivots_bin(pq_pivots_path, &1).unwrap();
let fixed_chunk_pq_table =
FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets);
let mut query_vec: Vec<f32> = vec![
32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32,
68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32,
7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32,
65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32,
11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32,
59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32,
7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32,
48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32,
47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32,
74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32,
65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32,
41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32,
84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32,
74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32,
37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32,
];
fixed_chunk_pq_table.preprocess_query(&mut query_vec);
assert_eq!(query_vec[0], 32.39f32 - fixed_chunk_pq_table.centroids[0]);
assert_eq!(
query_vec[127],
35.6f32 - fixed_chunk_pq_table.centroids[127]
);
}
#[test]
fn calculate_distances_tests() {
let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
let (dim, pq_table, centroids, chunk_offsets) =
load_pq_pivots_bin(pq_pivots_path, &1).unwrap();
let fixed_chunk_pq_table =
FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets);
let query_vec: Vec<f32> = vec![
32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32,
68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32,
7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32,
65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32,
11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32,
59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32,
7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32,
48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32,
47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32,
74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32,
65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32,
41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32,
84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32,
74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32,
37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32,
];
let dist_vec = fixed_chunk_pq_table.populate_chunk_distances(&query_vec);
assert_eq!(dist_vec.len(), 256);
// populate_chunk_distances_test
let mut sampled_output = 0.0;
(0..DIM).for_each(|dim_offset| {
let diff = fixed_chunk_pq_table.pq_table[dim_offset] - query_vec[dim_offset];
sampled_output += diff * diff;
});
assert_eq!(sampled_output, dist_vec[0]);
// populate_chunk_inner_products_test
let dist_vec = fixed_chunk_pq_table.populate_chunk_inner_products(&query_vec);
assert_eq!(dist_vec.len(), 256);
let mut sampled_output = 0.0;
(0..DIM).for_each(|dim_offset| {
sampled_output -= fixed_chunk_pq_table.pq_table[dim_offset] * query_vec[dim_offset];
});
assert_eq!(sampled_output, dist_vec[0]);
// l2_distance_test
let base_vec: Vec<u8> = vec![3u8];
let dist = fixed_chunk_pq_table.l2_distance(&query_vec, &base_vec);
let mut l2_output = 0.0;
(0..DIM).for_each(|dim_offset| {
let diff = fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] - query_vec[dim_offset];
l2_output += diff * diff;
});
assert_eq!(l2_output, dist);
// inner_product_test
let dist = fixed_chunk_pq_table.inner_product(&query_vec, &base_vec);
let mut l2_output = 0.0;
(0..DIM).for_each(|dim_offset| {
l2_output -=
fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] * query_vec[dim_offset];
});
assert_eq!(l2_output, dist);
// inflate_vector_test
let inflate_vector = fixed_chunk_pq_table.inflate_vector(&base_vec).unwrap();
assert_eq!(inflate_vector.len(), DIM);
assert_eq!(
inflate_vector[0],
fixed_chunk_pq_table.pq_table[3 * DIM] + fixed_chunk_pq_table.centroids[0]
);
assert_eq!(
inflate_vector[1],
fixed_chunk_pq_table.pq_table[3 * DIM + 1] + fixed_chunk_pq_table.centroids[1]
);
assert_eq!(
inflate_vector[127],
fixed_chunk_pq_table.pq_table[3 * DIM + 127] + fixed_chunk_pq_table.centroids[127]
);
}
fn load_pq_pivots_bin(
pq_pivots_path: &str,
num_pq_chunks: &usize,
) -> ANNResult<(usize, Vec<f32>, Vec<f32>, Vec<usize>)> {
if !file_exists(pq_pivots_path) {
return Err(ANNError::log_pq_error(
"ERROR: PQ k-means pivot file not found.".to_string(),
));
}
let (data, offset_num, offset_dim) = load_bin::<u64>(pq_pivots_path, 0)?;
let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim);
if offset_num != 4 {
let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num);
return Err(ANNError::log_pq_error(error_message));
}
let (data, pq_center_num, dim) = load_bin::<f32>(pq_pivots_path, file_offset_data[0])?;
let pq_table = data.to_vec();
if pq_center_num != NUM_PQ_CENTROIDS {
let error_message = format!(
"Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.",
pq_pivots_path, pq_center_num, NUM_PQ_CENTROIDS
);
return Err(ANNError::log_pq_error(error_message));
}
let (data, centroid_dim, nc) = load_bin::<f32>(pq_pivots_path, file_offset_data[1])?;
let centroids = data.to_vec();
if centroid_dim != dim || nc != 1 {
let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim);
return Err(ANNError::log_pq_error(error_message));
}
let (data, chunk_offset_num, nc) = load_bin::<u32>(pq_pivots_path, file_offset_data[2])?;
let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc);
if chunk_offset_num != num_pq_chunks + 1 || nc != 1 {
let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1);
return Err(ANNError::log_pq_error(error_message));
}
Ok((dim, pq_table, centroids, chunk_offsets))
}
}
#[cfg(test)]
mod pq_index_prune_query_test {
use super::*;
#[test]
fn pq_dist_lookup_test() {
let pq_ids: Vec<u8> = vec![1u8, 3u8, 2u8, 2u8];
let mut pq_dists: Vec<f32> = Vec::with_capacity(256 * 2);
for _ in 0..pq_dists.capacity() {
pq_dists.push(rand::random());
}
let dists_out = pq_dist_lookup(&pq_ids, 2, 2, &pq_dists);
assert_eq!(dists_out.len(), 2);
assert_eq!(dists_out[0], pq_dists[0 + 1] + pq_dists[256 + 3]);
assert_eq!(dists_out[1], pq_dists[0 + 2] + pq_dists[256 + 2]);
}
}

View File

@@ -0,0 +1,9 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
mod fixed_chunk_pq_table;
pub use fixed_chunk_pq_table::*;
mod pq_construction;
pub use pq_construction::*;

View File

@@ -0,0 +1,398 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations)]
use rayon::prelude::{IndexedParallelIterator, ParallelIterator};
use rayon::slice::ParallelSliceMut;
use crate::common::{ANNError, ANNResult};
use crate::storage::PQStorage;
use crate::utils::{compute_closest_centers, file_exists, k_means_clustering};
/// Max size of PQ training set
pub const MAX_PQ_TRAINING_SET_SIZE: f64 = 256_000f64;
/// Max number of PQ chunks
pub const MAX_PQ_CHUNKS: usize = 512;
pub const NUM_PQ_CENTROIDS: usize = 256;
/// block size for reading/processing large files and matrices in blocks
const BLOCK_SIZE: usize = 5000000;
const NUM_KMEANS_REPS_PQ: usize = 12;
/// given training data in train_data of dimensions num_train * dim, generate
/// PQ pivots using k-means algorithm to partition the co-ordinates into
/// num_pq_chunks (if it divides dimension, else rounded) chunks, and runs
/// k-means in each chunk to compute the PQ pivots and stores in bin format in
/// file pq_pivots_path as a s num_centers*dim floating point binary file
/// PQ pivot table layout: {pivot offsets data: METADATA_SIZE}{pivot vector:[dim; num_centroid]}{centroid vector:[dim; 1]}{chunk offsets:[chunk_num+1; 1]}
fn generate_pq_pivots(
train_data: &mut [f32],
num_train: usize,
dim: usize,
num_centers: usize,
num_pq_chunks: usize,
max_k_means_reps: usize,
pq_storage: &mut PQStorage,
) -> ANNResult<()> {
if num_pq_chunks > dim {
return Err(ANNError::log_pq_error(
"Error: number of chunks more than dimension.".to_string(),
));
}
if pq_storage.pivot_data_exist() {
let (file_num_centers, file_dim) = pq_storage.read_pivot_metadata()?;
if file_dim == dim && file_num_centers == num_centers {
// PQ pivot file exists. Not generating again.
return Ok(());
}
}
// Calculate centroid and center the training data
// If we use L2 distance, there is an option to
// translate all vectors to make them centered and
// then compute PQ. This needs to be set to false
// when using PQ for MIPS as such translations dont
// preserve inner products.
// Now, we're using L2 as default.
let mut centroid: Vec<f32> = vec![0.0; dim];
for dim_index in 0..dim {
for train_data_index in 0..num_train {
centroid[dim_index] += train_data[train_data_index * dim + dim_index];
}
centroid[dim_index] /= num_train as f32;
}
for dim_index in 0..dim {
for train_data_index in 0..num_train {
train_data[train_data_index * dim + dim_index] -= centroid[dim_index];
}
}
// Calculate each chunk's offset
// If we have 8 dimension and 3 chunk then offsets would be [0,3,6,8]
let mut chunk_offsets: Vec<usize> = vec![0; num_pq_chunks + 1];
let mut chunk_offset: usize = 0;
for chunk_index in 0..num_pq_chunks {
chunk_offset += dim / num_pq_chunks;
if chunk_index < (dim % num_pq_chunks) {
chunk_offset += 1;
}
chunk_offsets[chunk_index + 1] = chunk_offset;
}
let mut full_pivot_data: Vec<f32> = vec![0.0; num_centers * dim];
for chunk_index in 0..num_pq_chunks {
let chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index];
let mut cur_train_data: Vec<f32> = vec![0.0; num_train * chunk_size];
let mut cur_pivot_data: Vec<f32> = vec![0.0; num_centers * chunk_size];
cur_train_data
.par_chunks_mut(chunk_size)
.enumerate()
.for_each(|(train_data_index, chunk)| {
for (dim_offset, item) in chunk.iter_mut().enumerate() {
*item = train_data
[train_data_index * dim + chunk_offsets[chunk_index] + dim_offset];
}
});
// Run kmeans to get the centroids of this chunk.
let (_closest_docs, _closest_center, _residual) = k_means_clustering(
&cur_train_data,
num_train,
chunk_size,
&mut cur_pivot_data,
num_centers,
max_k_means_reps,
)?;
// Copy centroids from this chunk table to full table
for center_index in 0..num_centers {
full_pivot_data[center_index * dim + chunk_offsets[chunk_index]
..center_index * dim + chunk_offsets[chunk_index + 1]]
.copy_from_slice(
&cur_pivot_data[center_index * chunk_size..(center_index + 1) * chunk_size],
);
}
}
pq_storage.write_pivot_data(
&full_pivot_data,
&centroid,
&chunk_offsets,
num_centers,
dim,
)?;
Ok(())
}
/// streams the base file (data_file), and computes the closest centers in each
/// chunk to generate the compressed data_file and stores it in
/// pq_compressed_vectors_path.
/// If the numbber of centers is < 256, it stores as byte vector, else as
/// 4-byte vector in binary format.
/// Compressed PQ table layout: {num_points: usize}{num_chunks: usize}{compressed pq table: [num_points; num_chunks]}
fn generate_pq_data_from_pivots<T: Copy + Into<f32>>(
num_centers: usize,
num_pq_chunks: usize,
pq_storage: &mut PQStorage,
) -> ANNResult<()> {
let (num_points, dim) = pq_storage.read_pq_data_metadata()?;
let full_pivot_data: Vec<f32>;
let centroid: Vec<f32>;
let chunk_offsets: Vec<usize>;
if !pq_storage.pivot_data_exist() {
return Err(ANNError::log_pq_error(
"ERROR: PQ k-means pivot file not found.".to_string(),
));
} else {
(full_pivot_data, centroid, chunk_offsets) =
pq_storage.load_pivot_data(&num_pq_chunks, &num_centers, &dim)?;
}
pq_storage.write_compressed_pivot_metadata(num_points as i32, num_pq_chunks as i32)?;
let block_size = if num_points <= BLOCK_SIZE {
num_points
} else {
BLOCK_SIZE
};
let num_blocks = (num_points / block_size) + (num_points % block_size != 0) as usize;
for block_index in 0..num_blocks {
let start_index: usize = block_index * block_size;
let end_index: usize = std::cmp::min((block_index + 1) * block_size, num_points);
let cur_block_size: usize = end_index - start_index;
let mut block_compressed_base: Vec<usize> = vec![0; cur_block_size * num_pq_chunks];
let block_data: Vec<T> = pq_storage.read_pq_block_data(cur_block_size, dim)?;
let mut adjusted_block_data: Vec<f32> = vec![0.0; cur_block_size * dim];
for block_data_index in 0..cur_block_size {
for dim_index in 0..dim {
adjusted_block_data[block_data_index * dim + dim_index] =
block_data[block_data_index * dim + dim_index].into() - centroid[dim_index];
}
}
for chunk_index in 0..num_pq_chunks {
let cur_chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index];
if cur_chunk_size == 0 {
continue;
}
let mut cur_pivot_data: Vec<f32> = vec![0.0; num_centers * cur_chunk_size];
let mut cur_data: Vec<f32> = vec![0.0; cur_block_size * cur_chunk_size];
let mut closest_center: Vec<u32> = vec![0; cur_block_size];
// Divide the data into chunks and process each chunk in parallel.
cur_data
.par_chunks_mut(cur_chunk_size)
.enumerate()
.for_each(|(block_data_index, chunk)| {
for (dim_offset, item) in chunk.iter_mut().enumerate() {
*item = adjusted_block_data
[block_data_index * dim + chunk_offsets[chunk_index] + dim_offset];
}
});
cur_pivot_data
.par_chunks_mut(cur_chunk_size)
.enumerate()
.for_each(|(center_index, chunk)| {
for (din_offset, item) in chunk.iter_mut().enumerate() {
*item = full_pivot_data
[center_index * dim + chunk_offsets[chunk_index] + din_offset];
}
});
// Compute the closet centers
compute_closest_centers(
&cur_data,
cur_block_size,
cur_chunk_size,
&cur_pivot_data,
num_centers,
1,
&mut closest_center,
None,
None,
)?;
block_compressed_base
.par_chunks_mut(num_pq_chunks)
.enumerate()
.for_each(|(block_data_index, slice)| {
slice[chunk_index] = closest_center[block_data_index] as usize;
});
}
_ = pq_storage.write_compressed_pivot_data(
&block_compressed_base,
num_centers,
cur_block_size,
num_pq_chunks,
);
}
Ok(())
}
/// Save the data on a file.
/// # Arguments
/// * `p_val` - choose how many ratio sample data as trained data to get pivot
/// * `num_pq_chunks` - pq chunk number
/// * `codebook_prefix` - predefined pivots file named
/// * `pq_storage` - pq file access
pub fn generate_quantized_data<T: Default + Copy + Into<f32>>(
p_val: f64,
num_pq_chunks: usize,
codebook_prefix: &str,
pq_storage: &mut PQStorage,
) -> ANNResult<()> {
// If predefined pivots already exists, skip training.
if !file_exists(codebook_prefix) {
// Instantiates train data with random sample updates train_data_vector
// Training data with train_size samples loaded.
// Each sampled file has train_dim.
let (mut train_data_vector, train_size, train_dim) =
pq_storage.gen_random_slice::<T>(p_val)?;
generate_pq_pivots(
&mut train_data_vector,
train_size,
train_dim,
NUM_PQ_CENTROIDS,
num_pq_chunks,
NUM_KMEANS_REPS_PQ,
pq_storage,
)?;
}
generate_pq_data_from_pivots::<T>(NUM_PQ_CENTROIDS, num_pq_chunks, pq_storage)?;
Ok(())
}
#[cfg(test)]
mod pq_test {
use std::fs::File;
use std::io::Write;
use super::*;
use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, METADATA_SIZE};
#[test]
fn generate_pq_pivots_test() {
let pivot_file_name = "generate_pq_pivots_test.bin";
let compressed_file_name = "compressed.bin";
let pq_training_file_name = "tests/data/siftsmall_learn.bin";
let mut pq_storage =
PQStorage::new(pivot_file_name, compressed_file_name, pq_training_file_name).unwrap();
let mut train_data: Vec<f32> = vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
];
generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap();
let (data, nr, nc) = load_bin::<u64>(pivot_file_name, 0).unwrap();
let file_offset_data = convert_types_u64_usize(&data, nr, nc);
assert_eq!(file_offset_data[0], METADATA_SIZE);
assert_eq!(nr, 4);
assert_eq!(nc, 1);
let (data, nr, nc) = load_bin::<f32>(pivot_file_name, file_offset_data[0]).unwrap();
let full_pivot_data = data.to_vec();
assert_eq!(full_pivot_data.len(), 16);
assert_eq!(nr, 2);
assert_eq!(nc, 8);
let (data, nr, nc) = load_bin::<f32>(pivot_file_name, file_offset_data[1]).unwrap();
let centroid = data.to_vec();
assert_eq!(
centroid[0],
(1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32
);
assert_eq!(nr, 8);
assert_eq!(nc, 1);
let (data, nr, nc) = load_bin::<u32>(pivot_file_name, file_offset_data[2]).unwrap();
let chunk_offsets = convert_types_u32_usize(&data, nr, nc);
assert_eq!(chunk_offsets[0], 0);
assert_eq!(chunk_offsets[1], 4);
assert_eq!(chunk_offsets[2], 8);
assert_eq!(nr, 3);
assert_eq!(nc, 1);
std::fs::remove_file(pivot_file_name).unwrap();
}
#[test]
fn generate_pq_data_from_pivots_test() {
let data_file = "generate_pq_data_from_pivots_test_data.bin";
//npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8]
let mut train_data: Vec<f32> = vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32,
2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32,
100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32,
];
let my_nums_unstructured: &[u8] = unsafe {
std::slice::from_raw_parts(train_data.as_ptr() as *const u8, train_data.len() * 4)
};
let meta: Vec<i32> = vec![5, 8];
let meta_unstructured: &[u8] =
unsafe { std::slice::from_raw_parts(meta.as_ptr() as *const u8, meta.len() * 4) };
let mut data_file_writer = File::create(data_file).unwrap();
data_file_writer
.write_all(meta_unstructured)
.expect("Failed to write sample file");
data_file_writer
.write_all(my_nums_unstructured)
.expect("Failed to write sample file");
let pq_pivots_path = "generate_pq_data_from_pivots_test_pivot.bin";
let pq_compressed_vectors_path = "generate_pq_data_from_pivots_test.bin";
let mut pq_storage =
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap();
generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap();
generate_pq_data_from_pivots::<f32>(2, 2, &mut pq_storage).unwrap();
let (data, nr, nc) = load_bin::<u8>(pq_compressed_vectors_path, 0).unwrap();
assert_eq!(nr, 5);
assert_eq!(nc, 2);
assert_eq!(data[0], data[2]);
assert_ne!(data[0], data[8]);
std::fs::remove_file(data_file).unwrap();
std::fs::remove_file(pq_pivots_path).unwrap();
std::fs::remove_file(pq_compressed_vectors_path).unwrap();
}
#[test]
fn pq_end_to_end_validation_with_codebook_test() {
let data_file = "tests/data/siftsmall_learn.bin";
let pq_pivots_path = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
let gound_truth_path = "tests/data/siftsmall_learn.bin_pq_compressed.bin";
let pq_compressed_vectors_path = "validation.bin";
let mut pq_storage =
PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap();
generate_quantized_data::<f32>(0.5, 1, pq_pivots_path, &mut pq_storage).unwrap();
let (data, nr, nc) = load_bin::<u8>(pq_compressed_vectors_path, 0).unwrap();
let (gt_data, gt_nr, gt_nc) = load_bin::<u8>(gound_truth_path, 0).unwrap();
assert_eq!(nr, gt_nr);
assert_eq!(nc, gt_nc);
for i in 0..data.len() {
assert_eq!(data[i], gt_data[i]);
}
std::fs::remove_file(pq_compressed_vectors_path).unwrap();
}
}

View File

@@ -0,0 +1,312 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Aligned allocator
use std::collections::VecDeque;
use std::ops::Deref;
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use std::time::Duration;
use crate::common::{ANNError, ANNResult};
#[derive(Debug)]
/// Query scratch data structures
pub struct ConcurrentQueue<T> {
q: Mutex<VecDeque<T>>,
c: Mutex<bool>,
push_cv: Condvar,
}
impl Default for ConcurrentQueue<usize> {
fn default() -> Self {
Self::new()
}
}
impl<T> ConcurrentQueue<T> {
/// Create a concurrent queue
pub fn new() -> Self {
Self {
q: Mutex::new(VecDeque::new()),
c: Mutex::new(false),
push_cv: Condvar::new(),
}
}
/// Block the current thread until it is able to acquire the mutex
pub fn reserve(&self, size: usize) -> ANNResult<()> {
let mut guard = lock(&self.q)?;
guard.reserve(size);
Ok(())
}
/// queue stats
pub fn size(&self) -> ANNResult<usize> {
let guard = lock(&self.q)?;
Ok(guard.len())
}
/// empty the queue
pub fn is_empty(&self) -> ANNResult<bool> {
Ok(self.size()? == 0)
}
/// push back
pub fn push(&self, new_val: T) -> ANNResult<()> {
let mut guard = lock(&self.q)?;
self.push_internal(&mut guard, new_val);
self.push_cv.notify_all();
Ok(())
}
/// push back
fn push_internal(&self, guard: &mut MutexGuard<VecDeque<T>>, new_val: T) {
guard.push_back(new_val);
}
/// insert into queue
pub fn insert<I>(&self, iter: I) -> ANNResult<()>
where
I: IntoIterator<Item = T>,
{
let mut guard = lock(&self.q)?;
for item in iter {
self.push_internal(&mut guard, item);
}
self.push_cv.notify_all();
Ok(())
}
/// pop front
pub fn pop(&self) -> ANNResult<Option<T>> {
let mut guard = lock(&self.q)?;
Ok(guard.pop_front())
}
/// Empty - is this necessary?
pub fn empty_queue(&self) -> ANNResult<()> {
let mut guard = lock(&self.q)?;
while !guard.is_empty() {
let _ = guard.pop_front();
}
Ok(())
}
/// register for push notifications
pub fn wait_for_push_notify(&self, wait_time: Duration) -> ANNResult<()> {
let guard_lock = lock(&self.c)?;
let _ = self
.push_cv
.wait_timeout(guard_lock, wait_time)
.map_err(|err| {
ANNError::log_lock_poison_error(format!(
"ConcurrentQueue Lock is poisoned, err={}",
err
))
})?;
Ok(())
}
}
fn lock<T>(mutex: &Mutex<T>) -> ANNResult<MutexGuard<T>> {
let guard = mutex.lock().map_err(|err| {
ANNError::log_lock_poison_error(format!("ConcurrentQueue lock is poisoned, err={}", err))
})?;
Ok(guard)
}
/// A thread-safe queue that holds instances of `T`.
/// Each instance is stored in a `Box` to keep the size of the queue node constant.
#[derive(Debug)]
pub struct ArcConcurrentBoxedQueue<T> {
internal_queue: Arc<ConcurrentQueue<Box<T>>>,
}
impl<T> ArcConcurrentBoxedQueue<T> {
/// Create a new `ArcConcurrentBoxedQueue`.
pub fn new() -> Self {
Self {
internal_queue: Arc::new(ConcurrentQueue::new()),
}
}
}
impl<T> Default for ArcConcurrentBoxedQueue<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for ArcConcurrentBoxedQueue<T> {
/// Create a new `ArcConcurrentBoxedQueue` that shares the same internal queue
/// with the existing one. This allows multiple `ArcConcurrentBoxedQueue` to
/// operate on the same underlying queue.
fn clone(&self) -> Self {
Self {
internal_queue: Arc::clone(&self.internal_queue),
}
}
}
/// Deref to the ConcurrentQueue.
impl<T> Deref for ArcConcurrentBoxedQueue<T> {
type Target = ConcurrentQueue<Box<T>>;
fn deref(&self) -> &Self::Target {
&self.internal_queue
}
}
#[cfg(test)]
mod tests {
use crate::model::ConcurrentQueue;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_push_pop() {
let queue = ConcurrentQueue::<i32>::new();
queue.push(1).unwrap();
queue.push(2).unwrap();
queue.push(3).unwrap();
assert_eq!(queue.pop().unwrap(), Some(1));
assert_eq!(queue.pop().unwrap(), Some(2));
assert_eq!(queue.pop().unwrap(), Some(3));
assert_eq!(queue.pop().unwrap(), None);
}
#[test]
fn test_size_empty() {
let queue = ConcurrentQueue::new();
assert_eq!(queue.size().unwrap(), 0);
assert!(queue.is_empty().unwrap());
queue.push(1).unwrap();
queue.push(2).unwrap();
assert_eq!(queue.size().unwrap(), 2);
assert!(!queue.is_empty().unwrap());
queue.pop().unwrap();
queue.pop().unwrap();
assert_eq!(queue.size().unwrap(), 0);
assert!(queue.is_empty().unwrap());
}
#[test]
fn test_insert() {
let queue = ConcurrentQueue::new();
let data = vec![1, 2, 3];
queue.insert(data.into_iter()).unwrap();
assert_eq!(queue.pop().unwrap(), Some(1));
assert_eq!(queue.pop().unwrap(), Some(2));
assert_eq!(queue.pop().unwrap(), Some(3));
assert_eq!(queue.pop().unwrap(), None);
}
#[test]
fn test_notifications() {
let queue = Arc::new(ConcurrentQueue::new());
let queue_clone = Arc::clone(&queue);
let producer = thread::spawn(move || {
for i in 0..3 {
thread::sleep(Duration::from_millis(50));
queue_clone.push(i).unwrap();
}
});
let consumer = thread::spawn(move || {
let mut values = vec![];
for _ in 0..3 {
let mut val = -1;
while val == -1 {
queue
.wait_for_push_notify(Duration::from_millis(10))
.unwrap();
val = queue.pop().unwrap().unwrap_or(-1);
}
values.push(val);
}
values
});
producer.join().unwrap();
let consumer_results = consumer.join().unwrap();
assert_eq!(consumer_results, vec![0, 1, 2]);
}
#[test]
fn test_multithreaded_push_pop() {
let queue = Arc::new(ConcurrentQueue::new());
let queue_clone = Arc::clone(&queue);
let producer = thread::spawn(move || {
for i in 0..10 {
queue_clone.push(i).unwrap();
thread::sleep(Duration::from_millis(50));
}
});
let consumer = thread::spawn(move || {
let mut values = vec![];
for _ in 0..10 {
let mut val = -1;
while val == -1 {
val = queue.pop().unwrap().unwrap_or(-1);
thread::sleep(Duration::from_millis(10));
}
values.push(val);
}
values
});
producer.join().unwrap();
let consumer_results = consumer.join().unwrap();
assert_eq!(consumer_results, (0..10).collect::<Vec<_>>());
}
/// This is a single value test. It avoids the unlimited wait until the collectin got empty on the previous test.
/// It will make sure the signal mutex is matching the waiting mutex.
#[test]
fn test_wait_for_push_notify() {
let queue = Arc::new(ConcurrentQueue::<usize>::new());
let queue_clone = Arc::clone(&queue);
let producer = thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
queue_clone.push(1).unwrap();
});
let consumer = thread::spawn(move || {
queue
.wait_for_push_notify(Duration::from_millis(200))
.unwrap();
assert_eq!(queue.pop().unwrap(), Some(1));
});
producer.join().unwrap();
consumer.join().unwrap();
}
}

View File

@@ -0,0 +1,186 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Scratch space for in-memory index based search
use std::cmp::max;
use std::mem;
use hashbrown::HashSet;
use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice};
use crate::model::configuration::index_write_parameters::IndexWriteParameters;
use crate::model::{Neighbor, NeighborPriorityQueue, PQScratch};
use super::Scratch;
/// In-mem index related limits
pub const GRAPH_SLACK_FACTOR: f64 = 1.3_f64;
/// Max number of points for using bitset
pub const MAX_POINTS_FOR_USING_BITSET: usize = 100000;
/// TODO: SSD Index related limits
pub const MAX_GRAPH_DEGREE: usize = 512;
/// TODO: SSD Index related limits
pub const MAX_N_CMPS: usize = 16384;
/// TODO: SSD Index related limits
pub const SECTOR_LEN: usize = 4096;
/// TODO: SSD Index related limits
pub const MAX_N_SECTOR_READS: usize = 128;
/// The alignment required for memory access. This will be multiplied with size of T to get the actual alignment
pub const QUERY_ALIGNMENT_OF_T_SIZE: usize = 16;
/// Scratch space for in-memory index based search
#[derive(Debug)]
pub struct InMemQueryScratch<T, const N: usize> {
/// Size of the candidate queue
pub candidate_size: u32,
/// Max degree for each vertex
pub max_degree: u32,
/// Max occlusion size
pub max_occlusion_size: u32,
/// Query node
pub query: AlignedBoxWithSlice<T>,
/// Best candidates, whose size is candidate_queue_size
pub best_candidates: NeighborPriorityQueue,
/// Occlude factor
pub occlude_factor: Vec<f32>,
/// Visited neighbor id
pub id_scratch: Vec<u32>,
/// The distance between visited neighbor and query node
pub dist_scratch: Vec<f32>,
/// The PQ Scratch, keey it private since this class use the Box to own the memory. Use the function pq_scratch to get its reference
pub pq_scratch: Option<Box<PQScratch>>,
/// Buffers used in process delete, capacity increases as needed
pub expanded_nodes_set: HashSet<u32>,
/// Expanded neighbors
pub expanded_neighbors_vector: Vec<Neighbor>,
/// Occlude list
pub occlude_list_output: Vec<u32>,
/// RobinSet for larger dataset
pub node_visited_robinset: HashSet<u32>,
}
impl<T: Default + Copy, const N: usize> InMemQueryScratch<T, N> {
/// Create InMemQueryScratch instance
pub fn new(
search_candidate_size: u32,
index_write_parameter: &IndexWriteParameters,
init_pq_scratch: bool,
) -> ANNResult<Self> {
let indexing_candidate_size = index_write_parameter.search_list_size;
let max_degree = index_write_parameter.max_degree;
let max_occlusion_size = index_write_parameter.max_occlusion_size;
if search_candidate_size == 0 || indexing_candidate_size == 0 || max_degree == 0 || N == 0 {
return Err(ANNError::log_index_error(format!(
"In InMemQueryScratch, one of search_candidate_size = {}, indexing_candidate_size = {}, dim = {} or max_degree = {} is zero.",
search_candidate_size, indexing_candidate_size, N, max_degree)));
}
let query = AlignedBoxWithSlice::new(N, mem::size_of::<T>() * QUERY_ALIGNMENT_OF_T_SIZE)?;
let pq_scratch = if init_pq_scratch {
Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?))
} else {
None
};
let occlude_factor = Vec::with_capacity(max_occlusion_size as usize);
let capacity = (1.5 * GRAPH_SLACK_FACTOR * (max_degree as f64)).ceil() as usize;
let id_scratch = Vec::with_capacity(capacity);
let dist_scratch = Vec::with_capacity(capacity);
let expanded_nodes_set = HashSet::<u32>::new();
let expanded_neighbors_vector = Vec::<Neighbor>::new();
let occlude_list_output = Vec::<u32>::new();
let candidate_size = max(search_candidate_size, indexing_candidate_size);
let node_visited_robinset = HashSet::<u32>::with_capacity(20 * candidate_size as usize);
let scratch = Self {
candidate_size,
max_degree,
max_occlusion_size,
query,
best_candidates: NeighborPriorityQueue::with_capacity(candidate_size as usize),
occlude_factor,
id_scratch,
dist_scratch,
pq_scratch,
expanded_nodes_set,
expanded_neighbors_vector,
occlude_list_output,
node_visited_robinset,
};
Ok(scratch)
}
/// Resize the scratch with new candidate size
pub fn resize_for_new_candidate_size(&mut self, new_candidate_size: u32) {
if new_candidate_size > self.candidate_size {
let delta = new_candidate_size - self.candidate_size;
self.candidate_size = new_candidate_size;
self.best_candidates.reserve(delta as usize);
self.node_visited_robinset.reserve((20 * delta) as usize);
}
}
}
impl<T: Default + Copy, const N: usize> Scratch for InMemQueryScratch<T, N> {
fn clear(&mut self) {
self.best_candidates.clear();
self.occlude_factor.clear();
self.node_visited_robinset.clear();
self.id_scratch.clear();
self.dist_scratch.clear();
self.expanded_nodes_set.clear();
self.expanded_neighbors_vector.clear();
self.occlude_list_output.clear();
}
}
#[cfg(test)]
mod inmemory_query_scratch_test {
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
use super::*;
#[test]
fn node_visited_robinset_test() {
let index_write_parameter = IndexWriteParametersBuilder::new(10, 10)
.with_max_occlusion_size(5)
.build();
let mut scratch =
InMemQueryScratch::<f32, 32>::new(100, &index_write_parameter, false).unwrap();
assert_eq!(scratch.node_visited_robinset.len(), 0);
scratch.clear();
assert_eq!(scratch.node_visited_robinset.len(), 0);
}
}

View File

@@ -0,0 +1,28 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub mod scratch_traits;
pub use scratch_traits::*;
pub mod concurrent_queue;
pub use concurrent_queue::*;
pub mod pq_scratch;
pub use pq_scratch::*;
pub mod inmem_query_scratch;
pub use inmem_query_scratch::*;
pub mod scratch_store_manager;
pub use scratch_store_manager::*;
pub mod ssd_query_scratch;
pub use ssd_query_scratch::*;
pub mod ssd_thread_data;
pub use ssd_thread_data::*;
pub mod ssd_io_context;
pub use ssd_io_context::*;

View File

@@ -0,0 +1,105 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Aligned allocator
use std::mem::size_of;
use crate::common::{ANNResult, AlignedBoxWithSlice};
const MAX_PQ_CHUNKS: usize = 512;
#[derive(Debug)]
/// PQ scratch
pub struct PQScratch {
/// Aligned pq table dist scratch, must be at least [256 * NCHUNKS]
pub aligned_pqtable_dist_scratch: AlignedBoxWithSlice<f32>,
/// Aligned dist scratch, must be at least diskann MAX_DEGREE
pub aligned_dist_scratch: AlignedBoxWithSlice<f32>,
/// Aligned pq coord scratch, must be at least [N_CHUNKS * MAX_DEGREE]
pub aligned_pq_coord_scratch: AlignedBoxWithSlice<u8>,
/// Rotated query
pub rotated_query: AlignedBoxWithSlice<f32>,
/// Aligned query float
pub aligned_query_float: AlignedBoxWithSlice<f32>,
}
impl PQScratch {
const ALIGNED_ALLOC_256: usize = 256;
/// Create a new pq scratch
pub fn new(graph_degree: usize, aligned_dim: usize) -> ANNResult<Self> {
let aligned_pq_coord_scratch =
AlignedBoxWithSlice::new(graph_degree * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?;
let aligned_pqtable_dist_scratch =
AlignedBoxWithSlice::new(256 * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?;
let aligned_dist_scratch =
AlignedBoxWithSlice::new(graph_degree, PQScratch::ALIGNED_ALLOC_256)?;
let aligned_query_float = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::<f32>())?;
let rotated_query = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::<f32>())?;
Ok(Self {
aligned_pqtable_dist_scratch,
aligned_dist_scratch,
aligned_pq_coord_scratch,
rotated_query,
aligned_query_float,
})
}
/// Set rotated_query and aligned_query_float values
pub fn set<T>(&mut self, dim: usize, query: &[T], norm: f32)
where
T: Into<f32> + Copy,
{
for (d, item) in query.iter().enumerate().take(dim) {
let query_val: f32 = (*item).into();
if (norm - 1.0).abs() > f32::EPSILON {
self.rotated_query[d] = query_val / norm;
self.aligned_query_float[d] = query_val / norm;
} else {
self.rotated_query[d] = query_val;
self.aligned_query_float[d] = query_val;
}
}
}
}
#[cfg(test)]
mod tests {
use crate::model::PQScratch;
#[test]
fn test_pq_scratch() {
let graph_degree = 512;
let aligned_dim = 8;
let mut pq_scratch: PQScratch = PQScratch::new(graph_degree, aligned_dim).unwrap();
// Check alignment
assert_eq!(
(pq_scratch.aligned_pqtable_dist_scratch.as_ptr() as usize) % 256,
0
);
assert_eq!((pq_scratch.aligned_dist_scratch.as_ptr() as usize) % 256, 0);
assert_eq!(
(pq_scratch.aligned_pq_coord_scratch.as_ptr() as usize) % 256,
0
);
assert_eq!((pq_scratch.rotated_query.as_ptr() as usize) % 32, 0);
assert_eq!((pq_scratch.aligned_query_float.as_ptr() as usize) % 32, 0);
// Test set() method
let query = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let norm = 2.0f32;
pq_scratch.set::<u8>(query.len(), &query, norm);
(0..query.len()).for_each(|i| {
assert_eq!(pq_scratch.rotated_query[i], query[i] as f32 / norm);
assert_eq!(pq_scratch.aligned_query_float[i], query[i] as f32 / norm);
});
}
}

View File

@@ -0,0 +1,84 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use crate::common::ANNResult;
use super::ArcConcurrentBoxedQueue;
use super::{scratch_traits::Scratch};
use std::time::Duration;
pub struct ScratchStoreManager<T: Scratch> {
scratch: Option<Box<T>>,
scratch_pool: ArcConcurrentBoxedQueue<T>,
}
impl<T: Scratch> ScratchStoreManager<T> {
pub fn new(scratch_pool: ArcConcurrentBoxedQueue<T>, wait_time: Duration) -> ANNResult<Self> {
let mut scratch = scratch_pool.pop()?;
while scratch.is_none() {
scratch_pool.wait_for_push_notify(wait_time)?;
scratch = scratch_pool.pop()?;
}
Ok(ScratchStoreManager {
scratch,
scratch_pool,
})
}
pub fn scratch_space(&mut self) -> Option<&mut T> {
self.scratch.as_deref_mut()
}
}
impl<T: Scratch> Drop for ScratchStoreManager<T> {
fn drop(&mut self) {
if let Some(mut scratch) = self.scratch.take() {
scratch.clear();
let _ = self.scratch_pool.push(scratch);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct MyScratch {
data: Vec<i32>,
}
impl Scratch for MyScratch {
fn clear(&mut self) {
self.data.clear();
}
}
#[test]
fn test_scratch_store_manager() {
let wait_time = Duration::from_millis(100);
let scratch_pool = ArcConcurrentBoxedQueue::new();
for i in 1..3 {
scratch_pool.push(Box::new(MyScratch {
data: vec![i, 2 * i, 3 * i],
})).unwrap();
}
let mut manager = ScratchStoreManager::new(scratch_pool.clone(), wait_time).unwrap();
let scratch_space = manager.scratch_space().unwrap();
assert_eq!(scratch_space.data, vec![1, 2, 3]);
// At this point, the ScratchStoreManager will go out of scope,
// causing the Drop implementation to be called, which should
// call the clear method on MyScratch.
drop(manager);
let current_scratch = scratch_pool.pop().unwrap().unwrap();
assert_eq!(current_scratch.data, vec![2, 4, 6]);
}
}

View File

@@ -0,0 +1,8 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub trait Scratch {
fn clear(&mut self);
}

View File

@@ -0,0 +1,38 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete.
use crate::common::ANNError;
use platform::{FileHandle, IOCompletionPort};
// The IOContext struct for disk I/O. One for each thread.
pub struct IOContext {
pub status: Status,
pub file_handle: FileHandle,
pub io_completion_port: IOCompletionPort,
}
impl Default for IOContext {
fn default() -> Self {
IOContext {
status: Status::ReadWait,
file_handle: FileHandle::default(),
io_completion_port: IOCompletionPort::default(),
}
}
}
impl IOContext {
pub fn new() -> Self {
Self::default()
}
}
pub enum Status {
ReadWait,
ReadSuccess,
ReadFailed(ANNError),
ProcessComplete,
}

View File

@@ -0,0 +1,132 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete.
use std::mem;
use std::vec::Vec;
use hashbrown::HashSet;
use crate::{
common::{ANNResult, AlignedBoxWithSlice},
model::{Neighbor, NeighborPriorityQueue},
model::data_store::DiskScratchDataset,
};
use super::{PQScratch, Scratch, MAX_GRAPH_DEGREE, QUERY_ALIGNMENT_OF_T_SIZE};
// Scratch space for disk index based search.
pub struct SSDQueryScratch<T: Default + Copy, const N: usize>
{
// Disk scratch dataset storing fp vectors with aligned dim (N)
pub scratch_dataset: DiskScratchDataset<T, N>,
// The query scratch.
pub query: AlignedBoxWithSlice<T>,
/// The PQ Scratch.
pub pq_scratch: Option<Box<PQScratch>>,
// The visited set.
pub id_scratch: HashSet<u32>,
/// Best candidates, whose size is candidate_queue_size
pub best_candidates: NeighborPriorityQueue,
// Full return set.
pub full_return_set: Vec<Neighbor>,
}
//
impl<T: Copy + Default, const N: usize> SSDQueryScratch<T, N>
{
pub fn new(
visited_reserve: usize,
candidate_queue_size: usize,
init_pq_scratch: bool,
) -> ANNResult<Self> {
let scratch_dataset = DiskScratchDataset::<T, N>::new()?;
let query = AlignedBoxWithSlice::<T>::new(N, mem::size_of::<T>() * QUERY_ALIGNMENT_OF_T_SIZE)?;
let id_scratch = HashSet::<u32>::with_capacity(visited_reserve);
let full_return_set = Vec::<Neighbor>::with_capacity(visited_reserve);
let best_candidates = NeighborPriorityQueue::with_capacity(candidate_queue_size);
let pq_scratch = if init_pq_scratch {
Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?))
} else {
None
};
Ok(Self {
scratch_dataset,
query,
pq_scratch,
id_scratch,
best_candidates,
full_return_set,
})
}
pub fn pq_scratch(&mut self) -> &Option<Box<PQScratch>> {
&self.pq_scratch
}
}
impl<T: Default + Copy, const N: usize> Scratch for SSDQueryScratch<T, N>
{
fn clear(&mut self) {
self.id_scratch.clear();
self.best_candidates.clear();
self.full_return_set.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
// Arrange
let visited_reserve = 100;
let candidate_queue_size = 10;
let init_pq_scratch = true;
// Act
let result =
SSDQueryScratch::<u32, 3>::new(visited_reserve, candidate_queue_size, init_pq_scratch);
// Assert
assert!(result.is_ok());
let scratch = result.unwrap();
// Assert the properties of the scratch instance
assert!(scratch.pq_scratch.is_some());
assert!(scratch.id_scratch.is_empty());
assert!(scratch.best_candidates.size() == 0);
assert!(scratch.full_return_set.is_empty());
}
#[test]
fn test_clear() {
// Arrange
let mut scratch = SSDQueryScratch::<u32, 3>::new(100, 10, true).unwrap();
// Add some data to scratch fields
scratch.id_scratch.insert(1);
scratch.best_candidates.insert(Neighbor::new(2, 0.5));
scratch.full_return_set.push(Neighbor::new(3, 0.8));
// Act
scratch.clear();
// Assert
assert!(scratch.id_scratch.is_empty());
assert!(scratch.best_candidates.size() == 0);
assert!(scratch.full_return_set.is_empty());
}
}

View File

@@ -0,0 +1,92 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete.
use std::sync::Arc;
use super::{scratch_traits::Scratch, IOContext, SSDQueryScratch};
use crate::common::ANNResult;
// The thread data struct for SSD I/O. One for each thread, contains the ScratchSpace and the IOContext.
pub struct SSDThreadData<T: Default + Copy, const N: usize> {
pub scratch: SSDQueryScratch<T, N>,
pub io_context: Option<Arc<IOContext>>,
}
impl<T: Default + Copy, const N: usize> SSDThreadData<T, N> {
pub fn new(
aligned_dim: usize,
visited_reserve: usize,
init_pq_scratch: bool,
) -> ANNResult<Self> {
let scratch = SSDQueryScratch::new(aligned_dim, visited_reserve, init_pq_scratch)?;
Ok(SSDThreadData {
scratch,
io_context: None,
})
}
pub fn clear(&mut self) {
self.scratch.clear();
}
}
#[cfg(test)]
mod tests {
use crate::model::Neighbor;
use super::*;
#[test]
fn test_new() {
// Arrange
let aligned_dim = 10;
let visited_reserve = 100;
let init_pq_scratch = true;
// Act
let result = SSDThreadData::<u32, 3>::new(aligned_dim, visited_reserve, init_pq_scratch);
// Assert
assert!(result.is_ok());
let thread_data = result.unwrap();
// Assert the properties of the thread data instance
assert!(thread_data.io_context.is_none());
let scratch = &thread_data.scratch;
// Assert the properties of the scratch instance
assert!(scratch.pq_scratch.is_some());
assert!(scratch.id_scratch.is_empty());
assert!(scratch.best_candidates.size() == 0);
assert!(scratch.full_return_set.is_empty());
}
#[test]
fn test_clear() {
// Arrange
let mut thread_data = SSDThreadData::<u32, 3>::new(10, 100, true).unwrap();
// Add some data to scratch fields
thread_data.scratch.id_scratch.insert(1);
thread_data
.scratch
.best_candidates
.insert(Neighbor::new(2, 0.5));
thread_data
.scratch
.full_return_set
.push(Neighbor::new(3, 0.8));
// Act
thread_data.clear();
// Assert
assert!(thread_data.scratch.id_scratch.is_empty());
assert!(thread_data.scratch.best_candidates.size() == 0);
assert!(thread_data.scratch.full_return_set.is_empty());
}
}

View File

@@ -0,0 +1,22 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Vertex dimension
/// 32 vertex dimension
pub const DIM_32: usize = 32;
/// 64 vertex dimension
pub const DIM_64: usize = 64;
/// 104 vertex dimension
pub const DIM_104: usize = 104;
/// 128 vertex dimension
pub const DIM_128: usize = 128;
/// 256 vertex dimension
pub const DIM_256: usize = 256;

View File

@@ -0,0 +1,10 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod vertex;
pub use vertex::Vertex;
mod dimension;
pub use dimension::*;

View File

@@ -0,0 +1,68 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Vertex
use std::array::TryFromSliceError;
use vector::{FullPrecisionDistance, Metric};
/// Vertex with data type T and dimension N
#[derive(Debug)]
pub struct Vertex<'a, T, const N: usize>
where
[T; N]: FullPrecisionDistance<T, N>,
{
/// Vertex value
val: &'a [T; N],
/// Vertex Id
id: u32,
}
impl<'a, T, const N: usize> Vertex<'a, T, N>
where
[T; N]: FullPrecisionDistance<T, N>,
{
/// Create the vertex with data
pub fn new(val: &'a [T; N], id: u32) -> Self {
Self {
val,
id,
}
}
/// Compare the vertex with another.
#[inline(always)]
pub fn compare(&self, other: &Vertex<'a, T, N>, metric: Metric) -> f32 {
<[T; N]>::distance_compare(self.val, other.val, metric)
}
/// Get the vector associated with the vertex.
#[inline]
pub fn vector(&self) -> &[T; N] {
self.val
}
/// Get the vertex id.
#[inline]
pub fn vertex_id(&self) -> u32 {
self.id
}
}
impl<'a, T, const N: usize> TryFrom<(&'a [T], u32)> for Vertex<'a, T, N>
where
[T; N]: FullPrecisionDistance<T, N>,
{
type Error = TryFromSliceError;
fn try_from((mem_slice, id): (&'a [T], u32)) -> Result<Self, Self::Error> {
let array: &[T; N] = mem_slice.try_into()?;
Ok(Vertex::new(array, id))
}
}

View File

@@ -0,0 +1,7 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[allow(clippy::module_inception)]
mod windows_aligned_file_reader;
pub use windows_aligned_file_reader::*;

View File

@@ -0,0 +1,414 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::sync::Arc;
use std::time::Duration;
use std::{ptr, thread};
use crossbeam::sync::ShardedLock;
use hashbrown::HashMap;
use once_cell::sync::Lazy;
use platform::file_handle::{AccessMode, ShareMode};
use platform::{
file_handle::FileHandle,
file_io::{get_queued_completion_status, read_file_to_slice},
io_completion_port::IOCompletionPort,
};
use winapi::{
shared::{basetsd::ULONG_PTR, minwindef::DWORD},
um::minwinbase::OVERLAPPED,
};
use crate::common::{ANNError, ANNResult};
use crate::model::IOContext;
pub const MAX_IO_CONCURRENCY: usize = 128; // To do: explore the optimal value for this. The current value is taken from C++ code.
pub const FILE_ATTRIBUTE_READONLY: DWORD = 0x00000001;
pub const IO_COMPLETION_TIMEOUT: DWORD = u32::MAX; // Infinite timeout.
pub const DISK_IO_ALIGNMENT: usize = 512;
pub const ASYNC_IO_COMPLETION_CHECK_INTERVAL: Duration = Duration::from_micros(5);
/// Aligned read struct for disk IO, it takes the ownership of the AlignedBoxedSlice and returns the AlignedBoxWithSlice data immutably.
pub struct AlignedRead<'a, T> {
/// where to read from
/// offset needs to be aligned with DISK_IO_ALIGNMENT
offset: u64,
/// where to read into
/// aligned_buf and its len need to be aligned with DISK_IO_ALIGNMENT
aligned_buf: &'a mut [T],
}
impl<'a, T> AlignedRead<'a, T> {
pub fn new(offset: u64, aligned_buf: &'a mut [T]) -> ANNResult<Self> {
Self::assert_is_aligned(offset as usize)?;
Self::assert_is_aligned(std::mem::size_of_val(aligned_buf))?;
Ok(Self {
offset,
aligned_buf,
})
}
fn assert_is_aligned(val: usize) -> ANNResult<()> {
match val % DISK_IO_ALIGNMENT {
0 => Ok(()),
_ => Err(ANNError::log_disk_io_request_alignment_error(format!(
"The offset or length of AlignedRead request is not {} bytes aligned",
DISK_IO_ALIGNMENT
))),
}
}
pub fn aligned_buf(&self) -> &[T] {
self.aligned_buf
}
}
pub struct WindowsAlignedFileReader {
file_name: String,
// ctx_map is the mapping from thread id to io context. It is hashmap behind a sharded lock to allow concurrent access from multiple threads.
// ShardedLock: shardedlock provides an implementation of a reader-writer lock that offers concurrent read access to the shared data while allowing exclusive write access.
// It achieves better scalability by dividing the shared data into multiple shards, and each with its own internal lock.
// Multiple threads can read from different shards simultaneously, reducing contention.
// https://docs.rs/crossbeam/0.8.2/crossbeam/sync/struct.ShardedLock.html
// Comparing to RwLock, ShardedLock provides higher concurrency for read operations and is suitable for read heavy workloads.
// The value of the hashmap is an Arc<IOContext> to allow immutable access to IOContext with automatic reference counting.
ctx_map: Lazy<ShardedLock<HashMap<thread::ThreadId, Arc<IOContext>>>>,
}
impl WindowsAlignedFileReader {
pub fn new(fname: &str) -> ANNResult<Self> {
let reader: WindowsAlignedFileReader = WindowsAlignedFileReader {
file_name: fname.to_string(),
ctx_map: Lazy::new(|| ShardedLock::new(HashMap::new())),
};
reader.register_thread()?;
Ok(reader)
}
// Register the io context for a thread if it hasn't been registered.
pub fn register_thread(&self) -> ANNResult<()> {
let mut ctx_map = self.ctx_map.write().map_err(|_| {
ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string())
})?;
let id = thread::current().id();
if ctx_map.contains_key(&id) {
println!(
"Warning:: Duplicate registration for thread_id : {:?}. Directly call get_ctx to get the thread context data.",
id);
return Ok(());
}
let mut ctx = IOContext::new();
match unsafe { FileHandle::new(&self.file_name, AccessMode::Read, ShareMode::Read) } {
Ok(file_handle) => ctx.file_handle = file_handle,
Err(err) => {
return Err(ANNError::log_io_error(err));
}
}
// Create a io completion port for the file handle, later it will be used to get the completion status.
match IOCompletionPort::new(&ctx.file_handle, None, 0, 0) {
Ok(io_completion_port) => ctx.io_completion_port = io_completion_port,
Err(err) => {
return Err(ANNError::log_io_error(err));
}
}
ctx_map.insert(id, Arc::new(ctx));
Ok(())
}
// Get the reference counted io context for the current thread.
pub fn get_ctx(&self) -> ANNResult<Arc<IOContext>> {
let ctx_map = self.ctx_map.read().map_err(|_| {
ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string())
})?;
let id = thread::current().id();
match ctx_map.get(&id) {
Some(ctx) => Ok(Arc::clone(ctx)),
None => Err(ANNError::log_index_error(format!(
"unable to find IOContext for thread_id {:?}",
id
))),
}
}
// Read the data from the file by sending concurrent io requests in batches.
pub fn read<T>(&self, read_requests: &mut [AlignedRead<T>], ctx: &IOContext) -> ANNResult<()> {
let n_requests = read_requests.len();
let n_batches = (n_requests + MAX_IO_CONCURRENCY - 1) / MAX_IO_CONCURRENCY;
let mut overlapped_in_out =
vec![unsafe { std::mem::zeroed::<OVERLAPPED>() }; MAX_IO_CONCURRENCY];
for batch_idx in 0..n_batches {
let batch_start = MAX_IO_CONCURRENCY * batch_idx;
let batch_size = std::cmp::min(n_requests - batch_start, MAX_IO_CONCURRENCY);
for j in 0..batch_size {
let req = &mut read_requests[batch_start + j];
let os = &mut overlapped_in_out[j];
match unsafe {
read_file_to_slice(&ctx.file_handle, req.aligned_buf, os, req.offset)
} {
Ok(_) => {}
Err(error) => {
return Err(ANNError::IOError { err: (error) });
}
}
}
let mut n_read: DWORD = 0;
let mut n_complete: u64 = 0;
let mut completion_key: ULONG_PTR = 0;
let mut lp_os: *mut OVERLAPPED = ptr::null_mut();
while n_complete < batch_size as u64 {
match unsafe {
get_queued_completion_status(
&ctx.io_completion_port,
&mut n_read,
&mut completion_key,
&mut lp_os,
IO_COMPLETION_TIMEOUT,
)
} {
// An IO request completed.
Ok(true) => n_complete += 1,
// No IO request completed, continue to wait.
Ok(false) => {
thread::sleep(ASYNC_IO_COMPLETION_CHECK_INTERVAL);
}
// An error ocurred.
Err(error) => return Err(ANNError::IOError { err: (error) }),
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{fs::File, io::BufReader};
use bincode::deserialize_from;
use serde::{Deserialize, Serialize};
use crate::{common::AlignedBoxWithSlice, model::SECTOR_LEN};
use super::*;
pub const TEST_INDEX_PATH: &str =
"./tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index";
pub const TRUTH_NODE_DATA_PATH: &str =
"./tests/data/disk_index_node_data_aligned_reader_truth.bin";
#[derive(Debug, Serialize, Deserialize)]
struct NodeData {
num_neighbors: u32,
coordinates: Vec<f32>,
neighbors: Vec<u32>,
}
impl PartialEq for NodeData {
fn eq(&self, other: &Self) -> bool {
self.num_neighbors == other.num_neighbors
&& self.coordinates == other.coordinates
&& self.neighbors == other.neighbors
}
}
#[test]
fn test_new_aligned_file_reader() {
// Replace "test_file_path" with actual file path
let result = WindowsAlignedFileReader::new(TEST_INDEX_PATH);
assert!(result.is_ok());
let reader = result.unwrap();
assert_eq!(reader.file_name, TEST_INDEX_PATH);
}
#[test]
fn test_read() {
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
let ctx = reader.get_ctx().unwrap();
let read_length = 512; // adjust according to your logic
let num_read = 10;
let mut aligned_mem = AlignedBoxWithSlice::<u8>::new(read_length * num_read, 512).unwrap();
// create and add AlignedReads to the vector
let mut mem_slices = aligned_mem
.split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
.unwrap();
let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
.iter_mut()
.enumerate()
.map(|(i, slice)| {
let offset = (i * read_length) as u64;
AlignedRead::new(offset, slice).unwrap()
})
.collect();
let result = reader.read(&mut aligned_reads, &ctx);
assert!(result.is_ok());
}
#[test]
fn test_read_disk_index_by_sector() {
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
let ctx = reader.get_ctx().unwrap();
let read_length = SECTOR_LEN; // adjust according to your logic
let num_sector = 10;
let mut aligned_mem =
AlignedBoxWithSlice::<u8>::new(read_length * num_sector, 512).unwrap();
// Each slice will be used as the buffer for a read request of a sector.
let mut mem_slices = aligned_mem
.split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length)
.unwrap();
let mut aligned_reads: Vec<AlignedRead<'_, u8>> = mem_slices
.iter_mut()
.enumerate()
.map(|(sector_id, slice)| {
let offset = (sector_id * read_length) as u64;
AlignedRead::new(offset, slice).unwrap()
})
.collect();
let result = reader.read(&mut aligned_reads, &ctx);
assert!(result.is_ok());
aligned_reads.iter().for_each(|read| {
assert_eq!(read.aligned_buf.len(), SECTOR_LEN);
});
let disk_layout_meta = reconstruct_disk_meta(aligned_reads[0].aligned_buf);
assert!(disk_layout_meta.len() > 9);
let dims = disk_layout_meta[1];
let num_pts = disk_layout_meta[0];
let max_node_len = disk_layout_meta[3];
let max_num_nodes_per_sector = disk_layout_meta[4];
assert!(max_node_len * max_num_nodes_per_sector < SECTOR_LEN as u64);
let num_nbrs_start = (dims as usize) * std::mem::size_of::<f32>();
let nbrs_buf_start = num_nbrs_start + std::mem::size_of::<u32>();
let mut node_data_array = Vec::with_capacity(max_num_nodes_per_sector as usize * 9);
// Only validate the first 9 sectors with graph nodes.
(1..9).for_each(|sector_id| {
let sector_data = &mem_slices[sector_id];
for node_data in sector_data.chunks_exact(max_node_len as usize) {
// Extract coordinates data from the start of the node_data
let coordinates_end = (dims as usize) * std::mem::size_of::<f32>();
let coordinates = node_data[0..coordinates_end]
.chunks_exact(std::mem::size_of::<f32>())
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
.collect();
// Extract number of neighbors from the node_data
let neighbors_num = u32::from_le_bytes(
node_data[num_nbrs_start..nbrs_buf_start]
.try_into()
.unwrap(),
);
let nbors_buf_end =
nbrs_buf_start + (neighbors_num as usize) * std::mem::size_of::<u32>();
// Extract neighbors from the node data.
let mut neighbors = Vec::new();
for nbors_data in node_data[nbrs_buf_start..nbors_buf_end]
.chunks_exact(std::mem::size_of::<u32>())
{
let nbors_id = u32::from_le_bytes(nbors_data.try_into().unwrap());
assert!(nbors_id < num_pts as u32);
neighbors.push(nbors_id);
}
// Create NodeData struct and push it to the node_data_array
node_data_array.push(NodeData {
num_neighbors: neighbors_num,
coordinates,
neighbors,
});
}
});
// Compare that each node read from the disk index are expected.
let node_data_truth_file = File::open(TRUTH_NODE_DATA_PATH).unwrap();
let reader = BufReader::new(node_data_truth_file);
let node_data_vec: Vec<NodeData> = deserialize_from(reader).unwrap();
for (node_from_node_data_file, node_from_disk_index) in
node_data_vec.iter().zip(node_data_array.iter())
{
// Verify that the NodeData from the file is equal to the NodeData in node_data_array
assert_eq!(node_from_node_data_file, node_from_disk_index);
}
}
#[test]
fn test_read_fail_invalid_file() {
let reader = WindowsAlignedFileReader::new("/invalid_path");
assert!(reader.is_err());
}
#[test]
fn test_read_no_requests() {
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
let ctx = reader.get_ctx().unwrap();
let mut read_requests = Vec::<AlignedRead<u8>>::new();
let result = reader.read(&mut read_requests, &ctx);
assert!(result.is_ok());
}
#[test]
fn test_get_ctx() {
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
let result = reader.get_ctx();
assert!(result.is_ok());
}
#[test]
fn test_register_thread() {
let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap();
let result = reader.register_thread();
assert!(result.is_ok());
}
fn reconstruct_disk_meta(buffer: &[u8]) -> Vec<u64> {
let size_of_u64 = std::mem::size_of::<u64>();
let num_values = buffer.len() / size_of_u64;
let mut disk_layout_meta = Vec::with_capacity(num_values);
let meta_data = &buffer[8..];
for chunk in meta_data.chunks_exact(size_of_u64) {
let value = u64::from_le_bytes(chunk.try_into().unwrap());
disk_layout_meta.push(value);
}
disk_layout_meta
}
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_docs)]
//! Disk graph storage
use std::sync::Arc;
use crate::{model::{WindowsAlignedFileReader, IOContext, AlignedRead}, common::ANNResult};
/// Graph storage for disk index
/// One thread has one storage instance
pub struct DiskGraphStorage {
/// Disk graph reader
disk_graph_reader: Arc<WindowsAlignedFileReader>,
/// IOContext of current thread
ctx: Arc<IOContext>,
}
impl DiskGraphStorage {
/// Create a new DiskGraphStorage instance
pub fn new(disk_graph_reader: Arc<WindowsAlignedFileReader>) -> ANNResult<Self> {
let ctx = disk_graph_reader.get_ctx()?;
Ok(Self {
disk_graph_reader,
ctx,
})
}
/// Read disk graph data
pub fn read<T>(&self, read_requests: &mut [AlignedRead<T>]) -> ANNResult<()> {
self.disk_graph_reader.read(read_requests, &self.ctx)
}
}

View File

@@ -0,0 +1,363 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use byteorder::{ByteOrder, LittleEndian, ReadBytesExt};
use std::fs::File;
use std::io::Read;
use std::marker::PhantomData;
use std::{fs, mem};
use crate::common::{ANNError, ANNResult};
use crate::model::NUM_PQ_CENTROIDS;
use crate::storage::PQStorage;
use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, save_bin_u64};
use crate::utils::{
file_exists, gen_sample_data, get_file_size, round_up, CachedReader, CachedWriter,
};
const SECTOR_LEN: usize = 4096;
/// Todo: Remove the allow(dead_code) when the disk search code is complete
#[allow(dead_code)]
pub struct PQPivotData {
dim: usize,
pq_table: Vec<f32>,
centroids: Vec<f32>,
chunk_offsets: Vec<usize>,
}
pub struct DiskIndexStorage<T> {
/// Dataset file
dataset_file: String,
/// Index file path prefix
index_path_prefix: String,
// TODO: Only a placeholder for T, will be removed later
_marker: PhantomData<T>,
pq_storage: PQStorage,
}
impl<T> DiskIndexStorage<T> {
/// Create DiskIndexStorage instance
pub fn new(dataset_file: String, index_path_prefix: String) -> ANNResult<Self> {
let pq_storage: PQStorage = PQStorage::new(
&(index_path_prefix.clone() + ".bin_pq_pivots.bin"),
&(index_path_prefix.clone() + ".bin_pq_compressed.bin"),
&dataset_file,
)?;
Ok(DiskIndexStorage {
dataset_file,
index_path_prefix,
_marker: PhantomData,
pq_storage,
})
}
pub fn get_pq_storage(&mut self) -> &mut PQStorage {
&mut self.pq_storage
}
pub fn dataset_file(&self) -> &String {
&self.dataset_file
}
pub fn index_path_prefix(&self) -> &String {
&self.index_path_prefix
}
/// Create disk layout
/// Sector #1: disk_layout_meta
/// Sector #n: num_nodes_per_sector nodes
/// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]}
/// # Arguments
/// * `dataset_file` - dataset file containing full precision vectors
/// * `mem_index_file` - in-memory index graph file
/// * `disk_layout_file` - output disk layout file
pub fn create_disk_layout(&self) -> ANNResult<()> {
let mem_index_file = self.mem_index_file();
let disk_layout_file = self.disk_index_file();
// amount to read or write in one shot
let read_blk_size = 64 * 1024 * 1024;
let write_blk_size = read_blk_size;
let mut dataset_reader = CachedReader::new(self.dataset_file.as_str(), read_blk_size)?;
let num_pts = dataset_reader.read_u32()? as u64;
let dims = dataset_reader.read_u32()? as u64;
// Create cached reader + writer
let actual_file_size = get_file_size(mem_index_file.as_str())?;
println!("Vamana index file size={}", actual_file_size);
let mut vamana_reader = File::open(mem_index_file)?;
let mut diskann_writer = CachedWriter::new(disk_layout_file.as_str(), write_blk_size)?;
let index_file_size = vamana_reader.read_u64::<LittleEndian>()?;
if index_file_size != actual_file_size {
println!(
"Vamana Index file size does not match expected size per meta-data. file size from file: {}, actual file size: {}",
index_file_size, actual_file_size
);
}
let max_degree = vamana_reader.read_u32::<LittleEndian>()?;
let medoid = vamana_reader.read_u32::<LittleEndian>()?;
let vamana_frozen_num = vamana_reader.read_u64::<LittleEndian>()?;
let mut vamana_frozen_loc = 0;
if vamana_frozen_num == 1 {
vamana_frozen_loc = medoid;
}
let max_node_len = ((max_degree as u64 + 1) * (mem::size_of::<u32>() as u64))
+ (dims * (mem::size_of::<T>() as u64));
let num_nodes_per_sector = (SECTOR_LEN as u64) / max_node_len;
println!("medoid: {}B", medoid);
println!("max_node_len: {}B", max_node_len);
println!("num_nodes_per_sector: {}B", num_nodes_per_sector);
// SECTOR_LEN buffer for each sector
let mut sector_buf = vec![0u8; SECTOR_LEN];
let mut node_buf = vec![0u8; max_node_len as usize];
let num_nbrs_start = (dims as usize) * mem::size_of::<T>();
let nbrs_buf_start = num_nbrs_start + mem::size_of::<u32>();
// number of sectors (1 for meta data)
let num_sectors = round_up(num_pts, num_nodes_per_sector) / num_nodes_per_sector;
let disk_index_file_size = (num_sectors + 1) * (SECTOR_LEN as u64);
let disk_layout_meta = vec![
num_pts,
dims,
medoid as u64,
max_node_len,
num_nodes_per_sector,
vamana_frozen_num,
vamana_frozen_loc as u64,
// append_reorder_data
// We are not supporting this. Temporarily write it into the layout so that
// we can leverage C++ query driver to test the disk index
false as u64,
disk_index_file_size,
];
diskann_writer.write(&sector_buf)?;
let mut cur_node_coords = vec![0u8; (dims as usize) * mem::size_of::<T>()];
let mut cur_node_id = 0u64;
for sector in 0..num_sectors {
if sector % 100_000 == 0 {
println!("Sector #{} written", sector);
}
sector_buf.fill(0);
for sector_node_id in 0..num_nodes_per_sector {
if cur_node_id >= num_pts {
break;
}
node_buf.fill(0);
// read cur node's num_nbrs
let num_nbrs = vamana_reader.read_u32::<LittleEndian>()?;
// sanity checks on num_nbrs
debug_assert!(num_nbrs > 0);
debug_assert!(num_nbrs <= max_degree);
// write coords of node first
dataset_reader.read(&mut cur_node_coords)?;
node_buf[..cur_node_coords.len()].copy_from_slice(&cur_node_coords);
// write num_nbrs
LittleEndian::write_u32(
&mut node_buf[num_nbrs_start..(num_nbrs_start + mem::size_of::<u32>())],
num_nbrs,
);
// write neighbors
let nbrs_buf = &mut node_buf[nbrs_buf_start
..(nbrs_buf_start + (num_nbrs as usize) * mem::size_of::<u32>())];
vamana_reader.read_exact(nbrs_buf)?;
// get offset into sector_buf
let sector_node_buf_start = (sector_node_id * max_node_len) as usize;
let sector_node_buf = &mut sector_buf
[sector_node_buf_start..(sector_node_buf_start + max_node_len as usize)];
sector_node_buf.copy_from_slice(&node_buf[..(max_node_len as usize)]);
cur_node_id += 1;
}
// flush sector to disk
diskann_writer.write(&sector_buf)?;
}
diskann_writer.flush()?;
save_bin_u64(
disk_layout_file.as_str(),
&disk_layout_meta,
disk_layout_meta.len(),
1,
0,
)?;
Ok(())
}
pub fn index_build_cleanup(&self) -> ANNResult<()> {
fs::remove_file(self.mem_index_file())?;
Ok(())
}
pub fn gen_query_warmup_data(&self, sampling_rate: f64) -> ANNResult<()> {
gen_sample_data::<T>(
&self.dataset_file,
&self.warmup_query_prefix(),
sampling_rate,
)?;
Ok(())
}
/// Load pre-trained pivot table
pub fn load_pq_pivots_bin(
&self,
num_pq_chunks: &usize,
) -> ANNResult<PQPivotData> {
let pq_pivots_path = &self.pq_pivot_file();
if !file_exists(pq_pivots_path) {
return Err(ANNError::log_pq_error(
"ERROR: PQ k-means pivot file not found.".to_string(),
));
}
let (data, offset_num, offset_dim) = load_bin::<u64>(pq_pivots_path, 0)?;
let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim);
if offset_num != 4 {
let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num);
return Err(ANNError::log_pq_error(error_message));
}
let (data, pivot_num, dim) = load_bin::<f32>(pq_pivots_path, file_offset_data[0])?;
let pq_table = data.to_vec();
if pivot_num != NUM_PQ_CENTROIDS {
let error_message = format!(
"Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.",
pq_pivots_path, pivot_num, NUM_PQ_CENTROIDS
);
return Err(ANNError::log_pq_error(error_message));
}
let (data, centroid_dim, nc) = load_bin::<f32>(pq_pivots_path, file_offset_data[1])?;
let centroids = data.to_vec();
if centroid_dim != dim || nc != 1 {
let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim);
return Err(ANNError::log_pq_error(error_message));
}
let (data, chunk_offset_num, nc) = load_bin::<u32>(pq_pivots_path, file_offset_data[2])?;
let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc);
if chunk_offset_num != num_pq_chunks + 1 || nc != 1 {
let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1);
return Err(ANNError::log_pq_error(error_message));
}
Ok(PQPivotData {
dim,
pq_table,
centroids,
chunk_offsets
})
}
fn mem_index_file(&self) -> String {
self.index_path_prefix.clone() + "_mem.index"
}
fn disk_index_file(&self) -> String {
self.index_path_prefix.clone() + "_disk.index"
}
fn warmup_query_prefix(&self) -> String {
self.index_path_prefix.clone() + "_sample"
}
pub fn pq_pivot_file(&self) -> String {
self.index_path_prefix.clone() + ".bin_pq_pivots.bin"
}
pub fn compressed_pq_pivot_file(&self) -> String {
self.index_path_prefix.clone() + ".bin_pq_compressed.bin"
}
}
#[cfg(test)]
mod disk_index_storage_test {
use std::fs;
use crate::test_utils::get_test_file_path;
use super::*;
const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin";
const DISK_INDEX_PATH_PREFIX: &str = "tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2";
const TRUTH_DISK_LAYOUT: &str =
"tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index";
#[test]
fn create_disk_layout_test() {
let storage = DiskIndexStorage::<f32>::new(
get_test_file_path(TEST_DATA_FILE),
get_test_file_path(DISK_INDEX_PATH_PREFIX),
).unwrap();
storage.create_disk_layout().unwrap();
let disk_layout_file = storage.disk_index_file();
let rust_disk_layout = fs::read(disk_layout_file.as_str()).unwrap();
let truth_disk_layout = fs::read(get_test_file_path(TRUTH_DISK_LAYOUT).as_str()).unwrap();
assert!(rust_disk_layout == truth_disk_layout);
fs::remove_file(disk_layout_file.as_str()).expect("Failed to delete file");
}
#[test]
fn load_pivot_test() {
let dim: usize = 128;
let num_pq_chunk: usize = 1;
let pivot_file_prefix: &str = "tests/data/siftsmall_learn";
let storage = DiskIndexStorage::<f32>::new(
get_test_file_path(TEST_DATA_FILE),
pivot_file_prefix.to_string(),
).unwrap();
let pq_pivot_data =
storage.load_pq_pivots_bin(&num_pq_chunk).unwrap();
assert_eq!(pq_pivot_data.pq_table.len(), NUM_PQ_CENTROIDS * dim);
assert_eq!(pq_pivot_data.centroids.len(), dim);
assert_eq!(pq_pivot_data.chunk_offsets[0], 0);
assert_eq!(pq_pivot_data.chunk_offsets[1], dim);
assert_eq!(pq_pivot_data.chunk_offsets.len(), num_pq_chunk + 1);
}
#[test]
#[should_panic(expected = "ERROR: PQ k-means pivot file not found.")]
fn load_pivot_file_not_exist_test() {
let num_pq_chunk: usize = 1;
let pivot_file_prefix: &str = "tests/data/siftsmall_learn_file_not_exist";
let storage = DiskIndexStorage::<f32>::new(
get_test_file_path(TEST_DATA_FILE),
pivot_file_prefix.to_string(),
).unwrap();
let _ = storage.load_pq_pivots_bin(&num_pq_chunk).unwrap();
}
}

View File

@@ -0,0 +1,12 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
mod disk_index_storage;
pub use disk_index_storage::*;
mod disk_graph_storage;
pub use disk_graph_storage::*;
mod pq_storage;
pub use pq_storage::*;

View File

@@ -0,0 +1,367 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use byteorder::{LittleEndian, ReadBytesExt};
use rand::distributions::{Distribution, Uniform};
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::mem;
use crate::common::{ANNError, ANNResult};
use crate::utils::CachedReader;
use crate::utils::{
convert_types_u32_usize, convert_types_u64_usize, convert_types_usize_u32,
convert_types_usize_u64, convert_types_usize_u8, save_bin_f32, save_bin_u32, save_bin_u64,
};
use crate::utils::{file_exists, load_bin, open_file_to_write, METADATA_SIZE};
#[derive(Debug)]
pub struct PQStorage {
/// Pivot table path
pivot_file: String,
/// Compressed pivot path
compressed_pivot_file: String,
/// Data used to construct PQ table and PQ compressed table
pq_data_file: String,
/// PQ data reader
pq_data_file_reader: File,
}
impl PQStorage {
pub fn new(
pivot_file: &str,
compressed_pivot_file: &str,
pq_data_file: &str,
) -> std::io::Result<Self> {
let pq_data_file_reader = File::open(pq_data_file)?;
Ok(Self {
pivot_file: pivot_file.to_string(),
compressed_pivot_file: compressed_pivot_file.to_string(),
pq_data_file: pq_data_file.to_string(),
pq_data_file_reader,
})
}
pub fn write_compressed_pivot_metadata(&self, npts: i32, pq_chunk: i32) -> std::io::Result<()> {
let mut writer = open_file_to_write(&self.compressed_pivot_file)?;
writer.write_all(&npts.to_le_bytes())?;
writer.write_all(&pq_chunk.to_le_bytes())?;
Ok(())
}
pub fn write_compressed_pivot_data(
&self,
compressed_base: &[usize],
num_centers: usize,
block_size: usize,
num_pq_chunks: usize,
) -> std::io::Result<()> {
let mut writer = open_file_to_write(&self.compressed_pivot_file)?;
writer.seek(SeekFrom::Start((std::mem::size_of::<i32>() * 2) as u64))?;
if num_centers > 256 {
writer.write_all(unsafe {
std::slice::from_raw_parts(
compressed_base.as_ptr() as *const u8,
block_size * num_pq_chunks * std::mem::size_of::<usize>(),
)
})?;
} else {
let compressed_base_u8 =
convert_types_usize_u8(compressed_base, block_size, num_pq_chunks);
writer.write_all(&compressed_base_u8)?;
}
Ok(())
}
pub fn write_pivot_data(
&self,
full_pivot_data: &[f32],
centroid: &[f32],
chunk_offsets: &[usize],
num_centers: usize,
dim: usize,
) -> std::io::Result<()> {
let mut cumul_bytes: Vec<usize> = vec![0; 4];
cumul_bytes[0] = METADATA_SIZE;
cumul_bytes[1] = cumul_bytes[0]
+ save_bin_f32(
&self.pivot_file,
full_pivot_data,
num_centers,
dim,
cumul_bytes[0],
)?;
cumul_bytes[2] =
cumul_bytes[1] + save_bin_f32(&self.pivot_file, centroid, dim, 1, cumul_bytes[1])?;
// Because the writer only can write u32, u64 but not usize, so we need to convert the type first.
let chunk_offsets_u64 = convert_types_usize_u32(chunk_offsets, chunk_offsets.len(), 1);
cumul_bytes[3] = cumul_bytes[2]
+ save_bin_u32(
&self.pivot_file,
&chunk_offsets_u64,
chunk_offsets.len(),
1,
cumul_bytes[2],
)?;
let cumul_bytes_u64 = convert_types_usize_u64(&cumul_bytes, 4, 1);
save_bin_u64(&self.pivot_file, &cumul_bytes_u64, cumul_bytes.len(), 1, 0)?;
Ok(())
}
pub fn pivot_data_exist(&self) -> bool {
file_exists(&self.pivot_file)
}
pub fn read_pivot_metadata(&self) -> std::io::Result<(usize, usize)> {
let (_, file_num_centers, file_dim) = load_bin::<f32>(&self.pivot_file, METADATA_SIZE)?;
Ok((file_num_centers, file_dim))
}
pub fn load_pivot_data(
&self,
num_pq_chunks: &usize,
num_centers: &usize,
dim: &usize,
) -> ANNResult<(Vec<f32>, Vec<f32>, Vec<usize>)> {
// Load file offset data. File saved as offset data(4*1) -> pivot data(centroid num*dim) -> centroid of dim data(dim*1) -> chunk offset data(chunksize+1*1)
// Because we only can write u64 rather than usize, so the file stored as u64 type. Need to convert to usize when use.
let (data, offset_num, nc) = load_bin::<u64>(&self.pivot_file, 0)?;
let file_offset_data = convert_types_u64_usize(&data, offset_num, nc);
if offset_num != 4 {
let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", &self.pivot_file, offset_num);
return Err(ANNError::log_pq_error(error_message));
}
let (data, pivot_num, pivot_dim) = load_bin::<f32>(&self.pivot_file, file_offset_data[0])?;
let full_pivot_data = data;
if pivot_num != *num_centers || pivot_dim != *dim {
let error_message = format!("Error reading pq_pivots file {}. file_num_centers = {}, file_dim = {} but expecting {} centers in {} dimensions.", &self.pivot_file, pivot_num, pivot_dim, num_centers, dim);
return Err(ANNError::log_pq_error(error_message));
}
let (data, centroid_dim, nc) = load_bin::<f32>(&self.pivot_file, file_offset_data[1])?;
let centroid = data;
if centroid_dim != *dim || nc != 1 {
let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", &self.pivot_file, centroid_dim, nc, dim);
return Err(ANNError::log_pq_error(error_message));
}
let (data, chunk_offset_number, nc) =
load_bin::<u32>(&self.pivot_file, file_offset_data[2])?;
let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_number, nc);
if chunk_offset_number != *num_pq_chunks + 1 || nc != 1 {
let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_number, nc, num_pq_chunks + 1);
return Err(ANNError::log_pq_error(error_message));
}
Ok((full_pivot_data, centroid, chunk_offsets))
}
pub fn read_pq_data_metadata(&mut self) -> std::io::Result<(usize, usize)> {
let npts_i32 = self.pq_data_file_reader.read_i32::<LittleEndian>()?;
let dim_i32 = self.pq_data_file_reader.read_i32::<LittleEndian>()?;
let num_points = npts_i32 as usize;
let dim = dim_i32 as usize;
Ok((num_points, dim))
}
pub fn read_pq_block_data<T: Copy>(
&mut self,
cur_block_size: usize,
dim: usize,
) -> std::io::Result<Vec<T>> {
let mut buf = vec![0u8; cur_block_size * dim * std::mem::size_of::<T>()];
self.pq_data_file_reader.read_exact(&mut buf)?;
let ptr = buf.as_ptr() as *const T;
let block_data = unsafe { std::slice::from_raw_parts(ptr, cur_block_size * dim) };
Ok(block_data.to_vec())
}
/// streams data from the file, and samples each vector with probability p_val
/// and returns a matrix of size slice_size* ndims as floating point type.
/// the slice_size and ndims are set inside the function.
/// # Arguments
/// * `file_name` - filename where the data is
/// * `p_val` - possibility to sample data
/// * `sampled_vectors` - sampled vector chose by p_val possibility
/// * `slice_size` - how many sampled data return
/// * `dim` - each sample data dimension
pub fn gen_random_slice<T: Default + Copy + Into<f32>>(
&self,
mut p_val: f64,
) -> ANNResult<(Vec<f32>, usize, usize)> {
let read_blk_size = 64 * 1024 * 1024;
let mut reader = CachedReader::new(&self.pq_data_file, read_blk_size)?;
let npts = reader.read_u32()? as usize;
let dim = reader.read_u32()? as usize;
let mut sampled_vectors: Vec<f32> = Vec::new();
let mut slice_size = 0;
p_val = if p_val < 1f64 { p_val } else { 1f64 };
let mut generator = rand::thread_rng();
let distribution = Uniform::from(0.0..1.0);
for _ in 0..npts {
let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::<T>()];
reader.read(&mut cur_vector_bytes)?;
let random_value = distribution.sample(&mut generator);
if random_value < p_val {
let ptr = cur_vector_bytes.as_ptr() as *const T;
let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) };
sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into()));
slice_size += 1;
}
}
Ok((sampled_vectors, slice_size, dim))
}
}
#[cfg(test)]
mod pq_storage_tests {
use rand::Rng;
use super::*;
use crate::utils::gen_random_slice;
const DATA_FILE: &str = "tests/data/siftsmall_learn.bin";
const PQ_PIVOT_PATH: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin";
const PQ_COMPRESSED_PATH: &str = "tests/data/empty_pq_compressed.bin";
#[test]
fn new_test() {
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE);
assert!(result.is_ok());
}
#[test]
fn write_compressed_pivot_metadata_test() {
let compress_pivot_path = "write_compressed_pivot_metadata_test.bin";
let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap();
_ = result.write_compressed_pivot_metadata(100, 20);
let mut result_reader = File::open(compress_pivot_path).unwrap();
let npts_i32 = result_reader.read_i32::<LittleEndian>().unwrap();
let dim_i32 = result_reader.read_i32::<LittleEndian>().unwrap();
assert_eq!(npts_i32, 100);
assert_eq!(dim_i32, 20);
std::fs::remove_file(compress_pivot_path).unwrap();
}
#[test]
fn write_compressed_pivot_data_test() {
let compress_pivot_path = "write_compressed_pivot_data_test.bin";
let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap();
let mut rng = rand::thread_rng();
let num_centers = 256;
let block_size = 4;
let num_pq_chunks = 2;
let compressed_base: Vec<usize> = (0..block_size * num_pq_chunks)
.map(|_| rng.gen_range(0..num_centers))
.collect();
_ = result.write_compressed_pivot_data(
&compressed_base,
num_centers,
block_size,
num_pq_chunks,
);
let mut result_reader = File::open(compress_pivot_path).unwrap();
_ = result_reader.read_i32::<LittleEndian>().unwrap();
_ = result_reader.read_i32::<LittleEndian>().unwrap();
let mut buf = vec![0u8; block_size * num_pq_chunks * std::mem::size_of::<u8>()];
result_reader.read_exact(&mut buf).unwrap();
let ptr = buf.as_ptr() as *const u8;
let block_data = unsafe { std::slice::from_raw_parts(ptr, block_size * num_pq_chunks) };
for index in 0..block_data.len() {
assert_eq!(compressed_base[index], block_data[index] as usize);
}
std::fs::remove_file(compress_pivot_path).unwrap();
}
#[test]
fn pivot_data_exist_test() {
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
assert!(result.pivot_data_exist());
let pivot_path = "not_exist_pivot_path.bin";
let result = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
assert!(!result.pivot_data_exist());
}
#[test]
fn read_pivot_metadata_test() {
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
let (npt, dim) = result.read_pivot_metadata().unwrap();
assert_eq!(npt, 256);
assert_eq!(dim, 128);
}
#[test]
fn load_pivot_data_test() {
let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
let (pq_pivot_data, centroids, chunk_offsets) =
result.load_pivot_data(&1, &256, &128).unwrap();
assert_eq!(pq_pivot_data.len(), 256 * 128);
assert_eq!(centroids.len(), 128);
assert_eq!(chunk_offsets.len(), 2);
}
#[test]
fn read_pq_data_metadata_test() {
let mut result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap();
let (npt, dim) = result.read_pq_data_metadata().unwrap();
assert_eq!(npt, 25000);
assert_eq!(dim, 128);
}
#[test]
fn gen_random_slice_test() {
let file_name = "gen_random_slice_test.bin";
//npoints=2, dim=8
let data: [u8; 72] = [
2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
];
std::fs::write(file_name, data).expect("Failed to write sample file");
let (sampled_vectors, slice_size, ndims) =
gen_random_slice::<f32>(file_name, 1f64).unwrap();
let mut start = 8;
(0..sampled_vectors.len()).for_each(|i| {
assert_eq!(sampled_vectors[i].to_le_bytes(), data[start..start + 4]);
start += 4;
});
assert_eq!(sampled_vectors.len(), 16);
assert_eq!(slice_size, 2);
assert_eq!(ndims, 8);
let (sampled_vectors, slice_size, ndims) =
gen_random_slice::<f32>(file_name, 0f64).unwrap();
assert_eq!(sampled_vectors.len(), 0);
assert_eq!(slice_size, 0);
assert_eq!(ndims, 8);
std::fs::remove_file(file_name).expect("Failed to delete file");
}
}

View File

@@ -0,0 +1,74 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use vector::Metric;
use crate::index::InmemIndex;
use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder;
use crate::model::{IndexConfiguration};
use crate::model::vertex::DIM_128;
use crate::utils::{file_exists, load_metadata_from_file};
use super::get_test_file_path;
// f32, 128 DIM and 256 points source data
const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin";
const NUM_POINTS_TO_LOAD: usize = 256;
pub fn create_index_with_test_data() -> InmemIndex<f32, DIM_128> {
let index_write_parameters = IndexWriteParametersBuilder::new(50, 4).with_alpha(1.2).build();
let config = IndexConfiguration::new(
Metric::L2,
128,
128,
256,
false,
0,
false,
0,
1.0f32,
index_write_parameters);
let mut index: InmemIndex<f32, DIM_128> = InmemIndex::new(config).unwrap();
build_test_index(&mut index, get_test_file_path(TEST_DATA_FILE).as_str(), NUM_POINTS_TO_LOAD);
index.start = index.dataset.calculate_medoid_point_id().unwrap();
index
}
fn build_test_index(index: &mut InmemIndex<f32, DIM_128>, filename: &str, num_points_to_load: usize) {
if !file_exists(filename) {
panic!("ERROR: Data file {} does not exist.", filename);
}
let (file_num_points, file_dim) = load_metadata_from_file(filename).unwrap();
if file_num_points > index.configuration.max_points {
panic!(
"ERROR: Driver requests loading {} points and file has {} points,
but index can support only {} points as specified in configuration.",
num_points_to_load, file_num_points, index.configuration.max_points
);
}
if num_points_to_load > file_num_points {
panic!(
"ERROR: Driver requests loading {} points and file has only {} points.",
num_points_to_load, file_num_points
);
}
if file_dim != index.configuration.dim {
panic!(
"ERROR: Driver requests loading {} dimension, but file has {} dimension.",
index.configuration.dim, file_dim
);
}
index.dataset.build_from_file(filename, num_points_to_load).unwrap();
println!("Using only first {} from file.", num_points_to_load);
index.num_active_pts = num_points_to_load;
}

View File

@@ -0,0 +1,11 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub mod inmem_index_initialization;
/// test files should be placed under tests folder
pub fn get_test_file_path(relative_path: &str) -> String {
format!("{}/{}", env!("CARGO_MANIFEST_DIR"), relative_path)
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::cmp::Ordering;
use bit_vec::BitVec;
pub trait BitVecExtension {
fn resize(&mut self, new_len: usize, value: bool);
}
impl BitVecExtension for BitVec {
fn resize(&mut self, new_len: usize, value: bool) {
let old_len = self.len();
match new_len.cmp(&old_len) {
Ordering::Less => self.truncate(new_len),
Ordering::Greater => self.grow(new_len - old_len, value),
Ordering::Equal => {}
}
}
}
#[cfg(test)]
mod bit_vec_extension_test {
use super::*;
#[test]
fn resize_test() {
let mut bitset = BitVec::new();
bitset.resize(10, false);
assert_eq!(bitset.len(), 10);
assert!(bitset.none());
bitset.resize(11, true);
assert_eq!(bitset.len(), 11);
assert!(bitset[10]);
bitset.resize(5, false);
assert_eq!(bitset.len(), 5);
assert!(bitset.none());
}
}

View File

@@ -0,0 +1,160 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::fs::File;
use std::io::{Seek, Read};
use crate::common::{ANNResult, ANNError};
/// Sequential cached reads
pub struct CachedReader {
/// File reader
reader: File,
/// # bytes to cache in one shot read
cache_size: u64,
/// Underlying buf for cache
cache_buf: Vec<u8>,
/// Offset into cache_buf for cur_pos
cur_off: u64,
/// File size
fsize: u64,
}
impl CachedReader {
pub fn new(filename: &str, cache_size: u64) -> std::io::Result<Self> {
let mut reader = File::open(filename)?;
let metadata = reader.metadata()?;
let fsize = metadata.len();
let cache_size = cache_size.min(fsize);
let mut cache_buf = vec![0; cache_size as usize];
reader.read_exact(&mut cache_buf)?;
println!("Opened: {}, size: {}, cache_size: {}", filename, fsize, cache_size);
Ok(Self {
reader,
cache_size,
cache_buf,
cur_off: 0,
fsize,
})
}
pub fn get_file_size(&self) -> u64 {
self.fsize
}
pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> {
let n_bytes = read_buf.len() as u64;
if n_bytes <= (self.cache_size - self.cur_off) {
// case 1: cache contains all data
read_buf.copy_from_slice(&self.cache_buf[(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)]);
self.cur_off += n_bytes;
} else {
// case 2: cache contains some data
let cached_bytes = self.cache_size - self.cur_off;
if n_bytes - cached_bytes > self.fsize - self.reader.stream_position()? {
return Err(ANNError::log_index_error(format!(
"Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}",
n_bytes, cached_bytes, self.fsize, self.reader.stream_position()?))
);
}
read_buf[..cached_bytes as usize].copy_from_slice(&self.cache_buf[self.cur_off as usize..]);
// go to disk and fetch more data
self.reader.read_exact(&mut read_buf[cached_bytes as usize..])?;
// reset cur off
self.cur_off = self.cache_size;
let size_left = self.fsize - self.reader.stream_position()?;
if size_left >= self.cache_size {
self.reader.read_exact(&mut self.cache_buf)?;
self.cur_off = 0;
}
// note that if size_left < cache_size, then cur_off = cache_size,
// so subsequent reads will all be directly from file
}
Ok(())
}
pub fn read_u32(&mut self) -> ANNResult<u32> {
let mut bytes = [0u8; 4];
self.read(&mut bytes)?;
Ok(u32::from_le_bytes(bytes))
}
}
#[cfg(test)]
mod cached_reader_test {
use std::fs;
use super::*;
#[test]
fn cached_reader_works() {
let file_name = "cached_reader_works_test.bin";
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3,
0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41];
std::fs::write(file_name, data).expect("Failed to write sample file");
let mut reader = CachedReader::new(file_name, 8).unwrap();
assert_eq!(reader.get_file_size(), 72);
assert_eq!(reader.cache_size, 8);
let mut all_from_cache_buf = vec![0; 4];
reader.read(all_from_cache_buf.as_mut_slice()).unwrap();
assert_eq!(all_from_cache_buf, [2, 0, 1, 2]);
assert_eq!(reader.cur_off, 4);
let mut partial_from_cache_buf = vec![0; 6];
reader.read(partial_from_cache_buf.as_mut_slice()).unwrap();
assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]);
assert_eq!(reader.cur_off, 0);
let mut over_cache_size_buf = vec![0; 60];
reader.read(over_cache_size_buf.as_mut_slice()).unwrap();
assert_eq!(
over_cache_size_buf,
[0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11]
);
let mut remaining_less_than_cache_size_buf = vec![0; 2];
reader.read(remaining_less_than_cache_size_buf.as_mut_slice()).unwrap();
assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]);
assert_eq!(reader.cur_off, reader.cache_size);
fs::remove_file(file_name).expect("Failed to delete file");
}
#[test]
#[should_panic(expected = "n_bytes: 73 cached_bytes: 8 fsize: 72 current pos: 8")]
fn failed_for_reading_beyond_end_of_file() {
let file_name = "failed_for_reading_beyond_end_of_file_test.bin";
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3,
0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41];
std::fs::write(file_name, data).expect("Failed to write sample file");
let mut reader = CachedReader::new(file_name, 8).unwrap();
fs::remove_file(file_name).expect("Failed to delete file");
let mut over_size_buf = vec![0; 73];
reader.read(over_size_buf.as_mut_slice()).unwrap();
}
}

View File

@@ -0,0 +1,142 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::io::{Write, Seek, SeekFrom};
use std::fs::{OpenOptions, File};
use std::path::Path;
pub struct CachedWriter {
/// File writer
writer: File,
/// # bytes to cache for one shot write
cache_size: u64,
/// Underlying buf for cache
cache_buf: Vec<u8>,
/// Offset into cache_buf for cur_pos
cur_off: u64,
/// File size
fsize: u64,
}
impl CachedWriter {
pub fn new(filename: &str, cache_size: u64) -> std::io::Result<Self> {
let writer = OpenOptions::new()
.write(true)
.create(true)
.open(Path::new(filename))?;
if cache_size == 0 {
return Err(std::io::Error::new(std::io::ErrorKind::Other, "Cache size must be greater than 0"));
}
println!("Opened: {}, cache_size: {}", filename, cache_size);
Ok(Self {
writer,
cache_size,
cache_buf: vec![0; cache_size as usize],
cur_off: 0,
fsize: 0,
})
}
pub fn flush(&mut self) -> std::io::Result<()> {
// dump any remaining data in memory
if self.cur_off > 0 {
self.flush_cache()?;
}
self.writer.flush()?;
println!("Finished writing {}B", self.fsize);
Ok(())
}
pub fn get_file_size(&self) -> u64 {
self.fsize
}
/// Writes n_bytes from write_buf to the underlying cache
pub fn write(&mut self, write_buf: &[u8]) -> std::io::Result<()> {
let n_bytes = write_buf.len() as u64;
if n_bytes <= (self.cache_size - self.cur_off) {
// case 1: cache can take all data
self.cache_buf[(self.cur_off as usize)..((self.cur_off + n_bytes) as usize)].copy_from_slice(&write_buf[..n_bytes as usize]);
self.cur_off += n_bytes;
} else {
// case 2: cache cant take all data
// go to disk and write existing cache data
self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?;
self.fsize += self.cur_off;
// write the new data to disk
self.writer.write_all(write_buf)?;
self.fsize += n_bytes;
// clear cache data and reset cur_off
self.cache_buf.fill(0);
self.cur_off = 0;
}
Ok(())
}
pub fn reset(&mut self) -> std::io::Result<()> {
self.flush_cache()?;
self.writer.seek(SeekFrom::Start(0))?;
Ok(())
}
fn flush_cache(&mut self) -> std::io::Result<()> {
self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?;
self.fsize += self.cur_off;
self.cache_buf.fill(0);
self.cur_off = 0;
Ok(())
}
}
impl Drop for CachedWriter {
fn drop(&mut self) {
let _ = self.flush();
}
}
#[cfg(test)]
mod cached_writer_test {
use std::fs;
use super::*;
#[test]
fn cached_writer_works() {
let file_name = "cached_writer_works_test.bin";
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3,
0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41];
let mut writer = CachedWriter::new(file_name, 8).unwrap();
assert_eq!(writer.get_file_size(), 0);
assert_eq!(writer.cache_size, 8);
assert_eq!(writer.get_file_size(), 0);
let cache_all_buf = &data[0..4];
writer.write(cache_all_buf).unwrap();
assert_eq!(&writer.cache_buf[..4], cache_all_buf);
assert_eq!(&writer.cache_buf[4..], vec![0; 4]);
assert_eq!(writer.cur_off, 4);
assert_eq!(writer.get_file_size(), 0);
let write_all_buf = &data[4..10];
writer.write(write_all_buf).unwrap();
assert_eq!(writer.cache_buf, vec![0; 8]);
assert_eq!(writer.cur_off, 0);
assert_eq!(writer.get_file_size(), 10);
fs::remove_file(file_name).expect("Failed to delete file");
}
}

View File

@@ -0,0 +1,377 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! File operations
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::{mem, io};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, BufReader, Write, Seek, SeekFrom};
use std::path::Path;
use crate::model::data_store::DatasetDto;
/// Read metadata of data file.
pub fn load_metadata_from_file(file_name: &str) -> std::io::Result<(usize, usize)> {
let file = File::open(file_name)?;
let mut reader = BufReader::new(file);
let npoints = reader.read_i32::<LittleEndian>()? as usize;
let ndims = reader.read_i32::<LittleEndian>()? as usize;
Ok((npoints, ndims))
}
/// Read the deleted vertex ids from file.
pub fn load_ids_to_delete_from_file(file_name: &str) -> std::io::Result<(usize, Vec<u32>)> {
// The first 4 bytes are the number of vector ids.
// The rest of the file are the vector ids in the format of usize.
// The vector ids are sorted in ascending order.
let mut file = File::open(file_name)?;
let num_ids = file.read_u32::<LittleEndian>()? as usize;
let mut ids = Vec::with_capacity(num_ids);
for _ in 0..num_ids {
let id = file.read_u32::<LittleEndian>()?;
ids.push(id);
}
Ok((num_ids, ids))
}
/// Copy data from file
/// # Arguments
/// * `bin_file` - filename where the data is
/// * `data` - destination dataset dto to which the data is copied
/// * `pts_offset` - offset of points. data will be loaded after this point in dataset
/// * `npts` - number of points read from bin_file
/// * `dim` - point dimension read from bin_file
/// * `rounded_dim` - rounded dimension (padding zero if it's > dim)
/// # Return
/// * `npts` - number of points read from bin_file
/// * `dim` - point dimension read from bin_file
pub fn copy_aligned_data_from_file<T: Default + Copy>(
bin_file: &str,
dataset_dto: DatasetDto<T>,
pts_offset: usize,
) -> std::io::Result<(usize, usize)> {
let mut reader = File::open(bin_file)?;
let npts = reader.read_i32::<LittleEndian>()? as usize;
let dim = reader.read_i32::<LittleEndian>()? as usize;
let rounded_dim = dataset_dto.rounded_dim;
let offset = pts_offset * rounded_dim;
for i in 0..npts {
let data_slice = &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim];
let mut buf = vec![0u8; dim * mem::size_of::<T>()];
reader.read_exact(&mut buf)?;
let ptr = buf.as_ptr() as *const T;
let temp_slice = unsafe { std::slice::from_raw_parts(ptr, dim) };
data_slice.copy_from_slice(temp_slice);
(i * rounded_dim + dim..i * rounded_dim + rounded_dim).for_each(|j| {
dataset_dto.data[j] = T::default();
});
}
Ok((npts, dim))
}
/// Open a file to write
/// # Arguments
/// * `writer` - mutable File reference
/// * `file_name` - file name
#[inline]
pub fn open_file_to_write(file_name: &str) -> std::io::Result<File> {
OpenOptions::new()
.write(true)
.create(true)
.open(Path::new(file_name))
}
/// Delete a file
/// # Arguments
/// * `file_name` - file name
pub fn delete_file(file_name: &str) -> std::io::Result<()> {
if file_exists(file_name) {
fs::remove_file(file_name)?;
}
Ok(())
}
/// Check whether file exists or not
pub fn file_exists(filename: &str) -> bool {
std::path::Path::new(filename).exists()
}
/// Save data to file
/// # Arguments
/// * `filename` - filename where the data is
/// * `data` - information data
/// * `npts` - number of points
/// * `ndims` - point dimension
/// * `aligned_dim` - aligned dimension
/// * `offset` - data offset in file
pub fn save_data_in_base_dimensions<T: Default + Copy>(
filename: &str,
data: &mut [T],
npts: usize,
ndims: usize,
aligned_dim: usize,
offset: usize,
) -> std::io::Result<usize> {
let mut writer = open_file_to_write(filename)?;
let npts_i32 = npts as i32;
let ndims_i32 = ndims as i32;
let bytes_written = 2 * std::mem::size_of::<u32>() + npts * ndims * (std::mem::size_of::<T>());
writer.seek(std::io::SeekFrom::Start(offset as u64))?;
writer.write_all(&npts_i32.to_le_bytes())?;
writer.write_all(&ndims_i32.to_le_bytes())?;
let data_ptr = data.as_ptr() as *const u8;
for i in 0..npts {
let middle_offset = i * aligned_dim * std::mem::size_of::<T>();
let middle_slice = unsafe { std::slice::from_raw_parts(data_ptr.add(middle_offset), ndims * std::mem::size_of::<T>()) };
writer.write_all(middle_slice)?;
}
writer.flush()?;
Ok(bytes_written)
}
/// Read data file
/// # Arguments
/// * `bin_file` - filename where the data is
/// * `file_offset` - data offset in file
/// * `data` - information data
/// * `npts` - number of points
/// * `ndims` - point dimension
pub fn load_bin<T: Copy>(
bin_file: &str,
file_offset: usize) -> std::io::Result<(Vec<T>, usize, usize)>
{
let mut reader = File::open(bin_file)?;
reader.seek(std::io::SeekFrom::Start(file_offset as u64))?;
let npts = reader.read_i32::<LittleEndian>()? as usize;
let dim = reader.read_i32::<LittleEndian>()? as usize;
let size = npts * dim * std::mem::size_of::<T>();
let mut buf = vec![0u8; size];
reader.read_exact(&mut buf)?;
let ptr = buf.as_ptr() as *const T;
let data = unsafe { std::slice::from_raw_parts(ptr, npts * dim)};
Ok((data.to_vec(), npts, dim))
}
/// Get file size
pub fn get_file_size(filename: &str) -> io::Result<u64> {
let reader = File::open(filename)?;
let metadata = reader.metadata()?;
Ok(metadata.len())
}
macro_rules! save_bin {
($name:ident, $t:ty, $write_func:ident) => {
/// Write data into file
pub fn $name(filename: &str, data: &[$t], num_pts: usize, dims: usize, offset: usize) -> std::io::Result<usize> {
let mut writer = open_file_to_write(filename)?;
println!("Writing bin: {}", filename);
writer.seek(SeekFrom::Start(offset as u64))?;
let num_pts_i32 = num_pts as i32;
let dims_i32 = dims as i32;
let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::<u32>();
writer.write_i32::<LittleEndian>(num_pts_i32)?;
writer.write_i32::<LittleEndian>(dims_i32)?;
println!("bin: #pts = {}, #dims = {}, size = {}B", num_pts, dims, bytes_written);
for item in data.iter() {
writer.$write_func::<LittleEndian>(*item)?;
}
writer.flush()?;
println!("Finished writing bin.");
Ok(bytes_written)
}
};
}
save_bin!(save_bin_f32, f32, write_f32);
save_bin!(save_bin_u64, u64, write_u64);
save_bin!(save_bin_u32, u32, write_u32);
#[cfg(test)]
mod file_util_test {
use crate::model::data_store::InmemDataset;
use std::fs;
use super::*;
pub const DIM_8: usize = 8;
#[test]
fn load_metadata_test() {
let file_name = "test_load_metadata_test.bin";
let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes
std::fs::write(file_name, data).expect("Failed to write sample file");
match load_metadata_from_file(file_name) {
Ok((npoints, ndims)) => {
assert!(npoints == 200);
assert!(ndims == 128);
},
Err(_e) => {},
}
fs::remove_file(file_name).expect("Failed to delete file");
}
#[test]
fn load_data_test() {
let file_name = "test_load_data_test.bin";
//npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8]
let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0,
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41];
std::fs::write(file_name, data).expect("Failed to write sample file");
let mut dataset = InmemDataset::<f32, DIM_8>::new(2, 1f32).unwrap();
match copy_aligned_data_from_file(file_name, dataset.into_dto(), 0) {
Ok((num_points, dim)) => {
fs::remove_file(file_name).expect("Failed to delete file");
assert!(num_points == 2);
assert!(dim == 8);
assert!(dataset.data.len() == 16);
let first_vertex = dataset.get_vertex(0).unwrap();
let second_vertex = dataset.get_vertex(1).unwrap();
assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
},
Err(e) => {
fs::remove_file(file_name).expect("Failed to delete file");
panic!("{}", e)
},
}
}
#[test]
fn open_file_to_write_test() {
let file_name = "test_open_file_to_write_test.bin";
let mut writer = File::create(file_name).unwrap();
let data = [200, 0, 0, 0, 128, 0, 0, 0];
writer.write(&data).expect("Failed to write sample file");
let _ = open_file_to_write(file_name);
fs::remove_file(file_name).expect("Failed to delete file");
}
#[test]
fn delete_file_test() {
let file_name = "test_delete_file_test.bin";
let mut file = File::create(file_name).unwrap();
writeln!(file, "test delete file").unwrap();
let result = delete_file(file_name);
assert!(result.is_ok());
assert!(fs::metadata(file_name).is_err());
}
#[test]
fn save_data_in_base_dimensions_test() {
//npoints=2, dim=8
let mut data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0,
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41];
let num_points = 2;
let dim = DIM_8;
let data_file = "save_data_in_base_dimensions_test.data";
match save_data_in_base_dimensions(data_file, &mut data, num_points, dim, DIM_8, 0) {
Ok(num) => {
assert!(file_exists(data_file));
assert_eq!(num, 2 * std::mem::size_of::<u32>() + num_points * dim * std::mem::size_of::<u8>());
fs::remove_file(data_file).expect("Failed to delete file");
},
Err(e) => {
fs::remove_file(data_file).expect("Failed to delete file");
panic!("{}", e)
}
}
}
#[test]
fn save_bin_test() {
let filename = "save_bin_test";
let data = vec![0u64, 1u64, 2u64];
let num_pts = data.len();
let dims = 1;
let bytes_written = save_bin_u64(filename, &data, num_pts, dims, 0).unwrap();
assert_eq!(bytes_written, 32);
let mut file = File::open(filename).unwrap();
let mut buffer = vec![];
let npts_read = file.read_i32::<LittleEndian>().unwrap() as usize;
let dims_read = file.read_i32::<LittleEndian>().unwrap() as usize;
file.read_to_end(&mut buffer).unwrap();
let data_read: Vec<u64> = buffer
.chunks_exact(8)
.map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]))
.collect();
std::fs::remove_file(filename).unwrap();
assert_eq!(num_pts, npts_read);
assert_eq!(dims, dims_read);
assert_eq!(data, data_read);
}
#[test]
fn load_bin_test() {
let file_name = "load_bin_test";
let data = vec![0u64, 1u64, 2u64];
let num_pts = data.len();
let dims = 1;
let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, 0).unwrap();
assert_eq!(bytes_written, 32);
let (load_data, load_num_pts, load_dims) = load_bin::<u64>(file_name, 0).unwrap();
assert_eq!(load_num_pts, num_pts);
assert_eq!(load_dims, dims);
assert_eq!(load_data, data);
std::fs::remove_file(file_name).unwrap();
}
#[test]
fn load_bin_offset_test() {
let offset:usize = 32;
let file_name = "load_bin_offset_test";
let data = vec![0u64, 1u64, 2u64];
let num_pts = data.len();
let dims = 1;
let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, offset).unwrap();
assert_eq!(bytes_written, 32);
let (load_data, load_num_pts, load_dims) = load_bin::<u64>(file_name, offset).unwrap();
assert_eq!(load_num_pts, num_pts);
assert_eq!(load_dims, dims);
assert_eq!(load_data, data);
std::fs::remove_file(file_name).unwrap();
}
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use hashbrown::HashSet;
use std::{hash::BuildHasherDefault, ops::{Deref, DerefMut}};
use fxhash::FxHasher;
lazy_static::lazy_static! {
/// Singleton hasher.
static ref HASHER: BuildHasherDefault<FxHasher> = {
BuildHasherDefault::<FxHasher>::default()
};
}
pub struct HashSetForU32 {
hashset: HashSet::<u32, BuildHasherDefault<FxHasher>>,
}
impl HashSetForU32 {
pub fn with_capacity(capacity: usize) -> HashSetForU32 {
let hashset = HashSet::<u32, BuildHasherDefault<FxHasher>>::with_capacity_and_hasher(capacity, HASHER.clone());
HashSetForU32 {
hashset
}
}
}
impl Deref for HashSetForU32 {
type Target = HashSet::<u32, BuildHasherDefault<FxHasher>>;
fn deref(&self) -> &Self::Target {
&self.hashset
}
}
impl DerefMut for HashSetForU32 {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.hashset
}
}

View File

@@ -0,0 +1,430 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Aligned allocator
use rand::{distributions::Uniform, prelude::Distribution, thread_rng};
use rayon::prelude::*;
use std::cmp::min;
use crate::common::ANNResult;
use crate::utils::math_util::{calc_distance, compute_closest_centers, compute_vecs_l2sq};
/// Run Lloyds one iteration
/// Given data in row-major num_points * dim, and centers in row-major
/// num_centers * dim and squared lengths of ata points, output the closest
/// center to each data point, update centers, and also return inverted index.
/// If closest_centers == NULL, will allocate memory and return.
/// Similarly, if closest_docs == NULL, will allocate memory and return.
#[allow(clippy::too_many_arguments)]
fn lloyds_iter(
data: &[f32],
num_points: usize,
dim: usize,
centers: &mut [f32],
num_centers: usize,
docs_l2sq: &[f32],
mut closest_docs: &mut Vec<Vec<usize>>,
closest_center: &mut [u32],
) -> ANNResult<f32> {
let compute_residual = true;
closest_docs.iter_mut().for_each(|doc| doc.clear());
compute_closest_centers(
data,
num_points,
dim,
centers,
num_centers,
1,
closest_center,
Some(&mut closest_docs),
Some(docs_l2sq),
)?;
centers.fill(0.0);
centers
.par_chunks_mut(dim)
.enumerate()
.for_each(|(c, center)| {
let mut cluster_sum = vec![0.0; dim];
for &doc_index in &closest_docs[c] {
let current = &data[doc_index * dim..(doc_index + 1) * dim];
for (j, current_val) in current.iter().enumerate() {
cluster_sum[j] += *current_val as f64;
}
}
if !closest_docs[c].is_empty() {
for (i, sum_val) in cluster_sum.iter().enumerate() {
center[i] = (*sum_val / closest_docs[c].len() as f64) as f32;
}
}
});
let mut residual = 0.0;
if compute_residual {
let buf_pad: usize = 32;
let chunk_size: usize = 2 * 8192;
let nchunks =
num_points / chunk_size + (if num_points % chunk_size == 0 { 0 } else { 1 } as usize);
let mut residuals: Vec<f32> = vec![0.0; nchunks * buf_pad];
residuals
.par_iter_mut()
.enumerate()
.for_each(|(chunk, res)| {
for d in (chunk * chunk_size)..min(num_points, (chunk + 1) * chunk_size) {
*res += calc_distance(
&data[d * dim..(d + 1) * dim],
&centers[closest_center[d] as usize * dim..],
dim,
);
}
});
for chunk in 0..nchunks {
residual += residuals[chunk * buf_pad];
}
}
Ok(residual)
}
/// Run Lloyds until max_reps or stopping criterion
/// If you pass NULL for closest_docs and closest_center, it will NOT return
/// the results, else it will assume appropriate allocation as closest_docs =
/// new vec<usize> [num_centers], and closest_center = new size_t[num_points]
/// Final centers are output in centers as row-major num_centers * dim.
fn run_lloyds(
data: &[f32],
num_points: usize,
dim: usize,
centers: &mut [f32],
num_centers: usize,
max_reps: usize,
) -> ANNResult<(Vec<Vec<usize>>, Vec<u32>, f32)> {
let mut residual = f32::MAX;
let mut closest_docs = vec![Vec::new(); num_centers];
let mut closest_center = vec![0; num_points];
let mut docs_l2sq = vec![0.0; num_points];
compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim);
let mut old_residual;
for i in 0..max_reps {
old_residual = residual;
residual = lloyds_iter(
data,
num_points,
dim,
centers,
num_centers,
&docs_l2sq,
&mut closest_docs,
&mut closest_center,
)?;
if (i != 0 && (old_residual - residual) / residual < 0.00001) || (residual < f32::EPSILON) {
println!(
"Residuals unchanged: {} becomes {}. Early termination.",
old_residual, residual
);
break;
}
}
Ok((closest_docs, closest_center, residual))
}
/// Assume memory allocated for pivot_data as new float[num_centers * dim]
/// and select randomly num_centers points as pivots
fn selecting_pivots(
data: &[f32],
num_points: usize,
dim: usize,
pivot_data: &mut [f32],
num_centers: usize,
) {
let mut picked = Vec::new();
let mut rng = thread_rng();
let distribution = Uniform::from(0..num_points);
for j in 0..num_centers {
let mut tmp_pivot = distribution.sample(&mut rng);
while picked.contains(&tmp_pivot) {
tmp_pivot = distribution.sample(&mut rng);
}
picked.push(tmp_pivot);
let data_offset = tmp_pivot * dim;
let pivot_offset = j * dim;
pivot_data[pivot_offset..pivot_offset + dim]
.copy_from_slice(&data[data_offset..data_offset + dim]);
}
}
/// Select pivots in k-means++ algorithm
/// Points that are farther away from the already chosen centroids
/// have a higher probability of being selected as the next centroid.
/// The k-means++ algorithm helps avoid poor initial centroid
/// placement that can result in suboptimal clustering.
fn k_meanspp_selecting_pivots(
data: &[f32],
num_points: usize,
dim: usize,
pivot_data: &mut [f32],
num_centers: usize,
) {
if num_points > (1 << 23) {
println!("ERROR: n_pts {} currently not supported for k-means++, maximum is 8388608. Falling back to random pivot selection.", num_points);
selecting_pivots(data, num_points, dim, pivot_data, num_centers);
return;
}
let mut picked: Vec<usize> = Vec::new();
let mut rng = thread_rng();
let real_distribution = Uniform::from(0.0..1.0);
let int_distribution = Uniform::from(0..num_points);
let init_id = int_distribution.sample(&mut rng);
let mut num_picked = 1;
picked.push(init_id);
let init_data_offset = init_id * dim;
pivot_data[0..dim].copy_from_slice(&data[init_data_offset..init_data_offset + dim]);
let mut dist = vec![0.0; num_points];
dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| {
*dist_i = calc_distance(
&data[i * dim..(i + 1) * dim],
&data[init_id * dim..(init_id + 1) * dim],
dim,
);
});
let mut dart_val: f64;
let mut tmp_pivot = 0;
let mut sum_flag = false;
while num_picked < num_centers {
dart_val = real_distribution.sample(&mut rng);
let mut sum: f64 = 0.0;
for item in dist.iter().take(num_points) {
sum += *item as f64;
}
if sum == 0.0 {
sum_flag = true;
}
dart_val *= sum;
let mut prefix_sum: f64 = 0.0;
for (i, pivot) in dist.iter().enumerate().take(num_points) {
tmp_pivot = i;
if dart_val >= prefix_sum && dart_val < (prefix_sum + *pivot as f64) {
break;
}
prefix_sum += *pivot as f64;
}
if picked.contains(&tmp_pivot) && !sum_flag {
continue;
}
picked.push(tmp_pivot);
let pivot_offset = num_picked * dim;
let data_offset = tmp_pivot * dim;
pivot_data[pivot_offset..pivot_offset + dim]
.copy_from_slice(&data[data_offset..data_offset + dim]);
dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| {
*dist_i = (*dist_i).min(calc_distance(
&data[i * dim..(i + 1) * dim],
&data[tmp_pivot * dim..(tmp_pivot + 1) * dim],
dim,
));
});
num_picked += 1;
}
}
/// k-means algorithm interface
pub fn k_means_clustering(
data: &[f32],
num_points: usize,
dim: usize,
centers: &mut [f32],
num_centers: usize,
max_reps: usize,
) -> ANNResult<(Vec<Vec<usize>>, Vec<u32>, f32)> {
k_meanspp_selecting_pivots(data, num_points, dim, centers, num_centers);
let (closest_docs, closest_center, residual) =
run_lloyds(data, num_points, dim, centers, num_centers, max_reps)?;
Ok((closest_docs, closest_center, residual))
}
#[cfg(test)]
mod kmeans_test {
use super::*;
use approx::assert_relative_eq;
use rand::Rng;
#[test]
fn lloyds_iter_test() {
let dim = 2;
let num_points = 10;
let num_centers = 3;
let data: Vec<f32> = (1..=num_points * dim).map(|x| x as f32).collect();
let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0];
let mut closest_docs: Vec<Vec<usize>> = vec![vec![]; num_centers];
let mut closest_center: Vec<u32> = vec![0; num_points];
let docs_l2sq: Vec<f32> = data
.chunks(dim)
.map(|chunk| chunk.iter().map(|val| val.powi(2)).sum())
.collect();
let residual = lloyds_iter(
&data,
num_points,
dim,
&mut centers,
num_centers,
&docs_l2sq,
&mut closest_docs,
&mut closest_center,
)
.unwrap();
let expected_centers: [f32; 6] = [2.0, 3.0, 9.0, 10.0, 17.0, 18.0];
let expected_closest_docs: Vec<Vec<usize>> =
vec![vec![0, 1], vec![2, 3, 4, 5, 6], vec![7, 8, 9]];
let expected_closest_center: [u32; 10] = [0, 0, 1, 1, 1, 1, 1, 2, 2, 2];
let expected_residual: f32 = 100.0;
// sort data for assert
centers.sort_by(|a, b| a.partial_cmp(b).unwrap());
for inner_vec in &mut closest_docs {
inner_vec.sort();
}
closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_eq!(centers, expected_centers);
assert_eq!(closest_docs, expected_closest_docs);
assert_eq!(closest_center, expected_closest_center);
assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32);
}
#[test]
fn run_lloyds_test() {
let dim = 2;
let num_points = 10;
let num_centers = 3;
let max_reps = 5;
let data: Vec<f32> = (1..=num_points * dim).map(|x| x as f32).collect();
let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0];
let (mut closest_docs, mut closest_center, residual) =
run_lloyds(&data, num_points, dim, &mut centers, num_centers, max_reps).unwrap();
let expected_centers: [f32; 6] = [3.0, 4.0, 10.0, 11.0, 17.0, 18.0];
let expected_closest_docs: Vec<Vec<usize>> =
vec![vec![0, 1, 2], vec![3, 4, 5, 6], vec![7, 8, 9]];
let expected_closest_center: [u32; 10] = [0, 0, 0, 1, 1, 1, 1, 2, 2, 2];
let expected_residual: f32 = 72.0;
// sort data for assert
centers.sort_by(|a, b| a.partial_cmp(b).unwrap());
for inner_vec in &mut closest_docs {
inner_vec.sort();
}
closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_eq!(centers, expected_centers);
assert_eq!(closest_docs, expected_closest_docs);
assert_eq!(closest_center, expected_closest_center);
assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32);
}
#[test]
fn selecting_pivots_test() {
let dim = 2;
let num_points = 10;
let num_centers = 3;
// Generate some random data points
let mut rng = rand::thread_rng();
let data: Vec<f32> = (0..num_points * dim).map(|_| rng.gen()).collect();
let mut pivot_data = vec![0.0; num_centers * dim];
selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers);
// Verify that each pivot point corresponds to a point in the data
for i in 0..num_centers {
let pivot_offset = i * dim;
let pivot = &pivot_data[pivot_offset..(pivot_offset + dim)];
// Make sure the pivot is found in the data
let mut found = false;
for j in 0..num_points {
let data_offset = j * dim;
let point = &data[data_offset..(data_offset + dim)];
if pivot == point {
found = true;
break;
}
}
assert!(found, "Pivot not found in data");
}
}
#[test]
fn k_meanspp_selecting_pivots_test() {
let dim = 2;
let num_points = 10;
let num_centers = 3;
// Generate some random data points
let mut rng = rand::thread_rng();
let data: Vec<f32> = (0..num_points * dim).map(|_| rng.gen()).collect();
let mut pivot_data = vec![0.0; num_centers * dim];
k_meanspp_selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers);
// Verify that each pivot point corresponds to a point in the data
for i in 0..num_centers {
let pivot_offset = i * dim;
let pivot = &pivot_data[pivot_offset..pivot_offset + dim];
// Make sure the pivot is found in the data
let mut found = false;
for j in 0..num_points {
let data_offset = j * dim;
let point = &data[data_offset..data_offset + dim];
if pivot == point {
found = true;
break;
}
}
assert!(found, "Pivot not found in data");
}
}
}

View File

@@ -0,0 +1,481 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Aligned allocator
extern crate cblas;
extern crate openblas_src;
use cblas::{sgemm, snrm2, Layout, Transpose};
use rayon::prelude::*;
use std::{
cmp::{min, Ordering},
collections::BinaryHeap,
sync::{Arc, Mutex},
};
use crate::common::{ANNError, ANNResult};
struct PivotContainer {
piv_id: usize,
piv_dist: f32,
}
impl PartialOrd for PivotContainer {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
other.piv_dist.partial_cmp(&self.piv_dist)
}
}
impl Ord for PivotContainer {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Treat NaN as less than all other values.
// piv_dist should never be NaN.
self.partial_cmp(other).unwrap_or(Ordering::Less)
}
}
impl PartialEq for PivotContainer {
fn eq(&self, other: &Self) -> bool {
self.piv_dist == other.piv_dist
}
}
impl Eq for PivotContainer {}
/// Calculate the Euclidean distance between two vectors
pub fn calc_distance(vec_1: &[f32], vec_2: &[f32], dim: usize) -> f32 {
let mut dist = 0.0;
for j in 0..dim {
let diff = vec_1[j] - vec_2[j];
dist += diff * diff;
}
dist
}
/// Compute L2-squared norms of data stored in row-major num_points * dim,
/// need to be pre-allocated
pub fn compute_vecs_l2sq(vecs_l2sq: &mut [f32], data: &[f32], num_points: usize, dim: usize) {
assert_eq!(vecs_l2sq.len(), num_points);
vecs_l2sq
.par_iter_mut()
.enumerate()
.for_each(|(n_iter, vec_l2sq)| {
let slice = &data[n_iter * dim..(n_iter + 1) * dim];
let norm = unsafe { snrm2(dim as i32, slice, 1) };
*vec_l2sq = norm * norm;
});
}
/// Calculate k closest centers to data of num_points * dim (row-major)
/// Centers is num_centers * dim (row-major)
/// data_l2sq has pre-computed squared norms of data
/// centers_l2sq has pre-computed squared norms of centers
/// Pre-allocated center_index will contain id of nearest center
/// Pre-allocated dist_matrix should be num_points * num_centers and contain squared distances
/// Default value of k is 1
/// Ideally used only by compute_closest_centers
#[allow(clippy::too_many_arguments)]
pub fn compute_closest_centers_in_block(
data: &[f32],
num_points: usize,
dim: usize,
centers: &[f32],
num_centers: usize,
docs_l2sq: &[f32],
centers_l2sq: &[f32],
center_index: &mut [u32],
dist_matrix: &mut [f32],
k: usize,
) -> ANNResult<()> {
if k > num_centers {
return Err(ANNError::log_index_error(format!(
"ERROR: k ({}) > num_centers({})",
k, num_centers
)));
}
let ones_a: Vec<f32> = vec![1.0; num_centers];
let ones_b: Vec<f32> = vec![1.0; num_points];
unsafe {
sgemm(
Layout::RowMajor,
Transpose::None,
Transpose::Ordinary,
num_points as i32,
num_centers as i32,
1,
1.0,
docs_l2sq,
1,
&ones_a,
1,
0.0,
dist_matrix,
num_centers as i32,
);
}
unsafe {
sgemm(
Layout::RowMajor,
Transpose::None,
Transpose::Ordinary,
num_points as i32,
num_centers as i32,
1,
1.0,
&ones_b,
1,
centers_l2sq,
1,
1.0,
dist_matrix,
num_centers as i32,
);
}
unsafe {
sgemm(
Layout::RowMajor,
Transpose::None,
Transpose::Ordinary,
num_points as i32,
num_centers as i32,
dim as i32,
-2.0,
data,
dim as i32,
centers,
dim as i32,
1.0,
dist_matrix,
num_centers as i32,
);
}
if k == 1 {
center_index
.par_iter_mut()
.enumerate()
.for_each(|(i, center_idx)| {
let mut min = f32::MAX;
let current = &dist_matrix[i * num_centers..(i + 1) * num_centers];
let mut min_idx = 0;
for (j, &distance) in current.iter().enumerate() {
if distance < min {
min = distance;
min_idx = j;
}
}
*center_idx = min_idx as u32;
});
} else {
center_index
.par_chunks_mut(k)
.enumerate()
.for_each(|(i, center_chunk)| {
let current = &dist_matrix[i * num_centers..(i + 1) * num_centers];
let mut top_k_queue = BinaryHeap::new();
for (j, &distance) in current.iter().enumerate() {
let this_piv = PivotContainer {
piv_id: j,
piv_dist: distance,
};
if top_k_queue.len() < k {
top_k_queue.push(this_piv);
} else {
// Safe unwrap, top_k_queue is not empty
#[allow(clippy::unwrap_used)]
let mut top = top_k_queue.peek_mut().unwrap();
if this_piv.piv_dist < top.piv_dist {
*top = this_piv;
}
}
}
for (_j, center_idx) in center_chunk.iter_mut().enumerate() {
if let Some(this_piv) = top_k_queue.pop() {
*center_idx = this_piv.piv_id as u32;
} else {
break;
}
}
});
}
Ok(())
}
/// Given data in num_points * new_dim row major
/// Pivots stored in full_pivot_data as num_centers * new_dim row major
/// Calculate the k closest pivot for each point and store it in vector
/// closest_centers_ivf (row major, num_points*k) (which needs to be allocated
/// outside) Additionally, if inverted index is not null (and pre-allocated),
/// it will return inverted index for each center, assuming each of the inverted
/// indices is an empty vector. Additionally, if pts_norms_squared is not null,
/// then it will assume that point norms are pre-computed and use those values
#[allow(clippy::too_many_arguments)]
pub fn compute_closest_centers(
data: &[f32],
num_points: usize,
dim: usize,
pivot_data: &[f32],
num_centers: usize,
k: usize,
closest_centers_ivf: &mut [u32],
mut inverted_index: Option<&mut Vec<Vec<usize>>>,
pts_norms_squared: Option<&[f32]>,
) -> ANNResult<()> {
if k > num_centers {
return Err(ANNError::log_index_error(format!(
"ERROR: k ({}) > num_centers({})",
k, num_centers
)));
}
let _is_norm_given_for_pts = pts_norms_squared.is_some();
let mut pivs_norms_squared = vec![0.0; num_centers];
let mut pts_norms_squared = if let Some(pts_norms) = pts_norms_squared {
pts_norms.to_vec()
} else {
let mut norms_squared = vec![0.0; num_points];
compute_vecs_l2sq(&mut norms_squared, data, num_points, dim);
norms_squared
};
compute_vecs_l2sq(&mut pivs_norms_squared, pivot_data, num_centers, dim);
let par_block_size = num_points;
let n_blocks = if num_points % par_block_size == 0 {
num_points / par_block_size
} else {
num_points / par_block_size + 1
};
let mut closest_centers = vec![0u32; par_block_size * k];
let mut distance_matrix = vec![0.0; num_centers * par_block_size];
for cur_blk in 0..n_blocks {
let data_cur_blk = &data[cur_blk * par_block_size * dim..];
let num_pts_blk = min(par_block_size, num_points - cur_blk * par_block_size);
let pts_norms_blk = &mut pts_norms_squared[cur_blk * par_block_size..];
compute_closest_centers_in_block(
data_cur_blk,
num_pts_blk,
dim,
pivot_data,
num_centers,
pts_norms_blk,
&pivs_norms_squared,
&mut closest_centers,
&mut distance_matrix,
k,
)?;
closest_centers_ivf.clone_from_slice(&closest_centers);
if let Some(inverted_index_inner) = inverted_index.as_mut() {
let inverted_index_arc = Arc::new(Mutex::new(inverted_index_inner));
(0..num_points)
.into_par_iter()
.try_for_each(|j| -> ANNResult<()> {
let this_center_id = closest_centers[j] as usize;
let mut guard = inverted_index_arc.lock().map_err(|err| {
ANNError::log_index_error(format!(
"PoisonError: Lock poisoned when acquiring inverted_index_arc, err={}",
err
))
})?;
guard[this_center_id].push(j);
Ok(())
})?;
}
}
Ok(())
}
/// If to_subtract is true, will subtract nearest center from each row.
/// Else will add.
/// Output will be in data_load itself.
/// Nearest centers need to be provided in closest_centers.
pub fn process_residuals(
data_load: &mut [f32],
num_points: usize,
dim: usize,
cur_pivot_data: &[f32],
num_centers: usize,
closest_centers: &[u32],
to_subtract: bool,
) {
println!(
"Processing residuals of {} points in {} dimensions using {} centers",
num_points, dim, num_centers
);
data_load
.par_chunks_mut(dim)
.enumerate()
.for_each(|(n_iter, chunk)| {
let cur_pivot_index = closest_centers[n_iter] as usize * dim;
for d_iter in 0..dim {
if to_subtract {
chunk[d_iter] -= cur_pivot_data[cur_pivot_index + d_iter];
} else {
chunk[d_iter] += cur_pivot_data[cur_pivot_index + d_iter];
}
}
});
}
#[cfg(test)]
mod math_util_test {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn calc_distance_test() {
let vec1 = vec![1.0, 2.0, 3.0];
let vec2 = vec![4.0, 5.0, 6.0];
let dim = vec1.len();
let dist = calc_distance(&vec1, &vec2, dim);
let expected = 27.0;
assert_eq!(dist, expected);
}
#[test]
fn compute_vecs_l2sq_test() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let num_points = 2;
let dim = 3;
let mut vecs_l2sq = vec![0.0; num_points];
compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim);
let expected = vec![14.0, 77.0];
assert_eq!(vecs_l2sq.len(), num_points);
assert_abs_diff_eq!(vecs_l2sq[0], expected[0], epsilon = 1e-6);
assert_abs_diff_eq!(vecs_l2sq[1], expected[1], epsilon = 1e-6);
}
#[test]
fn compute_closest_centers_in_block_test() {
let num_points = 10;
let dim = 5;
let num_centers = 3;
let data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0,
31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0,
45.0, 46.0, 47.0, 48.0, 49.0, 50.0,
];
let centers = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 31.0, 32.0, 33.0, 34.0, 35.0,
];
let mut docs_l2sq = vec![0.0; num_points];
compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim);
let mut centers_l2sq = vec![0.0; num_centers];
compute_vecs_l2sq(&mut centers_l2sq, &centers, num_centers, dim);
let mut center_index = vec![0; num_points];
let mut dist_matrix = vec![0.0; num_points * num_centers];
let k = 1;
compute_closest_centers_in_block(
&data,
num_points,
dim,
&centers,
num_centers,
&docs_l2sq,
&centers_l2sq,
&mut center_index,
&mut dist_matrix,
k,
)
.unwrap();
assert_eq!(center_index.len(), num_points);
let expected_center_index = vec![0, 0, 0, 1, 1, 1, 2, 2, 2, 2];
assert_abs_diff_eq!(*center_index, expected_center_index);
assert_eq!(dist_matrix.len(), num_points * num_centers);
let expected_dist_matrix = vec![
0.0, 2000.0, 4500.0, 125.0, 1125.0, 3125.0, 500.0, 500.0, 2000.0, 1125.0, 125.0,
1125.0, 2000.0, 0.0, 500.0, 3125.0, 125.0, 125.0, 4500.0, 500.0, 0.0, 6125.0, 1125.0,
125.0, 8000.0, 2000.0, 500.0, 10125.0, 3125.0, 1125.0,
];
assert_abs_diff_eq!(*dist_matrix, expected_dist_matrix, epsilon = 1e-2);
}
#[test]
fn test_compute_closest_centers() {
let num_points = 4;
let dim = 3;
let num_centers = 2;
let mut data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
];
let pivot_data = vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0];
let k = 1;
let mut closest_centers_ivf = vec![0u32; num_points * k];
let mut inverted_index: Vec<Vec<usize>> = vec![vec![], vec![]];
compute_closest_centers(
&data,
num_points,
dim,
&pivot_data,
num_centers,
k,
&mut closest_centers_ivf,
Some(&mut inverted_index),
None,
)
.unwrap();
assert_eq!(closest_centers_ivf, vec![0, 0, 1, 1]);
for vec in inverted_index.iter_mut() {
vec.sort_unstable();
}
assert_eq!(inverted_index, vec![vec![0, 1], vec![2, 3]]);
}
#[test]
fn process_residuals_test() {
let mut data_load = vec![1.0, 2.0, 3.0, 4.0];
let num_points = 2;
let dim = 2;
let cur_pivot_data = vec![0.5, 1.5, 2.5, 3.5];
let num_centers = 2;
let closest_centers = vec![0, 1];
let to_subtract = true;
process_residuals(
&mut data_load,
num_points,
dim,
&cur_pivot_data,
num_centers,
&closest_centers,
to_subtract,
);
assert_eq!(data_load, vec![0.5, 0.5, 0.5, 0.5]);
}
}

View File

@@ -0,0 +1,34 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
pub mod file_util;
pub use file_util::*;
#[allow(clippy::module_inception)]
pub mod utils;
pub use utils::*;
pub mod bit_vec_extension;
pub use bit_vec_extension::*;
pub mod rayon_util;
pub use rayon_util::*;
pub mod timer;
pub use timer::*;
pub mod cached_reader;
pub use cached_reader::*;
pub mod cached_writer;
pub use cached_writer::*;
pub mod partition;
pub use partition::*;
pub mod math_util;
pub use math_util::*;
pub mod kmeans;
pub use kmeans::*;

View File

@@ -0,0 +1,151 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::mem;
use std::{fs::File, path::Path};
use std::io::{Write, Seek, SeekFrom};
use rand::distributions::{Distribution, Uniform};
use crate::common::ANNResult;
use super::CachedReader;
/// streams data from the file, and samples each vector with probability p_val
/// and returns a matrix of size slice_size* ndims as floating point type.
/// the slice_size and ndims are set inside the function.
/// # Arguments
/// * `file_name` - filename where the data is
/// * `p_val` - possibility to sample data
/// * `sampled_vectors` - sampled vector chose by p_val possibility
/// * `slice_size` - how many sampled data return
/// * `dim` - each sample data dimension
pub fn gen_random_slice<T: Default + Copy + Into<f32>>(data_file: &str, mut p_val: f64) -> ANNResult<(Vec<f32>, usize, usize)> {
let read_blk_size = 64 * 1024 * 1024;
let mut reader = CachedReader::new(data_file, read_blk_size)?;
let npts = reader.read_u32()? as usize;
let dim = reader.read_u32()? as usize;
let mut sampled_vectors: Vec<f32> = Vec::new();
let mut slice_size = 0;
p_val = if p_val < 1f64 { p_val } else { 1f64 };
let mut generator = rand::thread_rng();
let distribution = Uniform::from(0.0..1.0);
for _ in 0..npts {
let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::<T>()];
reader.read(&mut cur_vector_bytes)?;
let random_value = distribution.sample(&mut generator);
if random_value < p_val {
let ptr = cur_vector_bytes.as_ptr() as *const T;
let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) };
sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into()));
slice_size += 1;
}
}
Ok((sampled_vectors, slice_size, dim))
}
/// Generate random sample data and write into output_file
pub fn gen_sample_data<T>(data_file: &str, output_file: &str, sampling_rate: f64) -> ANNResult<()> {
let read_blk_size = 64 * 1024 * 1024;
let mut reader = CachedReader::new(data_file, read_blk_size)?;
let sample_data_path = format!("{}_data.bin", output_file);
let sample_ids_path = format!("{}_ids.bin", output_file);
let mut sample_data_writer = File::create(Path::new(&sample_data_path))?;
let mut sample_id_writer = File::create(Path::new(&sample_ids_path))?;
let mut num_sampled_pts = 0u32;
let one_const = 1u32;
let mut generator = rand::thread_rng();
let distribution = Uniform::from(0.0..1.0);
let npts_u32 = reader.read_u32()?;
let dim_u32 = reader.read_u32()?;
let dim = dim_u32 as usize;
sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?;
sample_data_writer.write_all(&dim_u32.to_le_bytes())?;
sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?;
sample_id_writer.write_all(&one_const.to_le_bytes())?;
for id in 0..npts_u32 {
let mut cur_row_bytes = vec![0u8; dim * mem::size_of::<T>()];
reader.read(&mut cur_row_bytes)?;
let random_value = distribution.sample(&mut generator);
if random_value < sampling_rate {
sample_data_writer.write_all(&cur_row_bytes)?;
sample_id_writer.write_all(&id.to_le_bytes())?;
num_sampled_pts += 1;
}
}
sample_data_writer.seek(SeekFrom::Start(0))?;
sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?;
sample_id_writer.seek(SeekFrom::Start(0))?;
sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?;
println!("Wrote {} points to sample file: {}", num_sampled_pts, sample_data_path);
Ok(())
}
#[cfg(test)]
mod partition_test {
use std::{fs, io::Read};
use byteorder::{ReadBytesExt, LittleEndian};
use crate::utils::file_exists;
use super::*;
#[test]
fn gen_sample_data_test() {
let file_name = "gen_sample_data_test.bin";
//npoints=2, dim=8
let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0,
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40,
0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41,
0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41,
0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41];
std::fs::write(file_name, data).expect("Failed to write sample file");
let sample_file_prefix = file_name.to_string() + "_sample";
gen_sample_data::<f32>(file_name, sample_file_prefix.as_str(), 1f64).unwrap();
let sample_data_path = format!("{}_data.bin", sample_file_prefix);
let sample_ids_path = format!("{}_ids.bin", sample_file_prefix);
assert!(file_exists(sample_data_path.as_str()));
assert!(file_exists(sample_ids_path.as_str()));
let mut data_file_reader = File::open(sample_data_path.as_str()).unwrap();
let mut ids_file_reader = File::open(sample_ids_path.as_str()).unwrap();
let mut num_sampled_pts = data_file_reader.read_u32::<LittleEndian>().unwrap();
assert_eq!(num_sampled_pts, 2);
num_sampled_pts = ids_file_reader.read_u32::<LittleEndian>().unwrap();
assert_eq!(num_sampled_pts, 2);
let dim = data_file_reader.read_u32::<LittleEndian>().unwrap() as usize;
assert_eq!(dim, 8);
assert_eq!(ids_file_reader.read_u32::<LittleEndian>().unwrap(), 1);
let mut start = 8;
for i in 0..num_sampled_pts {
let mut data_bytes = vec![0u8; dim * 4];
data_file_reader.read_exact(&mut data_bytes).unwrap();
assert_eq!(data_bytes, data[start..start + dim * 4]);
let id = ids_file_reader.read_u32::<LittleEndian>().unwrap();
assert_eq!(id, i);
start += dim * 4;
}
fs::remove_file(file_name).expect("Failed to delete file");
fs::remove_file(sample_data_path.as_str()).expect("Failed to delete file");
fs::remove_file(sample_ids_path.as_str()).expect("Failed to delete file");
}
}

View File

@@ -0,0 +1,33 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::ops::Range;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::common::ANNResult;
/// based on thread_num, execute the task in parallel using Rayon or serial
#[inline]
pub fn execute_with_rayon<F>(range: Range<usize>, num_threads: u32, f: F) -> ANNResult<()>
where F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy
{
if num_threads == 1 {
for i in range {
f(i)?;
}
Ok(())
} else {
range.into_par_iter().try_for_each(f)
}
}
/// set the thread count of Rayon, otherwise it will use threads as many as logical cores.
#[inline]
pub fn set_rayon_num_threads(num_threads: u32) {
std::env::set_var(
"RAYON_NUM_THREADS",
num_threads.to_string(),
);
}

View File

@@ -0,0 +1,101 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use platform::*;
use std::time::{Duration, Instant};
#[derive(Clone)]
pub struct Timer {
check_point: Instant,
pid: Option<usize>,
cycles: Option<u64>,
}
impl Default for Timer {
fn default() -> Self {
Self::new()
}
}
impl Timer {
pub fn new() -> Timer {
let pid = get_process_handle();
let cycles = get_process_cycle_time(pid);
Timer {
check_point: Instant::now(),
pid,
cycles,
}
}
pub fn reset(&mut self) {
self.check_point = Instant::now();
self.cycles = get_process_cycle_time(self.pid);
}
pub fn elapsed(&self) -> Duration {
Instant::now().duration_since(self.check_point)
}
pub fn elapsed_seconds(&self) -> f64 {
self.elapsed().as_secs_f64()
}
pub fn elapsed_gcycles(&self) -> f32 {
let cur_cycles = get_process_cycle_time(self.pid);
if let (Some(cur_cycles), Some(cycles)) = (cur_cycles, self.cycles) {
let spent_cycles =
((cur_cycles - cycles) as f64 * 1.0f64) / (1024 * 1024 * 1024) as f64;
return spent_cycles as f32;
}
0.0
}
pub fn elapsed_seconds_for_step(&self, step: &str) -> String {
format!(
"Time for {}: {:.3} seconds, {:.3}B cycles",
step,
self.elapsed_seconds(),
self.elapsed_gcycles()
)
}
}
#[cfg(test)]
mod timer_tests {
use super::*;
use std::{thread, time};
#[test]
fn test_new() {
let timer = Timer::new();
assert!(timer.check_point.elapsed().as_secs() < 1);
if cfg!(windows) {
assert!(timer.pid.is_some());
assert!(timer.cycles.is_some());
}
else {
assert!(timer.pid.is_none());
assert!(timer.cycles.is_none());
}
}
#[test]
fn test_reset() {
let mut timer = Timer::new();
thread::sleep(time::Duration::from_millis(100));
timer.reset();
assert!(timer.check_point.elapsed().as_millis() < 10);
}
#[test]
fn test_elapsed() {
let timer = Timer::new();
thread::sleep(time::Duration::from_millis(100));
assert!(timer.elapsed().as_millis() > 100);
assert!(timer.elapsed_seconds() > 0.1);
}
}

View File

@@ -0,0 +1,154 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::sync::Mutex;
use num_traits::Num;
/// Non recursive mutex
pub type NonRecursiveMutex = Mutex<()>;
/// Round up X to the nearest multiple of Y
#[inline]
pub fn round_up<T>(x: T, y: T) -> T
where T : Num + Copy
{
div_round_up(x, y) * y
}
/// Rounded-up division
#[inline]
pub fn div_round_up<T>(x: T, y: T) -> T
where T : Num + Copy
{
(x / y) + if x % y != T::zero() {T::one()} else {T::zero()}
}
/// Round down X to the nearest multiple of Y
#[inline]
pub fn round_down<T>(x: T, y: T) -> T
where T : Num + Copy
{
(x / y) * y
}
/// Is aligned
#[inline]
pub fn is_aligned<T>(x: T, y: T) -> bool
where T : Num + Copy
{
x % y == T::zero()
}
#[inline]
pub fn is_512_aligned(x: u64) -> bool {
is_aligned(x, 512)
}
#[inline]
pub fn is_4096_aligned(x: u64) -> bool {
is_aligned(x, 4096)
}
/// all metadata of individual sub-component files is written in first 4KB for unified files
pub const METADATA_SIZE: usize = 4096;
pub const BUFFER_SIZE_FOR_CACHED_IO: usize = 1024 * 1048576;
pub const PBSTR: &str = "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||";
pub const PBWIDTH: usize = 60;
macro_rules! convert_types {
($name:ident, $intput_type:ty, $output_type:ty) => {
/// Write data into file
pub fn $name(srcmat: &[$intput_type], npts: usize, dim: usize) -> Vec<$output_type> {
let mut destmat: Vec<$output_type> = Vec::new();
for i in 0..npts {
for j in 0..dim {
destmat.push(srcmat[i * dim + j] as $output_type);
}
}
destmat
}
};
}
convert_types!(convert_types_usize_u8, usize, u8);
convert_types!(convert_types_usize_u32, usize, u32);
convert_types!(convert_types_usize_u64, usize, u64);
convert_types!(convert_types_u64_usize, u64, usize);
convert_types!(convert_types_u32_usize, u32, usize);
#[cfg(test)]
mod file_util_test {
use super::*;
use std::any::type_name;
#[test]
fn round_up_test() {
assert_eq!(round_up(252, 8), 256);
assert_eq!(round_up(256, 8), 256);
}
#[test]
fn div_round_up_test() {
assert_eq!(div_round_up(252, 8), 32);
assert_eq!(div_round_up(256, 8), 32);
}
#[test]
fn round_down_test() {
assert_eq!(round_down(252, 8), 248);
assert_eq!(round_down(256, 8), 256);
}
#[test]
fn is_aligned_test() {
assert!(!is_aligned(252, 8));
assert!(is_aligned(256, 8));
}
#[test]
fn is_512_aligned_test() {
assert!(!is_512_aligned(520));
assert!(is_512_aligned(512));
}
#[test]
fn is_4096_aligned_test() {
assert!(!is_4096_aligned(4090));
assert!(is_4096_aligned(4096));
}
#[test]
fn convert_types_test() {
let data = vec![0u64, 1u64, 2u64];
let output = convert_types_u64_usize(&data, 3, 1);
assert_eq!(output.len(), 3);
assert_eq!(type_of(output[0]), "usize");
assert_eq!(output[0], 0usize);
let data = vec![0usize, 1usize, 2usize];
let output = convert_types_usize_u8(&data, 3, 1);
assert_eq!(output.len(), 3);
assert_eq!(type_of(output[0]), "u8");
assert_eq!(output[0], 0u8);
let data = vec![0usize, 1usize, 2usize];
let output = convert_types_usize_u64(&data, 3, 1);
assert_eq!(output.len(), 3);
assert_eq!(type_of(output[0]), "u64");
assert_eq!(output[0], 0u64);
let data = vec![0u32, 1u32, 2u32];
let output = convert_types_u32_usize(&data, 3, 1);
assert_eq!(output.len(), 3);
assert_eq!(type_of(output[0]), "usize");
assert_eq!(output[0],0usize);
}
fn type_of<T>(_: T) -> &'static str {
type_name::<T>()
}
}