From: Vsevolod Stakhov Date: Sat, 5 Nov 2016 12:03:20 +0000 (+0000) Subject: [Rework] Restore old fann_scores, move common parts X-Git-Tag: 1.4.0~122 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=c38238123244947b0effaa8b3637d80e90106f89;p=rspamd.git [Rework] Restore old fann_scores, move common parts --- diff --git a/src/lua/global_functions.lua b/src/lua/global_functions.lua index 0eb461496..b8b840043 100644 --- a/src/lua/global_functions.lua +++ b/src/lua/global_functions.lua @@ -164,3 +164,220 @@ function rspamd_str_split(s, sep) local p = lpeg.Ct(elem * (sep * elem)^0) -- make a table capture return lpeg.match(p, s) end + +-- Metafunctions +local function meta_size_function(task) + local sizes = { + 100, + 200, + 500, + 1000, + 2000, + 4000, + 10000, + 20000, + 30000, + 100000, + 200000, + 400000, + 800000, + 1000000, + 2000000, + 8000000, + } + + local size = task:get_size() + for i = 1,#sizes do + if sizes[i] >= size then + return {i / #sizes} + end + end + + return {0} +end + +local function meta_images_function(task) + local images = task:get_images() + local ntotal = 0 + local njpg = 0 + local npng = 0 + local nlarge = 0 + local nsmall = 0 + + if images then + for _,img in ipairs(images) do + if img:get_type() == 'png' then + npng = npng + 1 + elseif img:get_type() == 'jpeg' then + njpg = njpg + 1 + end + + local w = img:get_width() + local h = img:get_height() + + if w > 0 and h > 0 then + if w + h > 256 then + nlarge = nlarge + 1 + else + nsmall = nsmall + 1 + end + end + + ntotal = ntotal + 1 + end + end + if ntotal > 0 then + njpg = njpg / ntotal + npng = npng / ntotal + nlarge = nlarge / ntotal + nsmall = nsmall / ntotal + end + return {ntotal,njpg,npng,nlarge,nsmall} +end + +local function meta_nparts_function(task) + local nattachments = 0 + local ntextparts = 0 + local totalparts = 1 + + local tp = task:get_text_parts() + if tp then + ntextparts = #tp + end + + local parts = task:get_parts() + + if parts then + for _,p in ipairs(parts) do + if p:get_filename() then + nattachments = nattachments + 1 + end + totalparts = totalparts + 1 + end + end + + return {ntextparts/totalparts, nattachments/totalparts} +end + +local function meta_encoding_function(task) + local nutf = 0 + local nother = 0 + + local tp = task:get_text_parts() + if tp then + for _,p in ipairs(tp) do + if p:is_utf() then + nutf = nutf + 1 + else + nother = nother + 1 + end + end + end + + return {nutf, nother} +end + +local function meta_recipients_function(task) + local nmime = 0 + local nsmtp = 0 + + if task:has_recipients('mime') then + nmime = #(task:get_recipients('mime')) + end + if task:has_recipients('smtp') then + nsmtp = #(task:get_recipients('smtp')) + end + + if nmime > 0 then nmime = 1.0 / nmime end + if nsmtp > 0 then nsmtp = 1.0 / nsmtp end + + return {nmime,nsmtp} +end + +local function meta_received_function(task) + local ret = 0 + local rh = task:get_received_headers() + + if rh and #rh > 0 then + ret = 1 / #rh + end + + return {ret} +end + +local function meta_urls_function(task) + if task:has_urls() then + return {1.0 / #(task:get_urls())} + end + + return {0} +end + +local function meta_attachments_function(task) +end + +local metafunctions = { + { + cb = meta_size_function, + ninputs = 1, + }, + { + cb = meta_images_function, + ninputs = 5, + -- 1 - number of images, + -- 2 - number of png images, + -- 3 - number of jpeg images + -- 4 - number of large images (> 128 x 128) + -- 5 - number of small images (< 128 x 128) + }, + { + cb = meta_nparts_function, + ninputs = 2, + -- 1 - number of text parts + -- 2 - number of attachments + }, + { + cb = meta_encoding_function, + ninputs = 2, + -- 1 - number of utf parts + -- 2 - number of non-utf parts + }, + { + cb = meta_recipients_function, + ninputs = 2, + -- 1 - number of mime rcpt + -- 2 - number of smtp rcpt + }, + { + cb = meta_received_function, + ninputs = 1, + }, + { + cb = meta_urls_function, + ninputs = 1, + }, +} + +function rspamd_gen_metatokens(task) + local ipairs = ipairs + local metatokens = {} + for _,mt in ipairs(metafunctions) do + local ct = mt.cb(task) + + for _,tok in ipairs(ct) do + table.insert(metatokens, tok) + end + end + + return metatokens +end + +function rspamd_count_metatokens() + local ipairs = ipairs + local total = 0 + for _,mt in ipairs(metafunctions) do + total = total + mt.ninputs + end + + return total +end \ No newline at end of file diff --git a/src/plugins/lua/fann_classifier.lua b/src/plugins/lua/fann_classifier.lua index af7acece8..9c35d0bfa 100644 --- a/src/plugins/lua/fann_classifier.lua +++ b/src/plugins/lua/fann_classifier.lua @@ -149,7 +149,7 @@ local function tokens_to_vector(tokens) end local function add_metatokens(task, vec) - local mt = gen_metatokens(task) + local mt = rspamd_gen_metatokens(task) for _,tok in ipairs(mt) do table.insert(vec, tok) end @@ -157,7 +157,7 @@ end local function create_fann() local layers = {} - local mt_size = count_metatokens() + local mt_size = rspamd_count_metatokens() local neurons = classifier_config.neurons + mt_size for i = 1,classifier_config.layers - 1 do diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua new file mode 100644 index 000000000..e81af4762 --- /dev/null +++ b/src/plugins/lua/fann_redis.lua @@ -0,0 +1,586 @@ +--[[ +Copyright (c) 2016, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This plugin is a concept of FANN scores adjustment +-- NOT FOR PRODUCTION USE so far + +local rspamd_logger = require "rspamd_logger" +local rspamd_fann = require "rspamd_fann" +local rspamd_util = require "rspamd_util" +local fann_symbol_spam = 'FANN_SPAM' +local fann_symbol_ham = 'FANN_HAM' +require "fun" () +local ucl = require "ucl" + +local module_log_id = 0x100 +-- Module vars +-- ANNs indexed by settings id +local data = { + ['0'] = { + fann_mtime = 0, + ntrains = 0, + epoch = 0, + } +} + + +-- Lua script to train a row +-- Uses the following keys: +-- key1 - prefix for keys +-- key2 - max count of learns +-- key3 - spam or ham +-- returns 1 or 0: 1 - allow learn, 0 - not allow learn +local redis_lua_script_can_train = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return 0 end + local nspam = 0 + local nham = 0 + + local ret = redis.call('LLEN', KEYS[1] .. '_spam') + if ret then nspam = tonumber(ret) end + ret = redis.call('LLEN', KEYS[1] .. '_ham') + if ret then nham = tonumber(ret) end + + if KEYS[3] == 'spam' then + if nham + 1 >= nspam then return tostring(nspam + 1) end + else + if nspam + 1 >= nham then return tostring(nham + 1) end + end + + return tostring(0) +]] +local redis_can_train_sha = nil + +-- Lua script to load ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - local version +-- returns nil or bulk string if new ANN can be loaded +local redis_lua_script_maybe_load = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return false end + + local ver = 0 + local ret = redis.call('GET', KEYS[1] .. '_version') + if ret then ver = tonumber(ret) end + if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end + + return false +]] +local redis_maybe_load_sha = nil + +-- Lua script to invalidate ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +local redis_lua_script_maybe_invalidate = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return false end + redis.call('SET', KEYS[1] .. '_locked', '1') + redis.call('SET', KEYS[1] .. '_version', '0') + redis.call('DEL', KEYS[1] .. '_spam') + redis.call('DEL', KEYS[1] .. '_ham') + redis.call('DEL', KEYS[1] .. '_data') + redis.call('DEL', KEYS[1] .. '_locked') + return 1 +]] +local redis_maybe_invalidate_sha = nil + +local redis_params +redis_params = rspamd_parse_redis_server('fann_redis') + +local fann_prefix = 'RFANN' +local max_trains = 1000 +local max_epoch = 100 +local use_settings = false +local watch_interval = 60.0 + +local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) + if not ev_base or not redis_params or not callback or not command then + return false,nil,nil + end + + local addr + local rspamd_redis = require "rspamd_redis" + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(task, 'cannot select server to make redis request') + end + + local options = { + ev_base = ev_base, + config = cfg, + callback = callback, + host = addr:get_addr(), + timeout = redis_params['timeout'], + cmd = command, + args = args + } + + if redis_params['password'] then + options['password'] = redis_params['password'] + end + + if redis_params['db'] then + options['dbname'] = redis_params['db'] + end + + local ret,conn = rspamd_redis.make_request(options) + if not ret then + rspamd_logger.errx('cannot execute redis request') + end + return ret,conn,addr +end + +local function symbols_to_fann_vector(syms, scores) + local learn_data = {} + local matched_symbols = {} + local n = rspamd_config:get_symbols_count() + + each(function(s, score) + matched_symbols[s + 1] = rspamd_util.tanh(score) + end, zip(syms, scores)) + + for i=1,n do + if matched_symbols[i] then + learn_data[i] = matched_symbols[i] + else + learn_data[i] = 0 + end + end + + return learn_data +end + +local function gen_fann_prefix(id) + if use_settings then + return fann_prefix .. id + else + return fann_prefix + end +end + +local function is_fann_valid(ann) + if ann then + local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() + + if n ~= ann:get_inputs() then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', ann:get_inputs(), n) + return false + end + local layers = ann:get_layers() + + if not layers or #layers ~= 5 then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', + #layers) + return false + end + + return true + end +end + +local function fann_scores_filter(task) + local id = '0' + if use_settings then + local sid = task:get_settings_id() + if sid then + id = tostring(sid) + end + end + + if data[id].fann then + local symbols,scores = task:get_symbols_numeric() + local fann_data = symbols_to_fann_vector(symbols, scores) + local mt = rspamd_gen_metatokens(task) + + for _,tok in ipairs(mt) do + table.insert(fann_data, tok) + end + + local out = data[id].fann:test(fann_data) + local symscore = string.format('%.3f', out[1]) + rspamd_logger.infox(task, 'fann score: %s', symscore) + + if out[1] > 0 then + local result = rspamd_util.normalize_prob(out[1] / 2.0, 0) + task:insert_result(fann_symbol_spam, result, symscore, id) + else + local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) + task:insert_result(fann_symbol_ham, result, symscore, id) + end + end +end + +local function create_train_fann(n, id) + data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1) + data[id].ntrains = 0 + data[id].epoch = 0 +end + +local function load_or_invalidate_fann(data, id, ev_base) + local err,ann_data = rspamd_util.zstd_decompress(data) + local ann + + if err or not ann_data then + rspamd_logger.errx('cannot decompress ann: %s', err) + else + ann = rspamd_fann.load_data(ann_data) + end + + if is_fann_valid(ann) then + data[id].fann = ann + else + local function redis_invalidate_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err) + elseif type(data) == 'string' then + rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err) + end + end + -- Invalidate ANN + rspamd_logger.infox('invalidate ANN %s') + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + redis_invalidate_cb, --callback + 'EVALSHA', -- command + {redis_maybe_invalidate_sha, 1, fann_prefix .. id} + ) + end +end + +local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base) + local fname = gen_fann_prefix(id) + + local learn_spam, learn_ham = false, false + + if opts['spam_score'] then + learn_spam = score >= opts['spam_score'] + else + learn_spam = score >= required_score + end + if opts['ham_score'] then + learn_ham = score <= opts['ham_score'] + else + learn_ham = score < 0 + end + + if learn_spam or learn_ham then + local k + if learn_spam then k = 'spam' else k = 'ham' end + + local function learn_vec_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err) + end + end + + local function can_train_cb(err, data) + if not err and tonumber(data) > 0 then + local learn_data = symbols_to_fann_vector( + map(function(r) return r[1] end, results), + map(function(r) return r[2] end, results) + ) + -- Add filtered meta tokens + each(function(e) table.insert(learn_data, e) end, extra) + local str = rspamd_util.zstd_compress(table.concat(learn_data, ';')) + + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + learn_vec_cb, --callback + 'LPUSH', -- command + {fname .. '_' .. k, str} -- arguments + ) + else + if err then + rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err) + end + end + end + + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + can_train_cb, --callback + 'EVALSHA', -- command + {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments + ) + end +end + +local function train_fann(cfg, ev_base, elt) + +end + +local function maybe_train_fanns(cfg, ev_base) + local function members_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) + elseif type(data) == 'table' then + each(function(i, elt) + local redis_len_cb = function(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + if tonumber(data) and tonumber(data) > max_trains then + train_fann(cfg, ev_base, elt) + end + end + end + + local local_ver = 0 + local numelt = tonumber(elt) + if data[numelt] then + if data[numelt].version then + local_ver = data[numelt].version + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_len_cb, --callback + 'LLEN', -- command + {fann_prefix .. elt .. '_spam'} + ) + end, + data) + end + end + + if not redis_maybe_load_sha then + -- Plan new event early + return 1.0 + end + -- First we need to get all fanns stored in our Redis + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + members_cb, --callback + 'SMEMBERS', -- command + {fann_prefix} -- arguments + ) + + return watch_interval +end + +local function check_fanns(cfg, ev_base) + local function members_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) + elseif type(data) == 'table' then + each(function(i, elt) + local redis_update_cb = function(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) + elseif data and type(data) == 'string' then + load_or_invalidate_fann(data, elt, ev_base) + end + end + + local local_ver = 0 + local numelt = tonumber(elt) + if data[numelt] then + if data[numelt].version then + local_ver = data[numelt].version + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_update_cb, --callback + 'EVALSHA', -- command + {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)} + ) + end, + data) + end + end + + if not redis_maybe_load_sha then + -- Plan new event early + return 1.0 + end + -- First we need to get all fanns stored in our Redis + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + members_cb, --callback + 'SMEMBERS', -- command + {fann_prefix} -- arguments + ) + + return watch_interval +end + +-- Initialization part + +local opts = rspamd_config:get_all_opt("fann_redis") +if not (opts and type(opts) == 'table') or not redis_params then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + return +end + +if not rspamd_fann.is_enabled() then + rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. + 'module is eventually disabled') + return +else + use_settings = opts['use_settings'] + rspamd_config:set_metric_symbol({ + name = fann_symbol_spam, + score = 3.0, + description = 'Neural network SPAM', + group = 'fann' + }) + local id = rspamd_config:register_symbol({ + name = fann_symbol_spam, + type = 'postfilter', + priority = 5, + callback = fann_scores_filter + }) + rspamd_config:set_metric_symbol({ + name = fann_symbol_ham, + score = -2.0, + description = 'Neural network HAM', + group = 'fann' + }) + rspamd_config:register_symbol({ + name = fann_symbol_ham, + type = 'virtual', + parent = id + }) + if opts['train'] then + rspamd_config:add_on_load(function(cfg) + if opts['train']['max_train'] then + max_trains = opts['train']['max_train'] + end + if opts['train']['max_epoch'] then + max_epoch = opts['train']['max_epoch'] + end + local ret = cfg:register_worker_script("log_helper", + function(score, req_score, results, cf, id, extra, ev_base) + -- map (snd x) (filter (fst x == module_id) extra) + local extra_fann = map(function(e) return e[2] end, + filter(function(e) return e[1] == module_log_id end, extra)) + if use_settings then + fann_train_callback(score, req_score, results, cf, + tostring(id), opts['train'], extra_fann, ev_base) + else + fann_train_callback(score, req_score, results, cf, '0', + opts['train'], extra_fann, ev_base) + end + end) + + if not ret then + rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') + end + end) + -- This is needed to pass extra tokens from worker to log_helper + rspamd_plugins["fann_score"] = { + log_callback = function(task) + return totable(map( + function(tok) return {module_log_id, tok} end, + rspamd_gen_metatokens(task))) + end + } + end + -- Add training scripts + rspamd_config:add_on_load(function(cfg, ev_base, worker) + local function can_train_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err) + else + redis_can_train_sha = tostring(data) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + can_train_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_can_train} -- arguments + ) + + local function maybe_load_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err) + else + redis_maybe_load_sha = tostring(data) + + rspamd_config:add_periodic(ev_base, 0.0, + function(cfg, ev_base) + return check_fanns(cfg, ev_base) + end) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + maybe_load_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_maybe_load} -- arguments + ) + + local function maybe_invalidate_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err) + else + redis_maybe_invalidate_sha = tostring(data) + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + maybe_invalidate_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_maybe_invalidate} -- arguments + ) + + if worker:get_name() == 'normal' then + -- We also want to train neural nets when they have enough data + rspamd_config:add_periodic(ev_base, 0.0, + function(cfg, ev_base) + return maybe_train_fanns(cfg, ev_base) + end) + end + end) +end diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 7bc55117d..f533a34d8 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -36,389 +36,123 @@ local data = { } } - --- Lua script to train a row --- Uses the following keys: --- key1 - prefix for keys --- key2 - max count of learns --- key3 - spam or ham --- returns 1 or 0: 1 - allow learn, 0 - not allow learn -local redis_lua_script_can_train = [[ - local locked = redis.call('GET', KEYS[1] .. '_locked') - if locked then return 0 end - local nspam = 0 - local nham = 0 - - local ret = redis.call('LLEN', KEYS[1] .. '_spam') - if ret then nspam = tonumber(ret) end - ret = redis.call('LLEN', KEYS[1] .. '_ham') - if ret then nham = tonumber(ret) end - - if KEYS[3] == 'spam' then - if nham + 1 >= nspam then return tostring(nspam + 1) end - else - if nspam + 1 >= nham then return tostring(nham + 1) end - end - - return tostring(0) -]] -local redis_can_train_sha = nil - --- Lua script to load ANN from redis --- Uses the following keys --- key1 - prefix for keys --- key2 - local version --- returns nil or bulk string if new ANN can be loaded -local redis_lua_script_maybe_load = [[ - local locked = redis.call('GET', KEYS[1] .. '_locked') - if locked then return false end - - local ver = 0 - local ret = redis.call('GET', KEYS[1] .. '_version') - if ret then ver = tonumber(ret) end - if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end - - return false -]] -local redis_maybe_load_sha = nil - --- Lua script to invalidate ANN from redis --- Uses the following keys --- key1 - prefix for keys -local redis_lua_script_maybe_invalidate = [[ - local locked = redis.call('GET', KEYS[1] .. '_locked') - if locked then return false end - redis.call('SET', KEYS[1] .. '_locked', '1') - redis.call('SET', KEYS[1] .. '_version', '0') - redis.call('DEL', KEYS[1] .. '_spam') - redis.call('DEL', KEYS[1] .. '_ham') - redis.call('DEL', KEYS[1] .. '_data') - redis.call('DEL', KEYS[1] .. '_locked') - return 1 -]] -local redis_maybe_invalidate_sha = nil - -local redis_params -redis_params = rspamd_parse_redis_server('fann_scores') - -local fann_prefix = 'RFANN' +local fann_file local max_trains = 1000 local max_epoch = 100 local use_settings = false -local watch_interval = 60.0 -local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) - if not ev_base or not redis_params or not callback or not command then - return false,nil,nil - end +local function symbols_to_fann_vector(syms, scores) + local learn_data = {} + local matched_symbols = {} + local n = rspamd_config:get_symbols_count() - local addr - local rspamd_redis = require "rspamd_redis" + each(function(s, score) + matched_symbols[s + 1] = rspamd_util.tanh(score) + end, zip(syms, scores)) - if key then - if is_write then - addr = redis_params['write_servers']:get_upstream_by_hash(key) - else - addr = redis_params['read_servers']:get_upstream_by_hash(key) - end - else - if is_write then - addr = redis_params['write_servers']:get_upstream_master_slave(key) + for i=1,n do + if matched_symbols[i] then + learn_data[i] = matched_symbols[i] else - addr = redis_params['read_servers']:get_upstream_round_robin(key) + learn_data[i] = 0 end end - if not addr then - logger.errx(task, 'cannot select server to make redis request') - end - - local options = { - ev_base = ev_base, - config = cfg, - callback = callback, - host = addr:get_addr(), - timeout = redis_params['timeout'], - cmd = command, - args = args - } - - if redis_params['password'] then - options['password'] = redis_params['password'] - end - - if redis_params['db'] then - options['dbname'] = redis_params['db'] - end - - local ret,conn = rspamd_redis.make_request(options) - if not ret then - rspamd_logger.errx('cannot execute redis request') - end - return ret,conn,addr + return learn_data end --- Metafunctions -local function fann_size_function(task) - local sizes = { - 100, - 200, - 500, - 1000, - 2000, - 4000, - 10000, - 20000, - 30000, - 100000, - 200000, - 400000, - 800000, - 1000000, - 2000000, - 8000000, - } - - local size = task:get_size() - for i = 1,#sizes do - if sizes[i] >= size then - return {i / #sizes} - end +local function gen_fann_file(id) + if use_settings then + return fann_file .. id + else + return fann_file end - - return {0} end -local function fann_images_function(task) - local images = task:get_images() - local ntotal = 0 - local njpg = 0 - local npng = 0 - local nlarge = 0 - local nsmall = 0 - - if images then - for _,img in ipairs(images) do - if img:get_type() == 'png' then - npng = npng + 1 - elseif img:get_type() == 'jpeg' then - njpg = njpg + 1 - end - - local w = img:get_width() - local h = img:get_height() - - if w > 0 and h > 0 then - if w + h > 256 then - nlarge = nlarge + 1 - else - nsmall = nsmall + 1 - end - end +local function load_fann(id) + local fname = gen_fann_file(id) + local err,st = rspamd_util.stat(fname) - ntotal = ntotal + 1 - end + if err then + return false end - if ntotal > 0 then - njpg = njpg / ntotal - npng = npng / ntotal - nlarge = nlarge / ntotal - nsmall = nsmall / ntotal - end - return {ntotal,njpg,npng,nlarge,nsmall} -end -local function fann_nparts_function(task) - local nattachments = 0 - local ntextparts = 0 - local totalparts = 1 + local fd = rspamd_util.lock_file(fname) + data[id].fann = rspamd_fann.load(fname) + rspamd_util.unlock_file(fd) -- closes fd - local tp = task:get_text_parts() - if tp then - ntextparts = #tp - end + if data[id].fann then + local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() - local parts = task:get_parts() + if n ~= data[id].fann:get_inputs() then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache; removing', data[id].fann:get_inputs(), n) + data[id].fann = nil - if parts then - for _,p in ipairs(parts) do - if p:get_filename() then - nattachments = nattachments + 1 + local ret,err = rspamd_util.unlink(fname) + if not ret then + rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', + fname, err) end - totalparts = totalparts + 1 - end - end - - return {ntextparts/totalparts, nattachments/totalparts} -end - -local function fann_encoding_function(task) - local nutf = 0 - local nother = 0 - - local tp = task:get_text_parts() - if tp then - for _,p in ipairs(tp) do - if p:is_utf() then - nutf = nutf + 1 + else + local layers = data[id].fann:get_layers() + + if not layers or #layers ~= 5 then + rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing', + #layers) + data[id].fann = nil + local ret,err = rspamd_util.unlink(fname) + if not ret then + rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', + fname, err) + end else - nother = nother + 1 + rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname) + return true end end - end - - return {nutf, nother} -end - -local function fann_recipients_function(task) - local nmime = 0 - local nsmtp = 0 - - if task:has_recipients('mime') then - nmime = #(task:get_recipients('mime')) - end - if task:has_recipients('smtp') then - nsmtp = #(task:get_recipients('smtp')) - end - - if nmime > 0 then nmime = 1.0 / nmime end - if nsmtp > 0 then nsmtp = 1.0 / nsmtp end - - return {nmime,nsmtp} -end - -local function fann_received_function(task) - local ret = 0 - local rh = task:get_received_headers() - - if rh and #rh > 0 then - ret = 1 / #rh - end - - return {ret} -end - -local function fann_urls_function(task) - if task:has_urls() then - return {1.0 / #(task:get_urls())} - end - - return {0} -end - -local function fann_attachments_function(task) -end - -local metafunctions = { - { - cb = fann_size_function, - ninputs = 1, - }, - { - cb = fann_images_function, - ninputs = 5, - -- 1 - number of images, - -- 2 - number of png images, - -- 3 - number of jpeg images - -- 4 - number of large images (> 128 x 128) - -- 5 - number of small images (< 128 x 128) - }, - { - cb = fann_nparts_function, - ninputs = 2, - -- 1 - number of text parts - -- 2 - number of attachments - }, - { - cb = fann_encoding_function, - ninputs = 2, - -- 1 - number of utf parts - -- 2 - number of non-utf parts - }, - { - cb = fann_recipients_function, - ninputs = 2, - -- 1 - number of mime rcpt - -- 2 - number of smtp rcpt - }, - { - cb = fann_received_function, - ninputs = 1, - }, - { - cb = fann_urls_function, - ninputs = 1, - }, -} - -local function gen_metatokens(task) - local metatokens = {} - for _,mt in ipairs(metafunctions) do - local ct = mt.cb(task) - - for _,tok in ipairs(ct) do - table.insert(metatokens, tok) - end - end - - return metatokens -end - -local function count_metatokens() - local total = 0 - for _,mt in ipairs(metafunctions) do - total = total + mt.ninputs - end - - return total -end - -local function symbols_to_fann_vector(syms, scores) - local learn_data = {} - local matched_symbols = {} - local n = rspamd_config:get_symbols_count() - - each(function(s, score) - matched_symbols[s + 1] = rspamd_util.tanh(score) - end, zip(syms, scores)) - - for i=1,n do - if matched_symbols[i] then - learn_data[i] = matched_symbols[i] - else - learn_data[i] = 0 + else + rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname) + local ret,err = rspamd_util.unlink(fname) + if not ret then + rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', + fname, err) end end - return learn_data -end - -local function gen_fann_prefix(id) - if use_settings then - return fann_prefix .. id - else - return fann_prefix - end + return false end -local function is_fann_valid(ann) - if ann then - local n = rspamd_config:get_symbols_count() + count_metatokens() +local function check_fann(id) + if data[id].fann then + local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() - if n ~= ann:get_inputs() then + if n ~= data[id].fann:get_inputs() then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', ann:get_inputs(), n) - return false + ' is found in the cache', data[id].fann:get_inputs(), n) + data[id].fann = nil end - local layers = ann:get_layers() + local layers = data[id].fann:get_layers() if not layers or #layers ~= 5 then rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', #layers) - return false + data[id].fann = nil end + end + + local fname = gen_fann_file(id) + local err,st = rspamd_util.stat(fname) + + if not err then + local mtime = st['mtime'] - return true + if mtime > data[id].fann_mtime then + rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' .. + 'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname) + data[id].fann_mtime = mtime + data[id].fann = nil + end end end @@ -431,10 +165,12 @@ local function fann_scores_filter(task) end end + check_fann(id) + if data[id].fann then local symbols,scores = task:get_symbols_numeric() local fann_data = symbols_to_fann_vector(symbols, scores) - local mt = gen_metatokens(task) + local mt = rspamd_gen_metatokens(task) for _,tok in ipairs(mt) do table.insert(fann_data, tok) @@ -451,6 +187,10 @@ local function fann_scores_filter(task) local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) task:insert_result(fann_symbol_ham, result, symscore, id) end + else + if load_fann(id) then + fann_scores_filter(task) + end end end @@ -460,44 +200,69 @@ local function create_train_fann(n, id) data[id].epoch = 0 end -local function load_or_invalidate_fann(data, id, ev_base) - local err,ann_data = rspamd_util.zstd_decompress(data) - local ann +local function fann_train_callback(score, required_score, results, cf, id, opts, extra) + local n = cf:get_symbols_count() + rspamd_count_metatokens() + local fname = gen_fann_file(id) - if err or not ann_data then - rspamd_logger.errx('cannot decompress ann: %s', err) - else - ann = rspamd_fann.load_data(ann_data) + if not data[id].fann_train then + create_train_fann(n, id) + end + + if data[id].fann_train:get_inputs() ~= n then + rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' .. + ' is found in the cache', data[id].fann_train:get_inputs(), n) + create_train_fann(n, id) end - if is_fann_valid(ann) then - data[id].fann = ann + if data[id].ntrains > max_trains then + -- Store fann on disk + local res = false + + local err,st = rspamd_util.stat(fname) + if err then + local fd,err = rspamd_util.create_file(fname) + if not fd then + rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err) + else + rspamd_util.lock_file(fname, fd) + res = data[id].fann_train:save(fname) + rspamd_util.unlock_file(fd) -- Closes fd as well + end + else + local fd = rspamd_util.lock_file(fname) + res = data[id].fann_train:save(fname) + rspamd_util.unlock_file(fd) -- Closes fd as well + end + + if not res then + rspamd_logger.errx(cf, 'cannot save fann in %s', fname) + else + data[id].exist = true + data[id].ntrains = 0 + data[id].epoch = data[id].epoch + 1 + end else - local function redis_invalidate_cb(err, data) + if not data[id].checked then + data[id].checked = true + local err,st = rspamd_util.stat(fname) if err then - rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', id, err) - elseif type(data) == 'string' then - rspamd_logger.info(rspamd_config, 'invalidated ANN %s from redis: %s', id, err) + data[id].exist = false end end - -- Invalidate ANN - rspamd_logger.infox('invalidate ANN %s') - redis_make_request(ev_base, - rspamd_config, - nil, - true, -- is write - redis_invalidate_cb, --callback - 'EVALSHA', -- command - {redis_maybe_invalidate_sha, 1, fann_prefix .. id} - ) + if not data[id].exist then + rspamd_logger.infox(cf, 'not enough trains for fann %s, %s left', fname, + max_trains - data[id].ntrains) + end end -end -local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base) - local fname = gen_fann_prefix(id) + if data[id].epoch > max_epoch then + -- Re-create fann + rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname, + max_epoch) + create_train_fann(n, id) + end local learn_spam, learn_ham = false, false - if opts['spam_score'] then learn_spam = score >= opts['spam_score'] else @@ -510,157 +275,21 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, end if learn_spam or learn_ham then - local k - if learn_spam then k = 'spam' else k = 'ham' end - - local function learn_vec_cb(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot store train vector: %s', err) - end - end - - local function can_train_cb(err, data) - if not err and tonumber(data) > 0 then - local learn_data = symbols_to_fann_vector( - map(function(r) return r[1] end, results), - map(function(r) return r[2] end, results) - ) - -- Add filtered meta tokens - each(function(e) table.insert(learn_data, e) end, extra) - local str = rspamd_util.zstd_compress(table.concat(learn_data, ';')) - - redis_make_request(ev_base, - rspamd_config, - nil, - true, -- is write - learn_vec_cb, --callback - 'LPUSH', -- command - {fname .. '_' .. k, str} -- arguments - ) - else - if err then - rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err) - end - end - end - - redis_make_request(ev_base, - rspamd_config, - nil, - false, -- is write - can_train_cb, --callback - 'EVALSHA', -- command - {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments + local learn_data = symbols_to_fann_vector( + map(function(r) return r[1] end, results), + map(function(r) return r[2] end, results) ) - end -end - -local function train_fann(cfg, ev_base, elt) + -- Add filtered meta tokens + each(function(e) table.insert(learn_data, e) end, extra) -end - -local function maybe_train_fanns(cfg, ev_base) - local function members_cb(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) - elseif type(data) == 'table' then - each(function(i, elt) - local redis_len_cb = function(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err) - elseif data and type(data) == 'number' or type(data) == 'string' then - if tonumber(data) and tonumber(data) > max_trains then - train_fann(cfg, ev_base, elt) - end - end - end - - local local_ver = 0 - local numelt = tonumber(elt) - if data[numelt] then - if data[numelt].version then - local_ver = data[numelt].version - end - end - redis_make_request(ev_base, - rspamd_config, - nil, - false, -- is write - redis_len_cb, --callback - 'LLEN', -- command - {fann_prefix .. elt .. '_spam'} - ) - end, - data) - end - end - - if not redis_maybe_load_sha then - -- Plan new event early - return 1.0 - end - -- First we need to get all fanns stored in our Redis - redis_make_request(ev_base, - rspamd_config, - nil, - false, -- is write - members_cb, --callback - 'SMEMBERS', -- command - {fann_prefix} -- arguments - ) - - return watch_interval -end - -local function check_fanns(cfg, ev_base) - local function members_cb(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) - elseif type(data) == 'table' then - each(function(i, elt) - local redis_update_cb = function(err, data) - if err then - rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) - elseif data and type(data) == 'string' then - load_or_invalidate_fann(data, elt, ev_base) - end - end - - local local_ver = 0 - local numelt = tonumber(elt) - if data[numelt] then - if data[numelt].version then - local_ver = data[numelt].version - end - end - redis_make_request(ev_base, - rspamd_config, - nil, - false, -- is write - redis_update_cb, --callback - 'EVALSHA', -- command - {redis_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)} - ) - end, - data) + if learn_spam then + data[id].fann_train:train(learn_data, {1.0}) + else + data[id].fann_train:train(learn_data, {-1.0}) end - end - if not redis_maybe_load_sha then - -- Plan new event early - return 1.0 + data[id].ntrains = data[id].ntrains + 1 end - -- First we need to get all fanns stored in our Redis - redis_make_request(ev_base, - rspamd_config, - nil, - false, -- is write - members_cb, --callback - 'SMEMBERS', -- command - {fann_prefix} -- arguments - ) - - return watch_interval end -- Initialization part @@ -674,128 +303,71 @@ end if not rspamd_fann.is_enabled() then rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. 'module is eventually disabled') + return else - use_settings = opts['use_settings'] - rspamd_config:set_metric_symbol({ - name = fann_symbol_spam, - score = 3.0, - description = 'Neural network SPAM', - group = 'fann' - }) - local id = rspamd_config:register_symbol({ - name = fann_symbol_spam, - type = 'postfilter', - priority = 5, - callback = fann_scores_filter - }) - rspamd_config:set_metric_symbol({ - name = fann_symbol_ham, - score = -2.0, - description = 'Neural network HAM', - group = 'fann' - }) - rspamd_config:register_symbol({ - name = fann_symbol_ham, - type = 'virtual', - parent = id - }) - if opts['train'] then - rspamd_config:add_on_load(function(cfg) - if opts['train']['max_train'] then - max_trains = opts['train']['max_train'] - end - if opts['train']['max_epoch'] then - max_epoch = opts['train']['max_epoch'] - end - local ret = cfg:register_worker_script("log_helper", - function(score, req_score, results, cf, id, extra, ev_base) - -- map (snd x) (filter (fst x == module_id) extra) - local extra_fann = map(function(e) return e[2] end, - filter(function(e) return e[1] == module_log_id end, extra)) - if use_settings then - fann_train_callback(score, req_score, results, cf, - tostring(id), opts['train'], extra_fann, ev_base) - else - fann_train_callback(score, req_score, results, cf, '0', - opts['train'], extra_fann, ev_base) - end + if not opts['fann_file'] then + rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' .. + '`fann_file` to be specified') + else + fann_file = opts['fann_file'] + use_settings = opts['use_settings'] + rspamd_config:set_metric_symbol({ + name = fann_symbol_spam, + score = 3.0, + description = 'Neural network SPAM', + group = 'fann' + }) + local id = rspamd_config:register_symbol({ + name = fann_symbol_spam, + type = 'postfilter', + priority = 5, + callback = fann_scores_filter + }) + rspamd_config:set_metric_symbol({ + name = fann_symbol_ham, + score = -2.0, + description = 'Neural network HAM', + group = 'fann' + }) + rspamd_config:register_symbol({ + name = fann_symbol_ham, + type = 'virtual', + parent = id + }) + if opts['train'] then + rspamd_config:add_on_load(function(cfg) + if opts['train']['max_train'] then + max_trains = opts['train']['max_train'] + end + if opts['train']['max_epoch'] then + max_epoch = opts['train']['max_epoch'] + end + local ret = cfg:register_worker_script("log_helper", + function(score, req_score, results, cf, id, extra) + -- map (snd x) (filter (fst x == module_id) extra) + local extra_fann = map(function(e) return e[2] end, + filter(function(e) return e[1] == module_log_id end, extra)) + if use_settings then + fann_train_callback(score, req_score, results, cf, + tostring(id), opts['train'], extra_fann) + else + fann_train_callback(score, req_score, results, cf, '0', + opts['train'], extra_fann) + end end) - if not ret then - rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') - end - end) - -- This is needed to pass extra tokens from worker to log_helper - rspamd_plugins["fann_score"] = { - log_callback = function(task) - return totable(map( - function(tok) return {module_log_id, tok} end, - gen_metatokens(task))) - end - } - end - -- Add training scripts - rspamd_config:add_on_load(function(cfg, ev_base, worker) - local function can_train_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err) - else - redis_can_train_sha = tostring(data) - end - end - redis_make_request(ev_base, - rspamd_config, - nil, - true, -- is write - can_train_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_can_train} -- arguments - ) - - local function maybe_load_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err) - else - redis_maybe_load_sha = tostring(data) - - rspamd_config:add_periodic(ev_base, 0.0, - function(cfg, ev_base) - return check_fanns(cfg, ev_base) - end) - end - end - redis_make_request(ev_base, - rspamd_config, - nil, - true, -- is write - maybe_load_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_maybe_load} -- arguments - ) - - local function maybe_invalidate_sha_cb(err, data) - if err or not data or type(data) ~= 'string' then - rspamd_logger.errx(cfg, 'cannot save redis invalidate script: %s', err) - else - redis_maybe_invalidate_sha = tostring(data) - end - end - redis_make_request(ev_base, - rspamd_config, - nil, - true, -- is write - maybe_invalidate_sha_cb, --callback - 'SCRIPT', -- command - {'LOAD', redis_lua_script_maybe_invalidate} -- arguments - ) - - if worker:get_name() == 'normal' then - -- We also want to train neural nets when they have enough data - rspamd_config:add_periodic(ev_base, 0.0, - function(cfg, ev_base) - return maybe_train_fanns(cfg, ev_base) - end) + if not ret then + rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') + end + end) + rspamd_plugins["fann_score"] = { + log_callback = function(task) + return totable(map( + function(tok) return {module_log_id, tok} end, + rspamd_gen_metatokens(task))) + end + } end - end) + end end