aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/plugins
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2023-03-25 13:49:16 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2023-03-25 13:49:16 +0000
commit266daff34bfd4c7bd6548098ca98ed0e6289488a (patch)
treee778a1037b5fa407a236d6aa933cd511d2fda198 /lualib/plugins
parentdc9b70f7c3ce549c62288f025991765d0294facd (diff)
downloadrspamd-266daff34bfd4c7bd6548098ca98ed0e6289488a.tar.gz
rspamd-266daff34bfd4c7bd6548098ca98ed0e6289488a.zip
[Minor] Neural: Extract lua scripts
Diffstat (limited to 'lualib/plugins')
-rw-r--r--lualib/plugins/neural.lua109
1 files changed, 9 insertions, 100 deletions
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
index 3400f8d27..05dace489 100644
--- a/lualib/plugins/neural.lua
+++ b/lualib/plugins/neural.lua
@@ -96,113 +96,22 @@ local module_config = rspamd_config:get_all_opt(N)
settings = lua_util.override_defaults(settings, module_config)
local redis_params = lua_redis.parse_redis_server('neural')
--- Lua script that checks if we can store a new training vector
--- Uses the following keys:
--- key1 - ann key
--- returns nspam,nham (or nil if locked)
-local redis_lua_script_vectors_len = [[
- local prefix = KEYS[1]
- local locked = redis.call('HGET', prefix, 'lock')
- if locked then
- local host = redis.call('HGET', prefix, 'hostname') or 'unknown'
- return string.format('%s:%s', host, locked)
- end
- local nspam = 0
- local nham = 0
-
- local ret = redis.call('SCARD', prefix .. '_spam_set')
- if ret then nspam = tonumber(ret) end
- ret = redis.call('SCARD', prefix .. '_ham_set')
- if ret then nham = tonumber(ret) end
-
- return {nspam,nham}
-]]
-
--- Lua script to invalidate ANNs by rank
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - number of elements to leave
-local redis_lua_script_maybe_invalidate = [[
- local card = redis.call('ZCARD', KEYS[1])
- local lim = tonumber(KEYS[2])
- if card > lim then
- local to_delete = redis.call('ZRANGE', KEYS[1], 0, card - lim - 1)
- if to_delete then
- for _,k in ipairs(to_delete) do
- local tb = cjson.decode(k)
- if type(tb) == 'table' and type(tb.redis_key) == 'string' then
- redis.call('DEL', tb.redis_key)
- -- Also train vectors
- redis.call('DEL', tb.redis_key .. '_spam_set')
- redis.call('DEL', tb.redis_key .. '_ham_set')
- end
- end
- end
- redis.call('ZREMRANGEBYRANK', KEYS[1], 0, card - lim - 1)
- return to_delete
- else
- return {}
- end
-]]
-
--- Lua script to invalidate ANN from redis
--- Uses the following keys
--- key1 - prefix for keys
--- key2 - current time
--- key3 - key expire
--- key4 - hostname
-local redis_lua_script_maybe_lock = [[
- local locked = redis.call('HGET', KEYS[1], 'lock')
- local now = tonumber(KEYS[2])
- if locked then
- locked = tonumber(locked)
- local expire = tonumber(KEYS[3])
- if now > locked and (now - locked) < expire then
- return {tostring(locked), redis.call('HGET', KEYS[1], 'hostname') or 'unknown'}
- end
- end
- redis.call('HSET', KEYS[1], 'lock', tostring(now))
- redis.call('HSET', KEYS[1], 'hostname', KEYS[4])
- return 1
-]]
-
--- Lua script to save and unlock ANN in redis
--- Uses the following keys
--- key1 - prefix for ANN
--- key2 - prefix for profile
--- key3 - compressed ANN
--- key4 - profile as JSON
--- key5 - expire in seconds
--- key6 - current time
--- key7 - old key
--- key8 - ROC Thresholds
--- key9 - optional PCA
-local redis_lua_script_save_unlock = [[
- local now = tonumber(KEYS[6])
- redis.call('ZADD', KEYS[2], now, KEYS[4])
- redis.call('HSET', KEYS[1], 'ann', KEYS[3])
- redis.call('DEL', KEYS[1] .. '_spam_set')
- redis.call('DEL', KEYS[1] .. '_ham_set')
- redis.call('HDEL', KEYS[1], 'lock')
- redis.call('HDEL', KEYS[7], 'lock')
- redis.call('EXPIRE', KEYS[1], tonumber(KEYS[5]))
- redis.call('HSET', KEYS[1], 'roc_thresholds', KEYS[8])
- if KEYS[9] then
- redis.call('HSET', KEYS[1], 'pca', KEYS[9])
- end
- return 1
-]]
+
+local redis_lua_script_vectors_len = "neural_train_size.lua"
+local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua"
+local redis_lua_script_maybe_lock = "neural_maybe_lock.lua"
+local redis_lua_script_save_unlock = "neural_save_unlock.lua"
local redis_script_id = {}
local function load_scripts()
- redis_script_id.vectors_len = lua_redis.add_redis_script(redis_lua_script_vectors_len,
+ redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
redis_params)
- redis_script_id.maybe_invalidate = lua_redis.add_redis_script(redis_lua_script_maybe_invalidate,
+ redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
redis_params)
- redis_script_id.maybe_lock = lua_redis.add_redis_script(redis_lua_script_maybe_lock,
+ redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
redis_params)
- redis_script_id.save_unlock = lua_redis.add_redis_script(redis_lua_script_save_unlock,
+ redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
redis_params)
end