blob: 4c620e618ca4a860c2bf95b24bad7aaccd74efe4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
local SparseTensor = torch.class("torch.SparseTensor")
function SparseTensor:__init(keys, values)
if keys and values then
assert(torch.typename(keys):find('torch%..*LongTensor'))
assert(torch.isTensor(values))
assert(keys:nElement() == values:nElement(), "Expecting key and value tensors of same size")
self.keys = keys
self.values = values
elseif not (keys or values) then
self.keys = torch.LongTensor()
self.values = torch.Tensor()
else
error"Expecting zero or two args"
end
end
function SparseTensor:buildIndex(overwrite)
if self._map and not overwrite then return end
assert(self.keys and self.keys:dim() == 1)
assert(self.values and self.values:dim() == 1)
-- hash table
self._map = {}
for i=1,self.keys:size(1) do
self._map[self.keys[i]] = i
end
end
function SparseTensor:deleteIndex()
self._map = nil
end
local __index = SparseTensor.__index
function SparseTensor:__index(key)
if key == nil then
error"Attempt to index using a nil key"
elseif torch.type(key) ~= 'number' then
return __index(self, key)
end
if self._map then
assert(torch.type(self._map) == 'table')
local idx = self._map[key]
return idx and self.values[idx] or nil
elseif self.keys:nElement() > 0 then
for i=1,self.keys:size(1) do
if self.keys[i] == key then
return self.values[i]
end
end
end
return nil
end
|