]> source.dussan.org Git - rspamd.git/commitdiff
Allow lua script call for autolearn
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 7 Jan 2016 09:19:48 +0000 (09:19 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 7 Jan 2016 09:19:48 +0000 (09:19 +0000)
src/libstat/stat_process.c

index cc34a53336417bfcb7816411e02a1aae09344881..9eeaeaf10528ed6d0202517c925670b603229703 100644 (file)
@@ -681,9 +681,14 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
        struct rspamd_classifier *cl;
        const ucl_object_t *obj, *elt1, *elt2;
        struct metric_result *mres;
+       struct rspamd_task **ptask;
+       lua_State *L;
+       GString *tb;
        guint i;
+       gint err_idx;
        gboolean ret = FALSE;
        gdouble ham_score, spam_score;
+       const gchar *lua_script, *lua_ret;
 
        g_assert (RSPAMD_TASK_IS_CLASSIFIED (task));
        st_ctx = rspamd_stat_get_ctx ();
@@ -696,7 +701,6 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                if (cl->cfg->opts) {
                        obj = ucl_object_find_key (cl->cfg->opts, "autolearn");
 
-                       /* TODO: support range and lua for this option */
                        if (ucl_object_type (obj) == UCL_BOOLEAN) {
                                if (ucl_object_toboolean (obj)) {
                                        /*
@@ -763,6 +767,56 @@ rspamd_stat_check_autolearn (struct rspamd_task *task)
                                        }
                                }
                        }
+                       else if (ucl_object_type (obj) == UCL_STRING) {
+                               lua_script = ucl_object_tostring (obj);
+                               L = task->cfg->lua_state;
+
+                               if (luaL_dostring (L, lua_script) != 0) {
+                                       msg_err_task ("cannot execute lua script for autolearn "
+                                                       "extraction: %s", lua_tostring (L, -1));
+                               }
+                               else {
+                                       if (lua_type (L, -1) == LUA_TFUNCTION) {
+                                               lua_pushcfunction (L, &rspamd_lua_traceback);
+                                               err_idx = lua_gettop (L);
+                                               lua_pushvalue (L, -2); /* Function itself */
+
+                                               ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+                                               *ptask = task;
+                                               rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+                                               if (lua_pcall (L, 1, 1, err_idx) != 0) {
+                                                       tb = lua_touserdata (L, -1);
+                                                       msg_err_task ("call to autolearn script failed: "
+                                                                       "%v", tb);
+                                                       g_string_free (tb, TRUE);
+                                               }
+                                               else {
+                                                       lua_ret = lua_tostring (L, -1);
+
+                                                       if (lua_ret) {
+                                                               if (strcmp (lua_ret, "ham") == 0) {
+                                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM;
+                                                                       ret = TRUE;
+                                                               }
+                                                               else if (strcmp (lua_ret, "spam") == 0) {
+                                                                       task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM;
+                                                                       ret = TRUE;
+                                                               }
+                                                       }
+                                               }
+
+                                               /* Result + error function + original function */
+                                               lua_pop (L, 3);
+                                       }
+                                       else {
+                                               msg_err_task ("lua script must return "
+                                                               "function(task) and not %s",
+                                                               lua_typename (L, lua_type (
+                                                                               L, -1)));
+                                       }
+                               }
+                       }
 
                        if (ret) {
                                /* Do not autolearn if we have this symbol already */