diff options
-rw-r--r-- | lualib/rspamadm/rescore.lua | 182 |
1 files changed, 10 insertions, 172 deletions
diff --git a/lualib/rspamadm/rescore.lua b/lualib/rspamadm/rescore.lua index 9b4d3a4ce..cffba5927 100644 --- a/lualib/rspamadm/rescore.lua +++ b/lualib/rspamadm/rescore.lua @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. ]]-- +--[[ local lua_util = require "lua_util" local ucl = require "ucl" local logger = require "rspamd_logger" -local optim = require "optim" local rspamd_util = require "rspamd_util" local argparse = require "argparse" local rescore_utility = require "rescore_utility" @@ -128,7 +128,6 @@ parser:option "--l2" :default(0.0) local function make_dataset_from_logs(logs, all_symbols, spam_score) - -- Returns a list of {input, output} for torch SGD train local inputs = {} local outputs = {} @@ -146,8 +145,8 @@ local function make_dataset_from_logs(logs, all_symbols, spam_score) local symbols_set = {} for i=4,#log do - if not ignore_symbols[log[i]] then - symbols_set[log[i]] = true + if not ignore_symbols[ log[i] ] then + symbols_set[log[i] ] = true end end @@ -167,16 +166,6 @@ local function make_dataset_from_logs(logs, all_symbols, spam_score) end local function init_weights(all_symbols, original_symbol_scores) - - local weights = torch.Tensor(#all_symbols) - - for i, symbol in pairs(all_symbols) do - local score = original_symbol_scores[symbol] - if not score then score = 0 end - weights[i] = score - end - - return weights end local function shuffle(logs, messages) @@ -238,7 +227,7 @@ local function update_logs(logs, symbol_scores) for j=4,#log do log[j] = log[j]:gsub("%s+", "") - score = score + (symbol_scores[log[j]] or 0) + score = score + (symbol_scores[log[j] ] or 0) end log[2] = lua_util.round(score, 2) @@ -315,7 +304,7 @@ False positive rate: %.2f %% False negative rate: %.2f %% Overall accuracy: %.2f %% Slowest message: %.2f (%s) -]] +] ] logger.message("\nStatistics at threshold: " .. threshold) @@ -332,154 +321,6 @@ end -- training function local function train(dataset, opt, model, criterion, epoch, all_symbols, spam_threshold, initial_weights) - -- epoch tracker - epoch = epoch or 1 - - -- local vars - local time = rspamd_util.get_ticks() - local confusion = optim.ConfusionMatrix({'ham', 'spam'}) - - -- do one epoch - - local lbfgsState - local sgdState - - local batch_size = opt.batch - - logger.messagex("trainer epoch #%s, %s batch", epoch, batch_size) - - for t = 1,dataset:size(),batch_size do - -- create mini batch - local k = 1 - local last = math.min(t + batch_size - 1, dataset:size()) - local inputs = torch.Tensor(last - t + 1, #all_symbols) - local targets = torch.Tensor(last - t + 1) - for i = t,last do - -- load new sample - local sample = dataset[i] - local input = sample[1]:clone() - local target = sample[2]:clone() - --target = target:squeeze() - inputs[k] = input - targets[k] = target - k = k + 1 - end - - local parameters,gradParameters = model:getParameters() - - -- create closure to evaluate f(X) and df/dX - local feval = function(x) - -- just in case: - collectgarbage() - - -- get new parameters - if x ~= parameters then - parameters:copy(x) - end - - -- reset gradients - gradParameters:zero() - - -- evaluate function for complete mini batch - local outputs = model:forward(inputs) - local f = criterion:forward(outputs, targets) - - -- estimate df/dW - local df_do = criterion:backward(outputs, targets) - model:backward(inputs, df_do) - - -- penalties (L1 and L2): - local l1 = tonumber(opt.l1) or 0 - local l2 = tonumber(opt.l2) or 0 - - if l1 ~= 0 or l2 ~= 0 then - -- locals: - local norm,sign= torch.norm,torch.sign - - local diff = parameters - initial_weights - -- Loss: - f = f + l1 * norm(diff,1) - f = f + l2 * norm(diff,2)^2/2 - - -- Gradients: - gradParameters:add( sign(diff):mul(l1) + diff:clone():mul(l2) ) - end - - -- update confusion - for i = 1,(last - t + 1) do - local class_predicted, target_class = 1, 1 - if outputs[i][1] > 0.5 then class_predicted = 2 end - if targets[i] > 0.5 then target_class = 2 end - confusion:add(class_predicted, target_class) - end - - -- return f and df/dX - return f,gradParameters - end - - -- optimize on current mini-batch - if opt.optim == 'LBFGS' then - - -- Perform LBFGS step: - lbfgsState = lbfgsState or { - maxIter = opt.iters, - lineSearch = optim.lswolfe - } - optim.lbfgs(feval, parameters, lbfgsState) - - -- disp report: - logger.messagex('LBFGS step') - logger.messagex(' - progress in batch: ' .. t .. '/' .. dataset:size()) - logger.messagex(' - nb of iterations: ' .. lbfgsState.nIter) - logger.messagex(' - nb of function evalutions: ' .. lbfgsState.funcEval) - - elseif opt.optim == 'ADAM' then - sgdState = sgdState or { - learningRate = tonumber(opts.learning_rate),-- opt.learningRate, - momentum = tonumber(opts.momentum), -- opt.momentum, - learningRateDecay = tonumber(opts.learning_rate_decay), - weightDecay = tonumber(opts.weight_decay), - } - optim.adam(feval, parameters, sgdState) - elseif opt.optim == 'ADAGRAD' then - sgdState = sgdState or { - learningRate = tonumber(opts.learning_rate),-- opt.learningRate, - momentum = tonumber(opts.momentum), -- opt.momentum, - learningRateDecay = tonumber(opts.learning_rate_decay), - weightDecay = tonumber(opts.weight_decay), - } - optim.adagrad(feval, parameters, sgdState) - elseif opt.optim == 'SGD' then - sgdState = sgdState or { - learningRate = tonumber(opts.learning_rate),-- opt.learningRate, - momentum = tonumber(opts.momentum), -- opt.momentum, - learningRateDecay = tonumber(opts.learning_rate_decay), - weightDecay = tonumber(opts.weight_decay), - } - optim.sgd(feval, parameters, sgdState) - elseif opt.optim == 'NAG' then - sgdState = sgdState or { - learningRate = tonumber(opts.learning_rate),-- opt.learningRate, - momentum = tonumber(opts.momentum), -- opt.momentum, - learningRateDecay = tonumber(opts.learning_rate_decay), - weightDecay = tonumber(opts.weight_decay), - } - optim.nag(feval, parameters, sgdState) - else - logger.errx('unknown optimization method: %s', opt.optim) - os.exit(1) - end - end - - -- time taken - time = rspamd_util.get_ticks() - time - time = time / dataset:size() - logger.messagex("time to learn 1 sample = " .. (time*1000) .. 'ms') - - -- logger.messagex confusion matrix - logger.messagex('confusion: %s', tostring(confusion)) - logger.messagex('%s mean class accuracy (train set)', confusion.totalValid * 100) - confusion:zero() end local learning_rates = { @@ -493,15 +334,13 @@ local function get_threshold() local actions = rspamd_config:get_all_actions() if opts['spam-action'] then - return (actions[opts['spam-action']] or 0),actions['reject'] + return (actions[opts['spam-action'] ] or 0),actions['reject'] end return (actions['add header'] or actions['rewrite subject'] or actions['reject']), actions['reject'] end local function handler(args) - torch = require "torch" - nn = require "nn" opts = parser:parse(args) if not opts['log'] then parser:error('no log specified') @@ -640,16 +479,12 @@ local function handler(args) end shuffle(logs, messages) - torch.setdefaulttensortype('torch.FloatTensor') - local train_logs, validation_logs = split_logs(logs, messages,70) local cv_logs, test_logs = split_logs(validation_logs[1], validation_logs[2], 50) local dataset = make_dataset_from_logs(train_logs[1], all_symbols, reject_score) - -- Start of perceptron training local input_size = #all_symbols - torch.setnumthreads(opts['threads']) local linear_module = nn.Linear(input_size, 1, false) local activation = nn.Sigmoid() @@ -747,4 +582,7 @@ return { handler = handler, description = parser._description, name = 'rescore' -}
\ No newline at end of file +} +--]] + +return nil
\ No newline at end of file |