summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/lib/THNN/generic/HardShrink.c
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/lib/THNN/generic/HardShrink.c')
-rw-r--r--contrib/lua-torch/nn/lib/THNN/generic/HardShrink.c42
1 files changed, 42 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/lib/THNN/generic/HardShrink.c b/contrib/lua-torch/nn/lib/THNN/generic/HardShrink.c
new file mode 100644
index 000000000..aaae85bac
--- /dev/null
+++ b/contrib/lua-torch/nn/lib/THNN/generic/HardShrink.c
@@ -0,0 +1,42 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/HardShrink.c"
+#else
+
+void THNN_(HardShrink_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ accreal lambda_)
+{
+ real lambda = TH_CONVERT_ACCREAL_TO_REAL(lambda_);
+ THTensor_(resizeAs)(output, input);
+
+ TH_TENSOR_APPLY2(real, output, real, input,
+ if (*input_data > lambda)
+ *output_data = *input_data;
+ else if (*input_data < -lambda)
+ *output_data = *input_data;
+ else
+ *output_data = 0;
+ );
+}
+
+void THNN_(HardShrink_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ accreal lambda_)
+{
+ real lambda = TH_CONVERT_ACCREAL_TO_REAL(lambda_);
+ THNN_CHECK_NELEMENT(input, gradOutput);
+ THTensor_(resizeAs)(gradInput, input);
+ TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input,
+ if (*input_data > lambda || *input_data < -lambda)
+ *gradInput_data = *gradOutput_data;
+ else
+ *gradInput_data = 0;
+ );
+}
+
+#endif