aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_kann.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua/lua_kann.c')
-rw-r--r--src/lua/lua_kann.c51
1 files changed, 44 insertions, 7 deletions
diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c
index 3d50cc587..a1b31014d 100644
--- a/src/lua/lua_kann.c
+++ b/src/lua/lua_kann.c
@@ -143,13 +143,13 @@ static luaL_reg rspamd_kann_new_f[] = {
LUA_FUNCTION_DEF (kann, load);
LUA_FUNCTION_DEF (kann, destroy);
LUA_FUNCTION_DEF (kann, save);
-LUA_FUNCTION_DEF (kann, train);
-LUA_FUNCTION_DEF (kann, forward);
+LUA_FUNCTION_DEF (kann, train1);
+LUA_FUNCTION_DEF (kann, apply1);
static luaL_reg rspamd_kann_m[] = {
LUA_INTERFACE_DEF (kann, save),
- LUA_INTERFACE_DEF (kann, train),
- LUA_INTERFACE_DEF (kann, forward),
+ LUA_INTERFACE_DEF (kann, train1),
+ LUA_INTERFACE_DEF (kann, apply1),
{"__gc", lua_kann_destroy},
{NULL, NULL},
};
@@ -985,7 +985,7 @@ lua_kann_load (lua_State *L)
}
static int
-lua_kann_train (lua_State *L)
+lua_kann_train1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
@@ -993,9 +993,46 @@ lua_kann_train (lua_State *L)
}
static int
-lua_kann_forward (lua_State *L)
+lua_kann_apply1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
- g_assert_not_reached (); /* TODO: implement */
+ if (k && lua_istable (L, 2)) {
+ gsize vec_len = rspamd_lua_table_size (L, 2);
+ float *vec = (float *)g_malloc (sizeof (float) * vec_len);
+ int i_out;
+
+ for (gsize i = 0; i < vec_len; i ++) {
+ lua_rawgeti (L, 2, i + 1);
+ vec[i] = lua_tonumber (L, -1);
+ lua_pop (L, 1);
+ }
+
+ i_out = kann_find (k, KANN_F_OUT, 0);
+
+ if (i_out <= 0) {
+ g_free (vec);
+ return luaL_error (L, "invalid ANN: output layer is missing or is "
+ "at the input pos");
+ }
+
+ kann_set_batch_size (k, 1);
+ kann_feed_bind (k, KANN_F_IN, 0, &vec);
+ kad_eval_at (k->n, k->v, i_out);
+
+ gsize outlen = kad_len (k->v[i_out]);
+ lua_createtable (L, outlen, 0);
+
+ for (gsize i = 0; i < outlen; i ++) {
+ lua_pushnumber (L, k->v[i_out]->x[i]);
+ lua_rawseti (L, -2, i + 1);
+ }
+
+ g_free (vec);
+ }
+ else {
+ return luaL_error (L, "invalid arguments: rspamd{kann} expected");
+ }
+
+ return 1;
} \ No newline at end of file