From 585af99ecefae1d219858b0d9b4ac8d268aa19c9 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 8 Mar 2018 12:34:20 +0000 Subject: [PATCH] [Minor] Rename routines in neural plugin --- src/plugins/neural.lua | 230 +++++++++++++++++++++-------------------- 1 file changed, 118 insertions(+), 112 deletions(-) diff --git a/src/plugins/neural.lua b/src/plugins/neural.lua index 117881b31..b2c7adcfa 100644 --- a/src/plugins/neural.lua +++ b/src/plugins/neural.lua @@ -14,8 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. ]]-- --- This plugin is a concept of FANN scores adjustment --- NOT FOR PRODUCTION USE so far if confighelp then return @@ -24,14 +22,14 @@ end local rspamd_logger = require "rspamd_logger" local rspamd_fann = require "rspamd_fann" local rspamd_util = require "rspamd_util" -local rspamd_redis = require "lua_redis" +local lua_redis = require "lua_redis" local lua_util = require "lua_util" local fun = require "fun" local meta_functions = require "lua_meta" local use_torch = false local torch local nn -local N = "fann_redis" +local N = "neural" if rspamd_config:has_torch() then use_torch = true @@ -67,10 +65,14 @@ local settings = { } -- ANNs indexed by settings id -local fanns = { +local anns = { } -local opts = rspamd_config:get_all_opt("fann_redis") +local opts = rspamd_config:get_all_opt("neural") +if not opts then + -- Legacy + opts = rspamd_config:get_all_opt("fann_redis") +end -- Lua script to train a row @@ -205,21 +207,21 @@ local redis_save_unlock_id = nil local redis_params local function load_scripts(params) - redis_can_train_id = rspamd_redis.add_redis_script(redis_lua_script_can_train, + redis_can_train_id = lua_redis.add_redis_script(redis_lua_script_can_train, params) - redis_maybe_load_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_load, + redis_maybe_load_id = lua_redis.add_redis_script(redis_lua_script_maybe_load, params) - redis_maybe_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_invalidate, + redis_maybe_invalidate_id = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, params) - redis_locked_invalidate_id = rspamd_redis.add_redis_script(redis_lua_script_locked_invalidate, + redis_locked_invalidate_id = lua_redis.add_redis_script(redis_lua_script_locked_invalidate, params) - redis_maybe_lock_id = rspamd_redis.add_redis_script(redis_lua_script_maybe_lock, + redis_maybe_lock_id = lua_redis.add_redis_script(redis_lua_script_maybe_lock, params) - redis_save_unlock_id = rspamd_redis.add_redis_script(redis_lua_script_save_unlock, + redis_save_unlock_id = lua_redis.add_redis_script(redis_lua_script_save_unlock, params) end -local function gen_fann_prefix(rule, id) +local function gen_ann_prefix(rule, id) local cksum = rspamd_config:get_symbols_cksum():hex() -- We also need to count metatokens: local n = meta_functions.rspamd_count_metatokens() @@ -234,7 +236,7 @@ local function gen_fann_prefix(rule, id) end end -local function is_fann_valid(rule, prefix, ann) +local function is_ann_valid(rule, prefix, ann) if ann then local n = rspamd_config:get_symbols_count() + meta_functions.rspamd_count_metatokens() @@ -260,7 +262,7 @@ local function is_fann_valid(rule, prefix, ann) end end -local function fann_scores_filter(task) +local function ann_scores_filter(task) for _,rule in pairs(settings.rules) do local id = '0' @@ -275,23 +277,23 @@ local function fann_scores_filter(task) id = id .. r end - if fanns[id] and fanns[id].fann then - local fann_data = task:get_symbols_tokens() + if anns[id] and anns[id].ann then + local ann_data = task:get_symbols_tokens() local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens - fun.each(function(e) table.insert(fann_data, e) end, mt) + fun.each(function(e) table.insert(ann_data, e) end, mt) local score if use_torch then - local out = fanns[id].fann:forward(torch.Tensor(fann_data)) + local out = anns[id].ann:forward(torch.Tensor(ann_data)) score = out[1] else - local out = fanns[id].fann:test(fann_data) + local out = anns[id].ann:test(ann_data) score = out[1] end local symscore = string.format('%.3f', score) - rspamd_logger.infox(task, 'fann score: %s', symscore) + rspamd_logger.infox(task, 'ann score: %s', symscore) if score > 0 then local result = score @@ -310,7 +312,7 @@ local function fann_scores_filter(task) end end -local function create_fann(n, nlayers) +local function create_ann(n, nlayers) if use_torch then -- We ignore number of layers so far when using torch local ann = nn.Sequential() @@ -334,36 +336,36 @@ local function create_fann(n, nlayers) end end -local function create_train_fann(rule, n, id) - local prefix = gen_fann_prefix(rule, id) - if not fanns[id] then - fanns[id] = {} +local function create_train_ann(rule, n, id) + local prefix = gen_ann_prefix(rule, id) + if not anns[id] then + anns[id] = {} end -- Fix that for flexibe layers number - if fanns[id].fann then - if not is_fann_valid(rule, prefix, fanns[id].fann) then - fanns[id].fann_train = create_fann(n, rule.nlayers) - fanns[id].fann = nil + if anns[id].ann then + if not is_ann_valid(rule, prefix, anns[id].ann) then + anns[id].ann_train = create_ann(n, rule.nlayers) + anns[id].ann = nil rspamd_logger.infox(rspamd_config, 'invalidate existing ANN, create train ANN %s', prefix) - elseif rule.train.max_usages > 0 and fanns[id].version % rule.train.max_usages == 0 then - -- Forget last fann + elseif rule.train.max_usages > 0 and anns[id].version % rule.train.max_usages == 0 then + -- Forget last ann rspamd_logger.infox(rspamd_config, 'recreate ANN %s, version %s', prefix, - fanns[id].version) - fanns[id].fann_train = create_fann(n, rule.nlayers) + anns[id].version) + anns[id].ann_train = create_ann(n, rule.nlayers) else - fanns[id].fann_train = fanns[id].fann + anns[id].ann_train = anns[id].ann rspamd_logger.infox(rspamd_config, 'reuse ANN for training %s', prefix) end else - fanns[id].fann_train = create_fann(n, rule.nlayers) + anns[id].ann_train = create_ann(n, rule.nlayers) rspamd_logger.infox(rspamd_config, 'create train ANN %s', prefix) - fanns[id].version = 0 + anns[id].version = 0 end end -local function load_or_invalidate_fann(rule, data, id, ev_base) +local function load_or_invalidate_ann(rule, data, id, ev_base) local ver = data[2] - local prefix = gen_fann_prefix(rule, id) + local prefix = gen_ann_prefix(rule, id) if not ver or not tonumber(ver) then rspamd_logger.errx(rspamd_config, 'cannot get version for ANN: %s', prefix) @@ -384,33 +386,33 @@ local function load_or_invalidate_fann(rule, data, id, ev_base) end end - if is_fann_valid(rule, prefix, ann) then - if not fanns[id] then fanns[id] = {} end - fanns[id].fann = ann + if is_ann_valid(rule, prefix, ann) then + if not anns[id] then anns[id] = {} end + anns[id].ann = ann rspamd_logger.infox(rspamd_config, 'loaded ANN %s version %s from redis', prefix, ver) - fanns[id].version = tonumber(ver) + anns[id].version = tonumber(ver) else local function redis_invalidate_cb(_err, _data) if _err then rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - fanns[id].version = 0 + anns[id].version = 0 end end -- Invalidate ANN rspamd_logger.infox(rspamd_config, 'invalidate ANN %s', prefix) - rspamd_redis.exec_redis_script(redis_maybe_invalidate_id, + lua_redis.exec_redis_script(redis_maybe_invalidate_id, {ev_base = ev_base, is_write = true}, redis_invalidate_cb, {prefix}) end end -local function fann_train_callback(rule, task, score, required_score, id) +local function ann_train_callback(rule, task, score, required_score, id) local train_opts = rule['train'] - local fname,suffix = gen_fann_prefix(rule, id) + local fname,suffix = gen_ann_prefix(rule, id) local learn_spam, learn_ham @@ -460,16 +462,16 @@ local function fann_train_callback(rule, task, score, required_score, id) rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) return end - local fann_data = task:get_symbols_tokens() + local ann_data = task:get_symbols_tokens() local mt = meta_functions.rspamd_gen_metatokens(task) -- Add filtered meta tokens - fun.each(function(e) table.insert(fann_data, e) end, mt) + fun.each(function(e) table.insert(ann_data, e) end, mt) -- Check NaNs in train data - if fun.all(function(e) return e == e end, fann_data) then - local str = rspamd_util.zstd_compress(table.concat(fann_data, ';')) + if fun.all(function(e) return e == e end, ann_data) then + local str = rspamd_util.zstd_compress(table.concat(ann_data, ';')) vec_len = #str - rspamd_redis.redis_make_request(task, + lua_redis.redis_make_request(task, rule.redis, nil, true, -- is write @@ -479,7 +481,7 @@ local function fann_train_callback(rule, task, score, required_score, id) ) else rspamd_logger.errx(task, "do not store learn vector as it contains %s NaN values", - fun.length(fun.filter(function(e) return e ~= e end, fann_data))) + fun.length(fun.filter(function(e) return e ~= e end, ann_data))) end else @@ -492,18 +494,18 @@ local function fann_train_callback(rule, task, score, required_score, id) end end - rspamd_redis.exec_redis_script(redis_can_train_id, + lua_redis.exec_redis_script(redis_can_train_id, {task = task, is_write = true}, can_train_cb, - {gen_fann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)}) + {gen_ann_prefix(rule, nil), suffix, k, tostring(train_opts.max_trains)}) end end -local function train_fann(rule, _, ev_base, elt, worker) +local function train_ann(rule, _, ev_base, elt, worker) local spam_elts = {} local ham_elts = {} elt = tostring(elt) - local prefix = gen_fann_prefix(rule, elt) + local prefix = gen_ann_prefix(rule, elt) local function redis_unlock_cb(err) if err then @@ -516,7 +518,7 @@ local function train_fann(rule, _, ev_base, elt, worker) if err then rspamd_logger.errx(rspamd_config, 'cannot save ANN %s to redis: %s', prefix, err) - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -535,7 +537,7 @@ local function train_fann(rule, _, ev_base, elt, worker) if errcode ~= 0 then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', prefix, errmsg) - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -550,16 +552,16 @@ local function train_fann(rule, _, ev_base, elt, worker) local ann_data if use_torch then local f = torch.MemoryFile() - f:writeObject(fanns[elt].fann_train) + f:writeObject(anns[elt].ann_train) ann_data = rspamd_util.zstd_compress(f:storage():string()) else - ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data()) + ann_data = rspamd_util.zstd_compress(anns[elt].ann_train:data()) end - fanns[elt].version = fanns[elt].version + 1 - fanns[elt].fann = fanns[elt].fann_train - fanns[elt].fann_train = nil - rspamd_redis.exec_redis_script(redis_save_unlock_id, + anns[elt].version = anns[elt].version + 1 + anns[elt].ann = anns[elt].ann_train + anns[elt].ann_train = nil + lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, {prefix, tostring(ann_data), tostring(rule.ann_expire)}) @@ -571,7 +573,7 @@ local function train_fann(rule, _, ev_base, elt, worker) if err then rspamd_logger.errx(rspamd_config, 'cannot train ANN %s: %s', prefix, err) - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -586,12 +588,12 @@ local function train_fann(rule, _, ev_base, elt, worker) local ann_data local f = torch.MemoryFile(torch.CharStorage():string(tostring(data))) ann_data = rspamd_util.zstd_compress(f:storage():string()) - fanns[elt].fann_train = f:readObject() + anns[elt].ann_train = f:readObject() - fanns[elt].version = fanns[elt].version + 1 - fanns[elt].fann = fanns[elt].fann_train - fanns[elt].fann_train = nil - rspamd_redis.exec_redis_script(redis_save_unlock_id, + anns[elt].version = anns[elt].version + 1 + anns[elt].ann = anns[elt].ann_train + anns[elt].ann_train = nil + lua_redis.exec_redis_script(redis_save_unlock_id, {ev_base = ev_base, is_write = true}, redis_save_cb, {prefix, tostring(ann_data), tostring(rule.ann_expire)}) @@ -602,7 +604,7 @@ local function train_fann(rule, _, ev_base, elt, worker) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', prefix, err) - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -625,10 +627,10 @@ local function train_fann(rule, _, ev_base, elt, worker) return #elts == n end - -- Now we can train fann - if not fanns[elt] or not fanns[elt].fann_train then - -- Create fann if it does not exist - create_train_fann(rule, n, elt) + -- Now we can train ann + if not anns[elt] or not anns[elt].ann_train then + -- Create ann if it does not exist + create_train_ann(rule, n, elt) end if #spam_elts + #ham_elts < rule.train.max_trains / 2 then @@ -638,12 +640,12 @@ local function train_fann(rule, _, ev_base, elt, worker) rspamd_logger.errx(rspamd_config, 'cannot invalidate ANN %s from redis: %s', prefix, _err) elseif type(_data) == 'string' then rspamd_logger.infox(rspamd_config, 'invalidated ANN %s from redis: %s', prefix, _err) - fanns[elt].version = 0 + anns[elt].version = 0 end end -- Invalidate ANN rspamd_logger.infox(rspamd_config, 'invalidate ANN %s: training data is invalid', prefix) - rspamd_redis.exec_redis_script(redis_locked_invalidate_id, + lua_redis.exec_redis_script(redis_locked_invalidate_id, {ev_base = ev_base, is_write = true}, redis_invalidate_cb, {prefix}) @@ -665,7 +667,7 @@ local function train_fann(rule, _, ev_base, elt, worker) torch.setnumthreads(rule.train.learn_threads) end local criterion = nn.MSECriterion() - local trainer = nn.StochasticGradient(fanns[elt].fann_train, + local trainer = nn.StochasticGradient(anns[elt].ann_train, criterion) trainer.learning_rate = 0.01 trainer.verbose = false @@ -677,7 +679,7 @@ local function train_fann(rule, _, ev_base, elt, worker) trainer:train(dataset) local out = torch.MemoryFile() - out:writeObject(fanns[elt].fann_train) + out:writeObject(anns[elt].ann_train) local st = out:storage():string() return st end @@ -698,7 +700,7 @@ local function train_fann(rule, _, ev_base, elt, worker) end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts))) rule.learning_spawned = true rspamd_logger.infox(rspamd_config, 'start learning ANN %s', prefix) - fanns[elt].fann_train:train_threaded(inputs, outputs, ann_trained, + anns[elt].ann_train:train_threaded(inputs, outputs, ann_trained, ev_base, { max_epochs = rule.train.max_epoch, desired_mse = rule.train.mse @@ -713,7 +715,7 @@ local function train_fann(rule, _, ev_base, elt, worker) if err or type(data) ~= 'table' then rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', prefix, err) - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -728,7 +730,7 @@ local function train_fann(rule, _, ev_base, elt, worker) local _,str = rspamd_util.zstd_decompress(tok) return fun.totable(fun.map(tonumber, rspamd_str_split(tostring(str), ';'))) end, data)) - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -746,7 +748,7 @@ local function train_fann(rule, _, ev_base, elt, worker) prefix, err) elseif type(data) == 'number' then -- Can train ANN - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -768,7 +770,7 @@ local function train_fann(rule, _, ev_base, elt, worker) end end if rule.learning_spawned then - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -793,20 +795,20 @@ local function train_fann(rule, _, ev_base, elt, worker) rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', prefix) return end - rspamd_redis.exec_redis_script(redis_maybe_lock_id, + lua_redis.exec_redis_script(redis_maybe_lock_id, {ev_base = ev_base, is_write = true}, redis_lock_cb, {prefix, tostring(os.time()), tostring(rule.lock_expire), rspamd_util.get_hostname()}) end -local function maybe_train_fanns(rule, cfg, ev_base, worker) +local function maybe_train_anns(rule, cfg, ev_base, worker) local function members_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) elseif type(data) == 'table' then fun.each(function(elt) elt = tostring(elt) - local prefix = gen_fann_prefix(rule, elt) + local prefix = gen_ann_prefix(rule, elt) rspamd_logger.infox(cfg, "check ANN %s", prefix) local redis_len_cb = function(_err, _data) if _err then @@ -817,7 +819,7 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) rspamd_logger.infox(rspamd_config, 'need to learn ANN %s after %s learn vectors (%s required)', prefix, tonumber(_data), rule.train.max_trains) - train_fann(rule, cfg, ev_base, elt, worker) + train_ann(rule, cfg, ev_base, elt, worker) else rspamd_logger.infox(rspamd_config, 'no need to learn ANN %s %s learn vectors (%s required)', @@ -826,7 +828,7 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) end end - rspamd_redis.redis_make_request_taskless(ev_base, + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, @@ -840,21 +842,21 @@ local function maybe_train_fanns(rule, cfg, ev_base, worker) end end - -- First we need to get all fanns stored in our Redis - rspamd_redis.redis_make_request_taskless(ev_base, + -- First we need to get all anns stored in our Redis + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, false, -- is write members_cb, --callback 'SMEMBERS', -- command - {gen_fann_prefix(rule, nil)} -- arguments + {gen_ann_prefix(rule, nil)} -- arguments ) return rule.watch_interval end -local function check_fanns(rule, _, ev_base) +local function check_anns(rule, _, ev_base) local function members_cb(err, data) if err then rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', @@ -867,7 +869,7 @@ local function check_fanns(rule, _, ev_base) rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, _err) elseif _data and type(_data) == 'table' then - load_or_invalidate_fann(rule, _data, elt, ev_base) + load_or_invalidate_ann(rule, _data, elt, ev_base) else if type(_data) ~= 'number' then rspamd_logger.errx(rspamd_config, 'invalid ANN type returned from Redis: %s; prefix: %s', @@ -877,29 +879,29 @@ local function check_fanns(rule, _, ev_base) end local local_ver = 0 - if fanns[elt] then - if fanns[elt].version then - local_ver = fanns[elt].version + if anns[elt] then + if anns[elt].version then + local_ver = anns[elt].version end end - rspamd_redis.exec_redis_script(redis_maybe_load_id, + lua_redis.exec_redis_script(redis_maybe_load_id, {ev_base = ev_base, is_write = false}, redis_update_cb, - {gen_fann_prefix(rule, elt), tostring(local_ver)}) + {gen_ann_prefix(rule, elt), tostring(local_ver)}) end, data) end end - -- First we need to get all fanns stored in our Redis - rspamd_redis.redis_make_request_taskless(ev_base, + -- First we need to get all anns stored in our Redis + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, rule.redis, nil, false, -- is write members_cb, --callback 'SMEMBERS', -- command - {gen_fann_prefix(rule, nil)} -- arguments + {gen_ann_prefix(rule, nil)} -- arguments ) return rule.watch_interval @@ -916,11 +918,15 @@ local function ann_push_vector(task) local r = task:get_principal_recipient() sid = sid .. r end - fann_train_callback(rule, task, scores[1], scores[2], sid) + ann_train_callback(rule, task, scores[1], scores[2], sid) end end -redis_params = rspamd_parse_redis_server('fann_redis') +redis_params = lua_redis.parse_redis_server('neural') + +if not redis_params then + redis_params = lua_redis.parse_redis_server('fann_redis') +end -- Initialization part if not (opts and type(opts) == 'table') or not redis_params then @@ -929,8 +935,8 @@ if not (opts and type(opts) == 'table') or not redis_params then return end -if not rspamd_fann.is_enabled() then - rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. +if not rspamd_fann.is_enabled() and not use_torch then + rspamd_logger.errx(rspamd_config, 'neural networks support is not compiled in rspamd, this ' .. 'module is eventually disabled') lua_util.disable_module(N, "fail") return @@ -951,7 +957,7 @@ else name = 'FANN_CHECK', type = 'postfilter,nostat', priority = 6, - callback = fann_scores_filter + callback = ann_scores_filter }) local function deepcopy(orig) @@ -1002,7 +1008,7 @@ else name = def_rules.symbol_spam, score = 3.0, description = 'Neural network SPAM', - group = 'fann' + group = 'neural' }) rspamd_config:register_symbol({ name = def_rules.symbol_spam, @@ -1014,7 +1020,7 @@ else name = def_rules.symbol_ham, score = -2.0, description = 'Neural network HAM', - group = 'fann' + group = 'neural' }) rspamd_config:register_symbol({ name = def_rules.symbol_ham, @@ -1034,13 +1040,13 @@ else for _,rule in pairs(settings.rules) do load_scripts(rule.redis) rspamd_config:add_on_load(function(cfg, ev_base, worker) - check_fanns(rule, cfg, ev_base) + check_anns(rule, cfg, ev_base) if worker:is_primary_controller() then -- We also want to train neural nets when they have enough data rspamd_config:add_periodic(ev_base, 0.0, function(_, _) - return maybe_train_fanns(rule, cfg, ev_base, worker) + return maybe_train_anns(rule, cfg, ev_base, worker) end) end end) -- 2.39.5