You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DecisionForest.lua 2.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. local dt = require "decisiontree._env"
  2. -- Decision forest that ensembles a bag of decision trees.
  3. local DecisionForest = torch.class("dt.DecisionForest", "dt.DecisionTree", dt)
  4. function DecisionForest:__init(trees, weight, bias)
  5. assert(torch.type(trees) == 'table')
  6. self.trees = trees
  7. if #trees == 0 then
  8. self.weight = weight or torch.Tensor()
  9. assert(torch.isTensor(self.weight))
  10. assert(self.weight:nElement() == 0)
  11. else
  12. assert(torch.isTypeOf(trees[1], 'dt.DecisionTree'))
  13. self.weight = weight or torch.Tensor(#trees):fill(1)
  14. assert(torch.isTensor(self.weight))
  15. assert(self.weight:dim() == 1)
  16. assert(self.weight:min() >= 0, "Expecting positive weights")
  17. assert(#trees == self.weight:size(1))
  18. end
  19. self.bias = bias or 0
  20. assert(torch.type(self.bias) == 'number')
  21. end
  22. function DecisionForest:score(input, incrementalId)
  23. assert(torch.isTensor(input))
  24. local buffer = {}
  25. if incrementalId then
  26. self.buffers = self.buffers or {}
  27. self.buffers[incrementalId] = self.buffers[incrementalId] or {}
  28. buffer = self.buffers[incrementalId]
  29. end
  30. buffer.initialCounter = buffer.initialCounter or 0
  31. -- TODO: score in parallel
  32. local output
  33. if torch.isTensor(input) and input.isContiguous and input:isContiguous() and input:nDimension() == 2 then
  34. buffer.output = buffer.output or input.new()
  35. output = buffer.output
  36. assert(output:nElement() == 0 or output:size(1) == input:size(1))
  37. if output:nElement() == 0 then
  38. output:resize(input:size(1)):fill(self.bias)
  39. end
  40. for i,tree in ipairs(self.trees) do
  41. if i > buffer.initialCounter then
  42. local score = tree:score(input, nil, true)
  43. output:add(self.weight[i], score)
  44. end
  45. end
  46. else
  47. output = buffer.output or self.bias
  48. for i,tree in ipairs(self.trees) do
  49. if i > buffer.initialCounter then
  50. output = output + tree:score(input) * self.weight[i]
  51. end
  52. end
  53. buffer.output = output
  54. end
  55. buffer.initialCounter = #self.trees
  56. return output
  57. end
  58. function DecisionForest:add(tree, weight)
  59. assert(torch.type(weight) == 'number')
  60. assert(weight > 0)
  61. table.insert(self.trees, tree)
  62. self.weight:resize(#self.trees)
  63. self.weight[#self.trees] = weight
  64. return self
  65. end
  66. function DecisionForest:clone()
  67. local trees = {}
  68. for i, tree in ipairs(self.trees) do
  69. trees[i] = tree:clone()
  70. end
  71. return DecisionForest(trees, self.weight:clone(), self.bias)
  72. end