@@ -900,6 +900,7 @@ void kad_vec_mul_sum(int n, float *a, const float *b, const float *c) | |||
void kad_saxpy(int n, float a, const float *x, float *y) { kad_saxpy_inlined(n, a, x, y); } | |||
#ifdef HAVE_CBLAS | |||
extern void ssyev(const char* jobz, const char* uplo, int* n, float* a, int* lda, float* w, float* work, int* lwork, int* info); | |||
#ifdef HAVE_CBLAS_H | |||
#include "cblas.h" | |||
#else | |||
@@ -947,6 +948,34 @@ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float | |||
} | |||
#endif | |||
bool kad_ssyev_simple(int N, float *A, float *eugenvals) | |||
{ | |||
#ifndef HAVE_CBLAS | |||
return false; | |||
#else | |||
int n = N, lda = N, info, lwork; | |||
float wkopt; | |||
float *work; | |||
/* Query and allocate the optimal workspace */ | |||
lwork = -1; | |||
ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, &wkopt, &lwork, &info); | |||
lwork = wkopt; | |||
work = (float*) g_malloc(lwork * sizeof(double)); | |||
ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, work, &lwork, &info); | |||
/* Check for convergence */ | |||
if (info > 0) { | |||
g_free (work); | |||
return false; | |||
} | |||
g_free (work); | |||
return true; | |||
#endif | |||
} | |||
/*************************** | |||
* Random number generator * | |||
***************************/ |
@@ -244,6 +244,13 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */ | |||
} | |||
/* Additions by Rspamd */ | |||
void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); | |||
void kad_sgemm_simple (int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); | |||
/** | |||
* Calculate eugenvectors and eugenvalues | |||
* @param N dimensions of A (must be NxN) | |||
* @param A input matrix (part of it will be destroyed, so copy if needed), on finish the first `nwork` columns will have eugenvectors | |||
* @param eugenvals eugenvalues, must be N elements vector | |||
*/ | |||
bool kad_ssyev_simple (int N, float *A, float *eugenvals); | |||
#endif |
@@ -25,13 +25,14 @@ local exports = {} | |||
ffi.cdef[[ | |||
void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); | |||
bool kad_ssyev_simple (int N, float *A, float *output); | |||
]] | |||
local function table_to_ffi(a, m, n) | |||
local a_conv = ffi.new(string.format("float[%d][%d]", m, n), {}) | |||
local a_conv = ffi.new("float[?]", m * n) | |||
for i=1,m or #a do | |||
for j=1,n or #a[1] do | |||
a_conv[i - 1][j - 1] = a[i][j] | |||
a_conv[(i - 1) * n + (j - 1)] = a[i][j] | |||
end | |||
end | |||
return a_conv | |||
@@ -58,12 +59,28 @@ exports.sgemm = function(a, m, b, n, k, trans_a, trans_b) | |||
if type(b) == 'table' then | |||
b = table_to_ffi(b, k, n) | |||
end | |||
local res = ffi.new(string.format("float[%d][%d]", m, n), {}) | |||
local res = ffi.new("float[?]", m * n) | |||
ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a), | |||
ffi.cast('const float*', b), ffi.cast('float*', res)) | |||
return res | |||
end | |||
exports.eugen = function(a, n) | |||
if type(a) == 'table' then | |||
-- Need to convert, slow! | |||
n = n or #a | |||
a = table_to_ffi(a, n, n) | |||
end | |||
local res = ffi.new("float[?]", n) | |||
if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then | |||
return res,a | |||
end | |||
return nil | |||
end | |||
exports.ffi_to_table = ffi_to_table | |||
exports.table_to_ffi = table_to_ffi | |||