aboutsummaryrefslogtreecommitdiffstats
path: root/src/libstat/stat_process.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-24 15:03:29 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-24 15:03:29 +0100
commite1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed (patch)
tree718bee26b22d6582e63a7bffc8d50104b866c131 /src/libstat/stat_process.c
parent701a711049ee01373bc3862cc441fc3065c8dbc2 (diff)
downloadrspamd-e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed.tar.gz
rspamd-e1fadcc80b5f6a3d566224b0ed1a74d7a9dbc9ed.zip
[Feature] Improve autolearning
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r--src/libstat/stat_process.c76
1 files changed, 75 insertions, 1 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index 034e1a5be..d720a77ab 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -906,6 +906,19 @@ rspamd_stat_has_classifier_symbols (struct rspamd_task *task,
return FALSE;
}
+struct cl_cbref_dtor_data {
+ lua_State *L;
+ gint ref_idx;
+};
+
+static void
+rspamd_stat_cbref_dtor (void *d)
+{
+ struct cl_cbref_dtor_data *data = (struct cl_cbref_dtor_data *)d;
+
+ luaL_unref (data->L, LUA_REGISTRYINDEX, data->ref_idx);
+}
+
gboolean
rspamd_stat_check_autolearn (struct rspamd_task *task)
{
@@ -925,6 +938,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
st_ctx = rspamd_stat_get_ctx ();
g_assert (st_ctx != NULL);
+ L = task->cfg->lua_state;
+
for (i = 0; i < st_ctx->classifiers->len; i ++) {
cl = g_ptr_array_index (st_ctx->classifiers, i);
ret = FALSE;
@@ -933,6 +948,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
obj = ucl_object_lookup (cl->cfg->opts, "autolearn");
if (ucl_object_type (obj) == UCL_BOOLEAN) {
+ /* Legacy true/false */
if (ucl_object_toboolean (obj)) {
/*
* Default learning algorithm:
@@ -956,6 +972,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
}
}
else if (ucl_object_type (obj) == UCL_ARRAY && obj->len == 2) {
+ /* Legacy thresholds */
/*
* We have an array of 2 elements, treat it as a
* ham_score, spam_score
@@ -994,8 +1011,8 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
}
}
else if (ucl_object_type (obj) == UCL_STRING) {
+ /* Legacy sript */
lua_script = ucl_object_tostring (obj);
- L = task->cfg->lua_state;
if (luaL_dostring (L, lua_script) != 0) {
msg_err_task ("cannot execute lua script for autolearn "
@@ -1018,6 +1035,7 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
else {
lua_ret = lua_tostring (L, -1);
+ /* We can have immediate results */
if (lua_ret) {
if (strcmp (lua_ret, "ham") == 0) {
task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
@@ -1041,6 +1059,62 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
}
}
}
+ else if (ucl_object_type (obj) == UCL_OBJECT) {
+ /* Try to find autolearn callback */
+ if (cl->autolearn_cbref == 0) {
+ /* We don't have preprocessed cb id, so try to get it */
+ if (!rspamd_lua_require_function (L, "lua_bayes_learn",
+ "autolearn")) {
+ msg_err_task ("cannot get autolearn library from "
+ "`lua_bayes_learn`");
+ }
+ else {
+ struct cl_cbref_dtor_data *dtor_data;
+
+ dtor_data = (struct cl_cbref_dtor_data *)
+ rspamd_mempool_alloc (task->cfg->cfg_pool,
+ sizeof (*dtor_data));
+ cl->autolearn_cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+ dtor_data->L = L;
+ dtor_data->ref_idx = cl->autolearn_cbref;
+ rspamd_mempool_add_destructor (task->cfg->cfg_pool,
+ rspamd_stat_cbref_dtor, dtor_data);
+ }
+ }
+
+ if (cl->autolearn_cbref != -1) {
+ lua_pushcfunction (L, &rspamd_lua_traceback);
+ err_idx = lua_gettop (L);
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cl->autolearn_cbref);
+
+ ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+ *ptask = task;
+ rspamd_lua_setclass (L, "rspamd{task}", -1);
+ /* Push the whole object as well */
+ ucl_object_push_lua (L, obj, true);
+
+ if (lua_pcall (L, 2, 1, err_idx) != 0) {
+ msg_err_task ("call to autolearn script failed: "
+ "%s", lua_tostring (L, -1));
+ }
+ else {
+ lua_ret = lua_tostring (L, -1);
+
+ if (lua_ret) {
+ if (strcmp (lua_ret, "ham") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+ ret = TRUE;
+ }
+ else if (strcmp (lua_ret, "spam") == 0) {
+ task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+ ret = TRUE;
+ }
+ }
+ }
+
+ lua_settop (L, err_idx - 1);
+ }
+ }
if (ret) {
/* Do not autolearn if we have this symbol already */