diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-24 15:03:29 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-24 15:03:29 +0100 |
commit | e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed (patch) | |
tree | 718bee26b22d6582e63a7bffc8d50104b866c131 /src/libstat/stat_process.c | |
parent | 701a711049ee01373bc3862cc441fc3065c8dbc2 (diff) | |
download | rspamd-e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed.tar.gz rspamd-e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed.zip |
[Feature] Improve autolearning
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r-- | src/libstat/stat_process.c | 76 |
1 files changed, 75 insertions, 1 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 034e1a5be..d720a77ab 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -906,6 +906,19 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task, return FALSE; } +struct cl_cbref_dtor_data { + lua_State *L; + gint ref_idx; +}; + +static void +rspamd_stat_cbref_dtor (void *d) +{ + struct cl_cbref_dtor_data *data = (struct cl_cbref_dtor_data *)d; + + luaL_unref (data->L, LUA_REGISTRYINDEX, data->ref_idx); +} + gboolean rspamd_stat_check_autolearn (struct rspamd_task *task) { @@ -925,6 +938,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) st_ctx = rspamd_stat_get_ctx (); g_assert (st_ctx != NULL); + L = task->cfg->lua_state; + for (i = 0; i < st_ctx->classifiers->len; i ++) { cl = g_ptr_array_index (st_ctx->classifiers, i); ret = FALSE; @@ -933,6 +948,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) obj = ucl_object_lookup (cl->cfg->opts, "autolearn"); if (ucl_object_type (obj) == UCL_BOOLEAN) { + /* Legacy true/false */ if (ucl_object_toboolean (obj)) { /* * Default learning algorithm: @@ -956,6 +972,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } else if (ucl_object_type (obj) == UCL_ARRAY && obj->len == 2) { + /* Legacy thresholds */ /* * We have an array of 2 elements, treat it as a * ham_score, spam_score @@ -994,8 +1011,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } else if (ucl_object_type (obj) == UCL_STRING) { + /* Legacy sript */ lua_script = ucl_object_tostring (obj); - L = task->cfg->lua_state; if (luaL_dostring (L, lua_script) != 0) { msg_err_task ("cannot execute lua script for autolearn " @@ -1018,6 +1035,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) else { lua_ret = lua_tostring (L, -1); + /* We can have immediate results */ if (lua_ret) { if (strcmp (lua_ret, "ham") == 0) { task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; @@ -1041,6 +1059,62 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } } + else if (ucl_object_type (obj) == UCL_OBJECT) { + /* Try to find autolearn callback */ + if (cl->autolearn_cbref == 0) { + /* We don't have preprocessed cb id, so try to get it */ + if (!rspamd_lua_require_function (L, "lua_bayes_learn", + "autolearn")) { + msg_err_task ("cannot get autolearn library from " + "`lua_bayes_learn`"); + } + else { + struct cl_cbref_dtor_data *dtor_data; + + dtor_data = (struct cl_cbref_dtor_data *) + rspamd_mempool_alloc (task->cfg->cfg_pool, + sizeof (*dtor_data)); + cl->autolearn_cbref = luaL_ref (L, LUA_REGISTRYINDEX); + dtor_data->L = L; + dtor_data->ref_idx = cl->autolearn_cbref; + rspamd_mempool_add_destructor (task->cfg->cfg_pool, + rspamd_stat_cbref_dtor, dtor_data); + } + } + + if (cl->autolearn_cbref != -1) { + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + lua_rawgeti (L, LUA_REGISTRYINDEX, cl->autolearn_cbref); + + ptask = lua_newuserdata (L, sizeof (struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass (L, "rspamd{task}", -1); + /* Push the whole object as well */ + ucl_object_push_lua (L, obj, true); + + if (lua_pcall (L, 2, 1, err_idx) != 0) { + msg_err_task ("call to autolearn script failed: " + "%s", lua_tostring (L, -1)); + } + else { + lua_ret = lua_tostring (L, -1); + + if (lua_ret) { + if (strcmp (lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp (lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + lua_settop (L, err_idx - 1); + } + } if (ret) { /* Do not autolearn if we have this symbol already */ |