Initial commit

This commit is contained in:
yichuan520030910320
2025-06-30 09:05:05 +00:00
commit 46f6cc100b
1231 changed files with 278432 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
[package]
name = "vector"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
half = "2.2.1"
thiserror = "1.0.40"
bytemuck = "1.7.0"
[build-dependencies]
cc = "1.0.79"
[dev-dependencies]
base64 = "0.21.2"
bincode = "1.3.3"
serde = "1.0.163"
approx = "0.5.1"
rand = "0.8.5"

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
fn main() {
println!("cargo:rerun-if-changed=distance.c");
if cfg!(target_os = "macos") {
std::env::set_var("CFLAGS", "-mavx2 -mfma -Wno-error -MP -O2 -D NDEBUG -D MKL_ILP64 -D USE_AVX2 -D USE_ACCELERATED_PQ -D NOMINMAX -D _TARGET_ARM_APPLE_DARWIN");
cc::Build::new()
.file("distance.c")
.warnings_into_errors(true)
.debug(false)
.target("x86_64-apple-darwin")
.compile("nativefunctions.lib");
} else {
std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /O2 /Ob2 /Zc:inline /fp:fast /D NDEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX2 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot");
// std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /Obd /Zc:inline /fp:fast /D DEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX512 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot");
cc::Build::new()
.file("distance.c")
.warnings_into_errors(true)
.debug(false)
.compile("nativefunctions");
println!("cargo:rustc-link-arg=nativefunctions.lib");
}
}

View File

@@ -0,0 +1,35 @@
#include <immintrin.h>
#include <math.h>
inline __m256i load_128bit_to_256bit(const __m128i *ptr)
{
__m128i value128 = _mm_loadu_si128(ptr);
__m256i value256 = _mm256_castsi128_si256(value128);
return _mm256_inserti128_si256(value256, _mm_setzero_si128(), 1);
}
float distance_compare_avx512f_f16(const unsigned char *vec1, const unsigned char *vec2, size_t size)
{
__m512 sum_squared_diff = _mm512_setzero_ps();
for (int i = 0; i < size / 16; i += 1)
{
__m512 v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec1 + i * 2 * 16)));
__m512 v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec2 + i * 2 * 16)));
__m512 diff = _mm512_sub_ps(v1, v2);
sum_squared_diff = _mm512_fmadd_ps(diff, diff, sum_squared_diff);
}
size_t i = (size / 16) * 16;
if (i != size)
{
__m512 va = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec1 + i * 2)));
__m512 vb = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec2 + i * 2)));
__m512 diff512 = _mm512_sub_ps(va, vb);
sum_squared_diff = _mm512_fmadd_ps(diff512, diff512, sum_squared_diff);
}
return _mm512_reduce_add_ps(sum_squared_diff);
}

View File

@@ -0,0 +1,442 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32};
use crate::{Half, Metric};
/// Distance contract for full-precision vertex
pub trait FullPrecisionDistance<T, const N: usize> {
/// Get the distance between vertex a and vertex b
fn distance_compare(a: &[T; N], b: &[T; N], vec_type: Metric) -> f32;
}
// reason = "Not supported Metric type Metric::Cosine"
#[allow(clippy::panic)]
impl<const N: usize> FullPrecisionDistance<f32, N> for [f32; N] {
/// Calculate distance between two f32 Vertex
#[inline(always)]
fn distance_compare(a: &[f32; N], b: &[f32; N], metric: Metric) -> f32 {
match metric {
Metric::L2 => distance_l2_vector_f32::<N>(a, b),
_ => panic!("Not supported Metric type {:?}", metric),
}
}
}
// reason = "Not supported Metric type Metric::Cosine"
#[allow(clippy::panic)]
impl<const N: usize> FullPrecisionDistance<Half, N> for [Half; N] {
fn distance_compare(a: &[Half; N], b: &[Half; N], metric: Metric) -> f32 {
match metric {
Metric::L2 => distance_l2_vector_f16::<N>(a, b),
_ => panic!("Not supported Metric type {:?}", metric),
}
}
}
// reason = "Not yet supported Vector i8"
#[allow(clippy::panic)]
impl<const N: usize> FullPrecisionDistance<i8, N> for [i8; N] {
fn distance_compare(_a: &[i8; N], _b: &[i8; N], _metric: Metric) -> f32 {
panic!("Not supported VectorType i8")
}
}
// reason = "Not yet supported Vector u8"
#[allow(clippy::panic)]
impl<const N: usize> FullPrecisionDistance<u8, N> for [u8; N] {
fn distance_compare(_a: &[u8; N], _b: &[u8; N], _metric: Metric) -> f32 {
panic!("Not supported VectorType u8")
}
}
#[cfg(test)]
mod distance_test {
use super::*;
#[repr(C, align(32))]
pub struct F32Slice112([f32; 112]);
#[repr(C, align(32))]
pub struct F16Slice112([Half; 112]);
fn get_turing_test_data() -> (F32Slice112, F32Slice112) {
let a_slice: [f32; 112] = [
0.13961786,
-0.031577103,
-0.09567415,
0.06695563,
-0.1588727,
0.089852564,
-0.019837005,
0.07497972,
0.010418192,
-0.054594643,
0.08613386,
-0.05103466,
0.16568437,
-0.02703799,
0.00728657,
-0.15313251,
0.16462992,
-0.030570814,
0.11635703,
0.23938893,
0.018022912,
-0.12646551,
0.018048918,
-0.035986554,
0.031986624,
-0.015286017,
0.010117953,
-0.032691937,
0.12163067,
-0.04746277,
0.010213069,
-0.043672588,
-0.099362016,
0.06599016,
-0.19397286,
-0.13285528,
-0.22040887,
0.017690737,
-0.104262285,
-0.0044555613,
-0.07383778,
-0.108652934,
0.13399786,
0.054912474,
0.20181285,
0.1795591,
-0.05425621,
-0.10765217,
0.1405377,
-0.14101997,
-0.12017701,
0.011565498,
0.06952187,
0.060136646,
0.0023214167,
0.04204699,
0.048470616,
0.17398086,
0.024218207,
-0.15626553,
-0.11291045,
-0.09688122,
0.14393932,
-0.14713104,
-0.108876854,
0.035279203,
-0.05440188,
0.017205412,
0.011413814,
0.04009471,
0.11070237,
-0.058998976,
0.07260045,
-0.057893746,
-0.0036240944,
-0.0064988653,
-0.13842176,
-0.023219328,
0.0035885905,
-0.0719257,
-0.21335067,
0.11415403,
-0.0059823603,
0.12091869,
0.08136634,
-0.10769281,
0.024518685,
0.0009200326,
-0.11628049,
0.07448965,
0.13736208,
-0.04144517,
-0.16426727,
-0.06380103,
-0.21386267,
0.022373492,
-0.05874115,
0.017314062,
-0.040344074,
0.01059176,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
];
let b_slice: [f32; 112] = [
-0.07209058,
-0.17755842,
-0.030627966,
0.163028,
-0.2233766,
0.057412963,
0.0076995124,
-0.017121306,
-0.015759075,
-0.026947778,
-0.010282468,
-0.23968373,
-0.021486737,
-0.09903155,
0.09361805,
0.0042711576,
-0.08695552,
-0.042165346,
0.064218745,
-0.06707651,
0.07846054,
0.12235762,
-0.060716823,
0.18496591,
-0.13023394,
0.022469055,
0.056764495,
0.07168404,
-0.08856144,
-0.15343173,
0.099879816,
-0.033529017,
0.0795304,
-0.009242254,
-0.10254546,
0.13086525,
-0.101518914,
-0.1031299,
-0.056826904,
0.033196196,
0.044143833,
-0.049787212,
-0.018148342,
-0.11172959,
-0.06776237,
-0.09185828,
-0.24171598,
0.05080982,
-0.0727684,
0.045031235,
-0.11363879,
-0.063389264,
0.105850354,
-0.19847773,
0.08828623,
-0.087071925,
0.033512704,
0.16118294,
0.14111553,
0.020884402,
-0.088860825,
0.018745849,
0.047522716,
-0.03665169,
0.15726231,
-0.09930561,
0.057844743,
-0.10532736,
-0.091297254,
0.067029804,
0.04153976,
0.06393326,
0.054578528,
0.0038539872,
0.1023088,
-0.10653885,
-0.108500294,
-0.046606563,
0.020439683,
-0.120957725,
-0.13334097,
-0.13425854,
-0.20481694,
0.07009538,
0.08660361,
-0.0096641015,
0.095316306,
-0.002898167,
-0.19680002,
0.08466311,
0.04812689,
-0.028978813,
0.04780206,
-0.2001506,
-0.036866356,
-0.023720587,
0.10731964,
0.05517358,
-0.09580819,
0.14595725,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
];
(F32Slice112(a_slice), F32Slice112(b_slice))
}
fn get_turing_test_data_f16() -> (F16Slice112, F16Slice112) {
let (a_slice, b_slice) = get_turing_test_data();
let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
(
F16Slice112(a_data.collect::<Vec<Half>>().try_into().unwrap()),
F16Slice112(b_data.collect::<Vec<Half>>().try_into().unwrap()),
)
}
use crate::test_util::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_dist_l2_float_turing() {
// two vectors are allocated in the contiguous heap memory
let (a_slice, b_slice) = get_turing_test_data();
let distance = <[f32; 112] as FullPrecisionDistance<f32, 112>>::distance_compare(
&a_slice.0,
&b_slice.0,
Metric::L2,
);
assert_abs_diff_eq!(
distance,
no_vector_compare_f32(&a_slice.0, &b_slice.0),
epsilon = 1e-6
);
}
#[test]
fn test_dist_l2_f16_turing() {
// two vectors are allocated in the contiguous heap memory
let (a_slice, b_slice) = get_turing_test_data_f16();
let distance = <[Half; 112] as FullPrecisionDistance<Half, 112>>::distance_compare(
&a_slice.0,
&b_slice.0,
Metric::L2,
);
// Note the variance between the full 32 bit precision and the 16 bit precision
assert_eq!(distance, no_vector_compare_f16(&a_slice.0, &b_slice.0));
}
#[test]
fn distance_test() {
#[repr(C, align(32))]
struct Vector32ByteAligned {
v: [f32; 512],
}
// two vectors are allocated in the contiguous heap memory
let two_vec = Box::new(Vector32ByteAligned {
v: [
69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287,
20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589,
65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637,
95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185,
71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162,
87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187,
40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308,
66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993,
27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395,
26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126,
35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152,
1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572,
72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074,
62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238,
70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598,
35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985,
31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636,
93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104,
94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884,
80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782,
14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327,
14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626,
64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891,
55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901,
69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995,
76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308,
70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312,
87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276,
57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271,
23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204,
34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144,
94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746,
40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894,
11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394,
95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036,
73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184,
39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751,
14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984,
2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616,
18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644,
23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518,
63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944,
26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664,
28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457,
99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704,
77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221,
72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812,
66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263,
77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177,
18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946,
7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347,
11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583,
51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745,
46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412,
6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434,
3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655,
96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231,
12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898,
44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808,
54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903,
26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206,
40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142,
25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276,
34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146,
86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614,
32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845,
9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698,
50.41517, 28.156603, 42.644154,
],
});
let distance = compare::<f32, 256>(256, Metric::L2, &two_vec.v);
assert_eq!(distance, 429141.2);
}
fn compare<T, const N: usize>(dim: usize, metric: Metric, v: &[f32]) -> f32
where
for<'a> [T; N]: FullPrecisionDistance<T, N>,
{
let a_ptr = v.as_ptr();
let b_ptr = unsafe { a_ptr.add(dim) };
let a_ref =
<&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(a_ptr, dim) }).unwrap();
let b_ref =
<&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(b_ptr, dim) }).unwrap();
<[f32; N]>::distance_compare(a_ref, b_ref, metric)
}
}

View File

@@ -0,0 +1,152 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[cfg(test)]
mod e2e_test {
#[repr(C, align(32))]
pub struct F32Slice104([f32; 104]);
#[repr(C, align(32))]
pub struct F16Slice104([Half; 104]);
use approx::assert_abs_diff_eq;
use crate::half::Half;
use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32};
fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0;
for i in 0..a.len() {
let a_f32 = a[i];
let b_f32 = b[i];
let diff = a_f32 - b_f32;
sum += diff * diff;
}
sum
}
fn no_vector_compare(a: &[Half], b: &[Half]) -> f32 {
let mut sum = 0.0;
for i in 0..a.len() {
let a_f32 = a[i].to_f32();
let b_f32 = b[i].to_f32();
let diff = a_f32 - b_f32;
sum += diff * diff;
}
sum
}
#[test]
fn avx2_matches_novector() {
for i in 1..3 {
let (f1, f2) = get_test_data(0, i);
let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0);
let distance = no_vector_compare_f32(&f1.0, &f2.0);
assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-6);
}
}
#[test]
fn avx2_matches_novector_random() {
let (f1, f2) = get_test_data_random();
let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0);
let distance = no_vector_compare_f32(&f1.0, &f2.0);
assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-4);
}
#[test]
fn avx_f16_matches_novector() {
for i in 1..3 {
let (f1, f2) = get_test_data_f16(0, i);
let _a_slice = f1.0.map(|x| x.to_f32().to_string()).join(", ");
let _b_slice = f2.0.map(|x| x.to_f32().to_string()).join(", ");
let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref());
let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0);
assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4);
}
}
#[test]
fn avx_f16_matches_novector_random() {
let (f1, f2) = get_test_data_f16_random();
let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref());
let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0);
assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4);
}
fn get_test_data_f16(i1: usize, i2: usize) -> (F16Slice104, F16Slice104) {
let (a_slice, b_slice) = get_test_data(i1, i2);
let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
(
F16Slice104(a_data.collect::<Vec<Half>>().try_into().unwrap()),
F16Slice104(b_data.collect::<Vec<Half>>().try_into().unwrap()),
)
}
fn get_test_data(i1: usize, i2: usize) -> (F32Slice104, F32Slice104) {
use base64::{engine::general_purpose, Engine as _};
let b64 = general_purpose::STANDARD.decode(TEST_DATA).unwrap();
let decoded: Vec<Vec<f32>> = bincode::deserialize(&b64).unwrap();
debug_assert!(decoded.len() > i1);
debug_assert!(decoded.len() > i2);
let mut f1 = F32Slice104([0.0; 104]);
let v1 = &decoded[i1];
debug_assert!(v1.len() == 104);
f1.0.copy_from_slice(v1);
let mut f2 = F32Slice104([0.0; 104]);
let v2 = &decoded[i2];
debug_assert!(v2.len() == 104);
f2.0.copy_from_slice(v2);
(f1, f2)
}
fn get_test_data_f16_random() -> (F16Slice104, F16Slice104) {
let (a_slice, b_slice) = get_test_data_random();
let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x));
let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x));
(
F16Slice104(a_data.collect::<Vec<Half>>().try_into().unwrap()),
F16Slice104(b_data.collect::<Vec<Half>>().try_into().unwrap()),
)
}
fn get_test_data_random() -> (F32Slice104, F32Slice104) {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut f1 = F32Slice104([0.0; 104]);
for i in 0..104 {
f1.0[i] = rng.gen_range(-1.0..1.0);
}
let mut f2 = F32Slice104([0.0; 104]);
for i in 0..104 {
f2.0[i] = rng.gen_range(-1.0..1.0);
}
(f1, f2)
}
const TEST_DATA: &str = "BQAAAAAAAABoAAAAAAAAAPz3Dj7+VgG9z/DDvQkgiT2GryK+nwS4PTeBorz4jpk9ELEqPKKeX73zZrA9uAlRvSqpKT7Gft28LsTuO8XOHL6/lCg+pW/6vJhM7j1fInU+yaSTPC2AAb5T25M8o2YTvWgEAz00cnq8xcUlPPvnBb2AGfk9UmhCvbdUJzwH4jK9UH7Lvdklhz3SoEa+NwsIvt2yYb4q7JA8d4fVvfX/kbtDOJe9boXevbw2CT7n62A9B6hOPlfeNz7CO169vnjcvR3pDz6KZxC+XR/2vTd9PTx7YY492FF2PekiGDt3OSw9IIlGPQooMj5DZcY8EgQgvpg9572paca91GQTPoWpFr7U+t697YAQPYHUXr1d8ow8AQE7PFo6JD3tt+I96ahxvYuvlD3+IW29N4Jtu2/01Ltvvg2+dja+vI8uazvITZO9mXhavpfJ6T2tB8S7OKT3PWWjpj0Mjty9advIPFgucTp3JO69CI6YPaWoDD5pwim9rjUovh2qgr3R/lq+nUi3PI+acL041o081D8lvRCJLTwAAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAA6pJO94NE1voDn+rzQ8CY+1rxkvtspaz0xTPw7+0GMvC0ZgbyWwdy8zHcovKdvdb70BLC8DtHKvdK6vz0R9Ys7vBWyvZK1LL0ehYM9aV+JveuvoD2ilvo9NLJ4vbRnPT4MXAW+BhG4POOBaD0Vz5I9s1+1vTUdHb7Kjcw9uVUJvdbgoj3TbBe8WwPSvYoBBj4m6c+9xTXTvVTDaL28+Ac9KtA0Pa3tS73Vq5S8fNLkvf/Gir0yILy9ZYR3vvUdUD2ZB5W9rHI4PXS76L070oG9EsjYPb89S75pz7Q9xFKyvZ5ECT0kDSU+l4AQPsQVqzyq/LW95ZCZPC6nQj0VIBa9XwkhPr1gy72c7mw937XXvQ76ur3sRok9mCUqPXHvgj28jV89LZN8O0eH0T0KMdq9ZzXevYbmPr0fcac8r7j3vYmKCL4Sewm+iLtRviuOjz08XbE9LlYevDI1wz0s7z278oVJvtpjrT20IEU9+mTtvBjMQz1H9Ey+LQEXva1Rwrxmyts9sf1hPRY3xL3RdRU+AAAAAAAAAAAAAAAAAAAAAGgAAAAAAAAARqSTvbYJpLx1x869cW67PeeJhb7/cBu9m0eFPQO3oL0I+L49YQDavTYSez3SmTg96hBGPuh4oL2x2ow6WdCUO6XUSz4xcU88GReAvVfekj0Ph3Y9z43hvBzT5z1I2my9UVy3vAj8jL08Gtm9CfJcPRihTr1+8Yu9TiP+PNrJa77Dfa09IhpEPesJNr0XzFU8yye3PZKFyz3uzJ09FLRUvYq3l73X4X07DDUzvq9VXjwWtg8+JrzYPcFCkr0jDCg9T9zlvZbZjz4Y8pM89xo8PgAcfbvYSnY8XoFKvO05/L36yzE8J+5yPqfe5r2AZFq8ULRDvnkTgrw+S7q9qGYLvQDZYL1T8d09bFikvZw3+jsYLdO8H3GVveHBYT4gnsE8ZBIJPpzOEj7OSDC+ZYu+vFc1Erzko4M9GqLtPBHH5TwpeRs+miC4PBHH5Tw9Z9k9VUsUPjnppj0oC5C9mcqDvY7y1rxdvZU8PdFAPov9lz0bOmq94kdyPBBokTxtOj89fu4avSsazj1P7iE+x8YkPAAAAAAAAAAAAAAAAAAAAABoAAAAAAAAAHEruT3mgKM8JnEvvAsfHL63906+ifhgvldl1r14OeO9waUyuw3yUzx+PDW9UbDhPQP4Lb4KRRk+Oky2vaLfaT30mrA9YMeZPfzPMz4h42M+XfCHva4AGr6MOSM+iBOzvdsaE7xFxgI+gJGXvVMzE75kHY+8oAWNvVqNK7yOx589fU3lvVVPg730Cwk+DKkEPWYtxjqQ2MK9H0T+vTnGQj2yq5w8L49BvrEJrzyB4Yo9AXV7PYGCLr3MxsG9oWM7PTyu8TzEOhW+dyWrvUTxHD2nL+c9+VKFPcthhLsc0PM8FdyPPeLj/z1WAHS8ZvW2PGg4Cb5u3IU9g4CovSHW+L2CWoG++nZnPAi2ST3HmUC9P5rJuxQbU765lwU+7FLBPUPTfL0uGgk+yKy2PYwXaT1I4I+9AU6VPQ5QaDx9mdE8Qg8zPfGCUjzD/io9rr+BvTNDqT0MFNi9mHatvS1iJD0nVrK78WmIPE0QsL3PAQq9cMRgPWXmmr3yTcw9UcXrPccwa76+cBq+5iVOvUg9c70AAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAB/K7k9hCsnPUJXJr2Wg4a9MEtXve33Sj0VJZ89pciEvWLqwLzUgyu8ADTGPAVenL2UZ/c96YtMved+Wr3LUro9H8a7vGTSA77C5n69Lf3pPQj4KD5cFKq9fZ0uvvYQCT7b23G9XGMCPrGuy736Z9A9kZzFPSuCSD7/9/07Y4/6POxLir3/JBS9qFKMvkSzjryPgVY+ugq8PC9yhbsXaiq+O6WfPcvFK7vZXAy+goAQvXpHHj5jwPI87eokvrySET5QoOm8h8ixOhXzKb5s8+A9sjcJPjiLAz598yQ9yCYSPq6eGz4rvjE82lvGvWuIOLx23zK9hHg8vTWOv70/Tse81fA6Pr2wNz34Eza+2Uj3PZ3trr0aXAI9PCkKPiybe721P9U9QkNLO927jT3LpRA+mpJUvUeU6rwC/Qa+lr4Cvgrpnj1pQ/i9TxhSvJqYr72RS6y8aQLTPQzPiz3vSRY94NfrPJl6LL2adjO8iYfPuhRzZz2f7R8+iVskPcUeXr12ZiI+nd3xvIYv8bwqYlg+AAAAAAAAAAAAAAAAAAAAAA==";
}

View File

@@ -0,0 +1,82 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use bytemuck::{Pod, Zeroable};
use half::f16;
use std::convert::AsRef;
use std::fmt;
// Define the Half type as a new type over f16.
// the memory layout of the Half struct will be the same as the memory layout of the f16 type itself.
// The Half struct serves as a simple wrapper around the f16 type and does not introduce any additional memory overhead.
// Test function:
// use half::f16;
// pub struct Half(f16);
// fn main() {
// let size_of_half = std::mem::size_of::<Half>();
// let alignment_of_half = std::mem::align_of::<Half>();
// println!("Size of Half: {} bytes", size_of_half);
// println!("Alignment of Half: {} bytes", alignment_of_half);
// }
// Output:
// Size of Half: 2 bytes
// Alignment of Half: 2 bytes
pub struct Half(f16);
unsafe impl Pod for Half {}
unsafe impl Zeroable for Half {}
// Implement From<f32> for Half
impl From<Half> for f32 {
fn from(val: Half) -> Self {
val.0.to_f32()
}
}
// Implement AsRef<f16> for Half so that it can be used in distance_compare.
impl AsRef<f16> for Half {
fn as_ref(&self) -> &f16 {
&self.0
}
}
// Implement From<f32> for Half.
impl Half {
pub fn from_f32(value: f32) -> Self {
Self(f16::from_f32(value))
}
}
// Implement Default for Half.
impl Default for Half {
fn default() -> Self {
Self(f16::from_f32(Default::default()))
}
}
// Implement Clone for Half.
impl Clone for Half {
fn clone(&self) -> Self {
Half(self.0)
}
}
// Implement PartialEq for Half.
impl fmt::Debug for Half {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Half({:?})", self.0)
}
}
impl Copy for Half {}
impl Half {
pub fn to_f32(&self) -> f32 {
self.0.to_f32()
}
}
unsafe impl Send for Half {}
unsafe impl Sync for Half {}

View File

@@ -0,0 +1,78 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
//! Distance calculation for L2 Metric
#[cfg(not(target_feature = "avx2"))]
compile_error!("Library must be compiled with -C target-feature=+avx2");
use std::arch::x86_64::*;
use crate::Half;
/// Calculate the distance by vector arithmetic
#[inline(never)]
pub fn distance_l2_vector_f16<const N: usize>(a: &[Half; N], b: &[Half; N]) -> f32 {
debug_assert_eq!(N % 8, 0);
// make sure the addresses are bytes aligned
debug_assert_eq!(a.as_ptr().align_offset(32), 0);
debug_assert_eq!(b.as_ptr().align_offset(32), 0);
unsafe {
let mut sum = _mm256_setzero_ps();
let a_ptr = a.as_ptr() as *const __m128i;
let b_ptr = b.as_ptr() as *const __m128i;
// Iterate over the elements in steps of 8
for i in (0..N).step_by(8) {
let a_vec = _mm256_cvtph_ps(_mm_load_si128(a_ptr.add(i / 8)));
let b_vec = _mm256_cvtph_ps(_mm_load_si128(b_ptr.add(i / 8)));
let diff = _mm256_sub_ps(a_vec, b_vec);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum));
/* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
/* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
/* Conversion to float is a no-op on x86-64 */
_mm_cvtss_f32(x32)
}
}
/// Calculate the distance by vector arithmetic
#[inline(never)]
pub fn distance_l2_vector_f32<const N: usize>(a: &[f32; N], b: &[f32; N]) -> f32 {
debug_assert_eq!(N % 8, 0);
// make sure the addresses are bytes aligned
debug_assert_eq!(a.as_ptr().align_offset(32), 0);
debug_assert_eq!(b.as_ptr().align_offset(32), 0);
unsafe {
let mut sum = _mm256_setzero_ps();
// Iterate over the elements in steps of 8
for i in (0..N).step_by(8) {
let a_vec = _mm256_load_ps(&a[i]);
let b_vec = _mm256_load_ps(&b[i]);
let diff = _mm256_sub_ps(a_vec, b_vec);
sum = _mm256_fmadd_ps(diff, diff, sum);
}
let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum));
/* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
/* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
/* Conversion to float is a no-op on x86-64 */
_mm_cvtss_f32(x32)
}
}

View File

@@ -0,0 +1,26 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![cfg_attr(
not(test),
warn(clippy::panic, clippy::unwrap_used, clippy::expect_used)
)]
// #![feature(stdsimd)]
// mod f32x16;
// Uncomment above 2 to experiment with f32x16
mod distance;
mod half;
mod l2_float_distance;
mod metric;
mod utils;
pub use crate::half::Half;
pub use distance::FullPrecisionDistance;
pub use metric::Metric;
pub use utils::prefetch_vector;
#[cfg(test)]
mod distance_test;
mod test_util;

View File

@@ -0,0 +1,36 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#![warn(missing_debug_implementations, missing_docs)]
use std::str::FromStr;
/// Distance metric
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Metric {
/// Squared Euclidean (L2-Squared)
L2,
/// Cosine similarity
/// TODO: T should be float for Cosine distance
Cosine,
}
#[derive(thiserror::Error, Debug)]
pub enum ParseMetricError {
#[error("Invalid format for Metric: {0}")]
InvalidFormat(String),
}
impl FromStr for Metric {
type Err = ParseMetricError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"l2" => Ok(Metric::L2),
"cosine" => Ok(Metric::Cosine),
_ => Err(ParseMetricError::InvalidFormat(String::from(s))),
}
}
}

View File

@@ -0,0 +1,29 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
#[cfg(test)]
use crate::Half;
#[cfg(test)]
pub fn no_vector_compare_f16(a: &[Half], b: &[Half]) -> f32 {
let mut sum = 0.0;
debug_assert_eq!(a.len(), b.len());
for i in 0..a.len() {
sum += (a[i].to_f32() - b[i].to_f32()).powi(2);
}
sum
}
#[cfg(test)]
pub fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 {
let mut sum = 0.0;
debug_assert_eq!(a.len(), b.len());
for i in 0..a.len() {
sum += (a[i] - b[i]).powi(2);
}
sum
}

View File

@@ -0,0 +1,21 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT license.
*/
use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
/// Prefetch the given vector in chunks of 64 bytes, which is a cache line size
/// NOTE: good efficiency when total_vec_size is integral multiple of 64
#[inline]
pub fn prefetch_vector<T>(vec: &[T]) {
let vec_ptr = vec.as_ptr() as *const i8;
let vecsize = std::mem::size_of_val(vec);
let max_prefetch_size = (vecsize / 64) * 64;
for d in (0..max_prefetch_size).step_by(64) {
unsafe {
_mm_prefetch(vec_ptr.add(d), _MM_HINT_T0);
}
}
}