diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-04 14:17:01 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-04 14:17:01 +0100 |
commit | d7d71002117e4ec30d96ca91c54f971b8e835325 (patch) | |
tree | 06795869dfe3d09f95572d78264f947271c34e29 | |
parent | 98591e36dbdfa483130d200dbe3423611dfbab81 (diff) | |
download | rspamd-d7d71002117e4ec30d96ca91c54f971b8e835325.tar.gz rspamd-d7d71002117e4ec30d96ca91c54f971b8e835325.zip |
[Project] Add linalg ffi library for prototyping
-rw-r--r-- | contrib/kann/kautodiff.h | 5 | ||||
-rw-r--r-- | lualib/lua_ffi/init.lua | 1 | ||||
-rw-r--r-- | lualib/lua_ffi/linalg.lua | 70 |
3 files changed, 75 insertions, 1 deletions
diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h index a2c648835..e51176c84 100644 --- a/contrib/kann/kautodiff.h +++ b/contrib/kann/kautodiff.h @@ -102,7 +102,7 @@ void kad_delete(int n, kad_node_t **a); /* deallocate a compiled/linearized grap /** * Compute the value at a node - * + * * @param n number of nodes * @param a list of nodes * @param from compute the value at this node, 0<=from<n @@ -243,4 +243,7 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */ return n; } +/* Additions by Rspamd */ +void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); + #endif diff --git a/lualib/lua_ffi/init.lua b/lualib/lua_ffi/init.lua index 02b54f932..08a6763bb 100644 --- a/lualib/lua_ffi/init.lua +++ b/lualib/lua_ffi/init.lua @@ -49,6 +49,7 @@ pcall(ffi.load, "rspamd-server", true) exports.common = require "lua_ffi/common" exports.dkim = require "lua_ffi/dkim" exports.spf = require "lua_ffi/spf" +exports.linalg = require "lua_ffi/linalg" for k,v in pairs(ffi) do -- Preserve all stuff to use lua_ffi as ffi itself diff --git a/lualib/lua_ffi/linalg.lua b/lualib/lua_ffi/linalg.lua new file mode 100644 index 000000000..c3f6eff5a --- /dev/null +++ b/lualib/lua_ffi/linalg.lua @@ -0,0 +1,70 @@ +--[[ +Copyright (c) 2020, Vsevolod Stakhov <vsevolod@highsecure.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +--[[[ +-- @module lua_ffi/linalg +-- This module contains ffi interfaces to linear algebra routines +--]] + +local ffi = require 'ffi' + +local exports = {} + +ffi.cdef[[ + void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C); +]] + +local function table_to_ffi(a, m, n) + local a_conv = ffi.new(string.format("float[%d][%d]", m, n), {}) + for i=1,m or #a do + for j=1,n or #a[1] do + a_conv[i - 1][j - 1] = a[i][j] + end + end + return a_conv +end + +local function ffi_to_table(a, m, n) + local res = {} + + for i=0,m-1 do + res[i + 1] = {} + for j=0,n-1 do + res[i + 1][j + 1] = a[i][j] + end + end + + return res +end + +exports.sgemm = function(a, m, b, n, k, trans_a, trans_b) + if type(a) == 'table' then + -- Need to convert, slow! + a = table_to_ffi(a, m, k) + end + if type(b) == 'table' then + b = table_to_ffi(b, k, n) + end + local res = ffi.new(string.format("float[%d][%d]", m, n), {}) + ffi.C.kad_sgemm_simple(trans_a or 0, trans_b or 0, m, n, k, ffi.cast('const float*', a), + ffi.cast('const float*', b), ffi.cast('float*', res)) + return res +end + +exports.ffi_to_table = ffi_to_table +exports.table_to_ffi = table_to_ffi + +return exports
\ No newline at end of file |