summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c
blob: 274a27e3b04d0c73c840e64a0a1510c7e6115287 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/GatedLinearUnit.c"
#else

void THNN_(GatedLinear_updateOutput)(
          THNNState *state,
          THTensor *input,
          THTensor *output,
          int dim)
{
  // size output to half of input
  dim = dim - TH_INDEX_BASE;
  const long nIn = THTensor_(size)(input, dim);
  THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
      dim + TH_INDEX_BASE, nIn);

  const long inputSize = THTensor_(size)(input, dim) / 2;
  THLongStorage *newSizes = THTensor_(newSizeOf)(input);
  THLongStorage_set(newSizes, dim, inputSize);
  THTensor_(resize)(output, newSizes, NULL);

  // halve tensor
  THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
  THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);

  // x = x1:cmul( sigmoid(x2) )
  THTensor_(sigmoid)(output, secondHalf);
  THTensor_(cmul)(output, output, firstHalf);

  THLongStorage_free(newSizes);
  THTensor_(free)(firstHalf);
  THTensor_(free)(secondHalf);
}

void THNN_(GatedLinear_updateGradInput)(
          THNNState *state,
          THTensor *input,
          THTensor *gradOutput,
          THTensor *gradInput,
          int dim)
{
  // set up tensors
  dim = dim - TH_INDEX_BASE;
  const long nIn = THTensor_(size)(input, dim);
  THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
      dim + TH_INDEX_BASE, nIn);

  THTensor_(resizeAs)(gradInput, input);
  const long inputSize = THTensor_(size)(input, dim) / 2;
  THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
  THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);
  THTensor *gradInputfirstHalf = THTensor_(newNarrow)(gradInput, dim, 0, inputSize);
  THTensor *gradInputsecondHalf = THTensor_(newNarrow)(gradInput, dim, inputSize, inputSize);

  THTensor_(sigmoid)(gradInputfirstHalf, secondHalf);

  TH_TENSOR_APPLY2(real, gradInputsecondHalf, real, gradInputfirstHalf,
    real z = *gradInputfirstHalf_data;
    *gradInputsecondHalf_data = (1. - z) * z;
  );

  THTensor_(cmul)(gradInputfirstHalf, gradInputfirstHalf, gradOutput);

  THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, gradOutput);
  THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, firstHalf);

  THTensor_(free)(firstHalf);
  THTensor_(free)(secondHalf);
  THTensor_(free)(gradInputfirstHalf);
  THTensor_(free)(gradInputsecondHalf);
}

#endif