diff options
Diffstat (limited to 'contrib/lua-torch/nn/DotProduct.lua')
-rw-r--r-- | contrib/lua-torch/nn/DotProduct.lua | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/DotProduct.lua b/contrib/lua-torch/nn/DotProduct.lua new file mode 100644 index 000000000..ccd347e6b --- /dev/null +++ b/contrib/lua-torch/nn/DotProduct.lua @@ -0,0 +1,61 @@ +local DotProduct, parent = torch.class('nn.DotProduct', 'nn.Module') + +function DotProduct:__init() + parent.__init(self) + self.gradInput = {torch.Tensor(), torch.Tensor()} +end + +function DotProduct:updateOutput(input) + local input1, input2 = input[1], input[2] + if input1:dim() == 1 then + -- convert non batch input to batch input + input1 = input1:view(1,-1) + input2 = input2:view(1,-1) + end + if not self.buffer then + self.buffer = input1.new() + end + self.buffer:cmul(input1, input2) + self.output:sum(self.buffer, 2) + self.output:resize(input1:size(1)) + return self.output +end + +function DotProduct:updateGradInput(input, gradOutput) + local v1 = input[1] + local v2 = input[2] + local not_batch = false + + if #self.gradInput ~= 2 then + self.gradInput[1] = self.gradInput[1] or input[1].new() + self.gradInput[2] = self.gradInput[2] or input[2].new() + end + + if v1:dim() == 1 then + v1 = v1:view(1,-1) + v2 = v2:view(1,-1) + not_batch = true + end + + local gw1 = self.gradInput[1] + local gw2 = self.gradInput[2] + gw1:resizeAs(v1):copy(v2) + gw2:resizeAs(v2):copy(v1) + + local go = gradOutput:view(-1,1):expandAs(v1) + gw1:cmul(go) + gw2:cmul(go) + + if not_batch then + -- unbatch gradInput + self.gradInput[1]:set(gw1:select(1,1)) + self.gradInput[2]:set(gw2:select(1,1)) + end + + return self.gradInput +end + +function DotProduct:clearState() + if self.buffer then self.buffer:set() end + return parent.clearState(self) +end |