Browse Source

[Feature] Rework fann plugin to be a normal post-filter

tags/1.7.0
Vsevolod Stakhov 7 years ago
parent
commit
99caa38084
1 changed files with 44 additions and 77 deletions
  1. 44
    77
      src/plugins/lua/fann_redis.lua

+ 44
- 77
src/plugins/lua/fann_redis.lua View File

@@ -28,6 +28,7 @@ local fann_symbol_spam = 'FANNR_SPAM'
local fann_symbol_ham = 'FANNR_HAM'
local rspamd_redis = require "lua_redis"
local fun = require "fun"
local meta_functions = require "meta_functions"

local module_log_id = 0x200
-- Module vars
@@ -286,26 +287,6 @@ local function load_scripts(cfg, ev_base, on_load_cb)
)
end

local function symbols_to_fann_vector(syms, scores)
local learn_data = {}
local matched_symbols = {}
local n = rspamd_config:get_symbols_count()

fun.each(function(s, score)
matched_symbols[s + 1] = rspamd_util.tanh(score)
end, fun.zip(syms, scores))

for i=1,n do
if matched_symbols[i] then
learn_data[i] = matched_symbols[i]
else
learn_data[i] = 0
end
end

return learn_data
end

local function gen_fann_prefix(id)
if id then
return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id
@@ -345,13 +326,10 @@ local function fann_scores_filter(task)
end

if fanns[id].fann then
local symbols,scores = task:get_symbols_numeric()
local fann_data = symbols_to_fann_vector(symbols, scores)
local mt = rspamd_gen_metatokens(task)

for _,tok in ipairs(mt) do
table.insert(fann_data, tok)
end
local fann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
fun.each(function(e) table.insert(fann_data, e) end, mt)

local out = fanns[id].fann:test(fann_data)
local symscore = string.format('%.3f', out[1])
@@ -445,7 +423,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
end
end

local function fann_train_callback(score, required_score, results, _, id, opts, extra, ev_base)
local function fann_train_callback(task, score, required_score, id, opts)
local fname,suffix = gen_fann_prefix(id)

local learn_spam, learn_ham
@@ -473,13 +451,11 @@ local function fann_train_callback(score, required_score, results, _, id, opts,

local function can_train_cb(err, data)
if not err and tonumber(data) > 0 then
local learn_data = symbols_to_fann_vector(
fun.map(function(r) return r[1] end, results),
fun.map(function(r) return r[2] end, results)
)
local fann_data = task:get_symbols_tokens()
local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
fun.each(function(e) table.insert(learn_data, e) end, extra)
local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
fun.each(function(e) table.insert(fann_data, e) end, mt)
local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))

rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
@@ -500,14 +476,14 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
end
end

rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
rspamd_redis.rspamd_redis_make_request(task,
redis_params,
nil,
true, -- is write
can_train_cb, --callback
'EVALSHA', -- command
{redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments
{redis_can_train_sha, '4', gen_fann_prefix(nil),
suffix, k, tostring(max_trains)} -- arguments
)
end
end
@@ -857,6 +833,18 @@ local function check_fanns(_, ev_base)
return watch_interval
end

local function ann_push_vector(task)
local scores = task:get_metric_score()
local sid = task:get_settings_id()
if use_settings then
fann_train_callback(task, scores[1], scores[2],
tostring(sid), opts['train'])
else
fann_train_callback(task, scores[1], scores[2],
tostring(sid), opts['train'])
end
end

-- Initialization part

local opts = rspamd_config:get_all_opt("fann_redis")
@@ -892,7 +880,7 @@ else
local id = rspamd_config:register_symbol({
name = fann_symbol_spam,
type = 'postfilter',
priority = 5,
priority = 6,
callback = fann_scores_filter
})
rspamd_config:set_metric_symbol({
@@ -907,45 +895,24 @@ else
parent = id
})
if opts['train'] then
rspamd_config:add_on_load(function(cfg)
if opts['train']['max_train'] then
max_trains = opts['train']['max_train']
end
if opts['train']['max_epoch'] then
max_epoch = opts['train']['max_epoch']
end
if opts['train']['max_usages'] then
max_usages = opts['train']['max_usages']
end
if opts['train']['mse'] then
mse = opts['train']['mse']
end
local ret = cfg:register_worker_script("log_helper",
function(score, req_score, results, cf, _id, extra, ev_base)
-- fun.map (snd x) (fun.filter (fst x == module_id) extra)
local extra_fann = fun.map(function(e) return e[2] end,
fun.filter(function(e) return e[1] == module_log_id end, extra))
if use_settings then
fann_train_callback(score, req_score, results, cf,
tostring(_id), opts['train'], extra_fann, ev_base)
else
fann_train_callback(score, req_score, results, cf, '0',
opts['train'], extra_fann, ev_base)
end
end)

if not ret then
rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
end
end)
-- This is needed to pass extra tokens from worker to log_helper
rspamd_plugins["fann_redis"] = {
log_callback = function(task)
return fun.totable(fun.map(
function(tok) return {module_log_id, tok} end,
rspamd_gen_metatokens(task)))
end
}
if opts['train']['max_train'] then
max_trains = opts['train']['max_train']
end
if opts['train']['max_epoch'] then
max_epoch = opts['train']['max_epoch']
end
if opts['train']['max_usages'] then
max_usages = opts['train']['max_usages']
end
if opts['train']['mse'] then
mse = opts['train']['mse']
end
rspamd_config:register_symbol({
name = 'FANN_VECTOR_PUSH',
type = 'postfilter',
priority = 5,
callback = ann_push_vector
})
end
-- Add training scripts
rspamd_config:add_on_load(function(cfg, ev_base, worker)

Loading…
Cancel
Save