-- key1 - ann key
-- key2 - spam or ham
-- key3 - maximum trains
--- returns 1 or 0: 1 - allow learn, 0 - not allow learn
+-- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_store_train_vec = [[
local prefix = KEYS[1]
local locked = redis.call('HGET', prefix, 'lock')
if KEYS[2] == 'spam' then
if nspam <= lim then
- return tostring(nspam)
- else
- return tostring(-(nspam))
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if math.random() < skip_rate then
+ return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
+ end
+ end
+ return {tostring(nspam),'can learn'}
+ else -- Enough learns
+ return {tostring(-(nspam)),'too many spam samples'}
end
else
if nham <= lim then
- return tostring(nham)
+ if nsham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if math.random() < skip_rate then
+ return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
+ end
+ end
+ return {tostring(nham),'can learn'}
else
- return tostring(-(nham))
+ return {tostring(-(nham)),'too many ham samples'}
end
end
- return tostring(0)
+ return {tostring(0),'bad input'}
]]
local redis_can_store_train_vec_id = nil
if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
local function can_train_cb(err, data)
- if not err and tonumber(data) >= 0 then
- local coin = math.random()
- if coin < 1.0 - train_opts.train_prob then
- rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
- return
- end
- local vec = result_to_vector(task, set)
+ if not err and type(data) == 'table' then
+ local nsamples,reason = tonumber(data[1]),data[2]
- local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = set.ann.redis_key .. '_' .. learn_type
+ if nsamples > 0 then
+ local coin = math.random()
- local function learn_vec_cb(_err)
- if _err then
- rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
- rule.prefix, set.name, _err)
- else
- lua_util.debugm(N, task,
- "add train data for ANN rule " ..
- "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
- rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ if coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return
end
- end
- lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- learn_vec_cb, --callback
- 'LPUSH', -- command
- { target_key, str } -- arguments
- )
- else
- if err then
- rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
- rule.prefix, set.name, err)
- elseif tonumber(data) < 0 then
- rspamd_logger.infox(task, "cannot learn ANN %s:%s: too many %s samples: %s",
- rule.prefix, set.name, learn_type, -tonumber(data))
+ local vec = result_to_vector(task, set)
+
+ local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local target_key = set.ann.redis_key .. '_' .. learn_type
+
+ local function learn_vec_cb(_err)
+ if _err then
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ rule.prefix, set.name, _err)
+ else
+ lua_util.debugm(N, task,
+ "add train data for ANN rule " ..
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ end
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true, -- is write
+ learn_vec_cb, --callback
+ 'LPUSH', -- command
+ { target_key, str } -- arguments
+ )
+ else
+ -- Negative result returned
+ rspamd_logger.infox(task, "cannot learn ANN %s:%s: %s (%s vectors stored)",
+ rule.prefix, set.name, learn_type, reason, -tonumber(nsamples))
end
+ else
+ rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+ rule.prefix, set.name, err)
end
end
lua_redis.exec_redis_script(redis_can_store_train_vec_id,
{task = task, is_write = true},
can_train_cb,
- { set.ann.redis_key, learn_type, tostring(train_opts.max_trains)})
+ {
+ set.ann.redis_key,
+ learn_type,
+ tostring(train_opts.max_trains),
+ })
else
lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',
skip_reason)
return
end
+ if score ~= score then
+ lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
+ verdict)
+
+ return
+ end
+
for _,rule in pairs(settings.rules) do
local set = get_rule_settings(task, rule)