aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/BCECriterion.lua
blob: 8bb5f81787a978d352246c87f7681b859242eea5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
local THNN = require 'nn.THNN'
local BCECriterion, parent = torch.class('nn.BCECriterion', 'nn.Criterion')

function BCECriterion:__init(weights, sizeAverage)
   parent.__init(self)
   if sizeAverage ~= nil then
      self.sizeAverage = sizeAverage
   else
      self.sizeAverage = true
   end
   if weights ~= nil then
      assert(weights:dim() == 1, "weights input should be 1-D Tensor")
      self.weights = weights
   end
end


function BCECriterion:__len()
   return self.weights and #self.weights or 0
end

function BCECriterion:updateOutput(input, target)
   -- - log(input) * target - log(1 - input) * (1 - target)
   assert( input:nElement() == target:nElement(),
   "input and target size mismatch")
   self.output_tensor = self.output_tensor or input.new(1)

   local weights = self.weights
   if weights ~= nil and target:dim() ~= 1 then
      weights = self.weights:view(1, target:size(2)):expandAs(target)
   end

   input.THNN.BCECriterion_updateOutput(
      input:cdata(),
      target:cdata(),
      self.output_tensor:cdata(),
      self.sizeAverage,
      THNN.optionalTensor(weights)
   )

   self.output = self.output_tensor[1]
   return self.output
end

function BCECriterion:updateGradInput(input, target)
   -- - (target - input) / ( input (1 - input) )
   assert( input:nElement() == target:nElement(),
   "input and target size mismatch")

   local weights = self.weights
   if weights ~= nil and target:dim() ~= 1 then
      weights = self.weights:view(1, target:size(2)):expandAs(target)
   end

   input.THNN.BCECriterion_updateGradInput(
      input:cdata(),
      target:cdata(),
      self.gradInput:cdata(),
      self.sizeAverage,
      THNN.optionalTensor(weights)
   )

   return self.gradInput
end