aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-17 10:04:59 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-09-17 10:04:59 +0100
commit609d4f862202c43241fdab06b0795db6281ca3f7 (patch)
tree22b838fc9f19772242d919a668f97b2b068d02c8
parent8308d6a6776ad73b8c0e61fcfdb55007434621cf (diff)
downloadrspamd-609d4f862202c43241fdab06b0795db6281ca3f7.tar.gz
rspamd-609d4f862202c43241fdab06b0795db6281ca3f7.zip
[Feature] Allow to specify number of threads for ANN learning
-rw-r--r--lualib/lua_nn.lua1
-rw-r--r--src/plugins/lua/fann_redis.lua4
2 files changed, 5 insertions, 0 deletions
diff --git a/lualib/lua_nn.lua b/lualib/lua_nn.lua
index 0e8977d37..d0d2d5265 100644
--- a/lualib/lua_nn.lua
+++ b/lualib/lua_nn.lua
@@ -22,6 +22,7 @@ local lua_nn_models = {}
if rspamd_config:has_torch() then
torch = require "torch"
+ torch.setnumthreads(1)
end
if torch then
diff --git a/src/plugins/lua/fann_redis.lua b/src/plugins/lua/fann_redis.lua
index 2751b5d79..ab3da0035 100644
--- a/src/plugins/lua/fann_redis.lua
+++ b/src/plugins/lua/fann_redis.lua
@@ -47,6 +47,7 @@ local default_options = {
mse = 0.001,
autotrain = true,
train_prob = 1.0,
+ learn_threads = 1,
},
use_settings = false,
per_user = false,
@@ -781,6 +782,9 @@ local function train_fann(rule, _, ev_base, elt, worker)
dataset.size = function() return #dataset end
local function train_torch()
+ if rule.train.learn_threads > 1 then
+ torch.setnumthreads(rule.train.learn_threads)
+ end
local criterion = nn.MSECriterion()
local trainer = nn.StochasticGradient(fanns[elt].fann_train,
criterion)