aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--contrib/kann/kann.c12
-rw-r--r--src/lua/lua_kann.c170
2 files changed, 178 insertions, 4 deletions
diff --git a/contrib/kann/kann.c b/contrib/kann/kann.c
index 3fbf139cc..43227bdc6 100644
--- a/contrib/kann/kann.c
+++ b/contrib/kann/kann.c
@@ -670,7 +670,8 @@ kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return
kad_node_t *kann_layer_input(int n1)
{
kad_node_t *t;
- t = kad_feed(2, 1, n1), t->ext_flag |= KANN_F_IN;
+ t = kad_feed(2, 1, n1);
+ t->ext_flag |= KANN_F_IN;
return t;
}
@@ -761,6 +762,7 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
t = kann_layer_dense(t, n_out);
truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
+
if (cost_type == KANN_C_MSE) {
cost = kad_mse(t, truth);
} else if (cost_type == KANN_C_CEB) {
@@ -773,7 +775,13 @@ kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
t = kad_softmax(t);
cost = kad_ce_multi(t, truth);
}
- t->ext_flag |= KANN_F_OUT, cost->ext_flag |= KANN_F_COST;
+ else {
+ assert (0);
+ }
+
+ t->ext_flag |= KANN_F_OUT;
+ cost->ext_flag |= KANN_F_COST;
+
return cost;
}
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index a1b31014d..609f05539 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -295,7 +295,7 @@ void luaopen_kann (lua_State *L)
int fl = 0; \
if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
- (n)->ext_flag = fl; \
+ (n)->ext_flag |= fl; \
}while(0)
/***
@@ -984,12 +984,168 @@ lua_kann_load (lua_State *L)
return 1;
}
+struct rspamd_kann_train_cbdata {
+ lua_State *L;
+ kann_t *k;
+ gint cbref;
+};
+
+static void
+lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
+{
+ struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
+
+ if (cbd->cbref != -1) {
+ gint err_idx;
+ lua_State *L = cbd->L;
+
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
+ lua_pushinteger (L, iter);
+ lua_pushnumber (L, train_cost);
+ lua_pushnumber (L, val_cost);
+
+ if (lua_pcall (L, 3, 0, err_idx) != 0) {
+ msg_err ("cannot run lua train callback: %s",
+ lua_tostring (L, -1));
+ }
+
+ lua_settop (L, err_idx - 1);
+ }
+}
+
+#define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
+
static int
lua_kann_train1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
- g_assert_not_reached (); /* TODO: implement */
+ /* Default train params */
+ double lr = 0.001;
+ gint64 mini_size = 64;
+ gint64 max_epoch = 25;
+ gint64 max_drop_streak = 10;
+ double frac_val = 0.1;
+ gint cbref = -1;
+
+ if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
+ int n = rspamd_lua_table_size (L, 2);
+ int n_in = kann_dim_in (k);
+ int n_out = kann_dim_out (k);
+
+ if (n_in <= 0) {
+ return luaL_error (L, "invalid inputs count: %d", n_in);
+ }
+
+ if (n_out <= 0) {
+ return luaL_error (L, "invalid outputs count: %d", n_in);
+ }
+
+ if (n != rspamd_lua_table_size (L, 3) || n == 0) {
+ return luaL_error (L, "invalid dimensions: outputs size must be "
+ "equal to inputs and non zero");
+ }
+
+ if (lua_istable (L, 4)) {
+ GError *err = NULL;
+
+ if (!rspamd_lua_parse_table_arguments (L, 4, &err,
+ RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
+ "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
+ &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+ n = luaL_error (L, "invalid params: %s",
+ err ? err->message : "unknown error");
+ g_error_free (err);
+
+ return n;
+ }
+ }
+
+ float **x, **y;
+
+ /* Fill vectors */
+ x = (float **)g_malloc (sizeof (float *) * n);
+ y = (float **)g_malloc (sizeof (float *) * n);
+
+ for (int s = 0; s < n; s ++) {
+ /* Inputs */
+ lua_rawgeti (L, 2, s + 1);
+ x[s] = (float *)g_malloc (sizeof (float) * n_in);
+
+ if (rspamd_lua_table_size (L, -1) != n_in) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+
+ n = luaL_error (L, "invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int)rspamd_lua_table_size (L, -1),
+ n_in);
+
+ return n;
+ }
+
+ for (int i = 0; i < n_in; i ++) {
+ lua_rawgeti (L, -1, i + 1);
+ x[s][i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1);
+
+ /* Outputs */
+ y[s] = (float *)g_malloc (sizeof (float) * n_out);
+ lua_rawgeti (L, 3, s + 1);
+
+ if (rspamd_lua_table_size (L, -1) != n_out) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+
+ n = luaL_error (L, "invalid params at pos %d: "
+ "bad output dimension %d; "
+ "%d expected",
+ s + 1,
+ (int)rspamd_lua_table_size (L, -1),
+ n_out);
+
+ return n;
+ }
+
+ for (int i = 0; i < n_out; i ++) {
+ lua_rawgeti (L, -1, i + 1);
+ y[s][i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
+
+ lua_pop (L, 1);
+ }
+
+ struct rspamd_kann_train_cbdata cbd;
+
+ cbd.cbref = cbref;
+ cbd.k = k;
+ cbd.L = L;
+
+ int niters = kann_train_fnn1 (k, lr,
+ mini_size, max_epoch, max_drop_streak,
+ frac_val, n, x, y, lua_kann_train_cb, &cbd);
+
+ lua_pushinteger (L, niters);
+
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+ }
+ else {
+ return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
+ " optional params are expected");
+ }
+
+ return 1;
}
static int
@@ -1001,6 +1157,16 @@ lua_kann_apply1 (lua_State *L)
gsize vec_len = rspamd_lua_table_size (L, 2);
float *vec = (float *)g_malloc (sizeof (float) * vec_len);
int i_out;
+ int n_in = kann_dim_in (k);
+
+ if (n_in <= 0) {
+ return luaL_error (L, "invalid inputs count: %d", n_in);
+ }
+
+ if (n_in != vec_len) {
+ return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+ (int)vec_len, n_in);
+ }
for (gsize i = 0; i < vec_len; i ++) {
lua_rawgeti (L, 2, i + 1);