#include "lua_common.h"
#include "../expressions.h"
+#include "../map.h"
+#include "../radix.h"
/* Config file methods */
LUA_FUNCTION_DEF (config, get_module_opt);
LUA_FUNCTION_DEF (config, get_metric);
LUA_FUNCTION_DEF (config, get_all_opt);
LUA_FUNCTION_DEF (config, register_function);
+LUA_FUNCTION_DEF (config, add_radix_map);
+LUA_FUNCTION_DEF (config, add_hash_map);
static const struct luaL_reg configlib_m[] = {
LUA_INTERFACE_DEF (config, get_module_opt),
LUA_INTERFACE_DEF (config, get_metric),
LUA_INTERFACE_DEF (config, get_all_opt),
LUA_INTERFACE_DEF (config, register_function),
+ LUA_INTERFACE_DEF (config, add_radix_map),
+ LUA_INTERFACE_DEF (config, add_hash_map),
{"__tostring", lua_class_tostring},
{NULL, NULL}
};
{NULL, NULL}
};
+/* Radix tree */
+LUA_FUNCTION_DEF (radix, get_key);
+
+static const struct luaL_reg radixlib_m[] = {
+ LUA_INTERFACE_DEF (radix, get_key),
+ {"__tostring", lua_class_tostring},
+ {NULL, NULL}
+};
+
+/* Hash table */
+LUA_FUNCTION_DEF (hash_table, get_key);
+
+static const struct luaL_reg hashlib_m[] = {
+ LUA_INTERFACE_DEF (hash_table, get_key),
+ {"__tostring", lua_class_tostring},
+ {NULL, NULL}
+};
+
static struct config_file *
lua_check_config (lua_State * L)
{
return *((struct metric **)ud);
}
+static radix_tree_t *
+lua_check_radix (lua_State * L)
+{
+ void *ud = luaL_checkudata (L, 1, "rspamd{radix}");
+ luaL_argcheck (L, ud != NULL, 1, "'radix' expected");
+ return **((radix_tree_t ***)ud);
+}
+
+static GHashTable *
+lua_check_hash_table (lua_State * L)
+{
+ void *ud = luaL_checkudata (L, 1, "rspamd{hash_table}");
+ luaL_argcheck (L, ud != NULL, 1, "'hash_table' expected");
+ return **((GHashTable ***)ud);
+}
+
/*** Config functions ***/
static int
lua_config_get_module_opt (lua_State * L)
return 0;
}
+static int
+lua_config_add_radix_map (lua_State *L)
+{
+ struct config_file *cfg = lua_check_config (L);
+ const char *map_line;
+ radix_tree_t **r, ***ud;
+
+ if (cfg) {
+ map_line = luaL_checkstring (L, 2);
+ r = g_malloc (sizeof (radix_tree_t *));
+ *r = radix_tree_create ();
+ if (!add_map (map_line, read_radix_list, fin_radix_list, (void **)r)) {
+ msg_warn ("add_radix_map: invalid radix map %s", map_line);
+ radix_tree_free (*r);
+ g_free (r);
+ lua_pushnil (L);
+ return 1;
+ }
+ ud = lua_newuserdata (L, sizeof (radix_tree_t *));
+ *ud = r;
+ lua_setclass (L, "rspamd{radix}", -1);
+
+ return 1;
+ }
+
+ lua_pushnil (L);
+ return 1;
+
+}
+
+static int
+lua_config_add_hash_map (lua_State *L)
+{
+ struct config_file *cfg = lua_check_config (L);
+ const char *map_line;
+ GHashTable **r, ***ud;
+
+ if (cfg) {
+ map_line = luaL_checkstring (L, 2);
+ r = g_malloc (sizeof (GHashTable *));
+ *r = g_hash_table_new (rspamd_strcase_hash, rspamd_strcase_equal);
+ if (!add_map (map_line, read_host_list, fin_host_list, (void **)r)) {
+ msg_warn ("add_radix_map: invalid hash map %s", map_line);
+ g_hash_table_destroy (*r);
+ g_free (r);
+ lua_pushnil (L);
+ return 1;
+ }
+ ud = lua_newuserdata (L, sizeof (GHashTable *));
+ *ud = r;
+ lua_setclass (L, "rspamd{hash_table}", -1);
+
+ return 1;
+ }
+
+ lua_pushnil (L);
+ return 1;
+
+}
+
/*** Metric functions ***/
cd = g_malloc (sizeof (struct lua_callback_data));
cd->name = g_strdup (callback);
cd->L = L;
- register_symbol (&metric->cache, cd->name, weight, lua_metric_symbol_callback, cd);
+ register_symbol (&metric->cache, name, weight, lua_metric_symbol_callback, cd);
+ }
+ }
+ return 1;
+}
+
+/* Radix and hash table functions */
+static int
+lua_radix_get_key (lua_State * L)
+{
+ radix_tree_t *radix = lua_check_radix (L);
+ uint32_t key;
+
+ if (radix) {
+ key = luaL_checkint (L, 2);
+
+ if (radix32tree_find (radix, key) != RADIX_NO_VALUE) {
+ lua_pushboolean (L, 1);
+ return 1;
+ }
+ }
+
+ lua_pushboolean (L, 0);
+ return 1;
+}
+
+static int
+lua_hash_table_get_key (lua_State * L)
+{
+ GHashTable *tbl = lua_check_hash_table (L);
+ const char *key;
+
+ if (tbl) {
+ key = luaL_checkstring (L, 2);
+
+ if (g_hash_table_lookup (tbl, key) != NULL) {
+ lua_pushboolean (L, 1);
+ return 1;
}
}
+
+ lua_pushboolean (L, 0);
return 1;
}
return 1;
}
+
+int
+luaopen_radix (lua_State * L)
+{
+ lua_newclass (L, "rspamd{radix}", radixlib_m);
+ luaL_openlib (L, "rspamd_radix", null_reg, 0);
+
+ return 1;
+}
+
+int
+luaopen_hash_table (lua_State * L)
+{
+ lua_newclass (L, "rspamd{hash_table}", hashlib_m);
+ luaL_openlib (L, "rspamd_hash_table", null_reg, 0);
+
+ return 1;
+}
LUA_FUNCTION_DEF (task, get_recipients);
LUA_FUNCTION_DEF (task, get_from);
LUA_FUNCTION_DEF (task, get_from_ip);
+LUA_FUNCTION_DEF (task, get_from_ip_num);
+LUA_FUNCTION_DEF (task, get_client_ip_num);
LUA_FUNCTION_DEF (task, get_helo);
static const struct luaL_reg tasklib_m[] = {
LUA_INTERFACE_DEF (task, get_recipients),
LUA_INTERFACE_DEF (task, get_from),
LUA_INTERFACE_DEF (task, get_from_ip),
+ LUA_INTERFACE_DEF (task, get_from_ip_num),
+ LUA_INTERFACE_DEF (task, get_client_ip_num),
LUA_INTERFACE_DEF (task, get_helo),
{"__tostring", lua_class_tostring},
{NULL, NULL}
return 1;
}
+static int
+lua_task_get_from_ip_num (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+
+ if (task) {
+ if (task->from_addr.s_addr != 0) {
+ lua_pushinteger (L, ntohl (task->from_addr.s_addr));
+ return 1;
+ }
+ }
+
+ lua_pushnil (L);
+ return 1;
+}
+
+static int
+lua_task_get_client_ip_num (lua_State *L)
+{
+ struct worker_task *task = lua_check_task (L);
+
+ if (task) {
+ if (task->client_addr.s_addr != 0) {
+ lua_pushinteger (L, ntohl (task->client_addr.s_addr));
+ return 1;
+ }
+ }
+
+ lua_pushnil (L);
+ return 1;
+}
+
static int
lua_task_get_helo (lua_State *L)
{
exit (-errno);
}
-#ifndef WITHOUT_PERL
- /* Init perl interpreter */
- dTHXa (perl_interpreter);
- PERL_SYS_INIT3 (&argc, &argv, &env);
- perl_interpreter = perl_alloc ();
- if (perl_interpreter == NULL) {
- msg_err ("main: cannot allocate perl interpreter, %s", strerror (errno));
- exit (-errno);
- }
-
- PERL_SET_CONTEXT (perl_interpreter);
- perl_construct (perl_interpreter);
- perl_parse (perl_interpreter, xs_init, 3, args, NULL);
- init_perl_filters (cfg);
-#elif defined(WITH_LUA)
- init_lua_filters (cfg);
-#endif
-
/* Block signals to use sigsuspend in future */
sigprocmask (SIG_BLOCK, &signals.sa_mask, NULL);
l = g_list_next (l);
}
+#ifndef WITHOUT_PERL
+ /* Init perl interpreter */
+ dTHXa (perl_interpreter);
+ PERL_SYS_INIT3 (&argc, &argv, &env);
+ perl_interpreter = perl_alloc ();
+ if (perl_interpreter == NULL) {
+ msg_err ("main: cannot allocate perl interpreter, %s", strerror (errno));
+ exit (-errno);
+ }
+
+ PERL_SET_CONTEXT (perl_interpreter);
+ perl_construct (perl_interpreter);
+ perl_parse (perl_interpreter, xs_init, 3, args, NULL);
+ init_perl_filters (cfg);
+#elif defined(WITH_LUA)
+ init_lua_filters (cfg);
+#endif
+
rspamd->workers = g_hash_table_new (g_direct_hash, g_direct_equal);
spawn_workers (rspamd, TRUE);
--- /dev/null
+-- Module that add symbols to those hosts or from domains that are contained in whitelist
+
+local metric = 'default'
+local symbol_ip = nil
+local symbol_from = nil
+
+local r = nil
+local h = nil -- radix tree and hash table
+
+function check_whitelist (task)
+ if symbol_ip then
+ -- check client's ip
+ local ipn = task:get_from_ip_num()
+ if ipn then
+ local key = r:get_key(ipn)
+ if key then
+ task:insert_result(metric, symbol_ip, 1)
+ end
+ end
+ end
+
+ if symbol_from then
+ -- check client's from domain
+ local from = task:get_from()
+ if from then
+ local _,_,domain = string.find(from, '@(.+)>?$')
+ local key = h:get_key(domain)
+ if key then
+ task:insert_result(metric, symbol_from, 1)
+ end
+ end
+ end
+
+end
+
+
+-- Configuration
+local opts = rspamd_config:get_all_opt('whitelist')
+if opts then
+ if opts['symbol_ip'] or opts['symbol_from'] then
+ symbol_ip = opts['symbol_ip']
+ symbol_from = opts['symbol_from']
+
+ if symbol_ip then
+ if opts['ip_whitelist'] then
+ r = rspamd_config:add_radix_map (opts['ip_whitelist'])
+ else
+ -- No whitelist defined
+ symbol_ip = nil
+ end
+ end
+ if symbol_from then
+ if opts['from_whitelist'] then
+ h = rspamd_config:add_host_map (opts['from_whitelist'])
+ else
+ -- No whitelist defined
+ symbol_from = nil
+ end
+ end
+
+ if opts['metric'] then
+ metric = opts['metric']
+ end
+
+ -- Register symbol's callback
+ local m = rspamd_config:get_metric(metric)
+ if symbol_ip then
+ m:register_symbol(symbol_ip, 1.0, 'check_whitelist')
+ elseif symbol_from then
+ m:register_symbol(symbol_from, 1.0, 'check_whitelist')
+ end
+ end
+end