diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-07 19:45:08 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-07 19:45:08 +0100 |
commit | a1af120934a908292544b7848b3e62da0e8b9030 (patch) | |
tree | 273fdf6bf47d6cc522ec6a127f045deef1f151c1 /src | |
parent | c4bf8f28cfc9d32be4eb60bbd5a2820b5bcb4c1b (diff) | |
download | rspamd-a1af120934a908292544b7848b3e62da0e8b9030.tar.gz rspamd-a1af120934a908292544b7848b3e62da0e8b9030.zip |
[Minor] Neural: Further fixes
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/neural.lua | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index d2c0191e7..b0f307803 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -280,7 +280,6 @@ local function new_ann_profile(task, rule, set, version) end lua_redis.redis_make_request(task, - rspamd_config, rule.redis, nil, true, -- is write @@ -347,21 +346,28 @@ local function create_ann(n, nlayers) end -local function ann_train_callback(rule, task, score, required_score, set) +local function ann_push_task_result(rule, task, verdict, score, set) local train_opts = rule.train + local learn_spam, learn_ham if train_opts.autotrain then + if verdict == 'passthrough' or verdict == 'uncertain' then + lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', + verdict, score) + end + if train_opts['spam_score'] then learn_spam = score >= train_opts['spam_score'] else - learn_spam = score >= required_score + learn_spam = verdict == 'spam' or verdict == 'junk' end + if train_opts['ham_score'] then learn_ham = score <= train_opts['ham_score'] else - learn_ham = score < 0 + learn_ham = verdict == 'ham' end else -- Train by request header @@ -408,7 +414,7 @@ local function ann_train_callback(rule, task, score, required_score, set) true, -- is write learn_vec_cb, --callback 'LPUSH', -- command - { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments + { set.ann.redis_key .. '_' .. learn_type, str} -- arguments ) else if err then @@ -948,6 +954,7 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback) rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', err) elseif type(data) == 'table' then + lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name) process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) end end @@ -1000,12 +1007,12 @@ end local function ann_push_vector(task) if task:has_flag('skip') then return end if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end - local scores = task:get_metric_score() + local verdict,score = lua_util.get_task_verdict(task) for _,rule in pairs(settings.rules) do local sid = task:get_settings_id() or -1 if rule.settings[sid] then - ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid]) + ann_push_task_result(rule, task, verdict, score, rule.settings[sid]) end end @@ -1124,6 +1131,10 @@ local id = rspamd_config:register_symbol({ callback = ann_scores_filter }) +settings = lua_util.override_defaults(settings, module_config) +settings.rules = {} -- Reset unless validated further in the cycle + +-- Check all rules for k,r in pairs(rules) do local rule_elt = lua_util.override_defaults(default_options, r) rule_elt['redis'] = redis_params |