diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/neural.lua | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/src/plugins/neural.lua b/src/plugins/neural.lua index b2c7adcfa..e0bab70f3 100644 --- a/src/plugins/neural.lua +++ b/src/plugins/neural.lua @@ -48,6 +48,7 @@ local default_options = { autotrain = true, train_prob = 1.0, learn_threads = 1, + learning_rate = 0.01, }, use_settings = false, per_user = false, @@ -92,7 +93,7 @@ local redis_lua_script_can_train = [[ lim = lim + lim * 0.1 local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2]) - if not exists or exists == 0 then + if not exists or tonumber(exists) == 0 then redis.call('SADD', KEYS[1], KEYS[2]) end @@ -669,7 +670,7 @@ local function train_ann(rule, _, ev_base, elt, worker) local criterion = nn.MSECriterion() local trainer = nn.StochasticGradient(anns[elt].ann_train, criterion) - trainer.learning_rate = 0.01 + trainer.learning_rate = rule.train.learning_rate trainer.verbose = false trainer.maxIteration = rule.train.max_iterations trainer.hookIteration = function(self, iteration, currentError) |