aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/plugins/lua/neural.lua36
-rw-r--r--test/functional/lua/neural.lua2
2 files changed, 4 insertions, 34 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
index 3d1c387a5..894d42e30 100644
--- a/src/plugins/lua/neural.lua
+++ b/src/plugins/lua/neural.lua
@@ -35,36 +35,6 @@ local N = "neural"
local settings = neural_common.settings
--- Module vars
-local default_options = {
- train = {
- max_trains = 1000,
- max_epoch = 1000,
- max_usages = 10,
- max_iterations = 25, -- Torch style
- mse = 0.001,
- autotrain = true,
- train_prob = 1.0,
- learn_threads = 1,
- learn_mode = 'balanced', -- Possible values: balanced, proportional
- learning_rate = 0.01,
- classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
- spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
- ham_skip_prob = 0.0, -- proportional mode: ham skip probability
- store_pool_only = false, -- store tokens in cache only (disables autotrain);
- -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
- },
- watch_interval = 60.0,
- lock_expire = 600,
- learning_spawned = false,
- ann_expire = 60 * 60 * 24 * 2, -- 2 days
- hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
- symbol_spam = 'NEURAL_SPAM',
- symbol_ham = 'NEURAL_HAM',
- max_inputs = nil, -- when PCA is used
- blacklisted_symbols = {}, -- list of symbols skipped in neural processing
-}
-
local redis_profile_schema = ts.shape{
digest = ts.string,
symbols = ts.array_of(ts.string),
@@ -213,8 +183,8 @@ local function ann_push_task_result(rule, task, verdict, score, set)
-- Explicitly store tokens in cache
local vec = neural_common.result_to_vector(task, set)
- task:cache_set('neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
- task:cache_set('neural_profile_digest', set.digest)
+ task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
+ task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
skip_reason = 'store_pool_only has been set'
end
end
@@ -881,7 +851,7 @@ end
-- Check all rules
for k,r in pairs(rules) do
- local rule_elt = lua_util.override_defaults(default_options, r)
+ local rule_elt = lua_util.override_defaults(neural_common.default_options, r)
rule_elt['redis'] = neural_common.redis_params
rule_elt['anns'] = {} -- Store ANNs here
diff --git a/test/functional/lua/neural.lua b/test/functional/lua/neural.lua
index ccdad1b68..7ea29a252 100644
--- a/test/functional/lua/neural.lua
+++ b/test/functional/lua/neural.lua
@@ -49,7 +49,7 @@ rspamd_config.SAVE_NN_ROW_IDEMPOTENT = {
logger.errx(task, err)
return
end
- f:write(tohex(task:cache_get('neural_vec_mpack') or ''))
+ f:write(tohex(task:cache_get('SHORT_neural_vec_mpack') or ''))
f:close()
return
end,