aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/DontCast.lua
blob: b89f5436b9299ff98173d0a19cf49fb135c614bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator")

-- utility functions

local function recursiveTypeCopy(dst, src, type_str)
   if torch.type(src) == 'table' then
      dst = (torch.type(dst) == 'table') and dst or {}
      for k, v in pairs(src) do
         dst[k] = recursiveTypeCopy(dst[k], v, type_str)
      end
   elseif torch.isTensor(src) then
      dst = (torch.type(dst) == type_str) and dst or torch.getmetatable(type_str).new()
      dst:resize(src:size())
      if src:nElement() > 0 then
         dst:copy(src)
      end
   end
   return dst
end

local function tableTensorType(src)
   if type(src) == 'table' then
      local type_str, found
      for k,v in pairs(src) do
         type_str, found = tableTensorType(v)
         if found then
            return type_str, true
         end
      end
      return type_str, found
   else
      return torch.type(src), torch.isTensor(src)
   end
end

-- DontCast methods and constructor

function DontCast:__init(module, castin, castout, moduleType)
   parent.__init(self, module)
   self.castin = castin
   self.castout = (castout == nil) and castin or castout
   self.moduleType = moduleType
   if (self.castin or self.castout) and not self.moduleType then
      local moduleType, found = tableTensorType(module.output)
      if found then
         self.moduleType = moduleType
      else
         moduleType, found = tableTensorType(module:parameters())
         if found then
            self.moduleType = moduleType
         else
            error"Cannot extrapolate moduleType. Provide constructor argument 4"
         end
      end
   end
end

function DontCast:updateOutput(input)
   if self.castin and tableTensorType(input) ~= self.moduleType then
      self._input = recursiveTypeCopy(self._input, input, self.moduleType)
      input = self._input
   end

   local output = self.modules[1]:updateOutput(input)

   if self.castout then
      self.output = recursiveTypeCopy(self.output, output, tableTensorType(self.output))
   else
      self.output = output
   end
   return self.output
end

function DontCast:updateGradInput(input, gradOutput)
   if self.castin and tableTensorType(input) ~= self.moduleType then
      input = self._input
   end
   if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
      self._gradOutput = recursiveTypeCopy(self._gradOutput, gradOutput, self.moduleType)
      gradOutput = self._gradOutput
   end

   local gradInput = self.modules[1]:updateGradInput(input, gradOutput)

   if self.castin then
      self.gradInput = recursiveTypeCopy(self.gradInput, gradInput, tableTensorType(self.gradInput))
   else
      self.gradInput = gradInput
   end
   return self.gradInput
end

function DontCast:accGradParameters(input, gradOutput, scale)
   if self.castin and tableTensorType(input) ~= self.moduleType then
      input = self._input
   end
   if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
      gradOutput = self._gradOutput
   end

   self.modules[1]:accGradParameters(input, gradOutput, scale)
end

function DontCast:accUpdateGradParameters(input, gradOutput, lr)
   if self.castin and tableTensorType(input) ~= self.moduleType then
      input = self._input
   end
   if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
      gradOutput = self._gradOutput
   end

   self.modules[1]:accUpdateGradParameters(input, gradOutput, lr)
end

-- dont cast (the essence thereof)
function DontCast:type(type)
   if self.castout and tableTensorType(self.output) ~= type then
      self.output = recursiveTypeCopy(nil, self.output, type)
   end
   if self.castin and tableTensorType(self.gradInput) ~= type then
      self.gradInput = recursiveTypeCopy(nil, self.gradInput, type)
   end
   return self
end