lua_kann_apply1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
+ struct rspamd_lua_tensor *pca = NULL;
if (k) {
if (lua_istable (L, 2)) {
gsize vec_len = rspamd_lua_table_size (L, 2);
- float *vec = (float *) g_malloc (sizeof (float) * vec_len);
+ float *vec = (float *) g_malloc (sizeof (float) * vec_len),
+ *pca_out = NULL;
int i_out;
int n_in = kann_dim_in (k);
if (n_in <= 0) {
+ g_free (vec);
return luaL_error (L, "invalid inputs count: %d", n_in);
}
- if (n_in != vec_len) {
- return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
- (int) vec_len, n_in);
+ if (lua_isuserdata (L, 3)) {
+ pca = lua_check_tensor (L, 3);
+
+ if (pca) {
+ if (pca->ndims != 2) {
+ g_free (vec);
+ return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+ }
+
+ if (pca->dim[0] != n_in) {
+ g_free (vec);
+ return luaL_error (L, "invalid pca tensor: "
+ "matrix must have %d rows and it has %d rows instead",
+ n_in, pca->dim[0]);
+ }
+ }
+ else {
+ g_free (vec);
+ return luaL_error (L, "invalid params: pca matrix expected");
+ }
+ }
+ else {
+ if (n_in != vec_len) {
+ g_free (vec);
+ return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
+ (int) vec_len, n_in);
+ }
}
for (gsize i = 0; i < vec_len; i++) {
}
kann_set_batch_size (k, 1);
- kann_feed_bind (k, KANN_F_IN, 0, &vec);
+ if (pca) {
+ pca_out = g_malloc (sizeof (float) * n_in);
+
+ kad_sgemm_simple (0, 0, pca->dim[0], 1,
+ pca->dim[1], pca->data,
+ vec, pca_out);
+
+ kann_feed_bind (k, KANN_F_IN, 0, &pca_out);
+ }
+ else {
+ kann_feed_bind (k, KANN_F_IN, 0, &vec);
+ }
+
kad_eval_at (k->n, k->v, i_out);
gsize outlen = kad_len (k->v[i_out]);
}
g_free (vec);
+ g_free (pca_out);
}
else if (lua_isuserdata (L, 2)) {
struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);