From: Vsevolod Stakhov Date: Sat, 25 Mar 2023 13:49:16 +0000 (+0000) Subject: [Minor] Neural: Extract lua scripts X-Git-Tag: 3.6~186 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=266daff34bfd4c7bd6548098ca98ed0e6289488a;p=rspamd.git [Minor] Neural: Extract lua scripts --- diff --git a/.luacheckrc b/.luacheckrc index d5a18cc3b..c242bddd8 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -63,6 +63,7 @@ files['/**/lualib/lua_redis.lua'].globals = { files['/**/lualib/redis_scripts/**'].globals = { 'redis', 'KEYS', + 'cjson', } files['/**/src/rspamadm/*'].globals = { diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua index 3400f8d27..05dace489 100644 --- a/lualib/plugins/neural.lua +++ b/lualib/plugins/neural.lua @@ -96,113 +96,22 @@ local module_config = rspamd_config:get_all_opt(N) settings = lua_util.override_defaults(settings, module_config) local redis_params = lua_redis.parse_redis_server('neural') --- Lua script that checks if we can store a new training vector --- Uses the following keys: --- key1 - ann key --- returns nspam,nham (or nil if locked) -local redis_lua_script_vectors_len = [[ - local prefix = KEYS[1] - local locked = redis.call('HGET', prefix, 'lock') - if locked then - local host = redis.call('HGET', prefix, 'hostname') or 'unknown' - return string.format('%s:%s', host, locked) - end - local nspam = 0 - local nham = 0 - - local ret = redis.call('SCARD', prefix .. '_spam_set') - if ret then nspam = tonumber(ret) end - ret = redis.call('SCARD', prefix .. '_ham_set') - if ret then nham = tonumber(ret) end - - return {nspam,nham} -]] - --- Lua script to invalidate ANNs by rank --- Uses the following keys --- key1 - prefix for keys --- key2 - number of elements to leave -local redis_lua_script_maybe_invalidate = [[ - local card = redis.call('ZCARD', KEYS[1]) - local lim = tonumber(KEYS[2]) - if card > lim then - local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) - if to_delete then - for _,k in ipairs(to_delete) do - local tb = cjson.decode(k) - if type(tb) == 'table' and type(tb.redis_key) == 'string' then - redis.call('DEL', tb.redis_key) - -- Also train vectors - redis.call('DEL', tb.redis_key .. '_spam_set') - redis.call('DEL', tb.redis_key .. '_ham_set') - end - end - end - redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) - return to_delete - else - return {} - end -]] - --- Lua script to invalidate ANN from redis --- Uses the following keys --- key1 - prefix for keys --- key2 - current time --- key3 - key expire --- key4 - hostname -local redis_lua_script_maybe_lock = [[ - local locked = redis.call('HGET', KEYS[1], 'lock') - local now = tonumber(KEYS[2]) - if locked then - locked = tonumber(locked) - local expire = tonumber(KEYS[3]) - if now > locked and (now - locked) < expire then - return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'} - end - end - redis.call('HSET', KEYS[1], 'lock', tostring(now)) - redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) - return 1 -]] - --- Lua script to save and unlock ANN in redis --- Uses the following keys --- key1 - prefix for ANN --- key2 - prefix for profile --- key3 - compressed ANN --- key4 - profile as JSON --- key5 - expire in seconds --- key6 - current time --- key7 - old key --- key8 - ROC Thresholds --- key9 - optional PCA -local redis_lua_script_save_unlock = [[ - local now = tonumber(KEYS[6]) - redis.call('ZADD', KEYS[2], now, KEYS[4]) - redis.call('HSET', KEYS[1], 'ann', KEYS[3]) - redis.call('DEL', KEYS[1] .. '_spam_set') - redis.call('DEL', KEYS[1] .. '_ham_set') - redis.call('HDEL', KEYS[1], 'lock') - redis.call('HDEL', KEYS[7], 'lock') - redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) - redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8]) - if KEYS[9] then - redis.call('HSET', KEYS[1], 'pca', KEYS[9]) - end - return 1 -]] + +local redis_lua_script_vectors_len = "neural_train_size.lua" +local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua" +local redis_lua_script_maybe_lock = "neural_maybe_lock.lua" +local redis_lua_script_save_unlock = "neural_save_unlock.lua" local redis_script_id = {} local function load_scripts() - redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len, + redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len, redis_params) - redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate, + redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate, redis_params) - redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock, + redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock, redis_params) - redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock, + redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock, redis_params) end diff --git a/lualib/redis_scripts/neural_maybe_invalidate.lua b/lualib/redis_scripts/neural_maybe_invalidate.lua new file mode 100644 index 000000000..c54871717 --- /dev/null +++ b/lualib/redis_scripts/neural_maybe_invalidate.lua @@ -0,0 +1,25 @@ +-- Lua script to invalidate ANNs by rank +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - number of elements to leave + +local card = redis.call('ZCARD', KEYS[1]) +local lim = tonumber(KEYS[2]) +if card > lim then + local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1) + if to_delete then + for _,k in ipairs(to_delete) do + local tb = cjson.decode(k) + if type(tb) == 'table' and type(tb.redis_key) == 'string' then + redis.call('DEL', tb.redis_key) + -- Also train vectors + redis.call('DEL', tb.redis_key .. '_spam_set') + redis.call('DEL', tb.redis_key .. '_ham_set') + end + end + end + redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1) + return to_delete +else + return {} +end \ No newline at end of file diff --git a/lualib/redis_scripts/neural_maybe_lock.lua b/lualib/redis_scripts/neural_maybe_lock.lua new file mode 100644 index 000000000..7b5c6a60f --- /dev/null +++ b/lualib/redis_scripts/neural_maybe_lock.lua @@ -0,0 +1,19 @@ +-- Lua script lock ANN for learning +-- Uses the following keys +-- key1 - prefix for keys +-- key2 - current time +-- key3 - key expire +-- key4 - hostname + +local locked = redis.call('HGET', KEYS[1], 'lock') +local now = tonumber(KEYS[2]) +if locked then + locked = tonumber(locked) + local expire = tonumber(KEYS[3]) + if now > locked and (now - locked) < expire then + return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'} + end +end +redis.call('HSET', KEYS[1], 'lock', tostring(now)) +redis.call('HSET', KEYS[1], 'hostname', KEYS[4]) +return 1 \ No newline at end of file diff --git a/lualib/redis_scripts/neural_save_unlock.lua b/lualib/redis_scripts/neural_save_unlock.lua new file mode 100644 index 000000000..5af1ddcde --- /dev/null +++ b/lualib/redis_scripts/neural_save_unlock.lua @@ -0,0 +1,24 @@ +-- Lua script to save and unlock ANN in redis +-- Uses the following keys +-- key1 - prefix for ANN +-- key2 - prefix for profile +-- key3 - compressed ANN +-- key4 - profile as JSON +-- key5 - expire in seconds +-- key6 - current time +-- key7 - old key +-- key8 - ROC Thresholds +-- key9 - optional PCA +local now = tonumber(KEYS[6]) +redis.call('ZADD', KEYS[2], now, KEYS[4]) +redis.call('HSET', KEYS[1], 'ann', KEYS[3]) +redis.call('DEL', KEYS[1] .. '_spam_set') +redis.call('DEL', KEYS[1] .. '_ham_set') +redis.call('HDEL', KEYS[1], 'lock') +redis.call('HDEL', KEYS[7], 'lock') +redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5])) +redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8]) +if KEYS[9] then + redis.call('HSET', KEYS[1], 'pca', KEYS[9]) +end +return 1 \ No newline at end of file diff --git a/lualib/redis_scripts/neural_train_size.lua b/lualib/redis_scripts/neural_train_size.lua new file mode 100644 index 000000000..5a00ae3fc --- /dev/null +++ b/lualib/redis_scripts/neural_train_size.lua @@ -0,0 +1,20 @@ +-- Lua script that checks if we can store a new training vector +-- Uses the following keys: +-- key1 - ann key +-- returns nspam,nham (or nil if locked) + +local prefix = KEYS[1] +local locked = redis.call('HGET', prefix, 'lock') +if locked then + local host = redis.call('HGET', prefix, 'hostname') or 'unknown' + return string.format('%s:%s', host, locked) +end +local nspam = 0 +local nham = 0 + +local ret = redis.call('SCARD', prefix .. '_spam_set') +if ret then nspam = tonumber(ret) end +ret = redis.call('SCARD', prefix .. '_ham_set') +if ret then nham = tonumber(ret) end + +return {nspam,nham} \ No newline at end of file