From 6fe3489d5bdad985adcfd4e478d91cac2d437448 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Fri, 4 Nov 2016 15:56:16 +0000 Subject: [PATCH] [Rework] Add preliminary train tests --- src/plugins/lua/fann_scores.lua | 68 ++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/src/plugins/lua/fann_scores.lua b/src/plugins/lua/fann_scores.lua index 3c46cda2f..7bc55117d 100644 --- a/src/plugins/lua/fann_scores.lua +++ b/src/plugins/lua/fann_scores.lua @@ -520,7 +520,6 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, end local function can_train_cb(err, data) - rspamd_logger.errx('data: %s, err: %s', data, err) if not err and tonumber(data) > 0 then local learn_data = symbols_to_fann_vector( map(function(r) return r[1] end, results), @@ -556,6 +555,63 @@ local function fann_train_callback(score, required_score, results, cf, id, opts, end end +local function train_fann(cfg, ev_base, elt) + +end + +local function maybe_train_fanns(cfg, ev_base) + local function members_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANNS list from redis: %s', err) + elseif type(data) == 'table' then + each(function(i, elt) + local redis_len_cb = function(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get FANN trains %s from redis: %s', elt, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + if tonumber(data) and tonumber(data) > max_trains then + train_fann(cfg, ev_base, elt) + end + end + end + + local local_ver = 0 + local numelt = tonumber(elt) + if data[numelt] then + if data[numelt].version then + local_ver = data[numelt].version + end + end + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + redis_len_cb, --callback + 'LLEN', -- command + {fann_prefix .. elt .. '_spam'} + ) + end, + data) + end + end + + if not redis_maybe_load_sha then + -- Plan new event early + return 1.0 + end + -- First we need to get all fanns stored in our Redis + redis_make_request(ev_base, + rspamd_config, + nil, + false, -- is write + members_cb, --callback + 'SMEMBERS', -- command + {fann_prefix} -- arguments + ) + + return watch_interval +end + local function check_fanns(cfg, ev_base) local function members_cb(err, data) if err then @@ -680,7 +736,7 @@ else } end -- Add training scripts - rspamd_config:add_on_load(function(cfg, ev_base) + rspamd_config:add_on_load(function(cfg, ev_base, worker) local function can_train_sha_cb(err, data) if err or not data or type(data) ~= 'string' then rspamd_logger.errx(cfg, 'cannot save redis train script: %s', err) @@ -733,5 +789,13 @@ else 'SCRIPT', -- command {'LOAD', redis_lua_script_maybe_invalidate} -- arguments ) + + if worker:get_name() == 'normal' then + -- We also want to train neural nets when they have enough data + rspamd_config:add_periodic(ev_base, 0.0, + function(cfg, ev_base) + return maybe_train_fanns(cfg, ev_base) + end) + end end) end -- 2.39.5