aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/rescore.lua182
1 files changed, 10 insertions, 172 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua
index 9b4d3a4ce..cffba5927 100644
--- a/lualib/rspamadm/rescore.lua
+++ b/lualib/rspamadm/rescore.lua
@@ -14,10 +14,10 @@ See the License for the specific language governing permissions and
limitations under the License.
]]--
+--[[
local lua_util = require "lua_util"
local ucl = require "ucl"
local logger = require "rspamd_logger"
-local optim = require "optim"
local rspamd_util = require "rspamd_util"
local argparse = require "argparse"
local rescore_utility = require "rescore_utility"
@@ -128,7 +128,6 @@ parser:option "--l2"
:default(0.0)
local function make_dataset_from_logs(logs, all_symbols, spam_score)
- -- Returns a list of {input, output} for torch SGD train
local inputs = {}
local outputs = {}
@@ -146,8 +145,8 @@ local function make_dataset_from_logs(logs, all_symbols, spam_score)
local symbols_set = {}
for i=4,#log do
- if not ignore_symbols[log[i]] then
- symbols_set[log[i]] = true
+ if not ignore_symbols[ log[i] ] then
+ symbols_set[log[i] ] = true
end
end
@@ -167,16 +166,6 @@ local function make_dataset_from_logs(logs, all_symbols, spam_score)
end
local function init_weights(all_symbols, original_symbol_scores)
-
- local weights = torch.Tensor(#all_symbols)
-
- for i, symbol in pairs(all_symbols) do
- local score = original_symbol_scores[symbol]
- if not score then score = 0 end
- weights[i] = score
- end
-
- return weights
end
local function shuffle(logs, messages)
@@ -238,7 +227,7 @@ local function update_logs(logs, symbol_scores)
for j=4,#log do
log[j] = log[j]:gsub("%s+", "")
- score = score + (symbol_scores[log[j]] or 0)
+ score = score + (symbol_scores[log[j] ] or 0)
end
log[2] = lua_util.round(score, 2)
@@ -315,7 +304,7 @@ False positive rate: %.2f %%
False negative rate: %.2f %%
Overall accuracy: %.2f %%
Slowest message: %.2f (%s)
-]]
+] ]
logger.message("\nStatistics at threshold: " .. threshold)
@@ -332,154 +321,6 @@ end
-- training function
local function train(dataset, opt, model, criterion, epoch,
all_symbols, spam_threshold, initial_weights)
- -- epoch tracker
- epoch = epoch or 1
-
- -- local vars
- local time = rspamd_util.get_ticks()
- local confusion = optim.ConfusionMatrix({'ham', 'spam'})
-
- -- do one epoch
-
- local lbfgsState
- local sgdState
-
- local batch_size = opt.batch
-
- logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size)
-
- for t = 1,dataset:size(),batch_size do
- -- create mini batch
- local k = 1
- local last = math.min(t + batch_size - 1, dataset:size())
- local inputs = torch.Tensor(last - t + 1, #all_symbols)
- local targets = torch.Tensor(last - t + 1)
- for i = t,last do
- -- load new sample
- local sample = dataset[i]
- local input = sample[1]:clone()
- local target = sample[2]:clone()
- --target = target:squeeze()
- inputs[k] = input
- targets[k] = target
- k = k + 1
- end
-
- local parameters,gradParameters = model:getParameters()
-
- -- create closure to evaluate f(X) and df/dX
- local feval = function(x)
- -- just in case:
- collectgarbage()
-
- -- get new parameters
- if x ~= parameters then
- parameters:copy(x)
- end
-
- -- reset gradients
- gradParameters:zero()
-
- -- evaluate function for complete mini batch
- local outputs = model:forward(inputs)
- local f = criterion:forward(outputs, targets)
-
- -- estimate df/dW
- local df_do = criterion:backward(outputs, targets)
- model:backward(inputs, df_do)
-
- -- penalties (L1 and L2):
- local l1 = tonumber(opt.l1) or 0
- local l2 = tonumber(opt.l2) or 0
-
- if l1 ~= 0 or l2 ~= 0 then
- -- locals:
- local norm,sign= torch.norm,torch.sign
-
- local diff = parameters - initial_weights
- -- Loss:
- f = f + l1 * norm(diff,1)
- f = f + l2 * norm(diff,2)^2/2
-
- -- Gradients:
- gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) )
- end
-
- -- update confusion
- for i = 1,(last - t + 1) do
- local class_predicted, target_class = 1, 1
- if outputs[i][1] > 0.5 then class_predicted = 2 end
- if targets[i] > 0.5 then target_class = 2 end
- confusion:add(class_predicted, target_class)
- end
-
- -- return f and df/dX
- return f,gradParameters
- end
-
- -- optimize on current mini-batch
- if opt.optim == 'LBFGS' then
-
- -- Perform LBFGS step:
- lbfgsState = lbfgsState or {
- maxIter = opt.iters,
- lineSearch = optim.lswolfe
- }
- optim.lbfgs(feval, parameters, lbfgsState)
-
- -- disp report:
- logger.messagex('LBFGS step')
- logger.messagex(' - progress in batch: ' .. t .. '/' .. dataset:size())
- logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter)
- logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval)
-
- elseif opt.optim == 'ADAM' then
- sgdState = sgdState or {
- learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
- momentum = tonumber(opts.momentum), -- opt.momentum,
- learningRateDecay = tonumber(opts.learning_rate_decay),
- weightDecay = tonumber(opts.weight_decay),
- }
- optim.adam(feval, parameters, sgdState)
- elseif opt.optim == 'ADAGRAD' then
- sgdState = sgdState or {
- learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
- momentum = tonumber(opts.momentum), -- opt.momentum,
- learningRateDecay = tonumber(opts.learning_rate_decay),
- weightDecay = tonumber(opts.weight_decay),
- }
- optim.adagrad(feval, parameters, sgdState)
- elseif opt.optim == 'SGD' then
- sgdState = sgdState or {
- learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
- momentum = tonumber(opts.momentum), -- opt.momentum,
- learningRateDecay = tonumber(opts.learning_rate_decay),
- weightDecay = tonumber(opts.weight_decay),
- }
- optim.sgd(feval, parameters, sgdState)
- elseif opt.optim == 'NAG' then
- sgdState = sgdState or {
- learningRate = tonumber(opts.learning_rate),-- opt.learningRate,
- momentum = tonumber(opts.momentum), -- opt.momentum,
- learningRateDecay = tonumber(opts.learning_rate_decay),
- weightDecay = tonumber(opts.weight_decay),
- }
- optim.nag(feval, parameters, sgdState)
- else
- logger.errx('unknown optimization method: %s', opt.optim)
- os.exit(1)
- end
- end
-
- -- time taken
- time = rspamd_util.get_ticks() - time
- time = time / dataset:size()
- logger.messagex("time to learn 1 sample = " .. (time*1000) .. 'ms')
-
- -- logger.messagex confusion matrix
- logger.messagex('confusion: %s', tostring(confusion))
- logger.messagex('%s mean class accuracy (train set)', confusion.totalValid * 100)
- confusion:zero()
end
local learning_rates = {
@@ -493,15 +334,13 @@ local function get_threshold()
local actions = rspamd_config:get_all_actions()
if opts['spam-action'] then
- return (actions[opts['spam-action']] or 0),actions['reject']
+ return (actions[opts['spam-action'] ] or 0),actions['reject']
end
return (actions['add header'] or actions['rewrite subject']
or actions['reject']), actions['reject']
end
local function handler(args)
- torch = require "torch"
- nn = require "nn"
opts = parser:parse(args)
if not opts['log'] then
parser:error('no log specified')
@@ -640,16 +479,12 @@ local function handler(args)
end
shuffle(logs, messages)
- torch.setdefaulttensortype('torch.FloatTensor')
-
local train_logs, validation_logs = split_logs(logs, messages,70)
local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50)
local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score)
-
-- Start of perceptron training
local input_size = #all_symbols
- torch.setnumthreads(opts['threads'])
local linear_module = nn.Linear(input_size, 1, false)
local activation = nn.Sigmoid()
@@ -747,4 +582,7 @@ return {
handler = handler,
description = parser._description,
name = 'rescore'
-} \ No newline at end of file
+}
+--]]
+
+return nil \ No newline at end of file