diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
commit | 714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch) | |
tree | 84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/nn/DontCast.lua | |
parent | 220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff) | |
download | rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip |
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/nn/DontCast.lua')
-rw-r--r-- | contrib/lua-torch/nn/DontCast.lua | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/DontCast.lua b/contrib/lua-torch/nn/DontCast.lua new file mode 100644 index 000000000..b89f5436b --- /dev/null +++ b/contrib/lua-torch/nn/DontCast.lua @@ -0,0 +1,124 @@ +local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator") + +-- utility functions + +local function recursiveTypeCopy(dst, src, type_str) + if torch.type(src) == 'table' then + dst = (torch.type(dst) == 'table') and dst or {} + for k, v in pairs(src) do + dst[k] = recursiveTypeCopy(dst[k], v, type_str) + end + elseif torch.isTensor(src) then + dst = (torch.type(dst) == type_str) and dst or torch.getmetatable(type_str).new() + dst:resize(src:size()) + if src:nElement() > 0 then + dst:copy(src) + end + end + return dst +end + +local function tableTensorType(src) + if type(src) == 'table' then + local type_str, found + for k,v in pairs(src) do + type_str, found = tableTensorType(v) + if found then + return type_str, true + end + end + return type_str, found + else + return torch.type(src), torch.isTensor(src) + end +end + +-- DontCast methods and constructor + +function DontCast:__init(module, castin, castout, moduleType) + parent.__init(self, module) + self.castin = castin + self.castout = (castout == nil) and castin or castout + self.moduleType = moduleType + if (self.castin or self.castout) and not self.moduleType then + local moduleType, found = tableTensorType(module.output) + if found then + self.moduleType = moduleType + else + moduleType, found = tableTensorType(module:parameters()) + if found then + self.moduleType = moduleType + else + error"Cannot extrapolate moduleType. Provide constructor argument 4" + end + end + end +end + +function DontCast:updateOutput(input) + if self.castin and tableTensorType(input) ~= self.moduleType then + self._input = recursiveTypeCopy(self._input, input, self.moduleType) + input = self._input + end + + local output = self.modules[1]:updateOutput(input) + + if self.castout then + self.output = recursiveTypeCopy(self.output, output, tableTensorType(self.output)) + else + self.output = output + end + return self.output +end + +function DontCast:updateGradInput(input, gradOutput) + if self.castin and tableTensorType(input) ~= self.moduleType then + input = self._input + end + if self.castout and tableTensorType(gradOutput) ~= self.moduleType then + self._gradOutput = recursiveTypeCopy(self._gradOutput, gradOutput, self.moduleType) + gradOutput = self._gradOutput + end + + local gradInput = self.modules[1]:updateGradInput(input, gradOutput) + + if self.castin then + self.gradInput = recursiveTypeCopy(self.gradInput, gradInput, tableTensorType(self.gradInput)) + else + self.gradInput = gradInput + end + return self.gradInput +end + +function DontCast:accGradParameters(input, gradOutput, scale) + if self.castin and tableTensorType(input) ~= self.moduleType then + input = self._input + end + if self.castout and tableTensorType(gradOutput) ~= self.moduleType then + gradOutput = self._gradOutput + end + + self.modules[1]:accGradParameters(input, gradOutput, scale) +end + +function DontCast:accUpdateGradParameters(input, gradOutput, lr) + if self.castin and tableTensorType(input) ~= self.moduleType then + input = self._input + end + if self.castout and tableTensorType(gradOutput) ~= self.moduleType then + gradOutput = self._gradOutput + end + + self.modules[1]:accUpdateGradParameters(input, gradOutput, lr) +end + +-- dont cast (the essence thereof) +function DontCast:type(type) + if self.castout and tableTensorType(self.output) ~= type then + self.output = recursiveTypeCopy(nil, self.output, type) + end + if self.castin and tableTensorType(self.gradInput) ~= type then + self.gradInput = recursiveTypeCopy(nil, self.gradInput, type) + end + return self +end |