diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-08 11:10:11 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-08 11:10:11 +0000 |
commit | c8cb5343e03defb62830647d5646b2b66614f533 (patch) | |
tree | eff3ce6dc935d9373f6a6228ca1b57d2040953b7 | |
parent | 01094d437bd253101cd8e1cc794b7f3658079317 (diff) | |
download | rspamd-c8cb5343e03defb62830647d5646b2b66614f533.tar.gz rspamd-c8cb5343e03defb62830647d5646b2b66614f533.zip |
[Rework] Rename fann_redis to neural plugin
-rw-r--r-- | conf/modules.d/neural.conf (renamed from conf/modules.d/fann_redis.conf) | 0 | ||||
-rw-r--r-- | src/plugins/lua/fann_classifier.lua | 308 | ||||
-rw-r--r-- | src/plugins/lua/fann_scores.lua | 378 | ||||
-rw-r--r-- | src/plugins/neural.lua (renamed from src/plugins/lua/fann_redis.lua) | 0 |
4 files changed, 0 insertions, 686 deletions
diff --git a/conf/modules.d/fann_redis.conf b/conf/modules.d/neural.conf index f12224d48..f12224d48 100644 --- a/conf/modules.d/fann_redis.conf +++ b/conf/modules.d/neural.conf diff --git a/src/plugins/lua/fann_classifier.lua b/src/plugins/lua/fann_classifier.lua deleted file mode 100644 index f1e93ca2a..000000000 --- a/src/plugins/lua/fann_classifier.lua +++ /dev/null @@ -1,308 +0,0 @@ ---[[ -Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru> - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -]]-- - -if confighelp then - return -end - --- This plugin is a concept of FANN scores adjustment --- NOT FOR PRODUCTION USE so far -local rspamd_logger = require "rspamd_logger" -local rspamd_fann = require "rspamd_fann" -local rspamd_util = require "rspamd_util" -local lua_util = require "lua_util" -local fun = require "fun" -local N = 'fann_classifier' - -local redis_params -local classifier_config = { - key = 'neural_net', - neurons = 200, - layers = 3, -} - -local current_classify_ann = { - loaded = false, - version = 0, - spam_learned = 0, - ham_learned = 0 -} - -if not lua_util.check_experimental(N) then - return -end - -redis_params = rspamd_parse_redis_server('fann_classifier') - -local function maybe_load_fann(task, continue_cb, call_if_fail) - local function load_fann() - local function redis_fann_load_cb(err, data) - -- XXX: upstreams - if not err and type(data) == 'table' and type(data[2]) == 'string' then - local version = tonumber(data[1]) - local _err,ann_data = rspamd_util.zstd_decompress(data[2]) - local ann - - if _err or not ann_data then - rspamd_logger.errx(task, 'cannot decompress ann: %s', _err) - else - ann = rspamd_fann.load_data(ann_data) - end - - if ann then - current_classify_ann.loaded = true - current_classify_ann.version = version - current_classify_ann.ann = ann - if type(data[3]) == 'string' then - current_classify_ann.spam_learned = tonumber(data[3]) - else - current_classify_ann.spam_learned = 0 - end - if type(data[4]) == 'string' then - current_classify_ann.ham_learned = tonumber(data[4]) - else - current_classify_ann.ham_learned = 0 - end - rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE", - version, current_classify_ann.spam_learned, - current_classify_ann.ham_learned, - ann:get_mse()) - continue_cb(task, true) - elseif call_if_fail then - continue_cb(task, false) - end - elseif call_if_fail then - continue_cb(task, false) - end - end - - local key = classifier_config.key - local ret = rspamd_redis_make_request(task, - redis_params, -- connect params - key, -- hash key - false, -- is write - redis_fann_load_cb, --callback - 'HMGET', -- command - {key, 'version', 'data', 'spam', 'ham'} -- arguments - ) - if not ret then - rspamd_logger.errx(task, 'got error connecting to redis') - end - end - - local function check_fann() - local _, ret, upstream - local function redis_fann_check_cb(err, data) - if err then - rspamd_logger.errx(task, 'redis error on host %s: %s', upstream:get_addr(), err) - end - if not err and type(data) == 'string' then - local version = tonumber(data) - - if version <= current_classify_ann.version then - continue_cb(task, true) - else - load_fann() - end - end - end - - local key = classifier_config.key - ret,_,upstream = rspamd_redis_make_request(task, - redis_params, -- connect params - key, -- hash key - false, -- is write - redis_fann_check_cb, --callback - 'HGET', -- command - {key, 'version'} -- arguments - ) - if not ret then - rspamd_logger.errx(task, 'got error connecting to redis') - end - end - - if not current_classify_ann.loaded then - load_fann() - else - check_fann() - end -end - -local function tokens_to_vector(tokens) - local vec = fun.totable(fun.map(function(tok) return tok[1] end, tokens)) - local ret = {} - local neurons = classifier_config.neurons - for i = 1,neurons do - ret[i] = 0 - end - fun.each(function(e) - local n = (e % neurons) + 1 - ret[n] = ret[n] + 1 - end, vec) - local norm = 0 - for i = 1,neurons do - if ret[i] > norm then - norm = ret[i] - end - end - for i = 1,neurons do - if ret[i] ~= 0 and norm > 0 then - ret[i] = ret[i] / norm - end - end - - return ret -end - -local function add_metatokens(task, vec) - local mt = rspamd_gen_metatokens(task) - for _,tok in ipairs(mt) do - table.insert(vec, tok) - end -end - -local function create_fann() - local layers = {} - local mt_size = rspamd_count_metatokens() - local neurons = classifier_config.neurons + mt_size - - for i = 1,classifier_config.layers - 1 do - layers[i] = math.floor(neurons / i) - end - - table.insert(layers, 1) - - local ann = rspamd_fann.create(classifier_config.layers, layers) - current_classify_ann.loaded = true - current_classify_ann.version = 0 - current_classify_ann.ann = ann - current_classify_ann.spam_learned = 0 - current_classify_ann.ham_learned = 0 -end - -local function save_fann(task, is_spam) - local function redis_fann_save_cb(err) - if err then - rspamd_logger.errx(task, "cannot save neural net to redis: %s", err) - end - end - - local data = current_classify_ann.ann:data() - local key = classifier_config.key - current_classify_ann.version = current_classify_ann.version + 1 - - if is_spam then - current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1 - else - current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1 - end - local ret,conn = rspamd_redis_make_request(task, - redis_params, -- connect params - key, -- hash key - true, -- is write - redis_fann_save_cb, --callback - 'HMSET', -- command - { - key, - 'data', rspamd_util.zstd_compress(data), - }) -- arguments - - if ret then - conn:add_cmd('HINCRBY', {key, 'version', 1}) - if is_spam then - conn:add_cmd('HINCRBY', {key, 'spam', 1}) - else - conn:add_cmd('HINCRBY', {key, 'ham', 1}) - end - else - rspamd_logger.errx(task, 'got error connecting to redis') - end -end - -if redis_params then - rspamd_classifiers['neural'] = { - classify = function(task, classifier, tokens) - local function classify_cb() - local min_learns = classifier:get_param('min_learns') - - if min_learns then - min_learns = tonumber(min_learns) - end - - if min_learns and min_learns > 0 then - if current_classify_ann.ham_learned < min_learns or - current_classify_ann.spam_learned < min_learns then - - rspamd_logger.infox(task, 'fann classifier has not enough learns: (%s spam, %s ham), %s required', - current_classify_ann.spam_learned, current_classify_ann.ham_learned, - min_learns) - return - end - end - - -- Perform classification - local vec = tokens_to_vector(tokens) - add_metatokens(task, vec) - local out = current_classify_ann.ann:test(vec) - local result = rspamd_util.tanh(2 * (out[1])) - local symscore = string.format('%.3f', out[1]) - rspamd_logger.infox(task, 'fann classifier score: %s', symscore) - - if result > 0 then - fun.each(function(st) - task:insert_result(st:get_symbol(), result, symscore) - end, - fun.filter(function(st) - return st:is_spam() - end, classifier:get_statfiles()) - ) - else - fun.each(function(st) - task:insert_result(st:get_symbol(), -result, symscore) - end, - fun.filter(function(st) - return not st:is_spam() - end, classifier:get_statfiles()) - ) - end - end - maybe_load_fann(task, classify_cb, false) - end, - - learn = function(task, _, tokens, is_spam) - local function learn_cb(_, is_loaded) - if not is_loaded then - create_fann() - end - local vec = tokens_to_vector(tokens) - add_metatokens(task, vec) - - if is_spam then - current_classify_ann.ann:train(vec, {1.0}) - rspamd_logger.infox(task, "learned ANN spam, MSE: %s", - current_classify_ann.ann:get_mse()) - else - current_classify_ann.ann:train(vec, {-1.0}) - rspamd_logger.infox(task, "learned ANN ham, MSE: %s", - current_classify_ann.ann:get_mse()) - end - - save_fann(task, is_spam) - end - maybe_load_fann(task, learn_cb, true) - end, - } -end diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua deleted file mode 100644 index e6c7eade4..000000000 --- a/src/plugins/lua/fann_scores.lua +++ /dev/null @@ -1,378 +0,0 @@ ---[[ -Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru> - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -]]-- - --- This plugin is a concept of FANN scores adjustment --- NOT FOR PRODUCTION USE so far - -if confighelp then - return -end - -local rspamd_logger = require "rspamd_logger" -local rspamd_fann = require "rspamd_fann" -local rspamd_util = require "rspamd_util" -local fann_symbol_spam = 'FANN_SPAM' -local fann_symbol_ham = 'FANN_HAM' -local fun = require "fun" - -local module_log_id = 0x100 --- Module vars --- ANNs indexed by settings id -local data = { - ['0'] = { - fann_mtime = 0, - ntrains = 0, - epoch = 0, - } -} - -local fann_file -local max_trains = 1000 -local max_epoch = 100 -local use_settings = false - -local function symbols_to_fann_vector(syms, scores) - local learn_data = {} - local matched_symbols = {} - local n = rspamd_config:get_symbols_count() - - fun.each(function(s, score) - if score ~= score then score = 0.0 end -- nan sanity - matched_symbols[s + 1] = rspamd_util.tanh(score) - end, fun.zip(syms, scores)) - - for i=1,n do - if matched_symbols[i] then - learn_data[i] = matched_symbols[i] - else - learn_data[i] = 0 - end - end - - return learn_data -end - -local function gen_fann_file(id) - if use_settings then - return fann_file .. id - else - return fann_file - end -end - -local function load_fann(id) - local fname = gen_fann_file(id) - local err = rspamd_util.stat(fname) - - if err then - return false - end - - local fd = rspamd_util.lock_file(fname) - data[id].fann = rspamd_fann.load(fname) - rspamd_util.unlock_file(fd) -- closes fd - - if data[id].fann then - local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() - - if n ~= data[id].fann:get_inputs() then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache; removing', data[id].fann:get_inputs(), n) - data[id].fann = nil - - local ret,_err = rspamd_util.unlink(fname) - if not ret then - rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, _err) - end - else - local layers = data[id].fann:get_layers() - - if not layers or #layers ~= 5 then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing', - #layers) - data[id].fann = nil - local ret,_err = rspamd_util.unlink(fname) - if not ret then - rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, _err) - end - else - rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname) - return true - end - end - else - rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname) - local ret,_err = rspamd_util.unlink(fname) - if not ret then - rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s', - fname, _err) - end - end - - return false -end - -local function check_fann(id) - if data[id].fann then - local n = rspamd_config:get_symbols_count() + rspamd_count_metatokens() - - if n ~= data[id].fann:get_inputs() then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', data[id].fann:get_inputs(), n) - data[id].fann = nil - end - local layers = data[id].fann:get_layers() - - if not layers or #layers ~= 5 then - rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s', - #layers) - data[id].fann = nil - end - end - - local fname = gen_fann_file(id) - local err,st = rspamd_util.stat(fname) - - if not err then - local mtime = st['mtime'] - - if mtime > data[id].fann_mtime then - rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' .. - 'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname) - data[id].fann_mtime = mtime - data[id].fann = nil - end - end -end - -local function fann_scores_filter(task) - local id = '0' - if use_settings then - local sid = task:get_settings_id() - if sid then - id = tostring(sid) - end - end - - check_fann(id) - - if data[id].fann then - local symbols,scores = task:get_symbols_numeric() - local fann_data = symbols_to_fann_vector(symbols, scores) - local mt = rspamd_gen_metatokens(task) - - for _,tok in ipairs(mt) do - table.insert(fann_data, tok) - end - - local out = data[id].fann:test(fann_data) - local symscore = string.format('%.3f', out[1]) - rspamd_logger.infox(task, 'fann score: %s', symscore) - - if out[1] > 0 then - local result = rspamd_util.normalize_prob(out[1] / 2.0, 0) - task:insert_result(fann_symbol_spam, result, symscore, id) - else - local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0) - task:insert_result(fann_symbol_ham, result, symscore, id) - end - else - if load_fann(id) then - fann_scores_filter(task) - end - end -end - -local function create_train_fann(n, id) - data[id].fann_train = rspamd_fann.create(5, n, n, n / 2, n / 4, 1) - data[id].ntrains = 0 - data[id].epoch = 0 -end - -local function fann_train_callback(score, required_score, results, cf, id, opts, extra) - local n = cf:get_symbols_count() + rspamd_count_metatokens() - local fname = gen_fann_file(id) - - if not data[id].fann_train then - create_train_fann(n, id) - end - - if data[id].fann_train:get_inputs() ~= n then - rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' .. - ' is found in the cache', data[id].fann_train:get_inputs(), n) - create_train_fann(n, id) - end - - if data[id].ntrains > max_trains then - -- Store fann on disk - local res = false - - local err = rspamd_util.stat(fname) - local fd - if err then - fd,err = rspamd_util.create_file(fname) - if not fd then - rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err) - else - rspamd_util.lock_file(fname, fd) - res = data[id].fann_train:save(fname) - rspamd_util.unlock_file(fd) -- Closes fd as well - end - else - fd = rspamd_util.lock_file(fname) - res = data[id].fann_train:save(fname) - rspamd_util.unlock_file(fd) -- Closes fd as well - end - - if not res then - rspamd_logger.errx(cf, 'cannot save fann in %s', fname) - else - data[id].exist = true - data[id].ntrains = 0 - data[id].epoch = data[id].epoch + 1 - end - else - if not data[id].checked then - data[id].checked = true - local err = rspamd_util.stat(fname) - if err then - data[id].exist = false - end - end - if not data[id].exist then - rspamd_logger.infox(cf, 'not enough trains for fann %s, %s left', fname, - max_trains - data[id].ntrains) - end - end - - if data[id].epoch > max_epoch then - -- Re-create fann - rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname, - max_epoch) - create_train_fann(n, id) - end - - local learn_spam, learn_ham - if opts['spam_score'] then - learn_spam = score >= opts['spam_score'] - else - learn_spam = score >= required_score - end - if opts['ham_score'] then - learn_ham = score <= opts['ham_score'] - else - learn_ham = score < 0 - end - - if learn_spam or learn_ham then - local learn_data = symbols_to_fann_vector( - fun.map(function(r) return r[1] end, results), - fun.map(function(r) return r[2] end, results) - ) - -- Add filtered meta tokens - fun.each(function(e) table.insert(learn_data, e) end, extra) - - if learn_spam then - data[id].fann_train:train(learn_data, {1.0}) - else - data[id].fann_train:train(learn_data, {-1.0}) - end - - data[id].ntrains = data[id].ntrains + 1 - end -end - --- Initialization part - -local opts = rspamd_config:get_all_opt("fann_scores") -if not (opts and type(opts) == 'table') then - rspamd_logger.infox(rspamd_config, 'Module is unconfigured') - return -end - -if not rspamd_fann.is_enabled() then - rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. - 'module is eventually disabled') - - return -else - if not opts['fann_file'] then - rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' .. - '`fann_file` to be specified') - else - fann_file = opts['fann_file'] - use_settings = opts['use_settings'] - rspamd_config:set_metric_symbol({ - name = fann_symbol_spam, - score = 3.0, - description = 'Neural network SPAM', - group = 'fann' - }) - local id = rspamd_config:register_symbol({ - name = fann_symbol_spam, - type = 'postfilter', - priority = 5, - callback = fann_scores_filter - }) - rspamd_config:set_metric_symbol({ - name = fann_symbol_ham, - score = -2.0, - description = 'Neural network HAM', - group = 'fann' - }) - rspamd_config:register_symbol({ - name = fann_symbol_ham, - type = 'virtual', - parent = id - }) - if opts['train'] then - rspamd_config:add_on_load(function(cfg) - if opts['train']['max_train'] then - max_trains = opts['train']['max_train'] - end - if opts['train']['max_epoch'] then - max_epoch = opts['train']['max_epoch'] - end - local ret = cfg:register_worker_script("log_helper", - function(score, req_score, results, cf, _id, extra) - -- map (snd x) (filter (fst x == module_id) extra) - local extra_fann = fun.map(function(e) return e[2] end, - fun.filter(function(e) return e[1] == module_log_id end, extra)) - if use_settings then - fann_train_callback(score, req_score, results, cf, - tostring(_id), opts['train'], extra_fann) - else - fann_train_callback(score, req_score, results, cf, '0', - opts['train'], extra_fann) - end - end) - - if not ret then - rspamd_logger.errx(cfg, 'cannot find worker "log_helper"') - end - end) - rspamd_plugins["fann_score"] = { - log_callback = function(task) - return fun.totable(fun.map( - function(tok) return {module_log_id, tok} end, - rspamd_gen_metatokens(task))) - end - } - end - end -end diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/neural.lua index 117881b31..117881b31 100644 --- a/src/plugins/lua/fann_redis.lua +++ b/src/plugins/neural.lua |