123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995 |
- --[[
- Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- ]]--
-
-
- if confighelp then
- return
- end
-
- local fun = require "fun"
- local lua_redis = require "lua_redis"
- local lua_util = require "lua_util"
- local lua_verdict = require "lua_verdict"
- local neural_common = require "plugins/neural"
- local rspamd_kann = require "rspamd_kann"
- local rspamd_logger = require "rspamd_logger"
- local rspamd_tensor = require "rspamd_tensor"
- local rspamd_text = require "rspamd_text"
- local rspamd_util = require "rspamd_util"
- local ts = require("tableshape").types
-
- local N = "neural"
-
- local settings = neural_common.settings
-
- local redis_profile_schema = ts.shape{
- digest = ts.string,
- symbols = ts.array_of(ts.string),
- version = ts.number,
- redis_key = ts.string,
- distance = ts.number:is_optional(),
- }
-
- local has_blas = rspamd_tensor.has_blas()
- local text_cookie = rspamd_text.cookie
-
- -- Creates and stores ANN profile in Redis
- local function new_ann_profile(task, rule, set, version)
- local ann_key = neural_common.new_ann_key(rule, set, version, settings)
-
- local profile = {
- symbols = set.symbols,
- redis_key = ann_key,
- version = version,
- digest = set.digest,
- distance = 0 -- Since we are using our own profile
- }
-
- local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact', true)
-
- local function add_cb(err, _)
- if err then
- rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
- rule.prefix, set.name, profile.redis_key, err)
- else
- rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
- rule.prefix, set.name, profile.redis_key)
- end
- end
-
- lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- add_cb, --callback
- 'ZADD', -- command
- {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
- )
-
- return profile
- end
-
-
- -- ANN filter function, used to insert scores based on the existing symbols
- local function ann_scores_filter(task)
-
- for _,rule in pairs(settings.rules) do
- local sid = task:get_settings_id() or -1
- local ann
- local profile
-
- local set = neural_common.get_rule_settings(task, rule)
- if set then
- if set.ann then
- ann = set.ann.ann
- profile = set.ann
- else
- lua_util.debugm(N, task, 'no ann loaded for %s:%s',
- rule.prefix, set.name)
- end
- else
- lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
- rule.prefix, sid)
- end
-
- if ann then
- local vec = neural_common.result_to_vector(task, profile)
-
- local score
- local out = ann:apply1(vec, set.ann.pca)
- score = out[1]
-
- local symscore = string.format('%.3f', score)
- task:cache_set(rule.prefix .. '_neural_score', score)
- lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
- rule.prefix, set.name, set.ann.version, symscore)
-
- if score > 0 then
- local result = score
-
- -- If spam_score_threshold is defined, override all other thresholds.
- local spam_threshold = 0
- if rule.spam_score_threshold then
- spam_threshold = rule.spam_score_threshold
- elseif rule.roc_enabled and not set.ann.roc_thresholds then
- spam_threshold = set.ann.roc_thresholds[1]
- end
-
- if result >= spam_threshold then
- if rule.flat_threshold_curve then
- task:insert_result(rule.symbol_spam, 1.0, symscore)
- else
- task:insert_result(rule.symbol_spam, result, symscore)
- end
- else
- lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
- rule.prefix, set.name, set.ann.version, symscore,
- spam_threshold)
- end
- else
- local result = -(score)
-
- -- If ham_score_threshold is defined, override all other thresholds.
- local ham_threshold = 0
- if rule.ham_score_threshold then
- ham_threshold = rule.ham_score_threshold
- elseif rule.roc_enabled and not set.ann.roc_thresholds then
- ham_threshold = set.ann.roc_thresholds[2]
- end
-
- if result >= ham_threshold then
- if rule.flat_threshold_curve then
- task:insert_result(rule.symbol_ham, 1.0, symscore)
- else
- task:insert_result(rule.symbol_ham, result, symscore)
- end
- else
- lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
- rule.prefix, set.name, set.ann.version, result,
- ham_threshold)
- end
- end
- end
- end
- end
-
- local function ann_push_task_result(rule, task, verdict, score, set)
- local train_opts = rule.train
- local learn_spam, learn_ham
- local skip_reason = 'unknown'
-
- if not train_opts.store_pool_only and train_opts.autotrain then
- if train_opts.spam_score then
- learn_spam = score >= train_opts.spam_score
-
- if not learn_spam then
- skip_reason = string.format('score < spam_score: %f < %f',
- score, train_opts.spam_score)
- end
- else
- learn_spam = verdict == 'spam' or verdict == 'junk'
-
- if not learn_spam then
- skip_reason = string.format('verdict: %s',
- verdict)
- end
- end
-
- if train_opts.ham_score then
- learn_ham = score <= train_opts.ham_score
- if not learn_ham then
- skip_reason = string.format('score > ham_score: %f > %f',
- score, train_opts.ham_score)
- end
- else
- learn_ham = verdict == 'ham'
-
- if not learn_ham then
- skip_reason = string.format('verdict: %s',
- verdict)
- end
- end
- else
- -- Train by request header
- local hdr = task:get_request_header('ANN-Train')
-
- if hdr then
- if hdr:lower() == 'spam' then
- learn_spam = true
- elseif hdr:lower() == 'ham' then
- learn_ham = true
- else
- skip_reason = 'no explicit header'
- end
- elseif train_opts.store_pool_only then
- local ucl = require "ucl"
- learn_ham = false
- learn_spam = false
-
- -- Explicitly store tokens in cache
- local vec = neural_common.result_to_vector(task, set)
- task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
- task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
- skip_reason = 'store_pool_only has been set'
- end
- end
-
-
- if learn_spam or learn_ham then
- local learn_type
- if learn_spam then learn_type = 'spam' else learn_type = 'ham' end
-
- local function vectors_len_cb(err, data)
- if not err and type(data) == 'table' then
- local nspam,nham = data[1],data[2]
-
- if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
- local vec = neural_common.result_to_vector(task, set)
-
- local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
- local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
-
- local function learn_vec_cb(redis_err)
- if redis_err then
- rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
- rule.prefix, set.name, redis_err)
- else
- lua_util.debugm(N, task,
- "add train data for ANN rule " ..
- "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
- rule.prefix, set.name, learn_type, #vec, target_key, #str)
- end
- end
-
- lua_redis.redis_make_request(task,
- rule.redis,
- nil,
- true, -- is write
- learn_vec_cb, --callback
- 'SADD', -- command
- { target_key, str } -- arguments
- )
- else
- lua_util.debugm(N, task,
- "do not add %s train data for ANN rule " ..
- "%s:%s",
- learn_type, rule.prefix, set.name)
- end
- else
- if err then
- rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
- rule.prefix, set.name, err)
- elseif type(data) == 'string' then
- -- nil return value
- rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
- learn_type, rule.prefix, set.name, set.ann.redis_key, data)
- else
- rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
- 'please remove this key from Redis manually if you perform upgrade from the previous version',
- rule.prefix, set.name, set.ann.redis_key, type(data))
- end
- end
- end
-
- -- Check if we can learn
- if set.can_store_vectors then
- if not set.ann then
- -- Need to create or load a profile corresponding to the current configuration
- set.ann = new_ann_profile(task, rule, set, 0)
- lua_util.debugm(N, task,
- 'requested new profile for %s, set.ann is missing',
- set.name)
- end
-
- lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
- {task = task, is_write = false},
- vectors_len_cb,
- {
- set.ann.redis_key,
- })
- else
- lua_util.debugm(N, task,
- 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
- end
- else
- lua_util.debugm(N, task,
- 'do not push data to key %s: train condition not satisfied; reason: %s',
- (set.ann or {}).redis_key,
- skip_reason)
- end
- end
-
- --- Offline training logic
-
- -- Utility to extract and split saved training vectors to a table of tables
- local function process_training_vectors(data)
- return fun.totable(fun.map(function(tok)
- local _,str = rspamd_util.zstd_decompress(tok)
- return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
- end, data))
- end
-
- -- This function does the following:
- -- * Tries to lock ANN
- -- * Loads spam and ham vectors
- -- * Spawn learning process
- local function do_train_ann(worker, ev_base, rule, set, ann_key)
- local spam_elts = {}
- local ham_elts = {}
-
- local function redis_ham_cb(err, data)
- if err or type(data) ~= 'table' then
- rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
- ann_key, err)
- -- Unlock on error
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- {ann_key, 'lock'}
- )
- else
- -- Decompress and convert to numbers each training vector
- ham_elts = process_training_vectors(data)
- neural_common.spawn_train({worker = worker, ev_base = ev_base,
- rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
- spam_vec = spam_elts})
- end
- end
-
- -- Spam vectors received
- local function redis_spam_cb(err, data)
- if err or type(data) ~= 'table' then
- rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
- ann_key, err)
- -- Unlock ANN on error
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- neural_common.gen_unlock_cb(rule, set, ann_key), --callback
- 'HDEL', -- command
- {ann_key, 'lock'}
- )
- else
- -- Decompress and convert to numbers each training vector
- spam_elts = process_training_vectors(data)
- -- Now get ham vectors...
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_ham_cb, --callback
- 'SMEMBERS', -- command
- {ann_key .. '_ham_set'}
- )
- end
- end
-
- local function redis_lock_cb(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
- ann_key, err)
- elseif type(data) == 'number' and data == 1 then
- -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_spam_cb, --callback
- 'SMEMBERS', -- command
- {ann_key .. '_spam_set'}
- )
-
- rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
- rule.prefix, set.name, ann_key)
- else
- local lock_tm = tonumber(data[1])
- rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
- 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
- data[2], os.date('%c', lock_tm))
- end
- end
-
- -- Check if we are already learning this network
- if set.learning_spawned then
- rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
- ann_key)
- return
- end
-
- -- Call Redis script that tries to acquire a lock
- -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
- -- ANN is locked by another host (or a process, meh)
- lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
- {ev_base = ev_base, is_write = true},
- redis_lock_cb,
- {
- ann_key,
- tostring(os.time()),
- tostring(math.max(10.0, rule.watch_interval * 2)),
- rspamd_util.get_hostname()
- })
- end
-
- -- This function loads new ann from Redis
- -- This is based on `profile` attribute.
- -- ANN is loaded from `profile.redis_key`
- -- Rank of `profile` key is also increased, unfortunately, it means that we need to
- -- serialize profile one more time and set its rank to the current time
- -- set.ann fields are set according to Redis data received
- local function load_new_ann(rule, ev_base, set, profile, min_diff)
- local ann_key = profile.redis_key
-
- local function data_cb(err, data)
- if err then
- rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
- ann_key, err)
- else
- if type(data) == 'table' then
- if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
- local _err,ann_data = rspamd_util.zstd_decompress(data[1])
- local ann
-
- if _err or not ann_data then
- rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
- rule.prefix .. ':' .. set.name, ann_key, _err)
- return
- else
- ann = rspamd_kann.load(ann_data)
-
- if ann then
- set.ann = {
- digest = profile.digest,
- version = profile.version,
- symbols = profile.symbols,
- distance = min_diff,
- redis_key = profile.redis_key
- }
-
- local ucl = require "ucl"
- local profile_serialized = ucl.to_format(profile, 'json-compact', true)
- set.ann.ann = ann -- To avoid serialization
-
- local function rank_cb(_, _)
- -- TODO: maybe add some logging
- end
- -- Also update rank for the loaded ANN to avoid removal
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- true, -- is write
- rank_cb, --callback
- 'ZADD', -- command
- {set.prefix, tostring(rspamd_util.get_time()), profile_serialized}
- )
- rspamd_logger.infox(rspamd_config,
- 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[1], profile.version)
- else
- rspamd_logger.errx(rspamd_config,
- 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
- rule.prefix, set.name, ann_key)
- end
- end
- else
- lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
- end
-
- if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
- if rule.roc_enabled then
- local ucl = require "ucl"
- local parser = ucl.parser()
- local ok, parse_err = parser:parse_text(data[2])
- assert(ok, parse_err)
- local roc_thresholds = parser:get_object()
- set.ann.roc_thresholds = roc_thresholds
- rspamd_logger.infox(rspamd_config,
- 'loaded ROC thresholds for %s:%s; version=%s',
- rule.prefix, set.name, profile.version)
- rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
- end
- end
-
- if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
- -- PCA table
- local _err,pca_data = rspamd_util.zstd_decompress(data[3])
- if pca_data then
- if rule.max_inputs then
- -- We can use PCA
- set.ann.pca = rspamd_tensor.load(pca_data)
- rspamd_logger.infox(rspamd_config,
- 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
- rule.prefix, set.name, ann_key, #data[3], profile.version)
- else
- -- no need in pca, why is it there?
- rspamd_logger.warnx(rspamd_config,
- 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
- rule.prefix, set.name, ann_key)
- end
- else
- -- pca can be missing merely if we have no max_inputs
- if rule.max_inputs then
- rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
- rule.prefix, set.name, ann_key, _err)
- set.ann.ann = nil
- else
- -- It is okay
- set.ann.pca = nil
- end
- end
- end
-
- else
- lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
- rule.prefix, set.name, ann_key)
- end
- end
- end
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- data_cb, --callback
- 'HMGET', -- command
- {ann_key, 'ann', 'roc_thresholds', 'pca'}, -- arguments
- {opaque_data = true}
- )
- end
-
- -- Used to check an element in Redis serialized as JSON
- -- for some specific rule + some specific setting
- -- This function tries to load more fresh or more specific ANNs in lieu of
- -- the existing ones.
- -- Use this function to load ANNs as `callback` parameter for `check_anns` function
- local function process_existing_ann(_, ev_base, rule, set, profiles)
- local my_symbols = set.symbols
- local min_diff = math.huge
- local sel_elt
-
- for _,elt in fun.iter(profiles) do
- if elt and elt.symbols then
- local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
- -- Check distance
- if dist < #my_symbols * .3 then
- if dist < min_diff then
- min_diff = dist
- sel_elt = elt
- end
- end
- end
- end
-
- if sel_elt then
- -- We can load element from ANN
- if set.ann then
- -- We have an existing ANN, probably the same...
- if set.ann.digest == sel_elt.digest then
- -- Same ANN, check version
- if set.ann.version < sel_elt.version then
- -- Load new ann
- rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
- load_new_ann(rule, ev_base, set, sel_elt, min_diff)
- else
- lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
- 'our version = %s, remote version = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.version,
- sel_elt.version)
- end
- else
- -- We have some different ANN, so we need to compare distance
- if set.ann.distance > min_diff then
- -- Load more specific ANN
- rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
- load_new_ann(rule, ev_base, set, sel_elt, min_diff)
- else
- lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
- 'our distance = %s, remote distance = %s',
- rule.prefix .. ':' .. set.name,
- set.ann.distance,
- min_diff)
- end
- end
- else
- -- We have no ANN, load new one
- load_new_ann(rule, ev_base, set, sel_elt, min_diff)
- end
- end
- end
-
-
- -- This function checks all profiles and selects if we can train our
- -- ANN. By our we mean that it has exactly the same symbols in profile.
- -- Use this function to train ANN as `callback` parameter for `check_anns` function
- local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
- local my_symbols = set.symbols
- local sel_elt
- local lens = {
- spam = 0,
- ham = 0,
- }
-
- for _,elt in fun.iter(profiles) do
- if elt and elt.symbols then
- local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
- -- Check distance
- if dist == 0 then
- sel_elt = elt
- break
- end
- end
- end
-
- if sel_elt then
- -- We have our ANN and that's train vectors, check if we can learn
- local ann_key = sel_elt.redis_key
-
- lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
- ann_key)
-
- -- Create continuation closure
- local redis_len_cb_gen = function(cont_cb, what, is_final)
- return function(err, data)
- if err then
- rspamd_logger.errx(rspamd_config,
- 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
- elseif data and type(data) == 'number' or type(data) == 'string' then
- local ntrains = tonumber(data) or 0
- lens[what] = ntrains
- if is_final then
- -- Ensure that we have the following:
- -- one class has reached max_trains
- -- other class(es) are at least as full as classes_bias
- -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
- -- one class must have 10 or more trains whilst another should have
- -- at least (10 * (1 - 0.25)) = 8 trains
-
- local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
- local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
-
- if rule.train.learn_type == 'balanced' then
- local len_bias_check_pred = function(_, l)
- return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
- end
- if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
- rspamd_logger.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
- else
- rspamd_logger.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
- end
- else
- -- Probabilistic mode, just ensure that at least one vector is okay
- if min_len > 0 and max_len >= rule.train.max_trains then
- rspamd_logger.debugm(N, rspamd_config,
- 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
- ann_key, lens, rule.train.max_trains, what)
- cont_cb()
- else
- rspamd_logger.debugm(N, rspamd_config,
- 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
- ann_key, what, lens, rule.train.max_trains)
- end
- end
-
- else
- rspamd_logger.debugm(N, rspamd_config,
- 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
- what, ann_key, ntrains, rule.train.max_trains)
- cont_cb()
- end
- end
- end
-
- end
-
- local function initiate_train()
- rspamd_logger.infox(rspamd_config,
- 'need to learn ANN %s after %s required learn vectors',
- ann_key, lens)
- do_train_ann(worker, ev_base, rule, set, ann_key)
- end
-
- -- Spam vector is OK, check ham vector length
- local function check_ham_len()
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(initiate_train, 'ham', true), --callback
- 'SCARD', -- command
- {ann_key .. '_ham_set'}
- )
- end
-
- lua_redis.redis_make_request_taskless(ev_base,
- rspamd_config,
- rule.redis,
- nil,
- false, -- is write
- redis_len_cb_gen(check_ham_len, 'spam', false), --callback
- 'SCARD', -- command
- {ann_key .. '_spam_set'}
- )
- end
- end
-
- -- Used to deserialise ANN element from a list
- local function load_ann_profile(element)
- local ucl = require "ucl"
-
- local parser = ucl.parser()
- local res,ucl_err = parser:parse_string(element)
- if not res then
- rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
- ucl_err)
- return nil
- else
- local profile = parser:get_object()
- local checked,schema_err = redis_profile_schema:transform(profile)
- if not checked then
- rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)
-
- return nil
- end
- return checked
- end
- end
-
- -- Function to check or load ANNs from Redis
- local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
- for _,set in pairs(rule.settings) do
- local function members_cb(err, data)
- if err then
- rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
- err)
- set.can_store_vectors = true
- elseif type(data) == 'table' then
- lua_util.debugm(N, cfg, '%s: process element %s:%s',
- what, rule.prefix, set.name)
- process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
- set.can_store_vectors = true
- end
- end
-
- if type(set) == 'table' then
- -- Extract all profiles for some specific settings id
- -- Get the last `max_profiles` recently used
- -- Select the most appropriate to our profile but it should not differ by more
- -- than 30% of symbols
- lua_redis.redis_make_request_taskless(ev_base,
- cfg,
- rule.redis,
- nil,
- false, -- is write
- members_cb, --callback
- 'ZREVRANGE', -- command
- {set.prefix, '0', tostring(settings.max_profiles)} -- arguments
- )
- end
- end -- Cycle over all settings
-
- return rule.watch_interval
- end
-
- -- Function to clean up old ANNs
- local function cleanup_anns(rule, cfg, ev_base)
- for _,set in pairs(rule.settings) do
- local function invalidate_cb(err, data)
- if err then
- rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
- err)
- elseif type(data) == 'table' then
- for _,expired in ipairs(data) do
- local profile = load_ann_profile(expired)
- rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
- rule.prefix .. ':' .. set.name,
- profile.redis_key,
- profile.version)
- end
- end
- end
-
- if type(set) == 'table' then
- lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
- {ev_base = ev_base, is_write = true},
- invalidate_cb,
- {set.prefix, tostring(settings.max_profiles)})
- end
- end
- end
-
- local function ann_push_vector(task)
- if task:has_flag('skip') then
- lua_util.debugm(N, task, 'do not push data for skipped task')
- return
- end
- if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
- lua_util.debugm(N, task, 'do not push data for manual scan')
- return
- end
-
- local verdict,score = lua_verdict.get_specific_verdict(N, task)
-
- if verdict == 'passthrough' then
- lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
- verdict, score)
-
- return
- end
-
- if score ~= score then
- lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
- verdict)
-
- return
- end
-
- for _,rule in pairs(settings.rules) do
- local set = neural_common.get_rule_settings(task, rule)
-
- if set then
- ann_push_task_result(rule, task, verdict, score, set)
- else
- lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
- end
-
- end
- end
-
-
- -- Initialization part
- if not (neural_common.module_config and type(neural_common.module_config) == 'table')
- or not neural_common.redis_params then
- rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
- lua_util.disable_module(N, "redis")
- return
- end
-
- local rules = neural_common.module_config['rules']
-
- if not rules then
- -- Use legacy configuration
- rules = {}
- rules['default'] = neural_common.module_config
- end
-
- local id = rspamd_config:register_symbol({
- name = 'NEURAL_CHECK',
- type = 'postfilter,callback',
- flags = 'nostat',
- priority = lua_util.symbols_priorities.medium,
- callback = ann_scores_filter
- })
-
- neural_common.settings.rules = {} -- Reset unless validated further in the cycle
-
- if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
- -- Transform to hash for simplicity
- settings.blacklisted_symbols = lua_util.list_to_hash(settings.blacklisted_symbols)
- end
-
- -- Check all rules
- for k,r in pairs(rules) do
- local rule_elt = lua_util.override_defaults(neural_common.default_options, r)
- rule_elt['redis'] = neural_common.redis_params
- rule_elt['anns'] = {} -- Store ANNs here
-
- if not rule_elt.prefix then
- rule_elt.prefix = k
- end
- if not rule_elt.name then
- rule_elt.name = k
- end
- if rule_elt.train.max_train and not rule_elt.train.max_trains then
- rule_elt.train.max_trains = rule_elt.train.max_train
- end
-
- if not rule_elt.profile then rule_elt.profile = {} end
-
- if rule_elt.max_inputs and not has_blas then
- rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
- rule_elt.name, rule_elt.max_inputs)
- rule_elt.max_inputs = nil
- end
-
- rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
- settings.rules[k] = rule_elt
- rspamd_config:set_metric_symbol({
- name = rule_elt.symbol_spam,
- score = 0.0,
- description = 'Neural network SPAM',
- group = 'neural'
- })
- rspamd_config:register_symbol({
- name = rule_elt.symbol_spam,
- type = 'virtual',
- flags = 'nostat',
- parent = id
- })
-
- rspamd_config:set_metric_symbol({
- name = rule_elt.symbol_ham,
- score = -0.0,
- description = 'Neural network HAM',
- group = 'neural'
- })
- rspamd_config:register_symbol({
- name = rule_elt.symbol_ham,
- type = 'virtual',
- flags = 'nostat',
- parent = id
- })
- end
-
- rspamd_config:register_symbol({
- name = 'NEURAL_LEARN',
- type = 'idempotent,callback',
- flags = 'nostat,explicit_disable,ignore_passthrough',
- callback = ann_push_vector
- })
-
- -- We also need to deal with settings
- rspamd_config:add_post_init(neural_common.process_rules_settings)
-
- -- Add training scripts
- for _,rule in pairs(settings.rules) do
- neural_common.load_scripts(rule.redis)
- -- This function will check ANNs in Redis when a worker is loaded
- rspamd_config:add_on_load(function(cfg, ev_base, worker)
- if worker:is_scanner() then
- rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
- 'try_load_ann')
- end)
- end
-
- if worker:is_primary_controller() then
- -- We also want to train neural nets when they have enough data
- rspamd_config:add_periodic(ev_base, 0.0,
- function(_, _)
- -- Clean old ANNs
- cleanup_anns(rule, cfg, ev_base)
- return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
- 'try_train_ann')
- end)
- end
- end)
- end
|