From 99caa38084aa41938b98918b9119a288a0aefcf5 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 29 Jul 2017 15:23:39 +0100 Subject: [PATCH] [Feature] Rework fann plugin to be a normal post-filter --- src/plugins/lua/fann_redis.lua | 121 ++++++++++++--------------------- 1 file changed, 44 insertions(+), 77 deletions(-) diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index 2c3de9ddc..dbb4955ef 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -28,6 +28,7 @@ local fann_symbol_spam = 'FANNR_SPAM' local fann_symbol_ham = 'FANNR_HAM' local rspamd_redis = require "lua_redis" local fun = require "fun" +local meta_functions = require "meta_functions" local module_log_id = 0x200 -- Module vars @@ -286,26 +287,6 @@ local function load_scripts(cfg, ev_base, on_load_cb) ) end -local function symbols_to_fann_vector(syms, scores) - local learn_data = {} - local matched_symbols = {} - local n = rspamd_config:get_symbols_count() - - fun.each(function(s, score) - matched_symbols[s + 1] = rspamd_util.tanh(score) - end, fun.zip(syms, scores)) - - for i=1,n do - if matched_symbols[i] then - learn_data[i] = matched_symbols[i] - else - learn_data[i] = 0 - end - end - - return learn_data -end - local function gen_fann_prefix(id) if id then return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id @@ -345,13 +326,10 @@ local function fann_scores_filter(task) end if fanns[id].fann then - local symbols,scores = task:get_symbols_numeric() - local fann_data = symbols_to_fann_vector(symbols, scores) - local mt = rspamd_gen_metatokens(task) - - for _,tok in ipairs(mt) do - table.insert(fann_data, tok) - end + local fann_data = task:get_symbols_tokens() + local mt = meta_functions.rspamd_gen_metatokens(task) + -- Add filtered meta tokens + fun.each(function(e) table.insert(fann_data, e) end, mt) local out = fanns[id].fann:test(fann_data) local symscore = string.format('%.3f', out[1]) @@ -445,7 +423,7 @@ local function load_or_invalidate_fann(data, id, ev_base) end end -local function fann_train_callback(score, required_score, results, _, id, opts, extra, ev_base) +local function fann_train_callback(task, score, required_score, id, opts) local fname,suffix = gen_fann_prefix(id) local learn_spam, learn_ham @@ -473,13 +451,11 @@ local function fann_train_callback(score, required_score, results, _, id, opts, local function can_train_cb(err, data) if not err and tonumber(data) > 0 then - local learn_data = symbols_to_fann_vector( - fun.map(function(r) return r[1] end, results), - fun.map(function(r) return r[2] end, results) - ) + local fann_data = task:get_symbols_tokens() + local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens - fun.each(function(e) table.insert(learn_data, e) end, extra) - local str = rspamd_util.zstd_compress(table.concat(learn_data, ';')) + fun.each(function(e) table.insert(fann_data, e) end, mt) + local str = rspamd_util.zstd_compress(table.concat(fann_data, ';')) rspamd_redis.redis_make_request_taskless(ev_base, rspamd_config, @@ -500,14 +476,14 @@ local function fann_train_callback(score, required_score, results, _, id, opts, end end - rspamd_redis.redis_make_request_taskless(ev_base, - rspamd_config, + rspamd_redis.rspamd_redis_make_request(task, redis_params, nil, true, -- is write can_train_cb, --callback 'EVALSHA', -- command - {redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments + {redis_can_train_sha, '4', gen_fann_prefix(nil), + suffix, k, tostring(max_trains)} -- arguments ) end end @@ -857,6 +833,18 @@ local function check_fanns(_, ev_base) return watch_interval end +local function ann_push_vector(task) + local scores = task:get_metric_score() + local sid = task:get_settings_id() + if use_settings then + fann_train_callback(task, scores[1], scores[2], + tostring(sid), opts['train']) + else + fann_train_callback(task, scores[1], scores[2], + tostring(sid), opts['train']) + end +end + -- Initialization part local opts = rspamd_config:get_all_opt("fann_redis") @@ -892,7 +880,7 @@ else local id = rspamd_config:register_symbol({ name = fann_symbol_spam, type = 'postfilter', - priority = 5, + priority = 6, callback = fann_scores_filter }) rspamd_config:set_metric_symbol({ @@ -907,45 +895,24 @@ else parent = id }) if opts['train'] then - rspamd_config:add_on_load(function(cfg) - if opts['train']['max_train'] then - max_trains = opts['train']['max_train'] - end - if opts['train']['max_epoch'] then - max_epoch = opts['train']['max_epoch'] - end - if opts['train']['max_usages'] then - max_usages = opts['train']['max_usages'] - end - if opts['train']['mse'] then - mse = opts['train']['mse'] - end - local ret = cfg:register_worker_script("log_helper", - function(score, req_score, results, cf, _id, extra, ev_base) - -- fun.map (snd x) (fun.filter (fst x == module_id) extra) - local extra_fann = fun.map(function(e) return e[2] end, - fun.filter(function(e) return e[1] == module_log_id end, extra)) - if use_settings then - fann_train_callback(score, req_score, results, cf, - tostring(_id), opts['train'], extra_fann, ev_base) - else - fann_train_callback(score, req_score, results, cf, '0', - opts['train'], extra_fann, ev_base) - end - end) - - if not ret then - rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') - end - end) - -- This is needed to pass extra tokens from worker to log_helper - rspamd_plugins["fann_redis"] = { - log_callback = function(task) - return fun.totable(fun.map( - function(tok) return {module_log_id, tok} end, - rspamd_gen_metatokens(task))) - end - } + if opts['train']['max_train'] then + max_trains = opts['train']['max_train'] + end + if opts['train']['max_epoch'] then + max_epoch = opts['train']['max_epoch'] + end + if opts['train']['max_usages'] then + max_usages = opts['train']['max_usages'] + end + if opts['train']['mse'] then + mse = opts['train']['mse'] + end + rspamd_config:register_symbol({ + name = 'FANN_VECTOR_PUSH', + type = 'postfilter', + priority = 5, + callback = ann_push_vector + }) end -- Add training scripts rspamd_config:add_on_load(function(cfg, ev_base, worker) -- 2.39.5