]> source.dussan.org Git - rspamd.git/commitdiff
Allow custom lua scripts for users/languages extraction.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 2 Oct 2015 12:23:03 +0000 (13:23 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 2 Oct 2015 12:23:03 +0000 (13:23 +0100)
Issue: #388

src/libstat/backends/sqlite3_backend.c

index 33aac049456048beca97c3025998e647b1b7b5f6..215a80bfc667626c2515f6859224065cdac98e51 100644 (file)
 #include "libutil/sqlite_utils.h"
 #include "libstat/stat_internal.h"
 #include "libmime/message.h"
+#include "lua/lua_common.h"
 
 #define SQLITE3_BACKEND_TYPE "sqlite3"
 #define SQLITE3_SCHEMA_VERSION "1"
 #define SQLITE3_DEFAULT "default"
 
 struct rspamd_stat_sqlite3_db {
+       struct rspamd_stat_sqlite3_ctx *ctx;
        sqlite3 *sqlite;
        gchar *fname;
        GArray *prstmt;
        gboolean in_transaction;
        gboolean enable_users;
        gboolean enable_languages;
+       gint cbref_user;
+       gint cbref_language;
 };
 
 struct rspamd_stat_sqlite3_ctx {
        GHashTable *files;
        rspamd_mempool_t *pool;
+       lua_State *L;
 };
 
 struct rspamd_stat_sqlite3_rt {
@@ -282,31 +287,60 @@ rspamd_sqlite3_get_user (struct rspamd_stat_sqlite3_db *db,
                struct rspamd_task *task, gboolean learn)
 {
        gint64 id = 0; /* Default user is 0 */
-       gint rc;
+       gint rc, err_idx;
        const gchar *user = NULL;
        const InternetAddress *ia;
-
-       if (task->deliver_to != NULL) {
-               /* Use deliver-to value if presented */
-               user = task->deliver_to;
-       }
-       if (task->user != NULL) {
-               /* Use user value if presented */
-               user = task->user;
-       }
-       else if (task->rcpt_envelope != NULL) {
-               /* Check envelope recipients */
-               if (internet_address_list_length (task->rcpt_envelope) == 1) {
-                       /* XXX: we support now merely single recipient statistics */
-                       ia = internet_address_list_get_address (task->rcpt_envelope, 0);
-
-                       if (ia != NULL) {
-                               user = internet_address_mailbox_get_addr (INTERNET_ADDRESS_MAILBOX (ia));
+       struct rspamd_task **ptask;
+       lua_State *L = db->ctx->L;
+       GString *tb;
+
+       if (db->cbref_user == -1) {
+               if (task->deliver_to != NULL) {
+                       /* Use deliver-to value if presented */
+                       user = task->deliver_to;
+               }
+               if (task->user != NULL) {
+                       /* Use user value if presented */
+                       user = task->user;
+               }
+               else if (task->rcpt_envelope != NULL) {
+                       /* Check envelope recipients */
+                       if (internet_address_list_length (task->rcpt_envelope) == 1) {
+                               /* XXX: we support now merely single recipient statistics */
+                               ia = internet_address_list_get_address (task->rcpt_envelope, 0);
+
+                               if (ia != NULL) {
+                                       user = internet_address_mailbox_get_addr (
+                                                       INTERNET_ADDRESS_MAILBOX (ia));
+                               }
                        }
                }
+               /* XXX: We ignore now mime recipients as they could be easily forged */
        }
+       else {
+               /* Execute lua function to get userdata */
+               lua_pushcfunction (L, &rspamd_lua_traceback);
+               err_idx = lua_gettop (L);
+
+               lua_rawgeti (L, LUA_REGISTRYINDEX, db->cbref_user);
+               ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+               *ptask = task;
+               rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+               if (lua_pcall (L, 1, 1, err_idx) != 0) {
+                       tb = lua_touserdata (L, -1);
+                       msg_err_task ("call to user extraction script failed: %v", tb);
+                       g_string_free (tb, TRUE);
+               }
+               else {
+                       user = rspamd_mempool_strdup (task->task_pool, lua_tostring (L, -1));
+               }
+
+               /* Result + error function */
+               lua_pop (L, 2);
+       }
+
 
-       /* XXX: We ignore now mime recipients as they could be easily forged */
        if (user != NULL) {
                rc = rspamd_sqlite3_run_prstmt (task->task_pool, db->sqlite, db->prstmt,
                                RSPAMD_STAT_BACKEND_GET_USER, user, &id);
@@ -332,21 +366,50 @@ rspamd_sqlite3_get_language (struct rspamd_stat_sqlite3_db *db,
                struct rspamd_task *task, gboolean learn)
 {
        gint64 id = 0; /* Default language is 0 */
-       gint rc;
+       gint rc, err_idx;
        guint i;
        const gchar *language = NULL;
        struct mime_text_part *tp;
-
-       for (i = 0; i < task->text_parts->len; i ++) {
-               tp = g_ptr_array_index (task->text_parts, i);
-
-               if (tp->lang_code != NULL && tp->lang_code[0] != '\0' &&
-                               strcmp (tp->lang_code, "en") != 0) {
-                       language = tp->language;
-                       break;
+       struct rspamd_task **ptask;
+       lua_State *L = db->ctx->L;
+       GString *tb;
+
+       if (db->cbref_language == -1) {
+               for (i = 0; i < task->text_parts->len; i++) {
+                       tp = g_ptr_array_index (task->text_parts, i);
+
+                       if (tp->lang_code != NULL && tp->lang_code[0] != '\0' &&
+                                       strcmp (tp->lang_code, "en") != 0) {
+                               language = tp->language;
+                               break;
+                       }
+               }
+       }
+       else {
+               /* Execute lua function to get userdata */
+               lua_pushcfunction (L, &rspamd_lua_traceback);
+               err_idx = lua_gettop (L);
+
+               lua_rawgeti (L, LUA_REGISTRYINDEX, db->cbref_language);
+               ptask = lua_newuserdata (L, sizeof (struct rspamd_task *));
+               *ptask = task;
+               rspamd_lua_setclass (L, "rspamd{task}", -1);
+
+               if (lua_pcall (L, 1, 1, err_idx) != 0) {
+                       tb = lua_touserdata (L, -1);
+                       msg_err_task ("call to language extraction script failed: %v", tb);
+                       g_string_free (tb, TRUE);
+               }
+               else {
+                       language = rspamd_mempool_strdup (task->task_pool,
+                                       lua_tostring (L, -1));
                }
+
+               /* Result + error function */
+               lua_pop (L, 2);
        }
 
+
        /* XXX: We ignore multiple languages but default + extra */
        if (language != NULL) {
                rc = rspamd_sqlite3_run_prstmt (task->task_pool, db->sqlite, db->prstmt,
@@ -445,13 +508,14 @@ rspamd_sqlite3_init (struct rspamd_stat_ctx *ctx,
        struct rspamd_statfile_config *stf;
        GList *cur, *curst;
        const ucl_object_t *filenameo, *lang_enabled, *users_enabled;
-       const gchar *filename;
+       const gchar *filename, *lua_script;
        struct rspamd_stat_sqlite3_db *bk;
        GError *err = NULL;
 
        new = rspamd_mempool_alloc0 (cfg->cfg_pool, sizeof (*new));
        new->files = g_hash_table_new (g_direct_hash, g_direct_equal);
        new->pool = cfg->cfg_pool;
+       new->L = cfg->lua_state;
 
        /* Iterate over all classifiers and load matching statfiles */
        cur = cfg->classifiers;
@@ -484,6 +548,7 @@ rspamd_sqlite3_init (struct rspamd_stat_ctx *ctx,
                                }
 
                                if (bk != NULL) {
+                                       bk->ctx = new;
                                        g_hash_table_insert (new->files, stf, bk);
                                }
                                else {
@@ -496,7 +561,30 @@ rspamd_sqlite3_init (struct rspamd_stat_ctx *ctx,
                                users_enabled = ucl_object_find_any_key (clf->opts, "per_user",
                                                "users_enabled", NULL);
                                if (users_enabled != NULL) {
-                                       bk->enable_users = ucl_object_toboolean (users_enabled);
+                                       if (ucl_object_type (users_enabled) == UCL_BOOLEAN) {
+                                               bk->enable_users = ucl_object_toboolean (users_enabled);
+                                               bk->cbref_user = -1;
+                                       }
+                                       else if (ucl_object_type (users_enabled) == UCL_STRING) {
+                                               lua_script = ucl_object_tostring (users_enabled);
+
+                                               if (luaL_dostring (new->L, lua_script) != 0) {
+                                                       msg_err_config ("cannot execute lua script for users "
+                                                                       "extraction: %s", lua_tostring (new->L, -1));
+                                               }
+                                               else {
+                                                       if (lua_type (new->L, -1) == LUA_TFUNCTION) {
+                                                               bk->enable_users = TRUE;
+                                                               bk->cbref_user = luaL_ref (new->L,
+                                                                               LUA_REGISTRYINDEX);
+                                                       }
+                                                       else {
+                                                               msg_err_config ("lua script must return "
+                                                                               "function(task) and not %s",
+                                                                               lua_typename (new->L, lua_type (new->L, -1)));
+                                                       }
+                                               }
+                                       }
                                }
                                else {
                                        bk->enable_users = FALSE;
@@ -505,7 +593,33 @@ rspamd_sqlite3_init (struct rspamd_stat_ctx *ctx,
                                lang_enabled = ucl_object_find_any_key (clf->opts,
                                                "per_language", "languages_enabled", NULL);
                                if (lang_enabled != NULL) {
-                                       bk->enable_languages = ucl_object_toboolean (lang_enabled);
+                                       if (ucl_object_type (lang_enabled) == UCL_BOOLEAN) {
+                                               bk->enable_languages = ucl_object_toboolean (lang_enabled);
+                                               bk->cbref_language = -1;
+                                       }
+                                       else if (ucl_object_type (lang_enabled) == UCL_STRING) {
+                                               lua_script = ucl_object_tostring (lang_enabled);
+
+                                               if (luaL_dostring (new->L, lua_script) != 0) {
+                                                       msg_err_config (
+                                                                       "cannot execute lua script for languages "
+                                                                                       "extraction: %s",
+                                                                       lua_tostring (new->L, -1));
+                                               }
+                                               else {
+                                                       if (lua_type (new->L, -1) == LUA_TFUNCTION) {
+                                                               bk->enable_languages = TRUE;
+                                                               bk->cbref_language = luaL_ref (new->L,
+                                                                               LUA_REGISTRYINDEX);
+                                                       }
+                                                       else {
+                                                               msg_err_config ("lua script must return "
+                                                                               "function(task) and not %s",
+                                                                               lua_typename (new->L,
+                                                                                               lua_type (new->L, -1)));
+                                                       }
+                                               }
+                                       }
                                }
                                else {
                                        bk->enable_languages = FALSE;