diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-11-03 16:30:48 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2016-11-04 14:13:07 +0000 |
commit | 854f8bd4291a5402408caaba20cbb5064521a0b2 (patch) | |
tree | fd0db44ec575d64081640e49ba2550047e147467 /src | |
parent | cd0f7a9d7c311f215760c5ae13e4e0be06fed433 (diff) | |
download | rspamd-854f8bd4291a5402408caaba20cbb5064521a0b2.tar.gz rspamd-854f8bd4291a5402408caaba20cbb5064521a0b2.zip |
[Feature] Move fann_classifier to a separate plugin
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/fann_classifier.lua | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/src/plugins/lua/fann_classifier.lua b/src/plugins/lua/fann_classifier.lua new file mode 100644 index 000000000..af7acece8 --- /dev/null +++ b/src/plugins/lua/fann_classifier.lua @@ -0,0 +1,289 @@ +--[[ +Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- This plugin is a concept of FANN scores adjustment +-- NOT FOR PRODUCTION USE so far +local rspamd_logger = require "rspamd_logger" +local rspamd_fann = require "rspamd_fann" +local rspamd_util = require "rspamd_util" +require "fun" () +local ucl = require "ucl" + +local redis_params +local classifier_config = { + key = 'neural_net', + neurons = 200, + layers = 3, +} + +local current_classify_ann = { + loaded = false, + version = 0, + spam_learned = 0, + ham_learned = 0 +} + +redis_params = rspamd_parse_redis_server('fann_classifier') + +local function maybe_load_fann(task, continue_cb, call_if_fail) + local function load_fann() + local function redis_fann_load_cb(err, data) + if not err and type(data) == 'table' and type(data[2]) == 'string' then + local version = tonumber(data[1]) + local err,ann_data = rspamd_util.zstd_decompress(data[2]) + local ann + + if err or not ann_data then + rspamd_logger.errx(task, 'cannot decompress ann: %s', err) + else + ann = rspamd_fann.load_data(ann_data) + end + + if ann then + current_classify_ann.loaded = true + current_classify_ann.version = version + current_classify_ann.ann = ann + if type(data[3]) == 'string' then + current_classify_ann.spam_learned = tonumber(data[3]) + else + current_classify_ann.spam_learned = 0 + end + if type(data[4]) == 'string' then + current_classify_ann.ham_learned = tonumber(data[4]) + else + current_classify_ann.ham_learned = 0 + end + rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE", + version, current_classify_ann.spam_learned, + current_classify_ann.ham_learned, + ann:get_mse()) + continue_cb(task, true) + elseif call_if_fail then + continue_cb(task, false) + end + elseif call_if_fail then + continue_cb(task, false) + end + end + + local key = classifier_config.key + local ret,_,_ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_fann_load_cb, --callback + 'HMGET', -- command + {key, 'version', 'data', 'spam', 'ham'} -- arguments + ) + end + + local function check_fann() + local function redis_fann_check_cb(err, data) + if not err and type(data) == 'string' then + local version = tonumber(data) + + if version <= current_classify_ann.version then + continue_cb(task, true) + else + load_fann() + end + end + end + + local key = classifier_config.key + local ret,_,_ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + false, -- is write + redis_fann_check_cb, --callback + 'HGET', -- command + {key, 'version'} -- arguments + ) + end + + if not current_classify_ann.loaded then + load_fann() + else + check_fann() + end +end + +local function tokens_to_vector(tokens) + local vec = totable(map(function(tok) return tok[1] end, tokens)) + local ret = {} + local ntok = #vec + local neurons = classifier_config.neurons + for i = 1,neurons do + ret[i] = 0 + end + each(function(e) + local n = (e % neurons) + 1 + ret[n] = ret[n] + 1 + end, vec) + local norm = 0 + for i = 1,neurons do + if ret[i] > norm then + norm = ret[i] + end + end + for i = 1,neurons do + if ret[i] ~= 0 and norm > 0 then + ret[i] = ret[i] / norm + end + end + + return ret +end + +local function add_metatokens(task, vec) + local mt = gen_metatokens(task) + for _,tok in ipairs(mt) do + table.insert(vec, tok) + end +end + +local function create_fann() + local layers = {} + local mt_size = count_metatokens() + local neurons = classifier_config.neurons + mt_size + + for i = 1,classifier_config.layers - 1 do + layers[i] = math.floor(neurons / i) + end + + table.insert(layers, 1) + + local ann = rspamd_fann.create(classifier_config.layers, layers) + current_classify_ann.loaded = true + current_classify_ann.version = 0 + current_classify_ann.ann = ann + current_classify_ann.spam_learned = 0 + current_classify_ann.ham_learned = 0 +end + +local function save_fann(task, is_spam) + local function redis_fann_save_cb(err, data) + if err then + rspamd_logger.errx(task, "cannot save neural net to redis: %s", err) + end + end + + local data = current_classify_ann.ann:data() + local key = classifier_config.key + current_classify_ann.version = current_classify_ann.version + 1 + + if is_spam then + current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1 + else + current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1 + end + local ret,conn,_ = rspamd_redis_make_request(task, + redis_params, -- connect params + key, -- hash key + true, -- is write + redis_fann_save_cb, --callback + 'HMSET', -- command + { + key, + 'data', rspamd_util.zstd_compress(data), + }) -- arguments + + if conn then + conn:add_cmd('HINCRBY', {key, 'version', 1}) + if is_spam then + conn:add_cmd('HINCRBY', {key, 'spam', 1}) + rspamd_logger.errx(task, 'hui') + else + conn:add_cmd('HINCRBY', {key, 'ham', 1}) + rspamd_logger.errx(task, 'pezda') + end + end +end + +if redis_params then + rspamd_classifiers['neural'] = { + classify = function(task, classifier, tokens) + local function classify_cb(task) + local min_learns = classifier:get_param('min_learns') + + if min_learns then + min_learns = tonumber(min_learns) + end + + if min_learns and min_learns > 0 then + if current_classify_ann.ham_learned < min_learns or + current_classify_ann.spam_learned < min_learns then + + rspamd_logger.infox(task, 'fann classifier has not enough learns: (%s spam, %s ham), %s required', + current_classify_ann.spam_learned, current_classify_ann.ham_learned, + min_learns) + return + end + end + + -- Perform classification + local vec = tokens_to_vector(tokens) + add_metatokens(task, vec) + local out = current_classify_ann.ann:test(vec) + local result = rspamd_util.tanh(2 * (out[1])) + local symscore = string.format('%.3f', out[1]) + rspamd_logger.infox(task, 'fann classifier score: %s', symscore) + + if result > 0 then + each(function(st) + task:insert_result(st:get_symbol(), result, symscore) + end, + filter(function(st) + return st:is_spam() + end, classifier:get_statfiles()) + ) + else + each(function(st) + task:insert_result(st:get_symbol(), -result, symscore) + end, + filter(function(st) + return not st:is_spam() + end, classifier:get_statfiles()) + ) + end + end + maybe_load_fann(task, classify_cb, false) + end, + + learn = function(task, classifier, tokens, is_spam, is_unlearn) + local function learn_cb(task, is_loaded) + if not is_loaded then + create_fann() + end + local vec = tokens_to_vector(tokens) + add_metatokens(task, vec) + + if is_spam then + current_classify_ann.ann:train(vec, {1.0}) + rspamd_logger.infox(task, "learned ANN spam, MSE: %s", + current_classify_ann.ann:get_mse()) + else + current_classify_ann.ann:train(vec, {-1.0}) + rspamd_logger.infox(task, "learned ANN ham, MSE: %s", + current_classify_ann.ann:get_mse()) + end + + save_fann(task, is_spam) + end + maybe_load_fann(task, learn_cb, true) + end, + } +end
\ No newline at end of file |