diff options
author | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@highsecure.ru> | 2018-05-23 18:14:15 +0100 |
commit | 714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch) | |
tree | 84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/nn/ModuleCriterion.lua | |
parent | 220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff) | |
download | rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip |
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/nn/ModuleCriterion.lua')
-rw-r--r-- | contrib/lua-torch/nn/ModuleCriterion.lua | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/ModuleCriterion.lua b/contrib/lua-torch/nn/ModuleCriterion.lua new file mode 100644 index 000000000..bfc79ef55 --- /dev/null +++ b/contrib/lua-torch/nn/ModuleCriterion.lua @@ -0,0 +1,44 @@ +local ModuleCriterion, parent = torch.class("nn.ModuleCriterion", "nn.Criterion") + +function ModuleCriterion:__init(criterion, inputModule, targetModule, castTarget) + self.inputModule = inputModule + self.targetModule = targetModule + self.castTarget = (castTarget == nil) and true or castTarget + if self.inputModule then + local params = self.inputModule:parameters() + if params and #params > 0 then + print"Warning: nn.ModuleCriterion doesn't support parameter updates" + end + end + self.criterion = criterion +end + +function ModuleCriterion:updateOutput(input, target) + if self.inputModule then + self.input = self.inputModule:forward(input) + end + if self.targetModule then + self.target = self.targetModule:forward(target) + end + self.output = self.criterion:forward(self.input or input, self.target or target) + return self.output +end + +function ModuleCriterion:updateGradInput(input, target) + self.gradInput = self.criterion:backward(self.input or input, self.target or target) + if self.inputModule then + self.gradInput = self.inputModule:backward(input, self.gradInput) + end + return self.gradInput +end + +function ModuleCriterion:type(type, typecache) + if self.inputModule then + self.inputModule:type(type, typecache) + end + if self.castTarget and self.targetModule then + self.targetModule:type(type, typecache) + end + self.criterion:type(type, typecache) + return parent.type(self, type, typecache) +end |