aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/plugins/lua/fann_scores.lua768
1 files changed, 299 insertions, 469 deletions
diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua
index 0d9e00435..64566e102 100644
--- a/src/plugins/lua/fann_scores.lua
+++ b/src/plugins/lua/fann_scores.lua
@@ -1,5 +1,5 @@
--[[
-Copyright (c) 2015, Vsevolod Stakhov <vsevolod@highsecure.ru>
+Copyright (c) 2016, Vsevolod Stakhov <vsevolod@highsecure.ru>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -36,11 +36,111 @@ local data = {
}
}
-local fann_file
+
+-- Lua script to train a row
+-- Uses the following keys:
+-- key1 - prefix for keys
+-- key2 - max count of learns
+-- key3 - spam or ham
+-- returns 1 or 0: 1 - allow learn, 0 - not allow learn
+local redis_lua_script_can_train = [[
+ local locked = redis.call('GET', KEYS[1] .. '_locked')
+ if locked then return 0 end
+ local nspam = 0
+ local nham = 0
+
+ local ret = redis.call('LLEN', KEYS[1] .. '_spam')
+ if ret then nspam = tonumber(ret) end
+ ret = redis.call('LLEN', KEYS[1] .. '_ham')
+ if ret then nham = tonumber(ret) end
+
+ if KEYS[3] == 'spam' then
+ if nham + 1 >= nspam then return tostring(nspam) end
+ else
+ if nspam + 1 >= nham then return tostring(nham) end
+ end
+
+ return tostring(0)
+]]
+local redis_can_train_sha = nil
+
+-- Lua script to load ANN from redis
+-- Uses the following keys
+-- key1 - prefix for keys
+-- key2 - local version
+-- returns nil or bulk string if new ANN can be loaded
+local redis_lua_script_maybe_load = [[
+ local locked = redis.call('GET', KEYS[1] .. '_locked')
+ if locked then return false end
+
+ local ver = 0
+ local ret = redis.call('GET', KEYS[1] .. '_version')
+ if ret then ver = tonumber(ret) end
+ if ver > KEYS[2] then return redis.call('GET', KEYS[1] .. '_ann') end
+
+ return false
+]]
+local redis_fann_maybe_load_sha = nil
+
+local redis_params
+redis_params = rspamd_parse_redis_server('fann_scores')
+
+local fann_prefix = 'RF'
local max_trains = 1000
local max_epoch = 100
local use_settings = false
+local watch_interval = 60.0
+
+local function redis_make_request(ev_base, cfg, key, is_write, callback, command, args)
+ if not ev_base or not redis_params or not callback or not command then
+ return false,nil,nil
+ end
+ local addr
+ local rspamd_redis = require "rspamd_redis"
+
+ if key then
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_by_hash(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_by_hash(key)
+ end
+ else
+ if is_write then
+ addr = redis_params['write_servers']:get_upstream_master_slave(key)
+ else
+ addr = redis_params['read_servers']:get_upstream_round_robin(key)
+ end
+ end
+
+ if not addr then
+ logger.errx(task, 'cannot select server to make redis request')
+ end
+
+ local options = {
+ ev_base = ev_base,
+ config = cfg,
+ callback = callback,
+ host = addr:get_addr(),
+ timeout = redis_params['timeout'],
+ cmd = command,
+ args = args
+ }
+
+ if redis_params['password'] then
+ options['password'] = redis_params['password']
+ end
+
+ if redis_params['db'] then
+ options['dbname'] = redis_params['db']
+ end
+
+ local ret,conn = rspamd_redis.make_request(options)
+ if not ret then
+ rspamd_logger.errx('cannot execute redis request')
+ end
+ return ret,conn,addr
+end
-- Metafunctions
local function fann_size_function(task)
@@ -277,69 +377,15 @@ local function symbols_to_fann_vector(syms, scores)
return learn_data
end
-local function gen_fann_file(id)
+local function gen_fann_prefix(id)
if use_settings then
- return fann_file .. id
- else
- return fann_file
- end
-end
-
-local function load_fann(id)
- local fname = gen_fann_file(id)
- local err,st = rspamd_util.stat(fname)
-
- if err then
- return false
- end
-
- local fd = rspamd_util.lock_file(fname)
- data[id].fann = rspamd_fann.load(fname)
- rspamd_util.unlock_file(fd) -- closes fd
-
- if data[id].fann then
- local n = rspamd_config:get_symbols_count() + count_metatokens()
-
- if n ~= data[id].fann:get_inputs() then
- rspamd_logger.infox(rspamd_config, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache; removing', data[id].fann:get_inputs(), n)
- data[id].fann = nil
-
- local ret,err = rspamd_util.unlink(fname)
- if not ret then
- rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fname, err)
- end
- else
- local layers = data[id].fann:get_layers()
-
- if not layers or #layers ~= 5 then
- rspamd_logger.infox(rspamd_config, 'fann has incorrect number of layers: %s, removing',
- #layers)
- data[id].fann = nil
- local ret,err = rspamd_util.unlink(fname)
- if not ret then
- rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fname, err)
- end
- else
- rspamd_logger.infox(rspamd_config, 'loaded fann from %s', fname)
- return true
- end
- end
+ return fann_prefix .. id
else
- rspamd_logger.infox(rspamd_config, 'fann is invalid: "%s"; removing', fname)
- local ret,err = rspamd_util.unlink(fname)
- if not ret then
- rspamd_logger.errx(rspamd_config, 'cannot remove invalid fann from %s: %s',
- fname, err)
- end
+ return fann_prefix
end
-
- return false
end
-local function check_fann(id)
+local function is_fann_valid(id)
if data[id].fann then
local n = rspamd_config:get_symbols_count() + count_metatokens()
@@ -357,19 +403,11 @@ local function check_fann(id)
end
end
- local fname = gen_fann_file(id)
- local err,st = rspamd_util.stat(fname)
-
- if not err then
- local mtime = st['mtime']
-
- if mtime > data[id].fann_mtime then
- rspamd_logger.infox(rspamd_config, 'have more fresh version of fann ' ..
- 'file: %s -> %s, need to reload %s', data[id].fann_mtime, mtime, fname)
- data[id].fann_mtime = mtime
- data[id].fann = nil
- end
+ if data[id].fann then
+ return true
end
+
+ return false
end
local function fann_scores_filter(task)
@@ -381,8 +419,6 @@ local function fann_scores_filter(task)
end
end
- check_fann(id)
-
if data[id].fann then
local symbols,scores = task:get_symbols_numeric()
local fann_data = symbols_to_fann_vector(symbols, scores)
@@ -403,10 +439,6 @@ local function fann_scores_filter(task)
local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
task:insert_result(fann_symbol_ham, result, symscore, id)
end
- else
- if load_fann(id) then
- fann_scores_filter(task)
- end
end
end
@@ -416,69 +448,11 @@ local function create_train_fann(n, id)
data[id].epoch = 0
end
-local function fann_train_callback(score, required_score, results, cf, id, opts, extra)
- local n = cf:get_symbols_count() + count_metatokens()
- local fname = gen_fann_file(id)
-
- if not data[id].fann_train then
- create_train_fann(n, id)
- end
-
- if data[id].fann_train:get_inputs() ~= n then
- rspamd_logger.infox(cf, 'fann has incorrect number of inputs: %s, %s symbols' ..
- ' is found in the cache', data[id].fann_train:get_inputs(), n)
- create_train_fann(n, id)
- end
-
- if data[id].ntrains > max_trains then
- -- Store fann on disk
- local res = false
-
- local err,st = rspamd_util.stat(fname)
- if err then
- local fd,err = rspamd_util.create_file(fname)
- if not fd then
- rspamd_logger.errx(cf, 'cannot save fann in %s: %s', fname, err)
- else
- rspamd_util.lock_file(fname, fd)
- res = data[id].fann_train:save(fname)
- rspamd_util.unlock_file(fd) -- Closes fd as well
- end
- else
- local fd = rspamd_util.lock_file(fname)
- res = data[id].fann_train:save(fname)
- rspamd_util.unlock_file(fd) -- Closes fd as well
- end
-
- if not res then
- rspamd_logger.errx(cf, 'cannot save fann in %s', fname)
- else
- data[id].exist = true
- data[id].ntrains = 0
- data[id].epoch = data[id].epoch + 1
- end
- else
- if not data[id].checked then
- data[id].checked = true
- local err,st = rspamd_util.stat(fname)
- if err then
- data[id].exist = false
- end
- end
- if not data[id].exist then
- rspamd_logger.infox(cf, 'not enough trains for fann %s, %s left', fname,
- max_trains - data[id].ntrains)
- end
- end
-
- if data[id].epoch > max_epoch then
- -- Re-create fann
- rspamd_logger.infox(cf, 'create new fann in %s after %s epoches', fname,
- max_epoch)
- create_train_fann(n, id)
- end
+local function fann_train_callback(score, required_score, results, cf, id, opts, extra, ev_base)
+ local fname = gen_fann_prefix(id)
local learn_spam, learn_ham = false, false
+
if opts['spam_score'] then
learn_spam = score >= opts['spam_score']
else
@@ -491,21 +465,110 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
end
if learn_spam or learn_ham then
- local learn_data = symbols_to_fann_vector(
- map(function(r) return r[1] end, results),
- map(function(r) return r[2] end, results)
+ local k
+ if learn_spam then k = 'spam' else k = 'ham' end
+
+ local function can_train_cb(err, data)
+ rspamd_logger.errx('hui')
+ if not err and tonumber(data) > 0 then
+ local learn_data = symbols_to_fann_vector(
+ map(function(r) return r[1] end, results),
+ map(function(r) return r[2] end, results)
+ )
+ -- Add filtered meta tokens
+ each(function(e) table.insert(learn_data, e) end, extra)
+ local str = table.concat(learn_data, ';')
+
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ learn_cb, --callback
+ 'LPUSH', -- command
+ {fname .. '_' .. k, str} -- arguments
+ )
+ else
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot check if we can train: %s', err)
+ end
+ end
+ end
+
+ rspamd_logger.errx('pizda: %s %s %s %s', redis_can_train_sha, fname, tostring(max_trains), k)
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ can_train_cb, --callback
+ 'EVALSHA', -- command
+ {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments
)
- -- Add filtered meta tokens
- each(function(e) table.insert(learn_data, e) end, extra)
+ end
+end
- if learn_spam then
- data[id].fann_train:train(learn_data, {1.0})
- else
- data[id].fann_train:train(learn_data, {-1.0})
+local function check_fanns(cfg, ev_base)
+ local function members_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err)
+ elseif type(data) == 'table' then
+ each(function(i, elt)
+ local redis_load_cb = function(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
+ elseif type(data) == 'string' then
+ --load_fann(data, elt)
+ end
+ end
+ local redis_update_cb = function(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot get FANN version %s from redis: %s', elt, err)
+ elseif data then
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_load_cb, --callback
+ 'GET', -- command
+ {fann_prefix, fann_prefix .. elt .. '_data'} -- arguments
+ )
+ end
+ end
+
+ local local_ver = 0
+ local numelt = tonumber(elt)
+ if data[numelt] then
+ if data[numelt].version then
+ local_ver = data[numelt].version
+ end
+ end
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ redis_update_cb, --callback
+ 'EVALSHA', -- command
+ {redis_fann_maybe_load_sha, 2, fann_prefix .. elt, tostring(local_ver)}
+ )
+ end,
+ data)
end
+ end
- data[id].ntrains = data[id].ntrains + 1
+ if not redis_fann_maybe_load_sha then
+ -- Plan new event early
+ return 1.0
end
+ -- First we need to get all fanns stored in our Redis
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ false, -- is write
+ members_cb, --callback
+ 'SMEMBERS', -- command
+ {fann_prefix} -- arguments
+ )
+
+ return watch_interval
end
-- Initialization part
@@ -519,337 +582,104 @@ end
if not rspamd_fann.is_enabled() then
rspamd_logger.errx(rspamd_config, 'fann is not compiled in rspamd, this ' ..
'module is eventually disabled')
-
return
else
- if not opts['fann_file'] then
- rspamd_logger.warnx(rspamd_config, 'fann_scores module requires ' ..
- '`fann_file` to be specified')
- else
- fann_file = opts['fann_file']
- use_settings = opts['use_settings']
- rspamd_config:set_metric_symbol({
- name = fann_symbol_spam,
- score = 3.0,
- description = 'Neural network SPAM',
- group = 'fann'
- })
- local id = rspamd_config:register_symbol({
- name = fann_symbol_spam,
- type = 'postfilter',
- priority = 5,
- callback = fann_scores_filter
- })
- rspamd_config:set_metric_symbol({
- name = fann_symbol_ham,
- score = -2.0,
- description = 'Neural network HAM',
- group = 'fann'
- })
- rspamd_config:register_symbol({
- name = fann_symbol_ham,
- type = 'virtual',
- parent = id
- })
- if opts['train'] then
- rspamd_config:add_on_load(function(cfg)
- if opts['train']['max_train'] then
- max_trains = opts['train']['max_train']
- end
- if opts['train']['max_epoch'] then
- max_epoch = opts['train']['max_epoch']
- end
- local ret = cfg:register_worker_script("log_helper",
- function(score, req_score, results, cf, id, extra)
- -- map (snd x) (filter (fst x == module_id) extra)
- local extra_fann = map(function(e) return e[2] end,
- filter(function(e) return e[1] == module_log_id end, extra))
- if use_settings then
- fann_train_callback(score, req_score, results, cf,
- tostring(id), opts['train'], extra_fann)
- else
- fann_train_callback(score, req_score, results, cf, '0',
- opts['train'], extra_fann)
- end
+ use_settings = opts['use_settings']
+ rspamd_config:set_metric_symbol({
+ name = fann_symbol_spam,
+ score = 3.0,
+ description = 'Neural network SPAM',
+ group = 'fann'
+ })
+ local id = rspamd_config:register_symbol({
+ name = fann_symbol_spam,
+ type = 'postfilter',
+ priority = 5,
+ callback = fann_scores_filter
+ })
+ rspamd_config:set_metric_symbol({
+ name = fann_symbol_ham,
+ score = -2.0,
+ description = 'Neural network HAM',
+ group = 'fann'
+ })
+ rspamd_config:register_symbol({
+ name = fann_symbol_ham,
+ type = 'virtual',
+ parent = id
+ })
+ if opts['train'] then
+ rspamd_config:add_on_load(function(cfg)
+ if opts['train']['max_train'] then
+ max_trains = opts['train']['max_train']
+ end
+ if opts['train']['max_epoch'] then
+ max_epoch = opts['train']['max_epoch']
+ end
+ local ret = cfg:register_worker_script("log_helper",
+ function(score, req_score, results, cf, id, extra, ev_base)
+ -- map (snd x) (filter (fst x == module_id) extra)
+ local extra_fann = map(function(e) return e[2] end,
+ filter(function(e) return e[1] == module_log_id end, extra))
+ if use_settings then
+ fann_train_callback(score, req_score, results, cf,
+ tostring(id), opts['train'], extra_fann, ev_base)
+ else
+ fann_train_callback(score, req_score, results, cf, '0',
+ opts['train'], extra_fann, ev_base)
+ end
end)
- if not ret then
- rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
- end
- end)
- rspamd_plugins["fann_score"] = {
- log_callback = function(task)
- return totable(map(
- function(tok) return {module_log_id, tok} end,
- gen_metatokens(task)))
- end
- }
- end
+ if not ret then
+ rspamd_logger.errx(cfg, 'cannot find worker "log_helper"')
+ end
+ end)
+ -- This is needed to pass extra tokens from worker to log_helper
+ rspamd_plugins["fann_score"] = {
+ log_callback = function(task)
+ return totable(map(
+ function(tok) return {module_log_id, tok} end,
+ gen_metatokens(task)))
+ end
+ }
end
-end
-
-local redis_params
-local classifier_config = {
- key = 'neural_net',
- neurons = 200,
- layers = 3,
-}
-
-local current_classify_ann = {
- loaded = false,
- version = 0,
- spam_learned = 0,
- ham_learned = 0
-}
-
-redis_params = rspamd_parse_redis_server('fann_scores')
-
-local function maybe_load_fann(task, continue_cb, call_if_fail)
- local function load_fann()
- local function redis_fann_load_cb(err, data)
- if not err and type(data) == 'table' and type(data[2]) == 'string' then
- local version = tonumber(data[1])
- local err,ann_data = rspamd_util.zstd_decompress(data[2])
- local ann
-
- if err or not ann_data then
- rspamd_logger.errx(task, 'cannot decompress ann: %s', err)
- else
- ann = rspamd_fann.load_data(ann_data)
- end
-
- if ann then
- current_classify_ann.loaded = true
- current_classify_ann.version = version
- current_classify_ann.ann = ann
- if type(data[3]) == 'string' then
- current_classify_ann.spam_learned = tonumber(data[3])
- else
- current_classify_ann.spam_learned = 0
- end
- if type(data[4]) == 'string' then
- current_classify_ann.ham_learned = tonumber(data[4])
- else
- current_classify_ann.ham_learned = 0
- end
- rspamd_logger.infox(task, "loaded fann classifier version %s (%s spam, %s ham), %s MSE",
- version, current_classify_ann.spam_learned,
- current_classify_ann.ham_learned,
- ann:get_mse())
- continue_cb(task, true)
- elseif call_if_fail then
- continue_cb(task, false)
- end
- elseif call_if_fail then
- continue_cb(task, false)
+ -- Add training scripts
+ rspamd_config:add_on_load(function(cfg, ev_base)
+ local function can_train_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err)
+ else
+ redis_can_train_sha = tostring(data)
end
end
-
- local key = classifier_config.key
- local ret,_,_ = rspamd_redis_make_request(task,
- redis_params, -- connect params
- key, -- hash key
- false, -- is write
- redis_fann_load_cb, --callback
- 'HMGET', -- command
- {key, 'version', 'data', 'spam', 'ham'} -- arguments
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ can_train_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_can_train} -- arguments
)
- end
- local function check_fann()
- local function redis_fann_check_cb(err, data)
- if not err and type(data) == 'string' then
- local version = tonumber(data)
+ local function maybe_load_sha_cb(err, data)
+ if err or not data or type(data) ~= 'string' then
+ rspamd_logger.errx(cfg, 'cannot save redis load script: %s', err)
+ else
+ redis_fann_maybe_load_sha = tostring(data)
- if version <= current_classify_ann.version then
- continue_cb(task, true)
- else
- load_fann()
- end
+ rspamd_config:add_periodic(ev_base, 0.0,
+ function(cfg, ev_base)
+ return check_fanns(cfg, ev_base)
+ end)
end
end
-
- local key = classifier_config.key
- local ret,_,_ = rspamd_redis_make_request(task,
- redis_params, -- connect params
- key, -- hash key
- false, -- is write
- redis_fann_check_cb, --callback
- 'HGET', -- command
- {key, 'version'} -- arguments
+ redis_make_request(ev_base,
+ rspamd_config,
+ nil,
+ true, -- is write
+ maybe_load_sha_cb, --callback
+ 'SCRIPT', -- command
+ {'LOAD', redis_lua_script_maybe_load} -- arguments
)
- end
-
- if not current_classify_ann.loaded then
- load_fann()
- else
- check_fann()
- end
-end
-
-local function tokens_to_vector(tokens)
- local vec = totable(map(function(tok) return tok[1] end, tokens))
- local ret = {}
- local ntok = #vec
- local neurons = classifier_config.neurons
- for i = 1,neurons do
- ret[i] = 0
- end
- each(function(e)
- local n = (e % neurons) + 1
- ret[n] = ret[n] + 1
- end, vec)
- local norm = 0
- for i = 1,neurons do
- if ret[i] > norm then
- norm = ret[i]
- end
- end
- for i = 1,neurons do
- if ret[i] ~= 0 and norm > 0 then
- ret[i] = ret[i] / norm
- end
- end
-
- return ret
-end
-
-local function add_metatokens(task, vec)
- local mt = gen_metatokens(task)
- for _,tok in ipairs(mt) do
- table.insert(vec, tok)
- end
-end
-
-local function create_fann()
- local layers = {}
- local mt_size = count_metatokens()
- local neurons = classifier_config.neurons + mt_size
-
- for i = 1,classifier_config.layers - 1 do
- layers[i] = math.floor(neurons / i)
- end
-
- table.insert(layers, 1)
-
- local ann = rspamd_fann.create(classifier_config.layers, layers)
- current_classify_ann.loaded = true
- current_classify_ann.version = 0
- current_classify_ann.ann = ann
- current_classify_ann.spam_learned = 0
- current_classify_ann.ham_learned = 0
-end
-
-local function save_fann(task, is_spam)
- local function redis_fann_save_cb(err, data)
- if err then
- rspamd_logger.errx(task, "cannot save neural net to redis: %s", err)
- end
- end
-
- local data = current_classify_ann.ann:data()
- local key = classifier_config.key
- current_classify_ann.version = current_classify_ann.version + 1
-
- if is_spam then
- current_classify_ann.spam_learned = current_classify_ann.spam_learned + 1
- else
- current_classify_ann.ham_learned = current_classify_ann.ham_learned + 1
- end
- local ret,conn,_ = rspamd_redis_make_request(task,
- redis_params, -- connect params
- key, -- hash key
- true, -- is write
- redis_fann_save_cb, --callback
- 'HMSET', -- command
- {
- key,
- 'data', rspamd_util.zstd_compress(data),
- }) -- arguments
-
- if conn then
- conn:add_cmd('HINCRBY', {key, 'version', 1})
- if is_spam then
- conn:add_cmd('HINCRBY', {key, 'spam', 1})
- rspamd_logger.errx(task, 'hui')
- else
- conn:add_cmd('HINCRBY', {key, 'ham', 1})
- rspamd_logger.errx(task, 'pezda')
- end
- end
-end
-
-if redis_params then
- rspamd_classifiers['neural'] = {
- classify = function(task, classifier, tokens)
- local function classify_cb(task)
- local min_learns = classifier:get_param('min_learns')
-
- if min_learns then
- min_learns = tonumber(min_learns)
- end
-
- if min_learns and min_learns > 0 then
- if current_classify_ann.ham_learned < min_learns or
- current_classify_ann.spam_learned < min_learns then
-
- rspamd_logger.infox(task, 'fann classifier has not enough learns: (%s spam, %s ham), %s required',
- current_classify_ann.spam_learned, current_classify_ann.ham_learned,
- min_learns)
- return
- end
- end
-
- -- Perform classification
- local vec = tokens_to_vector(tokens)
- add_metatokens(task, vec)
- local out = current_classify_ann.ann:test(vec)
- local result = rspamd_util.tanh(2 * (out[1]))
- local symscore = string.format('%.3f', out[1])
- rspamd_logger.infox(task, 'fann classifier score: %s', symscore)
-
- if result > 0 then
- each(function(st)
- task:insert_result(st:get_symbol(), result, symscore)
- end,
- filter(function(st)
- return st:is_spam()
- end, classifier:get_statfiles())
- )
- else
- each(function(st)
- task:insert_result(st:get_symbol(), -result, symscore)
- end,
- filter(function(st)
- return not st:is_spam()
- end, classifier:get_statfiles())
- )
- end
- end
- maybe_load_fann(task, classify_cb, false)
- end,
-
- learn = function(task, classifier, tokens, is_spam, is_unlearn)
- local function learn_cb(task, is_loaded)
- if not is_loaded then
- create_fann()
- end
- local vec = tokens_to_vector(tokens)
- add_metatokens(task, vec)
-
- if is_spam then
- current_classify_ann.ann:train(vec, {1.0})
- rspamd_logger.infox(task, "learned ANN spam, MSE: %s",
- current_classify_ann.ann:get_mse())
- else
- current_classify_ann.ann:train(vec, {-1.0})
- rspamd_logger.infox(task, "learned ANN ham, MSE: %s",
- current_classify_ann.ann:get_mse())
- end
-
- save_fann(task, is_spam)
- end
- maybe_load_fann(task, learn_cb, true)
- end,
- }
+ end)
end