aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-07 19:45:08 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-07 19:45:08 +0100
commita1af120934a908292544b7848b3e62da0e8b9030 (patch)
tree273fdf6bf47d6cc522ec6a127f045deef1f151c1 /src
parentc4bf8f28cfc9d32be4eb60bbd5a2820b5bcb4c1b (diff)
downloadrspamd-a1af120934a908292544b7848b3e62da0e8b9030.tar.gz
rspamd-a1af120934a908292544b7848b3e62da0e8b9030.zip
[Minor] Neural: Further fixes
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/neural.lua25
1 files changed, 18 insertions, 7 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index d2c0191e7..b0f307803 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -280,7 +280,6 @@ local function new_ann_profile(task, rule, set, version)
end
lua_redis.redis_make_request(task,
- rspamd_config,
rule.redis,
nil,
true, -- is write
@@ -347,21 +346,28 @@ local function create_ann(n, nlayers)
end
-local function ann_train_callback(rule, task, score, required_score, set)
+local function ann_push_task_result(rule, task, verdict, score, set)
local train_opts = rule.train
+
local learn_spam, learn_ham
if train_opts.autotrain then
+ if verdict == 'passthrough' or verdict == 'uncertain' then
+ lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
+ verdict, score)
+ end
+
if train_opts['spam_score'] then
learn_spam = score >= train_opts['spam_score']
else
- learn_spam = score >= required_score
+ learn_spam = verdict == 'spam' or verdict == 'junk'
end
+
if train_opts['ham_score'] then
learn_ham = score <= train_opts['ham_score']
else
- learn_ham = score < 0
+ learn_ham = verdict == 'ham'
end
else
-- Train by request header
@@ -408,7 +414,7 @@ local function ann_train_callback(rule, task, score, required_score, set)
true, -- is write
learn_vec_cb, --callback
'LPUSH', -- command
- { set.ann.redis_prefix .. '_' .. learn_type, str} -- arguments
+ { set.ann.redis_key .. '_' .. learn_type, str} -- arguments
)
else
if err then
@@ -948,6 +954,7 @@ local function check_anns(worker, cfg, ev_base, rule, process_callback)
rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
err)
elseif type(data) == 'table' then
+ lua_util.debugm(N, cfg, 'process element %s:%s', rule.prefix, set.name)
process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
end
end
@@ -1000,12 +1007,12 @@ end
local function ann_push_vector(task)
if task:has_flag('skip') then return end
if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then return end
- local scores = task:get_metric_score()
+ local verdict,score = lua_util.get_task_verdict(task)
for _,rule in pairs(settings.rules) do
local sid = task:get_settings_id() or -1
if rule.settings[sid] then
- ann_train_callback(rule, task, scores[1], scores[2], rule.settings[sid])
+ ann_push_task_result(rule, task, verdict, score, rule.settings[sid])
end
end
@@ -1124,6 +1131,10 @@ local id = rspamd_config:register_symbol({
callback = ann_scores_filter
})
+settings = lua_util.override_defaults(settings, module_config)
+settings.rules = {} -- Reset unless validated further in the cycle
+
+-- Check all rules
for k,r in pairs(rules) do
local rule_elt = lua_util.override_defaults(default_options, r)
rule_elt['redis'] = redis_params