You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataSet.lua 4.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. local dt = require "decisiontree._env"
  2. local DataSet = torch.class("dt.DataSet", dt)
  3. function DataSet:__init(input, target, nThreads)
  4. if torch.type(input) == 'table' then
  5. assert(torch.isTypeOf(input[1], 'torch.SparseTensor'))
  6. else
  7. assert(torch.isTensor(input))
  8. end
  9. self.input = input
  10. assert(torch.isTensor(target))
  11. self.target = target
  12. self.nThreads = nThreads or 1
  13. self.sortedFeatureValues, self.featureIds = self:sortFeatureValues(input)
  14. end
  15. -- group examples by featureId. For each featureId, sort examples by featureValue (ascending order)
  16. -- returns a table mapping featureIds to sorted lists of exampleIds
  17. -- e.g. {featureId={example1,example2,example3}}
  18. function DataSet:sortFeatureValues(inputs)
  19. local isSparse = torch.typename(inputs[1]):match('torch.*SparseTensor')
  20. assert(isSparse or torch.isTensor(inputs))
  21. local featureIds = torch.LongTensor()
  22. local dataset = {} -- TODO use tds.Hash (will require SparseTensor to be userdata)
  23. if isSparse then
  24. local proto = inputs[1].values
  25. -- get list of featureIds
  26. local featureMap = {}
  27. for i,input in ipairs(inputs) do
  28. input.keys:apply(function(key)
  29. featureMap[key] = (featureMap[key] or 0) + 1
  30. end)
  31. end
  32. local _ = require "moses"
  33. featureIds = featureIds.new(_.keys(featureMap))
  34. local featureCounts = torch.LongTensor(featureIds:size(1))
  35. for i=1,featureIds:size(1) do
  36. featureCounts[i] = featureMap[featureIds[i]]
  37. end
  38. for i=1,featureIds:size(1) do
  39. local featureId = featureIds[i]
  40. local featureCount = featureCounts[i]
  41. dataset[featureId] = {
  42. values=proto.new(featureCount),
  43. examples=torch.LongTensor(featureCount),
  44. i=0
  45. }
  46. end
  47. for exampleId,input in ipairs(inputs) do
  48. local sparseIdx = 0
  49. input.keys:apply(function(key)
  50. sparseIdx = sparseIdx + 1
  51. local f = dataset[key]
  52. f.i = f.i + 1
  53. f.values[f.i] = input.values[sparseIdx]
  54. f.examples[f.i] = exampleId
  55. end)
  56. end
  57. local sortVal, sortIdx = proto.new(), torch.LongTensor()
  58. for featureId,f in pairs(dataset) do
  59. assert(f.values:size(1) == f.i)
  60. sortVal:sort(sortIdx, f.values, 1, false)
  61. local sortedExampleIds = torch.LongTensor(f.i)
  62. sortedExampleIds:index(f.examples, 1, sortIdx)
  63. dataset[featureId] = sortedExampleIds
  64. end
  65. else
  66. assert(torch.isTensor(inputs))
  67. featureIds:range(1,inputs:size(2))
  68. for i=1,inputs:size(2) do
  69. local featureId = i
  70. local values = inputs:select(2, i)
  71. local _, sortedFeatureExampleIds = values:sort(1, false)
  72. dataset[featureId] = sortedFeatureExampleIds
  73. end
  74. end
  75. return dataset, featureIds
  76. end
  77. function DataSet:getSortedFeature(featureId)
  78. assert(self.sortedFeatureValues)
  79. return self.sortedFeatureValues[featureId]
  80. end
  81. function DataSet:size()
  82. return self.target:size(1)
  83. end
  84. function DataSet:getExampleIds()
  85. if not self.exampleIds then
  86. self.exampleIds = torch.LongTensor():range(1,self:size())
  87. end
  88. return self.exampleIds
  89. end
  90. function DataSet:countPositive(exampleIds)
  91. assert(torch.type(exampleIds) == 'torch.LongTensor')
  92. local dt = require 'decisiontree'
  93. local buffer = dt.getBufferTable('DataSet')
  94. buffer.tensor = buffer.tensor or self.target.new()
  95. buffer.tensor:index(self.target, 1, exampleIds)
  96. local nPositive = 0
  97. buffer.tensor:apply(function(x)
  98. if x > 0 then nPositive = nPositive + 1 end
  99. end)
  100. return nPositive
  101. end
  102. function DataSet:initScore()
  103. self.score = self.score or torch.Tensor()
  104. self.score:resize(self:size()):fill(0)
  105. end
  106. function DataSet:buildIndex()
  107. if torch.type(self.input) == 'table' then
  108. for exampleId,input in ipairs(self.input) do
  109. if torch.isTypeOf(input, 'torch.SparseTensor') then
  110. input:buildIndex()
  111. end
  112. end
  113. end
  114. end
  115. function DataSet:deleteIndex()
  116. if torch.type(self.input) == 'table' then
  117. for exampleId,input in ipairs(self.input) do
  118. if torch.isTypeOf(input, 'torch.SparseTensor') then
  119. input:deleteIndex()
  120. end
  121. end
  122. end
  123. end