1
0
şunun yansıması https://github.com/rspamd/rspamd.git eşitlendi 2024-07-29 20:17:47 +02:00
rspamd/contrib/lua-torch/decisiontree/RandomForestTrainer.lua
2018-05-23 18:14:15 +01:00

160 satır
5.7 KiB
Lua

local dt = require "decisiontree._env"
local RandomForestTrainer = torch.class("dt.RandomForestTrainer", dt)
function RandomForestTrainer:__init(opt)
assert(torch.type(opt.nTree) == 'number')
assert(opt.nTree > 0)
self.nTree = opt.nTree
-- max number of leaf nodes per tree
assert(torch.type(opt.maxLeafNodes) == 'number')
assert(opt.maxLeafNodes > 0)
self.maxLeafNodes = opt.maxLeafNodes
-- min number of examples per leaf
assert(torch.type(opt.minLeafSize) == 'number')
assert(opt.minLeafSize > 0)
self.minLeafSize = opt.minLeafSize
-- when non-positive, defaults to sqrt(#feature)
assert(torch.type(opt.featureBaggingSize) == 'number')
self.featureBaggingSize = opt.featureBaggingSize
assert(torch.type(opt.activeRatio) == 'number')
assert(opt.activeRatio > 0)
self.activeRatio = opt.activeRatio
-- default parallelization is singlethread
self.parallelMode = 'singlethread'
end
-- Train a DecisionForest
function RandomForestTrainer:train(trainSet, featureIds, verbose)
assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
assert(torch.type(featureIds) == 'torch.LongTensor')
if verbose then print(string.format("Begin training Decision Forest with %d trees", self.nTree)) end
local weight = torch.Tensor(self.nTree):fill(1 / self.nTree) -- RF uses uniform weights
local trees
if self.parallelMode == 'singlethread' then
trees = self:trainTrees(trainSet, featureIds, verbose)
elseif self.parallelMode == 'treeparallel' then
trainSet:deleteIndex() -- prevents serialization bottleneck
trees = self:trainTreesTP(trainSet, featureIds, verbose)
else
error("Unrecognized parallel mode: " .. self.parallelMode)
end
if verbose then print(string.format("Successfully trained %d trees", #trees)) end
-- set bias
local bias = 0;
for i, tree in ipairs(trees) do
bias = bias + tree.root.score * weight[i]
end
return dt.DecisionForest(trees, weight, bias)
end
function RandomForestTrainer:trainTrees(trainSet, featureIds, verbose)
-- the same CartTrainer will be used for each tree
local cartTrainer = dt.CartTrainer(trainSet, self.minLeafSize, self.maxLeafNodes)
local trees = {}
for treeId=1,self.nTree do
-- Train a CartTree
local tree = self.trainTree(cartTrainer, featureIds, self.featureBaggingSize, self.activeRatio, treeId, verbose)
table.insert(trees, tree)
end
return trees
end
-- static function that returns a cartTree
function RandomForestTrainer.trainTree(cartTrainer, featureIds, baggingSize, activeRatio, treeId, verbose)
assert(torch.isTypeOf(cartTrainer, 'dt.CartTrainer'))
assert(torch.type(featureIds) == 'torch.LongTensor')
local baggingSize = baggingSize > 0 and baggingSize or torch.round(math.sqrt(featureIds:size(1)))
if verbose then
print(string.format("Tree %d: Creating features bootstrap sample with baggingSize %d, nFeatures %d", treeId, baggingSize, featureIds:size(1)))
end
local trainSet = cartTrainer.dataset
-- sample boot strap features
local baggingIndices = torch.LongTensor(baggingSize):random(1,featureIds:size(1))
local activeFeatures = featureIds:index(1, baggingIndices)
-- sample boot strap examples
local sampleSize = torch.round(trainSet:size() * activeRatio)
if verbose then print(string.format("Creating bootstrap sample created of size %d", sampleSize)) end
baggingIndices:resize(sampleSize):random(1,trainSet:size())
local bootStrapExampleIds = torch.LongTensor()
bootStrapExampleIds:index(trainSet:getExampleIds(), 1, baggingIndices)
local cartTree = cartTrainer:train(dt.GiniState(bootStrapExampleIds), activeFeatures)
if verbose then print(string.format("Complete processing tree number %d", treeId)) end
return cartTree
end
-- call before training to enable tree-level parallelization
function RandomForestTrainer:treeParallel(workPool)
assert(self.parallelMode == 'singlethread', self.parallelMode)
self.parallelMode = 'treeparallel'
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
-- require the dt package
self.workPool:update('require', {libname='decisiontree',varname='dt'})
end
-- TP is for tree parallel (not toilet paper)
function RandomForestTrainer:trainTreesTP(trainSet, featureIds, verbose)
assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
assert(torch.type(featureIds) == 'torch.LongTensor')
local minLeafSize = self.minLeafSize
local maxLeafNodes = self.maxLeafNodes
-- setup worker store (each worker will have its own cartTrainer)
self.workPool:updateup('execute', function(store)
local dt = require 'decisiontree'
store.cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes)
store.featureIds = featureIds
end)
for treeId=1,self.nTree do
-- upvalues
local baggingSize = self.featureBaggingSize
local activeRatio = self.activeRatio
-- task closure that will be executed in worker-thread
local function trainTreeTask(store)
local dt = require 'decisiontree'
return dt.RandomForestTrainer.trainTree(store.cartTrainer, store.featureIds, baggingSize, activeRatio, treeId, verbose)
end
self.workPool:writeup('execute', trainTreeTask)
end
local trees = {}
for treeId=1,self.nTree do
local taskname, tree = self.workPool:read()
assert(taskname=='execute')
assert(torch.isTypeOf(tree, 'dt.CartTree'))
table.insert(trees, tree)
end
return trees
end
function RandomForestTrainer:getName()
return string.format(
"randomforest-aRatio-%4.2f-maxLeaf-%d-minExample-%d-nTree-%d",
self.activeRatio, self.maxLeafNodes, self.minLeafSize, self.nTree
)
end