]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Another try to fix threading in torch
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Sep 2017 09:00:20 +0000 (10:00 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sun, 17 Sep 2017 09:00:20 +0000 (10:00 +0100)
contrib/torch/torch7/lib/TH/CMakeLists.txt
contrib/torch/torch7/lib/TH/THGeneral.c
contrib/torch/torch7/lib/TH/cmake/FindBLAS.cmake

index fdaec267ccb3c2896b8123ee289751156f5a2ba4..2b71bc3d085eb2da8cc1e5e6e485d5649ff547a6 100644 (file)
@@ -244,6 +244,8 @@ IF(BLAS_FOUND)
   TARGET_LINK_LIBRARIES(TH ${BLAS_LIBRARIES})
   IF(BLAS_INFO STREQUAL "mkl")
     ADD_DEFINITIONS(-DTH_BLAS_MKL)
+  ELSEIF(BLAS_INFO STREQUAL "open")
+    ADD_DEFINITIONS(-DTH_BLAS_OPEN)
   ENDIF()
 ENDIF(BLAS_FOUND)
 
index ac032b992e3bf30f57d0fcc414eac8ec1cba5f24..50dba205ff20fe35f6211ba92fef37bbe176efb0 100644 (file)
@@ -324,15 +324,26 @@ void THSetNumThreads(int num_threads)
 #ifdef _OPENMP
   omp_set_num_threads(num_threads);
 #endif
+#ifdef TH_BLAS_OPEN
+  extern void openblas_set_num_threads(int);
+  openblas_set_num_threads(num_threads);
+#endif
 }
 
 int THGetNumThreads(void)
 {
+  int nthreads = 1;
 #ifdef _OPENMP
-  return omp_get_max_threads();
-#else
-  return 1;
+  nthreads = omp_get_max_threads();
 #endif
+#ifdef TH_BLAS_OPEN
+  int bl_threads = 1;
+  extern int openblas_get_num_threads(void);
+  bl_threads = openblas_get_num_threads();
+  nthreads = nthreads > bl_threads ? bl_threads : nthreads;
+#endif
+
+  return nthreads;
 }
 
 int THGetNumCores(void)
index b7835a1f68558b3a46ae3e533bac7e45a6c3e218..1f254d231c5afa72af1399598cf273d032120612 100644 (file)
@@ -235,7 +235,17 @@ if((NOT BLAS_LIBRARIES)
   ""
   "blas")
   if (BLAS_LIBRARIES)
-    set(BLAS_INFO "generic")
+    check_fortran_libraries(
+            TMP_BLAS_LIBRARIES
+            TMP_BLAS
+            openblas_get_num_threads
+            ""
+            "blas")
+    if (TMP_BLAS_LIBRARIES)
+      set(BLAS_INFO "open")
+    else()
+      set(BLAS_INFO "generic")
+    endif()
   endif (BLAS_LIBRARIES)
 endif()