summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-23 18:14:15 +0100
committerVsevolod Stakhov <vsevolod@highsecure.ru>2018-05-23 18:14:15 +0100
commit714eb56e1760fdfb26afccde92664d3a2f1e8435 (patch)
tree84d1399acbb92f852b4bd64f9ea5412680b0c6ab /contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c
parent220a51ff68013dd668a45b78c60a7b8bfc10f074 (diff)
downloadrspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.tar.gz
rspamd-714eb56e1760fdfb26afccde92664d3a2f1e8435.zip
[Minor] Move lua contrib libraries to lua- prefix
Diffstat (limited to 'contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c')
-rw-r--r--contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c73
1 files changed, 73 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c b/contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c
new file mode 100644
index 000000000..274a27e3b
--- /dev/null
+++ b/contrib/lua-torch/nn/lib/THNN/generic/GatedLinearUnit.c
@@ -0,0 +1,73 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/GatedLinearUnit.c"
+#else
+
+void THNN_(GatedLinear_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int dim)
+{
+ // size output to half of input
+ dim = dim - TH_INDEX_BASE;
+ const long nIn = THTensor_(size)(input, dim);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
+ dim + TH_INDEX_BASE, nIn);
+
+ const long inputSize = THTensor_(size)(input, dim) / 2;
+ THLongStorage *newSizes = THTensor_(newSizeOf)(input);
+ THLongStorage_set(newSizes, dim, inputSize);
+ THTensor_(resize)(output, newSizes, NULL);
+
+ // halve tensor
+ THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
+ THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);
+
+ // x = x1:cmul( sigmoid(x2) )
+ THTensor_(sigmoid)(output, secondHalf);
+ THTensor_(cmul)(output, output, firstHalf);
+
+ THLongStorage_free(newSizes);
+ THTensor_(free)(firstHalf);
+ THTensor_(free)(secondHalf);
+}
+
+void THNN_(GatedLinear_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int dim)
+{
+ // set up tensors
+ dim = dim - TH_INDEX_BASE;
+ const long nIn = THTensor_(size)(input, dim);
+ THArgCheck(nIn % 2 == 0, 2, "Halving dimension must be even. Dim %d is size %ld",
+ dim + TH_INDEX_BASE, nIn);
+
+ THTensor_(resizeAs)(gradInput, input);
+ const long inputSize = THTensor_(size)(input, dim) / 2;
+ THTensor *firstHalf = THTensor_(newNarrow)(input, dim, 0, inputSize);
+ THTensor *secondHalf = THTensor_(newNarrow)(input, dim, inputSize, inputSize);
+ THTensor *gradInputfirstHalf = THTensor_(newNarrow)(gradInput, dim, 0, inputSize);
+ THTensor *gradInputsecondHalf = THTensor_(newNarrow)(gradInput, dim, inputSize, inputSize);
+
+ THTensor_(sigmoid)(gradInputfirstHalf, secondHalf);
+
+ TH_TENSOR_APPLY2(real, gradInputsecondHalf, real, gradInputfirstHalf,
+ real z = *gradInputfirstHalf_data;
+ *gradInputsecondHalf_data = (1. - z) * z;
+ );
+
+ THTensor_(cmul)(gradInputfirstHalf, gradInputfirstHalf, gradOutput);
+
+ THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, gradOutput);
+ THTensor_(cmul)(gradInputsecondHalf, gradInputsecondHalf, firstHalf);
+
+ THTensor_(free)(firstHalf);
+ THTensor_(free)(secondHalf);
+ THTensor_(free)(gradInputfirstHalf);
+ THTensor_(free)(gradInputsecondHalf);
+}
+
+#endif