diff options
Diffstat (limited to 'contrib/lua-torch/nn/lib/THNN/generic/ClassNLLCriterion.c')
-rw-r--r-- | contrib/lua-torch/nn/lib/THNN/generic/ClassNLLCriterion.c | 163 |
1 files changed, 163 insertions, 0 deletions
diff --git a/contrib/lua-torch/nn/lib/THNN/generic/ClassNLLCriterion.c b/contrib/lua-torch/nn/lib/THNN/generic/ClassNLLCriterion.c new file mode 100644 index 000000000..4cf37aeaf --- /dev/null +++ b/contrib/lua-torch/nn/lib/THNN/generic/ClassNLLCriterion.c @@ -0,0 +1,163 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/ClassNLLCriterion.c" +#else + +void THNN_(ClassNLLCriterion_updateOutput)( + THNNState *state, + THTensor *input, + THIndexTensor *target, + THTensor *output, + bool sizeAverage, + THTensor *weights, + THTensor *total_weight, + long ignore_index) +{ + THNN_CHECK_DIM_SIZE(output, 1, 0, 1); + THNN_CHECK_DIM_SIZE(total_weight, 1, 0, 1); + int n_dims = THTensor_(nDimension)(input); + int n_classes = THTensor_(size)(input, n_dims - 1); + ignore_index -= TH_INDEX_BASE; + + if (THIndexTensor_(nDimension)(target) > 1) { + THError("multi-target not supported"); + } + if (THTensor_(nDimension)(input) > 2) { + THError("input tensor should be 1D or 2D"); + } + if (weights && THTensor_(nElement)(weights) != n_classes) { + THDescBuff s1 = THTensor_(sizeDesc)(weights); + THError("weight tensor should be defined either for all %d classes or no classes" + " but got weight tensor of shape: %s", n_classes, s1.str); + } + + input = THTensor_(newContiguous)(input); + target = THIndexTensor_(newContiguous)(target); + weights = weights ? THTensor_(newContiguous)(weights) : NULL; + + real *input_data = THTensor_(data)(input); + THIndex_t *target_data = THIndexTensor_(data)(target); + real *weights_data = weights ? THTensor_(data)(weights) : NULL; + real *output_data = THTensor_(data)(output); + real *total_weight_data = THTensor_(data)(total_weight); + + output_data[0] = total_weight_data[0] = 0.0; + + if (THTensor_(nDimension)(input) == 1) { + int cur_target = target_data[0] - TH_INDEX_BASE; + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); + total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f; + output_data[0] = -input_data[cur_target] * total_weight_data[0]; + } + } else if (THTensor_(nDimension)(input) == 2) { + int batch_size = THTensor_(size)(input, 0); + THAssert(THIndexTensor_(size)(target, 0) == batch_size); + + int n_target = THTensor_(size)(input, 1); + + int i; + for (i = 0; i < batch_size; i++) { + int cur_target = target_data[i] - TH_INDEX_BASE; + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); + + real cur_weight = weights ? weights_data[cur_target] : 1.0f; + total_weight_data[0] += cur_weight; + output_data[0] -= input_data[i * n_target + cur_target] * cur_weight; + } + } + } + + if (sizeAverage && total_weight_data[0]) { + output_data[0] /= total_weight_data[0]; + } + + if (weights) { + THTensor_(free)(weights); + } + THTensor_(free)(input); + THIndexTensor_(free)(target); +} + +void THNN_(ClassNLLCriterion_updateGradInput)( + THNNState *state, + THTensor *input, + THIndexTensor *target, + THTensor *gradInput, + bool sizeAverage, + THTensor *weights, + THTensor *total_weight, + long ignore_index) +{ + int n_dims = THTensor_(nDimension)(input); + int n_classes = THTensor_(size)(input, n_dims - 1); + ignore_index -= TH_INDEX_BASE; + + if (!THTensor_(isContiguous)(gradInput)) { + THError("gradInput must be contiguous"); + } + + real *total_weight_data = THTensor_(data)(total_weight); + + if (!(*total_weight_data > 0)) { + return; + } + + if (THIndexTensor_(nDimension)(target) > 1) { + THError("multi-target not supported"); + } + + if (THTensor_(nDimension)(input) > 2) { + THError("input tensor should be 1D or 2D"); + } + + if (weights && THTensor_(nElement)(weights) != n_classes) { + THError("weight tensor should be defined either for all or no classes"); + } + + target = THIndexTensor_(newContiguous)(target); + weights = weights ? THTensor_(newContiguous)(weights) : NULL; + + THIndex_t *target_data = THIndexTensor_(data)(target); + real *weights_data = weights ? THTensor_(data)(weights) : NULL; + real *gradInput_data = THTensor_(data)(gradInput); + + if (THTensor_(nDimension)(input) == 1) { + int cur_target = target_data[0] - TH_INDEX_BASE; + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); + + gradInput_data[cur_target] = + (!sizeAverage && weights) ? -weights_data[cur_target] : -1; + } + + } else if (THTensor_(nDimension)(input) == 2) { + int batch_size = THTensor_(size)(input, 0); + THAssert(THIndexTensor_(size)(target, 0) == batch_size); + + int n_target = THTensor_(size)(input, 1); + + int i; + for (i = 0; i < batch_size; i++){ + int cur_target = target_data[i] - TH_INDEX_BASE; + + if (cur_target != ignore_index) { + THAssert(cur_target >= 0 && cur_target < n_classes); + + gradInput_data[i * n_target + cur_target] = + -(weights ? weights_data[cur_target] : 1.0f); + + if (sizeAverage && *total_weight_data) { + gradInput_data[i * n_target + cur_target] /= *total_weight_data; + } + } + } + } + + THIndexTensor_(free)(target); + if (weights) { + THTensor_(free)(weights); + } +} + +#endif |