--- /dev/null
+/*-
+ * Copyright 2019 Vsevolod Stakhov
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lua_common.h"
+#include "contrib/kann/kann.h"
+
+/***
+ * @module rspamd_kann
+ * `rspamd_kann` is a Lua interface to kann library
+ */
+
+#define KANN_NODE_CLASS "rspamd{kann_node}"
+
+/* Simple macros to define behaviour */
+#define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
+#define KANN_LAYER_INTERFACE(name) {#name, lua_kann_layer_ ## name}
+
+#define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_ ## name (lua_State *L)
+#define KANN_TRANSFORM_INTERFACE(name) {#name, lua_kann_transform_ ## name}
+
+#define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
+#define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
+
+/*
+ * Forwarded declarations
+ */
+static kad_node_t *lua_check_kann_node (lua_State *L, int pos);
+
+/* Layers */
+KANN_LAYER_DEF(input);
+KANN_LAYER_DEF(dense);
+KANN_LAYER_DEF(layernorm);
+KANN_LAYER_DEF(rnn);
+KANN_LAYER_DEF(lstm);
+KANN_LAYER_DEF(gru);
+KANN_LAYER_DEF(conv2d);
+KANN_LAYER_DEF(conv1d);
+KANN_LAYER_DEF(cost);
+
+static luaL_reg rspamd_kann_layers_f[] = {
+ KANN_LAYER_INTERFACE(input),
+ KANN_LAYER_INTERFACE(dense),
+ KANN_LAYER_INTERFACE(layernorm),
+ KANN_LAYER_INTERFACE(rnn),
+ KANN_LAYER_INTERFACE(lstm),
+ KANN_LAYER_INTERFACE(gru),
+ KANN_LAYER_INTERFACE(conv2d),
+ KANN_LAYER_INTERFACE(conv1d),
+ KANN_LAYER_INTERFACE(cost),
+ {NULL, NULL},
+};
+
+/* Transition and composition functions */
+
+/* General transform */
+KANN_TRANSFORM_DEF (add);
+KANN_TRANSFORM_DEF (sub);
+KANN_TRANSFORM_DEF (mul);
+KANN_TRANSFORM_DEF (cmul);
+KANN_TRANSFORM_DEF (matmul);
+
+KANN_TRANSFORM_DEF (square);
+KANN_TRANSFORM_DEF (sigm);
+KANN_TRANSFORM_DEF (tanh);
+KANN_TRANSFORM_DEF (relu);
+KANN_TRANSFORM_DEF (softmax);
+KANN_TRANSFORM_DEF (1minus);
+KANN_TRANSFORM_DEF (exp);
+KANN_TRANSFORM_DEF (log);
+KANN_TRANSFORM_DEF (sin);
+static luaL_reg rspamd_kann_transform_f[] = {
+ KANN_TRANSFORM_INTERFACE (add),
+ KANN_TRANSFORM_INTERFACE (sub),
+ KANN_TRANSFORM_INTERFACE (mul),
+ KANN_TRANSFORM_INTERFACE (cmul),
+ KANN_TRANSFORM_INTERFACE (matmul),
+
+ KANN_TRANSFORM_INTERFACE (square),
+ KANN_TRANSFORM_INTERFACE (sigm),
+ KANN_TRANSFORM_INTERFACE (tanh),
+ KANN_TRANSFORM_INTERFACE (relu),
+ KANN_TRANSFORM_INTERFACE (softmax),
+ KANN_TRANSFORM_INTERFACE (1minus),
+ KANN_TRANSFORM_INTERFACE (exp),
+ KANN_TRANSFORM_INTERFACE (log),
+ KANN_TRANSFORM_INTERFACE (sin),
+ {NULL, NULL},
+};
+
+/* Loss functions */
+KANN_LOSS_DEF (mse);
+KANN_LOSS_DEF (ce_multi);
+KANN_LOSS_DEF (ce_bin);
+KANN_LOSS_DEF (ce_bin_neg);
+KANN_LOSS_DEF (ce_multi_weighted);
+static luaL_reg rspamd_kann_loss_f[] = {
+ KANN_LOSS_INTERFACE (mse),
+ KANN_LOSS_INTERFACE (ce_multi),
+ KANN_LOSS_INTERFACE (ce_bin),
+ KANN_LOSS_INTERFACE (ce_bin_neg),
+ KANN_LOSS_INTERFACE (ce_multi_weighted),
+ {NULL, NULL},
+};
+
+static int
+rspamd_kann_table_to_flags (lua_State *L, int table_pos)
+{
+ int result = 0;
+
+ lua_pushvalue (L, table_pos);
+
+ for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
+ int fl = lua_tointeger (L, -1);
+
+ result |= fl;
+ }
+
+ lua_pop (L, 1);
+
+ return result;
+}
+
+static gint
+lua_load_kann (lua_State * L)
+{
+ lua_newtable (L);
+
+ /* Flags */
+ lua_pushstring (L, "flag");
+ lua_newtable (L);
+ lua_pushinteger (L, KANN_F_IN);
+ lua_setfield (L, -2, "in");
+ lua_pushinteger (L, KANN_F_COST);
+ lua_setfield (L, -2, "cost");
+ lua_pushinteger (L, KANN_F_OUT);
+ lua_setfield (L, -2, "out");
+ lua_pushinteger (L, KANN_F_TRUTH);
+ lua_setfield (L, -2, "truth");
+ lua_settable (L, -3);
+
+ /* Cost type */
+ lua_pushstring (L, "cost");
+ lua_newtable (L);
+ /* binary cross-entropy cost, used with sigmoid */
+ lua_pushinteger (L, KANN_C_CEB);
+ lua_setfield (L, -2, "ceb");
+ /* multi-class cross-entropy cost, used with softmax */
+ lua_pushinteger (L, KANN_C_CEM);
+ lua_setfield (L, -2, "cem");
+ /* binary cross-entropy-like cost, used with tanh */
+ lua_pushinteger (L, KANN_C_CEB_NEG);
+ lua_setfield (L, -2, "ceb_neg");
+ lua_pushinteger (L, KANN_C_MSE);
+ lua_setfield (L, -2, "mse");
+ lua_settable (L, -3);
+
+ /* RNN flag */
+ lua_pushstring (L, "rnn");
+ lua_newtable (L);
+ /* apply layer normalization */
+ lua_pushinteger (L, KANN_RNN_NORM);
+ lua_setfield (L, -2, "norm");
+ /* take the initial hidden values as variables */
+ lua_pushinteger (L, KANN_RNN_VAR_H0);
+ lua_setfield (L, -2, "var_h0");
+ lua_settable (L, -3);
+
+ /* Layers */
+ lua_pushstring (L, "layer");
+ lua_newtable (L);
+ luaL_register (L, NULL, rspamd_kann_layers_f);
+ lua_settable (L, -3);
+
+ /* Transforms */
+ lua_pushstring (L, "transform");
+ lua_newtable (L);
+ luaL_register (L, NULL, rspamd_kann_transform_f);
+ lua_settable (L, -3);
+
+ /* Cost */
+ lua_pushstring (L, "loss");
+ lua_newtable (L);
+ luaL_register (L, NULL, rspamd_kann_loss_f);
+ lua_settable (L, -3);
+
+ return 1;
+}
+
+static kad_node_t *
+lua_check_kann_node (lua_State *L, int pos)
+{
+ void *ud = rspamd_lua_check_udata (L, pos, KANN_NODE_CLASS);
+ luaL_argcheck (L, ud != NULL, pos, "'kann_node' expected");
+ return ud ? *((kad_node_t **)ud) : NULL;
+}
+
+void luaopen_kann (lua_State *L)
+{
+ /* Metatables */
+ rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL);
+ lua_pop (L, 1); /* No need in metatable... */
+ rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
+ lua_settop (L, 0);
+}
+
+/* Layers implementation */
+#define PUSH_KAD_NODE(n) do { \
+ kad_node_t **pt; \
+ pt = lua_newuserdata (L, sizeof (kad_node_t *)); \
+ *pt = (n); \
+ rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
+} while(0)
+
+#define PROCESS_KAD_FLAGS(n, pos) do { \
+ int fl = 0; \
+ if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
+ else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
+ (n)->ext_flag = fl; \
+}while(0)
+
+static int
+lua_kann_layer_input (lua_State *L)
+{
+ gint nnodes = luaL_checkinteger (L, 1);
+
+ if (nnodes > 0) {
+ kad_node_t *t;
+
+ t = kann_layer_input (nnodes);
+
+ PROCESS_KAD_FLAGS (t, 2);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, nnodes required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_dense (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ gint nnodes = luaL_checkinteger (L, 2);
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ t = kann_layer_dense (in, nnodes);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_layerdropout (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ double r = luaL_checknumber (L, 2);
+
+ if (in != NULL) {
+ kad_node_t *t;
+
+ t = kann_layer_dropout (in, r);
+
+ PROCESS_KAD_FLAGS (t, 3);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input + rate required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_layernorm (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+
+ if (in != NULL) {
+ kad_node_t *t;
+
+ t = kann_layer_layernorm (in);
+
+ PROCESS_KAD_FLAGS (t, 2);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_rnn (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ gint nnodes = luaL_checkinteger (L, 2);
+ gint rnnflags = 0;
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ if (lua_type (L, 3) == LUA_TNUMBER) {
+ rnnflags = lua_tointeger (L, 3);
+ }
+
+ t = kann_layer_rnn (in, nnodes, rnnflags);
+
+ PROCESS_KAD_FLAGS (t, 4);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_lstm (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ gint nnodes = luaL_checkinteger (L, 2);
+ gint rnnflags = 0;
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ if (lua_type (L, 3) == LUA_TNUMBER) {
+ rnnflags = lua_tointeger (L, 3);
+ }
+
+ t = kann_layer_lstm (in, nnodes, rnnflags);
+
+ PROCESS_KAD_FLAGS (t, 4);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_gru (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ gint nnodes = luaL_checkinteger (L, 2);
+ gint rnnflags = 0;
+
+ if (in != NULL && nnodes > 0) {
+ kad_node_t *t;
+
+ if (lua_type (L, 3) == LUA_TNUMBER) {
+ rnnflags = lua_tointeger (L, 3);
+ }
+
+ t = kann_layer_gru (in, nnodes, rnnflags);
+
+ PROCESS_KAD_FLAGS (t, 4);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input + nnodes required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_conv2d (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ int n_flt = luaL_checkinteger (L, 2);
+ int k_rows = luaL_checkinteger (L, 3);
+ int k_cols = luaL_checkinteger (L, 4);
+ int stride_r = luaL_checkinteger (L, 5);
+ int stride_c = luaL_checkinteger (L, 6);
+ int pad_r = luaL_checkinteger (L, 7);
+ int pad_c = luaL_checkinteger (L, 8);
+
+ if (in != NULL) {
+ kad_node_t *t;
+ t = kann_layer_conv2d (in, n_flt, k_rows, k_cols, stride_r, stride_c,
+ pad_r, pad_c);
+
+ PROCESS_KAD_FLAGS (t, 9);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_conv1d (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ int n_flt = luaL_checkinteger (L, 2);
+ int k_size = luaL_checkinteger (L, 3);
+ int stride = luaL_checkinteger (L, 4);
+ int pad = luaL_checkinteger (L, 5);
+
+ if (in != NULL) {
+ kad_node_t *t;
+ t = kann_layer_conv1d (in, n_flt, k_size, stride, pad);
+
+ PROCESS_KAD_FLAGS (t, 6);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input, nflt, k, stride, pad required");
+ }
+
+ return 1;
+}
+
+static int
+lua_kann_layer_cost (lua_State *L)
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+ int nout = luaL_checkinteger (L, 2);
+ int cost_type = luaL_checkinteger (L, 3);
+
+ if (in != NULL && nout > 0) {
+ kad_node_t *t;
+ t = kann_layer_cost (in, nout, cost_type);
+
+ PROCESS_KAD_FLAGS (t, 4);
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments, input, nout and cost_type are required");
+ }
+
+ return 1;
+}
+
+/* Generic helpers */
+static int
+lua_kann_call_unary_function (lua_State *L, const char *name,
+ kad_node_t *(*func)(kad_node_t *))
+{
+ kad_node_t *in = lua_check_kann_node (L, 1);
+
+ if (in != NULL) {
+ kad_node_t *t;
+ t = func (in);
+
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for %s, input required", name);
+ }
+
+ return 1;
+}
+static int
+lua_kann_call_binary_function (lua_State *L, const char *name,
+ kad_node_t *(*func)(kad_node_t *, kad_node_t *))
+{
+ kad_node_t *x = lua_check_kann_node (L, 1);
+ kad_node_t *y = lua_check_kann_node (L, 2);
+
+ if (x != NULL && y != NULL) {
+ kad_node_t *t;
+ t = func (x, y);
+
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for %s, 2 inputs required", name);
+ }
+
+ return 1;
+}
+
+#define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \
+static int lua_kann_transform_ ##name (lua_State *L) \
+{ \
+ return lua_kann_call_unary_function(L, #name, kad_##name); \
+}
+
+#define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \
+static int lua_kann_transform_ ##name (lua_State *L) \
+{ \
+ return lua_kann_call_binary_function(L, #name, kad_##name); \
+}
+
+#define LUA_LOSS_FUNC_IMPL(name) \
+static int lua_kann_loss_ ##name (lua_State *L) \
+{ \
+ return lua_kann_call_binary_function(L, #name, kad_##name); \
+}
+
+/* Transform functions registered via macro helpers */
+LUA_BINARY_TRANSFORM_FUNC_IMPL (add)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (sub)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (mul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (cmul)
+LUA_BINARY_TRANSFORM_FUNC_IMPL (matmul)
+
+LUA_UNARY_TRANSFORM_FUNC_IMPL (square)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (sigm)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (tanh)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (relu)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (softmax)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (1minus)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (exp)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (log)
+LUA_UNARY_TRANSFORM_FUNC_IMPL (sin)
+
+/* Generic cost functions */
+LUA_LOSS_FUNC_IMPL (mse)
+LUA_LOSS_FUNC_IMPL (ce_multi)
+LUA_LOSS_FUNC_IMPL (ce_bin)
+LUA_LOSS_FUNC_IMPL (ce_bin_neg)
+
+/* The only case of ternary weight function */
+static int
+lua_kann_loss_ce_multi_weighted (lua_State *L)
+{
+ kad_node_t *pred = lua_check_kann_node (L, 1);
+ kad_node_t *truth = lua_check_kann_node (L, 2);
+ kad_node_t *weight = lua_check_kann_node (L, 3);
+
+ if (pred != NULL && truth != NULL && weight != NULL) {
+ kad_node_t *t;
+ t = kad_ce_multi_weighted (pred, truth, weight);
+
+ PUSH_KAD_NODE (t);
+ }
+ else {
+ return luaL_error (L, "invalid arguments for ce_multi_weighted, 3 inputs required");
+ }
+
+ return 1;
+}
\ No newline at end of file