aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/fann_redis.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/fann_redis.lua')
-rw-r--r--src/plugins/lua/fann_redis.lua20
1 files changed, 13 insertions, 7 deletions
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 7f475b5eb..830a02f3f 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -37,19 +37,25 @@ local data = {
-- Lua script to train a row
-- Uses the following keys:
--- key1 - prefix for keys
--- key2 - max count of learns
+-- key1 - prefix for fann
+-- key2 - fann suffix (settings id)
-- key3 - spam or ham
-- returns 1 or 0: 1 - allow learn, 0 - not allow learn
local redis_lua_script_can_train = [[
- local locked = redis.call('GET', KEYS[1] .. '_locked')
+ local prefix = KEYS[1] .. KEYS[2]
+ local locked = redis.call('GET', prefix .. '_locked')
if locked then return 0 end
local nspam = 0
local nham = 0
- local ret = redis.call('LLEN', KEYS[1] .. '_spam')
+ local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2])
+ if not exists then
+ redis.call('SADD', KEYS[1], KEYS[2])
+ end
+
+ local ret = redis.call('LLEN', prefix .. '_spam')
if ret then nspam = tonumber(ret) end
- ret = redis.call('LLEN', KEYS[1] .. '_ham')
+ ret = redis.call('LLEN', prefix .. '_ham')
if ret then nham = tonumber(ret) end
if KEYS[3] == 'spam' then
@@ -354,10 +360,10 @@ local function fann_train_callback(score, required_score, results, cf, id, opts,
redis_make_request(ev_base,
rspamd_config,
nil,
- false, -- is write
+ true, -- is write
can_train_cb, --callback
'EVALSHA', -- command
- {redis_can_train_sha, '3', fname, tostring(max_trains), k} -- arguments
+ {redis_can_train_sha, '3', fann_prefix, id, k} -- arguments
)
end
end