From c17368ed729788433d28b7116764e29dd84e2f86 Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Thu, 8 Mar 2018 16:02:45 +0000 Subject: [PATCH] [Minor] Some adjustments to neural module --- conf/modules.d/neural.conf | 2 ++ src/plugins/neural.lua | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/conf/modules.d/neural.conf b/conf/modules.d/neural.conf index 296ee2f2a..1c27403bf 100644 --- a/conf/modules.d/neural.conf +++ b/conf/modules.d/neural.conf @@ -21,6 +21,8 @@ neural { max_usages = 20; # Number of learn iterations while ANN data is valid spam_score = 8; # Score to learn spam ham_score = -2; # Score to learn ham + learning_rate = 0.01; # Rate of learning (Torch only) + max_iterations = 25; # Maximum iterations of learning (Torch only) } timeout = 20; # Increase redis timeout diff --git a/src/plugins/neural.lua b/src/plugins/neural.lua index b2c7adcfa..e0bab70f3 100644 --- a/src/plugins/neural.lua +++ b/src/plugins/neural.lua @@ -48,6 +48,7 @@ local default_options = { autotrain = true, train_prob = 1.0, learn_threads = 1, + learning_rate = 0.01, }, use_settings = false, per_user = false, @@ -92,7 +93,7 @@ local redis_lua_script_can_train = [[ lim = lim + lim * 0.1 local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2]) - if not exists or exists == 0 then + if not exists or tonumber(exists) == 0 then redis.call('SADD', KEYS[1], KEYS[2]) end @@ -669,7 +670,7 @@ local function train_ann(rule, _, ev_base, elt, worker) local criterion = nn.MSECriterion() local trainer = nn.StochasticGradient(anns[elt].ann_train, criterion) - trainer.learning_rate = 0.01 + trainer.learning_rate = rule.train.learning_rate trainer.verbose = false trainer.maxIteration = rule.train.max_iterations trainer.hookIteration = function(self, iteration, currentError) -- 2.39.5