diff options
-rw-r--r-- | .luacheckrc | 2 | ||||
-rw-r--r-- | src/plugins/lua/fann_classifier.lua | 65 | ||||
-rw-r--r-- | src/plugins/lua/fann_scores.lua | 46 |
3 files changed, 65 insertions, 48 deletions
diff --git a/.luacheckrc b/.luacheckrc index 419d9e301..2eb10ce1a 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -2,8 +2,6 @@ codes = true std = 'min' exclude_files = { - '/**/src/plugins/lua/fann_classifier.lua', - '/**/src/plugins/lua/fann_scores.lua', } globals = { diff --git a/src/plugins/lua/fann_classifier.lua b/src/plugins/lua/fann_classifier.lua index 9c35d0bfa..770f244d8 100644 --- a/src/plugins/lua/fann_classifier.lua +++ b/src/plugins/lua/fann_classifier.lua @@ -19,8 +19,7 @@ limitations under the License. local rspamd_logger = require "rspamd_logger" local rspamd_fann = require "rspamd_fann" local rspamd_util = require "rspamd_util" -require "fun" () -local ucl = require "ucl" +local fun = require "fun" local redis_params local classifier_config = { @@ -41,13 +40,14 @@ redis_params = rspamd_parse_redis_server('fann_classifier') local function maybe_load_fann(task, continue_cb, call_if_fail) local function load_fann() local function redis_fann_load_cb(err, data) + -- XXX: upstreams if not err and type(data) == 'table' and type(data[2]) == 'string' then local version = tonumber(data[1]) - local err,ann_data = rspamd_util.zstd_decompress(data[2]) + local _err,ann_data = rspamd_util.zstd_decompress(data[2]) local ann - if err or not ann_data then - rspamd_logger.errx(task, 'cannot decompress ann: %s', err) + if _err or not ann_data then + rspamd_logger.errx(task, 'cannot decompress ann: %s', _err) else ann = rspamd_fann.load_data(ann_data) end @@ -80,7 +80,7 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) end local key = classifier_config.key - local ret,_,_ = rspamd_redis_make_request(task, + local ret,_,upstream = rspamd_redis_make_request(task, redis_params, -- connect params key, -- hash key false, -- is write @@ -88,10 +88,21 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) 'HMGET', -- command {key, 'version', 'data', 'spam', 'ham'} -- arguments ) + if not ret then + rspamd_logger.errx(task, 'redis error on host %s', upstream:get_addr()) + upstream:fail() + end end local function check_fann() + local _, ret, upstream local function redis_fann_check_cb(err, data) + if err then + rspamd_logger.errx(task, 'redis error on host %s: %s', upstream:get_addr(), err) + upstream:fail() + else + upstream:ok() + end if not err and type(data) == 'string' then local version = tonumber(data) @@ -104,7 +115,7 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) end local key = classifier_config.key - local ret,_,_ = rspamd_redis_make_request(task, + ret,_,upstream = rspamd_redis_make_request(task, redis_params, -- connect params key, -- hash key false, -- is write @@ -112,6 +123,10 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) 'HGET', -- command {key, 'version'} -- arguments ) + if not ret then + rspamd_logger.errx(task, 'redis error on host %s', upstream:get_addr()) + upstream:fail() + end end if not current_classify_ann.loaded then @@ -122,14 +137,13 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) end local function tokens_to_vector(tokens) - local vec = totable(map(function(tok) return tok[1] end, tokens)) + local vec = fun.totable(fun.map(function(tok) return tok[1] end, tokens)) local ret = {} - local ntok = #vec local neurons = classifier_config.neurons for i = 1,neurons do ret[i] = 0 end - each(function(e) + fun.each(function(e) local n = (e % neurons) + 1 ret[n] = ret[n] + 1 end, vec) @@ -175,9 +189,13 @@ local function create_fann() end local function save_fann(task, is_spam) - local function redis_fann_save_cb(err, data) + local ret, conn, upstream + local function redis_fann_save_cb(err) if err then rspamd_logger.errx(task, "cannot save neural net to redis: %s", err) + upstream:fail() + else + upstream:ok() end end @@ -190,7 +208,7 @@ local function save_fann(task, is_spam) else current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1 end - local ret,conn,_ = rspamd_redis_make_request(task, + ret,conn,upstream = rspamd_redis_make_request(task, redis_params, -- connect params key, -- hash key true, -- is write @@ -201,22 +219,23 @@ local function save_fann(task, is_spam) 'data', rspamd_util.zstd_compress(data), }) -- arguments - if conn then + if ret then conn:add_cmd('HINCRBY', {key, 'version', 1}) if is_spam then conn:add_cmd('HINCRBY', {key, 'spam', 1}) - rspamd_logger.errx(task, 'hui') else conn:add_cmd('HINCRBY', {key, 'ham', 1}) - rspamd_logger.errx(task, 'pezda') end + else + rspamd_logger.errx(task, 'redis error on host %s: %s', upstream:get_addr()) + upstream:fail() end end if redis_params then rspamd_classifiers['neural'] = { classify = function(task, classifier, tokens) - local function classify_cb(task) + local function classify_cb() local min_learns = classifier:get_param('min_learns') if min_learns then @@ -243,18 +262,18 @@ if redis_params then rspamd_logger.infox(task, 'fann classifier score: %s', symscore) if result > 0 then - each(function(st) + fun.each(function(st) task:insert_result(st:get_symbol(), result, symscore) end, - filter(function(st) + fun.filter(function(st) return st:is_spam() end, classifier:get_statfiles()) ) else - each(function(st) + fun.each(function(st) task:insert_result(st:get_symbol(), -result, symscore) end, - filter(function(st) + fun.filter(function(st) return not st:is_spam() end, classifier:get_statfiles()) ) @@ -263,8 +282,8 @@ if redis_params then maybe_load_fann(task, classify_cb, false) end, - learn = function(task, classifier, tokens, is_spam, is_unlearn) - local function learn_cb(task, is_loaded) + learn = function(task, _, tokens, is_spam, _) + local function learn_cb(_, is_loaded) if not is_loaded then create_fann() end @@ -286,4 +305,4 @@ if redis_params then maybe_load_fann(task, learn_cb, true) end, } -end
\ No newline at end of file +end diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 32169ee46..a96d27701 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -23,7 +23,6 @@ local rspamd_util = require "rspamd_util" local fann_symbol_spam = 'FANN_SPAM' local fann_symbol_ham = 'FANN_HAM' local fun = require "fun" -local ucl = require "ucl" local module_log_id = 0x100 -- Module vars @@ -46,9 +45,9 @@ local function symbols_to_fann_vector(syms, scores) local matched_symbols = {} local n = rspamd_config:get_symbols_count() - each(function(s, score) + fun.each(function(s, score) matched_symbols[s + 1] = rspamd_util.tanh(score) - end, zip(syms, scores)) + end, fun.zip(syms, scores)) for i=1,n do if matched_symbols[i] then @@ -71,7 +70,7 @@ end local function load_fann(id) local fname = gen_fann_file(id) - local err,st = rspamd_util.stat(fname) + local err = rspamd_util.stat(fname) if err then return false @@ -89,10 +88,10 @@ local function load_fann(id) ' is found in the cache; removing', data[id].fann:get_inputs(), n) data[id].fann = nil - local ret,err = rspamd_util.unlink(fname) + local ret,_err = rspamd_util.unlink(fname) if not ret then rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, err) + fname, _err) end else local layers = data[id].fann:get_layers() @@ -101,10 +100,10 @@ local function load_fann(id) rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing', #layers) data[id].fann = nil - local ret,err = rspamd_util.unlink(fname) + local ret,_err = rspamd_util.unlink(fname) if not ret then rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, err) + fname, _err) end else rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname) @@ -113,10 +112,10 @@ local function load_fann(id) end else rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname) - local ret,err = rspamd_util.unlink(fname) + local ret,_err = rspamd_util.unlink(fname) if not ret then rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, err) + fname, _err) end end @@ -218,9 +217,10 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, -- Store fann on disk local res = false - local err,st = rspamd_util.stat(fname) + local err = rspamd_util.stat(fname) + local fd if err then - local fd,err = rspamd_util.create_file(fname) + fd,err = rspamd_util.create_file(fname) if not fd then rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err) else @@ -229,7 +229,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, rspamd_util.unlock_file(fd) -- Closes fd as well end else - local fd = rspamd_util.lock_file(fname) + fd = rspamd_util.lock_file(fname) res = data[id].fann_train:save(fname) rspamd_util.unlock_file(fd) -- Closes fd as well end @@ -244,7 +244,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, else if not data[id].checked then data[id].checked = true - local err,st = rspamd_util.stat(fname) + local err = rspamd_util.stat(fname) if err then data[id].exist = false end @@ -262,7 +262,7 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, create_train_fann(n, id) end - local learn_spam, learn_ham = false, false + local learn_spam, learn_ham if opts['spam_score'] then learn_spam = score >= opts['spam_score'] else @@ -276,11 +276,11 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, if learn_spam or learn_ham then local learn_data = symbols_to_fann_vector( - map(function(r) return r[1] end, results), - map(function(r) return r[2] end, results) + fun.map(function(r) return r[1] end, results), + fun.map(function(r) return r[2] end, results) ) -- Add filtered meta tokens - each(function(e) table.insert(learn_data, e) end, extra) + fun.each(function(e) table.insert(learn_data, e) end, extra) if learn_spam then data[id].fann_train:train(learn_data, {1.0}) @@ -344,13 +344,13 @@ else max_epoch = opts['train']['max_epoch'] end local ret = cfg:register_worker_script("log_helper", - function(score, req_score, results, cf, id, extra) + function(score, req_score, results, cf, _id, extra) -- map (snd x) (filter (fst x == module_id) extra) - local extra_fann = map(function(e) return e[2] end, - filter(function(e) return e[1] == module_log_id end, extra)) + local extra_fann = fun.map(function(e) return e[2] end, + fun.filter(function(e) return e[1] == module_log_id end, extra)) if use_settings then fann_train_callback(score, req_score, results, cf, - tostring(id), opts['train'], extra_fann) + tostring(_id), opts['train'], extra_fann) else fann_train_callback(score, req_score, results, cf, '0', opts['train'], extra_fann) @@ -363,7 +363,7 @@ else end) rspamd_plugins["fann_score"] = { log_callback = function(task) - return totable(map( + return fun.totable(fun.map( function(tok) return {module_log_id, tok} end, rspamd_gen_metatokens(task))) end |