Browse Source

[Project] Add ssyev method interface

tags/2.6
Vsevolod Stakhov 3 years ago
parent
commit
9fd03abf5d
3 changed files with 57 additions and 4 deletions
  1. 29
    0
      contrib/kann/kautodiff.c
  2. 8
    1
      contrib/kann/kautodiff.h
  3. 20
    3
      lualib/lua_ffi/linalg.lua

+ 29
- 0
contrib/kann/kautodiff.c View File

@@ -900,6 +900,7 @@ void kad_vec_mul_sum(int n, float *a, const float *b, const float *c)
void kad_saxpy(int n, float a, const float *x, float *y) { kad_saxpy_inlined(n, a, x, y); }

#ifdef HAVE_CBLAS
extern void ssyev(const char* jobz, const char* uplo, int* n, float* a, int* lda, float* w, float* work, int* lwork, int* info);
#ifdef HAVE_CBLAS_H
#include "cblas.h"
#else
@@ -947,6 +948,34 @@ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float
}
#endif

bool kad_ssyev_simple(int N, float *A, float *eugenvals)
{
#ifndef HAVE_CBLAS
return false;
#else
int n = N, lda = N, info, lwork;
float wkopt;
float *work;

/* Query and allocate the optimal workspace */
lwork = -1;
ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, &wkopt, &lwork, &info);
lwork = wkopt;
work = (float*) g_malloc(lwork * sizeof(double));
ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, work, &lwork, &info);
/* Check for convergence */
if (info > 0) {
g_free (work);

return false;
}

g_free (work);

return true;
#endif
}

/***************************
* Random number generator *
***************************/

+ 8
- 1
contrib/kann/kautodiff.h View File

@@ -244,6 +244,13 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */
}

/* Additions by Rspamd */
void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
void kad_sgemm_simple (int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
/**
* Calculate eugenvectors and eugenvalues
* @param N dimensions of A (must be NxN)
* @param A input matrix (part of it will be destroyed, so copy if needed), on finish the first `nwork` columns will have eugenvectors
* @param eugenvals eugenvalues, must be N elements vector
*/
bool kad_ssyev_simple (int N, float *A, float *eugenvals);

#endif

+ 20
- 3
lualib/lua_ffi/linalg.lua View File

@@ -25,13 +25,14 @@ local exports = {}

ffi.cdef[[
void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
bool kad_ssyev_simple (int N, float *A, float *output);
]]

local function table_to_ffi(a, m, n)
local a_conv = ffi.new(string.format("float[%d][%d]", m, n), {})
local a_conv = ffi.new("float[?]", m * n)
for i=1,m or #a do
for j=1,n or #a[1] do
a_conv[i - 1][j - 1] = a[i][j]
a_conv[(i - 1) * n + (j - 1)] = a[i][j]
end
end
return a_conv
@@ -58,12 +59,28 @@ exports.sgemm = function(a, m, b, n, k, trans_a, trans_b)
if type(b) == 'table' then
b = table_to_ffi(b, k, n)
end
local res = ffi.new(string.format("float[%d][%d]", m, n), {})
local res = ffi.new("float[?]", m * n)
ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a),
ffi.cast('const float*', b), ffi.cast('float*', res))
return res
end

exports.eugen = function(a, n)
if type(a) == 'table' then
-- Need to convert, slow!
n = n or #a
a = table_to_ffi(a, n, n)
end

local res = ffi.new("float[?]", n)

if ffi.C.kad_ssyev_simple(n, ffi.cast('float*', a), res) then
return res,a
end

return nil
end

exports.ffi_to_table = ffi_to_table
exports.table_to_ffi = table_to_ffi


Loading…
Cancel
Save