From ef9839e676596c71b40c3c0dcf77a63111d8ea49 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Mon, 10 Oct 2016 20:36:48 +0100 Subject: [PATCH] [Minor] Multiple fixes to neural net classifier --- src/plugins/lua/fann_scores.lua | 62 ++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 8123a92bb..c1c3d80c0 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -597,9 +597,20 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) current_classify_ann.loaded = true current_classify_ann.version = version current_classify_ann.ann = ann - current_classify_ann.spam_learned = tonumber(data[3]) - current_classify_ann.ham_learned = tonumber(data[4]) - rspamd_logger.infox(task, "loaded fann classifier version %s", version) + if type(data[3]) == 'string' then + current_classify_ann.spam_learned = tonumber(data[3]) + else + current_classify_ann.spam_learned = 0 + end + if type(data[4]) == 'string' then + current_classify_ann.ham_learned = tonumber(data[4]) + else + current_classify_ann.ham_learned = 0 + end + rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE", + version, current_classify_ann.spam_learned, + current_classify_ann.ham_learned, + ann:get_mse()) continue_cb(task, true) elseif call_if_fail then continue_cb(task, false) @@ -625,7 +636,7 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) if not err and type(data) == 'string' then local version = tonumber(data) - if version == current_classify_ann.version then + if version <= current_classify_ann.version then continue_cb(task, true) else load_fann() @@ -652,8 +663,9 @@ local function maybe_load_fann(task, continue_cb, call_if_fail) end local function tokens_to_vector(tokens) - local vec = map(function(tok) return tok[1] end, tokens) + local vec = totable(map(function(tok) return tok[1] end, tokens)) local ret = {} + local ntok = #vec local neurons = classifier_config.neurons for i = 1,neurons do ret[i] = 0 @@ -662,9 +674,15 @@ local function tokens_to_vector(tokens) local n = (e % neurons) + 1 ret[n] = ret[n] + 1 end, vec) + local norm = 0 + for i = 1,neurons do + if ret[i] > norm then + norm = ret[i] + end + end for i = 1,neurons do - if ret[i] ~= 0 then - ret[i] = 1.0 / ret[i] + if ret[i] ~= 0 and norm > 0 then + ret[i] = ret[i] / norm end end @@ -713,7 +731,7 @@ local function save_fann(task, is_spam) else current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1 end - local ret,_,_ = rspamd_redis_make_request(task, + local ret,conn,_ = rspamd_redis_make_request(task, redis_params, -- connect params key, -- hash key true, -- is write @@ -721,12 +739,19 @@ local function save_fann(task, is_spam) 'HMSET', -- command { key, - 'version', tostring(current_classify_ann.version), 'data', rspamd_util.zstd_compress(data), - 'spam', tostring(current_classify_ann.spam_learned), - 'ham', tostring(current_classify_ann.ham_learned), - } -- arguments - ) + }) -- arguments + + if conn then + conn:add_cmd('HINCRBY', {key, 'version', 1}) + if is_spam then + conn:add_cmd('HINCRBY', {key, 'spam', 1}) + rspamd_logger.errx(task, 'hui') + else + conn:add_cmd('HINCRBY', {key, 'ham', 1}) + rspamd_logger.errx(task, 'pezda') + end + end end if redis_params then @@ -754,7 +779,7 @@ if redis_params then local vec = tokens_to_vector(tokens) add_metatokens(task, vec) local out = current_classify_ann.ann:test(vec) - local result = rspamd_util.tanh(2 * (out[1] - 0.5)) + local result = rspamd_util.tanh(2 * (out[1])) local symscore = string.format('%.3f', out[1]) rspamd_logger.infox(task, 'fann classifier score: %s', symscore) @@ -786,12 +811,17 @@ if redis_params then end local vec = tokens_to_vector(tokens) add_metatokens(task, vec) - rspamd_logger.infox(task, "vector: %s", vec) + if is_spam then current_classify_ann.ann:train(vec, {1.0}) + rspamd_logger.infox(task, "learned ANN spam, MSE: %s", + current_classify_ann.ann:get_mse()) else - current_classify_ann.ann:train(vec, {0.0}) + current_classify_ann.ann:train(vec, {-1.0}) + rspamd_logger.infox(task, "learned ANN ham, MSE: %s", + current_classify_ann.ann:get_mse()) end + save_fann(task, is_spam) end maybe_load_fann(task, learn_cb, true) -- 2.39.5