diff options
-rw-r--r-- | contrib/kann/kautodiff.c | 29 | ||||
-rw-r--r-- | contrib/kann/kautodiff.h | 9 | ||||
-rw-r--r-- | lualib/lua_ffi/linalg.lua | 23 |
3 files changed, 57 insertions, 4 deletions
diff --git a/contrib/kann/kautodiff.c b/contrib/kann/kautodiff.c index 47a86a71e..7b0bf8e93 100644 --- a/contrib/kann/kautodiff.c +++ b/contrib/kann/kautodiff.c @@ -900,6 +900,7 @@ void kad_vec_mul_sum(int n, float *a, const float *b, const float *c) void kad_saxpy(int n, float a, const float *x, float *y) { kad_saxpy_inlined(n, a, x, y); } #ifdef HAVE_CBLAS +extern void ssyev(const char* jobz, const char* uplo, int* n, float* a, int* lda, float* w, float* work, int* lwork, int* info); #ifdef HAVE_CBLAS_H #include "cblas.h" #else @@ -947,6 +948,34 @@ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float } #endif +bool kad_ssyev_simple(int N, float *A, float *eugenvals) +{ +#ifndef HAVE_CBLAS + return false; +#else + int n = N, lda = N, info, lwork; + float wkopt; + float *work; + + /* Query and allocate the optimal workspace */ + lwork = -1; + ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, &wkopt, &lwork, &info); + lwork = wkopt; + work = (float*) g_malloc(lwork * sizeof(double)); + ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, work, &lwork, &info); + /* Check for convergence */ + if (info > 0) { + g_free (work); + + return false; + } + + g_free (work); + + return true; +#endif +} + /*************************** * Random number generator * ***************************/ diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h index e51176c84..8c797205c 100644 --- a/contrib/kann/kautodiff.h +++ b/contrib/kann/kautodiff.h @@ -244,6 +244,13 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */ } /* Additions by Rspamd */ -void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); +void kad_sgemm_simple (int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); +/** + * Calculate eugenvectors and eugenvalues + * @param N dimensions of A (must be NxN) + * @param A input matrix (part of it will be destroyed, so copy if needed), on finish the first `nwork` columns will have eugenvectors + * @param eugenvals eugenvalues, must be N elements vector + */ +bool kad_ssyev_simple (int N, float *A, float *eugenvals); #endif diff --git a/lualib/lua_ffi/linalg.lua b/lualib/lua_ffi/linalg.lua index c3f6eff5a..85e84b5ac 100644 --- a/lualib/lua_ffi/linalg.lua +++ b/lualib/lua_ffi/linalg.lua @@ -25,13 +25,14 @@ local exports = {} ffi.cdef[[ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); + bool kad_ssyev_simple (int N, float *A, float *output); ]] local function table_to_ffi(a, m, n) - local a_conv = ffi.new(string.format("float[%d][%d]", m, n), {}) + local a_conv = ffi.new("float[?]", m * n) for i=1,m or #a do for j=1,n or #a[1] do - a_conv[i - 1][j - 1] = a[i][j] + a_conv[(i - 1) * n + (j - 1)] = a[i][j] end end return a_conv @@ -58,12 +59,28 @@ exports.sgemm = function(a, m, b, n, k, trans_a, trans_b) if type(b) == 'table' then b = table_to_ffi(b, k, n) end - local res = ffi.new(string.format("float[%d][%d]", m, n), {}) + local res = ffi.new("float[?]", m * n) ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a), ffi.cast('const float*', b), ffi.cast('float*', res)) return res end +exports.eugen = function(a, n) + if type(a) == 'table' then + -- Need to convert, slow! + n = n or #a + a = table_to_ffi(a, n, n) + end + + local res = ffi.new("float[?]", n) + + if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then + return res,a + end + + return nil +end + exports.ffi_to_table = ffi_to_table exports.table_to_ffi = table_to_ffi |