From 8308d6a6776ad73b8c0e61fcfdb55007434621cf Mon Sep 17 00:00:00 2001 From: Vsevolod Stakhov Date: Sun, 17 Sep 2017 10:00:20 +0100 Subject: [PATCH] [Fix] Another try to fix threading in torch --- contrib/torch/torch7/lib/TH/CMakeLists.txt | 2 ++ contrib/torch/torch7/lib/TH/THGeneral.c | 17 ++++++++++++++--- .../torch/torch7/lib/TH/cmake/FindBLAS.cmake | 12 +++++++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/contrib/torch/torch7/lib/TH/CMakeLists.txt b/contrib/torch/torch7/lib/TH/CMakeLists.txt index fdaec267c..2b71bc3d0 100644 --- a/contrib/torch/torch7/lib/TH/CMakeLists.txt +++ b/contrib/torch/torch7/lib/TH/CMakeLists.txt @@ -244,6 +244,8 @@ IF(BLAS_FOUND) TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES}) IF(BLAS_INFO STREQUAL "mkl") ADD_DEFINITIONS(-DTH_BLAS_MKL) + ELSEIF(BLAS_INFO STREQUAL "open") + ADD_DEFINITIONS(-DTH_BLAS_OPEN) ENDIF() ENDIF(BLAS_FOUND) diff --git a/contrib/torch/torch7/lib/TH/THGeneral.c b/contrib/torch/torch7/lib/TH/THGeneral.c index ac032b992..50dba205f 100644 --- a/contrib/torch/torch7/lib/TH/THGeneral.c +++ b/contrib/torch/torch7/lib/TH/THGeneral.c @@ -324,15 +324,26 @@ void THSetNumThreads(int num_threads) #ifdef _OPENMP omp_set_num_threads(num_threads); #endif +#ifdef TH_BLAS_OPEN + extern void openblas_set_num_threads(int); + openblas_set_num_threads(num_threads); +#endif } int THGetNumThreads(void) { + int nthreads = 1; #ifdef _OPENMP - return omp_get_max_threads(); -#else - return 1; + nthreads = omp_get_max_threads(); #endif +#ifdef TH_BLAS_OPEN + int bl_threads = 1; + extern int openblas_get_num_threads(void); + bl_threads = openblas_get_num_threads(); + nthreads = nthreads > bl_threads ? bl_threads : nthreads; +#endif + + return nthreads; } int THGetNumCores(void) diff --git a/contrib/torch/torch7/lib/TH/cmake/FindBLAS.cmake b/contrib/torch/torch7/lib/TH/cmake/FindBLAS.cmake index b7835a1f6..1f254d231 100644 --- a/contrib/torch/torch7/lib/TH/cmake/FindBLAS.cmake +++ b/contrib/torch/torch7/lib/TH/cmake/FindBLAS.cmake @@ -235,7 +235,17 @@ if((NOT BLAS_LIBRARIES) "" "blas") if (BLAS_LIBRARIES) - set(BLAS_INFO "generic") + check_fortran_libraries( + TMP_BLAS_LIBRARIES + TMP_BLAS + openblas_get_num_threads + "" + "blas") + if (TMP_BLAS_LIBRARIES) + set(BLAS_INFO "open") + else() + set(BLAS_INFO "generic") + endif() endif (BLAS_LIBRARIES) endif() -- 2.39.5