aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree/DataSet.lua
blob: 15058a7c6a62e8697aab0ca6827abab09a6ffe79 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
local dt = require "decisiontree._env"

local DataSet = torch.class("dt.DataSet", dt)

function DataSet:__init(input, target, nThreads)
   if torch.type(input) == 'table' then
      assert(torch.isTypeOf(input[1], 'torch.SparseTensor'))
   else
      assert(torch.isTensor(input))
   end
   self.input = input
   assert(torch.isTensor(target))
   self.target = target
   self.nThreads = nThreads or 1

   self.sortedFeatureValues, self.featureIds = self:sortFeatureValues(input)
end

-- group examples by featureId. For each featureId, sort examples by featureValue (ascending order)
-- returns a table mapping featureIds to sorted lists of exampleIds
-- e.g. {featureId={example1,example2,example3}}
function DataSet:sortFeatureValues(inputs)
   local isSparse = torch.typename(inputs[1]):match('torch.*SparseTensor')
   assert(isSparse or torch.isTensor(inputs))

   local featureIds = torch.LongTensor()
   local dataset = {} -- TODO use tds.Hash (will require SparseTensor to be userdata)
   if isSparse then
      local proto = inputs[1].values
      -- get list of featureIds
      local featureMap = {}
      for i,input in ipairs(inputs) do
         input.keys:apply(function(key)
            featureMap[key] = (featureMap[key] or 0) + 1
         end)
      end
      local _ = require "moses"
      featureIds = featureIds.new(_.keys(featureMap))
      local featureCounts = torch.LongTensor(featureIds:size(1))
      for i=1,featureIds:size(1) do
         featureCounts[i] = featureMap[featureIds[i]]
      end

      for i=1,featureIds:size(1) do
         local featureId = featureIds[i]
         local featureCount = featureCounts[i]
         dataset[featureId] = {
            values=proto.new(featureCount),
            examples=torch.LongTensor(featureCount),
            i=0
         }
      end

      for exampleId,input in ipairs(inputs) do
         local sparseIdx = 0
         input.keys:apply(function(key)
            sparseIdx = sparseIdx + 1
            local f = dataset[key]
            f.i = f.i + 1
            f.values[f.i] = input.values[sparseIdx]
            f.examples[f.i] = exampleId
         end)
      end

      local sortVal, sortIdx = proto.new(), torch.LongTensor()
      for featureId,f in pairs(dataset) do
         assert(f.values:size(1) == f.i)
         sortVal:sort(sortIdx, f.values, 1, false)

         local sortedExampleIds = torch.LongTensor(f.i)
         sortedExampleIds:index(f.examples, 1, sortIdx)

         dataset[featureId] = sortedExampleIds
      end
   else
      assert(torch.isTensor(inputs))
      featureIds:range(1,inputs:size(2))

      for i=1,inputs:size(2) do
         local featureId = i
         local values = inputs:select(2, i)
         local _, sortedFeatureExampleIds = values:sort(1, false)
         dataset[featureId] = sortedFeatureExampleIds
      end
   end

   return dataset, featureIds
end

function DataSet:getSortedFeature(featureId)
   assert(self.sortedFeatureValues)
   return self.sortedFeatureValues[featureId]
end

function DataSet:size()
   return self.target:size(1)
end

function DataSet:getExampleIds()
   if not self.exampleIds then
      self.exampleIds = torch.LongTensor():range(1,self:size())
   end
   return self.exampleIds
end

function DataSet:countPositive(exampleIds)
   assert(torch.type(exampleIds) == 'torch.LongTensor')
   local dt = require 'decisiontree'
   local buffer = dt.getBufferTable('DataSet')
   buffer.tensor = buffer.tensor or self.target.new()
   buffer.tensor:index(self.target, 1, exampleIds)
   local nPositive = 0
   buffer.tensor:apply(function(x)
      if x > 0 then nPositive = nPositive + 1 end
   end)
   return nPositive
end

function DataSet:initScore()
   self.score = self.score or torch.Tensor()
   self.score:resize(self:size()):fill(0)
end

function DataSet:buildIndex()
   if torch.type(self.input) == 'table' then
      for exampleId,input in ipairs(self.input) do
         if torch.isTypeOf(input, 'torch.SparseTensor') then
            input:buildIndex()
         end
      end
   end
end

function DataSet:deleteIndex()
   if torch.type(self.input) == 'table' then
      for exampleId,input in ipairs(self.input) do
         if torch.isTypeOf(input, 'torch.SparseTensor') then
            input:deleteIndex()
         end
      end
   end
end