aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/LookupTable.lua
blob: 6cffc6c3e4d3f9785d92f1fea6a48189aab62ff5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
local THNN = require 'nn.THNN'
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')

LookupTable.__version = 4

function LookupTable:__init(nIndex, nOutput, paddingValue, maxNorm, normType)
   parent.__init(self)

   self.weight = torch.Tensor(nIndex, nOutput)
   self.gradWeight = torch.Tensor(nIndex, nOutput):zero()
   self.paddingValue = paddingValue or 0
   self.maxNorm = maxNorm or nil
   self.normType = normType or nil

   self:reset()
end

function LookupTable:backCompatibility()
   self._count = self._count or torch.IntTensor()
   self._input = self._input or torch.LongTensor()

   if not self.shouldScaleGradByFreq then
      self.shouldScaleGradByFreq = false
   end
end

function LookupTable:accUpdateOnly()
   self.gradWeight = nil
   return self
end

function LookupTable:setPadding(paddingValue)
   self.paddingValue = paddingValue
   return self
end

function LookupTable:setMaxNorm(maxNorm)
   self.maxNorm = maxNorm
   return self
end

function LookupTable:setNormType(normType)
   self.normType = normType
   return self
end

function LookupTable:scaleGradByFreq()
   self.shouldScaleGradByFreq = true
   return self
end

function LookupTable:reset(stdv)
   stdv = stdv or 1
   self.weight:normal(0, stdv)
end

function LookupTable:makeInputContiguous(input)
   -- make sure input is a contiguous torch.LongTensor
   if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then
      self.copiedInput = true
      self._input:resize(input:size()):copy(input)
      return self._input
   end
   self.copiedInput = false
   return input
end

function LookupTable:updateOutput(input)
   self:backCompatibility()
   self:renorm(input)
   input = self:makeInputContiguous(input)
   if input:dim() == 1 then
      self.output:index(self.weight, 1, input)
   elseif input:dim() == 2 then
      self.output:index(self.weight, 1, input:view(-1))
      self.output = self.output:view(input:size(1), input:size(2), self.weight:size(2))
   else
      error("input must be a vector or matrix")
   end
   return self.output
end

function LookupTable:updateGradInput(input, gradOutput)
   -- the input can be of any type (as in the forward it's
   -- converted anyway to LongTensor) thus, need to allocate
   -- new memory each time the user changes the input type
   if torch.type(self.gradInput) ~= torch.type(input) then
      self.gradInput = input.new()
   end
   if not self.gradInput:isSameSizeAs(input) then
      self.gradInput:resizeAs(input):zero()
   end
   return self.gradInput
end

function LookupTable:accGradParameters(input, gradOutput, scale)
   self:backCompatibility()
   input = self.copiedInput and self._input or input
   if input:dim() == 2 then
      input = input:view(-1)
   elseif input:dim() ~= 1 then
      error("input must be a vector or matrix")
   end

   self.gradWeight.THNN.LookupTable_accGradParameters(
      input:cdata(),
      gradOutput:cdata(),
      self.gradWeight:cdata(),
      self._count:cdata(),
      THNN.optionalTensor(self._sorted),
      THNN.optionalTensor(self._indices),
      self.shouldScaleGradByFreq or false,
      self.paddingValue or 0,
      scale or 1
   )
end

function LookupTable:renorm(input)
   if not self.maxNorm then
      return
   end
   -- copy input into _input, so _input is continuous.
   -- The copied _input will be modified in the C code.
   self._input:resize(input:size()):copy(input)
   local row_idx = self._input
   if row_idx:dim() == 2 then
      row_idx = row_idx:view(-1)
   elseif row_idx:dim() ~= 1 then
      error("input must be a vector or matrix")
   end
   -- "row_idx" and "weight" will be modified in the C code
   self.weight.THNN.LookupTable_renorm(
      row_idx:cdata(),
      self.weight:cdata(),
      self.maxNorm,
      self.normType or 2
   )
end

function LookupTable:type(type, tensorCache)
   parent.type(self, type, tensorCache)

   if type and type:find('torch%.Cuda.*Tensor') then
      -- CUDA uses _sorted and _indices temporary tensors
      self._sorted = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
      self._indices = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
      self._count = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
      self._input = torch.CudaLongTensor and torch.CudaLongTensor.new() or torch.CudaTensor.new()
   else
      -- self._count and self._input should only be converted if using Cuda
      self._count = torch.IntTensor()
      self._input = torch.LongTensor()
   end

   return self
end

function LookupTable:clearState()
   nn.utils.clear(self, '_count', '_input')
   return parent.clearState(self)
end

function LookupTable:sharedAccUpdateGradParameters(input, gradOutput, lr)
   -- we do not need to accumulate parameters when sharing:
   self:defaultAccUpdateGradParameters(input, gradOutput, lr)
end