From: Vsevolod Stakhov Date: Sat, 6 Jul 2019 08:39:48 +0000 (+0100) Subject: [Project] Add ANN load function X-Git-Tag: 2.0~650 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=ff024667584fa62f342507647ef0beab5917cf58;p=rspamd.git [Project] Add ANN load function --- diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua index ff53249c5..cca6f647c 100644 --- a/src/plugins/lua/neural.lua +++ b/src/plugins/lua/neural.lua @@ -725,7 +725,50 @@ end -- serialize profile one more time and set its rank to the current time -- set.ann fields are set according to Redis data received local function load_new_ann(rule, ev_base, set, profile, min_diff) + local ann_key = profile.ann_key + local function data_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', + ann_key, err) + else + local _err,ann_data = rspamd_util.zstd_decompress(data[1]) + local ann + + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return + else + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + ann = ann, + version = profile.version, + symbols = profile.symbols, + distance = min_diff + } + + rspamd_logger.infox(rspamd_config, 'loaded ANN for %s from %s; %s bytes compressed; version=%s', + rule.prefix .. ':' .. set.name, ann_key, #ann_data, profile.version) + else + rspamd_logger.errx(rspamd_config, 'cannot deserialize ANN for %s from Redis key %s', + rule.prefix .. ':' .. set.name, ann_key) + end + end + end + end + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HGET', -- command + {ann_key, 'ann'}, -- arguments + {opaque_data = true} + ) end -- Used to check an element in Redis serialized as JSON @@ -740,7 +783,7 @@ local function process_existing_ann(rule, ev_base, set, profiles) for _,elt in fun.iter(profiles) do if elt and elt.symbols then local dist = lua_util.distance_sorted(elt.symbols, my_symbols) - + -- Check distance if dist < #my_symbols * .3 then if dist < min_diff then min_diff = dist