aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/CosineDistance.lua
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/CosineDistance.lua')
-rw-r--r--contrib/lua-torch/nn/CosineDistance.lua116
1 files changed, 116 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/CosineDistance.lua b/contrib/lua-torch/nn/CosineDistance.lua
new file mode 100644
index 000000000..fe4e4b9f5
--- /dev/null
+++ b/contrib/lua-torch/nn/CosineDistance.lua
@@ -0,0 +1,116 @@
+local CosineDistance, parent = torch.class('nn.CosineDistance', 'nn.Module')
+
+function CosineDistance:__init()
+ parent.__init(self)
+ self.gradInput = {torch.Tensor(), torch.Tensor()}
+end
+
+local function makeContiguous(self, input1, input2)
+ if not input1:isContiguous() then
+ self._input1 = self._input1 or input1.new()
+ self._input1:resizeAs(input1):copy(input1)
+ input1 = self._input1
+ end
+ if not input2:isContiguous() then
+ self._input2 = self._input2 or input2.new()
+ self._input2:resizeAs(input2):copy(input2)
+ input2 = self._input2
+ end
+ return input1, input2
+end
+
+function CosineDistance:updateOutput(input)
+ local input1, input2 = input[1], input[2]
+
+ input1, input2 = makeContiguous(self, input1, input2)
+
+ if input1:dim() == 1 then
+ input1 = input1:view(1,-1)
+ input2 = input2:view(1,-1)
+ end
+
+ if not self.buffer then
+ self.buffer = input1.new()
+ self.w1 = input1.new()
+ self.w22 = input1.new()
+ self.w = input1.new()
+ self.w32 = input1.new()
+ self.ones = input1.new()
+ end
+
+ self.buffer:cmul(input1,input2)
+ self.w1:sum(self.buffer,2)
+
+ local epsilon = 1e-12
+ self.buffer:cmul(input1,input1)
+ self.w22:sum(self.buffer,2):add(epsilon)
+ self.ones:resizeAs(self.w22):fill(1)
+ self.w22:cdiv(self.ones, self.w22)
+ self.w:resizeAs(self.w22):copy(self.w22)
+
+ self.buffer:cmul(input2,input2)
+ self.w32:sum(self.buffer,2):add(epsilon)
+ self.w32:cdiv(self.ones, self.w32)
+ self.w:cmul(self.w32)
+ self.w:sqrt()
+
+ self.output:cmul(self.w1,self.w)
+ self.output:resize(input1:size(1))
+
+ return self.output
+end
+
+function CosineDistance:updateGradInput(input, gradOutput)
+ local v1 = input[1]
+ local v2 = input[2]
+ local not_batch = false
+
+ v1, v2 = makeContiguous(self, v1, v2)
+
+ if v1:dim() == 1 then
+ v1 = v1:view(1,-1)
+ v2 = v2:view(1,-1)
+ not_batch = true
+ end
+
+ if #self.gradInput ~= 2 then
+ self.gradInput[1] = self.gradInput[1] or v1.new()
+ self.gradInput[2] = self.gradInput[2] or v1.new()
+ end
+
+ local gw1 = self.gradInput[1]
+ local gw2 = self.gradInput[2]
+ gw1:resizeAs(v1):copy(v2)
+ gw2:resizeAs(v1):copy(v1)
+
+ self.buffer:cmul(self.w1,self.w22)
+ gw1:addcmul(-1,self.buffer:expandAs(v1),v1)
+ gw1:cmul(self.w:expandAs(v1))
+
+ self.buffer:cmul(self.w1,self.w32)
+ gw2:addcmul(-1,self.buffer:expandAs(v1),v2)
+ gw2:cmul(self.w:expandAs(v1))
+
+ local go = gradOutput:view(-1,1):expandAs(v1)
+ gw1:cmul(go)
+ gw2:cmul(go)
+
+ if not_batch then
+ self.gradInput[1]:resize(gw1:size(2))
+ self.gradInput[2]:resize(gw2:size(2))
+ end
+
+ return self.gradInput
+end
+
+function CosineDistance:clearState()
+ nn.utils.clear(self, {
+ 'buffer',
+ 'w1',
+ 'w22',
+ 'w',
+ 'w32',
+ 'ones',
+ })
+ return parent.clearState(self)
+end