summaryrefslogtreecommitdiffstats
path: root/contrib/torch/nn/lib/THNN/generic/Sqrt.c
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/torch/nn/lib/THNN/generic/Sqrt.c')
-rw-r--r--contrib/torch/nn/lib/THNN/generic/Sqrt.c52
1 files changed, 52 insertions, 0 deletions
diff --git a/contrib/torch/nn/lib/THNN/generic/Sqrt.c b/contrib/torch/nn/lib/THNN/generic/Sqrt.c
new file mode 100644
index 000000000..174884e34
--- /dev/null
+++ b/contrib/torch/nn/lib/THNN/generic/Sqrt.c
@@ -0,0 +1,52 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/Sqrt.c"
+#else
+
+void THNN_(Sqrt_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ accreal eps_)
+{
+ real eps = TH_CONVERT_ACCREAL_TO_REAL(eps_);
+ THTensor_(resizeAs)(output, input);
+ THTensor_(sqrt)(output, input);
+}
+
+void THNN_(Sqrt_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ THTensor *output)
+{
+ THNN_CHECK_SHAPE(output, gradOutput);
+ THTensor_(resizeAs)(gradInput, input);
+
+ if (output->nDimension == 1 ||
+ !THTensor_(isContiguous)(output) ||
+ !THTensor_(isContiguous)(gradOutput) ||
+ !THTensor_(isContiguous)(gradInput))
+ {
+ TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, output,
+ *gradInput_data = (*output_data == 0.0) ? 0.0 : (0.5 * (*gradOutput_data / *output_data));
+ );
+ }
+ else
+ {
+ real *gradOutput_data = THTensor_(data)(gradOutput);
+ real *gradInput_data = THTensor_(data)(gradInput);
+ real *output_data = THTensor_(data)(output);
+ long i;
+#pragma omp parallel for private(i)
+ for(i = 0; i < THTensor_(nElement)(output); i++)
+ {
+ if (output_data[i] == 0.0)
+ gradInput_data[i] = 0.0;
+ else
+ gradInput_data[i] = 0.5 * (gradOutput_data[i] / output_data[i]);
+ }
+ }
+}
+
+#endif