summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-06 09:39:48 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2019-07-06 09:39:48 +0100
commitff024667584fa62f342507647ef0beab5917cf58 (patch)
treeb0ec177b779b1bc34cb469a1a95e5247557fe5a6
parent178cb973ce7de1738fecdf5ca3e6ebf43bc51020 (diff)
downloadrspamd-ff024667584fa62f342507647ef0beab5917cf58.tar.gz
rspamd-ff024667584fa62f342507647ef0beab5917cf58.zip
[Project] Add ANN load function
-rw-r--r--src/plugins/lua/neural.lua45
1 files changed, 44 insertions, 1 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index ff53249c5..cca6f647c 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -725,7 +725,50 @@ end
-- serialize profile one more time and set its rank to the current time
-- set.ann fields are set according to Redis data received
local function load_new_ann(rule, ev_base, set, profile, min_diff)
+ local ann_key = profile.ann_key
+ local function data_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
+ ann_key, err)
+ else
+ local _err,ann_data = rspamd_util.zstd_decompress(data[1])
+ local ann
+
+ if _err or not ann_data then
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+ rule.prefix .. ':' .. set.name, ann_key, _err)
+ return
+ else
+ ann = rspamd_kann.load(ann_data)
+
+ if ann then
+ set.ann = {
+ ann = ann,
+ version = profile.version,
+ symbols = profile.symbols,
+ distance = min_diff
+ }
+
+ rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s',
+ rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version)
+ else
+ rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s',
+ rule.prefix .. ':' .. set.name, ann_key)
+ end
+ end
+ end
+ end
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ data_cb, --callback
+ 'HGET', -- command
+ {ann_key, 'ann'}, -- arguments
+ {opaque_data = true}
+ )
end
-- Used to check an element in Redis serialized as JSON
@@ -740,7 +783,7 @@ local function process_existing_ann(rule, ev_base, set, profiles)
for _,elt in fun.iter(profiles) do
if elt and elt.symbols then
local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
-
+ -- Check distance
if dist < #my_symbols * .3 then
if dist < min_diff then
min_diff = dist