aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/libcryptobox/cryptobox.h4
-rw-r--r--src/libcryptobox/keypair.c119
-rw-r--r--src/libcryptobox/keypair.h34
-rw-r--r--src/libserver/cfg_rcl.c80
-rw-r--r--src/libserver/symbols_cache.c12
-rw-r--r--src/libstat/stat_process.c19
-rw-r--r--src/libutil/map.c6
-rw-r--r--src/libutil/mem_pool.h2
-rw-r--r--src/libutil/util.c12
-rw-r--r--src/plugins/lua/asn.lua9
-rw-r--r--src/plugins/lua/emails.lua2
-rw-r--r--src/plugins/lua/forged_recipients.lua9
-rw-r--r--src/plugins/lua/neural.lua80
-rw-r--r--src/plugins/surbl.c3
14 files changed, 331 insertions, 60 deletions
diff --git a/src/libcryptobox/cryptobox.h b/src/libcryptobox/cryptobox.h
index 9fc5c7fe9..1045547a2 100644
--- a/src/libcryptobox/cryptobox.h
+++ b/src/libcryptobox/cryptobox.h
@@ -312,7 +312,7 @@ guint rspamd_cryptobox_mac_bytes (enum rspamd_cryptobox_mode mode);
guint rspamd_cryptobox_signature_bytes (enum rspamd_cryptobox_mode mode);
/* Hash IUF interface */
-typedef struct RSPAMD_ALIGNED(32) rspamd_cryptobox_hash_state_s {
+typedef struct RSPAMD_ALIGNED(16) rspamd_cryptobox_hash_state_s {
unsigned char opaque[256];
} rspamd_cryptobox_hash_state_t;
@@ -343,7 +343,7 @@ void rspamd_cryptobox_hash (guchar *out,
gsize keylen);
/* Non crypto hash IUF interface */
-typedef struct RSPAMD_ALIGNED(32) rspamd_cryptobox_fast_hash_state_s {
+typedef struct RSPAMD_ALIGNED(16) rspamd_cryptobox_fast_hash_state_s {
unsigned char opaque[64 + sizeof (size_t) + sizeof (uint64_t)];
} rspamd_cryptobox_fast_hash_state_t;
diff --git a/src/libcryptobox/keypair.c b/src/libcryptobox/keypair.c
index 98401e10f..50e3614d9 100644
--- a/src/libcryptobox/keypair.c
+++ b/src/libcryptobox/keypair.c
@@ -19,6 +19,9 @@
#include "libcryptobox/keypair_private.h"
#include "libutil/str_util.h"
#include "libutil/printf.h"
+#include "contrib/libottery/ottery.h"
+
+const guchar encrypted_magic[7] = {'r', 'u', 'c', 'l', 'e', 'v', '1'};
static GQuark
rspamd_keypair_quark (void)
@@ -908,3 +911,119 @@ rspamd_pubkey_equal (const struct rspamd_cryptobox_pubkey *k1,
return FALSE;
}
+
+gboolean
+rspamd_keypair_decrypt (struct rspamd_cryptobox_keypair *kp,
+ const guchar *in, gsize inlen,
+ guchar **out, gsize *outlen,
+ GError **err)
+{
+ const guchar *nonce, *mac, *data, *pubkey;
+
+ g_assert (kp != NULL);
+ g_assert (in != NULL);
+
+ if (kp->type != RSPAMD_KEYPAIR_KEX) {
+ g_set_error (err, rspamd_keypair_quark (), EINVAL,
+ "invalid keypair type");
+
+ return FALSE;
+ }
+
+ if (inlen < sizeof (encrypted_magic) + rspamd_cryptobox_pk_bytes (kp->alg) +
+ rspamd_cryptobox_mac_bytes (kp->alg) +
+ rspamd_cryptobox_nonce_bytes (kp->alg)) {
+ g_set_error (err, rspamd_keypair_quark (), E2BIG, "invalid size: too small");
+
+ return FALSE;
+ }
+
+ if (memcmp (in, encrypted_magic, sizeof (encrypted_magic)) != 0) {
+ g_set_error (err, rspamd_keypair_quark (), EINVAL,
+ "invalid magic");
+
+ return FALSE;
+ }
+
+ /* Set pointers */
+ pubkey = in + sizeof (encrypted_magic);
+ mac = pubkey + rspamd_cryptobox_pk_bytes (kp->alg);
+ nonce = mac + rspamd_cryptobox_mac_bytes (kp->alg);
+ data = nonce + rspamd_cryptobox_nonce_bytes (kp->alg);
+
+ if (data - in >= inlen) {
+ g_set_error (err, rspamd_keypair_quark (), E2BIG, "invalid size: too small");
+
+ return FALSE;
+ }
+
+ inlen -= data - in;
+
+ /* Allocate memory for output */
+ *out = g_malloc (inlen);
+ memcpy (*out, data, inlen);
+
+ if (!rspamd_cryptobox_decrypt_inplace (*out, inlen, nonce, pubkey,
+ rspamd_keypair_component (kp, RSPAMD_KEYPAIR_COMPONENT_SK, NULL),
+ mac, kp->alg)) {
+ g_set_error (err, rspamd_keypair_quark (), EPERM, "verification failed");
+ g_free (*out);
+
+ return FALSE;
+ }
+
+ if (outlen) {
+ *outlen = inlen;
+ }
+
+ return TRUE;
+}
+gboolean
+rspamd_keypair_encrypt (struct rspamd_cryptobox_keypair *kp,
+ const guchar *in, gsize inlen,
+ guchar **out, gsize *outlen,
+ GError **err)
+{
+ guchar *nonce, *mac, *data, *pubkey;
+ struct rspamd_cryptobox_keypair *local;
+ gsize olen;
+
+ g_assert (kp != NULL);
+ g_assert (in != NULL);
+
+ if (kp->type != RSPAMD_KEYPAIR_KEX) {
+ g_set_error (err, rspamd_keypair_quark (), EINVAL,
+ "invalid keypair type");
+
+ return FALSE;
+ }
+
+ local = rspamd_keypair_new (kp->type, kp->alg);
+
+ olen = inlen + sizeof (encrypted_magic) +
+ rspamd_cryptobox_pk_bytes (kp->alg) +
+ rspamd_cryptobox_mac_bytes (kp->alg) +
+ rspamd_cryptobox_nonce_bytes (kp->alg);
+ *out = g_malloc (olen);
+ memcpy (*out, encrypted_magic, sizeof (encrypted_magic));
+ pubkey = *out + sizeof (encrypted_magic);
+ mac = pubkey + rspamd_cryptobox_pk_bytes (kp->alg);
+ nonce = mac + rspamd_cryptobox_mac_bytes (kp->alg);
+ data = nonce + rspamd_cryptobox_nonce_bytes (kp->alg);
+
+ ottery_rand_bytes (nonce, rspamd_cryptobox_nonce_bytes (kp->alg));
+ memcpy (data, in, inlen);
+ memcpy (pubkey, rspamd_keypair_component (kp,
+ RSPAMD_KEYPAIR_COMPONENT_PK, NULL),
+ rspamd_cryptobox_pk_bytes (kp->alg));
+ rspamd_cryptobox_encrypt_inplace (data, inlen, nonce, pubkey,
+ rspamd_keypair_component (local, RSPAMD_KEYPAIR_COMPONENT_SK, NULL),
+ mac, kp->alg);
+ rspamd_keypair_unref (local);
+
+ if (outlen) {
+ *outlen = olen;
+ }
+
+ return TRUE;
+} \ No newline at end of file
diff --git a/src/libcryptobox/keypair.h b/src/libcryptobox/keypair.h
index b24ecc9aa..3e78e7cbb 100644
--- a/src/libcryptobox/keypair.h
+++ b/src/libcryptobox/keypair.h
@@ -28,6 +28,8 @@ enum rspamd_cryptobox_keypair_type {
RSPAMD_KEYPAIR_SIGN
};
+extern const guchar encrypted_magic[7];
+
/**
* Opaque structure for the full (public + private) keypair
*/
@@ -270,5 +272,37 @@ gboolean rspamd_keypair_verify (struct rspamd_cryptobox_pubkey *pk,
gboolean rspamd_pubkey_equal (const struct rspamd_cryptobox_pubkey *k1,
const struct rspamd_cryptobox_pubkey *k2);
+/**
+ * Decrypts data using keypair and a pubkey stored in in, in must start from
+ * `encrypted_magic` constant
+ * @param kp keypair
+ * @param in raw input
+ * @param inlen input length
+ * @param out output (allocated internally using g_malloc)
+ * @param outlen output size
+ * @return TRUE if decryption is completed, out must be freed in this case
+ */
+gboolean rspamd_keypair_decrypt (struct rspamd_cryptobox_keypair *kp,
+ const guchar *in, gsize inlen,
+ guchar **out, gsize *outlen,
+ GError **err);
+
+/**
+ * Encrypts data usign specific keypair.
+ * This method actually generates ephemeral local keypair, use public key from
+ * the remote keypair and encrypts data
+ * @param kp keypair
+ * @param in raw input
+ * @param inlen input length
+ * @param out output (allocated internally using g_malloc)
+ * @param outlen output size
+ * @param err pointer to error
+ * @return TRUE if encryption has been completed, out must be freed in this case
+ */
+gboolean rspamd_keypair_encrypt (struct rspamd_cryptobox_keypair *kp,
+ const guchar *in, gsize inlen,
+ guchar **out, gsize *outlen,
+ GError **err);
+
#endif /* SRC_LIBCRYPTOBOX_KEYPAIR_H_ */
diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c
index a35bbf56c..deb4a8ed3 100644
--- a/src/libserver/cfg_rcl.c
+++ b/src/libserver/cfg_rcl.c
@@ -3624,6 +3624,32 @@ rspamd_rcl_maybe_apply_lua_transform (struct rspamd_config *cfg)
lua_settop (L, 0);
}
+static bool
+rspamd_rcl_decrypt_handler(struct ucl_parser *parser,
+ const unsigned char *source, size_t source_len,
+ unsigned char **destination, size_t *dest_len,
+ void *user_data)
+{
+ GError *err = NULL;
+ struct rspamd_cryptobox_keypair *kp = (struct rspamd_cryptobox_keypair *)user_data;
+
+ if (!rspamd_keypair_decrypt (kp, source, source_len,
+ destination, dest_len, &err)) {
+ msg_err ("cannot decrypt file: %e", err);
+ g_error_free (err);
+
+ return false;
+ }
+
+ return true;
+}
+
+static void
+rspamd_rcl_decrypt_free (unsigned char *data, size_t len, void *user_data)
+{
+ g_free (data);
+}
+
gboolean
rspamd_config_read (struct rspamd_config *cfg, const gchar *filename,
const gchar *convert_to, rspamd_rcl_section_fin_t logger_fin,
@@ -3638,6 +3664,8 @@ rspamd_config_read (struct rspamd_config *cfg, const gchar *filename,
const ucl_object_t *logger_obj;
rspamd_cryptobox_hash_state_t hs;
unsigned char cksumbuf[rspamd_cryptobox_HASHBYTES];
+ gchar keypair_path[PATH_MAX];
+ struct rspamd_cryptobox_keypair *decrypt_keypair = NULL;
struct ucl_emitter_functions f;
if (stat (filename, &st) == -1) {
@@ -3659,11 +3687,63 @@ rspamd_config_read (struct rspamd_config *cfg, const gchar *filename,
close (fd);
+ /* Try to load keyfile if available */
+ rspamd_snprintf (keypair_path, sizeof (keypair_path), "%s.key",
+ filename);
+ if (stat (keypair_path, &st) == -1 &&
+ (fd = open (keypair_path, O_RDONLY)) != -1) {
+ struct ucl_parser *kp_parser;
+
+ kp_parser = ucl_parser_new (0);
+
+ if (ucl_parser_add_fd (kp_parser, fd)) {
+ ucl_object_t *kp_obj;
+
+ kp_obj = ucl_parser_get_object (kp_parser);
+
+ g_assert (kp_obj != NULL);
+ decrypt_keypair = rspamd_keypair_from_ucl (kp_obj);
+
+ if (decrypt_keypair == NULL) {
+ msg_err_config_forced ("cannot load keypair from %s: invalid keypair",
+ keypair_path);
+ }
+ else {
+ /* Add decryption support to UCL */
+ rspamd_mempool_add_destructor (cfg->cfg_pool,
+ (rspamd_mempool_destruct_t)rspamd_keypair_unref,
+ decrypt_keypair);
+ }
+
+ ucl_object_unref (kp_obj);
+ }
+ else {
+ msg_err_config_forced ("cannot load keypair from %s: %s",
+ keypair_path, ucl_parser_get_error (kp_parser));
+ }
+
+ ucl_parser_free (kp_parser);
+ }
+
rspamd_cryptobox_hash_init (&hs, NULL, 0);
parser = ucl_parser_new (UCL_PARSER_SAVE_COMMENTS);
rspamd_ucl_add_conf_variables (parser, vars);
rspamd_ucl_add_conf_macros (parser, cfg);
+ if (decrypt_keypair) {
+ struct ucl_parser_special_handler *decrypt_handler;
+
+ decrypt_handler = rspamd_mempool_alloc0 (cfg->cfg_pool,
+ sizeof (*decrypt_handler));
+ decrypt_handler->user_data = decrypt_keypair;
+ decrypt_handler->magic = encrypted_magic;
+ decrypt_handler->magic_len = sizeof (encrypted_magic);
+ decrypt_handler->handler = rspamd_rcl_decrypt_handler;
+ decrypt_handler->free_function = rspamd_rcl_decrypt_free;
+
+ ucl_parser_add_special_handler (parser, decrypt_handler);
+ }
+
if (!ucl_parser_add_chunk (parser, data, st.st_size)) {
msg_err_config_forced ("ucl parser error: %s", ucl_parser_get_error (parser));
ucl_parser_free (parser);
diff --git a/src/libserver/symbols_cache.c b/src/libserver/symbols_cache.c
index eac6c8d0c..abed26fc3 100644
--- a/src/libserver/symbols_cache.c
+++ b/src/libserver/symbols_cache.c
@@ -415,7 +415,7 @@ rspamd_symbols_cache_post_init (struct symbols_cache *cache)
struct delayed_cache_dependency *ddep;
struct delayed_cache_condition *dcond;
GList *cur;
- guint i, j;
+ gint i, j;
gint id;
rspamd_symbols_cache_resort (cache);
@@ -506,6 +506,16 @@ rspamd_symbols_cache_post_init (struct symbols_cache *cache)
msg_err_cache ("cannot find dependency on symbol %s", dep->sym);
}
}
+
+ /* Reversed loop to make removal safe */
+ for (j = it->deps->len - 1; j >= 0; j --) {
+ dep = g_ptr_array_index (it->deps, j);
+
+ if (dep->item == NULL) {
+ /* Remove useless dep */
+ g_ptr_array_remove_index (it->deps, j);
+ }
+ }
}
g_ptr_array_sort_with_data (cache->prefilters, prefilters_cmp, cache);
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c
index aed9d383d..070b9d6ca 100644
--- a/src/libstat/stat_process.c
+++ b/src/libstat/stat_process.c
@@ -81,6 +81,7 @@ rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx,
gchar tmpbuf[128];
lua_State *L = task->cfg->lua_state;
const gchar *headers_hash;
+ struct rspamd_mime_header *hdr;
ar = g_array_sized_new (FALSE, FALSE, sizeof (elt), 16);
elt.flags = RSPAMD_STAT_TOKEN_FLAG_META;
@@ -183,6 +184,24 @@ rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx,
g_array_append_val (ar, elt);
}
+ /* Use more precise headers order */
+ cur = g_list_first (task->headers_order->head);
+ while (cur) {
+ hdr = cur->data;
+
+ if (hdr->name && hdr->type != RSPAMD_HEADER_RECEIVED) {
+ /* We assume that headers count is not more than 10^10 */
+ gsize nlen = strlen (hdr->name) + 1 + 10;
+ gchar *hdrbuf = rspamd_mempool_alloc (task->task_pool, nlen);
+ nlen = rspamd_snprintf (hdrbuf, nlen, "%s:%d", hdr->name, hdr->order);
+ rspamd_str_lc (hdrbuf, nlen);
+ elt.begin = hdrbuf;
+ elt.len = nlen;
+ g_array_append_val (ar, elt);
+ }
+
+ cur = g_list_next (cur);
+ }
/* Use metatokens plugin from Lua */
lua_getglobal (L, "rspamd_plugins");
diff --git a/src/libutil/map.c b/src/libutil/map.c
index 1e8f70f7e..577933c25 100644
--- a/src/libutil/map.c
+++ b/src/libutil/map.c
@@ -346,11 +346,11 @@ rspamd_map_cache_cb (gint fd, short what, gpointer ud)
/* We have another update, so this cache element is obviously expired */
/* Important: we do not set cache availability to zero here */
MAP_RELEASE (cache_cbd->shm, "rspamd_http_map_cached_cbdata");
- msg_debug_map ("cached data is now expired (gen mismatch) for %s", map->name);
+ msg_info_map ("cached data is now expired (gen mismatch) for %s", map->name);
event_del (&cache_cbd->timeout);
g_free (cache_cbd);
}
- else if (cache_cbd->data->last_checked > cache_cbd->last_checked) {
+ else if (cache_cbd->data->last_checked >= cache_cbd->last_checked) {
/*
* We checked map but we have not found anything more recent,
* reschedule cache check
@@ -363,7 +363,7 @@ rspamd_map_cache_cb (gint fd, short what, gpointer ud)
else {
g_atomic_int_set (&map->cache->available, 0);
MAP_RELEASE (cache_cbd->shm, "rspamd_http_map_cached_cbdata");
- msg_debug_map ("cached data is now expired for %s", map->name);
+ msg_info_map ("cached data is now expired for %s", map->name);
event_del (&cache_cbd->timeout);
g_free (cache_cbd);
}
diff --git a/src/libutil/mem_pool.h b/src/libutil/mem_pool.h
index 7423895c4..27d4c8ebf 100644
--- a/src/libutil/mem_pool.h
+++ b/src/libutil/mem_pool.h
@@ -21,7 +21,7 @@ struct f_str_s;
#define MEMPOOL_TAG_LEN 20
#define MEMPOOL_UID_LEN 20
-#define MEM_ALIGNMENT 8
+#define MEM_ALIGNMENT 16
#define align_ptr(p, a) \
(guint8 *) (((uintptr_t) (p) + ((uintptr_t) a - 1)) & ~((uintptr_t) a - 1))
diff --git a/src/libutil/util.c b/src/libutil/util.c
index 29c12ca2f..d55a8e8ac 100644
--- a/src/libutil/util.c
+++ b/src/libutil/util.c
@@ -83,6 +83,7 @@
#include "cryptobox.h"
#include "zlib.h"
+#include "contrib/uthash/utlist.h"
/* Check log messages intensity once per minute */
#define CHECK_TIME 60
@@ -418,6 +419,12 @@ out:
return (-1);
}
+static int
+rspamd_prefer_v4_hack (const struct addrinfo *a1, const struct addrinfo *a2)
+{
+ return a1->ai_addr->sa_family - a2->ai_addr->sa_family;
+}
+
/**
* Make a universal socket
* @param credits host, ip or path to unix socket
@@ -480,6 +487,7 @@ rspamd_socket (const gchar *credits, guint16 port,
rspamd_snprintf (portbuf, sizeof (portbuf), "%d", (int)port);
if ((r = getaddrinfo (credits, portbuf, &hints, &res)) == 0) {
+ LL_SORT2 (res, rspamd_prefer_v4_hack, ai_next);
r = rspamd_inet_socket_create (type, res, is_server, async, NULL);
freeaddrinfo (res);
return r;
@@ -572,7 +580,9 @@ rspamd_sockets_list (const gchar *credits, guint16 port,
rspamd_snprintf (portbuf, sizeof (portbuf), "%d", (int)port);
if ((r = getaddrinfo (credits, portbuf, &hints, &res)) == 0) {
- fd = rspamd_inet_socket_create (type, res, is_server, async, &result);
+ LL_SORT2 (res, rspamd_prefer_v4_hack, ai_next);
+ fd = rspamd_inet_socket_create (type, res, is_server, async,
+ &result);
freeaddrinfo (res);
if (result == NULL) {
diff --git a/src/plugins/lua/asn.lua b/src/plugins/lua/asn.lua
index 3e84a3824..61572a600 100644
--- a/src/plugins/lua/asn.lua
+++ b/src/plugins/lua/asn.lua
@@ -60,18 +60,19 @@ local function asn_check(task)
local asn_check_func = {}
function asn_check_func.rspamd(ip)
+ local dnsbl = options['provider_info']['ip' .. ip:get_version()]
+ local req_name = rspamd_logger.slog("%1.%2",
+ table.concat(ip:inversed_str_octets(), '.'), dnsbl)
local function rspamd_dns_cb(_, _, results, dns_err)
if dns_err and (dns_err ~= 'requested record is not found' and dns_err ~= 'no records with this name') then
- rspamd_logger.errx(task, 'error querying dns: %s', dns_err)
+ rspamd_logger.errx(task, 'error querying dns (%s): %s', req_name, dns_err)
end
if not (results and results[1]) then return end
local parts = rspamd_re:split(results[1])
-- "15169 | 8.8.8.0/24 | US | arin |" for 8.8.8.8
asn_set(parts[1], parts[2], parts[3])
end
- local dnsbl = options['provider_info']['ip' .. ip:get_version()]
- local req_name = rspamd_logger.slog("%1.%2",
- table.concat(ip:inversed_str_octets(), '.'), dnsbl)
+
task:get_resolver():resolve_txt(task:get_session(), task:get_mempool(),
req_name, rspamd_dns_cb)
end
diff --git a/src/plugins/lua/emails.lua b/src/plugins/lua/emails.lua
index 51df0e959..5f076e1d5 100644
--- a/src/plugins/lua/emails.lua
+++ b/src/plugins/lua/emails.lua
@@ -52,7 +52,7 @@ local function check_email_rule(task, rule, addr)
local function emails_dns_cb(_, _, results, err)
if err and (err ~= 'requested record is not found'
and err ~= 'no records with this name') then
- logger.errx(task, 'Error querying DNS: %1', err)
+ logger.errx(task, 'Error querying DNS(%s): %s', to_resolve, err)
elseif results then
local expected_found = false
local symbol = rule['symbol']
diff --git a/src/plugins/lua/forged_recipients.lua b/src/plugins/lua/forged_recipients.lua
index d9cab67c7..782a408bd 100644
--- a/src/plugins/lua/forged_recipients.lua
+++ b/src/plugins/lua/forged_recipients.lua
@@ -39,7 +39,8 @@ local function check_forged_headers(task)
if not mime_rcpt then
return
elseif #mime_rcpt == 0 then
- task:insert_result(symbol_rcpt, score)
+ local sra = smtp_rcpt[1].addr .. (#smtp_rcpt > 1 and ' ...' or '')
+ task:insert_result(symbol_rcpt, score, '', sra)
return
end
-- Find pair for each smtp recipient recipient in To or Cc headers
@@ -67,7 +68,9 @@ local function check_forged_headers(task)
end
end
if not res then
- task:insert_result(symbol_rcpt, score)
+ local mra = mime_rcpt[1].addr .. (#mime_rcpt > 1 and ' ..' or '')
+ local sra = smtp_rcpt[1].addr .. (#smtp_rcpt > 1 and ' ...' or '')
+ task:insert_result(symbol_rcpt, score, mra, sra)
break
end
end
@@ -76,7 +79,7 @@ local function check_forged_headers(task)
local mime_from = task:get_from(2)
if not mime_from or not mime_from[1] or
not (string.lower(mime_from[1]['addr']) == string.lower(smtp_from[1]['addr'])) then
- task:insert_result(symbol_sender, 1)
+ task:insert_result(symbol_sender, 1, ((mime_from or E)[1] or E).addr or '', smtp_from[1].addr)
end
end
end
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 04b732472..9d0bbb446 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -65,10 +65,6 @@ local settings = {
rules = {}
}
--- ANNs indexed by settings id
-local anns = {
-}
-
local opts = rspamd_config:get_all_opt("neural")
if not opts then
-- Legacy
@@ -278,7 +274,7 @@ local function ann_scores_filter(task)
id = id .. r
end
- if anns[id] and anns[id].ann then
+ if rule.anns[id] and rule.anns[id].ann then
local ann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
@@ -286,15 +282,15 @@ local function ann_scores_filter(task)
local score
if use_torch then
- local out = anns[id].ann:forward(torch.Tensor(ann_data))
+ local out = rule.anns[id].ann:forward(torch.Tensor(ann_data))
score = out[1]
else
- local out = anns[id].ann:test(ann_data)
+ local out = rule.anns[id].ann:test(ann_data)
score = out[1]
end
local symscore = string.format('%.3f', score)
- rspamd_logger.infox(task, 'ann score: %s', symscore)
+ rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore)
if score > 0 then
local result = score
@@ -339,28 +335,29 @@ end
local function create_train_ann(rule, n, id)
local prefix = gen_ann_prefix(rule, id)
- if not anns[id] then
- anns[id] = {}
+ if not rule.anns[id] then
+ rule.anns[id] = {}
end
-- Fix that for flexibe layers number
- if anns[id].ann then
- if not is_ann_valid(rule, prefix, anns[id].ann) then
- anns[id].ann_train = create_ann(n, rule.nlayers)
- anns[id].ann = nil
+ if rule.anns[id].ann then
+ if not is_ann_valid(rule, prefix, rule.anns[id].ann) then
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].ann = nil
rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix)
- elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then
+ elseif rule.train.max_usages > 0 and
+ rule.anns[id].version % rule.train.max_usages == 0 then
-- Forget last ann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
- anns[id].version)
- anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].version)
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
else
- anns[id].ann_train = anns[id].ann
+ rule.anns[id].ann_train = rule.anns[id].ann
rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix)
end
else
- anns[id].ann_train = create_ann(n, rule.nlayers)
+ rule.anns[id].ann_train = create_ann(n, rule.nlayers)
rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix)
- anns[id].version = 0
+ rule.anns[id].version = 0
end
end
@@ -388,18 +385,18 @@ local function load_or_invalidate_ann(rule, data, id, ev_base)
end
if is_ann_valid(rule, prefix, ann) then
- if not anns[id] then anns[id] = {} end
- anns[id].ann = ann
+ if not rule.anns[id] then rule.anns[id] = {} end
+ rule.anns[id].ann = ann
rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis',
prefix, ver)
- anns[id].version = tonumber(ver)
+ rule.anns[id].version = tonumber(ver)
else
local function redis_invalidate_cb(_err, _data)
if _err then
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- anns[id].version = 0
+ rule.anns[id].version = 0
end
end
-- Invalidate ANN
@@ -553,15 +550,15 @@ local function train_ann(rule, _, ev_base, elt, worker)
local ann_data
if use_torch then
local f = torch.MemoryFile()
- f:writeObject(anns[elt].ann_train)
+ f:writeObject(rule.anns[elt].ann_train)
ann_data = rspamd_util.zstd_compress(f:storage():string())
else
- ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data())
+ ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data())
end
- anns[elt].version = anns[elt].version + 1
- anns[elt].ann = anns[elt].ann_train
- anns[elt].ann_train = nil
+ rule.anns[elt].version = rule.anns[elt].version + 1
+ rule.anns[elt].ann = rule.anns[elt].ann_train
+ rule.anns[elt].ann_train = nil
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
@@ -589,11 +586,11 @@ local function train_ann(rule, _, ev_base, elt, worker)
local ann_data
local f = torch.MemoryFile(torch.CharStorage():string(tostring(data)))
ann_data = rspamd_util.zstd_compress(f:storage():string())
- anns[elt].ann_train = f:readObject()
+ rule.anns[elt].ann_train = f:readObject()
- anns[elt].version = anns[elt].version + 1
- anns[elt].ann = anns[elt].ann_train
- anns[elt].ann_train = nil
+ rule.anns[elt].version = rule.anns[elt].version + 1
+ rule.anns[elt].ann = rule.anns[elt].ann_train
+ rule.anns[elt].ann_train = nil
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},
redis_save_cb,
@@ -629,7 +626,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
end
-- Now we can train ann
- if not anns[elt] or not anns[elt].ann_train then
+ if not rule.anns[elt] or not rule.anns[elt].ann_train then
-- Create ann if it does not exist
create_train_ann(rule, n, elt)
end
@@ -641,7 +638,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err)
elseif type(_data) == 'string' then
rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err)
- anns[elt].version = 0
+ rule.anns[elt].version = 0
end
end
-- Invalidate ANN
@@ -668,7 +665,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
torch.setnumthreads(rule.train.learn_threads)
end
local criterion = nn.MSECriterion()
- local trainer = nn.StochasticGradient(anns[elt].ann_train,
+ local trainer = nn.StochasticGradient(rule.anns[elt].ann_train,
criterion)
trainer.learning_rate = rule.train.learning_rate
trainer.verbose = false
@@ -680,7 +677,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
trainer:train(dataset)
local out = torch.MemoryFile()
- out:writeObject(anns[elt].ann_train)
+ out:writeObject(rule.anns[elt].ann_train)
local st = out:storage():string()
return st
end
@@ -701,7 +698,7 @@ local function train_ann(rule, _, ev_base, elt, worker)
end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
rule.learning_spawned = true
rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix)
- anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
+ rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained,
ev_base, {
max_epochs = rule.train.max_epoch,
desired_mse = rule.train.mse
@@ -880,9 +877,9 @@ local function check_anns(rule, _, ev_base)
end
local local_ver = 0
- if anns[elt] then
- if anns[elt].version then
- local_ver = anns[elt].version
+ if rule.anns[elt] then
+ if rule.anns[elt].version then
+ local_ver = rule.anns[elt].version
end
end
lua_redis.exec_redis_script(redis_maybe_load_id,
@@ -963,6 +960,7 @@ else
for k,r in pairs(rules) do
local def_rules = lua_util.override_defaults(default_options, r)
def_rules['redis'] = redis_params
+ def_rules['anns'] = {} -- Store ANNs here
if not def_rules.prefix then
def_rules.prefix = k
diff --git a/src/plugins/surbl.c b/src/plugins/surbl.c
index 17bc26688..d0bdd3b8c 100644
--- a/src/plugins/surbl.c
+++ b/src/plugins/surbl.c
@@ -696,7 +696,6 @@ surbl_module_parse_rule (const ucl_object_t* value, struct rspamd_config* cfg)
/* Mutually exclusive options */
msg_err_config ("options noip and resolve_ip are "
"mutually exclusive for suffix %s", new_suffix->suffix);
- ucl_object_unref (ropts);
continue;
}
@@ -845,8 +844,6 @@ surbl_module_parse_rule (const ucl_object_t* value, struct rspamd_config* cfg)
RSPAMD_MONITORED_DEFAULT, ropts);
surbl_module_ctx->suffixes = g_list_prepend (surbl_module_ctx->suffixes,
new_suffix);
-
- ucl_object_unref (ropts);
}
return nrules;