diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/libcryptobox/cryptobox.h | 4 | ||||
-rw-r--r-- | src/libcryptobox/keypair.c | 119 | ||||
-rw-r--r-- | src/libcryptobox/keypair.h | 34 | ||||
-rw-r--r-- | src/libserver/cfg_rcl.c | 80 | ||||
-rw-r--r-- | src/libserver/symbols_cache.c | 12 | ||||
-rw-r--r-- | src/libstat/stat_process.c | 19 | ||||
-rw-r--r-- | src/libutil/map.c | 6 | ||||
-rw-r--r-- | src/libutil/mem_pool.h | 2 | ||||
-rw-r--r-- | src/libutil/util.c | 12 | ||||
-rw-r--r-- | src/plugins/lua/asn.lua | 9 | ||||
-rw-r--r-- | src/plugins/lua/emails.lua | 2 | ||||
-rw-r--r-- | src/plugins/lua/forged_recipients.lua | 9 | ||||
-rw-r--r-- | src/plugins/lua/neural.lua | 80 | ||||
-rw-r--r-- | src/plugins/surbl.c | 3 |
14 files changed, 331 insertions, 60 deletions
diff --git a/src/libcryptobox/cryptobox.h b/src/libcryptobox/cryptobox.h index 9fc5c7fe9..1045547a2 100644 --- a/src/libcryptobox/cryptobox.h +++ b/src/libcryptobox/cryptobox.h @@ -312,7 +312,7 @@ guint rspamd_cryptobox_mac_bytes (enum rspamd_cryptobox_mode mode); guint rspamd_cryptobox_signature_bytes (enum rspamd_cryptobox_mode mode); /* Hash IUF interface */ -typedef struct RSPAMD_ALIGNED(32) rspamd_cryptobox_hash_state_s { +typedef struct RSPAMD_ALIGNED(16) rspamd_cryptobox_hash_state_s { unsigned char opaque[256]; } rspamd_cryptobox_hash_state_t; @@ -343,7 +343,7 @@ void rspamd_cryptobox_hash (guchar *out, gsize keylen); /* Non crypto hash IUF interface */ -typedef struct RSPAMD_ALIGNED(32) rspamd_cryptobox_fast_hash_state_s { +typedef struct RSPAMD_ALIGNED(16) rspamd_cryptobox_fast_hash_state_s { unsigned char opaque[64 + sizeof (size_t) + sizeof (uint64_t)]; } rspamd_cryptobox_fast_hash_state_t; diff --git a/src/libcryptobox/keypair.c b/src/libcryptobox/keypair.c index 98401e10f..50e3614d9 100644 --- a/src/libcryptobox/keypair.c +++ b/src/libcryptobox/keypair.c @@ -19,6 +19,9 @@ #include "libcryptobox/keypair_private.h" #include "libutil/str_util.h" #include "libutil/printf.h" +#include "contrib/libottery/ottery.h" + +const guchar encrypted_magic[7] = {'r', 'u', 'c', 'l', 'e', 'v', '1'}; static GQuark rspamd_keypair_quark (void) @@ -908,3 +911,119 @@ rspamd_pubkey_equal (const struct rspamd_cryptobox_pubkey *k1, return FALSE; } + +gboolean +rspamd_keypair_decrypt (struct rspamd_cryptobox_keypair *kp, + const guchar *in, gsize inlen, + guchar **out, gsize *outlen, + GError **err) +{ + const guchar *nonce, *mac, *data, *pubkey; + + g_assert (kp != NULL); + g_assert (in != NULL); + + if (kp->type != RSPAMD_KEYPAIR_KEX) { + g_set_error (err, rspamd_keypair_quark (), EINVAL, + "invalid keypair type"); + + return FALSE; + } + + if (inlen < sizeof (encrypted_magic) + rspamd_cryptobox_pk_bytes (kp->alg) + + rspamd_cryptobox_mac_bytes (kp->alg) + + rspamd_cryptobox_nonce_bytes (kp->alg)) { + g_set_error (err, rspamd_keypair_quark (), E2BIG, "invalid size: too small"); + + return FALSE; + } + + if (memcmp (in, encrypted_magic, sizeof (encrypted_magic)) != 0) { + g_set_error (err, rspamd_keypair_quark (), EINVAL, + "invalid magic"); + + return FALSE; + } + + /* Set pointers */ + pubkey = in + sizeof (encrypted_magic); + mac = pubkey + rspamd_cryptobox_pk_bytes (kp->alg); + nonce = mac + rspamd_cryptobox_mac_bytes (kp->alg); + data = nonce + rspamd_cryptobox_nonce_bytes (kp->alg); + + if (data - in >= inlen) { + g_set_error (err, rspamd_keypair_quark (), E2BIG, "invalid size: too small"); + + return FALSE; + } + + inlen -= data - in; + + /* Allocate memory for output */ + *out = g_malloc (inlen); + memcpy (*out, data, inlen); + + if (!rspamd_cryptobox_decrypt_inplace (*out, inlen, nonce, pubkey, + rspamd_keypair_component (kp, RSPAMD_KEYPAIR_COMPONENT_SK, NULL), + mac, kp->alg)) { + g_set_error (err, rspamd_keypair_quark (), EPERM, "verification failed"); + g_free (*out); + + return FALSE; + } + + if (outlen) { + *outlen = inlen; + } + + return TRUE; +} +gboolean +rspamd_keypair_encrypt (struct rspamd_cryptobox_keypair *kp, + const guchar *in, gsize inlen, + guchar **out, gsize *outlen, + GError **err) +{ + guchar *nonce, *mac, *data, *pubkey; + struct rspamd_cryptobox_keypair *local; + gsize olen; + + g_assert (kp != NULL); + g_assert (in != NULL); + + if (kp->type != RSPAMD_KEYPAIR_KEX) { + g_set_error (err, rspamd_keypair_quark (), EINVAL, + "invalid keypair type"); + + return FALSE; + } + + local = rspamd_keypair_new (kp->type, kp->alg); + + olen = inlen + sizeof (encrypted_magic) + + rspamd_cryptobox_pk_bytes (kp->alg) + + rspamd_cryptobox_mac_bytes (kp->alg) + + rspamd_cryptobox_nonce_bytes (kp->alg); + *out = g_malloc (olen); + memcpy (*out, encrypted_magic, sizeof (encrypted_magic)); + pubkey = *out + sizeof (encrypted_magic); + mac = pubkey + rspamd_cryptobox_pk_bytes (kp->alg); + nonce = mac + rspamd_cryptobox_mac_bytes (kp->alg); + data = nonce + rspamd_cryptobox_nonce_bytes (kp->alg); + + ottery_rand_bytes (nonce, rspamd_cryptobox_nonce_bytes (kp->alg)); + memcpy (data, in, inlen); + memcpy (pubkey, rspamd_keypair_component (kp, + RSPAMD_KEYPAIR_COMPONENT_PK, NULL), + rspamd_cryptobox_pk_bytes (kp->alg)); + rspamd_cryptobox_encrypt_inplace (data, inlen, nonce, pubkey, + rspamd_keypair_component (local, RSPAMD_KEYPAIR_COMPONENT_SK, NULL), + mac, kp->alg); + rspamd_keypair_unref (local); + + if (outlen) { + *outlen = olen; + } + + return TRUE; +}
\ No newline at end of file diff --git a/src/libcryptobox/keypair.h b/src/libcryptobox/keypair.h index b24ecc9aa..3e78e7cbb 100644 --- a/src/libcryptobox/keypair.h +++ b/src/libcryptobox/keypair.h @@ -28,6 +28,8 @@ enum rspamd_cryptobox_keypair_type { RSPAMD_KEYPAIR_SIGN }; +extern const guchar encrypted_magic[7]; + /** * Opaque structure for the full (public + private) keypair */ @@ -270,5 +272,37 @@ gboolean rspamd_keypair_verify (struct rspamd_cryptobox_pubkey *pk, gboolean rspamd_pubkey_equal (const struct rspamd_cryptobox_pubkey *k1, const struct rspamd_cryptobox_pubkey *k2); +/** + * Decrypts data using keypair and a pubkey stored in in, in must start from + * `encrypted_magic` constant + * @param kp keypair + * @param in raw input + * @param inlen input length + * @param out output (allocated internally using g_malloc) + * @param outlen output size + * @return TRUE if decryption is completed, out must be freed in this case + */ +gboolean rspamd_keypair_decrypt (struct rspamd_cryptobox_keypair *kp, + const guchar *in, gsize inlen, + guchar **out, gsize *outlen, + GError **err); + +/** + * Encrypts data usign specific keypair. + * This method actually generates ephemeral local keypair, use public key from + * the remote keypair and encrypts data + * @param kp keypair + * @param in raw input + * @param inlen input length + * @param out output (allocated internally using g_malloc) + * @param outlen output size + * @param err pointer to error + * @return TRUE if encryption has been completed, out must be freed in this case + */ +gboolean rspamd_keypair_encrypt (struct rspamd_cryptobox_keypair *kp, + const guchar *in, gsize inlen, + guchar **out, gsize *outlen, + GError **err); + #endif /* SRC_LIBCRYPTOBOX_KEYPAIR_H_ */ diff --git a/src/libserver/cfg_rcl.c b/src/libserver/cfg_rcl.c index a35bbf56c..deb4a8ed3 100644 --- a/src/libserver/cfg_rcl.c +++ b/src/libserver/cfg_rcl.c @@ -3624,6 +3624,32 @@ rspamd_rcl_maybe_apply_lua_transform (struct rspamd_config *cfg) lua_settop (L, 0); } +static bool +rspamd_rcl_decrypt_handler(struct ucl_parser *parser, + const unsigned char *source, size_t source_len, + unsigned char **destination, size_t *dest_len, + void *user_data) +{ + GError *err = NULL; + struct rspamd_cryptobox_keypair *kp = (struct rspamd_cryptobox_keypair *)user_data; + + if (!rspamd_keypair_decrypt (kp, source, source_len, + destination, dest_len, &err)) { + msg_err ("cannot decrypt file: %e", err); + g_error_free (err); + + return false; + } + + return true; +} + +static void +rspamd_rcl_decrypt_free (unsigned char *data, size_t len, void *user_data) +{ + g_free (data); +} + gboolean rspamd_config_read (struct rspamd_config *cfg, const gchar *filename, const gchar *convert_to, rspamd_rcl_section_fin_t logger_fin, @@ -3638,6 +3664,8 @@ rspamd_config_read (struct rspamd_config *cfg, const gchar *filename, const ucl_object_t *logger_obj; rspamd_cryptobox_hash_state_t hs; unsigned char cksumbuf[rspamd_cryptobox_HASHBYTES]; + gchar keypair_path[PATH_MAX]; + struct rspamd_cryptobox_keypair *decrypt_keypair = NULL; struct ucl_emitter_functions f; if (stat (filename, &st) == -1) { @@ -3659,11 +3687,63 @@ rspamd_config_read (struct rspamd_config *cfg, const gchar *filename, close (fd); + /* Try to load keyfile if available */ + rspamd_snprintf (keypair_path, sizeof (keypair_path), "%s.key", + filename); + if (stat (keypair_path, &st) == -1 && + (fd = open (keypair_path, O_RDONLY)) != -1) { + struct ucl_parser *kp_parser; + + kp_parser = ucl_parser_new (0); + + if (ucl_parser_add_fd (kp_parser, fd)) { + ucl_object_t *kp_obj; + + kp_obj = ucl_parser_get_object (kp_parser); + + g_assert (kp_obj != NULL); + decrypt_keypair = rspamd_keypair_from_ucl (kp_obj); + + if (decrypt_keypair == NULL) { + msg_err_config_forced ("cannot load keypair from %s: invalid keypair", + keypair_path); + } + else { + /* Add decryption support to UCL */ + rspamd_mempool_add_destructor (cfg->cfg_pool, + (rspamd_mempool_destruct_t)rspamd_keypair_unref, + decrypt_keypair); + } + + ucl_object_unref (kp_obj); + } + else { + msg_err_config_forced ("cannot load keypair from %s: %s", + keypair_path, ucl_parser_get_error (kp_parser)); + } + + ucl_parser_free (kp_parser); + } + rspamd_cryptobox_hash_init (&hs, NULL, 0); parser = ucl_parser_new (UCL_PARSER_SAVE_COMMENTS); rspamd_ucl_add_conf_variables (parser, vars); rspamd_ucl_add_conf_macros (parser, cfg); + if (decrypt_keypair) { + struct ucl_parser_special_handler *decrypt_handler; + + decrypt_handler = rspamd_mempool_alloc0 (cfg->cfg_pool, + sizeof (*decrypt_handler)); + decrypt_handler->user_data = decrypt_keypair; + decrypt_handler->magic = encrypted_magic; + decrypt_handler->magic_len = sizeof (encrypted_magic); + decrypt_handler->handler = rspamd_rcl_decrypt_handler; + decrypt_handler->free_function = rspamd_rcl_decrypt_free; + + ucl_parser_add_special_handler (parser, decrypt_handler); + } + if (!ucl_parser_add_chunk (parser, data, st.st_size)) { msg_err_config_forced ("ucl parser error: %s", ucl_parser_get_error (parser)); ucl_parser_free (parser); diff --git a/src/libserver/symbols_cache.c b/src/libserver/symbols_cache.c index eac6c8d0c..abed26fc3 100644 --- a/src/libserver/symbols_cache.c +++ b/src/libserver/symbols_cache.c @@ -415,7 +415,7 @@ rspamd_symbols_cache_post_init (struct symbols_cache *cache) struct delayed_cache_dependency *ddep; struct delayed_cache_condition *dcond; GList *cur; - guint i, j; + gint i, j; gint id; rspamd_symbols_cache_resort (cache); @@ -506,6 +506,16 @@ rspamd_symbols_cache_post_init (struct symbols_cache *cache) msg_err_cache ("cannot find dependency on symbol %s", dep->sym); } } + + /* Reversed loop to make removal safe */ + for (j = it->deps->len - 1; j >= 0; j --) { + dep = g_ptr_array_index (it->deps, j); + + if (dep->item == NULL) { + /* Remove useless dep */ + g_ptr_array_remove_index (it->deps, j); + } + } } g_ptr_array_sort_with_data (cache->prefilters, prefilters_cmp, cache); diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index aed9d383d..070b9d6ca 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -81,6 +81,7 @@ rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx, gchar tmpbuf[128]; lua_State *L = task->cfg->lua_state; const gchar *headers_hash; + struct rspamd_mime_header *hdr; ar = g_array_sized_new (FALSE, FALSE, sizeof (elt), 16); elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; @@ -183,6 +184,24 @@ rspamd_stat_tokenize_parts_metadata (struct rspamd_stat_ctx *st_ctx, g_array_append_val (ar, elt); } + /* Use more precise headers order */ + cur = g_list_first (task->headers_order->head); + while (cur) { + hdr = cur->data; + + if (hdr->name && hdr->type != RSPAMD_HEADER_RECEIVED) { + /* We assume that headers count is not more than 10^10 */ + gsize nlen = strlen (hdr->name) + 1 + 10; + gchar *hdrbuf = rspamd_mempool_alloc (task->task_pool, nlen); + nlen = rspamd_snprintf (hdrbuf, nlen, "%s:%d", hdr->name, hdr->order); + rspamd_str_lc (hdrbuf, nlen); + elt.begin = hdrbuf; + elt.len = nlen; + g_array_append_val (ar, elt); + } + + cur = g_list_next (cur); + } /* Use metatokens plugin from Lua */ lua_getglobal (L, "rspamd_plugins"); diff --git a/src/libutil/map.c b/src/libutil/map.c index 1e8f70f7e..577933c25 100644 --- a/src/libutil/map.c +++ b/src/libutil/map.c @@ -346,11 +346,11 @@ rspamd_map_cache_cb (gint fd, short what, gpointer ud) /* We have another update, so this cache element is obviously expired */ /* Important: we do not set cache availability to zero here */ MAP_RELEASE (cache_cbd->shm, "rspamd_http_map_cached_cbdata"); - msg_debug_map ("cached data is now expired (gen mismatch) for %s", map->name); + msg_info_map ("cached data is now expired (gen mismatch) for %s", map->name); event_del (&cache_cbd->timeout); g_free (cache_cbd); } - else if (cache_cbd->data->last_checked > cache_cbd->last_checked) { + else if (cache_cbd->data->last_checked >= cache_cbd->last_checked) { /* * We checked map but we have not found anything more recent, * reschedule cache check @@ -363,7 +363,7 @@ rspamd_map_cache_cb (gint fd, short what, gpointer ud) else { g_atomic_int_set (&map->cache->available, 0); MAP_RELEASE (cache_cbd->shm, "rspamd_http_map_cached_cbdata"); - msg_debug_map ("cached data is now expired for %s", map->name); + msg_info_map ("cached data is now expired for %s", map->name); event_del (&cache_cbd->timeout); g_free (cache_cbd); } diff --git a/src/libutil/mem_pool.h b/src/libutil/mem_pool.h index 7423895c4..27d4c8ebf 100644 --- a/src/libutil/mem_pool.h +++ b/src/libutil/mem_pool.h @@ -21,7 +21,7 @@ struct f_str_s; #define MEMPOOL_TAG_LEN 20 #define MEMPOOL_UID_LEN 20 -#define MEM_ALIGNMENT 8 +#define MEM_ALIGNMENT 16 #define align_ptr(p, a) \ (guint8 *) (((uintptr_t) (p) + ((uintptr_t) a - 1)) & ~((uintptr_t) a - 1)) diff --git a/src/libutil/util.c b/src/libutil/util.c index 29c12ca2f..d55a8e8ac 100644 --- a/src/libutil/util.c +++ b/src/libutil/util.c @@ -83,6 +83,7 @@ #include "cryptobox.h" #include "zlib.h" +#include "contrib/uthash/utlist.h" /* Check log messages intensity once per minute */ #define CHECK_TIME 60 @@ -418,6 +419,12 @@ out: return (-1); } +static int +rspamd_prefer_v4_hack (const struct addrinfo *a1, const struct addrinfo *a2) +{ + return a1->ai_addr->sa_family - a2->ai_addr->sa_family; +} + /** * Make a universal socket * @param credits host, ip or path to unix socket @@ -480,6 +487,7 @@ rspamd_socket (const gchar *credits, guint16 port, rspamd_snprintf (portbuf, sizeof (portbuf), "%d", (int)port); if ((r = getaddrinfo (credits, portbuf, &hints, &res)) == 0) { + LL_SORT2 (res, rspamd_prefer_v4_hack, ai_next); r = rspamd_inet_socket_create (type, res, is_server, async, NULL); freeaddrinfo (res); return r; @@ -572,7 +580,9 @@ rspamd_sockets_list (const gchar *credits, guint16 port, rspamd_snprintf (portbuf, sizeof (portbuf), "%d", (int)port); if ((r = getaddrinfo (credits, portbuf, &hints, &res)) == 0) { - fd = rspamd_inet_socket_create (type, res, is_server, async, &result); + LL_SORT2 (res, rspamd_prefer_v4_hack, ai_next); + fd = rspamd_inet_socket_create (type, res, is_server, async, + &result); freeaddrinfo (res); if (result == NULL) { diff --git a/src/plugins/lua/asn.lua b/src/plugins/lua/asn.lua index 3e84a3824..61572a600 100644 --- a/src/plugins/lua/asn.lua +++ b/src/plugins/lua/asn.lua @@ -60,18 +60,19 @@ local function asn_check(task) local asn_check_func = {} function asn_check_func.rspamd(ip) + local dnsbl = options['provider_info']['ip' .. ip:get_version()] + local req_name = rspamd_logger.slog("%1.%2", + table.concat(ip:inversed_str_octets(), '.'), dnsbl) local function rspamd_dns_cb(_, _, results, dns_err) if dns_err and (dns_err ~= 'requested record is not found' and dns_err ~= 'no records with this name') then - rspamd_logger.errx(task, 'error querying dns: %s', dns_err) + rspamd_logger.errx(task, 'error querying dns (%s): %s', req_name, dns_err) end if not (results and results[1]) then return end local parts = rspamd_re:split(results[1]) -- "15169 | 8.8.8.0/24 | US | arin |" for 8.8.8.8 asn_set(parts[1], parts[2], parts[3]) end - local dnsbl = options['provider_info']['ip' .. ip:get_version()] - local req_name = rspamd_logger.slog("%1.%2", - table.concat(ip:inversed_str_octets(), '.'), dnsbl) + task:get_resolver():resolve_txt(task:get_session(), task:get_mempool(), req_name, rspamd_dns_cb) end diff --git a/src/plugins/lua/emails.lua b/src/plugins/lua/emails.lua index 51df0e959..5f076e1d5 100644 --- a/src/plugins/lua/emails.lua +++ b/src/plugins/lua/emails.lua @@ -52,7 +52,7 @@ local function check_email_rule(task, rule, addr) local function emails_dns_cb(_, _, results, err) if err and (err ~= 'requested record is not found' and err ~= 'no records with this name') then - logger.errx(task, 'Error querying DNS: %1', err) + logger.errx(task, 'Error querying DNS(%s): %s', to_resolve, err) elseif results then local expected_found = false local symbol = rule['symbol'] diff --git a/src/plugins/lua/forged_recipients.lua b/src/plugins/lua/forged_recipients.lua index d9cab67c7..782a408bd 100644 --- a/src/plugins/lua/forged_recipients.lua +++ b/src/plugins/lua/forged_recipients.lua @@ -39,7 +39,8 @@ local function check_forged_headers(task) if not mime_rcpt then return elseif #mime_rcpt == 0 then - task:insert_result(symbol_rcpt, score) + local sra = smtp_rcpt[1].addr .. (#smtp_rcpt > 1 and ' ...' or '') + task:insert_result(symbol_rcpt, score, '', sra) return end -- Find pair for each smtp recipient recipient in To or Cc headers @@ -67,7 +68,9 @@ local function check_forged_headers(task) end end if not res then - task:insert_result(symbol_rcpt, score) + local mra = mime_rcpt[1].addr .. (#mime_rcpt > 1 and ' ..' or '') + local sra = smtp_rcpt[1].addr .. (#smtp_rcpt > 1 and ' ...' or '') + task:insert_result(symbol_rcpt, score, mra, sra) break end end @@ -76,7 +79,7 @@ local function check_forged_headers(task) local mime_from = task:get_from(2) if not mime_from or not mime_from[1] or not (string.lower(mime_from[1]['addr']) == string.lower(smtp_from[1]['addr'])) then - task:insert_result(symbol_sender, 1) + task:insert_result(symbol_sender, 1, ((mime_from or E)[1] or E).addr or '', smtp_from[1].addr) end end end diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index 04b732472..9d0bbb446 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -65,10 +65,6 @@ local settings = { rules = {} } --- ANNs indexed by settings id -local anns = { -} - local opts = rspamd_config:get_all_opt("neural") if not opts then -- Legacy @@ -278,7 +274,7 @@ local function ann_scores_filter(task) id = id .. r end - if anns[id] and anns[id].ann then + if rule.anns[id] and rule.anns[id].ann then local ann_data = task:get_symbols_tokens() local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens @@ -286,15 +282,15 @@ local function ann_scores_filter(task) local score if use_torch then - local out = anns[id].ann:forward(torch.Tensor(ann_data)) + local out = rule.anns[id].ann:forward(torch.Tensor(ann_data)) score = out[1] else - local out = anns[id].ann:test(ann_data) + local out = rule.anns[id].ann:test(ann_data) score = out[1] end local symscore = string.format('%.3f', score) - rspamd_logger.infox(task, 'ann score: %s', symscore) + rspamd_logger.infox(task, '%s ann score: %s', rule.name, symscore) if score > 0 then local result = score @@ -339,28 +335,29 @@ end local function create_train_ann(rule, n, id) local prefix = gen_ann_prefix(rule, id) - if not anns[id] then - anns[id] = {} + if not rule.anns[id] then + rule.anns[id] = {} end -- Fix that for flexibe layers number - if anns[id].ann then - if not is_ann_valid(rule, prefix, anns[id].ann) then - anns[id].ann_train = create_ann(n, rule.nlayers) - anns[id].ann = nil + if rule.anns[id].ann then + if not is_ann_valid(rule, prefix, rule.anns[id].ann) then + rule.anns[id].ann_train = create_ann(n, rule.nlayers) + rule.anns[id].ann = nil rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) - elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then + elseif rule.train.max_usages > 0 and + rule.anns[id].version % rule.train.max_usages == 0 then -- Forget last ann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, - anns[id].version) - anns[id].ann_train = create_ann(n, rule.nlayers) + rule.anns[id].version) + rule.anns[id].ann_train = create_ann(n, rule.nlayers) else - anns[id].ann_train = anns[id].ann + rule.anns[id].ann_train = rule.anns[id].ann rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix) end else - anns[id].ann_train = create_ann(n, rule.nlayers) + rule.anns[id].ann_train = create_ann(n, rule.nlayers) rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix) - anns[id].version = 0 + rule.anns[id].version = 0 end end @@ -388,18 +385,18 @@ local function load_or_invalidate_ann(rule, data, id, ev_base) end if is_ann_valid(rule, prefix, ann) then - if not anns[id] then anns[id] = {} end - anns[id].ann = ann + if not rule.anns[id] then rule.anns[id] = {} end + rule.anns[id].ann = ann rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', prefix, ver) - anns[id].version = tonumber(ver) + rule.anns[id].version = tonumber(ver) else local function redis_invalidate_cb(_err, _data) if _err then rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - anns[id].version = 0 + rule.anns[id].version = 0 end end -- Invalidate ANN @@ -553,15 +550,15 @@ local function train_ann(rule, _, ev_base, elt, worker) local ann_data if use_torch then local f = torch.MemoryFile() - f:writeObject(anns[elt].ann_train) + f:writeObject(rule.anns[elt].ann_train) ann_data = rspamd_util.zstd_compress(f:storage():string()) else - ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data()) + ann_data = rspamd_util.zstd_compress(rule.anns[elt].ann_train:data()) end - anns[elt].version = anns[elt].version + 1 - anns[elt].ann = anns[elt].ann_train - anns[elt].ann_train = nil + rule.anns[elt].version = rule.anns[elt].version + 1 + rule.anns[elt].ann = rule.anns[elt].ann_train + rule.anns[elt].ann_train = nil lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, @@ -589,11 +586,11 @@ local function train_ann(rule, _, ev_base, elt, worker) local ann_data local f = torch.MemoryFile(torch.CharStorage():string(tostring(data))) ann_data = rspamd_util.zstd_compress(f:storage():string()) - anns[elt].ann_train = f:readObject() + rule.anns[elt].ann_train = f:readObject() - anns[elt].version = anns[elt].version + 1 - anns[elt].ann = anns[elt].ann_train - anns[elt].ann_train = nil + rule.anns[elt].version = rule.anns[elt].version + 1 + rule.anns[elt].ann = rule.anns[elt].ann_train + rule.anns[elt].ann_train = nil lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, @@ -629,7 +626,7 @@ local function train_ann(rule, _, ev_base, elt, worker) end -- Now we can train ann - if not anns[elt] or not anns[elt].ann_train then + if not rule.anns[elt] or not rule.anns[elt].ann_train then -- Create ann if it does not exist create_train_ann(rule, n, elt) end @@ -641,7 +638,7 @@ local function train_ann(rule, _, ev_base, elt, worker) rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - anns[elt].version = 0 + rule.anns[elt].version = 0 end end -- Invalidate ANN @@ -668,7 +665,7 @@ local function train_ann(rule, _, ev_base, elt, worker) torch.setnumthreads(rule.train.learn_threads) end local criterion = nn.MSECriterion() - local trainer = nn.StochasticGradient(anns[elt].ann_train, + local trainer = nn.StochasticGradient(rule.anns[elt].ann_train, criterion) trainer.learning_rate = rule.train.learning_rate trainer.verbose = false @@ -680,7 +677,7 @@ local function train_ann(rule, _, ev_base, elt, worker) trainer:train(dataset) local out = torch.MemoryFile() - out:writeObject(anns[elt].ann_train) + out:writeObject(rule.anns[elt].ann_train) local st = out:storage():string() return st end @@ -701,7 +698,7 @@ local function train_ann(rule, _, ev_base, elt, worker) end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) rule.learning_spawned = true rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) - anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained, + rule.anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained, ev_base, { max_epochs = rule.train.max_epoch, desired_mse = rule.train.mse @@ -880,9 +877,9 @@ local function check_anns(rule, _, ev_base) end local local_ver = 0 - if anns[elt] then - if anns[elt].version then - local_ver = anns[elt].version + if rule.anns[elt] then + if rule.anns[elt].version then + local_ver = rule.anns[elt].version end end lua_redis.exec_redis_script(redis_maybe_load_id, @@ -963,6 +960,7 @@ else for k,r in pairs(rules) do local def_rules = lua_util.override_defaults(default_options, r) def_rules['redis'] = redis_params + def_rules['anns'] = {} -- Store ANNs here if not def_rules.prefix then def_rules.prefix = k diff --git a/src/plugins/surbl.c b/src/plugins/surbl.c index 17bc26688..d0bdd3b8c 100644 --- a/src/plugins/surbl.c +++ b/src/plugins/surbl.c @@ -696,7 +696,6 @@ surbl_module_parse_rule (const ucl_object_t* value, struct rspamd_config* cfg) /* Mutually exclusive options */ msg_err_config ("options noip and resolve_ip are " "mutually exclusive for suffix %s", new_suffix->suffix); - ucl_object_unref (ropts); continue; } @@ -845,8 +844,6 @@ surbl_module_parse_rule (const ucl_object_t* value, struct rspamd_config* cfg) RSPAMD_MONITORED_DEFAULT, ropts); surbl_module_ctx->suffixes = g_list_prepend (surbl_module_ctx->suffixes, new_suffix); - - ucl_object_unref (ropts); } return nrules; |