]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Neural: Fix random sampling
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 24 Oct 2019 17:22:51 +0000 (18:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 24 Oct 2019 17:22:51 +0000 (18:22 +0100)
Issue: #3119

src/plugins/lua/neural.lua

index 87df493252eaf5bfe8a3a044478f61ecf5bfe063..faeb664126cabe2e645d13a74a589bf39c999814 100644 (file)
@@ -97,6 +97,7 @@ end
 -- key1 - ann key
 -- key2 - spam or ham
 -- key3 - maximum trains
+-- key4 - sampling coin (as Redis scripts do not allow math.random calls)
 -- returns 1 or 0 + reason: 1 - allow learn, 0 - not allow learn
 local redis_lua_script_can_store_train_vec = [[
   local prefix = KEYS[1]
@@ -105,6 +106,7 @@ local redis_lua_script_can_store_train_vec = [[
   local nspam = 0
   local nham = 0
   local lim = tonumber(KEYS[3])
+  local coin = tonumber(KEYS[4])
 
   local ret = redis.call('LLEN', prefix .. '_spam')
   if ret then nspam = tonumber(ret) end
@@ -116,7 +118,7 @@ local redis_lua_script_can_store_train_vec = [[
       if nspam > nham then
         -- Apply sampling
         local skip_rate = 1.0 - nham / (nspam + 1)
-        if math.random() < skip_rate then
+        if coun < skip_rate then
           return {tostring(-(nspam)),'sampled out with probability ' .. tostring(skip_rate)}
         end
       end
@@ -129,7 +131,7 @@ local redis_lua_script_can_store_train_vec = [[
       if nham > nspam then
         -- Apply sampling
         local skip_rate = 1.0 - nspam / (nham + 1)
-        if math.random() < skip_rate then
+        if coin < skip_rate then
           return {tostring(-(nham)),'sampled out with probability ' .. tostring(skip_rate)}
         end
       end
@@ -488,6 +490,7 @@ local function ann_push_task_result(rule, task, verdict, score, set)
           set.ann.redis_key,
           learn_type,
           tostring(train_opts.max_trains),
+          tostring(math.random()),
         })
   else
     lua_util.debugm(N, task, 'do not push data: train condition not satisfied; reason: %s',