diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-06 09:39:48 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2019-07-06 09:39:48 +0100 |
commit | ff024667584fa62f342507647ef0beab5917cf58 (patch) | |
tree | b0ec177b779b1bc34cb469a1a95e5247557fe5a6 | |
parent | 178cb973ce7de1738fecdf5ca3e6ebf43bc51020 (diff) | |
download | rspamd-ff024667584fa62f342507647ef0beab5917cf58.tar.gz rspamd-ff024667584fa62f342507647ef0beab5917cf58.zip |
[Project] Add ANN load function
-rw-r--r-- | src/plugins/lua/neural.lua | 45 |
1 files changed, 44 insertions, 1 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index ff53249c5..cca6f647c 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -725,7 +725,50 @@ end -- serialize profile one more time and set its rank to the current time -- set.ann fields are set according to Redis data received local function load_new_ann(rule, ev_base, set, profile, min_diff) + local ann_key = profile.ann_key + local function data_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', + ann_key, err) + else + local _err,ann_data = rspamd_util.zstd_decompress(data[1]) + local ann + + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return + else + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + ann = ann, + version = profile.version, + symbols = profile.symbols, + distance = min_diff + } + + rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s', + rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version) + else + rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s', + rule.prefix .. ':' .. set.name, ann_key) + end + end + end + end + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HGET', -- command + {ann_key, 'ann'}, -- arguments + {opaque_data = true} + ) end -- Used to check an element in Redis serialized as JSON @@ -740,7 +783,7 @@ local function process_existing_ann(rule, ev_base, set, profiles) for _,elt in fun.iter(profiles) do if elt and elt.symbols then local dist = lua_util.distance_sorted(elt.symbols, my_symbols) - + -- Check distance if dist < #my_symbols * .3 then if dist < min_diff then min_diff = dist |