mirror of
https://github.com/rspamd/rspamd.git
synced 2024-09-13 15:45:47 +02:00
160 lines
5.7 KiB
Lua
160 lines
5.7 KiB
Lua
local dt = require "decisiontree._env"
|
|
|
|
local RandomForestTrainer = torch.class("dt.RandomForestTrainer", dt)
|
|
|
|
function RandomForestTrainer:__init(opt)
|
|
assert(torch.type(opt.nTree) == 'number')
|
|
assert(opt.nTree > 0)
|
|
self.nTree = opt.nTree
|
|
-- max number of leaf nodes per tree
|
|
assert(torch.type(opt.maxLeafNodes) == 'number')
|
|
assert(opt.maxLeafNodes > 0)
|
|
self.maxLeafNodes = opt.maxLeafNodes
|
|
-- min number of examples per leaf
|
|
assert(torch.type(opt.minLeafSize) == 'number')
|
|
assert(opt.minLeafSize > 0)
|
|
self.minLeafSize = opt.minLeafSize
|
|
|
|
-- when non-positive, defaults to sqrt(#feature)
|
|
assert(torch.type(opt.featureBaggingSize) == 'number')
|
|
self.featureBaggingSize = opt.featureBaggingSize
|
|
|
|
assert(torch.type(opt.activeRatio) == 'number')
|
|
assert(opt.activeRatio > 0)
|
|
self.activeRatio = opt.activeRatio
|
|
|
|
-- default parallelization is singlethread
|
|
self.parallelMode = 'singlethread'
|
|
end
|
|
|
|
-- Train a DecisionForest
|
|
function RandomForestTrainer:train(trainSet, featureIds, verbose)
|
|
assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
|
|
assert(torch.type(featureIds) == 'torch.LongTensor')
|
|
|
|
if verbose then print(string.format("Begin training Decision Forest with %d trees", self.nTree)) end
|
|
|
|
local weight = torch.Tensor(self.nTree):fill(1 / self.nTree) -- RF uses uniform weights
|
|
|
|
local trees
|
|
if self.parallelMode == 'singlethread' then
|
|
trees = self:trainTrees(trainSet, featureIds, verbose)
|
|
elseif self.parallelMode == 'treeparallel' then
|
|
trainSet:deleteIndex() -- prevents serialization bottleneck
|
|
trees = self:trainTreesTP(trainSet, featureIds, verbose)
|
|
else
|
|
error("Unrecognized parallel mode: " .. self.parallelMode)
|
|
end
|
|
|
|
if verbose then print(string.format("Successfully trained %d trees", #trees)) end
|
|
|
|
-- set bias
|
|
local bias = 0;
|
|
for i, tree in ipairs(trees) do
|
|
bias = bias + tree.root.score * weight[i]
|
|
end
|
|
|
|
return dt.DecisionForest(trees, weight, bias)
|
|
end
|
|
|
|
function RandomForestTrainer:trainTrees(trainSet, featureIds, verbose)
|
|
|
|
-- the same CartTrainer will be used for each tree
|
|
local cartTrainer = dt.CartTrainer(trainSet, self.minLeafSize, self.maxLeafNodes)
|
|
|
|
local trees = {}
|
|
for treeId=1,self.nTree do
|
|
-- Train a CartTree
|
|
local tree = self.trainTree(cartTrainer, featureIds, self.featureBaggingSize, self.activeRatio, treeId, verbose)
|
|
table.insert(trees, tree)
|
|
end
|
|
return trees
|
|
end
|
|
|
|
-- static function that returns a cartTree
|
|
function RandomForestTrainer.trainTree(cartTrainer, featureIds, baggingSize, activeRatio, treeId, verbose)
|
|
assert(torch.isTypeOf(cartTrainer, 'dt.CartTrainer'))
|
|
assert(torch.type(featureIds) == 'torch.LongTensor')
|
|
local baggingSize = baggingSize > 0 and baggingSize or torch.round(math.sqrt(featureIds:size(1)))
|
|
|
|
if verbose then
|
|
print(string.format("Tree %d: Creating features bootstrap sample with baggingSize %d, nFeatures %d", treeId, baggingSize, featureIds:size(1)))
|
|
end
|
|
|
|
local trainSet = cartTrainer.dataset
|
|
|
|
-- sample boot strap features
|
|
local baggingIndices = torch.LongTensor(baggingSize):random(1,featureIds:size(1))
|
|
local activeFeatures = featureIds:index(1, baggingIndices)
|
|
|
|
-- sample boot strap examples
|
|
local sampleSize = torch.round(trainSet:size() * activeRatio)
|
|
if verbose then print(string.format("Creating bootstrap sample created of size %d", sampleSize)) end
|
|
|
|
baggingIndices:resize(sampleSize):random(1,trainSet:size())
|
|
local bootStrapExampleIds = torch.LongTensor()
|
|
bootStrapExampleIds:index(trainSet:getExampleIds(), 1, baggingIndices)
|
|
|
|
local cartTree = cartTrainer:train(dt.GiniState(bootStrapExampleIds), activeFeatures)
|
|
|
|
if verbose then print(string.format("Complete processing tree number %d", treeId)) end
|
|
|
|
return cartTree
|
|
end
|
|
|
|
-- call before training to enable tree-level parallelization
|
|
function RandomForestTrainer:treeParallel(workPool)
|
|
assert(self.parallelMode == 'singlethread', self.parallelMode)
|
|
self.parallelMode = 'treeparallel'
|
|
self.workPool = torch.type(workPool) == 'number' and dt.WorkPool(workPool) or workPool
|
|
assert(torch.isTypeOf(self.workPool, 'dt.WorkPool'))
|
|
|
|
-- require the dt package
|
|
self.workPool:update('require', {libname='decisiontree',varname='dt'})
|
|
end
|
|
|
|
-- TP is for tree parallel (not toilet paper)
|
|
function RandomForestTrainer:trainTreesTP(trainSet, featureIds, verbose)
|
|
assert(torch.isTypeOf(trainSet, 'dt.DataSet'))
|
|
assert(torch.type(featureIds) == 'torch.LongTensor')
|
|
local minLeafSize = self.minLeafSize
|
|
local maxLeafNodes = self.maxLeafNodes
|
|
|
|
-- setup worker store (each worker will have its own cartTrainer)
|
|
self.workPool:updateup('execute', function(store)
|
|
local dt = require 'decisiontree'
|
|
|
|
store.cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNodes)
|
|
store.featureIds = featureIds
|
|
end)
|
|
|
|
for treeId=1,self.nTree do
|
|
-- upvalues
|
|
local baggingSize = self.featureBaggingSize
|
|
local activeRatio = self.activeRatio
|
|
-- task closure that will be executed in worker-thread
|
|
local function trainTreeTask(store)
|
|
local dt = require 'decisiontree'
|
|
return dt.RandomForestTrainer.trainTree(store.cartTrainer, store.featureIds, baggingSize, activeRatio, treeId, verbose)
|
|
end
|
|
self.workPool:writeup('execute', trainTreeTask)
|
|
end
|
|
|
|
local trees = {}
|
|
for treeId=1,self.nTree do
|
|
local taskname, tree = self.workPool:read()
|
|
assert(taskname=='execute')
|
|
assert(torch.isTypeOf(tree, 'dt.CartTree'))
|
|
table.insert(trees, tree)
|
|
end
|
|
return trees
|
|
end
|
|
|
|
function RandomForestTrainer:getName()
|
|
return string.format(
|
|
"randomforest-aRatio-%4.2f-maxLeaf-%d-minExample-%d-nTree-%d",
|
|
self.activeRatio, self.maxLeafNodes, self.minLeafSize, self.nTree
|
|
)
|
|
end
|
|
|