Initial commit
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "build_and_insert_delete_memory_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
diskann = { path = "../../diskann" }
|
||||
logger = { path = "../../logger" }
|
||||
vector = { path = "../../vector" }
|
||||
|
||||
@@ -0,0 +1,420 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::env;
|
||||
|
||||
use diskann::{
|
||||
common::{ANNError, ANNResult},
|
||||
index::create_inmem_index,
|
||||
model::{
|
||||
configuration::index_write_parameters::IndexWriteParametersBuilder,
|
||||
vertex::{DIM_104, DIM_128, DIM_256},
|
||||
IndexConfiguration,
|
||||
},
|
||||
utils::round_up,
|
||||
utils::{file_exists, load_ids_to_delete_from_file, load_metadata_from_file, Timer},
|
||||
};
|
||||
|
||||
use vector::{FullPrecisionDistance, Half, Metric};
|
||||
|
||||
// The main function to build an in-memory index
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_and_insert_delete_in_memory_index<T>(
|
||||
metric: Metric,
|
||||
data_path: &str,
|
||||
delta_path: &str,
|
||||
r: u32,
|
||||
l: u32,
|
||||
alpha: f32,
|
||||
save_path: &str,
|
||||
num_threads: u32,
|
||||
_use_pq_build: bool,
|
||||
_num_pq_bytes: usize,
|
||||
use_opq: bool,
|
||||
delete_path: &str,
|
||||
) -> ANNResult<()>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
|
||||
{
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(l, r)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
let (data_num, data_dim) = load_metadata_from_file(data_path)?;
|
||||
|
||||
let config = IndexConfiguration::new(
|
||||
metric,
|
||||
data_dim,
|
||||
round_up(data_dim as u64, 8_u64) as usize,
|
||||
data_num,
|
||||
false,
|
||||
0,
|
||||
use_opq,
|
||||
0,
|
||||
2.0f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let mut index = create_inmem_index::<T>(config)?;
|
||||
|
||||
let timer = Timer::new();
|
||||
|
||||
index.build(data_path, data_num)?;
|
||||
|
||||
let diff = timer.elapsed();
|
||||
|
||||
println!("Initial indexing time: {}", diff.as_secs_f64());
|
||||
|
||||
let (delta_data_num, _) = load_metadata_from_file(delta_path)?;
|
||||
|
||||
index.insert(delta_path, delta_data_num)?;
|
||||
|
||||
if !delete_path.is_empty() {
|
||||
if !file_exists(delete_path) {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"ERROR: Data file for delete {} does not exist.",
|
||||
delete_path
|
||||
)));
|
||||
}
|
||||
|
||||
let (num_points_to_delete, vertex_ids_to_delete) =
|
||||
load_ids_to_delete_from_file(delete_path)?;
|
||||
index.soft_delete(vertex_ids_to_delete, num_points_to_delete)?;
|
||||
}
|
||||
|
||||
index.save(save_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> ANNResult<()> {
|
||||
let mut data_type = String::new();
|
||||
let mut dist_fn = String::new();
|
||||
let mut data_path = String::new();
|
||||
let mut insert_path = String::new();
|
||||
let mut index_path_prefix = String::new();
|
||||
let mut delete_path = String::new();
|
||||
|
||||
let mut num_threads = 0u32;
|
||||
let mut r = 64u32;
|
||||
let mut l = 100u32;
|
||||
|
||||
let mut alpha = 1.2f32;
|
||||
let mut build_pq_bytes = 0u32;
|
||||
let mut _use_pq_build = false;
|
||||
let mut use_opq = false;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut iter = args.iter().skip(1).peekable();
|
||||
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
return Ok(());
|
||||
}
|
||||
"--data_type" => {
|
||||
data_type = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"data_type".to_string(),
|
||||
"Missing data type".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--dist_fn" => {
|
||||
dist_fn = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"dist_fn".to_string(),
|
||||
"Missing distance function".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--data_path" => {
|
||||
data_path = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"data_path".to_string(),
|
||||
"Missing data path".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--insert_path" => {
|
||||
insert_path = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"insert_path".to_string(),
|
||||
"Missing insert path".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--index_path_prefix" => {
|
||||
index_path_prefix = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"index_path_prefix".to_string(),
|
||||
"Missing index path prefix".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--max_degree" | "-R" => {
|
||||
r = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
"Missing max degree".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--Lbuild" | "-L" => {
|
||||
l = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
"Missing build complexity".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--alpha" => {
|
||||
alpha = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"alpha".to_string(),
|
||||
"Missing alpha".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"alpha".to_string(),
|
||||
format!("ParseFloatError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--num_threads" | "-T" => {
|
||||
num_threads = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
"Missing number of threads".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--build_PQ_bytes" => {
|
||||
build_pq_bytes = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
"Missing PQ bytes".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--use_opq" => {
|
||||
use_opq = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
"Missing use_opq flag".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
format!("ParseBoolError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--delete_path" => {
|
||||
delete_path = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"delete_path".to_string(),
|
||||
"Missing delete_path".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"delete_set_path".to_string(),
|
||||
format!("ParseStringError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ANNError::log_index_config_error(
|
||||
String::from(""),
|
||||
format!("Unknown argument: {}", arg),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data_type.is_empty()
|
||||
|| dist_fn.is_empty()
|
||||
|| data_path.is_empty()
|
||||
|| index_path_prefix.is_empty()
|
||||
{
|
||||
return Err(ANNError::log_index_config_error(
|
||||
String::from(""),
|
||||
"Missing required arguments".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
_use_pq_build = build_pq_bytes > 0;
|
||||
|
||||
let metric = dist_fn
|
||||
.parse::<Metric>()
|
||||
.map_err(|err| ANNError::log_index_config_error("dist_fn".to_string(), err.to_string()))?;
|
||||
|
||||
println!(
|
||||
"Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}",
|
||||
r, l, alpha, num_threads
|
||||
);
|
||||
|
||||
match data_type.as_str() {
|
||||
"int8" => {
|
||||
build_and_insert_delete_in_memory_index::<i8>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
&delete_path,
|
||||
)?;
|
||||
}
|
||||
"uint8" => {
|
||||
build_and_insert_delete_in_memory_index::<u8>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
&delete_path,
|
||||
)?;
|
||||
}
|
||||
"float" => {
|
||||
build_and_insert_delete_in_memory_index::<f32>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
&delete_path,
|
||||
)?;
|
||||
}
|
||||
"f16" => {
|
||||
build_and_insert_delete_in_memory_index::<Half>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
&delete_path,
|
||||
)?;
|
||||
}
|
||||
_ => {
|
||||
println!("Unsupported type. Use one of int8, uint8 or float.");
|
||||
return Err(ANNError::log_index_config_error(
|
||||
"data_type".to_string(),
|
||||
"Invalid data type".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Arguments");
|
||||
println!("--help, -h Print information on arguments");
|
||||
println!("--data_type data type <int8/uint8/float> (required)");
|
||||
println!("--dist_fn distance function <l2/cosine> (required)");
|
||||
println!(
|
||||
"--data_path Input data file in bin format for initial build (required)"
|
||||
);
|
||||
println!("--insert_path Input data file in bin format for insert (required)");
|
||||
println!("--index_path_prefix Path prefix for saving index file components (required)");
|
||||
println!("--max_degree, -R Maximum graph degree (default: 64)");
|
||||
println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)");
|
||||
println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)");
|
||||
println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)");
|
||||
println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)");
|
||||
println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "build_and_insert_memory_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
diskann = { path = "../../diskann" }
|
||||
logger = { path = "../../logger" }
|
||||
vector = { path = "../../vector" }
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::env;
|
||||
|
||||
use diskann::{
|
||||
common::{ANNResult, ANNError},
|
||||
index::create_inmem_index,
|
||||
utils::round_up,
|
||||
model::{
|
||||
IndexWriteParametersBuilder,
|
||||
IndexConfiguration,
|
||||
vertex::{DIM_128, DIM_256, DIM_104}
|
||||
},
|
||||
utils::{load_metadata_from_file, Timer},
|
||||
};
|
||||
|
||||
use vector::{Metric, FullPrecisionDistance, Half};
|
||||
|
||||
// The main function to build an in-memory index
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_and_insert_in_memory_index<T> (
|
||||
metric: Metric,
|
||||
data_path: &str,
|
||||
delta_path: &str,
|
||||
r: u32,
|
||||
l: u32,
|
||||
alpha: f32,
|
||||
save_path: &str,
|
||||
num_threads: u32,
|
||||
_use_pq_build: bool,
|
||||
_num_pq_bytes: usize,
|
||||
use_opq: bool
|
||||
) -> ANNResult<()>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>
|
||||
{
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(l, r)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
let (data_num, data_dim) = load_metadata_from_file(data_path)?;
|
||||
|
||||
let config = IndexConfiguration::new(
|
||||
metric,
|
||||
data_dim,
|
||||
round_up(data_dim as u64, 8_u64) as usize,
|
||||
data_num,
|
||||
false,
|
||||
0,
|
||||
use_opq,
|
||||
0,
|
||||
2.0f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let mut index = create_inmem_index::<T>(config)?;
|
||||
|
||||
let timer = Timer::new();
|
||||
|
||||
index.build(data_path, data_num)?;
|
||||
|
||||
let diff = timer.elapsed();
|
||||
|
||||
println!("Initial indexing time: {}", diff.as_secs_f64());
|
||||
|
||||
let (delta_data_num, _) = load_metadata_from_file(delta_path)?;
|
||||
|
||||
index.insert(delta_path, delta_data_num)?;
|
||||
|
||||
index.save(save_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> ANNResult<()> {
|
||||
let mut data_type = String::new();
|
||||
let mut dist_fn = String::new();
|
||||
let mut data_path = String::new();
|
||||
let mut insert_path = String::new();
|
||||
let mut index_path_prefix = String::new();
|
||||
|
||||
let mut num_threads = 0u32;
|
||||
let mut r = 64u32;
|
||||
let mut l = 100u32;
|
||||
|
||||
let mut alpha = 1.2f32;
|
||||
let mut build_pq_bytes = 0u32;
|
||||
let mut _use_pq_build = false;
|
||||
let mut use_opq = false;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut iter = args.iter().skip(1).peekable();
|
||||
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
return Ok(());
|
||||
}
|
||||
"--data_type" => {
|
||||
data_type = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"data_type".to_string(),
|
||||
"Missing data type".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--dist_fn" => {
|
||||
dist_fn = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"dist_fn".to_string(),
|
||||
"Missing distance function".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--data_path" => {
|
||||
data_path = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"data_path".to_string(),
|
||||
"Missing data path".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--insert_path" => {
|
||||
insert_path = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"insert_path".to_string(),
|
||||
"Missing insert path".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--index_path_prefix" => {
|
||||
index_path_prefix = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"index_path_prefix".to_string(),
|
||||
"Missing index path prefix".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--max_degree" | "-R" => {
|
||||
r = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
"Missing max degree".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--Lbuild" | "-L" => {
|
||||
l = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
"Missing build complexity".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--alpha" => {
|
||||
alpha = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"alpha".to_string(),
|
||||
"Missing alpha".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"alpha".to_string(),
|
||||
format!("ParseFloatError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--num_threads" | "-T" => {
|
||||
num_threads = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
"Missing number of threads".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--build_PQ_bytes" => {
|
||||
build_pq_bytes = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
"Missing PQ bytes".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--use_opq" => {
|
||||
use_opq = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
"Missing use_opq flag".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
format!("ParseBoolError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ANNError::log_index_config_error(
|
||||
String::from(""),
|
||||
format!("Unknown argument: {}", arg),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data_type.is_empty()
|
||||
|| dist_fn.is_empty()
|
||||
|| data_path.is_empty()
|
||||
|| index_path_prefix.is_empty()
|
||||
{
|
||||
return Err(ANNError::log_index_config_error(
|
||||
String::from(""),
|
||||
"Missing required arguments".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
_use_pq_build = build_pq_bytes > 0;
|
||||
|
||||
let metric = dist_fn
|
||||
.parse::<Metric>()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"dist_fn".to_string(),
|
||||
err.to_string(),
|
||||
))?;
|
||||
|
||||
println!(
|
||||
"Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}",
|
||||
r, l, alpha, num_threads
|
||||
);
|
||||
|
||||
match data_type.as_str() {
|
||||
"int8" => {
|
||||
build_and_insert_in_memory_index::<i8>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
"uint8" => {
|
||||
build_and_insert_in_memory_index::<u8>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
"float" => {
|
||||
build_and_insert_in_memory_index::<f32>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
"f16" => {
|
||||
build_and_insert_in_memory_index::<Half>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
_ => {
|
||||
println!("Unsupported type. Use one of int8, uint8 or float.");
|
||||
return Err(ANNError::log_index_config_error("data_type".to_string(), "Invalid data type".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Arguments");
|
||||
println!("--help, -h Print information on arguments");
|
||||
println!("--data_type data type <int8/uint8/float> (required)");
|
||||
println!("--dist_fn distance function <l2/cosine> (required)");
|
||||
println!("--data_path Input data file in bin format for initial build (required)");
|
||||
println!("--insert_path Input data file in bin format for insert (required)");
|
||||
println!("--index_path_prefix Path prefix for saving index file components (required)");
|
||||
println!("--max_degree, -R Maximum graph degree (default: 64)");
|
||||
println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)");
|
||||
println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)");
|
||||
println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)");
|
||||
println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)");
|
||||
println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)");
|
||||
}
|
||||
|
||||
14
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/build_disk_index/Cargo.toml
vendored
Normal file
14
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/build_disk_index/Cargo.toml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "build_disk_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
diskann = { path = "../../diskann" }
|
||||
logger = { path = "../../logger" }
|
||||
vector = { path = "../../vector" }
|
||||
openblas-src = { version = "0.10.8", features = ["system", "static"] }
|
||||
377
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/build_disk_index/src/main.rs
vendored
Normal file
377
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/build_disk_index/src/main.rs
vendored
Normal file
@@ -0,0 +1,377 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::env;
|
||||
|
||||
use diskann::{
|
||||
common::{ANNError, ANNResult},
|
||||
index::ann_disk_index::create_disk_index,
|
||||
model::{
|
||||
default_param_vals::ALPHA,
|
||||
vertex::{DIM_104, DIM_128, DIM_256},
|
||||
DiskIndexBuildParameters, IndexConfiguration, IndexWriteParametersBuilder,
|
||||
},
|
||||
storage::DiskIndexStorage,
|
||||
utils::round_up,
|
||||
utils::{load_metadata_from_file, Timer},
|
||||
};
|
||||
|
||||
use vector::{FullPrecisionDistance, Half, Metric};
|
||||
|
||||
/// The main function to build a disk index
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_disk_index<T>(
|
||||
metric: Metric,
|
||||
data_path: &str,
|
||||
r: u32,
|
||||
l: u32,
|
||||
index_path_prefix: &str,
|
||||
num_threads: u32,
|
||||
search_ram_limit_gb: f64,
|
||||
index_build_ram_limit_gb: f64,
|
||||
num_pq_chunks: usize,
|
||||
use_opq: bool,
|
||||
) -> ANNResult<()>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
|
||||
{
|
||||
let disk_index_build_parameters =
|
||||
DiskIndexBuildParameters::new(search_ram_limit_gb, index_build_ram_limit_gb)?;
|
||||
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(l, r)
|
||||
.with_saturate_graph(true)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
let (data_num, data_dim) = load_metadata_from_file(data_path)?;
|
||||
|
||||
let config = IndexConfiguration::new(
|
||||
metric,
|
||||
data_dim,
|
||||
round_up(data_dim as u64, 8_u64) as usize,
|
||||
data_num,
|
||||
num_pq_chunks > 0,
|
||||
num_pq_chunks,
|
||||
use_opq,
|
||||
0,
|
||||
1f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let storage = DiskIndexStorage::new(data_path.to_string(), index_path_prefix.to_string())?;
|
||||
let mut index = create_disk_index::<T>(Some(disk_index_build_parameters), config, storage)?;
|
||||
|
||||
let timer = Timer::new();
|
||||
|
||||
index.build("")?;
|
||||
|
||||
let diff = timer.elapsed();
|
||||
println!("Indexing time: {}", diff.as_secs_f64());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> ANNResult<()> {
|
||||
let mut data_type = String::new();
|
||||
let mut dist_fn = String::new();
|
||||
let mut data_path = String::new();
|
||||
let mut index_path_prefix = String::new();
|
||||
|
||||
let mut num_threads = 0u32;
|
||||
let mut r = 64u32;
|
||||
let mut l = 100u32;
|
||||
let mut search_ram_limit_gb = 0f64;
|
||||
let mut index_build_ram_limit_gb = 0f64;
|
||||
|
||||
let mut build_pq_bytes = 0u32;
|
||||
let mut use_opq = false;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut iter = args.iter().skip(1).peekable();
|
||||
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
return Ok(());
|
||||
}
|
||||
"--data_type" => {
|
||||
data_type = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"data_type".to_string(),
|
||||
"Missing data type".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--dist_fn" => {
|
||||
dist_fn = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"dist_fn".to_string(),
|
||||
"Missing distance function".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--data_path" => {
|
||||
data_path = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"data_path".to_string(),
|
||||
"Missing data path".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--index_path_prefix" => {
|
||||
index_path_prefix = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"index_path_prefix".to_string(),
|
||||
"Missing index path prefix".to_string(),
|
||||
)
|
||||
})?
|
||||
.to_owned();
|
||||
}
|
||||
"--max_degree" | "-R" => {
|
||||
r = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
"Missing max degree".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--Lbuild" | "-L" => {
|
||||
l = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
"Missing build complexity".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--num_threads" | "-T" => {
|
||||
num_threads = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
"Missing number of threads".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--build_PQ_bytes" => {
|
||||
build_pq_bytes = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
"Missing PQ bytes".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
format!("ParseIntError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--use_opq" => {
|
||||
use_opq = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
"Missing use_opq flag".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
format!("ParseBoolError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--search_DRAM_budget" | "-B" => {
|
||||
search_ram_limit_gb = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"search_DRAM_budget".to_string(),
|
||||
"Missing search_DRAM_budget flag".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"search_DRAM_budget".to_string(),
|
||||
format!("ParseBoolError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--build_DRAM_budget" | "-M" => {
|
||||
index_build_ram_limit_gb = iter
|
||||
.next()
|
||||
.ok_or_else(|| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_DRAM_budget".to_string(),
|
||||
"Missing build_DRAM_budget flag".to_string(),
|
||||
)
|
||||
})?
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
"build_DRAM_budget".to_string(),
|
||||
format!("ParseBoolError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ANNError::log_index_config_error(
|
||||
String::from(""),
|
||||
format!("Unknown argument: {}", arg),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data_type.is_empty()
|
||||
|| dist_fn.is_empty()
|
||||
|| data_path.is_empty()
|
||||
|| index_path_prefix.is_empty()
|
||||
{
|
||||
return Err(ANNError::log_index_config_error(
|
||||
String::from(""),
|
||||
"Missing required arguments".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let metric = dist_fn
|
||||
.parse::<Metric>()
|
||||
.map_err(|err| ANNError::log_index_config_error("dist_fn".to_string(), err.to_string()))?;
|
||||
|
||||
println!(
|
||||
"Starting index build with R: {} Lbuild: {} alpha: {} #threads: {} search_DRAM_budget: {} build_DRAM_budget: {}",
|
||||
r, l, ALPHA, num_threads, search_ram_limit_gb, index_build_ram_limit_gb
|
||||
);
|
||||
|
||||
let err = match data_type.as_str() {
|
||||
"int8" => build_disk_index::<i8>(
|
||||
metric,
|
||||
&data_path,
|
||||
r,
|
||||
l,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
search_ram_limit_gb,
|
||||
index_build_ram_limit_gb,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
),
|
||||
"uint8" => build_disk_index::<u8>(
|
||||
metric,
|
||||
&data_path,
|
||||
r,
|
||||
l,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
search_ram_limit_gb,
|
||||
index_build_ram_limit_gb,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
),
|
||||
"float" => build_disk_index::<f32>(
|
||||
metric,
|
||||
&data_path,
|
||||
r,
|
||||
l,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
search_ram_limit_gb,
|
||||
index_build_ram_limit_gb,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
),
|
||||
"f16" => build_disk_index::<Half>(
|
||||
metric,
|
||||
&data_path,
|
||||
r,
|
||||
l,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
search_ram_limit_gb,
|
||||
index_build_ram_limit_gb,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
),
|
||||
_ => {
|
||||
println!("Unsupported type. Use one of int8, uint8, float or f16.");
|
||||
return Err(ANNError::log_index_config_error(
|
||||
"data_type".to_string(),
|
||||
"Invalid data type".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match err {
|
||||
Ok(_) => {
|
||||
println!("Index build completed successfully");
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Error: {:?}", err);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Arguments");
|
||||
println!("--help, -h Print information on arguments");
|
||||
println!("--data_type data type <int8/uint8/float> (required)");
|
||||
println!("--dist_fn distance function <l2/cosine> (required)");
|
||||
println!("--data_path Input data file in bin format (required)");
|
||||
println!("--index_path_prefix Path prefix for saving index file components (required)");
|
||||
println!("--max_degree, -R Maximum graph degree (default: 64)");
|
||||
println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)");
|
||||
println!("--search_DRAM_budget Bound on the memory footprint of the index at search time in GB. Once built, the index will use up only the specified RAM limit, the rest will reside on disk");
|
||||
println!("--build_DRAM_budget Limit on the memory allowed for building the index in GB");
|
||||
println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)");
|
||||
println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)");
|
||||
println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)");
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "build_memory_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.3.8", features = ["derive"] }
|
||||
diskann = { path = "../../diskann" }
|
||||
logger = { path = "../../logger" }
|
||||
vector = { path = "../../vector" }
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
use clap::{Args, Parser};
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
enum DataType {
|
||||
/// Float data type.
|
||||
Float,
|
||||
|
||||
/// Half data type.
|
||||
FP16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
enum DistanceFunction {
|
||||
/// Euclidean distance.
|
||||
L2,
|
||||
|
||||
/// Cosine distance.
|
||||
Cosine,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct BuildMemoryIndexArgs {
|
||||
/// Data type of the vectors.
|
||||
#[clap(long, default_value = "float")]
|
||||
pub data_type: DataType,
|
||||
|
||||
/// Distance function to use.
|
||||
#[clap(long, default_value = "l2")]
|
||||
pub dist_fn: Metric,
|
||||
|
||||
/// Path to the data file. The file should be in the format specified by the `data_type` argument.
|
||||
#[clap(long, short, required = true)]
|
||||
pub data_path: String,
|
||||
|
||||
/// Path to the index file. The index will be saved to this prefixed name.
|
||||
#[clap(long, short, required = true)]
|
||||
pub index_path_prefix: String,
|
||||
|
||||
/// Number of max out degree from a vertex.
|
||||
#[clap(long, default_value = "32")]
|
||||
pub max_degree: usize,
|
||||
|
||||
/// Number of candidates to consider when building out edges
|
||||
#[clap(long, short default_value = "50")]
|
||||
pub l_build: usize,
|
||||
|
||||
/// Alpha to use to build diverse edges
|
||||
#[clap(long, short default_value = "1.0")]
|
||||
pub alpha: f32,
|
||||
|
||||
/// Number of threads to use.
|
||||
#[clap(long, short, default_value = "1")]
|
||||
pub num_threads: u8,
|
||||
|
||||
/// Number of PQ bytes to use.
|
||||
#[clap(long, short, default_value = "8")]
|
||||
pub build_pq_bytes: usize,
|
||||
|
||||
/// Use opq?
|
||||
#[clap(long, short, default_value = "false")]
|
||||
pub use_opq: bool,
|
||||
}
|
||||
174
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/build_memory_index/src/main.rs
vendored
Normal file
174
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/build_memory_index/src/main.rs
vendored
Normal file
@@ -0,0 +1,174 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use diskann::{
|
||||
common::ANNResult,
|
||||
index::create_inmem_index,
|
||||
model::{
|
||||
vertex::{DIM_104, DIM_128, DIM_256},
|
||||
IndexConfiguration, IndexWriteParametersBuilder,
|
||||
},
|
||||
utils::round_up,
|
||||
utils::{load_metadata_from_file, Timer},
|
||||
};
|
||||
|
||||
use vector::{FullPrecisionDistance, Half, Metric};
|
||||
|
||||
/// The main function to build an in-memory index
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn build_in_memory_index<T>(
|
||||
metric: Metric,
|
||||
data_path: &str,
|
||||
r: u32,
|
||||
l: u32,
|
||||
alpha: f32,
|
||||
save_path: &str,
|
||||
num_threads: u32,
|
||||
_use_pq_build: bool,
|
||||
_num_pq_bytes: usize,
|
||||
use_opq: bool,
|
||||
) -> ANNResult<()>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
|
||||
{
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(l, r)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
let (data_num, data_dim) = load_metadata_from_file(data_path)?;
|
||||
|
||||
let config = IndexConfiguration::new(
|
||||
metric,
|
||||
data_dim,
|
||||
round_up(data_dim as u64, 8_u64) as usize,
|
||||
data_num,
|
||||
false,
|
||||
0,
|
||||
use_opq,
|
||||
0,
|
||||
1f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let mut index = create_inmem_index::<T>(config)?;
|
||||
|
||||
let timer = Timer::new();
|
||||
|
||||
index.build(data_path, data_num)?;
|
||||
|
||||
let diff = timer.elapsed();
|
||||
|
||||
println!("Indexing time: {}", diff.as_secs_f64());
|
||||
index.save(save_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> ANNResult<()> {
|
||||
let args = BuildMemoryIndexArgs::parse();
|
||||
|
||||
let _use_pq_build = args.build_pq_bytes > 0;
|
||||
|
||||
println!(
|
||||
"Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}",
|
||||
args.max_degree, args.l_build, args.alpha, args.num_threads
|
||||
);
|
||||
|
||||
let err = match args.data_type {
|
||||
DataType::Float => build_in_memory_index::<f32>(
|
||||
args.dist_fn,
|
||||
&args.data_path.to_string_lossy(),
|
||||
args.max_degree,
|
||||
args.l_build,
|
||||
args.alpha,
|
||||
&args.index_path_prefix,
|
||||
args.num_threads,
|
||||
_use_pq_build,
|
||||
args.build_pq_bytes,
|
||||
args.use_opq,
|
||||
),
|
||||
DataType::FP16 => build_in_memory_index::<Half>(
|
||||
args.dist_fn,
|
||||
&args.data_path.to_string_lossy(),
|
||||
args.max_degree,
|
||||
args.l_build,
|
||||
args.alpha,
|
||||
&args.index_path_prefix,
|
||||
args.num_threads,
|
||||
_use_pq_build,
|
||||
args.build_pq_bytes,
|
||||
args.use_opq,
|
||||
),
|
||||
};
|
||||
|
||||
match err {
|
||||
Ok(_) => {
|
||||
println!("Index build completed successfully");
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Error: {:?}", err);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)]
|
||||
enum DataType {
|
||||
/// Float data type.
|
||||
Float,
|
||||
|
||||
/// Half data type.
|
||||
FP16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
struct BuildMemoryIndexArgs {
|
||||
/// data type <int8/uint8/float / fp16> (required)
|
||||
#[arg(long = "data_type", default_value = "float")]
|
||||
pub data_type: DataType,
|
||||
|
||||
/// Distance function to use.
|
||||
#[arg(long = "dist_fn", default_value = "l2")]
|
||||
pub dist_fn: Metric,
|
||||
|
||||
/// Path to the data file. The file should be in the format specified by the `data_type` argument.
|
||||
#[arg(long = "data_path", short, required = true)]
|
||||
pub data_path: PathBuf,
|
||||
|
||||
/// Path to the index file. The index will be saved to this prefixed name.
|
||||
#[arg(long = "index_path_prefix", short, required = true)]
|
||||
pub index_path_prefix: String,
|
||||
|
||||
/// Number of max out degree from a vertex.
|
||||
#[arg(long = "max_degree", short = 'R', default_value = "64")]
|
||||
pub max_degree: u32,
|
||||
|
||||
/// Number of candidates to consider when building out edges
|
||||
#[arg(long = "l_build", short = 'L', default_value = "100")]
|
||||
pub l_build: u32,
|
||||
|
||||
/// alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter
|
||||
#[arg(long, short, default_value = "1.2")]
|
||||
pub alpha: f32,
|
||||
|
||||
/// Number of threads to use.
|
||||
#[arg(long = "num_threads", short = 'T', default_value = "1")]
|
||||
pub num_threads: u32,
|
||||
|
||||
/// Number of PQ bytes to build the index; 0 for full precision build
|
||||
#[arg(long = "build_pq_bytes", short, default_value = "0")]
|
||||
pub build_pq_bytes: usize,
|
||||
|
||||
/// Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression
|
||||
#[arg(long = "use_opq", short, default_value = "false")]
|
||||
pub use_opq: bool,
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "convert_f32_to_bf16"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
half = "2.2.1"
|
||||
154
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs
vendored
Normal file
154
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use half::{bf16, f16};
|
||||
use std::env;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{self, Read, Write, BufReader, BufWriter};
|
||||
|
||||
enum F16OrBF16 {
|
||||
F16(f16),
|
||||
BF16(bf16),
|
||||
}
|
||||
|
||||
fn main() -> io::Result<()> {
|
||||
// Retrieve command-line arguments
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
match args.len() {
|
||||
3|4|5|6=> {},
|
||||
_ => {
|
||||
print_usage();
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve the input and output file paths from the arguments
|
||||
let input_file_path = &args[1];
|
||||
let output_file_path = &args[2];
|
||||
let use_f16 = args.len() >= 4 && args[3] == "f16";
|
||||
let save_as_float = args.len() >= 5 && args[4] == "save_as_float";
|
||||
let batch_size = if args.len() >= 6 { args[5].parse::<i32>().unwrap() } else { 100000 };
|
||||
println!("use_f16: {}", use_f16);
|
||||
println!("save_as_float: {}", save_as_float);
|
||||
println!("batch_size: {}", batch_size);
|
||||
|
||||
// Open the input file for reading
|
||||
let mut input_file = BufReader::new(File::open(input_file_path)?);
|
||||
|
||||
// Open the output file for writing
|
||||
let mut output_file = BufWriter::new(OpenOptions::new().write(true).create(true).open(output_file_path)?);
|
||||
|
||||
// Read the first 8 bytes as metadata
|
||||
let mut metadata = [0; 8];
|
||||
input_file.read_exact(&mut metadata)?;
|
||||
|
||||
// Write the metadata to the output file
|
||||
output_file.write_all(&metadata)?;
|
||||
|
||||
// Extract the number of points and dimension from the metadata
|
||||
let num_points = i32::from_le_bytes(metadata[..4].try_into().unwrap());
|
||||
let dimension = i32::from_le_bytes(metadata[4..].try_into().unwrap());
|
||||
let num_batches = num_points / batch_size;
|
||||
// Calculate the size of one data point in bytes
|
||||
let data_point_size = (dimension * 4 * batch_size) as usize;
|
||||
let mut batches_processed = 0;
|
||||
let numbers_to_print = 2;
|
||||
let mut numbers_printed = 0;
|
||||
let mut num_fb16_wins = 0;
|
||||
let mut num_f16_wins = 0;
|
||||
let mut bf16_overflow = 0;
|
||||
let mut f16_overflow = 0;
|
||||
|
||||
// Process each data point
|
||||
for _ in 0..num_batches {
|
||||
// Read one data point from the input file
|
||||
let mut buffer = vec![0; data_point_size];
|
||||
match input_file.read_exact(&mut buffer){
|
||||
Ok(()) => {
|
||||
// Convert the float32 data to bf16
|
||||
let half_data: Vec<F16OrBF16> = buffer
|
||||
.chunks_exact(4)
|
||||
.map(|chunk| {
|
||||
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
|
||||
let converted_bf16 = bf16::from_f32(value);
|
||||
let converted_f16 = f16::from_f32(value);
|
||||
let distance_f16 = (converted_f16.to_f32() - value).abs();
|
||||
let distance_bf16 = (converted_bf16.to_f32() - value).abs();
|
||||
|
||||
if distance_f16 < distance_bf16 {
|
||||
num_f16_wins += 1;
|
||||
} else {
|
||||
num_fb16_wins += 1;
|
||||
}
|
||||
|
||||
if (converted_bf16 == bf16::INFINITY) || (converted_bf16 == bf16::NEG_INFINITY) {
|
||||
bf16_overflow += 1;
|
||||
}
|
||||
|
||||
if (converted_f16 == f16::INFINITY) || (converted_f16 == f16::NEG_INFINITY) {
|
||||
f16_overflow += 1;
|
||||
}
|
||||
|
||||
if numbers_printed < numbers_to_print {
|
||||
numbers_printed += 1;
|
||||
println!("f32 value: {} f16 value: {} | distance {}, bf16 value: {} | distance {},",
|
||||
value, converted_f16, converted_f16.to_f32() - value, converted_bf16, converted_bf16.to_f32() - value);
|
||||
}
|
||||
|
||||
if use_f16 {
|
||||
F16OrBF16::F16(converted_f16)
|
||||
} else {
|
||||
F16OrBF16::BF16(converted_bf16)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
batches_processed += 1;
|
||||
|
||||
match save_as_float {
|
||||
true => {
|
||||
for float_val in half_data {
|
||||
match float_val {
|
||||
F16OrBF16::F16(f16_val) => output_file.write_all(&f16_val.to_f32().to_le_bytes())?,
|
||||
F16OrBF16::BF16(bf16_val) => output_file.write_all(&bf16_val.to_f32().to_le_bytes())?,
|
||||
}
|
||||
}
|
||||
}
|
||||
false => {
|
||||
for float_val in half_data {
|
||||
match float_val {
|
||||
F16OrBF16::F16(f16_val) => output_file.write_all(&f16_val.to_le_bytes())?,
|
||||
F16OrBF16::BF16(bf16_val) => output_file.write_all(&bf16_val.to_le_bytes())?,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Print the number of points processed
|
||||
println!("Processed {} points out of {}", batches_processed * batch_size, num_points);
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => {
|
||||
println!("Conversion completed! {} of times f16 wins | overflow count {}, {} of times bf16 wins | overflow count{}",
|
||||
num_f16_wins, f16_overflow, num_fb16_wins, bf16_overflow);
|
||||
break;
|
||||
}
|
||||
Err(err) => {
|
||||
println!("Error: {}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prints the usage information
|
||||
fn print_usage() {
|
||||
println!("Usage: program_name input_file output_file [f16] [save_as_float] [batch_size]]");
|
||||
println!("specify f16 to downscale to f16. otherwise, downscale to bf16.");
|
||||
println!("specify save_as_float to downcast to f16 or bf16, and upcast to float before saving the output data. otherwise, the data will be saved as half type.");
|
||||
println!("specify the batch_size as a int, the default value is 100000.");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "load_and_insert_memory_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
diskann = { path = "../../diskann" }
|
||||
logger = { path = "../../logger" }
|
||||
vector = { path = "../../vector" }
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::env;
|
||||
|
||||
use diskann::{
|
||||
common::{ANNResult, ANNError},
|
||||
index::create_inmem_index,
|
||||
utils::round_up,
|
||||
model::{
|
||||
IndexWriteParametersBuilder,
|
||||
IndexConfiguration,
|
||||
vertex::{DIM_128, DIM_256, DIM_104}
|
||||
},
|
||||
utils::{Timer, load_metadata_from_file},
|
||||
};
|
||||
|
||||
use vector::{Metric, FullPrecisionDistance, Half};
|
||||
|
||||
// The main function to build an in-memory index
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn load_and_insert_in_memory_index<T> (
|
||||
metric: Metric,
|
||||
data_path: &str,
|
||||
delta_path: &str,
|
||||
r: u32,
|
||||
l: u32,
|
||||
alpha: f32,
|
||||
save_path: &str,
|
||||
num_threads: u32,
|
||||
_use_pq_build: bool,
|
||||
_num_pq_bytes: usize,
|
||||
use_opq: bool
|
||||
) -> ANNResult<()>
|
||||
where
|
||||
T: Default + Copy + Sync + Send + Into<f32>,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>
|
||||
{
|
||||
let index_write_parameters = IndexWriteParametersBuilder::new(l, r)
|
||||
.with_alpha(alpha)
|
||||
.with_saturate_graph(false)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
let (data_num, data_dim) = load_metadata_from_file(&format!("{}.data", data_path))?;
|
||||
|
||||
let config = IndexConfiguration::new(
|
||||
metric,
|
||||
data_dim,
|
||||
round_up(data_dim as u64, 8_u64) as usize,
|
||||
data_num,
|
||||
false,
|
||||
0,
|
||||
use_opq,
|
||||
0,
|
||||
2.0f32,
|
||||
index_write_parameters,
|
||||
);
|
||||
let mut index = create_inmem_index::<T>(config)?;
|
||||
|
||||
let timer = Timer::new();
|
||||
|
||||
index.load(data_path, data_num)?;
|
||||
|
||||
let diff = timer.elapsed();
|
||||
|
||||
println!("Initial indexing time: {}", diff.as_secs_f64());
|
||||
|
||||
let (delta_data_num, _) = load_metadata_from_file(delta_path)?;
|
||||
|
||||
index.insert(delta_path, delta_data_num)?;
|
||||
|
||||
index.save(save_path)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> ANNResult<()> {
|
||||
let mut data_type = String::new();
|
||||
let mut dist_fn = String::new();
|
||||
let mut data_path = String::new();
|
||||
let mut insert_path = String::new();
|
||||
let mut index_path_prefix = String::new();
|
||||
|
||||
let mut num_threads = 0u32;
|
||||
let mut r = 64u32;
|
||||
let mut l = 100u32;
|
||||
|
||||
let mut alpha = 1.2f32;
|
||||
let mut build_pq_bytes = 0u32;
|
||||
let mut _use_pq_build = false;
|
||||
let mut use_opq = false;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut iter = args.iter().skip(1).peekable();
|
||||
|
||||
while let Some(arg) = iter.next() {
|
||||
match arg.as_str() {
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
return Ok(());
|
||||
}
|
||||
"--data_type" => {
|
||||
data_type = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"data_type".to_string(),
|
||||
"Missing data type".to_string())
|
||||
)?
|
||||
.to_owned();
|
||||
}
|
||||
"--dist_fn" => {
|
||||
dist_fn = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"dist_fn".to_string(),
|
||||
"Missing distance function".to_string())
|
||||
)?
|
||||
.to_owned();
|
||||
}
|
||||
"--data_path" => {
|
||||
data_path = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"data_path".to_string(),
|
||||
"Missing data path".to_string())
|
||||
)?
|
||||
.to_owned();
|
||||
}
|
||||
"--insert_path" => {
|
||||
insert_path = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"insert_path".to_string(),
|
||||
"Missing insert path".to_string())
|
||||
)?
|
||||
.to_owned();
|
||||
}
|
||||
"--index_path_prefix" => {
|
||||
index_path_prefix = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"index_path_prefix".to_string(),
|
||||
"Missing index path prefix".to_string()))?
|
||||
.to_owned();
|
||||
}
|
||||
"--max_degree" | "-R" => {
|
||||
r = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
"Missing max degree".to_string()))?
|
||||
.parse()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"max_degree".to_string(),
|
||||
format!("ParseIntError: {}", err))
|
||||
)?;
|
||||
}
|
||||
"--Lbuild" | "-L" => {
|
||||
l = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
"Missing build complexity".to_string()))?
|
||||
.parse()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"Lbuild".to_string(),
|
||||
format!("ParseIntError: {}", err))
|
||||
)?;
|
||||
}
|
||||
"--alpha" => {
|
||||
alpha = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"alpha".to_string(),
|
||||
"Missing alpha".to_string()))?
|
||||
.parse()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"alpha".to_string(),
|
||||
format!("ParseFloatError: {}", err))
|
||||
)?;
|
||||
}
|
||||
"--num_threads" | "-T" => {
|
||||
num_threads = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
"Missing number of threads".to_string()))?
|
||||
.parse()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"num_threads".to_string(),
|
||||
format!("ParseIntError: {}", err))
|
||||
)?;
|
||||
}
|
||||
"--build_PQ_bytes" => {
|
||||
build_pq_bytes = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
"Missing PQ bytes".to_string()))?
|
||||
.parse()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"build_PQ_bytes".to_string(),
|
||||
format!("ParseIntError: {}", err))
|
||||
)?;
|
||||
}
|
||||
"--use_opq" => {
|
||||
use_opq = iter.next().ok_or_else(|| ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
"Missing use_opq flag".to_string()))?
|
||||
.parse()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"use_opq".to_string(),
|
||||
format!("ParseBoolError: {}", err))
|
||||
)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ANNError::log_index_config_error(String::from(""), format!("Unknown argument: {}", arg)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if data_type.is_empty()
|
||||
|| dist_fn.is_empty()
|
||||
|| data_path.is_empty()
|
||||
|| index_path_prefix.is_empty()
|
||||
{
|
||||
return Err(ANNError::log_index_config_error(String::from(""), "Missing required arguments".to_string()));
|
||||
}
|
||||
|
||||
_use_pq_build = build_pq_bytes > 0;
|
||||
|
||||
let metric = dist_fn
|
||||
.parse::<Metric>()
|
||||
.map_err(|err| ANNError::log_index_config_error(
|
||||
"dist_fn".to_string(),
|
||||
err.to_string(),
|
||||
))?;
|
||||
|
||||
println!(
|
||||
"Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}",
|
||||
r, l, alpha, num_threads
|
||||
);
|
||||
|
||||
match data_type.as_str() {
|
||||
"int8" => {
|
||||
load_and_insert_in_memory_index::<i8>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
"uint8" => {
|
||||
load_and_insert_in_memory_index::<u8>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
"float" => {
|
||||
load_and_insert_in_memory_index::<f32>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?;
|
||||
}
|
||||
"f16" => {
|
||||
load_and_insert_in_memory_index::<Half>(
|
||||
metric,
|
||||
&data_path,
|
||||
&insert_path,
|
||||
r,
|
||||
l,
|
||||
alpha,
|
||||
&index_path_prefix,
|
||||
num_threads,
|
||||
_use_pq_build,
|
||||
build_pq_bytes as usize,
|
||||
use_opq,
|
||||
)?
|
||||
}
|
||||
_ => {
|
||||
println!("Unsupported type. Use one of int8, uint8 or float.");
|
||||
return Err(ANNError::log_index_config_error("data_type".to_string(), "Invalid data type".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Arguments");
|
||||
println!("--help, -h Print information on arguments");
|
||||
println!("--data_type data type <int8/uint8/float> (required)");
|
||||
println!("--dist_fn distance function <l2/cosine> (required)");
|
||||
println!("--data_path Input data file in bin format for initial build (required)");
|
||||
println!("--insert_path Input data file in bin format for insert (required)");
|
||||
println!("--index_path_prefix Path prefix for saving index file components (required)");
|
||||
println!("--max_degree, -R Maximum graph degree (default: 64)");
|
||||
println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)");
|
||||
println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)");
|
||||
println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)");
|
||||
println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)");
|
||||
println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "search_memory_index"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bytemuck = "1.13.1"
|
||||
diskann = { path = "../../diskann" }
|
||||
num_cpus = "1.15.0"
|
||||
rayon = "1.7.0"
|
||||
vector = { path = "../../vector" }
|
||||
|
||||
430
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/search_memory_index/src/main.rs
vendored
Normal file
430
packages/leann-backend-diskann/third_party/DiskANN/rust/cmd_drivers/search_memory_index/src/main.rs
vendored
Normal file
@@ -0,0 +1,430 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
mod search_index_utils;
|
||||
use bytemuck::Pod;
|
||||
use diskann::{
|
||||
common::{ANNError, ANNResult},
|
||||
index,
|
||||
model::{
|
||||
configuration::index_write_parameters::{default_param_vals, IndexWriteParametersBuilder},
|
||||
vertex::{DIM_104, DIM_128, DIM_256},
|
||||
IndexConfiguration,
|
||||
},
|
||||
utils::{load_metadata_from_file, save_bin_u32},
|
||||
};
|
||||
use std::{env, path::Path, process::exit, time::Instant};
|
||||
use vector::{FullPrecisionDistance, Half, Metric};
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn search_memory_index<T>(
|
||||
metric: Metric,
|
||||
index_path: &str,
|
||||
result_path_prefix: &str,
|
||||
query_file: &str,
|
||||
truthset_file: &str,
|
||||
num_threads: u32,
|
||||
recall_at: u32,
|
||||
print_all_recalls: bool,
|
||||
l_vec: &Vec<u32>,
|
||||
show_qps_per_thread: bool,
|
||||
fail_if_recall_below: f32,
|
||||
) -> ANNResult<i32>
|
||||
where
|
||||
T: Default + Copy + Sized + Pod + Sync + Send + Into<f32>,
|
||||
[T; DIM_104]: FullPrecisionDistance<T, DIM_104>,
|
||||
[T; DIM_128]: FullPrecisionDistance<T, DIM_128>,
|
||||
[T; DIM_256]: FullPrecisionDistance<T, DIM_256>,
|
||||
{
|
||||
// Load the query file
|
||||
let (query, query_num, query_dim, query_aligned_dim) =
|
||||
search_index_utils::load_aligned_bin::<T>(query_file)?;
|
||||
let mut gt_dim: usize = 0;
|
||||
let mut gt_ids: Option<Vec<u32>> = None;
|
||||
let mut gt_dists: Option<Vec<f32>> = None;
|
||||
|
||||
// Check for ground truth
|
||||
let mut calc_recall_flag = false;
|
||||
if !truthset_file.is_empty() && Path::new(truthset_file).exists() {
|
||||
let ret = search_index_utils::load_truthset(truthset_file)?;
|
||||
gt_ids = Some(ret.0);
|
||||
gt_dists = ret.1;
|
||||
let gt_num = ret.2;
|
||||
gt_dim = ret.3;
|
||||
|
||||
if gt_num != query_num {
|
||||
println!("Error. Mismatch in number of queries and ground truth data");
|
||||
}
|
||||
|
||||
calc_recall_flag = true;
|
||||
} else {
|
||||
println!(
|
||||
"Truthset file {} not found. Not computing recall",
|
||||
truthset_file
|
||||
);
|
||||
}
|
||||
|
||||
let num_frozen_pts = search_index_utils::get_graph_num_frozen_points(index_path)?;
|
||||
|
||||
// C++ uses the max given L value, so we do the same here. Max degree is never specified in C++ so use the rust default
|
||||
let index_write_params = IndexWriteParametersBuilder::new(
|
||||
*l_vec.iter().max().unwrap(),
|
||||
default_param_vals::MAX_DEGREE,
|
||||
)
|
||||
.with_num_threads(num_threads)
|
||||
.build();
|
||||
|
||||
let (index_num_points, _) = load_metadata_from_file(&format!("{}.data", index_path))?;
|
||||
|
||||
let index_config = IndexConfiguration::new(
|
||||
metric,
|
||||
query_dim,
|
||||
query_aligned_dim,
|
||||
index_num_points,
|
||||
false,
|
||||
0,
|
||||
false,
|
||||
num_frozen_pts,
|
||||
1f32,
|
||||
index_write_params,
|
||||
);
|
||||
let mut index = index::create_inmem_index::<T>(index_config)?;
|
||||
|
||||
index.load(index_path, index_num_points)?;
|
||||
|
||||
println!("Using {} threads to search", num_threads);
|
||||
let qps_title = if show_qps_per_thread {
|
||||
"QPS/thread"
|
||||
} else {
|
||||
"QPS"
|
||||
};
|
||||
let mut table_width = 4 + 12 + 18 + 20 + 15;
|
||||
let mut table_header_str = format!(
|
||||
"{:>4}{:>12}{:>18}{:>20}{:>15}",
|
||||
"Ls", qps_title, "Avg dist cmps", "Mean Latency (mus)", "99.9 Latency"
|
||||
);
|
||||
|
||||
let first_recall: u32 = if print_all_recalls { 1 } else { recall_at };
|
||||
let mut recalls_to_print: usize = 0;
|
||||
if calc_recall_flag {
|
||||
for curr_recall in first_recall..=recall_at {
|
||||
let recall_str = format!("Recall@{}", curr_recall);
|
||||
table_header_str.push_str(&format!("{:>12}", recall_str));
|
||||
recalls_to_print = (recall_at + 1 - first_recall) as usize;
|
||||
table_width += recalls_to_print * 12;
|
||||
}
|
||||
}
|
||||
|
||||
println!("{}", table_header_str);
|
||||
println!("{}", "=".repeat(table_width));
|
||||
|
||||
let mut query_result_ids: Vec<Vec<u32>> =
|
||||
vec![vec![0; query_num * recall_at as usize]; l_vec.len()];
|
||||
let mut latency_stats: Vec<f32> = vec![0.0; query_num];
|
||||
let mut cmp_stats: Vec<u32> = vec![0; query_num];
|
||||
let mut best_recall = 0.0;
|
||||
|
||||
std::env::set_var("RAYON_NUM_THREADS", num_threads.to_string());
|
||||
|
||||
for test_id in 0..l_vec.len() {
|
||||
let l_value = l_vec[test_id];
|
||||
|
||||
if l_value < recall_at {
|
||||
println!(
|
||||
"Ignoring search with L:{} since it's smaller than K:{}",
|
||||
l_value, recall_at
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let zipped = cmp_stats
|
||||
.par_iter_mut()
|
||||
.zip(latency_stats.par_iter_mut())
|
||||
.zip(query_result_ids[test_id].par_chunks_mut(recall_at as usize))
|
||||
.zip(query.par_chunks(query_aligned_dim));
|
||||
|
||||
let start = Instant::now();
|
||||
zipped.for_each(|(((cmp, latency), query_result), query_chunk)| {
|
||||
let query_start = Instant::now();
|
||||
*cmp = index
|
||||
.search(query_chunk, recall_at as usize, l_value, query_result)
|
||||
.unwrap();
|
||||
|
||||
let query_end = Instant::now();
|
||||
let diff = query_end.duration_since(query_start);
|
||||
*latency = diff.as_micros() as f32;
|
||||
});
|
||||
let diff = Instant::now().duration_since(start);
|
||||
|
||||
let mut displayed_qps: f32 = query_num as f32 / diff.as_secs_f32();
|
||||
if show_qps_per_thread {
|
||||
displayed_qps /= num_threads as f32;
|
||||
}
|
||||
|
||||
let mut recalls: Vec<f32> = Vec::new();
|
||||
if calc_recall_flag {
|
||||
recalls.reserve(recalls_to_print);
|
||||
for curr_recall in first_recall..=recall_at {
|
||||
recalls.push(search_index_utils::calculate_recall(
|
||||
query_num,
|
||||
gt_ids.as_ref().unwrap(),
|
||||
>_dists,
|
||||
gt_dim,
|
||||
&query_result_ids[test_id],
|
||||
recall_at,
|
||||
curr_recall,
|
||||
)? as f32);
|
||||
}
|
||||
}
|
||||
|
||||
latency_stats.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
let mean_latency = latency_stats.iter().sum::<f32>() / query_num as f32;
|
||||
let avg_cmps = cmp_stats.iter().sum::<u32>() as f32 / query_num as f32;
|
||||
|
||||
let mut stat_str = format!(
|
||||
"{: >4}{: >12.2}{: >18.2}{: >20.2}{: >15.2}",
|
||||
l_value,
|
||||
displayed_qps,
|
||||
avg_cmps,
|
||||
mean_latency,
|
||||
latency_stats[(0.999 * query_num as f32).round() as usize]
|
||||
);
|
||||
|
||||
for recall in recalls.iter() {
|
||||
stat_str.push_str(&format!("{: >12.2}", recall));
|
||||
best_recall = f32::max(best_recall, *recall);
|
||||
}
|
||||
|
||||
println!("{}", stat_str);
|
||||
}
|
||||
|
||||
println!("Done searching. Now saving results");
|
||||
for (test_id, l_value) in l_vec.iter().enumerate() {
|
||||
if *l_value < recall_at {
|
||||
println!(
|
||||
"Ignoring all search with L: {} since it's smaller than K: {}",
|
||||
l_value, recall_at
|
||||
);
|
||||
}
|
||||
|
||||
let cur_result_path = format!("{}_{}_idx_uint32.bin", result_path_prefix, l_value);
|
||||
save_bin_u32(
|
||||
&cur_result_path,
|
||||
query_result_ids[test_id].as_slice(),
|
||||
query_num,
|
||||
recall_at as usize,
|
||||
0,
|
||||
)?;
|
||||
}
|
||||
|
||||
if best_recall >= fail_if_recall_below {
|
||||
Ok(0)
|
||||
} else {
|
||||
Ok(-1)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> ANNResult<()> {
|
||||
let return_val: i32;
|
||||
{
|
||||
let mut data_type: String = String::new();
|
||||
let mut metric: Option<Metric> = None;
|
||||
let mut index_path: String = String::new();
|
||||
let mut result_path_prefix: String = String::new();
|
||||
let mut query_file: String = String::new();
|
||||
let mut truthset_file: String = String::new();
|
||||
let mut num_cpus: u32 = num_cpus::get() as u32;
|
||||
let mut recall_at: Option<u32> = None;
|
||||
let mut print_all_recalls: bool = false;
|
||||
let mut l_vec: Vec<u32> = Vec::new();
|
||||
let mut show_qps_per_thread: bool = false;
|
||||
let mut fail_if_recall_below: f32 = 0.0;
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let mut iter = args.iter().skip(1).peekable();
|
||||
while let Some(arg) = iter.next() {
|
||||
let ann_error =
|
||||
|| ANNError::log_index_config_error(String::from(arg), format!("Missing {}", arg));
|
||||
match arg.as_str() {
|
||||
"--help" | "-h" => {
|
||||
print_help();
|
||||
return Ok(());
|
||||
}
|
||||
"--data_type" => {
|
||||
data_type = iter.next().ok_or_else(ann_error)?.to_owned();
|
||||
}
|
||||
"--dist_fn" => {
|
||||
metric = Some(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
String::from(arg),
|
||||
format!("ParseError: {}", err),
|
||||
)
|
||||
})?);
|
||||
}
|
||||
"--index_path_prefix" => {
|
||||
index_path = iter.next().ok_or_else(ann_error)?.to_owned();
|
||||
}
|
||||
"--result_path" => {
|
||||
result_path_prefix = iter.next().ok_or_else(ann_error)?.to_owned();
|
||||
}
|
||||
"--query_file" => {
|
||||
query_file = iter.next().ok_or_else(ann_error)?.to_owned();
|
||||
}
|
||||
"--gt_file" => {
|
||||
truthset_file = iter.next().ok_or_else(ann_error)?.to_owned();
|
||||
}
|
||||
"--recall_at" | "-K" => {
|
||||
recall_at =
|
||||
Some(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
String::from(arg),
|
||||
format!("ParseError: {}", err),
|
||||
)
|
||||
})?);
|
||||
}
|
||||
"--print_all_recalls" => {
|
||||
print_all_recalls = true;
|
||||
}
|
||||
"--search_list" | "-L" => {
|
||||
while iter.peek().is_some() && !iter.peek().unwrap().starts_with('-') {
|
||||
l_vec.push(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
String::from(arg),
|
||||
format!("ParseError: {}", err),
|
||||
)
|
||||
})?);
|
||||
}
|
||||
}
|
||||
"--num_threads" => {
|
||||
num_cpus = iter.next().ok_or_else(ann_error)?.parse().map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
String::from(arg),
|
||||
format!("ParseError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
"--qps_per_thread" => {
|
||||
show_qps_per_thread = true;
|
||||
}
|
||||
"--fail_if_recall_below" => {
|
||||
fail_if_recall_below =
|
||||
iter.next().ok_or_else(ann_error)?.parse().map_err(|err| {
|
||||
ANNError::log_index_config_error(
|
||||
String::from(arg),
|
||||
format!("ParseError: {}", err),
|
||||
)
|
||||
})?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Unknown argument: {}",
|
||||
arg
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if metric.is_none() {
|
||||
return Err(ANNError::log_index_error(String::from("No metric given!")));
|
||||
} else if recall_at.is_none() {
|
||||
return Err(ANNError::log_index_error(String::from(
|
||||
"No recall_at given!",
|
||||
)));
|
||||
}
|
||||
|
||||
// Seems like float is the only supported data type for FullPrecisionDistance right now,
|
||||
// but keep the structure in place here for future data types
|
||||
match data_type.as_str() {
|
||||
"float" => {
|
||||
return_val = search_memory_index::<f32>(
|
||||
metric.unwrap(),
|
||||
&index_path,
|
||||
&result_path_prefix,
|
||||
&query_file,
|
||||
&truthset_file,
|
||||
num_cpus,
|
||||
recall_at.unwrap(),
|
||||
print_all_recalls,
|
||||
&l_vec,
|
||||
show_qps_per_thread,
|
||||
fail_if_recall_below,
|
||||
)?;
|
||||
}
|
||||
"int8" => {
|
||||
return_val = search_memory_index::<i8>(
|
||||
metric.unwrap(),
|
||||
&index_path,
|
||||
&result_path_prefix,
|
||||
&query_file,
|
||||
&truthset_file,
|
||||
num_cpus,
|
||||
recall_at.unwrap(),
|
||||
print_all_recalls,
|
||||
&l_vec,
|
||||
show_qps_per_thread,
|
||||
fail_if_recall_below,
|
||||
)?;
|
||||
}
|
||||
"uint8" => {
|
||||
return_val = search_memory_index::<u8>(
|
||||
metric.unwrap(),
|
||||
&index_path,
|
||||
&result_path_prefix,
|
||||
&query_file,
|
||||
&truthset_file,
|
||||
num_cpus,
|
||||
recall_at.unwrap(),
|
||||
print_all_recalls,
|
||||
&l_vec,
|
||||
show_qps_per_thread,
|
||||
fail_if_recall_below,
|
||||
)?;
|
||||
}
|
||||
"f16" => {
|
||||
return_val = search_memory_index::<Half>(
|
||||
metric.unwrap(),
|
||||
&index_path,
|
||||
&result_path_prefix,
|
||||
&query_file,
|
||||
&truthset_file,
|
||||
num_cpus,
|
||||
recall_at.unwrap(),
|
||||
print_all_recalls,
|
||||
&l_vec,
|
||||
show_qps_per_thread,
|
||||
fail_if_recall_below,
|
||||
)?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Unknown data type: {}!",
|
||||
data_type
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rust only allows returning values with this method, but this will immediately terminate the program without running destructors on the
|
||||
// stack. To get around this enclose main function logic in a block so that by the time we return here all destructors have been called.
|
||||
exit(return_val);
|
||||
}
|
||||
|
||||
fn print_help() {
|
||||
println!("Arguments");
|
||||
println!("--help, -h Print information on arguments");
|
||||
println!("--data_type data type <int8/uint8/float> (required)");
|
||||
println!("--dist_fn distance function <l2/cosine> (required)");
|
||||
println!("--index_path_prefix Path prefix to the index (required)");
|
||||
println!("--result_path Path prefix for saving results of the queries (required)");
|
||||
println!("--query_file Query file in binary format");
|
||||
println!("--gt_file Ground truth file for the queryset");
|
||||
println!("--recall_at, -K Number of neighbors to be returned");
|
||||
println!("--print_all_recalls Print recalls at all positions, from 1 up to specified recall_at value");
|
||||
println!("--search_list List of L values of search");
|
||||
println!("----num_threads, -T Number of threads used for building index (defaults to num_cpus::get())");
|
||||
println!("--qps_per_thread Print overall QPS divided by the number of threads in the output table");
|
||||
println!("--fail_if_recall_below If set to a value >0 and <100%, program returns -1 if best recall found is below this threshold");
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use bytemuck::{cast_slice, Pod};
|
||||
use diskann::{
|
||||
common::{ANNError, ANNResult, AlignedBoxWithSlice},
|
||||
model::data_store::DatasetDto,
|
||||
utils::{copy_aligned_data_from_file, is_aligned, round_up},
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::mem::size_of;
|
||||
|
||||
pub(crate) fn calculate_recall(
|
||||
num_queries: usize,
|
||||
gold_std: &[u32],
|
||||
gs_dist: &Option<Vec<f32>>,
|
||||
dim_gs: usize,
|
||||
our_results: &[u32],
|
||||
dim_or: u32,
|
||||
recall_at: u32,
|
||||
) -> ANNResult<f64> {
|
||||
let mut total_recall: f64 = 0.0;
|
||||
let (mut gt, mut res): (HashSet<u32>, HashSet<u32>) = (HashSet::new(), HashSet::new());
|
||||
|
||||
for i in 0..num_queries {
|
||||
gt.clear();
|
||||
res.clear();
|
||||
|
||||
let gt_slice = &gold_std[dim_gs * i..];
|
||||
let res_slice = &our_results[dim_or as usize * i..];
|
||||
let mut tie_breaker = recall_at as usize;
|
||||
|
||||
if gs_dist.is_some() {
|
||||
tie_breaker = (recall_at - 1) as usize;
|
||||
let gt_dist_vec = &gs_dist.as_ref().unwrap()[dim_gs * i..];
|
||||
while tie_breaker < dim_gs
|
||||
&& gt_dist_vec[tie_breaker] == gt_dist_vec[(recall_at - 1) as usize]
|
||||
{
|
||||
tie_breaker += 1;
|
||||
}
|
||||
}
|
||||
|
||||
(0..tie_breaker).for_each(|idx| {
|
||||
gt.insert(gt_slice[idx]);
|
||||
});
|
||||
|
||||
(0..tie_breaker).for_each(|idx| {
|
||||
res.insert(res_slice[idx]);
|
||||
});
|
||||
|
||||
let mut cur_recall: u32 = 0;
|
||||
for v in gt.iter() {
|
||||
if res.contains(v) {
|
||||
cur_recall += 1;
|
||||
}
|
||||
}
|
||||
|
||||
total_recall += cur_recall as f64;
|
||||
}
|
||||
|
||||
Ok(total_recall / num_queries as f64 * (100.0 / recall_at as f64))
|
||||
}
|
||||
|
||||
pub(crate) fn get_graph_num_frozen_points(graph_file: &str) -> ANNResult<usize> {
|
||||
let mut file = File::open(graph_file)?;
|
||||
let mut usize_buffer = [0; size_of::<usize>()];
|
||||
let mut u32_buffer = [0; size_of::<u32>()];
|
||||
|
||||
file.read_exact(&mut usize_buffer)?;
|
||||
file.read_exact(&mut u32_buffer)?;
|
||||
file.read_exact(&mut u32_buffer)?;
|
||||
file.read_exact(&mut usize_buffer)?;
|
||||
let file_frozen_pts = usize::from_le_bytes(usize_buffer);
|
||||
|
||||
Ok(file_frozen_pts)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn load_truthset(
|
||||
bin_file: &str,
|
||||
) -> ANNResult<(Vec<u32>, Option<Vec<f32>>, usize, usize)> {
|
||||
let mut file = File::open(bin_file)?;
|
||||
let actual_file_size = file.metadata()?.len() as usize;
|
||||
|
||||
let mut buffer = [0; size_of::<i32>()];
|
||||
file.read_exact(&mut buffer)?;
|
||||
let npts = i32::from_le_bytes(buffer) as usize;
|
||||
|
||||
file.read_exact(&mut buffer)?;
|
||||
let dim = i32::from_le_bytes(buffer) as usize;
|
||||
|
||||
println!("Metadata: #pts = {npts}, #dims = {dim}... ");
|
||||
|
||||
let expected_file_size_with_dists: usize =
|
||||
2 * npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
|
||||
let expected_file_size_just_ids: usize = npts * dim * size_of::<u32>() + 2 * size_of::<u32>();
|
||||
|
||||
let truthset_type : i32 = match actual_file_size
|
||||
{
|
||||
// This is in the C++ code, but nothing is done in this case. Keeping it here for future reference just in case.
|
||||
// expected_file_size_just_ids => 2,
|
||||
x if x == expected_file_size_with_dists => 1,
|
||||
_ => return Err(ANNError::log_index_error(format!("Error. File size mismatch. File should have bin format, with npts followed by ngt
|
||||
followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}",
|
||||
actual_file_size,
|
||||
expected_file_size_with_dists,
|
||||
expected_file_size_just_ids)))
|
||||
};
|
||||
|
||||
let mut ids: Vec<u32> = vec![0; npts * dim];
|
||||
let mut buffer = vec![0; npts * dim * size_of::<u32>()];
|
||||
file.read_exact(&mut buffer)?;
|
||||
ids.clone_from_slice(cast_slice::<u8, u32>(&buffer));
|
||||
|
||||
if truthset_type == 1 {
|
||||
let mut dists: Vec<f32> = vec![0.0; npts * dim];
|
||||
let mut buffer = vec![0; npts * dim * size_of::<f32>()];
|
||||
file.read_exact(&mut buffer)?;
|
||||
dists.clone_from_slice(cast_slice::<u8, f32>(&buffer));
|
||||
|
||||
return Ok((ids, Some(dists), npts, dim));
|
||||
}
|
||||
|
||||
Ok((ids, None, npts, dim))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn load_aligned_bin<T: Default + Copy + Sized + Pod>(
|
||||
bin_file: &str,
|
||||
) -> ANNResult<(AlignedBoxWithSlice<T>, usize, usize, usize)> {
|
||||
let t_size = size_of::<T>();
|
||||
let (npts, dim, file_size): (usize, usize, usize);
|
||||
{
|
||||
println!("Reading (with alignment) bin file: {bin_file}");
|
||||
let mut file = File::open(bin_file)?;
|
||||
file_size = file.metadata()?.len() as usize;
|
||||
|
||||
let mut buffer = [0; size_of::<i32>()];
|
||||
file.read_exact(&mut buffer)?;
|
||||
npts = i32::from_le_bytes(buffer) as usize;
|
||||
|
||||
file.read_exact(&mut buffer)?;
|
||||
dim = i32::from_le_bytes(buffer) as usize;
|
||||
}
|
||||
|
||||
let rounded_dim = round_up(dim, 8);
|
||||
let expected_actual_file_size = npts * dim * size_of::<T>() + 2 * size_of::<u32>();
|
||||
|
||||
if file_size != expected_actual_file_size {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"ERROR: File size mismatch. Actual size is {} while expected size is {}
|
||||
npts = {}, #dims = {}, aligned_dim = {}",
|
||||
file_size, expected_actual_file_size, npts, dim, rounded_dim
|
||||
)));
|
||||
}
|
||||
|
||||
println!("Metadata: #pts = {npts}, #dims = {dim}, aligned_dim = {rounded_dim}...");
|
||||
|
||||
let alloc_size = npts * rounded_dim;
|
||||
let alignment = 8 * t_size;
|
||||
println!(
|
||||
"allocating aligned memory of {} bytes... ",
|
||||
alloc_size * t_size
|
||||
);
|
||||
if !is_aligned(alloc_size * t_size, alignment) {
|
||||
return Err(ANNError::log_index_error(format!(
|
||||
"Requested memory size is not a multiple of {}. Can not be allocated.",
|
||||
alignment
|
||||
)));
|
||||
}
|
||||
|
||||
let mut data = AlignedBoxWithSlice::<T>::new(alloc_size, alignment)?;
|
||||
let dto = DatasetDto {
|
||||
data: &mut data,
|
||||
rounded_dim,
|
||||
};
|
||||
|
||||
println!("done. Copying data to mem_aligned buffer...");
|
||||
|
||||
let (_, _) = copy_aligned_data_from_file(bin_file, dto, 0)?;
|
||||
|
||||
Ok((data, npts, dim, rounded_dim))
|
||||
}
|
||||
Reference in New Issue
Block a user