diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/fann_scores.lua | 768 |
1 files changed, 299 insertions, 469 deletions
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 0d9e00435..64566e102 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -1,5 +1,5 @@ --[[ -Copyright (c) 2015, Vsevolod Stakhov <vsevolod@highsecure.ru> +Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru> Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,11 +36,111 @@ local data = { } } -local fann_file + +-- Lua script to train a row +-- Uses the following keys: +-- key1 - prefix for keys +-- key2 - max count of learns +-- key3 - spam or ham +-- returns 1 or 0: 1 - allow learn, 0 - not allow learn +local redis_lua_script_can_train = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return 0 end + local nspam = 0 + local nham = 0 + + local ret = redis.call('LLEN', KEYS[1] .. '_spam') + if ret then nspam = tonumber(ret) end + ret = redis.call('LLEN', KEYS[1] .. '_ham') + if ret then nham = tonumber(ret) end + + if KEYS[3] == 'spam' then + if nham + 1 >= nspam then return tostring(nspam) end + else + if nspam + 1 >= nham then return tostring(nham) end + end + + return tostring(0) +]] +local redis_can_train_sha = nil + +-- Lua script to load ANN from redis +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - local version +-- returns nil or bulk string if new ANN can be loaded +local redis_lua_script_maybe_load = [[ + local locked = redis.call('GET', KEYS[1] .. '_locked') + if locked then return false end + + local ver = 0 + local ret = redis.call('GET', KEYS[1] .. '_version') + if ret then ver = tonumber(ret) end + if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end + + return false +]] +local redis_fann_maybe_load_sha = nil + +local redis_params +redis_params = rspamd_parse_redis_server('fann_scores') + +local fann_prefix = 'RF' local max_trains = 1000 local max_epoch = 100 local use_settings = false +local watch_interval = 60.0 + +local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args) + if not ev_base or not redis_params or not callback or not command then + return false,nil,nil + end + local addr + local rspamd_redis = require "rspamd_redis" + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(task, 'cannot select server to make redis request') + end + + local options = { + ev_base = ev_base, + config = cfg, + callback = callback, + host = addr:get_addr(), + timeout = redis_params['timeout'], + cmd = command, + args = args + } + + if redis_params['password'] then + options['password'] = redis_params['password'] + end + + if redis_params['db'] then + options['dbname'] = redis_params['db'] + end + + local ret,conn = rspamd_redis.make_request(options) + if not ret then + rspamd_logger.errx('cannot execute redis request') + end + return ret,conn,addr +end -- Metafunctions local function fann_size_function(task) @@ -277,69 +377,15 @@ local function symbols_to_fann_vector(syms, scores) return learn_data end -local function gen_fann_file(id) +local function gen_fann_prefix(id) if use_settings then - return fann_file .. id - else - return fann_file - end -end - -local function load_fann(id) - local fname = gen_fann_file(id) - local err,st = rspamd_util.stat(fname) - - if err then - return false - end - - local fd = rspamd_util.lock_file(fname) - data[id].fann = rspamd_fann.load(fname) - rspamd_util.unlock_file(fd) -- closes fd - - if data[id].fann then - local n = rspamd_config:get_symbols_count() + count_metatokens() - - if n ~= data[id].fann:get_inputs() then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache; removing', data[id].fann:get_inputs(), n) - data[id].fann = nil - - local ret,err = rspamd_util.unlink(fname) - if not ret then - rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, err) - end - else - local layers = data[id].fann:get_layers() - - if not layers or #layers ~= 5 then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing', - #layers) - data[id].fann = nil - local ret,err = rspamd_util.unlink(fname) - if not ret then - rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, err) - end - else - rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname) - return true - end - end + return fann_prefix .. id else - rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname) - local ret,err = rspamd_util.unlink(fname) - if not ret then - rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, err) - end + return fann_prefix end - - return false end -local function check_fann(id) +local function is_fann_valid(id) if data[id].fann then local n = rspamd_config:get_symbols_count() + count_metatokens() @@ -357,19 +403,11 @@ local function check_fann(id) end end - local fname = gen_fann_file(id) - local err,st = rspamd_util.stat(fname) - - if not err then - local mtime = st['mtime'] - - if mtime > data[id].fann_mtime then - rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' .. - 'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname) - data[id].fann_mtime = mtime - data[id].fann = nil - end + if data[id].fann then + return true end + + return false end local function fann_scores_filter(task) @@ -381,8 +419,6 @@ local function fann_scores_filter(task) end end - check_fann(id) - if data[id].fann then local symbols,scores = task:get_symbols_numeric() local fann_data = symbols_to_fann_vector(symbols, scores) @@ -403,10 +439,6 @@ local function fann_scores_filter(task) local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) task:insert_result(fann_symbol_ham, result, symscore, id) end - else - if load_fann(id) then - fann_scores_filter(task) - end end end @@ -416,69 +448,11 @@ local function create_train_fann(n, id) data[id].epoch = 0 end -local function fann_train_callback(score, required_score, results, cf, id, opts, extra) - local n = cf:get_symbols_count() + count_metatokens() - local fname = gen_fann_file(id) - - if not data[id].fann_train then - create_train_fann(n, id) - end - - if data[id].fann_train:get_inputs() ~= n then - rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', data[id].fann_train:get_inputs(), n) - create_train_fann(n, id) - end - - if data[id].ntrains > max_trains then - -- Store fann on disk - local res = false - - local err,st = rspamd_util.stat(fname) - if err then - local fd,err = rspamd_util.create_file(fname) - if not fd then - rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err) - else - rspamd_util.lock_file(fname, fd) - res = data[id].fann_train:save(fname) - rspamd_util.unlock_file(fd) -- Closes fd as well - end - else - local fd = rspamd_util.lock_file(fname) - res = data[id].fann_train:save(fname) - rspamd_util.unlock_file(fd) -- Closes fd as well - end - - if not res then - rspamd_logger.errx(cf, 'cannot save fann in %s', fname) - else - data[id].exist = true - data[id].ntrains = 0 - data[id].epoch = data[id].epoch + 1 - end - else - if not data[id].checked then - data[id].checked = true - local err,st = rspamd_util.stat(fname) - if err then - data[id].exist = false - end - end - if not data[id].exist then - rspamd_logger.infox(cf, 'not enough trains for fann %s, %s left', fname, - max_trains - data[id].ntrains) - end - end - - if data[id].epoch > max_epoch then - -- Re-create fann - rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname, - max_epoch) - create_train_fann(n, id) - end +local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base) + local fname = gen_fann_prefix(id) local learn_spam, learn_ham = false, false + if opts['spam_score'] then learn_spam = score >= opts['spam_score'] else @@ -491,21 +465,110 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, end if learn_spam or learn_ham then - local learn_data = symbols_to_fann_vector( - map(function(r) return r[1] end, results), - map(function(r) return r[2] end, results) + local k + if learn_spam then k = 'spam' else k = 'ham' end + + local function can_train_cb(err, data) + rspamd_logger.errx('hui') + if not err and tonumber(data) > 0 then + local learn_data = symbols_to_fann_vector( + map(function(r) return r[1] end, results), + map(function(r) return r[2] end, results) + ) + -- Add filtered meta tokens + each(function(e) table.insert(learn_data, e) end, extra) + local str = table.concat(learn_data, ';') + + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + learn_cb, --callback + 'LPUSH', -- command + {fname .. '_' .. k, str} -- arguments + ) + else + if err then + rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err) + end + end + end + + rspamd_logger.errx('pizda: %s %s %s %s', redis_can_train_sha, fname, tostring(max_trains), k) + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + can_train_cb, --callback + 'EVALSHA', -- command + {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments ) - -- Add filtered meta tokens - each(function(e) table.insert(learn_data, e) end, extra) + end +end - if learn_spam then - data[id].fann_train:train(learn_data, {1.0}) - else - data[id].fann_train:train(learn_data, {-1.0}) +local function check_fanns(cfg, ev_base) + local function members_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) + elseif type(data) == 'table' then + each(function(i, elt) + local redis_load_cb = function(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) + elseif type(data) == 'string' then + --load_fann(data, elt) + end + end + local redis_update_cb = function(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err) + elseif data then + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_load_cb, --callback + 'GET', -- command + {fann_prefix, fann_prefix .. elt .. '_data'} -- arguments + ) + end + end + + local local_ver = 0 + local numelt = tonumber(elt) + if data[numelt] then + if data[numelt].version then + local_ver = data[numelt].version + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_update_cb, --callback + 'EVALSHA', -- command + {redis_fann_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)} + ) + end, + data) end + end - data[id].ntrains = data[id].ntrains + 1 + if not redis_fann_maybe_load_sha then + -- Plan new event early + return 1.0 end + -- First we need to get all fanns stored in our Redis + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + members_cb, --callback + 'SMEMBERS', -- command + {fann_prefix} -- arguments + ) + + return watch_interval end -- Initialization part @@ -519,337 +582,104 @@ end if not rspamd_fann.is_enabled() then rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. 'module is eventually disabled') - return else - if not opts['fann_file'] then - rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' .. - '`fann_file` to be specified') - else - fann_file = opts['fann_file'] - use_settings = opts['use_settings'] - rspamd_config:set_metric_symbol({ - name = fann_symbol_spam, - score = 3.0, - description = 'Neural network SPAM', - group = 'fann' - }) - local id = rspamd_config:register_symbol({ - name = fann_symbol_spam, - type = 'postfilter', - priority = 5, - callback = fann_scores_filter - }) - rspamd_config:set_metric_symbol({ - name = fann_symbol_ham, - score = -2.0, - description = 'Neural network HAM', - group = 'fann' - }) - rspamd_config:register_symbol({ - name = fann_symbol_ham, - type = 'virtual', - parent = id - }) - if opts['train'] then - rspamd_config:add_on_load(function(cfg) - if opts['train']['max_train'] then - max_trains = opts['train']['max_train'] - end - if opts['train']['max_epoch'] then - max_epoch = opts['train']['max_epoch'] - end - local ret = cfg:register_worker_script("log_helper", - function(score, req_score, results, cf, id, extra) - -- map (snd x) (filter (fst x == module_id) extra) - local extra_fann = map(function(e) return e[2] end, - filter(function(e) return e[1] == module_log_id end, extra)) - if use_settings then - fann_train_callback(score, req_score, results, cf, - tostring(id), opts['train'], extra_fann) - else - fann_train_callback(score, req_score, results, cf, '0', - opts['train'], extra_fann) - end + use_settings = opts['use_settings'] + rspamd_config:set_metric_symbol({ + name = fann_symbol_spam, + score = 3.0, + description = 'Neural network SPAM', + group = 'fann' + }) + local id = rspamd_config:register_symbol({ + name = fann_symbol_spam, + type = 'postfilter', + priority = 5, + callback = fann_scores_filter + }) + rspamd_config:set_metric_symbol({ + name = fann_symbol_ham, + score = -2.0, + description = 'Neural network HAM', + group = 'fann' + }) + rspamd_config:register_symbol({ + name = fann_symbol_ham, + type = 'virtual', + parent = id + }) + if opts['train'] then + rspamd_config:add_on_load(function(cfg) + if opts['train']['max_train'] then + max_trains = opts['train']['max_train'] + end + if opts['train']['max_epoch'] then + max_epoch = opts['train']['max_epoch'] + end + local ret = cfg:register_worker_script("log_helper", + function(score, req_score, results, cf, id, extra, ev_base) + -- map (snd x) (filter (fst x == module_id) extra) + local extra_fann = map(function(e) return e[2] end, + filter(function(e) return e[1] == module_log_id end, extra)) + if use_settings then + fann_train_callback(score, req_score, results, cf, + tostring(id), opts['train'], extra_fann, ev_base) + else + fann_train_callback(score, req_score, results, cf, '0', + opts['train'], extra_fann, ev_base) + end end) - if not ret then - rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') - end - end) - rspamd_plugins["fann_score"] = { - log_callback = function(task) - return totable(map( - function(tok) return {module_log_id, tok} end, - gen_metatokens(task))) - end - } - end + if not ret then + rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') + end + end) + -- This is needed to pass extra tokens from worker to log_helper + rspamd_plugins["fann_score"] = { + log_callback = function(task) + return totable(map( + function(tok) return {module_log_id, tok} end, + gen_metatokens(task))) + end + } end -end - -local redis_params -local classifier_config = { - key = 'neural_net', - neurons = 200, - layers = 3, -} - -local current_classify_ann = { - loaded = false, - version = 0, - spam_learned = 0, - ham_learned = 0 -} - -redis_params = rspamd_parse_redis_server('fann_scores') - -local function maybe_load_fann(task, continue_cb, call_if_fail) - local function load_fann() - local function redis_fann_load_cb(err, data) - if not err and type(data) == 'table' and type(data[2]) == 'string' then - local version = tonumber(data[1]) - local err,ann_data = rspamd_util.zstd_decompress(data[2]) - local ann - - if err or not ann_data then - rspamd_logger.errx(task, 'cannot decompress ann: %s', err) - else - ann = rspamd_fann.load_data(ann_data) - end - - if ann then - current_classify_ann.loaded = true - current_classify_ann.version = version - current_classify_ann.ann = ann - if type(data[3]) == 'string' then - current_classify_ann.spam_learned = tonumber(data[3]) - else - current_classify_ann.spam_learned = 0 - end - if type(data[4]) == 'string' then - current_classify_ann.ham_learned = tonumber(data[4]) - else - current_classify_ann.ham_learned = 0 - end - rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE", - version, current_classify_ann.spam_learned, - current_classify_ann.ham_learned, - ann:get_mse()) - continue_cb(task, true) - elseif call_if_fail then - continue_cb(task, false) - end - elseif call_if_fail then - continue_cb(task, false) + -- Add training scripts + rspamd_config:add_on_load(function(cfg, ev_base) + local function can_train_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err) + else + redis_can_train_sha = tostring(data) end end - - local key = classifier_config.key - local ret,_,_ = rspamd_redis_make_request(task, - redis_params, -- connect params - key, -- hash key - false, -- is write - redis_fann_load_cb, --callback - 'HMGET', -- command - {key, 'version', 'data', 'spam', 'ham'} -- arguments + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + can_train_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_can_train} -- arguments ) - end - local function check_fann() - local function redis_fann_check_cb(err, data) - if not err and type(data) == 'string' then - local version = tonumber(data) + local function maybe_load_sha_cb(err, data) + if err or not data or type(data) ~= 'string' then + rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err) + else + redis_fann_maybe_load_sha = tostring(data) - if version <= current_classify_ann.version then - continue_cb(task, true) - else - load_fann() - end + rspamd_config:add_periodic(ev_base, 0.0, + function(cfg, ev_base) + return check_fanns(cfg, ev_base) + end) end end - - local key = classifier_config.key - local ret,_,_ = rspamd_redis_make_request(task, - redis_params, -- connect params - key, -- hash key - false, -- is write - redis_fann_check_cb, --callback - 'HGET', -- command - {key, 'version'} -- arguments + redis_make_request(ev_base, + rspamd_config, + nil, + true, -- is write + maybe_load_sha_cb, --callback + 'SCRIPT', -- command + {'LOAD', redis_lua_script_maybe_load} -- arguments ) - end - - if not current_classify_ann.loaded then - load_fann() - else - check_fann() - end -end - -local function tokens_to_vector(tokens) - local vec = totable(map(function(tok) return tok[1] end, tokens)) - local ret = {} - local ntok = #vec - local neurons = classifier_config.neurons - for i = 1,neurons do - ret[i] = 0 - end - each(function(e) - local n = (e % neurons) + 1 - ret[n] = ret[n] + 1 - end, vec) - local norm = 0 - for i = 1,neurons do - if ret[i] > norm then - norm = ret[i] - end - end - for i = 1,neurons do - if ret[i] ~= 0 and norm > 0 then - ret[i] = ret[i] / norm - end - end - - return ret -end - -local function add_metatokens(task, vec) - local mt = gen_metatokens(task) - for _,tok in ipairs(mt) do - table.insert(vec, tok) - end -end - -local function create_fann() - local layers = {} - local mt_size = count_metatokens() - local neurons = classifier_config.neurons + mt_size - - for i = 1,classifier_config.layers - 1 do - layers[i] = math.floor(neurons / i) - end - - table.insert(layers, 1) - - local ann = rspamd_fann.create(classifier_config.layers, layers) - current_classify_ann.loaded = true - current_classify_ann.version = 0 - current_classify_ann.ann = ann - current_classify_ann.spam_learned = 0 - current_classify_ann.ham_learned = 0 -end - -local function save_fann(task, is_spam) - local function redis_fann_save_cb(err, data) - if err then - rspamd_logger.errx(task, "cannot save neural net to redis: %s", err) - end - end - - local data = current_classify_ann.ann:data() - local key = classifier_config.key - current_classify_ann.version = current_classify_ann.version + 1 - - if is_spam then - current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1 - else - current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1 - end - local ret,conn,_ = rspamd_redis_make_request(task, - redis_params, -- connect params - key, -- hash key - true, -- is write - redis_fann_save_cb, --callback - 'HMSET', -- command - { - key, - 'data', rspamd_util.zstd_compress(data), - }) -- arguments - - if conn then - conn:add_cmd('HINCRBY', {key, 'version', 1}) - if is_spam then - conn:add_cmd('HINCRBY', {key, 'spam', 1}) - rspamd_logger.errx(task, 'hui') - else - conn:add_cmd('HINCRBY', {key, 'ham', 1}) - rspamd_logger.errx(task, 'pezda') - end - end -end - -if redis_params then - rspamd_classifiers['neural'] = { - classify = function(task, classifier, tokens) - local function classify_cb(task) - local min_learns = classifier:get_param('min_learns') - - if min_learns then - min_learns = tonumber(min_learns) - end - - if min_learns and min_learns > 0 then - if current_classify_ann.ham_learned < min_learns or - current_classify_ann.spam_learned < min_learns then - - rspamd_logger.infox(task, 'fann classifier has not enough learns: (%s spam, %s ham), %s required', - current_classify_ann.spam_learned, current_classify_ann.ham_learned, - min_learns) - return - end - end - - -- Perform classification - local vec = tokens_to_vector(tokens) - add_metatokens(task, vec) - local out = current_classify_ann.ann:test(vec) - local result = rspamd_util.tanh(2 * (out[1])) - local symscore = string.format('%.3f', out[1]) - rspamd_logger.infox(task, 'fann classifier score: %s', symscore) - - if result > 0 then - each(function(st) - task:insert_result(st:get_symbol(), result, symscore) - end, - filter(function(st) - return st:is_spam() - end, classifier:get_statfiles()) - ) - else - each(function(st) - task:insert_result(st:get_symbol(), -result, symscore) - end, - filter(function(st) - return not st:is_spam() - end, classifier:get_statfiles()) - ) - end - end - maybe_load_fann(task, classify_cb, false) - end, - - learn = function(task, classifier, tokens, is_spam, is_unlearn) - local function learn_cb(task, is_loaded) - if not is_loaded then - create_fann() - end - local vec = tokens_to_vector(tokens) - add_metatokens(task, vec) - - if is_spam then - current_classify_ann.ann:train(vec, {1.0}) - rspamd_logger.infox(task, "learned ANN spam, MSE: %s", - current_classify_ann.ann:get_mse()) - else - current_classify_ann.ann:train(vec, {-1.0}) - rspamd_logger.infox(task, "learned ANN ham, MSE: %s", - current_classify_ann.ann:get_mse()) - end - - save_fann(task, is_spam) - end - maybe_load_fann(task, learn_cb, true) - end, - } + end) end |