summaryrefslogtreecommitdiffstats
path: root/contrib/lua-torch/nn/lib/THNN/generic/DistKLDivCriterion.c
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/lua-torch/nn/lib/THNN/generic/DistKLDivCriterion.c')
-rw-r--r--contrib/lua-torch/nn/lib/THNN/generic/DistKLDivCriterion.c44
1 files changed, 44 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/lib/THNN/generic/DistKLDivCriterion.c b/contrib/lua-torch/nn/lib/THNN/generic/DistKLDivCriterion.c
new file mode 100644
index 000000000..6bd6aa067
--- /dev/null
+++ b/contrib/lua-torch/nn/lib/THNN/generic/DistKLDivCriterion.c
@@ -0,0 +1,44 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/DistKLDivCriterion.c"
+#else
+
+void THNN_(DistKLDivCriterion_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *output,
+ bool sizeAverage)
+{
+ THNN_CHECK_NELEMENT(input, target);
+ THNN_CHECK_DIM_SIZE(output, 1, 0, 1);
+
+ real sum = 0;
+
+ TH_TENSOR_APPLY2(real, input, real, target,
+ sum += *target_data > 0 ? *target_data * (log(*target_data) - *input_data) : 0;
+ );
+
+ if (sizeAverage)
+ sum /= THTensor_(nElement)(input);
+
+ THTensor_(set1d)(output, 0, sum);
+}
+
+void THNN_(DistKLDivCriterion_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *target,
+ THTensor *gradInput,
+ bool sizeAverage)
+{
+ THNN_CHECK_NELEMENT(input, target);
+
+ real norm = (sizeAverage ? 1./((real)THTensor_(nElement)(input)) : 1.);
+
+ THTensor_(resizeAs)(gradInput, input);
+ TH_TENSOR_APPLY3(real, gradInput, real, input, real, target,
+ *gradInput_data = *target_data > 0 ? norm * (-*target_data) : 0;
+ );
+}
+
+#endif