Browse Source

[Fix] Another try to fix threading in torch

tags/1.7.0
Vsevolod Stakhov 6 years ago
parent
commit
8308d6a677

+ 2
- 0
contrib/torch/torch7/lib/TH/CMakeLists.txt View File

@@ -244,6 +244,8 @@ IF(BLAS_FOUND)
TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES})
IF(BLAS_INFO STREQUAL "mkl")
ADD_DEFINITIONS(-DTH_BLAS_MKL)
ELSEIF(BLAS_INFO STREQUAL "open")
ADD_DEFINITIONS(-DTH_BLAS_OPEN)
ENDIF()
ENDIF(BLAS_FOUND)


+ 14
- 3
contrib/torch/torch7/lib/TH/THGeneral.c View File

@@ -324,15 +324,26 @@ void THSetNumThreads(int num_threads)
#ifdef _OPENMP
omp_set_num_threads(num_threads);
#endif
#ifdef TH_BLAS_OPEN
extern void openblas_set_num_threads(int);
openblas_set_num_threads(num_threads);
#endif
}

int THGetNumThreads(void)
{
int nthreads = 1;
#ifdef _OPENMP
return omp_get_max_threads();
#else
return 1;
nthreads = omp_get_max_threads();
#endif
#ifdef TH_BLAS_OPEN
int bl_threads = 1;
extern int openblas_get_num_threads(void);
bl_threads = openblas_get_num_threads();
nthreads = nthreads > bl_threads ? bl_threads : nthreads;
#endif

return nthreads;
}

int THGetNumCores(void)

+ 11
- 1
contrib/torch/torch7/lib/TH/cmake/FindBLAS.cmake View File

@@ -235,7 +235,17 @@ if((NOT BLAS_LIBRARIES)
""
"blas")
if (BLAS_LIBRARIES)
set(BLAS_INFO "generic")
check_fortran_libraries(
TMP_BLAS_LIBRARIES
TMP_BLAS
openblas_get_num_threads
""
"blas")
if (TMP_BLAS_LIBRARIES)
set(BLAS_INFO "open")
else()
set(BLAS_INFO "generic")
endif()
endif (BLAS_LIBRARIES)
endif()


Loading…
Cancel
Save