aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/kann
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/kann')
-rw-r--r--contrib/kann/kautodiff.c29
-rw-r--r--contrib/kann/kautodiff.h9
2 files changed, 37 insertions, 1 deletions
diff --git a/contrib/kann/kautodiff.c b/contrib/kann/kautodiff.c
index 47a86a71e..7b0bf8e93 100644
--- a/contrib/kann/kautodiff.c
+++ b/contrib/kann/kautodiff.c
@@ -900,6 +900,7 @@ void kad_vec_mul_sum(int n, float *a, const float *b, const float *c)
void kad_saxpy(int n, float a, const float *x, float *y) { kad_saxpy_inlined(n, a, x, y); }
#ifdef HAVE_CBLAS
+extern void ssyev(const char* jobz, const char* uplo, int* n, float* a, int* lda, float* w, float* work, int* lwork, int* info);
#ifdef HAVE_CBLAS_H
#include "cblas.h"
#else
@@ -947,6 +948,34 @@ void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float
}
#endif
+bool kad_ssyev_simple(int N, float *A, float *eugenvals)
+{
+#ifndef HAVE_CBLAS
+ return false;
+#else
+ int n = N, lda = N, info, lwork;
+ float wkopt;
+ float *work;
+
+ /* Query and allocate the optimal workspace */
+ lwork = -1;
+ ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, &wkopt, &lwork, &info);
+ lwork = wkopt;
+ work = (float*) g_malloc(lwork * sizeof(double));
+ ssyev ("Vectors", "Upper", &n, A, &lda, eugenvals, work, &lwork, &info);
+ /* Check for convergence */
+ if (info > 0) {
+ g_free (work);
+
+ return false;
+ }
+
+ g_free (work);
+
+ return true;
+#endif
+}
+
/***************************
* Random number generator *
***************************/
diff --git a/contrib/kann/kautodiff.h b/contrib/kann/kautodiff.h
index e51176c84..8c797205c 100644
--- a/contrib/kann/kautodiff.h
+++ b/contrib/kann/kautodiff.h
@@ -244,6 +244,13 @@ static inline int kad_len(const kad_node_t *p) /* calculate the size of p->x */
}
/* Additions by Rspamd */
-void kad_sgemm_simple(int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+void kad_sgemm_simple (int trans_A, int trans_B, int M, int N, int K, const float *A, const float *B, float *C);
+/**
+ * Calculate eugenvectors and eugenvalues
+ * @param N dimensions of A (must be NxN)
+ * @param A input matrix (part of it will be destroyed, so copy if needed), on finish the first `nwork` columns will have eugenvectors
+ * @param eugenvals eugenvalues, must be N elements vector
+ */
+bool kad_ssyev_simple (int N, float *A, float *eugenvals);
#endif