]> source.dussan.org Git - rspamd.git/commitdiff
* Add ability to add maps from lua scripts and access theirs elements
authorVsevolod Stakhov <vsevolod@rambler-co.ru>
Mon, 2 Nov 2009 16:37:06 +0000 (19:37 +0300)
committerVsevolod Stakhov <vsevolod@rambler-co.ru>
Mon, 2 Nov 2009 16:37:06 +0000 (19:37 +0300)
* Add whitelist module for whitelisting score for some ip/from addresses

src/lua/lua_common.c
src/lua/lua_common.h
src/lua/lua_config.c
src/lua/lua_task.c
src/main.c
src/plugins/lua/whitelist.lua [new file with mode: 0644]

index 05829db975a75f7f4ee14480b6d0ddc0f4713f22..7e7caefda59ddfdeeb562e91a89771c8f8563cca 100644 (file)
@@ -191,6 +191,8 @@ init_lua ()
                (void)luaopen_logger (L);
                (void)luaopen_config (L);
                (void)luaopen_metric (L);
+               (void)luaopen_radix (L);
+               (void)luaopen_hash_table (L);
                (void)luaopen_task (L);
                (void)luaopen_textpart (L);
                (void)luaopen_message (L);
index ab06166b3b7f137dbd766ce3e36cb5060fa31e47..9cd966e230ffd4cbd8038097d4fef2b550f078c5 100644 (file)
@@ -24,6 +24,8 @@ int luaopen_message (lua_State *L);
 int luaopen_task (lua_State *L);
 int luaopen_config (lua_State *L);
 int luaopen_metric (lua_State *L);
+int luaopen_radix (lua_State *L);
+int luaopen_hash_table (lua_State *L);
 int luaopen_textpart (lua_State *L);
 void init_lua_filters (struct config_file *cfg);
 
index c25103f77da028727b6a2825c907e7d54db9aa12..defc9508f9d4b4a9ab3f98eec087a71d14d8afb4 100644 (file)
 
 #include "lua_common.h"
 #include "../expressions.h"
+#include "../map.h"
+#include "../radix.h"
 
 /* Config file methods */
 LUA_FUNCTION_DEF (config, get_module_opt);
 LUA_FUNCTION_DEF (config, get_metric);
 LUA_FUNCTION_DEF (config, get_all_opt);
 LUA_FUNCTION_DEF (config, register_function);
+LUA_FUNCTION_DEF (config, add_radix_map);
+LUA_FUNCTION_DEF (config, add_hash_map);
 
 static const struct luaL_reg    configlib_m[] = {
        LUA_INTERFACE_DEF (config, get_module_opt),
        LUA_INTERFACE_DEF (config, get_metric),
        LUA_INTERFACE_DEF (config, get_all_opt),
        LUA_INTERFACE_DEF (config, register_function),
+       LUA_INTERFACE_DEF (config, add_radix_map),
+       LUA_INTERFACE_DEF (config, add_hash_map),
        {"__tostring", lua_class_tostring},
        {NULL, NULL}
 };
@@ -50,6 +56,24 @@ static const struct luaL_reg    metriclib_m[] = {
        {NULL, NULL}
 };
 
+/* Radix tree */
+LUA_FUNCTION_DEF (radix, get_key);
+
+static const struct luaL_reg    radixlib_m[] = {
+       LUA_INTERFACE_DEF (radix, get_key),
+       {"__tostring", lua_class_tostring},
+       {NULL, NULL}
+};
+
+/* Hash table */
+LUA_FUNCTION_DEF (hash_table, get_key);
+
+static const struct luaL_reg    hashlib_m[] = {
+       LUA_INTERFACE_DEF (hash_table, get_key),
+       {"__tostring", lua_class_tostring},
+       {NULL, NULL}
+};
+
 static struct config_file      *
 lua_check_config (lua_State * L)
 {
@@ -66,6 +90,22 @@ lua_check_metric (lua_State * L)
        return *((struct metric **)ud);
 }
 
+static radix_tree_t           *
+lua_check_radix (lua_State * L)
+{
+       void                           *ud = luaL_checkudata (L, 1, "rspamd{radix}");
+       luaL_argcheck (L, ud != NULL, 1, "'radix' expected");
+       return **((radix_tree_t ***)ud);
+}
+
+static GHashTable           *
+lua_check_hash_table (lua_State * L)
+{
+       void                           *ud = luaL_checkudata (L, 1, "rspamd{hash_table}");
+       luaL_argcheck (L, ud != NULL, 1, "'hash_table' expected");
+       return **((GHashTable ***)ud);
+}
+
 /*** Config functions ***/
 static int
 lua_config_get_module_opt (lua_State * L)
@@ -205,6 +245,66 @@ lua_config_register_function (lua_State *L)
        return 0;
 }
 
+static int
+lua_config_add_radix_map (lua_State *L)
+{
+       struct config_file             *cfg = lua_check_config (L);
+       const char                     *map_line;
+       radix_tree_t                   **r, ***ud;
+
+       if (cfg) {
+               map_line = luaL_checkstring (L, 2);
+               r = g_malloc (sizeof (radix_tree_t *));
+               *r = radix_tree_create ();
+               if (!add_map (map_line, read_radix_list, fin_radix_list, (void **)r)) {
+                       msg_warn ("add_radix_map: invalid radix map %s", map_line);
+                       radix_tree_free (*r);
+                       g_free (r);
+                       lua_pushnil (L);
+                       return 1;
+               }
+               ud = lua_newuserdata (L, sizeof (radix_tree_t *));
+               *ud = r;
+               lua_setclass (L, "rspamd{radix}", -1);
+
+               return 1;
+       }
+
+       lua_pushnil (L);
+       return 1;
+
+}
+
+static int
+lua_config_add_hash_map (lua_State *L)
+{
+       struct config_file             *cfg = lua_check_config (L);
+       const char                     *map_line;
+       GHashTable                    **r, ***ud;
+
+       if (cfg) {
+               map_line = luaL_checkstring (L, 2);
+               r = g_malloc (sizeof (GHashTable *));
+               *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal);
+               if (!add_map (map_line, read_host_list, fin_host_list, (void **)r)) {
+                       msg_warn ("add_radix_map: invalid hash map %s", map_line);
+                       g_hash_table_destroy (*r);
+                       g_free (r);
+                       lua_pushnil (L);
+                       return 1;
+               }
+               ud = lua_newuserdata (L, sizeof (GHashTable *));
+               *ud = r;
+               lua_setclass (L, "rspamd{hash_table}", -1);
+
+               return 1;
+       }
+
+       lua_pushnil (L);
+       return 1;
+
+}
+
 /*** Metric functions ***/
 
 
@@ -240,9 +340,48 @@ lua_metric_register_symbol (lua_State * L)
                        cd = g_malloc (sizeof (struct lua_callback_data));
                        cd->name = g_strdup (callback);
                        cd->L = L;
-                       register_symbol (&metric->cache, cd->name, weight, lua_metric_symbol_callback, cd);
+                       register_symbol (&metric->cache, name, weight, lua_metric_symbol_callback, cd);
+               }
+       }
+       return 1;
+}
+
+/* Radix and hash table functions */
+static int
+lua_radix_get_key (lua_State * L)
+{
+       radix_tree_t                  *radix = lua_check_radix (L);
+       uint32_t                       key;
+
+       if (radix) {
+               key = luaL_checkint (L, 2);
+
+               if (radix32tree_find (radix, key) != RADIX_NO_VALUE) {
+                       lua_pushboolean (L, 1);
+                       return 1;
+               }
+       }
+
+       lua_pushboolean (L, 0);
+       return 1;
+}
+
+static int
+lua_hash_table_get_key (lua_State * L)
+{
+       GHashTable                    *tbl = lua_check_hash_table (L);
+       const char                    *key;
+
+       if (tbl) {
+               key = luaL_checkstring (L, 2);
+
+               if (g_hash_table_lookup (tbl, key) != NULL) {
+                       lua_pushboolean (L, 1);
+                       return 1;
                }
        }
+
+       lua_pushboolean (L, 0);
        return 1;
 }
 
@@ -263,3 +402,21 @@ luaopen_metric (lua_State * L)
 
        return 1;
 }
+
+int
+luaopen_radix (lua_State * L)
+{
+       lua_newclass (L, "rspamd{radix}", radixlib_m);
+       luaL_openlib (L, "rspamd_radix", null_reg, 0);
+
+       return 1;
+}
+
+int
+luaopen_hash_table (lua_State * L)
+{
+       lua_newclass (L, "rspamd{hash_table}", hashlib_m);
+       luaL_openlib (L, "rspamd_hash_table", null_reg, 0);
+
+       return 1;
+}
index 7912b7de047e88c4519beda9e18a6d2c00f6c7d4..cfacf3f2f8f2074d1181b6578f3daafdd20188df 100644 (file)
@@ -41,6 +41,8 @@ LUA_FUNCTION_DEF (task, call_rspamd_function);
 LUA_FUNCTION_DEF (task, get_recipients);
 LUA_FUNCTION_DEF (task, get_from);
 LUA_FUNCTION_DEF (task, get_from_ip);
+LUA_FUNCTION_DEF (task, get_from_ip_num);
+LUA_FUNCTION_DEF (task, get_client_ip_num);
 LUA_FUNCTION_DEF (task, get_helo);
 
 static const struct luaL_reg    tasklib_m[] = {
@@ -56,6 +58,8 @@ static const struct luaL_reg    tasklib_m[] = {
        LUA_INTERFACE_DEF (task, get_recipients),
        LUA_INTERFACE_DEF (task, get_from),
        LUA_INTERFACE_DEF (task, get_from_ip),
+       LUA_INTERFACE_DEF (task, get_from_ip_num),
+       LUA_INTERFACE_DEF (task, get_client_ip_num),
        LUA_INTERFACE_DEF (task, get_helo),
        {"__tostring", lua_class_tostring},
        {NULL, NULL}
@@ -436,6 +440,38 @@ lua_task_get_from_ip (lua_State *L)
        return 1;
 }
 
+static int
+lua_task_get_from_ip_num (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+       
+       if (task) {
+               if (task->from_addr.s_addr != 0) {
+                       lua_pushinteger (L, ntohl (task->from_addr.s_addr));
+                       return 1;
+               }
+       }
+
+       lua_pushnil (L);
+       return 1;
+}
+
+static int
+lua_task_get_client_ip_num (lua_State *L)
+{
+       struct worker_task             *task = lua_check_task (L);
+       
+       if (task) {
+               if (task->client_addr.s_addr != 0) {
+                       lua_pushinteger (L, ntohl (task->client_addr.s_addr));
+                       return 1;
+               }
+       }
+
+       lua_pushnil (L);
+       return 1;
+}
+
 static int
 lua_task_get_helo (lua_State *L)
 {
index 4b811bb36c6095c55d0863a56b09051c890766ae..4a02fa1596de19f1f5e2d743cafd33b6e5c41a5d 100644 (file)
@@ -674,24 +674,6 @@ main (int argc, char **argv, char **env)
                exit (-errno);
        }
 
-#ifndef WITHOUT_PERL
-       /* Init perl interpreter */
-       dTHXa (perl_interpreter);
-       PERL_SYS_INIT3 (&argc, &argv, &env);
-       perl_interpreter = perl_alloc ();
-       if (perl_interpreter == NULL) {
-               msg_err ("main: cannot allocate perl interpreter, %s", strerror (errno));
-               exit (-errno);
-       }
-
-       PERL_SET_CONTEXT (perl_interpreter);
-       perl_construct (perl_interpreter);
-       perl_parse (perl_interpreter, xs_init, 3, args, NULL);
-       init_perl_filters (cfg);
-#elif defined(WITH_LUA)
-       init_lua_filters (cfg);
-#endif
-
        /* Block signals to use sigsuspend in future */
        sigprocmask (SIG_BLOCK, &signals.sa_mask, NULL);
 
@@ -726,6 +708,24 @@ main (int argc, char **argv, char **env)
                l = g_list_next (l);
        }
 
+#ifndef WITHOUT_PERL
+       /* Init perl interpreter */
+       dTHXa (perl_interpreter);
+       PERL_SYS_INIT3 (&argc, &argv, &env);
+       perl_interpreter = perl_alloc ();
+       if (perl_interpreter == NULL) {
+               msg_err ("main: cannot allocate perl interpreter, %s", strerror (errno));
+               exit (-errno);
+       }
+
+       PERL_SET_CONTEXT (perl_interpreter);
+       perl_construct (perl_interpreter);
+       perl_parse (perl_interpreter, xs_init, 3, args, NULL);
+       init_perl_filters (cfg);
+#elif defined(WITH_LUA)
+       init_lua_filters (cfg);
+#endif
+
        rspamd->workers = g_hash_table_new (g_direct_hash, g_direct_equal);
        spawn_workers (rspamd, TRUE);
 
diff --git a/src/plugins/lua/whitelist.lua b/src/plugins/lua/whitelist.lua
new file mode 100644 (file)
index 0000000..e95aca6
--- /dev/null
@@ -0,0 +1,73 @@
+-- Module that add symbols to those hosts or from domains that are contained in whitelist
+
+local metric = 'default'
+local symbol_ip = nil
+local symbol_from = nil
+
+local r = nil
+local h = nil  -- radix tree and hash table
+
+function check_whitelist (task)
+       if symbol_ip then
+               -- check client's ip
+               local ipn = task:get_from_ip_num()
+               if ipn then
+                       local key = r:get_key(ipn)
+                       if key then
+                               task:insert_result(metric, symbol_ip, 1)
+                       end
+               end
+       end
+
+       if symbol_from then
+               -- check client's from domain
+               local from = task:get_from()
+               if from then
+                       local _,_,domain = string.find(from, '@(.+)>?$')
+                       local key = h:get_key(domain)
+                       if key then
+                               task:insert_result(metric, symbol_from, 1)
+                       end
+               end
+       end
+
+end
+
+
+-- Configuration
+local opts =  rspamd_config:get_all_opt('whitelist')
+if opts then
+    if opts['symbol_ip'] or opts['symbol_from'] then
+        symbol_ip = opts['symbol_ip']
+        symbol_from = opts['symbol_from']
+               
+               if symbol_ip then
+                       if opts['ip_whitelist'] then
+                               r = rspamd_config:add_radix_map (opts['ip_whitelist'])
+                       else
+                               -- No whitelist defined
+                               symbol_ip = nil
+                       end
+               end
+               if symbol_from then
+                       if opts['from_whitelist'] then
+                               h = rspamd_config:add_host_map (opts['from_whitelist'])
+                       else
+                               -- No whitelist defined
+                               symbol_from = nil
+                       end
+               end
+
+               if opts['metric'] then
+                       metric = opts['metric']
+               end
+
+               -- Register symbol's callback
+               local m = rspamd_config:get_metric(metric)
+               if symbol_ip then
+                       m:register_symbol(symbol_ip, 1.0, 'check_whitelist')
+               elseif symbol_from then
+                       m:register_symbol(symbol_from, 1.0, 'check_whitelist')
+               end
+       end
+end