summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.luacheckrc2
-rw-r--r--src/plugins/lua/fann_classifier.lua65
-rw-r--r--src/plugins/lua/fann_scores.lua46
3 files changed, 65 insertions, 48 deletions
diff --git a/.luacheckrc b/.luacheckrc
index 419d9e301..2eb10ce1a 100644
--- a/.luacheckrc
+++ b/.luacheckrc
@@ -2,8 +2,6 @@ codes = true
std = 'min'
exclude_files = {
- '/**/src/plugins/lua/fann_classifier.lua',
- '/**/src/plugins/lua/fann_scores.lua',
}
globals = {
diff --git a/src/plugins/lua/fann_classifier.lua b/src/plugins/lua/fann_classifier.lua
index 9c35d0bfa..770f244d8 100644
--- a/src/plugins/lua/fann_classifier.lua
+++ b/src/plugins/lua/fann_classifier.lua
@@ -19,8 +19,7 @@ limitations under the License.
local rspamd_logger = require "rspamd_logger"
local rspamd_fann = require "rspamd_fann"
local rspamd_util = require "rspamd_util"
-require "fun" ()
-local ucl = require "ucl"
+local fun = require "fun"
local redis_params
local classifier_config = {
@@ -41,13 +40,14 @@ redis_params = rspamd_parse_redis_server('fann_classifier')
local function maybe_load_fann(task, continue_cb, call_if_fail)
local function load_fann()
local function redis_fann_load_cb(err, data)
+ -- XXX: upstreams
if not err and type(data) == 'table' and type(data[2]) == 'string' then
local version = tonumber(data[1])
- local err,ann_data = rspamd_util.zstd_decompress(data[2])
+ local _err,ann_data = rspamd_util.zstd_decompress(data[2])
local ann
- if err or not ann_data then
- rspamd_logger.errx(task, 'cannot decompress ann: %s', err)
+ if _err or not ann_data then
+ rspamd_logger.errx(task, 'cannot decompress ann: %s', _err)
else
ann = rspamd_fann.load_data(ann_data)
end
@@ -80,7 +80,7 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
end
local key = classifier_config.key
- local ret,_,_ = rspamd_redis_make_request(task,
+ local ret,_,upstream = rspamd_redis_make_request(task,
redis_params, -- connect params
key, -- hash key
false, -- is write
@@ -88,10 +88,21 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
'HMGET', -- command
{key, 'version', 'data', 'spam', 'ham'} -- arguments
)
+ if not ret then
+ rspamd_logger.errx(task, 'redis error on host %s', upstream:get_addr())
+ upstream:fail()
+ end
end
local function check_fann()
+ local _, ret, upstream
local function redis_fann_check_cb(err, data)
+ if err then
+ rspamd_logger.errx(task, 'redis error on host %s: %s', upstream:get_addr(), err)
+ upstream:fail()
+ else
+ upstream:ok()
+ end
if not err and type(data) == 'string' then
local version = tonumber(data)
@@ -104,7 +115,7 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
end
local key = classifier_config.key
- local ret,_,_ = rspamd_redis_make_request(task,
+ ret,_,upstream = rspamd_redis_make_request(task,
redis_params, -- connect params
key, -- hash key
false, -- is write
@@ -112,6 +123,10 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
'HGET', -- command
{key, 'version'} -- arguments
)
+ if not ret then
+ rspamd_logger.errx(task, 'redis error on host %s', upstream:get_addr())
+ upstream:fail()
+ end
end
if not current_classify_ann.loaded then
@@ -122,14 +137,13 @@ local function maybe_load_fann(task, continue_cb, call_if_fail)
end
local function tokens_to_vector(tokens)
- local vec = totable(map(function(tok) return tok[1] end, tokens))
+ local vec = fun.totable(fun.map(function(tok) return tok[1] end, tokens))
local ret = {}
- local ntok = #vec
local neurons = classifier_config.neurons
for i = 1,neurons do
ret[i] = 0
end
- each(function(e)
+ fun.each(function(e)
local n = (e % neurons) + 1
ret[n] = ret[n] + 1
end, vec)
@@ -175,9 +189,13 @@ local function create_fann()
end
local function save_fann(task, is_spam)
- local function redis_fann_save_cb(err, data)
+ local ret, conn, upstream
+ local function redis_fann_save_cb(err)
if err then
rspamd_logger.errx(task, "cannot save neural net to redis: %s", err)
+ upstream:fail()
+ else
+ upstream:ok()
end
end
@@ -190,7 +208,7 @@ local function save_fann(task, is_spam)
else
current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1
end
- local ret,conn,_ = rspamd_redis_make_request(task,
+ ret,conn,upstream = rspamd_redis_make_request(task,
redis_params, -- connect params
key, -- hash key
true, -- is write
@@ -201,22 +219,23 @@ local function save_fann(task, is_spam)
'data', rspamd_util.zstd_compress(data),
}) -- arguments
- if conn then
+ if ret then
conn:add_cmd('HINCRBY', {key, 'version', 1})
if is_spam then
conn:add_cmd('HINCRBY', {key, 'spam', 1})
- rspamd_logger.errx(task, 'hui')
else
conn:add_cmd('HINCRBY', {key, 'ham', 1})
- rspamd_logger.errx(task, 'pezda')
end
+ else
+ rspamd_logger.errx(task, 'redis error on host %s: %s', upstream:get_addr())
+ upstream:fail()
end
end
if redis_params then
rspamd_classifiers['neural'] = {
classify = function(task, classifier, tokens)
- local function classify_cb(task)
+ local function classify_cb()
local min_learns = classifier:get_param('min_learns')
if min_learns then
@@ -243,18 +262,18 @@ if redis_params then
rspamd_logger.infox(task, 'fann classifier score: %s', symscore)
if result > 0 then
- each(function(st)
+ fun.each(function(st)
task:insert_result(st:get_symbol(), result, symscore)
end,
- filter(function(st)
+ fun.filter(function(st)
return st:is_spam()
end, classifier:get_statfiles())
)
else
- each(function(st)
+ fun.each(function(st)
task:insert_result(st:get_symbol(), -result, symscore)
end,
- filter(function(st)
+ fun.filter(function(st)
return not st:is_spam()
end, classifier:get_statfiles())
)
@@ -263,8 +282,8 @@ if redis_params then
maybe_load_fann(task, classify_cb, false)
end,
- learn = function(task, classifier, tokens, is_spam, is_unlearn)
- local function learn_cb(task, is_loaded)
+ learn = function(task, _, tokens, is_spam, _)
+ local function learn_cb(_, is_loaded)
if not is_loaded then
create_fann()
end
@@ -286,4 +305,4 @@ if redis_params then
maybe_load_fann(task, learn_cb, true)
end,
}
-end \ No newline at end of file
+end
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua
index 32169ee46..a96d27701 100644
--- a/src/plugins/lua/fann_scores.lua
+++ b/src/plugins/lua/fann_scores.lua
@@ -23,7 +23,6 @@ local rspamd_util = require "rspamd_util"
local fann_symbol_spam = 'FANN_SPAM'
local fann_symbol_ham = 'FANN_HAM'
local fun = require "fun"
-local ucl = require "ucl"
local module_log_id = 0x100
-- Module vars
@@ -46,9 +45,9 @@ local function symbols_to_fann_vector(syms, scores)
local matched_symbols = {}
local n = rspamd_config:get_symbols_count()
- each(function(s, score)
+ fun.each(function(s, score)
matched_symbols[s + 1] = rspamd_util.tanh(score)
- end, zip(syms, scores))
+ end, fun.zip(syms, scores))
for i=1,n do
if matched_symbols[i] then
@@ -71,7 +70,7 @@ end
local function load_fann(id)
local fname = gen_fann_file(id)
- local err,st = rspamd_util.stat(fname)
+ local err = rspamd_util.stat(fname)
if err then
return false
@@ -89,10 +88,10 @@ local function load_fann(id)
' is found in the cache; removing', data[id].fann:get_inputs(), n)
data[id].fann = nil
- local ret,err = rspamd_util.unlink(fname)
+ local ret,_err = rspamd_util.unlink(fname)
if not ret then
rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fname, err)
+ fname, _err)
end
else
local layers = data[id].fann:get_layers()
@@ -101,10 +100,10 @@ local function load_fann(id)
rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing',
#layers)
data[id].fann = nil
- local ret,err = rspamd_util.unlink(fname)
+ local ret,_err = rspamd_util.unlink(fname)
if not ret then
rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fname, err)
+ fname, _err)
end
else
rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname)
@@ -113,10 +112,10 @@ local function load_fann(id)
end
else
rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname)
- local ret,err = rspamd_util.unlink(fname)
+ local ret,_err = rspamd_util.unlink(fname)
if not ret then
rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fname, err)
+ fname, _err)
end
end
@@ -218,9 +217,10 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
-- Store fann on disk
local res = false
- local err,st = rspamd_util.stat(fname)
+ local err = rspamd_util.stat(fname)
+ local fd
if err then
- local fd,err = rspamd_util.create_file(fname)
+ fd,err = rspamd_util.create_file(fname)
if not fd then
rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err)
else
@@ -229,7 +229,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
rspamd_util.unlock_file(fd) -- Closes fd as well
end
else
- local fd = rspamd_util.lock_file(fname)
+ fd = rspamd_util.lock_file(fname)
res = data[id].fann_train:save(fname)
rspamd_util.unlock_file(fd) -- Closes fd as well
end
@@ -244,7 +244,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
else
if not data[id].checked then
data[id].checked = true
- local err,st = rspamd_util.stat(fname)
+ local err = rspamd_util.stat(fname)
if err then
data[id].exist = false
end
@@ -262,7 +262,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
create_train_fann(n, id)
end
- local learn_spam, learn_ham = false, false
+ local learn_spam, learn_ham
if opts['spam_score'] then
learn_spam = score >= opts['spam_score']
else
@@ -276,11 +276,11 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
if learn_spam or learn_ham then
local learn_data = symbols_to_fann_vector(
- map(function(r) return r[1] end, results),
- map(function(r) return r[2] end, results)
+ fun.map(function(r) return r[1] end, results),
+ fun.map(function(r) return r[2] end, results)
)
-- Add filtered meta tokens
- each(function(e) table.insert(learn_data, e) end, extra)
+ fun.each(function(e) table.insert(learn_data, e) end, extra)
if learn_spam then
data[id].fann_train:train(learn_data, {1.0})
@@ -344,13 +344,13 @@ else
max_epoch = opts['train']['max_epoch']
end
local ret = cfg:register_worker_script("log_helper",
- function(score, req_score, results, cf, id, extra)
+ function(score, req_score, results, cf, _id, extra)
-- map (snd x) (filter (fst x == module_id) extra)
- local extra_fann = map(function(e) return e[2] end,
- filter(function(e) return e[1] == module_log_id end, extra))
+ local extra_fann = fun.map(function(e) return e[2] end,
+ fun.filter(function(e) return e[1] == module_log_id end, extra))
if use_settings then
fann_train_callback(score, req_score, results, cf,
- tostring(id), opts['train'], extra_fann)
+ tostring(_id), opts['train'], extra_fann)
else
fann_train_callback(score, req_score, results, cf, '0',
opts['train'], extra_fann)
@@ -363,7 +363,7 @@ else
end)
rspamd_plugins["fann_score"] = {
log_callback = function(task)
- return totable(map(
+ return fun.totable(fun.map(
function(tok) return {module_log_id, tok} end,
rspamd_gen_metatokens(task)))
end