diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-27 15:35:42 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2020-08-27 15:35:42 +0100 |
commit | 844acefdab6da55af0b371419ca8039f2bd78d29 (patch) | |
tree | 73870f631b3418ebbffdf2e0d6bd2ce28fdb46c5 /src/plugins | |
parent | 0641629ce8f6ce9f348b4353c3fc8ce667d15566 (diff) | |
download | rspamd-844acefdab6da55af0b371419ca8039f2bd78d29.tar.gz rspamd-844acefdab6da55af0b371419ca8039f2bd78d29.zip |
[Project] Neural: Add PCA loading logic
Diffstat (limited to 'src/plugins')
-rw-r--r-- | src/plugins/lua/neural.lua | 118 |
1 files changed, 75 insertions, 43 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index a3027662c..352d397d5 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -22,6 +22,7 @@ end local rspamd_logger = require "rspamd_logger" local rspamd_util = require "rspamd_util" local rspamd_kann = require "rspamd_kann" +local rspamd_text = require "rspamd_text" local lua_redis = require "lua_redis" local lua_util = require "lua_util" local rspamd_tensor = require "rspamd_tensor" @@ -71,6 +72,7 @@ local redis_profile_schema = ts.shape{ } local has_blas = rspamd_tensor.has_blas() +local text_cookie = rspamd_text.cookie -- Rule structure: -- * static config fields (see `default_options`) @@ -327,7 +329,7 @@ local function ann_scores_filter(task) local vec = result_to_vector(task, profile) local score - local out = ann:apply1(vec) + local out = ann:apply1(vec, set.ann.pca) score = out[1] local symscore = string.format('%.3f', score) @@ -940,52 +942,81 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', ann_key, err) else - if type(data) == 'string' then - local _err,ann_data = rspamd_util.zstd_decompress(data) - local ann - - if _err or not ann_data then - rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', - rule.prefix .. ':' .. set.name, ann_key, _err) - return + if type(data) == 'table' then + if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then + local _err,ann_data = rspamd_util.zstd_decompress(data[1]) + local ann + + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return + else + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + digest = profile.digest, + version = profile.version, + symbols = profile.symbols, + distance = min_diff, + redis_key = profile.redis_key + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + set.ann.ann = ann -- To avoid serialization + + local function rank_cb(_, _) + -- TODO: maybe add some logging + end + -- Also update rank for the loaded ANN to avoid removal + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + {set.prefix, tostring(rspamd_util.get_time()), profile_serialized} + ) + rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #ann_data, profile.version) + else + rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) + end + end else - ann = rspamd_kann.load(ann_data) - - if ann then - set.ann = { - digest = profile.digest, - version = profile.version, - symbols = profile.symbols, - distance = min_diff, - redis_key = profile.redis_key - } - - local ucl = require "ucl" - local profile_serialized = ucl.to_format(profile, 'json-compact', true) - set.ann.ann = ann -- To avoid serialization - - local function rank_cb(_, _) - -- TODO: maybe add some logging + lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) + end + if set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then + -- PCA table + local _err,pca_data = rspamd_util.zstd_decompress(data[2]) + if pca_data then + if rule.max_inputs then + -- We can use PCA + set.ann.pca = rspamd_tensor.load(pca_data) + else + -- no need in pca, why is it there? + rspamd_logger.warnx(rspamd_config, 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', + rule.prefix, set.name, ann_key) end - -- Also update rank for the loaded ANN to avoid removal - lua_redis.redis_make_request_taskless(ev_base, - rspamd_config, - rule.redis, - nil, - true, -- is write - rank_cb, --callback - 'ZADD', -- command - {set.prefix, tostring(rspamd_util.get_time()), profile_serialized} - ) - rspamd_logger.infox(rspamd_config, 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', - rule.prefix, set.name, ann_key, #ann_data, profile.version) else - rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s:%s from Redis key %s', - rule.prefix, set.name, ann_key) + -- pca can be missing merely if we have no max_inputs + if rule.max_inputs then + rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s', + rule.prefix, set.name, ann_key, _err) + set.ann.ann = nil + else + -- It is okay + set.ann.pca = nil + end end end else - lua_util.debugm(N, rspamd_config, 'no ANN for %s:%s in Redis key %s', + lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', rule.prefix, set.name, ann_key) end end @@ -996,8 +1027,9 @@ local function load_new_ann(rule, ev_base, set, profile, min_diff) nil, false, -- is write data_cb, --callback - 'HGET', -- command - {ann_key, 'ann'} -- arguments + 'HMGET', -- command + {ann_key, 'ann', 'pca'}, -- arguments + {opaque_data = true} ) end |