aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/DontCast.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-23 18:14:15 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-23 18:14:15 +0100
commit714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch)
tree84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/nn/DontCast.lua
parent220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff)
downloadrspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz
rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/nn/DontCast.lua')
-rw-r--r--contrib/lua-torch/nn/DontCast.lua124
1 files changed, 124 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/DontCast.lua b/contrib/lua-torch/nn/DontCast.lua
new file mode 100644
index 000000000..b89f5436b
--- /dev/null
+++ b/contrib/lua-torch/nn/DontCast.lua
@@ -0,0 +1,124 @@
+local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator")
+
+-- utility functions
+
+local function recursiveTypeCopy(dst, src, type_str)
+ if torch.type(src) == 'table' then
+ dst = (torch.type(dst) == 'table') and dst or {}
+ for k, v in pairs(src) do
+ dst[k] = recursiveTypeCopy(dst[k], v, type_str)
+ end
+ elseif torch.isTensor(src) then
+ dst = (torch.type(dst) == type_str) and dst or torch.getmetatable(type_str).new()
+ dst:resize(src:size())
+ if src:nElement() > 0 then
+ dst:copy(src)
+ end
+ end
+ return dst
+end
+
+local function tableTensorType(src)
+ if type(src) == 'table' then
+ local type_str, found
+ for k,v in pairs(src) do
+ type_str, found = tableTensorType(v)
+ if found then
+ return type_str, true
+ end
+ end
+ return type_str, found
+ else
+ return torch.type(src), torch.isTensor(src)
+ end
+end
+
+-- DontCast methods and constructor
+
+function DontCast:__init(module, castin, castout, moduleType)
+ parent.__init(self, module)
+ self.castin = castin
+ self.castout = (castout == nil) and castin or castout
+ self.moduleType = moduleType
+ if (self.castin or self.castout) and not self.moduleType then
+ local moduleType, found = tableTensorType(module.output)
+ if found then
+ self.moduleType = moduleType
+ else
+ moduleType, found = tableTensorType(module:parameters())
+ if found then
+ self.moduleType = moduleType
+ else
+ error"Cannot extrapolate moduleType. Provide constructor argument 4"
+ end
+ end
+ end
+end
+
+function DontCast:updateOutput(input)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ self._input = recursiveTypeCopy(self._input, input, self.moduleType)
+ input = self._input
+ end
+
+ local output = self.modules[1]:updateOutput(input)
+
+ if self.castout then
+ self.output = recursiveTypeCopy(self.output, output, tableTensorType(self.output))
+ else
+ self.output = output
+ end
+ return self.output
+end
+
+function DontCast:updateGradInput(input, gradOutput)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ input = self._input
+ end
+ if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
+ self._gradOutput = recursiveTypeCopy(self._gradOutput, gradOutput, self.moduleType)
+ gradOutput = self._gradOutput
+ end
+
+ local gradInput = self.modules[1]:updateGradInput(input, gradOutput)
+
+ if self.castin then
+ self.gradInput = recursiveTypeCopy(self.gradInput, gradInput, tableTensorType(self.gradInput))
+ else
+ self.gradInput = gradInput
+ end
+ return self.gradInput
+end
+
+function DontCast:accGradParameters(input, gradOutput, scale)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ input = self._input
+ end
+ if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
+ gradOutput = self._gradOutput
+ end
+
+ self.modules[1]:accGradParameters(input, gradOutput, scale)
+end
+
+function DontCast:accUpdateGradParameters(input, gradOutput, lr)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ input = self._input
+ end
+ if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
+ gradOutput = self._gradOutput
+ end
+
+ self.modules[1]:accUpdateGradParameters(input, gradOutput, lr)
+end
+
+-- dont cast (the essence thereof)
+function DontCast:type(type)
+ if self.castout and tableTensorType(self.output) ~= type then
+ self.output = recursiveTypeCopy(nil, self.output, type)
+ end
+ if self.castin and tableTensorType(self.gradInput) ~= type then
+ self.gradInput = recursiveTypeCopy(nil, self.gradInput, type)
+ end
+ return self
+end