]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Neural: Fix PCA based learning
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 22:51:38 +0000 (23:51 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 27 Aug 2020 22:51:38 +0000 (23:51 +0100)
src/lua/lua_kann.c
src/plugins/lua/neural.lua

index 1827fe1acfa2d528a2cd83baa368e92a4539442d..db12e1f871f91151ea206d83e1ebe58e3727bacb 100644 (file)
@@ -1023,6 +1023,7 @@ static int
 lua_kann_train1 (lua_State *L)
 {
        kann_t *k = lua_check_kann (L, 1);
+       struct rspamd_lua_tensor *pca = NULL;
 
        /* Default train params */
        double lr = 0.001;
@@ -1055,8 +1056,8 @@ lua_kann_train1 (lua_State *L)
 
                        if (!rspamd_lua_parse_table_arguments (L, 4, &err,
                                        RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
-                                       "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
-                                       &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+                                       "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
+                                       &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
                                n = luaL_error (L, "invalid params: %s",
                                                err ? err->message : "unknown error");
                                g_error_free (err);
@@ -1065,36 +1066,83 @@ lua_kann_train1 (lua_State *L)
                        }
                }
 
-               float **x, **y;
+               if (pca) {
+                       /* Check pca matrix validity */
+                       if (pca->ndims != 2) {
+                               return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+                       }
+
+                       if (pca->dim[0] != n_in) {
+                               return luaL_error (L, "invalid pca tensor: "
+                                                 "matrix must have %d rows and it has %d rows instead",
+                                                 n_in, pca->dim[0]);
+                       }
+               }
 
-               /* Fill vectors */
+               float **x, **y, *tmp_row = NULL;
+
+               /* Fill vectors row by row */
                x = (float **)g_malloc0 (sizeof (float *) * n);
                y = (float **)g_malloc0 (sizeof (float *) * n);
 
+               if (pca) {
+                       tmp_row = g_malloc (sizeof (float) * pca->dim[1]);
+               }
+
                for (int s = 0; s < n; s ++) {
                        /* Inputs */
                        lua_rawgeti (L, 2, s + 1);
                        x[s] = (float *)g_malloc (sizeof (float) * n_in);
 
-                       if (rspamd_lua_table_size (L, -1) != n_in) {
-                               FREE_VEC (x, n);
-                               FREE_VEC (y, n);
+                       if (pca == NULL) {
+                               if (rspamd_lua_table_size (L, -1) != n_in) {
+                                       FREE_VEC (x, n);
+                                       FREE_VEC (y, n);
 
-                               lua_pop (L, 1);
-                               n = luaL_error (L, "invalid params at pos %d: "
-                                          "bad input dimension %d; %d expected",
-                                               s + 1,
-                                               (int)rspamd_lua_table_size (L, -1),
-                                               n_in);
+                                       n = luaL_error (L, "invalid params at pos %d: "
+                                                                          "bad input dimension %d; %d expected",
+                                                       s + 1,
+                                                       (int) rspamd_lua_table_size (L, -1),
+                                                       n_in);
+                                       lua_pop (L, 1);
 
-                               return n;
+                                       return n;
+                               }
+
+                               for (int i = 0; i < n_in; i++) {
+                                       lua_rawgeti (L, -1, i + 1);
+                                       x[s][i] = lua_tonumber (L, -1);
+
+                                       lua_pop (L, 1);
+                               }
                        }
+                       else {
+                               if (rspamd_lua_table_size (L, -1) != pca->dim[1]) {
+                                       FREE_VEC (x, n);
+                                       FREE_VEC (y, n);
+                                       g_free (tmp_row);
+
+                                       n = luaL_error (L, "(pca on) invalid params at pos %d: "
+                                                                          "bad input dimension %d; %d expected",
+                                                       s + 1,
+                                                       (int) rspamd_lua_table_size (L, -1),
+                                                       pca->dim[1]);
+                                       lua_pop (L, 1);
 
-                       for (int i = 0; i < n_in; i ++) {
-                               lua_rawgeti (L, -1, i + 1);
-                               x[s][i] = lua_tonumber (L, -1);
+                                       return n;
+                               }
 
-                               lua_pop (L, 1);
+
+                               for (int i = 0; i < pca->dim[1]; i++) {
+                                       lua_rawgeti (L, -1, i + 1);
+                                       tmp_row[i] = lua_tonumber (L, -1);
+
+                                       lua_pop (L, 1);
+                               }
+
+                               kad_sgemm_simple (0, 0, pca->dim[0], 1,
+                                               pca->dim[1], pca->data,
+                                               tmp_row, x[s]);
                        }
 
                        lua_pop (L, 1);
@@ -1104,9 +1152,9 @@ lua_kann_train1 (lua_State *L)
                        lua_rawgeti (L, 3, s + 1);
 
                        if (rspamd_lua_table_size (L, -1) != n_out) {
-                               lua_pop (L, 1);
                                FREE_VEC (x, n);
                                FREE_VEC (y, n);
+                               g_free (tmp_row);
 
                                n = luaL_error (L, "invalid params at pos %d: "
                                           "bad output dimension %d; "
@@ -1114,6 +1162,7 @@ lua_kann_train1 (lua_State *L)
                                                s + 1,
                                                (int)rspamd_lua_table_size (L, -1),
                                                n_out);
+                               lua_pop (L, 1);
 
                                return n;
                        }
@@ -1142,6 +1191,7 @@ lua_kann_train1 (lua_State *L)
 
                FREE_VEC (x, n);
                FREE_VEC (y, n);
+               g_free (tmp_row);
        }
        else {
                return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
index 5b4ff8b3be04bf8b39eddafdb4aba9051a0d33b3..0258fb0b0675ed9977d8942b68c5b1b1bff2941b 100644 (file)
@@ -640,17 +640,15 @@ end
 
 -- This is an utility function for PCA training
 local function fill_scatter(inputs)
-  local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
-  local row_len = #inputs[1]
+  local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1])
+  local nsamples = #inputs
 
-  if type(inputs) == 'table' then
-    -- Convert to a tensor
-    inputs = rspamd_tensor.fromtable(inputs)
-  end
+  -- Convert to a tensor where each row is an input dimension
+  inputs = rspamd_tensor.fromtable(inputs):transpose()
 
   local meanv = inputs:mean()
 
-  for i=1,row_len do
+  for i=1,nsamples do
     local col = rspamd_tensor.new(1, #inputs)
     for j=1,#inputs do
       local x = inputs[j][i] - meanv[j]
@@ -679,6 +677,8 @@ local function learn_pca(inputs, max_inputs)
     w[i] = scatter_matrix[#scatter_matrix - i + 1]
   end
 
+  lua_util.debugm(N, 'pca matrix: %s', w)
+
   return w
 end
 
@@ -856,8 +856,11 @@ local function spawn_train(worker, ev_base, rule, set, ann_key, ham_vec, spam_ve
         local profile_serialized = ucl.to_format(profile, 'json-compact', true)
 
         rspamd_logger.infox(rspamd_config,
-            'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
-            rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
+            'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+            rule.prefix, set.name,
+            #data, #ann_data,
+            #(set.ann.pca or {}), #(pca_data or {}),
+            set.ann.redis_key, ann_key)
 
         lua_redis.exec_redis_script(redis_save_unlock_id,
             {ev_base = ev_base, is_write = true},