aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/fann_redis.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/fann_redis.lua')
-rw-r--r--src/plugins/lua/fann_redis.lua106
1 files changed, 75 insertions, 31 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index ac6f78772..09c20ebf9 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -43,12 +43,12 @@ local default_options = {
max_trains = 1000,
max_epoch = 1000,
max_usages = 10,
- use_settings = false,
- per_user = false,
- watch_interval = 60.0,
mse = 0.001,
autotrain = true,
},
+ use_settings = false,
+ per_user = false,
+ watch_interval = 60.0,
nlayers = 4,
lock_expire = 600,
learning_spawned = false,
@@ -58,8 +58,7 @@ local default_options = {
}
local settings = {
- rules = {
- }
+ rules = {}
}
-- ANNs indexed by settings id
@@ -96,9 +95,17 @@ local redis_lua_script_can_train = [[
if ret then nham = tonumber(ret) end
if KEYS[3] == 'spam' then
- if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end
+ if nham <= lim and nham + 1 >= nspam then
+ return tostring(nspam + 1)
+ else
+ return tostring(-(nham + 1))
+ end
else
- if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end
+ if nspam <= lim and nspam + 1 >= nham then
+ return tostring(nham + 1)
+ else
+ return tostring(-(nspam + 1))
+ end
end
return tostring(0)
@@ -312,8 +319,7 @@ local function gen_fann_prefix(rule, id)
tprefix = 't';
end
if id then
- return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id),
- rule.prefix .. id
+ return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id
else
return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
end
@@ -433,7 +439,7 @@ local function create_train_fann(rule, n, id)
if not is_fann_valid(rule, prefix, fanns[id].fann) then
fanns[id].fann_train = create_fann(n, rule.nlayers)
fanns[id].fann = nil
- elseif fanns[id].version % rule.max_usages == 0 then
+ elseif fanns[id].version % rule.train.max_usages == 0 then
-- Forget last fann
rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix,
fanns[id].version)
@@ -507,7 +513,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
local learn_spam, learn_ham
- if rule.autotrain then
+ if train_opts.autotrain then
if train_opts['spam_score'] then
learn_spam = score >= train_opts['spam_score']
else
@@ -531,13 +537,18 @@ local function fann_train_callback(rule, task, score, required_score, id)
end
end
+
if learn_spam or learn_ham then
local k
+ local vec_len = 0
if learn_spam then k = 'spam' else k = 'ham' end
local function learn_vec_cb(err)
if err then
rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err)
+ else
+ rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes",
+ rule['name'], k, vec_len)
end
end
@@ -548,6 +559,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
-- Add filtered meta tokens
fun.each(function(e) table.insert(fann_data, e) end, mt)
local str = rspamd_util.zstd_compress(table.concat(fann_data, ';'))
+ vec_len = #str
rspamd_redis.redis_make_request(task,
rule.redis,
@@ -559,10 +571,13 @@ local function fann_train_callback(rule, task, score, required_score, id)
)
else
if err then
- rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err)
+ rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err)
if string.match(err, 'NOSCRIPT') then
load_scripts(rspamd_config, task:get_ev_base(), nil)
end
+ elseif tonumber(data) < 0 then
+ rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s",
+ fname, k, -tonumber(data))
end
end
end
@@ -574,7 +589,7 @@ local function fann_train_callback(rule, task, score, required_score, id)
can_train_cb, --callback
'EVALSHA', -- command
{redis_can_train_sha, '4', gen_fann_prefix(rule, nil),
- suffix, k, tostring(rule.max_trains)} -- arguments
+ suffix, k, tostring(train_opts.max_trains)} -- arguments
)
end
end
@@ -722,7 +737,7 @@ local function train_fann(rule, _, ev_base, elt, worker)
create_train_fann(rule, n, elt)
end
- if #spam_elts + #ham_elts < rule.max_trains / 2 then
+ if #spam_elts + #ham_elts < rule.train.max_trains / 2 then
-- Invalidate ANN as it is definitely invalid
local function redis_invalidate_cb(_err, _data)
if _err then
@@ -912,10 +927,10 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker)
rspamd_logger.errx(rspamd_config,
'cannot get FANN trains %s from redis: %s', prefix, _err)
elseif _data and type(_data) == 'number' or type(_data) == 'string' then
- if tonumber(_data) and tonumber(_data) >= rule.max_trains then
+ if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then
rspamd_logger.infox(rspamd_config,
'need to learn ANN %s after %s learn vectors (%s required)',
- prefix, tonumber(_data), rule.max_trains)
+ prefix, tonumber(_data), rule.train.max_trains)
train_fann(rule, cfg, ev_base, elt, worker)
end
end
@@ -1011,8 +1026,7 @@ end
local function ann_push_vector(task)
local scores = task:get_metric_score()
-
- for _,rule in ipairs(settings.rules) do
+ for _,rule in pairs(settings.rules) do
local sid = "0"
if rule.use_settings then
sid = tostring(task:get_settings_id())
@@ -1053,32 +1067,64 @@ else
callback = fann_scores_filter
})
+ local function deepcopy(orig)
+ local orig_type = type(orig)
+ local copy
+ if orig_type == 'table' then
+ copy = {}
+ for orig_key, orig_value in next, orig, nil do
+ copy[deepcopy(orig_key)] = deepcopy(orig_value)
+ end
+ setmetatable(copy, deepcopy(getmetatable(orig)))
+ else -- number, string, boolean, etc
+ copy = orig
+ end
+ return copy
+ end
+ local function override_defaults(def, override)
+ for k,v in pairs(def) do
+ if override[k] then
+ if def[k] then
+ if type(override[k]) == 'table' then
+ override_defaults(def[k], override[k])
+ else
+ def[k] = override[k]
+ end
+ else
+ def[k] = override[k]
+ end
+ end
+ end
+ end
for k,r in pairs(rules) do
- rules[k] = default_options
- rules[k]['redis'] = redis_params
- local cur = rules[k]
+ local def_rules = deepcopy(default_options)
+ def_rules['redis'] = redis_params
-- Override defaults
- for sk,v in pairs(r) do
- cur[sk] = v
+ override_defaults(def_rules, r)
+
+ if not def_rules.prefix then
+ def_rules.prefix = k
end
- if not cur.prefix then
- cur.prefix = k
+ if not def_rules.name then
+ def_rules.name = k
end
+ rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
+ settings.rules[k] = def_rules
rspamd_config:set_metric_symbol({
- name = cur.symbol_spam,
+ name = def_rules.symbol_spam,
score = 3.0,
description = 'Neural network SPAM',
group = 'fann'
})
rspamd_config:set_metric_symbol({
- name = cur.symbol_ham,
+ name = def_rules.symbol_ham,
score = -2.0,
description = 'Neural network HAM',
group = 'fann'
})
rspamd_config:register_symbol({
- name = cur.symbol_ham,
+ name = def_rules.symbol_ham,
type = 'virtual,nostat',
parent = id
})
@@ -1086,13 +1132,11 @@ else
rspamd_config:register_symbol({
name = 'FANN_VECTOR_PUSH',
- type = 'postfilter,nostat',
+ type = 'idempotent,nostat',
priority = 5,
callback = ann_push_vector
})
- settings.rules = rules
-
-- Add training scripts
for _,rule in pairs(settings.rules) do
rspamd_config:add_on_load(function(cfg, ev_base, worker)