aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/decisiontree/SparseTensor.lua
blob: 4c620e618ca4a860c2bf95b24bad7aaccd74efe4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
local SparseTensor = torch.class("torch.SparseTensor")

function SparseTensor:__init(keys, values)
   if keys and values then
      assert(torch.typename(keys):find('torch%..*LongTensor'))
      assert(torch.isTensor(values))
      assert(keys:nElement() == values:nElement(), "Expecting key and value tensors of same size")
      self.keys = keys
      self.values = values
   elseif not (keys or values) then
      self.keys = torch.LongTensor()
      self.values = torch.Tensor()
   else
      error"Expecting zero or two args"
   end
end

function SparseTensor:buildIndex(overwrite)
   if self._map and not overwrite then return end
   assert(self.keys and self.keys:dim() == 1)
   assert(self.values and self.values:dim() == 1)
   -- hash table
   self._map = {}
   for i=1,self.keys:size(1) do
      self._map[self.keys[i]] = i
   end
end

function SparseTensor:deleteIndex()
   self._map = nil
end

local __index = SparseTensor.__index
function SparseTensor:__index(key)
   if key == nil then
      error"Attempt to index using a nil key"
   elseif torch.type(key) ~= 'number' then
      return __index(self, key)
   end

   if self._map then
      assert(torch.type(self._map) == 'table')
      local idx = self._map[key]
      return idx and self.values[idx] or nil
   elseif self.keys:nElement() > 0 then
      for i=1,self.keys:size(1) do
         if self.keys[i] == key then
            return self.values[i]
         end
      end
   end
   return nil
end