Initial commit
This commit is contained in:
24
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/Cargo.toml
vendored
Normal file
24
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/Cargo.toml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
[package]
|
||||
name = "vector"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
half = "2.2.1"
|
||||
thiserror = "1.0.40"
|
||||
bytemuck = "1.7.0"
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1.0.79"
|
||||
|
||||
[dev-dependencies]
|
||||
base64 = "0.21.2"
|
||||
bincode = "1.3.3"
|
||||
serde = "1.0.163"
|
||||
approx = "0.5.1"
|
||||
rand = "0.8.5"
|
||||
|
||||
29
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/build.rs
vendored
Normal file
29
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/build.rs
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
fn main() {
|
||||
println!("cargo:rerun-if-changed=distance.c");
|
||||
if cfg!(target_os = "macos") {
|
||||
std::env::set_var("CFLAGS", "-mavx2 -mfma -Wno-error -MP -O2 -D NDEBUG -D MKL_ILP64 -D USE_AVX2 -D USE_ACCELERATED_PQ -D NOMINMAX -D _TARGET_ARM_APPLE_DARWIN");
|
||||
|
||||
cc::Build::new()
|
||||
.file("distance.c")
|
||||
.warnings_into_errors(true)
|
||||
.debug(false)
|
||||
.target("x86_64-apple-darwin")
|
||||
.compile("nativefunctions.lib");
|
||||
} else {
|
||||
std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /O2 /Ob2 /Zc:inline /fp:fast /D NDEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX2 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot");
|
||||
// std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /Obd /Zc:inline /fp:fast /D DEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX512 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot");
|
||||
|
||||
cc::Build::new()
|
||||
.file("distance.c")
|
||||
.warnings_into_errors(true)
|
||||
.debug(false)
|
||||
.compile("nativefunctions");
|
||||
|
||||
println!("cargo:rustc-link-arg=nativefunctions.lib");
|
||||
}
|
||||
}
|
||||
|
||||
35
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/distance.c
vendored
Normal file
35
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/distance.c
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
#include <immintrin.h>
|
||||
#include <math.h>
|
||||
|
||||
inline __m256i load_128bit_to_256bit(const __m128i *ptr)
|
||||
{
|
||||
__m128i value128 = _mm_loadu_si128(ptr);
|
||||
__m256i value256 = _mm256_castsi128_si256(value128);
|
||||
return _mm256_inserti128_si256(value256, _mm_setzero_si128(), 1);
|
||||
}
|
||||
|
||||
float distance_compare_avx512f_f16(const unsigned char *vec1, const unsigned char *vec2, size_t size)
|
||||
{
|
||||
__m512 sum_squared_diff = _mm512_setzero_ps();
|
||||
|
||||
for (int i = 0; i < size / 16; i += 1)
|
||||
{
|
||||
__m512 v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec1 + i * 2 * 16)));
|
||||
__m512 v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec2 + i * 2 * 16)));
|
||||
|
||||
__m512 diff = _mm512_sub_ps(v1, v2);
|
||||
sum_squared_diff = _mm512_fmadd_ps(diff, diff, sum_squared_diff);
|
||||
}
|
||||
|
||||
size_t i = (size / 16) * 16;
|
||||
|
||||
if (i != size)
|
||||
{
|
||||
__m512 va = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec1 + i * 2)));
|
||||
__m512 vb = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec2 + i * 2)));
|
||||
__m512 diff512 = _mm512_sub_ps(va, vb);
|
||||
sum_squared_diff = _mm512_fmadd_ps(diff512, diff512, sum_squared_diff);
|
||||
}
|
||||
|
||||
return _mm512_reduce_add_ps(sum_squared_diff);
|
||||
}
|
||||
442
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/distance.rs
vendored
Normal file
442
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/distance.rs
vendored
Normal file
@@ -0,0 +1,442 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32};
|
||||
use crate::{Half, Metric};
|
||||
|
||||
/// Distance contract for full-precision vertex
|
||||
pub trait FullPrecisionDistance<T, const N: usize> {
|
||||
/// Get the distance between vertex a and vertex b
|
||||
fn distance_compare(a: &[T; N], b: &[T; N], vec_type: Metric) -> f32;
|
||||
}
|
||||
|
||||
// reason = "Not supported Metric type Metric::Cosine"
|
||||
#[allow(clippy::panic)]
|
||||
impl<const N: usize> FullPrecisionDistance<f32, N> for [f32; N] {
|
||||
/// Calculate distance between two f32 Vertex
|
||||
#[inline(always)]
|
||||
fn distance_compare(a: &[f32; N], b: &[f32; N], metric: Metric) -> f32 {
|
||||
match metric {
|
||||
Metric::L2 => distance_l2_vector_f32::<N>(a, b),
|
||||
_ => panic!("Not supported Metric type {:?}", metric),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reason = "Not supported Metric type Metric::Cosine"
|
||||
#[allow(clippy::panic)]
|
||||
impl<const N: usize> FullPrecisionDistance<Half, N> for [Half; N] {
|
||||
fn distance_compare(a: &[Half; N], b: &[Half; N], metric: Metric) -> f32 {
|
||||
match metric {
|
||||
Metric::L2 => distance_l2_vector_f16::<N>(a, b),
|
||||
_ => panic!("Not supported Metric type {:?}", metric),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reason = "Not yet supported Vector i8"
|
||||
#[allow(clippy::panic)]
|
||||
impl<const N: usize> FullPrecisionDistance<i8, N> for [i8; N] {
|
||||
fn distance_compare(_a: &[i8; N], _b: &[i8; N], _metric: Metric) -> f32 {
|
||||
panic!("Not supported VectorType i8")
|
||||
}
|
||||
}
|
||||
|
||||
// reason = "Not yet supported Vector u8"
|
||||
#[allow(clippy::panic)]
|
||||
impl<const N: usize> FullPrecisionDistance<u8, N> for [u8; N] {
|
||||
fn distance_compare(_a: &[u8; N], _b: &[u8; N], _metric: Metric) -> f32 {
|
||||
panic!("Not supported VectorType u8")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod distance_test {
|
||||
use super::*;
|
||||
|
||||
#[repr(C, align(32))]
|
||||
pub struct F32Slice112([f32; 112]);
|
||||
|
||||
#[repr(C, align(32))]
|
||||
pub struct F16Slice112([Half; 112]);
|
||||
|
||||
fn get_turing_test_data() -> (F32Slice112, F32Slice112) {
|
||||
let a_slice: [f32; 112] = [
|
||||
0.13961786,
|
||||
-0.031577103,
|
||||
-0.09567415,
|
||||
0.06695563,
|
||||
-0.1588727,
|
||||
0.089852564,
|
||||
-0.019837005,
|
||||
0.07497972,
|
||||
0.010418192,
|
||||
-0.054594643,
|
||||
0.08613386,
|
||||
-0.05103466,
|
||||
0.16568437,
|
||||
-0.02703799,
|
||||
0.00728657,
|
||||
-0.15313251,
|
||||
0.16462992,
|
||||
-0.030570814,
|
||||
0.11635703,
|
||||
0.23938893,
|
||||
0.018022912,
|
||||
-0.12646551,
|
||||
0.018048918,
|
||||
-0.035986554,
|
||||
0.031986624,
|
||||
-0.015286017,
|
||||
0.010117953,
|
||||
-0.032691937,
|
||||
0.12163067,
|
||||
-0.04746277,
|
||||
0.010213069,
|
||||
-0.043672588,
|
||||
-0.099362016,
|
||||
0.06599016,
|
||||
-0.19397286,
|
||||
-0.13285528,
|
||||
-0.22040887,
|
||||
0.017690737,
|
||||
-0.104262285,
|
||||
-0.0044555613,
|
||||
-0.07383778,
|
||||
-0.108652934,
|
||||
0.13399786,
|
||||
0.054912474,
|
||||
0.20181285,
|
||||
0.1795591,
|
||||
-0.05425621,
|
||||
-0.10765217,
|
||||
0.1405377,
|
||||
-0.14101997,
|
||||
-0.12017701,
|
||||
0.011565498,
|
||||
0.06952187,
|
||||
0.060136646,
|
||||
0.0023214167,
|
||||
0.04204699,
|
||||
0.048470616,
|
||||
0.17398086,
|
||||
0.024218207,
|
||||
-0.15626553,
|
||||
-0.11291045,
|
||||
-0.09688122,
|
||||
0.14393932,
|
||||
-0.14713104,
|
||||
-0.108876854,
|
||||
0.035279203,
|
||||
-0.05440188,
|
||||
0.017205412,
|
||||
0.011413814,
|
||||
0.04009471,
|
||||
0.11070237,
|
||||
-0.058998976,
|
||||
0.07260045,
|
||||
-0.057893746,
|
||||
-0.0036240944,
|
||||
-0.0064988653,
|
||||
-0.13842176,
|
||||
-0.023219328,
|
||||
0.0035885905,
|
||||
-0.0719257,
|
||||
-0.21335067,
|
||||
0.11415403,
|
||||
-0.0059823603,
|
||||
0.12091869,
|
||||
0.08136634,
|
||||
-0.10769281,
|
||||
0.024518685,
|
||||
0.0009200326,
|
||||
-0.11628049,
|
||||
0.07448965,
|
||||
0.13736208,
|
||||
-0.04144517,
|
||||
-0.16426727,
|
||||
-0.06380103,
|
||||
-0.21386267,
|
||||
0.022373492,
|
||||
-0.05874115,
|
||||
0.017314062,
|
||||
-0.040344074,
|
||||
0.01059176,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
];
|
||||
let b_slice: [f32; 112] = [
|
||||
-0.07209058,
|
||||
-0.17755842,
|
||||
-0.030627966,
|
||||
0.163028,
|
||||
-0.2233766,
|
||||
0.057412963,
|
||||
0.0076995124,
|
||||
-0.017121306,
|
||||
-0.015759075,
|
||||
-0.026947778,
|
||||
-0.010282468,
|
||||
-0.23968373,
|
||||
-0.021486737,
|
||||
-0.09903155,
|
||||
0.09361805,
|
||||
0.0042711576,
|
||||
-0.08695552,
|
||||
-0.042165346,
|
||||
0.064218745,
|
||||
-0.06707651,
|
||||
0.07846054,
|
||||
0.12235762,
|
||||
-0.060716823,
|
||||
0.18496591,
|
||||
-0.13023394,
|
||||
0.022469055,
|
||||
0.056764495,
|
||||
0.07168404,
|
||||
-0.08856144,
|
||||
-0.15343173,
|
||||
0.099879816,
|
||||
-0.033529017,
|
||||
0.0795304,
|
||||
-0.009242254,
|
||||
-0.10254546,
|
||||
0.13086525,
|
||||
-0.101518914,
|
||||
-0.1031299,
|
||||
-0.056826904,
|
||||
0.033196196,
|
||||
0.044143833,
|
||||
-0.049787212,
|
||||
-0.018148342,
|
||||
-0.11172959,
|
||||
-0.06776237,
|
||||
-0.09185828,
|
||||
-0.24171598,
|
||||
0.05080982,
|
||||
-0.0727684,
|
||||
0.045031235,
|
||||
-0.11363879,
|
||||
-0.063389264,
|
||||
0.105850354,
|
||||
-0.19847773,
|
||||
0.08828623,
|
||||
-0.087071925,
|
||||
0.033512704,
|
||||
0.16118294,
|
||||
0.14111553,
|
||||
0.020884402,
|
||||
-0.088860825,
|
||||
0.018745849,
|
||||
0.047522716,
|
||||
-0.03665169,
|
||||
0.15726231,
|
||||
-0.09930561,
|
||||
0.057844743,
|
||||
-0.10532736,
|
||||
-0.091297254,
|
||||
0.067029804,
|
||||
0.04153976,
|
||||
0.06393326,
|
||||
0.054578528,
|
||||
0.0038539872,
|
||||
0.1023088,
|
||||
-0.10653885,
|
||||
-0.108500294,
|
||||
-0.046606563,
|
||||
0.020439683,
|
||||
-0.120957725,
|
||||
-0.13334097,
|
||||
-0.13425854,
|
||||
-0.20481694,
|
||||
0.07009538,
|
||||
0.08660361,
|
||||
-0.0096641015,
|
||||
0.095316306,
|
||||
-0.002898167,
|
||||
-0.19680002,
|
||||
0.08466311,
|
||||
0.04812689,
|
||||
-0.028978813,
|
||||
0.04780206,
|
||||
-0.2001506,
|
||||
-0.036866356,
|
||||
-0.023720587,
|
||||
0.10731964,
|
||||
0.05517358,
|
||||
-0.09580819,
|
||||
0.14595725,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
0.0,
|
||||
];
|
||||
|
||||
(F32Slice112(a_slice), F32Slice112(b_slice))
|
||||
}
|
||||
|
||||
fn get_turing_test_data_f16() -> (F16Slice112, F16Slice112) {
|
||||
let (a_slice, b_slice) = get_turing_test_data();
|
||||
let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
|
||||
let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
|
||||
|
||||
(
|
||||
F16Slice112(a_data.collect::<Vec<Half>>().try_into().unwrap()),
|
||||
F16Slice112(b_data.collect::<Vec<Half>>().try_into().unwrap()),
|
||||
)
|
||||
}
|
||||
|
||||
use crate::test_util::*;
|
||||
use approx::assert_abs_diff_eq;
|
||||
|
||||
#[test]
|
||||
fn test_dist_l2_float_turing() {
|
||||
// two vectors are allocated in the contiguous heap memory
|
||||
let (a_slice, b_slice) = get_turing_test_data();
|
||||
let distance = <[f32; 112] as FullPrecisionDistance<f32, 112>>::distance_compare(
|
||||
&a_slice.0,
|
||||
&b_slice.0,
|
||||
Metric::L2,
|
||||
);
|
||||
|
||||
assert_abs_diff_eq!(
|
||||
distance,
|
||||
no_vector_compare_f32(&a_slice.0, &b_slice.0),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dist_l2_f16_turing() {
|
||||
// two vectors are allocated in the contiguous heap memory
|
||||
let (a_slice, b_slice) = get_turing_test_data_f16();
|
||||
let distance = <[Half; 112] as FullPrecisionDistance<Half, 112>>::distance_compare(
|
||||
&a_slice.0,
|
||||
&b_slice.0,
|
||||
Metric::L2,
|
||||
);
|
||||
|
||||
// Note the variance between the full 32 bit precision and the 16 bit precision
|
||||
assert_eq!(distance, no_vector_compare_f16(&a_slice.0, &b_slice.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn distance_test() {
|
||||
#[repr(C, align(32))]
|
||||
struct Vector32ByteAligned {
|
||||
v: [f32; 512],
|
||||
}
|
||||
|
||||
// two vectors are allocated in the contiguous heap memory
|
||||
let two_vec = Box::new(Vector32ByteAligned {
|
||||
v: [
|
||||
69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
|
||||
20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
|
||||
65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
|
||||
95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
|
||||
71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
|
||||
87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
|
||||
40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
|
||||
66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
|
||||
27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
|
||||
26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
|
||||
35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
|
||||
1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
|
||||
72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
|
||||
62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
|
||||
70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
|
||||
35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
|
||||
31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
|
||||
93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
|
||||
94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
|
||||
80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
|
||||
14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
|
||||
14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
|
||||
64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
|
||||
55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
|
||||
69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
|
||||
76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
|
||||
70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
|
||||
87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
|
||||
57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
|
||||
23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
|
||||
34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
|
||||
94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
|
||||
40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
|
||||
11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
|
||||
95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
|
||||
73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
|
||||
39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
|
||||
14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
|
||||
2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
|
||||
18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
|
||||
23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
|
||||
63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
|
||||
26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
|
||||
28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
|
||||
99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
|
||||
77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
|
||||
72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
|
||||
66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
|
||||
77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
|
||||
18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
|
||||
7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
|
||||
11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
|
||||
51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
|
||||
46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
|
||||
6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
|
||||
3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
|
||||
96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
|
||||
12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
|
||||
44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
|
||||
54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
|
||||
26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
|
||||
40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
|
||||
25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
|
||||
34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
|
||||
86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
|
||||
32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
|
||||
9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
|
||||
50.41517, 28.156603, 42.644154,
|
||||
],
|
||||
});
|
||||
|
||||
let distance = compare::<f32, 256>(256, Metric::L2, &two_vec.v);
|
||||
|
||||
assert_eq!(distance, 429141.2);
|
||||
}
|
||||
|
||||
fn compare<T, const N: usize>(dim: usize, metric: Metric, v: &[f32]) -> f32
|
||||
where
|
||||
for<'a> [T; N]: FullPrecisionDistance<T, N>,
|
||||
{
|
||||
let a_ptr = v.as_ptr();
|
||||
let b_ptr = unsafe { a_ptr.add(dim) };
|
||||
|
||||
let a_ref =
|
||||
<&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(a_ptr, dim) }).unwrap();
|
||||
let b_ref =
|
||||
<&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(b_ptr, dim) }).unwrap();
|
||||
|
||||
<[f32; N]>::distance_compare(a_ref, b_ref, metric)
|
||||
}
|
||||
}
|
||||
152
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/distance_test.rs
vendored
Normal file
152
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/distance_test.rs
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[cfg(test)]
|
||||
mod e2e_test {
|
||||
|
||||
#[repr(C, align(32))]
|
||||
pub struct F32Slice104([f32; 104]);
|
||||
|
||||
#[repr(C, align(32))]
|
||||
pub struct F16Slice104([Half; 104]);
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
|
||||
use crate::half::Half;
|
||||
use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32};
|
||||
|
||||
fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 {
|
||||
let mut sum = 0.0;
|
||||
for i in 0..a.len() {
|
||||
let a_f32 = a[i];
|
||||
let b_f32 = b[i];
|
||||
let diff = a_f32 - b_f32;
|
||||
sum += diff * diff;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
fn no_vector_compare(a: &[Half], b: &[Half]) -> f32 {
|
||||
let mut sum = 0.0;
|
||||
for i in 0..a.len() {
|
||||
let a_f32 = a[i].to_f32();
|
||||
let b_f32 = b[i].to_f32();
|
||||
let diff = a_f32 - b_f32;
|
||||
sum += diff * diff;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avx2_matches_novector() {
|
||||
for i in 1..3 {
|
||||
let (f1, f2) = get_test_data(0, i);
|
||||
|
||||
let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0);
|
||||
let distance = no_vector_compare_f32(&f1.0, &f2.0);
|
||||
|
||||
assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avx2_matches_novector_random() {
|
||||
let (f1, f2) = get_test_data_random();
|
||||
|
||||
let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0);
|
||||
let distance = no_vector_compare_f32(&f1.0, &f2.0);
|
||||
|
||||
assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avx_f16_matches_novector() {
|
||||
for i in 1..3 {
|
||||
let (f1, f2) = get_test_data_f16(0, i);
|
||||
let _a_slice = f1.0.map(|x| x.to_f32().to_string()).join(", ");
|
||||
let _b_slice = f2.0.map(|x| x.to_f32().to_string()).join(", ");
|
||||
|
||||
let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref());
|
||||
let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0);
|
||||
|
||||
assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avx_f16_matches_novector_random() {
|
||||
let (f1, f2) = get_test_data_f16_random();
|
||||
|
||||
let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref());
|
||||
let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0);
|
||||
|
||||
assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4);
|
||||
}
|
||||
|
||||
fn get_test_data_f16(i1: usize, i2: usize) -> (F16Slice104, F16Slice104) {
|
||||
let (a_slice, b_slice) = get_test_data(i1, i2);
|
||||
let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
|
||||
let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
|
||||
|
||||
(
|
||||
F16Slice104(a_data.collect::<Vec<Half>>().try_into().unwrap()),
|
||||
F16Slice104(b_data.collect::<Vec<Half>>().try_into().unwrap()),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_test_data(i1: usize, i2: usize) -> (F32Slice104, F32Slice104) {
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
|
||||
let b64 = general_purpose::STANDARD.decode(TEST_DATA).unwrap();
|
||||
|
||||
let decoded: Vec<Vec<f32>> = bincode::deserialize(&b64).unwrap();
|
||||
debug_assert!(decoded.len() > i1);
|
||||
debug_assert!(decoded.len() > i2);
|
||||
|
||||
let mut f1 = F32Slice104([0.0; 104]);
|
||||
let v1 = &decoded[i1];
|
||||
debug_assert!(v1.len() == 104);
|
||||
f1.0.copy_from_slice(v1);
|
||||
|
||||
let mut f2 = F32Slice104([0.0; 104]);
|
||||
let v2 = &decoded[i2];
|
||||
debug_assert!(v2.len() == 104);
|
||||
f2.0.copy_from_slice(v2);
|
||||
|
||||
(f1, f2)
|
||||
}
|
||||
|
||||
fn get_test_data_f16_random() -> (F16Slice104, F16Slice104) {
|
||||
let (a_slice, b_slice) = get_test_data_random();
|
||||
let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
|
||||
let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
|
||||
|
||||
(
|
||||
F16Slice104(a_data.collect::<Vec<Half>>().try_into().unwrap()),
|
||||
F16Slice104(b_data.collect::<Vec<Half>>().try_into().unwrap()),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_test_data_random() -> (F32Slice104, F32Slice104) {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut f1 = F32Slice104([0.0; 104]);
|
||||
|
||||
for i in 0..104 {
|
||||
f1.0[i] = rng.gen_range(-1.0..1.0);
|
||||
}
|
||||
|
||||
let mut f2 = F32Slice104([0.0; 104]);
|
||||
|
||||
for i in 0..104 {
|
||||
f2.0[i] = rng.gen_range(-1.0..1.0);
|
||||
}
|
||||
|
||||
(f1, f2)
|
||||
}
|
||||
|
||||
const TEST_DATA: &str = "BQAAAAAAAABoAAAAAAAAAPz3Dj7+VgG9z/DDvQkgiT2GryK+nwS4PTeBorz4jpk9ELEqPKKeX73zZrA9uAlRvSqpKT7Gft28LsTuO8XOHL6/lCg+pW/6vJhM7j1fInU+yaSTPC2AAb5T25M8o2YTvWgEAz00cnq8xcUlPPvnBb2AGfk9UmhCvbdUJzwH4jK9UH7Lvdklhz3SoEa+NwsIvt2yYb4q7JA8d4fVvfX/kbtDOJe9boXevbw2CT7n62A9B6hOPlfeNz7CO169vnjcvR3pDz6KZxC+XR/2vTd9PTx7YY492FF2PekiGDt3OSw9IIlGPQooMj5DZcY8EgQgvpg9572paca91GQTPoWpFr7U+t697YAQPYHUXr1d8ow8AQE7PFo6JD3tt+I96ahxvYuvlD3+IW29N4Jtu2/01Ltvvg2+dja+vI8uazvITZO9mXhavpfJ6T2tB8S7OKT3PWWjpj0Mjty9advIPFgucTp3JO69CI6YPaWoDD5pwim9rjUovh2qgr3R/lq+nUi3PI+acL041o081D8lvRCJLTwAAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAA6pJO94NE1voDn+rzQ8CY+1rxkvtspaz0xTPw7+0GMvC0ZgbyWwdy8zHcovKdvdb70BLC8DtHKvdK6vz0R9Ys7vBWyvZK1LL0ehYM9aV+JveuvoD2ilvo9NLJ4vbRnPT4MXAW+BhG4POOBaD0Vz5I9s1+1vTUdHb7Kjcw9uVUJvdbgoj3TbBe8WwPSvYoBBj4m6c+9xTXTvVTDaL28+Ac9KtA0Pa3tS73Vq5S8fNLkvf/Gir0yILy9ZYR3vvUdUD2ZB5W9rHI4PXS76L070oG9EsjYPb89S75pz7Q9xFKyvZ5ECT0kDSU+l4AQPsQVqzyq/LW95ZCZPC6nQj0VIBa9XwkhPr1gy72c7mw937XXvQ76ur3sRok9mCUqPXHvgj28jV89LZN8O0eH0T0KMdq9ZzXevYbmPr0fcac8r7j3vYmKCL4Sewm+iLtRviuOjz08XbE9LlYevDI1wz0s7z278oVJvtpjrT20IEU9+mTtvBjMQz1H9Ey+LQEXva1Rwrxmyts9sf1hPRY3xL3RdRU+AAAAAAAAAAAAAAAAAAAAAGgAAAAAAAAARqSTvbYJpLx1x869cW67PeeJhb7/cBu9m0eFPQO3oL0I+L49YQDavTYSez3SmTg96hBGPuh4oL2x2ow6WdCUO6XUSz4xcU88GReAvVfekj0Ph3Y9z43hvBzT5z1I2my9UVy3vAj8jL08Gtm9CfJcPRihTr1+8Yu9TiP+PNrJa77Dfa09IhpEPesJNr0XzFU8yye3PZKFyz3uzJ09FLRUvYq3l73X4X07DDUzvq9VXjwWtg8+JrzYPcFCkr0jDCg9T9zlvZbZjz4Y8pM89xo8PgAcfbvYSnY8XoFKvO05/L36yzE8J+5yPqfe5r2AZFq8ULRDvnkTgrw+S7q9qGYLvQDZYL1T8d09bFikvZw3+jsYLdO8H3GVveHBYT4gnsE8ZBIJPpzOEj7OSDC+ZYu+vFc1Erzko4M9GqLtPBHH5TwpeRs+miC4PBHH5Tw9Z9k9VUsUPjnppj0oC5C9mcqDvY7y1rxdvZU8PdFAPov9lz0bOmq94kdyPBBokTxtOj89fu4avSsazj1P7iE+x8YkPAAAAAAAAAAAAAAAAAAAAABoAAAAAAAAAHEruT3mgKM8JnEvvAsfHL63906+ifhgvldl1r14OeO9waUyuw3yUzx+PDW9UbDhPQP4Lb4KRRk+Oky2vaLfaT30mrA9YMeZPfzPMz4h42M+XfCHva4AGr6MOSM+iBOzvdsaE7xFxgI+gJGXvVMzE75kHY+8oAWNvVqNK7yOx589fU3lvVVPg730Cwk+DKkEPWYtxjqQ2MK9H0T+vTnGQj2yq5w8L49BvrEJrzyB4Yo9AXV7PYGCLr3MxsG9oWM7PTyu8TzEOhW+dyWrvUTxHD2nL+c9+VKFPcthhLsc0PM8FdyPPeLj/z1WAHS8ZvW2PGg4Cb5u3IU9g4CovSHW+L2CWoG++nZnPAi2ST3HmUC9P5rJuxQbU765lwU+7FLBPUPTfL0uGgk+yKy2PYwXaT1I4I+9AU6VPQ5QaDx9mdE8Qg8zPfGCUjzD/io9rr+BvTNDqT0MFNi9mHatvS1iJD0nVrK78WmIPE0QsL3PAQq9cMRgPWXmmr3yTcw9UcXrPccwa76+cBq+5iVOvUg9c70AAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAB/K7k9hCsnPUJXJr2Wg4a9MEtXve33Sj0VJZ89pciEvWLqwLzUgyu8ADTGPAVenL2UZ/c96YtMved+Wr3LUro9H8a7vGTSA77C5n69Lf3pPQj4KD5cFKq9fZ0uvvYQCT7b23G9XGMCPrGuy736Z9A9kZzFPSuCSD7/9/07Y4/6POxLir3/JBS9qFKMvkSzjryPgVY+ugq8PC9yhbsXaiq+O6WfPcvFK7vZXAy+goAQvXpHHj5jwPI87eokvrySET5QoOm8h8ixOhXzKb5s8+A9sjcJPjiLAz598yQ9yCYSPq6eGz4rvjE82lvGvWuIOLx23zK9hHg8vTWOv70/Tse81fA6Pr2wNz34Eza+2Uj3PZ3trr0aXAI9PCkKPiybe721P9U9QkNLO927jT3LpRA+mpJUvUeU6rwC/Qa+lr4Cvgrpnj1pQ/i9TxhSvJqYr72RS6y8aQLTPQzPiz3vSRY94NfrPJl6LL2adjO8iYfPuhRzZz2f7R8+iVskPcUeXr12ZiI+nd3xvIYv8bwqYlg+AAAAAAAAAAAAAAAAAAAAAA==";
|
||||
}
|
||||
|
||||
82
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/half.rs
vendored
Normal file
82
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/half.rs
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use bytemuck::{Pod, Zeroable};
|
||||
use half::f16;
|
||||
use std::convert::AsRef;
|
||||
use std::fmt;
|
||||
|
||||
// Define the Half type as a new type over f16.
|
||||
// the memory layout of the Half struct will be the same as the memory layout of the f16 type itself.
|
||||
// The Half struct serves as a simple wrapper around the f16 type and does not introduce any additional memory overhead.
|
||||
// Test function:
|
||||
// use half::f16;
|
||||
// pub struct Half(f16);
|
||||
// fn main() {
|
||||
// let size_of_half = std::mem::size_of::<Half>();
|
||||
// let alignment_of_half = std::mem::align_of::<Half>();
|
||||
// println!("Size of Half: {} bytes", size_of_half);
|
||||
// println!("Alignment of Half: {} bytes", alignment_of_half);
|
||||
// }
|
||||
// Output:
|
||||
// Size of Half: 2 bytes
|
||||
// Alignment of Half: 2 bytes
|
||||
pub struct Half(f16);
|
||||
|
||||
unsafe impl Pod for Half {}
|
||||
unsafe impl Zeroable for Half {}
|
||||
|
||||
// Implement From<f32> for Half
|
||||
impl From<Half> for f32 {
|
||||
fn from(val: Half) -> Self {
|
||||
val.0.to_f32()
|
||||
}
|
||||
}
|
||||
|
||||
// Implement AsRef<f16> for Half so that it can be used in distance_compare.
|
||||
impl AsRef<f16> for Half {
|
||||
fn as_ref(&self) -> &f16 {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// Implement From<f32> for Half.
|
||||
impl Half {
|
||||
pub fn from_f32(value: f32) -> Self {
|
||||
Self(f16::from_f32(value))
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Default for Half.
|
||||
impl Default for Half {
|
||||
fn default() -> Self {
|
||||
Self(f16::from_f32(Default::default()))
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Clone for Half.
|
||||
impl Clone for Half {
|
||||
fn clone(&self) -> Self {
|
||||
Half(self.0)
|
||||
}
|
||||
}
|
||||
|
||||
// Implement PartialEq for Half.
|
||||
impl fmt::Debug for Half {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Half({:?})", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Copy for Half {}
|
||||
|
||||
impl Half {
|
||||
pub fn to_f32(&self) -> f32 {
|
||||
self.0.to_f32()
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for Half {}
|
||||
unsafe impl Sync for Half {}
|
||||
|
||||
78
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/l2_float_distance.rs
vendored
Normal file
78
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/l2_float_distance.rs
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
|
||||
//! Distance calculation for L2 Metric
|
||||
|
||||
#[cfg(not(target_feature = "avx2"))]
|
||||
compile_error!("Library must be compiled with -C target-feature=+avx2");
|
||||
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
use crate::Half;
|
||||
|
||||
/// Calculate the distance by vector arithmetic
|
||||
#[inline(never)]
|
||||
pub fn distance_l2_vector_f16<const N: usize>(a: &[Half; N], b: &[Half; N]) -> f32 {
|
||||
debug_assert_eq!(N % 8, 0);
|
||||
|
||||
// make sure the addresses are bytes aligned
|
||||
debug_assert_eq!(a.as_ptr().align_offset(32), 0);
|
||||
debug_assert_eq!(b.as_ptr().align_offset(32), 0);
|
||||
|
||||
unsafe {
|
||||
let mut sum = _mm256_setzero_ps();
|
||||
let a_ptr = a.as_ptr() as *const __m128i;
|
||||
let b_ptr = b.as_ptr() as *const __m128i;
|
||||
|
||||
// Iterate over the elements in steps of 8
|
||||
for i in (0..N).step_by(8) {
|
||||
let a_vec = _mm256_cvtph_ps(_mm_load_si128(a_ptr.add(i / 8)));
|
||||
let b_vec = _mm256_cvtph_ps(_mm_load_si128(b_ptr.add(i / 8)));
|
||||
|
||||
let diff = _mm256_sub_ps(a_vec, b_vec);
|
||||
sum = _mm256_fmadd_ps(diff, diff, sum);
|
||||
}
|
||||
|
||||
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum));
|
||||
/* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
|
||||
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
|
||||
/* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
|
||||
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
||||
/* Conversion to float is a no-op on x86-64 */
|
||||
_mm_cvtss_f32(x32)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the distance by vector arithmetic
|
||||
#[inline(never)]
|
||||
pub fn distance_l2_vector_f32<const N: usize>(a: &[f32; N], b: &[f32; N]) -> f32 {
|
||||
debug_assert_eq!(N % 8, 0);
|
||||
|
||||
// make sure the addresses are bytes aligned
|
||||
debug_assert_eq!(a.as_ptr().align_offset(32), 0);
|
||||
debug_assert_eq!(b.as_ptr().align_offset(32), 0);
|
||||
|
||||
unsafe {
|
||||
let mut sum = _mm256_setzero_ps();
|
||||
|
||||
// Iterate over the elements in steps of 8
|
||||
for i in (0..N).step_by(8) {
|
||||
let a_vec = _mm256_load_ps(&a[i]);
|
||||
let b_vec = _mm256_load_ps(&b[i]);
|
||||
let diff = _mm256_sub_ps(a_vec, b_vec);
|
||||
sum = _mm256_fmadd_ps(diff, diff, sum);
|
||||
}
|
||||
|
||||
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum));
|
||||
/* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
|
||||
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
|
||||
/* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
|
||||
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
|
||||
/* Conversion to float is a no-op on x86-64 */
|
||||
_mm_cvtss_f32(x32)
|
||||
}
|
||||
}
|
||||
|
||||
26
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/lib.rs
vendored
Normal file
26
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/lib.rs
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![cfg_attr(
|
||||
not(test),
|
||||
warn(clippy::panic, clippy::unwrap_used, clippy::expect_used)
|
||||
)]
|
||||
|
||||
// #![feature(stdsimd)]
|
||||
// mod f32x16;
|
||||
// Uncomment above 2 to experiment with f32x16
|
||||
mod distance;
|
||||
mod half;
|
||||
mod l2_float_distance;
|
||||
mod metric;
|
||||
mod utils;
|
||||
|
||||
pub use crate::half::Half;
|
||||
pub use distance::FullPrecisionDistance;
|
||||
pub use metric::Metric;
|
||||
pub use utils::prefetch_vector;
|
||||
|
||||
#[cfg(test)]
|
||||
mod distance_test;
|
||||
mod test_util;
|
||||
36
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/metric.rs
vendored
Normal file
36
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/metric.rs
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#![warn(missing_debug_implementations, missing_docs)]
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Distance metric
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum Metric {
|
||||
/// Squared Euclidean (L2-Squared)
|
||||
L2,
|
||||
|
||||
/// Cosine similarity
|
||||
/// TODO: T should be float for Cosine distance
|
||||
Cosine,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ParseMetricError {
|
||||
#[error("Invalid format for Metric: {0}")]
|
||||
InvalidFormat(String),
|
||||
}
|
||||
|
||||
impl FromStr for Metric {
|
||||
type Err = ParseMetricError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"l2" => Ok(Metric::L2),
|
||||
"cosine" => Ok(Metric::Cosine),
|
||||
_ => Err(ParseMetricError::InvalidFormat(String::from(s))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
29
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/test_util.rs
vendored
Normal file
29
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/test_util.rs
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
#[cfg(test)]
|
||||
use crate::Half;
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn no_vector_compare_f16(a: &[Half], b: &[Half]) -> f32 {
|
||||
let mut sum = 0.0;
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
|
||||
for i in 0..a.len() {
|
||||
sum += (a[i].to_f32() - b[i].to_f32()).powi(2);
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 {
|
||||
let mut sum = 0.0;
|
||||
debug_assert_eq!(a.len(), b.len());
|
||||
|
||||
for i in 0..a.len() {
|
||||
sum += (a[i] - b[i]).powi(2);
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
21
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/utils.rs
vendored
Normal file
21
packages/leann-backend-diskann/third_party/DiskANN/rust/vector/src/utils.rs
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT license.
|
||||
*/
|
||||
use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
|
||||
|
||||
/// Prefetch the given vector in chunks of 64 bytes, which is a cache line size
|
||||
/// NOTE: good efficiency when total_vec_size is integral multiple of 64
|
||||
#[inline]
|
||||
pub fn prefetch_vector<T>(vec: &[T]) {
|
||||
let vec_ptr = vec.as_ptr() as *const i8;
|
||||
let vecsize = std::mem::size_of_val(vec);
|
||||
let max_prefetch_size = (vecsize / 64) * 64;
|
||||
|
||||
for d in (0..max_prefetch_size).step_by(64) {
|
||||
unsafe {
|
||||
_mm_prefetch(vec_ptr.add(d), _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user