aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-11-11 16:32:54 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-11-11 16:32:54 +0000
commit80b59a48c8d35026425ef2db28e281e955c8e08c (patch)
treece99b72bbed68d3480dd2a7a0a98f232c236971b /contrib/torch/decisiontree
parenta55f3e05b6462e063e0501f6251bd05c39f4eaab (diff)
downloadrspamd-80b59a48c8d35026425ef2db28e281e955c8e08c.tar.gz
rspamd-80b59a48c8d35026425ef2db28e281e955c8e08c.zip
[Fix] Fix random forests module
Diffstat (limited to 'contrib/torch/decisiontree')
-rw-r--r--contrib/torch/decisiontree/DataSet.lua28
-rw-r--r--contrib/torch/decisiontree/GradientBoostTrainer.lua2
-rw-r--r--contrib/torch/decisiontree/init.lua12
3 files changed, 9 insertions, 33 deletions
diff --git a/contrib/torch/decisiontree/DataSet.lua b/contrib/torch/decisiontree/DataSet.lua
index 505ec86d2..15058a7c6 100644
--- a/contrib/torch/decisiontree/DataSet.lua
+++ b/contrib/torch/decisiontree/DataSet.lua
@@ -1,5 +1,4 @@
local dt = require "decisiontree._env"
-local ipc = require 'libipc'
local DataSet = torch.class("dt.DataSet", dt)
@@ -77,31 +76,10 @@ function DataSet:sortFeatureValues(inputs)
assert(torch.isTensor(inputs))
featureIds:range(1,inputs:size(2))
- local wq = ipc.workqueue()
for i=1,inputs:size(2) do
- wq:write({i, inputs:select(2, i)})
- end
- for i=1,self.nThreads do
- wq:write(nil)
- end
-
- ipc.map(self.nThreads, function(wq)
- while true do
- local data = wq:read()
- if data == nil then break end
- local featureId = data[1]
- local values = data[2]
- local sortFeatureValues, sortExampleIds = values:sort(1, false)
- sortFeatureValues = nil
- wq:write({featureId, sortExampleIds})
- collectgarbage()
- end
- end, wq)
-
- for _=1,inputs:size(2) do
- local data = wq:read()
- local featureId = data[1]
- local sortedFeatureExampleIds = data[2]
+ local featureId = i
+ local values = inputs:select(2, i)
+ local _, sortedFeatureExampleIds = values:sort(1, false)
dataset[featureId] = sortedFeatureExampleIds
end
end
diff --git a/contrib/torch/decisiontree/GradientBoostTrainer.lua b/contrib/torch/decisiontree/GradientBoostTrainer.lua
index 54d1ba8ac..51299b109 100644
--- a/contrib/torch/decisiontree/GradientBoostTrainer.lua
+++ b/contrib/torch/decisiontree/GradientBoostTrainer.lua
@@ -177,7 +177,7 @@ function GradientBoostTrainer:train(trainSet, featureIds, validSet, verbose)
timer:reset()
local stop, validLoss, bestDecisionForest = self:validate(trainSet, validSet, decisionForest, bestDecisionForest)
if dt.PROFILE then print("validate tree time: "..timer:time().real) end
- if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", trainLoss, validLoss)) end
+ if verbose then print(string.format("Loss: train=%7.4f, valid=%7.4f", 0, validLoss)) end
if stop then
if verbose then print(string.format("GBDT early stopped on tree %d", treeId)) end
break
diff --git a/contrib/torch/decisiontree/init.lua b/contrib/torch/decisiontree/init.lua
index fdaf1b56a..26f790b60 100644
--- a/contrib/torch/decisiontree/init.lua
+++ b/contrib/torch/decisiontree/init.lua
@@ -1,10 +1,8 @@
require 'paths'
-require 'xlua'
+--require 'xlua'
require 'string'
require 'os'
-require 'sys'
-require 'image'
-require 'lfs'
+--require 'sys'
require 'nn'
-- these actually return local variables but we will re-require them
@@ -31,7 +29,7 @@ require 'decisiontree.math'
require 'decisiontree.utils'
-- for multi-threading
-require 'decisiontree.WorkPool'
+--require 'decisiontree.WorkPool'
-- abstract classes
require 'decisiontree.DecisionTree'
@@ -62,8 +60,8 @@ require 'decisiontree.GradientBoostTrainer'
require 'decisiontree.GradientBoostState' -- TreeState subclass
-- unit tests and benchmarks
-require 'decisiontree.test'
-require 'decisiontree.benchmark'
+--require 'decisiontree.test'
+--require 'decisiontree.benchmark'
-- nn.Module
require 'decisiontree.DFD'