From 1c1215868969dc988d522f8858ba9444da8a0997 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 7 Jan 2016 09:19:48 +0000 Subject: Allow lua script call for autolearn --- src/libstat/stat_process.c | 56 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) (limited to 'src/libstat/stat_process.c') diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index cc34a5333..9eeaeaf10 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -681,9 +681,14 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) struct rspamd_classifier *cl; const ucl_object_t *obj, *elt1, *elt2; struct metric_result *mres; + struct rspamd_task **ptask; + lua_State *L; + GString *tb; guint i; + gint err_idx; gboolean ret = FALSE; gdouble ham_score, spam_score; + const gchar *lua_script, *lua_ret; g_assert (RSPAMD_TASK_IS_CLASSIFIED (task)); st_ctx = rspamd_stat_get_ctx (); @@ -696,7 +701,6 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) if (cl->cfg->opts) { obj = ucl_object_find_key (cl->cfg->opts, "autolearn"); - /* TODO: support range and lua for this option */ if (ucl_object_type (obj) == UCL_BOOLEAN) { if (ucl_object_toboolean (obj)) { /* @@ -763,6 +767,56 @@ rspamd_stat_check_autolearn (struct rspamd_task *task) } } } + else if (ucl_object_type (obj) == UCL_STRING) { + lua_script = ucl_object_tostring (obj); + L = task->cfg->lua_state; + + if (luaL_dostring (L, lua_script) != 0) { + msg_err_task ("cannot execute lua script for autolearn " + "extraction: %s", lua_tostring (L, -1)); + } + else { + if (lua_type (L, -1) == LUA_TFUNCTION) { + lua_pushcfunction (L, &rspamd_lua_traceback); + err_idx = lua_gettop (L); + lua_pushvalue (L, -2); /* Function itself */ + + ptask = lua_newuserdata (L, sizeof (struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass (L, "rspamd{task}", -1); + + if (lua_pcall (L, 1, 1, err_idx) != 0) { + tb = lua_touserdata (L, -1); + msg_err_task ("call to autolearn script failed: " + "%v", tb); + g_string_free (tb, TRUE); + } + else { + lua_ret = lua_tostring (L, -1); + + if (lua_ret) { + if (strcmp (lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp (lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + /* Result + error function + original function */ + lua_pop (L, 3); + } + else { + msg_err_task ("lua script must return " + "function(task) and not %s", + lua_typename (L, lua_type ( + L, -1))); + } + } + } if (ret) { /* Do not autolearn if we have this symbol already */ -- cgit v1.2.3