diff options
-rw-r--r-- | src/plugins/lua/fann_redis.lua | 106 |
1 files changed, 75 insertions, 31 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua index ac6f78772..09c20ebf9 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/lua/fann_redis.lua @@ -43,12 +43,12 @@ local default_options = { max_trains = 1000, max_epoch = 1000, max_usages = 10, - use_settings = false, - per_user = false, - watch_interval = 60.0, mse = 0.001, autotrain = true, }, + use_settings = false, + per_user = false, + watch_interval = 60.0, nlayers = 4, lock_expire = 600, learning_spawned = false, @@ -58,8 +58,7 @@ local default_options = { } local settings = { - rules = { - } + rules = {} } -- ANNs indexed by settings id @@ -96,9 +95,17 @@ local redis_lua_script_can_train = [[ if ret then nham = tonumber(ret) end if KEYS[3] == 'spam' then - if nham <= lim and nham + 1 >= nspam then return tostring(nspam + 1) end + if nham <= lim and nham + 1 >= nspam then + return tostring(nspam + 1) + else + return tostring(-(nham + 1)) + end else - if nspam <= lim and nspam + 1 >= nham then return tostring(nham + 1) end + if nspam <= lim and nspam + 1 >= nham then + return tostring(nham + 1) + else + return tostring(-(nspam + 1)) + end end return tostring(0) @@ -312,8 +319,7 @@ local function gen_fann_prefix(rule, id) tprefix = 't'; end if id then - return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), - rule.prefix .. id + return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id), id else return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil end @@ -433,7 +439,7 @@ local function create_train_fann(rule, n, id) if not is_fann_valid(rule, prefix, fanns[id].fann) then fanns[id].fann_train = create_fann(n, rule.nlayers) fanns[id].fann = nil - elseif fanns[id].version % rule.max_usages == 0 then + elseif fanns[id].version % rule.train.max_usages == 0 then -- Forget last fann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, fanns[id].version) @@ -507,7 +513,7 @@ local function fann_train_callback(rule, task, score, required_score, id) local learn_spam, learn_ham - if rule.autotrain then + if train_opts.autotrain then if train_opts['spam_score'] then learn_spam = score >= train_opts['spam_score'] else @@ -531,13 +537,18 @@ local function fann_train_callback(rule, task, score, required_score, id) end end + if learn_spam or learn_ham then local k + local vec_len = 0 if learn_spam then k = 'spam' else k = 'ham' end local function learn_vec_cb(err) if err then rspamd_logger.errx(rspamd_config, 'cannot store train vector for %s: %s', fname, err) + else + rspamd_logger.infox(task, "trained ANN rule %s, save %s vector, %s bytes", + rule['name'], k, vec_len) end end @@ -548,6 +559,7 @@ local function fann_train_callback(rule, task, score, required_score, id) -- Add filtered meta tokens fun.each(function(e) table.insert(fann_data, e) end, mt) local str = rspamd_util.zstd_compress(table.concat(fann_data, ';')) + vec_len = #str rspamd_redis.redis_make_request(task, rule.redis, @@ -559,10 +571,13 @@ local function fann_train_callback(rule, task, score, required_score, id) ) else if err then - rspamd_logger.errx(rspamd_config, 'cannot check if we can train %s: %s', fname, err) + rspamd_logger.errx(task, 'cannot check if we can train %s: %s', fname, err) if string.match(err, 'NOSCRIPT') then load_scripts(rspamd_config, task:get_ev_base(), nil) end + elseif tonumber(data) < 0 then + rspamd_logger.infox(task, "cannot learn ANN %s: too many %s samples: %s", + fname, k, -tonumber(data)) end end end @@ -574,7 +589,7 @@ local function fann_train_callback(rule, task, score, required_score, id) can_train_cb, --callback 'EVALSHA', -- command {redis_can_train_sha, '4', gen_fann_prefix(rule, nil), - suffix, k, tostring(rule.max_trains)} -- arguments + suffix, k, tostring(train_opts.max_trains)} -- arguments ) end end @@ -722,7 +737,7 @@ local function train_fann(rule, _, ev_base, elt, worker) create_train_fann(rule, n, elt) end - if #spam_elts + #ham_elts < rule.max_trains / 2 then + if #spam_elts + #ham_elts < rule.train.max_trains / 2 then -- Invalidate ANN as it is definitely invalid local function redis_invalidate_cb(_err, _data) if _err then @@ -912,10 +927,10 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', prefix, _err) elseif _data and type(_data) == 'number' or type(_data) == 'string' then - if tonumber(_data) and tonumber(_data) >= rule.max_trains then + if tonumber(_data) and tonumber(_data) >= rule.train.max_trains then rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', - prefix, tonumber(_data), rule.max_trains) + prefix, tonumber(_data), rule.train.max_trains) train_fann(rule, cfg, ev_base, elt, worker) end end @@ -1011,8 +1026,7 @@ end local function ann_push_vector(task) local scores = task:get_metric_score() - - for _,rule in ipairs(settings.rules) do + for _,rule in pairs(settings.rules) do local sid = "0" if rule.use_settings then sid = tostring(task:get_settings_id()) @@ -1053,32 +1067,64 @@ else callback = fann_scores_filter }) + local function deepcopy(orig) + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[deepcopy(orig_key)] = deepcopy(orig_value) + end + setmetatable(copy, deepcopy(getmetatable(orig))) + else -- number, string, boolean, etc + copy = orig + end + return copy + end + local function override_defaults(def, override) + for k,v in pairs(def) do + if override[k] then + if def[k] then + if type(override[k]) == 'table' then + override_defaults(def[k], override[k]) + else + def[k] = override[k] + end + else + def[k] = override[k] + end + end + end + end for k,r in pairs(rules) do - rules[k] = default_options - rules[k]['redis'] = redis_params - local cur = rules[k] + local def_rules = deepcopy(default_options) + def_rules['redis'] = redis_params -- Override defaults - for sk,v in pairs(r) do - cur[sk] = v + override_defaults(def_rules, r) + + if not def_rules.prefix then + def_rules.prefix = k end - if not cur.prefix then - cur.prefix = k + if not def_rules.name then + def_rules.name = k end + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) + settings.rules[k] = def_rules rspamd_config:set_metric_symbol({ - name = cur.symbol_spam, + name = def_rules.symbol_spam, score = 3.0, description = 'Neural network SPAM', group = 'fann' }) rspamd_config:set_metric_symbol({ - name = cur.symbol_ham, + name = def_rules.symbol_ham, score = -2.0, description = 'Neural network HAM', group = 'fann' }) rspamd_config:register_symbol({ - name = cur.symbol_ham, + name = def_rules.symbol_ham, type = 'virtual,nostat', parent = id }) @@ -1086,13 +1132,11 @@ else rspamd_config:register_symbol({ name = 'FANN_VECTOR_PUSH', - type = 'postfilter,nostat', + type = 'idempotent,nostat', priority = 5, callback = ann_push_vector }) - settings.rules = rules - -- Add training scripts for _,rule in pairs(settings.rules) do rspamd_config:add_on_load(function(cfg, ev_base, worker) |