summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/DotProduct.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/DotProduct.lua')
-rw-r--r--contrib/lua-torch/nn/DotProduct.lua61
1 files changed, 61 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/DotProduct.lua b/contrib/lua-torch/nn/DotProduct.lua
new file mode 100644
index 000000000..ccd347e6b
--- /dev/null
+++ b/contrib/lua-torch/nn/DotProduct.lua
@@ -0,0 +1,61 @@
+local DotProduct, parent = torch.class('nn.DotProduct', 'nn.Module')
+
+function DotProduct:__init()
+ parent.__init(self)
+ self.gradInput = {torch.Tensor(), torch.Tensor()}
+end
+
+function DotProduct:updateOutput(input)
+ local input1, input2 = input[1], input[2]
+ if input1:dim() == 1 then
+ -- convert non batch input to batch input
+ input1 = input1:view(1,-1)
+ input2 = input2:view(1,-1)
+ end
+ if not self.buffer then
+ self.buffer = input1.new()
+ end
+ self.buffer:cmul(input1, input2)
+ self.output:sum(self.buffer, 2)
+ self.output:resize(input1:size(1))
+ return self.output
+end
+
+function DotProduct:updateGradInput(input, gradOutput)
+ local v1 = input[1]
+ local v2 = input[2]
+ local not_batch = false
+
+ if #self.gradInput ~= 2 then
+ self.gradInput[1] = self.gradInput[1] or input[1].new()
+ self.gradInput[2] = self.gradInput[2] or input[2].new()
+ end
+
+ if v1:dim() == 1 then
+ v1 = v1:view(1,-1)
+ v2 = v2:view(1,-1)
+ not_batch = true
+ end
+
+ local gw1 = self.gradInput[1]
+ local gw2 = self.gradInput[2]
+ gw1:resizeAs(v1):copy(v2)
+ gw2:resizeAs(v2):copy(v1)
+
+ local go = gradOutput:view(-1,1):expandAs(v1)
+ gw1:cmul(go)
+ gw2:cmul(go)
+
+ if not_batch then
+ -- unbatch gradInput
+ self.gradInput[1]:set(gw1:select(1,1))
+ self.gradInput[2]:set(gw2:select(1,1))
+ end
+
+ return self.gradInput
+end
+
+function DotProduct:clearState()
+ if self.buffer then self.buffer:set() end
+ return parent.clearState(self)
+end