]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Add preliminary bindings for kann
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 28 Jun 2019 17:06:40 +0000 (18:06 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 28 Jun 2019 17:06:40 +0000 (18:06 +0100)
src/lua/lua_kann.c [new file with mode: 0644]

diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
new file mode 100644 (file)
index 0000000..cec75ac
--- /dev/null
@@ -0,0 +1,564 @@
+/*-
+ * Copyright 2019 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua_common.h"
+#include "contrib/kann/kann.h"
+
+/***
+ * @module rspamd_kann
+ * `rspamd_kann` is a Lua interface to kann library
+ */
+
+#define KANN_NODE_CLASS "rspamd{kann_node}"
+
+/* Simple macros to define behaviour */
+#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
+#define KANN_LAYER_INTERFACE(name) {#name, lua_kann_layer_ ## name}
+
+#define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_ ## name (lua_State *L)
+#define KANN_TRANSFORM_INTERFACE(name) {#name, lua_kann_transform_ ## name}
+
+#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
+#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
+
+/*
+ * Forwarded declarations
+ */
+static kad_node_t *lua_check_kann_node (lua_State *L, int pos);
+
+/* Layers */
+KANN_LAYER_DEF(input);
+KANN_LAYER_DEF(dense);
+KANN_LAYER_DEF(layernorm);
+KANN_LAYER_DEF(rnn);
+KANN_LAYER_DEF(lstm);
+KANN_LAYER_DEF(gru);
+KANN_LAYER_DEF(conv2d);
+KANN_LAYER_DEF(conv1d);
+KANN_LAYER_DEF(cost);
+
+static luaL_reg rspamd_kann_layers_f[] = {
+               KANN_LAYER_INTERFACE(input),
+               KANN_LAYER_INTERFACE(dense),
+               KANN_LAYER_INTERFACE(layernorm),
+               KANN_LAYER_INTERFACE(rnn),
+               KANN_LAYER_INTERFACE(lstm),
+               KANN_LAYER_INTERFACE(gru),
+               KANN_LAYER_INTERFACE(conv2d),
+               KANN_LAYER_INTERFACE(conv1d),
+               KANN_LAYER_INTERFACE(cost),
+               {NULL, NULL},
+};
+
+/* Transition and composition functions */
+
+/* General transform */
+KANN_TRANSFORM_DEF (add);
+KANN_TRANSFORM_DEF (sub);
+KANN_TRANSFORM_DEF (mul);
+KANN_TRANSFORM_DEF (cmul);
+KANN_TRANSFORM_DEF (matmul);
+
+KANN_TRANSFORM_DEF (square);
+KANN_TRANSFORM_DEF (sigm);
+KANN_TRANSFORM_DEF (tanh);
+KANN_TRANSFORM_DEF (relu);
+KANN_TRANSFORM_DEF (softmax);
+KANN_TRANSFORM_DEF (1minus);
+KANN_TRANSFORM_DEF (exp);
+KANN_TRANSFORM_DEF (log);
+KANN_TRANSFORM_DEF (sin);
+static luaL_reg rspamd_kann_transform_f[] = {
+               KANN_TRANSFORM_INTERFACE (add),
+               KANN_TRANSFORM_INTERFACE (sub),
+               KANN_TRANSFORM_INTERFACE (mul),
+               KANN_TRANSFORM_INTERFACE (cmul),
+               KANN_TRANSFORM_INTERFACE (matmul),
+
+               KANN_TRANSFORM_INTERFACE (square),
+               KANN_TRANSFORM_INTERFACE (sigm),
+               KANN_TRANSFORM_INTERFACE (tanh),
+               KANN_TRANSFORM_INTERFACE (relu),
+               KANN_TRANSFORM_INTERFACE (softmax),
+               KANN_TRANSFORM_INTERFACE (1minus),
+               KANN_TRANSFORM_INTERFACE (exp),
+               KANN_TRANSFORM_INTERFACE (log),
+               KANN_TRANSFORM_INTERFACE (sin),
+               {NULL, NULL},
+};
+
+/* Loss functions */
+KANN_LOSS_DEF (mse);
+KANN_LOSS_DEF (ce_multi);
+KANN_LOSS_DEF (ce_bin);
+KANN_LOSS_DEF (ce_bin_neg);
+KANN_LOSS_DEF (ce_multi_weighted);
+static luaL_reg rspamd_kann_loss_f[] = {
+               KANN_LOSS_INTERFACE (mse),
+               KANN_LOSS_INTERFACE (ce_multi),
+               KANN_LOSS_INTERFACE (ce_bin),
+               KANN_LOSS_INTERFACE (ce_bin_neg),
+               KANN_LOSS_INTERFACE (ce_multi_weighted),
+               {NULL, NULL},
+};
+
+static int
+rspamd_kann_table_to_flags (lua_State *L, int table_pos)
+{
+       int result = 0;
+
+       lua_pushvalue (L, table_pos);
+
+       for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+               int fl = lua_tointeger (L, -1);
+
+               result |= fl;
+       }
+
+       lua_pop (L, 1);
+
+       return result;
+}
+
+static gint
+lua_load_kann (lua_State * L)
+{
+       lua_newtable (L);
+
+       /* Flags */
+       lua_pushstring (L, "flag");
+       lua_newtable (L);
+       lua_pushinteger (L, KANN_F_IN);
+       lua_setfield (L, -2, "in");
+       lua_pushinteger (L, KANN_F_COST);
+       lua_setfield (L, -2, "cost");
+       lua_pushinteger (L, KANN_F_OUT);
+       lua_setfield (L, -2, "out");
+       lua_pushinteger (L, KANN_F_TRUTH);
+       lua_setfield (L, -2, "truth");
+       lua_settable (L, -3);
+
+       /* Cost type */
+       lua_pushstring (L, "cost");
+       lua_newtable (L);
+       /* binary cross-entropy cost, used with sigmoid */
+       lua_pushinteger (L, KANN_C_CEB);
+       lua_setfield (L, -2, "ceb");
+       /* multi-class cross-entropy cost, used with softmax */
+       lua_pushinteger (L, KANN_C_CEM);
+       lua_setfield (L, -2, "cem");
+       /* binary cross-entropy-like cost, used with tanh */
+       lua_pushinteger (L, KANN_C_CEB_NEG);
+       lua_setfield (L, -2, "ceb_neg");
+       lua_pushinteger (L, KANN_C_MSE);
+       lua_setfield (L, -2, "mse");
+       lua_settable (L, -3);
+
+       /* RNN flag */
+       lua_pushstring (L, "rnn");
+       lua_newtable (L);
+       /* apply layer normalization */
+       lua_pushinteger (L, KANN_RNN_NORM);
+       lua_setfield (L, -2, "norm");
+       /* take the initial hidden values as variables */
+       lua_pushinteger (L, KANN_RNN_VAR_H0);
+       lua_setfield (L, -2, "var_h0");
+       lua_settable (L, -3);
+
+       /* Layers */
+       lua_pushstring (L, "layer");
+       lua_newtable (L);
+       luaL_register (L, NULL, rspamd_kann_layers_f);
+       lua_settable (L, -3);
+
+       /* Transforms */
+       lua_pushstring (L, "transform");
+       lua_newtable (L);
+       luaL_register (L, NULL, rspamd_kann_transform_f);
+       lua_settable (L, -3);
+
+       /* Cost */
+       lua_pushstring (L, "loss");
+       lua_newtable (L);
+       luaL_register (L, NULL, rspamd_kann_loss_f);
+       lua_settable (L, -3);
+
+       return 1;
+}
+
+static kad_node_t *
+lua_check_kann_node (lua_State *L, int pos)
+{
+       void *ud = rspamd_lua_check_udata (L, pos, KANN_NODE_CLASS);
+       luaL_argcheck (L, ud != NULL, pos, "'kann_node' expected");
+       return ud ? *((kad_node_t **)ud) : NULL;
+}
+
+void luaopen_kann (lua_State *L)
+{
+       /* Metatables */
+       rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+       lua_pop (L, 1); /* No need in metatable... */
+       rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
+       lua_settop (L, 0);
+}
+
+/* Layers implementation */
+#define PUSH_KAD_NODE(n) do { \
+       kad_node_t **pt; \
+       pt = lua_newuserdata (L, sizeof (kad_node_t *)); \
+       *pt = (n); \
+       rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
+} while(0)
+
+#define PROCESS_KAD_FLAGS(n, pos) do { \
+       int fl = 0; \
+       if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
+       else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
+       (n)->ext_flag = fl; \
+}while(0)
+
+static int
+lua_kann_layer_input (lua_State *L)
+{
+       gint nnodes = luaL_checkinteger (L, 1);
+
+       if (nnodes > 0) {
+               kad_node_t *t;
+
+               t = kann_layer_input (nnodes);
+
+               PROCESS_KAD_FLAGS (t, 2);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, nnodes required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_dense (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       gint nnodes = luaL_checkinteger (L, 2);
+
+       if (in != NULL && nnodes > 0) {
+               kad_node_t *t;
+
+               t = kann_layer_dense (in, nnodes);
+
+               PROCESS_KAD_FLAGS (t, 3);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input + nnodes required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_layerdropout (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       double r = luaL_checknumber (L, 2);
+
+       if (in != NULL) {
+               kad_node_t *t;
+
+               t = kann_layer_dropout (in, r);
+
+               PROCESS_KAD_FLAGS (t, 3);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input + rate required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_layernorm (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+
+       if (in != NULL) {
+               kad_node_t *t;
+
+               t = kann_layer_layernorm (in);
+
+               PROCESS_KAD_FLAGS (t, 2);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_rnn (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       gint nnodes = luaL_checkinteger (L, 2);
+       gint rnnflags = 0;
+
+       if (in != NULL && nnodes > 0) {
+               kad_node_t *t;
+
+               if (lua_type (L, 3) == LUA_TNUMBER) {
+                       rnnflags = lua_tointeger (L, 3);
+               }
+
+               t = kann_layer_rnn (in, nnodes, rnnflags);
+
+               PROCESS_KAD_FLAGS (t, 4);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input + nnodes required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_lstm (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       gint nnodes = luaL_checkinteger (L, 2);
+       gint rnnflags = 0;
+
+       if (in != NULL && nnodes > 0) {
+               kad_node_t *t;
+
+               if (lua_type (L, 3) == LUA_TNUMBER) {
+                       rnnflags = lua_tointeger (L, 3);
+               }
+
+               t = kann_layer_lstm (in, nnodes, rnnflags);
+
+               PROCESS_KAD_FLAGS (t, 4);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input + nnodes required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_gru (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       gint nnodes = luaL_checkinteger (L, 2);
+       gint rnnflags = 0;
+
+       if (in != NULL && nnodes > 0) {
+               kad_node_t *t;
+
+               if (lua_type (L, 3) == LUA_TNUMBER) {
+                       rnnflags = lua_tointeger (L, 3);
+               }
+
+               t = kann_layer_gru (in, nnodes, rnnflags);
+
+               PROCESS_KAD_FLAGS (t, 4);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input + nnodes required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_conv2d (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       int n_flt = luaL_checkinteger (L, 2);
+       int k_rows = luaL_checkinteger (L, 3);
+       int k_cols =  luaL_checkinteger (L, 4);
+       int stride_r = luaL_checkinteger (L, 5);
+       int stride_c = luaL_checkinteger (L, 6);
+       int pad_r = luaL_checkinteger (L, 7);
+       int pad_c = luaL_checkinteger (L, 8);
+
+       if (in != NULL) {
+               kad_node_t *t;
+               t = kann_layer_conv2d (in, n_flt, k_rows, k_cols, stride_r, stride_c,
+                               pad_r, pad_c);
+
+               PROCESS_KAD_FLAGS (t, 9);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_conv1d (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       int n_flt = luaL_checkinteger (L, 2);
+       int k_size = luaL_checkinteger (L, 3);
+       int stride = luaL_checkinteger (L, 4);
+       int pad = luaL_checkinteger (L, 5);
+
+       if (in != NULL) {
+               kad_node_t *t;
+               t = kann_layer_conv1d (in, n_flt, k_size, stride, pad);
+
+               PROCESS_KAD_FLAGS (t, 6);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input, nflt, k, stride, pad required");
+       }
+
+       return 1;
+}
+
+static int
+lua_kann_layer_cost (lua_State *L)
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+       int nout = luaL_checkinteger (L, 2);
+       int cost_type = luaL_checkinteger (L, 3);
+
+       if (in != NULL && nout > 0) {
+               kad_node_t *t;
+               t = kann_layer_cost (in, nout, cost_type);
+
+               PROCESS_KAD_FLAGS (t, 4);
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, input, nout and cost_type are required");
+       }
+
+       return 1;
+}
+
+/* Generic helpers */
+static int
+lua_kann_call_unary_function (lua_State *L, const char *name,
+               kad_node_t *(*func)(kad_node_t *))
+{
+       kad_node_t *in = lua_check_kann_node (L, 1);
+
+       if (in != NULL) {
+               kad_node_t *t;
+               t = func (in);
+
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments for %s, input required", name);
+       }
+
+       return 1;
+}
+static int
+lua_kann_call_binary_function (lua_State *L, const char *name,
+                                                         kad_node_t *(*func)(kad_node_t *, kad_node_t *))
+{
+       kad_node_t *x = lua_check_kann_node (L, 1);
+       kad_node_t *y = lua_check_kann_node (L, 2);
+
+       if (x != NULL && y != NULL) {
+               kad_node_t *t;
+               t = func (x, y);
+
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments for %s, 2 inputs required", name);
+       }
+
+       return 1;
+}
+
+#define LUA_UNARY_TRANSFORM_FUNC_IMPL(name)                                                                    \
+static int lua_kann_transform_ ##name (lua_State *L)                                           \
+{                                                                                                                                                      \
+       return lua_kann_call_unary_function(L, #name, kad_##name);                              \
+}
+
+#define LUA_BINARY_TRANSFORM_FUNC_IMPL(name)                                                           \
+static int lua_kann_transform_ ##name (lua_State *L)                                           \
+{                                                                                                                                                      \
+       return lua_kann_call_binary_function(L, #name, kad_##name);                             \
+}
+
+#define LUA_LOSS_FUNC_IMPL(name)                                                                                       \
+static int lua_kann_loss_ ##name (lua_State *L)                                                                \
+{                                                                                                                                                      \
+       return lua_kann_call_binary_function(L, #name, kad_##name);                             \
+}
+
+/* Transform functions registered via macro helpers */
+LUA_BINARY_TRANSFORM_FUNC_IMPL (add)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (sub)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (mul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (cmul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (matmul)
+
+LUA_UNARY_TRANSFORM_FUNC_IMPL (square)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (sigm)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (tanh)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (relu)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (softmax)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (1minus)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (exp)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (log)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (sin)
+
+/* Generic cost functions */
+LUA_LOSS_FUNC_IMPL (mse)
+LUA_LOSS_FUNC_IMPL (ce_multi)
+LUA_LOSS_FUNC_IMPL (ce_bin)
+LUA_LOSS_FUNC_IMPL (ce_bin_neg)
+
+/* The only case of ternary weight function */
+static int
+lua_kann_loss_ce_multi_weighted (lua_State *L)
+{
+       kad_node_t *pred = lua_check_kann_node (L, 1);
+       kad_node_t *truth = lua_check_kann_node (L, 2);
+       kad_node_t *weight = lua_check_kann_node (L, 3);
+
+       if (pred != NULL && truth != NULL && weight != NULL) {
+               kad_node_t *t;
+               t = kad_ce_multi_weighted (pred, truth, weight);
+
+               PUSH_KAD_NODE (t);
+       }
+       else {
+               return luaL_error (L, "invalid arguments for ce_multi_weighted, 3 inputs required");
+       }
+
+       return 1;
+}
\ No newline at end of file