aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-29 15:23:39 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-07-29 15:23:39 +0100
commit99caa38084aa41938b98918b9119a288a0aefcf5 (patch)
treee21c1a10a2e11737c53bf875501149d3d97c05e4 /src/plugins
parentc2f9222a2438fa9d604bf92e63010a4ae5041f63 (diff)
downloadrspamd-99caa38084aa41938b98918b9119a288a0aefcf5.tar.gz
rspamd-99caa38084aa41938b98918b9119a288a0aefcf5.zip
[Feature] Rework fann plugin to be a normal post-filter
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/fann_redis.lua121
1 files changed, 44 insertions, 77 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 2c3de9ddc..dbb4955ef 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -28,6 +28,7 @@ local fann_symbol_spam = 'FANNR_SPAM'
local fann_symbol_ham = 'FANNR_HAM'
local rspamd_redis = require "lua_redis"
local fun = require "fun"
+local meta_functions = require "meta_functions"
local module_log_id = 0x200
-- Module vars
@@ -286,26 +287,6 @@ local function load_scripts(cfg, ev_base, on_load_cb)
)
end
-local function symbols_to_fann_vector(syms, scores)
- local learn_data = {}
- local matched_symbols = {}
- local n = rspamd_config:get_symbols_count()
-
- fun.each(function(s, score)
- matched_symbols[s + 1] = rspamd_util.tanh(score)
- end, fun.zip(syms, scores))
-
- for i=1,n do
- if matched_symbols[i] then
- learn_data[i] = matched_symbols[i]
- else
- learn_data[i] = 0
- end
- end
-
- return learn_data
-end
-
local function gen_fann_prefix(id)
if id then
return fann_prefix .. rspamd_config:get_symbols_cksum():hex() .. id,id
@@ -345,13 +326,10 @@ local function fann_scores_filter(task)
end
if fanns[id].fann then
- local symbols,scores = task:get_symbols_numeric()
- local fann_data = symbols_to_fann_vector(symbols, scores)
- local mt = rspamd_gen_metatokens(task)
-
- for _,tok in ipairs(mt) do
- table.insert(fann_data, tok)
- end
+ local fann_data = task:get_symbols_tokens()
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+ -- Add filtered meta tokens
+ fun.each(function(e) table.insert(fann_data, e) end, mt)
local out = fanns[id].fann:test(fann_data)
local symscore = string.format('%.3f', out[1])
@@ -445,7 +423,7 @@ local function load_or_invalidate_fann(data, id, ev_base)
end
end
-local function fann_train_callback(score, required_score, results, _, id, opts, extra, ev_base)
+local function fann_train_callback(task, score, required_score, id, opts)
local fname,suffix = gen_fann_prefix(id)
local learn_spam, learn_ham
@@ -473,13 +451,11 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
local function can_train_cb(err, data)
if not err and tonumber(data) > 0 then
- local learn_data = symbols_to_fann_vector(
- fun.map(function(r) return r[1] end, results),
- fun.map(function(r) return r[2] end, results)
- )
+ local fann_data = task:get_symbols_tokens()
+ local mt = meta_functions.rspamd_gen_metatokens(task)
-- Add filtered meta tokens
- fun.each(function(e) table.insert(learn_data, e) end, extra)
- local str = rspamd_util.zstd_compress(table.concat(learn_data, ';'))
+ fun.each(function(e) table.insert(fann_data, e) end, mt)
+ local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
rspamd_redis.redis_make_request_taskless(ev_base,
rspamd_config,
@@ -500,14 +476,14 @@ local function fann_train_callback(score, required_score, results, _, id, opts,
end
end
- rspamd_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
+ rspamd_redis.rspamd_redis_make_request(task,
redis_params,
nil,
true, -- is write
can_train_cb, --callback
'EVALSHA', -- command
- {redis_can_train_sha, '4', gen_fann_prefix(nil), suffix, k, tostring(max_trains)} -- arguments
+ {redis_can_train_sha, '4', gen_fann_prefix(nil),
+ suffix, k, tostring(max_trains)} -- arguments
)
end
end
@@ -857,6 +833,18 @@ local function check_fanns(_, ev_base)
return watch_interval
end
+local function ann_push_vector(task)
+ local scores = task:get_metric_score()
+ local sid = task:get_settings_id()
+ if use_settings then
+ fann_train_callback(task, scores[1], scores[2],
+ tostring(sid), opts['train'])
+ else
+ fann_train_callback(task, scores[1], scores[2],
+ tostring(sid), opts['train'])
+ end
+end
+
-- Initialization part
local opts = rspamd_config:get_all_opt("fann_redis")
@@ -892,7 +880,7 @@ else
local id = rspamd_config:register_symbol({
name = fann_symbol_spam,
type = 'postfilter',
- priority = 5,
+ priority = 6,
callback = fann_scores_filter
})
rspamd_config:set_metric_symbol({
@@ -907,45 +895,24 @@ else
parent = id
})
if opts['train'] then
- rspamd_config:add_on_load(function(cfg)
- if opts['train']['max_train'] then
- max_trains = opts['train']['max_train']
- end
- if opts['train']['max_epoch'] then
- max_epoch = opts['train']['max_epoch']
- end
- if opts['train']['max_usages'] then
- max_usages = opts['train']['max_usages']
- end
- if opts['train']['mse'] then
- mse = opts['train']['mse']
- end
- local ret = cfg:register_worker_script("log_helper",
- function(score, req_score, results, cf, _id, extra, ev_base)
- -- fun.map (snd x) (fun.filter (fst x == module_id) extra)
- local extra_fann = fun.map(function(e) return e[2] end,
- fun.filter(function(e) return e[1] == module_log_id end, extra))
- if use_settings then
- fann_train_callback(score, req_score, results, cf,
- tostring(_id), opts['train'], extra_fann, ev_base)
- else
- fann_train_callback(score, req_score, results, cf, '0',
- opts['train'], extra_fann, ev_base)
- end
- end)
-
- if not ret then
- rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
- end
- end)
- -- This is needed to pass extra tokens from worker to log_helper
- rspamd_plugins["fann_redis"] = {
- log_callback = function(task)
- return fun.totable(fun.map(
- function(tok) return {module_log_id, tok} end,
- rspamd_gen_metatokens(task)))
- end
- }
+ if opts['train']['max_train'] then
+ max_trains = opts['train']['max_train']
+ end
+ if opts['train']['max_epoch'] then
+ max_epoch = opts['train']['max_epoch']
+ end
+ if opts['train']['max_usages'] then
+ max_usages = opts['train']['max_usages']
+ end
+ if opts['train']['mse'] then
+ mse = opts['train']['mse']
+ end
+ rspamd_config:register_symbol({
+ name = 'FANN_VECTOR_PUSH',
+ type = 'postfilter',
+ priority = 5,
+ callback = ann_push_vector
+ })
end
-- Add training scripts
rspamd_config:add_on_load(function(cfg, ev_base, worker)