From b35772a2873d610c563d208073ea4b29135d24b7 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sat, 8 Oct 2016 16:35:23 +0100 Subject: [PATCH] [Feature] Add neural net classifier to fann_scores module --- src/plugins/lua/fann_scores.lua | 216 +++++++++++++++++++++++++++++++- 1 file changed, 215 insertions(+), 1 deletion(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 9ddb79fc3..c67eb597d 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -498,9 +498,11 @@ end if not rspamd_fann.is_enabled() then rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' .. 'module is eventually disabled') + + return else if not opts['fann_file'] then - rspamd_logger.errx(rspamd_config, 'fann_scores module requires ' .. + rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' .. '`fann_file` to be specified') else fann_file = opts['fann_file'] @@ -560,3 +562,215 @@ else end end end + +local redis_params +local classifier_config = { + key = 'neural_net', + neurons = 200, + layers = 3, +} + +local current_classify_ann = { + loaded = false, + version = 0, + spam_learned = 0, + ham_learned = 0 +} + +redis_params = rspamd_parse_redis_server('fann_scores') + +local function maybe_load_fann(task, continue_cb, call_if_fail) + local function load_fann() + local function redis_fann_load_cb(task, err, data) + if not err and type(data) == 'table' and type(data[2]) == 'string' then + local version = tonumber(data[1]) + local ann_data = data[2] + local ann = rspamd_fann.load_data(ann_data) + + if ann then + current_classify_ann.loaded = true + current_classify_ann.version = version + current_classify_ann.ann = ann + current_classify_ann.spam_learned = tonumber(data[3]) + current_classify_ann.ham_learned = tonumber(data[4]) + rspamd_logger.infox(task, "loaded fann classifier version %s", version) + continue_cb(task, true) + elseif call_if_fail then + continue_cb(task, false) + end + elseif call_if_fail then + continue_cb(task, false) + end + end + + local key = classifier_config.key + local ret,_,_ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_fann_load_cb, --callback + 'HMGET', -- command + {key, 'version', 'data', 'spam', 'ham'} -- arguments + ) + end + + local function check_fann() + local function redis_fann_check_cb(task, err, data) + if not err and type(data) == 'string' then + local version = tonumber(data) + + if version == current_classify_ann.version then + continue_cb(task, true) + else + load_fann() + end + end + end + + local key = classifier_config.key + local ret,_,_ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_fann_check_cb, --callback + 'HGET', -- command + {key, 'version'} -- arguments + ) + end + + if not current_classify_ann.loaded then + load_fann() + else + check_fann() + end +end + +local function tokens_to_vector(tokens) + local vec = map(function(tok) return tok[1] end, tokens) + local ret = {} + local neurons = classifier_config.neurons + for i = 1,neurons do + ret[i] = 0 + end + each(function(e) + local n = (e % neurons) + 1 + ret[n] = ret[n] + 1 + end, vec) + for i = 1,neurons do + if ret[i] ~= 0 then + ret[i] = 1.0 / ret[i] + end + end + + return ret +end + +local function add_metatokens(task, vec) + local mt = gen_metatokens(task) + for _,tok in ipairs(mt) do + table.insert(vec, tok) + end +end + +local function create_fann() + local layers = {} + local mt_size = count_metatokens() + local neurons = classifier_config.neurons + mt_size + + for i = 1,classifier_config.layers - 1 do + layers[i] = math.floor(neurons / i) + end + + table.insert(layers, 1) + + local ann = rspamd_fann.create(classifier_config.layers, layers) + current_classify_ann.loaded = true + current_classify_ann.version = 0 + current_classify_ann.ann = ann + current_classify_ann.spam_learned = 0 + current_classify_ann.ham_learned = 0 +end + +local function save_fann(task, is_spam) + local function redis_fann_save_cb(task, err, data) + if err then + rspamd_logger.errx(task, "cannot save neural net to redis: %s", err) + end + end + + local data = current_classify_ann.ann:data() + local key = classifier_config.key + current_classify_ann.version = current_classify_ann.version + 1 + + if is_spam then + current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1 + else + current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1 + end + local ret,_,_ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_fann_save_cb, --callback + 'HMSET', -- command + { + key, + 'version', tostring(current_classify_ann.version), + 'data', tostring(data), + 'spam', tostring(current_classify_ann.spam_learned), + 'ham', tostring(current_classify_ann.ham_learned), + } -- arguments + ) +end + +if redis_params then + rspamd_classifiers['neural'] = { + classify = function(task, classifier, tokens) + local function classify_cb(task) + local vec = tokens_to_vector(tokens) + add_metatokens(task, vec) + local out = current_classify_ann.ann:test(vec) + local result = rspamd_util.tanh(2 * (out[1] - 0.5)) + local symscore = string.format('%.3f', out[1]) + rspamd_logger.infox(task, 'fann classifier score: %s', symscore) + + if result > 0 then + each(function(st) + task:insert_result(st:get_symbol(), result, symscore) + end, + filter(function(st) + return st:is_spam() + end, classifier:get_statfiles()) + ) + else + each(function(st) + task:insert_result(st:get_symbol(), -result, symscore) + end, + filter(function(st) + return not st:is_spam() + end, classifier:get_statfiles()) + ) + end + end + maybe_load_fann(task, classify_cb, false) + end, + + learn = function(task, classifier, tokens, is_spam, is_unlearn) + local function learn_cb(task, is_loaded) + if not is_loaded then + create_fann() + end + local vec = tokens_to_vector(tokens) + add_metatokens(task, vec) + rspamd_logger.infox(task, "vector: %s", vec) + if is_spam then + current_classify_ann.ann:train(vec, {1.0}) + else + current_classify_ann.ann:train(vec, {0.0}) + end + save_fann(task, is_spam) + end + maybe_load_fann(task, learn_cb, true) + end, + } +end -- 2.39.5