From: Vsevolod Stakhov Date: Sat, 11 Nov 2017 16:32:54 +0000 (+0000) Subject: [Fix] Fix random forests module X-Git-Tag: 1.7.0~448 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=80b59a48c8d35026425ef2db28e281e955c8e08c;p=rspamd.git [Fix] Fix random forests module --- diff --git a/contrib/torch/decisiontree/DataSet.lua b/contrib/torch/decisiontree/DataSet.lua index 505ec86d2..15058a7c6 100644 --- a/contrib/torch/decisiontree/DataSet.lua +++ b/contrib/torch/decisiontree/DataSet.lua @@ -1,5 +1,4 @@ local dt = require "decisiontree._env" -local ipc = require 'libipc' local DataSet = torch.class("dt.DataSet", dt) @@ -77,31 +76,10 @@ function DataSet:sortFeatureValues(inputs) assert(torch.isTensor(inputs)) featureIds:range(1,inputs:size(2)) - local wq = ipc.workqueue() for i=1,inputs:size(2) do - wq:write({i, inputs:select(2, i)}) - end - for i=1,self.nThreads do - wq:write(nil) - end - - ipc.map(self.nThreads, function(wq) - while true do - local data = wq:read() - if data == nil then break end - local featureId = data[1] - local values = data[2] - local sortFeatureValues, sortExampleIds = values:sort(1, false) - sortFeatureValues = nil - wq:write({featureId, sortExampleIds}) - collectgarbage() - end - end, wq) - - for _=1,inputs:size(2) do - local data = wq:read() - local featureId = data[1] - local sortedFeatureExampleIds = data[2] + local featureId = i + local values = inputs:select(2, i) + local _, sortedFeatureExampleIds = values:sort(1, false) dataset[featureId] = sortedFeatureExampleIds end end diff --git a/contrib/torch/decisiontree/GradientBoostTrainer.lua b/contrib/torch/decisiontree/GradientBoostTrainer.lua index 54d1ba8ac..51299b109 100644 --- a/contrib/torch/decisiontree/GradientBoostTrainer.lua +++ b/contrib/torch/decisiontree/GradientBoostTrainer.lua @@ -177,7 +177,7 @@ function GradientBoostTrainer:train(trainSet, featureIds, validSet, verbose) timer:reset() local stop, validLoss, bestDecisionForest = self:validate(trainSet, validSet, decisionForest, bestDecisionForest) if dt.PROFILE then print("validate tree time: "..timer:time().real) end - if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", trainLoss, validLoss)) end + if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", 0, validLoss)) end if stop then if verbose then print(string.format("GBDT early stopped on tree %d", treeId)) end break diff --git a/contrib/torch/decisiontree/init.lua b/contrib/torch/decisiontree/init.lua index fdaf1b56a..26f790b60 100644 --- a/contrib/torch/decisiontree/init.lua +++ b/contrib/torch/decisiontree/init.lua @@ -1,10 +1,8 @@ require 'paths' -require 'xlua' +--require 'xlua' require 'string' require 'os' -require 'sys' -require 'image' -require 'lfs' +--require 'sys' require 'nn' -- these actually return local variables but we will re-require them @@ -31,7 +29,7 @@ require 'decisiontree.math' require 'decisiontree.utils' -- for multi-threading -require 'decisiontree.WorkPool' +--require 'decisiontree.WorkPool' -- abstract classes require 'decisiontree.DecisionTree' @@ -62,8 +60,8 @@ require 'decisiontree.GradientBoostTrainer' require 'decisiontree.GradientBoostState' -- TreeState subclass -- unit tests and benchmarks -require 'decisiontree.test' -require 'decisiontree.benchmark' +--require 'decisiontree.test' +--require 'decisiontree.benchmark' -- nn.Module require 'decisiontree.DFD'