#ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "generic/LookupTable.c" #else static void THNN_(LookupTable_resetCount)( THInteger_t *count_data, THIndexTensor *input) { ptrdiff_t i; THIndex_t *input_data = THIndexTensor_(data)(input); ptrdiff_t numel = THIndexTensor_(nElement)(input); for (i = 0; isize[0]); count_data = THIntegerTensor_(data)(count); } if (!THTensor_(isContiguous)(gradWeight)) THError("gradWeight must be contiguous"); if (!THIndexTensor_(isContiguous)(input)) THError("input must be contiguous"); if (THIndexTensor_(nDimension)(input) != 1 && THIndexTensor_(nDimension)(input) != 2) { THDescBuff s1 = THIndexTensor_(sizeDesc)(input); THError("input must be a vector or matrix, but is of shape: %s", s1.str); } THIndex_t *input_data = THIndexTensor_(data)(input); ptrdiff_t numel = THIndexTensor_(nElement)(input); long numw = THTensor_(size)(gradWeight, 0); // check that inputs are all within range for (i=0; i= numw + TH_INDEX_BASE) { THError("inputs need to be in the range %ld <= input < %ld, " "but got input of value: %ld", TH_INDEX_BASE, (numw + TH_INDEX_BASE), input_data[i]); } gradOutput = THTensor_(newContiguous)(gradOutput); real *gw = THTensor_(data)(gradWeight); real *go = THTensor_(data)(gradOutput); long stride = THTensor_(stride)(gradWeight, 0); if (count_data) THNN_(LookupTable_resetCount)(count_data, input); #ifdef _OPENMP if (numel > 1000) { // The strategy is to parallelize over sections of the vocabulary, so that // thread 1 handles updates to gradWeight[0..nVocab/nThreads]. Every thread // has to traverse the entire input, but the dominating factor is the axpy // BLAS call. #pragma omp parallel private(i) { int tid = omp_get_thread_num(); int nthreads = omp_get_num_threads(); long start = tid * (numw/nthreads + 1); long end = start + (numw/nthreads + 1); for (i=0; i= start && k < end) { real scale_ = scale; if (count_data) scale_ /= count_data[k]; THBlas_(axpy)(stride, scale_, go + i*stride, 1, gw + k*stride, 1); } } } } THTensor_(free)(gradOutput); return; } #endif for (i=0; i maxNorm) { new_norm = maxNorm / (norm + 1e-7); for (j=0; j= numw + TH_INDEX_BASE) { THError("input need to be in the range %ld <= input < %ld, " "but got input of value: %ld", TH_INDEX_BASE, (numw + TH_INDEX_BASE), row_idx[i]); } } // get unique indices qsort(row_idx, numel, sizeof(THIndex_t), THNN_(compare_THIndex)); ptrdiff_t ptr = 0; for (i=0; i 1000) { // The strategy is to parallelize over the rows that appear in // row_idx, so that thread 1 handles the rows in row_idx[0..numel/nThreads]. // This distributes the work evenly to each thread. #pragma omp parallel for private(i) for (i=0; i