diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-08 16:02:45 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-03-08 16:02:45 +0000 |
commit | c17368ed729788433d28b7116764e29dd84e2f86 (patch) | |
tree | d25ab25aaaf121a3c4747dbbd70b5f8c8c260f79 | |
parent | 42952fac4d300b78334ea81c5387aa4955e8fd96 (diff) | |
download | rspamd-c17368ed729788433d28b7116764e29dd84e2f86.tar.gz rspamd-c17368ed729788433d28b7116764e29dd84e2f86.zip |
[Minor] Some adjustments to neural module
-rw-r--r-- | conf/modules.d/neural.conf | 2 | ||||
-rw-r--r-- | src/plugins/neural.lua | 5 |
2 files changed, 5 insertions, 2 deletions
diff --git a/conf/modules.d/neural.conf b/conf/modules.d/neural.conf index 296ee2f2a..1c27403bf 100644 --- a/conf/modules.d/neural.conf +++ b/conf/modules.d/neural.conf @@ -21,6 +21,8 @@ neural { max_usages = 20; # Number of learn iterations while ANN data is valid spam_score = 8; # Score to learn spam ham_score = -2; # Score to learn ham + learning_rate = 0.01; # Rate of learning (Torch only) + max_iterations = 25; # Maximum iterations of learning (Torch only) } timeout = 20; # Increase redis timeout diff --git a/src/plugins/neural.lua b/src/plugins/neural.lua index b2c7adcfa..e0bab70f3 100644 --- a/src/plugins/neural.lua +++ b/src/plugins/neural.lua @@ -48,6 +48,7 @@ local default_options = { autotrain = true, train_prob = 1.0, learn_threads = 1, + learning_rate = 0.01, }, use_settings = false, per_user = false, @@ -92,7 +93,7 @@ local redis_lua_script_can_train = [[ lim = lim + lim * 0.1 local exists = redis.call('SISMEMBER', KEYS[1], KEYS[2]) - if not exists or exists == 0 then + if not exists or tonumber(exists) == 0 then redis.call('SADD', KEYS[1], KEYS[2]) end @@ -669,7 +670,7 @@ local function train_ann(rule, _, ev_base, elt, worker) local criterion = nn.MSECriterion() local trainer = nn.StochasticGradient(anns[elt].ann_train, criterion) - trainer.learning_rate = 0.01 + trainer.learning_rate = rule.train.learning_rate trainer.verbose = false trainer.maxIteration = rule.train.max_iterations trainer.hookIteration = function(self, iteration, currentError) |