aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/CosineEmbeddingCriterion.lua
blob: d55e03130940afe6778b3071595c3f25fef9cdb6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
local CosineEmbeddingCriterion, parent = torch.class('nn.CosineEmbeddingCriterion', 'nn.Criterion')

function CosineEmbeddingCriterion:__init(margin)
   parent.__init(self)
   margin = margin or 0
   self.margin = margin
   self.gradInput = {torch.Tensor(), torch.Tensor()}
   self.sizeAverage = true
end

function CosineEmbeddingCriterion:updateOutput(input,y)

   local input1, input2 = input[1], input[2]

   -- keep backward compatibility
   if type(y) == 'number' then
     self._y = self._y or input1.new(1)
     self._y[1] = y
     y = self._y
   end

   if input1:dim() == 1 then
      input1 = input1:view(1,-1)
      input2 = input2:view(1,-1)
   end

   if not self.buffer then
      self.buffer = input1.new()
      self.w1  = input1.new()
      self.w22 = input1.new()
      self.w  = input1.new()
      self.w32 = input1.new()
      self._outputs = input1.new()
      -- comparison operators behave differently from cuda/c implementations
      if input1:type() == 'torch.CudaTensor' then
         self._idx = input1.new()
      else
         self._idx = torch.ByteTensor()
      end
   end

   self.buffer:cmul(input1,input2)
   self.w1:sum(self.buffer,2)

   local epsilon = 1e-12
   self.buffer:cmul(input1,input1)
   self.w22:sum(self.buffer,2):add(epsilon)
   -- self._outputs is also used as a temporary buffer
   self._outputs:resizeAs(self.w22):fill(1)
   self.w22:cdiv(self._outputs, self.w22)
   self.w:resizeAs(self.w22):copy(self.w22)

   self.buffer:cmul(input2,input2)
   self.w32:sum(self.buffer,2):add(epsilon)
   self.w32:cdiv(self._outputs, self.w32)
   self.w:cmul(self.w32)
   self.w:sqrt()

   self._outputs:cmul(self.w1,self.w)
   self._outputs = self._outputs:select(2,1)

   y.eq(self._idx,y,-1)
   self._outputs[self._idx] = self._outputs[self._idx]:add(-self.margin):cmax(0)
   y.eq(self._idx,y,1)
   self._outputs[self._idx] = self._outputs[self._idx]:mul(-1):add(1)

   self.output = self._outputs:sum()

   if self.sizeAverage then
      self.output = self.output/y:size(1)
   end

   return self.output
end

function CosineEmbeddingCriterion:updateGradInput(input, y)

   local v1  = input[1]
   local v2  = input[2]
   local not_batch = false

   -- keep backward compatibility
   if type(y) == 'number' then
     self._y = self._y or input1.new(1)
     self._y[1] = y
     y = self._y
   end

   if v1:dim() == 1 then
      v1 = v1:view(1,-1)
      v2 = v2:view(1,-1)
      not_batch = true
   end

   local gw1 = self.gradInput[1]
   local gw2 = self.gradInput[2]
   gw1:resizeAs(v1):copy(v2)
   gw2:resizeAs(v1):copy(v1)

   self.buffer:cmul(self.w1,self.w22)
   gw1:addcmul(-1,self.buffer:expandAs(v1),v1)
   gw1:cmul(self.w:expandAs(v1))

   self.buffer:cmul(self.w1,self.w32)
   gw2:addcmul(-1,self.buffer:expandAs(v1),v2)
   gw2:cmul(self.w:expandAs(v1))

   -- self._idx = self._outputs <= 0
   y.le(self._idx,self._outputs,0)
   self._idx = self._idx:view(-1,1):expand(gw1:size())
   gw1[self._idx] = 0
   gw2[self._idx] = 0

   y.eq(self._idx,y,1)
   self._idx = self._idx:view(-1,1):expand(gw2:size())
   gw1[self._idx] = gw1[self._idx]:mul(-1)
   gw2[self._idx] = gw2[self._idx]:mul(-1)

   if self.sizeAverage then
      gw1:div(y:size(1))
      gw2:div(y:size(1))
   end

   if not_batch then
      self.gradInput[1]:resize(gw1:size(2))
      self.gradInput[2]:resize(gw2:size(2))
   end

   return self.gradInput
end

function CosineEmbeddingCriterion:type(type)
   self._idx = nil
   parent.type(self,type)
   -- comparison operators behave differently from cuda/c implementations
   if type == 'torch.CudaTensor' then
      self._idx = torch.CudaTensor()
   else
      self._idx = torch.ByteTensor()
   end
   return self
end