Initial commit
This commit is contained in:
110
packages/leann-backend-hnsw/third_party/faiss/misc/test_blas.cpp
vendored
Normal file
110
packages/leann-backend-hnsw/third_party/faiss/misc/test_blas.cpp
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
/*
|
||||
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
*
|
||||
* This source code is licensed under the MIT license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <random>
|
||||
|
||||
#undef FINTEGER
|
||||
#define FINTEGER long
|
||||
|
||||
extern "C" {
|
||||
|
||||
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
||||
|
||||
int sgemm_(
|
||||
const char* transa,
|
||||
const char* transb,
|
||||
FINTEGER* m,
|
||||
FINTEGER* n,
|
||||
FINTEGER* k,
|
||||
const float* alpha,
|
||||
const float* a,
|
||||
FINTEGER* lda,
|
||||
const float* b,
|
||||
FINTEGER* ldb,
|
||||
float* beta,
|
||||
float* c,
|
||||
FINTEGER* ldc);
|
||||
|
||||
/* Lapack functions, see http://www.netlib.org/clapack/old/single/sgeqrf.c */
|
||||
|
||||
int sgeqrf_(
|
||||
FINTEGER* m,
|
||||
FINTEGER* n,
|
||||
float* a,
|
||||
FINTEGER* lda,
|
||||
float* tau,
|
||||
float* work,
|
||||
FINTEGER* lwork,
|
||||
FINTEGER* info);
|
||||
}
|
||||
|
||||
float* new_random_vec(int size) {
|
||||
float* x = new float[size];
|
||||
std::mt19937 rng;
|
||||
std::uniform_real_distribution<> distrib;
|
||||
for (int i = 0; i < size; i++)
|
||||
x[i] = distrib(rng);
|
||||
return x;
|
||||
}
|
||||
|
||||
int main() {
|
||||
FINTEGER m = 10, n = 20, k = 30;
|
||||
float *a = new_random_vec(m * k), *b = new_random_vec(n * k),
|
||||
*c = new float[n * m];
|
||||
float one = 1.0, zero = 0.0;
|
||||
|
||||
printf("BLAS test\n");
|
||||
|
||||
sgemm_("Not transposed",
|
||||
"Not transposed",
|
||||
&m,
|
||||
&n,
|
||||
&k,
|
||||
&one,
|
||||
a,
|
||||
&m,
|
||||
b,
|
||||
&k,
|
||||
&zero,
|
||||
c,
|
||||
&m);
|
||||
|
||||
printf("errors=\n");
|
||||
|
||||
for (int i = 0; i < m; i++) {
|
||||
for (int j = 0; j < n; j++) {
|
||||
float accu = 0;
|
||||
for (int l = 0; l < k; l++)
|
||||
accu += a[i + l * m] * b[l + j * k];
|
||||
printf("%6.3f ", accu - c[i + j * m]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
long info = 0x64bL << 32;
|
||||
long mi = 0x64bL << 32 | m;
|
||||
float* tau = new float[m];
|
||||
FINTEGER lwork = -1;
|
||||
|
||||
float work1;
|
||||
|
||||
printf("Intentional Lapack error (appears only for 64-bit INTEGER):\n");
|
||||
sgeqrf_(&mi, &n, c, &m, tau, &work1, &lwork, (FINTEGER*)&info);
|
||||
|
||||
// sgeqrf_ (&m, &n, c, &zeroi, tau, &work1, &lwork, (FINTEGER*)&info);
|
||||
printf("info=%016lx\n", info);
|
||||
|
||||
if (info >> 32 == 0x64b) {
|
||||
printf("Lapack uses 32-bit integers\n");
|
||||
} else {
|
||||
printf("Lapack uses 64-bit integers\n");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user