]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Start moving of fann redis to torch
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 2 Sep 2017 19:02:50 +0000 (20:02 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 2 Sep 2017 19:02:50 +0000 (20:02 +0100)
src/lua/lua_config.c
src/lua/lua_task.c
src/plugins/lua/fann_redis.lua

index 21cc4bbe7c74c92e49776111fbd83fdcd4769f5f..6dfd45e12697c0f3813b6c5397be6c38e6e24696 100644 (file)
@@ -694,7 +694,7 @@ static const struct luaL_reg configlib_m[] = {
        LUA_INTERFACE_DEF (config, add_example),
        LUA_INTERFACE_DEF (config, set_peak_cb),
        LUA_INTERFACE_DEF (config, get_cpu_flags),
-       LUA_INTERFACE_DEF (config, get_cpu_flags),
+       LUA_INTERFACE_DEF (config, has_torch),
        {"__tostring", rspamd_lua_class_tostring},
        {"__newindex", lua_config_newindex},
        {NULL, NULL}
index 61ef6d075c6f06e0ab4b72725387df30d3992111..9d139044fc8a6c6d3b90ea6bc48dc96a0cc297f9 100644 (file)
@@ -3134,6 +3134,7 @@ struct tokens_foreach_cbdata {
        struct rspamd_task *task;
        lua_State *L;
        gint idx;
+       gboolean normalize;
 };
 
 static void
@@ -3150,7 +3151,12 @@ tokens_foreach_cb (gint id, const gchar *sym, gint flags, gpointer ud)
        mres = cbd->task->result;
 
        if (mres && (s = g_hash_table_lookup (mres->symbols, sym)) != NULL) {
-               lua_pushnumber (cbd->L, tanh (s->score));
+               if (cbd->normalize) {
+                       lua_pushnumber (cbd->L, tanh (s->score));
+               }
+               else {
+                       lua_pushnumber (cbd->L, s->score);
+               }
        }
        else {
                lua_pushnumber (cbd->L, 0.0);
@@ -3168,6 +3174,15 @@ lua_task_get_symbols_tokens (lua_State *L)
        cbd.task = task;
        cbd.L = L;
        cbd.idx = 1;
+       cbd.normalize = TRUE;
+
+       if (lua_type (L, 2) == LUA_TBOOLEAN) {
+               cbd.normalize = lua_toboolean (L, 2);
+       }
+       else {
+               cbd.normalize = TRUE;
+       }
+
        lua_createtable (L, rspamd_symbols_cache_symbols_count (task->cfg->cache), 0);
        rspamd_symbols_cache_foreach (task->cfg->cache, tokens_foreach_cb, &cbd);
 
index 531a740d904a234e502e2c08a21e2c2532a49078..b0cbdefab42c1ef7aa2d4161c582c60bf9e7b66a 100644 (file)
@@ -27,6 +27,15 @@ local rspamd_util = require "rspamd_util"
 local rspamd_redis = require "lua_redis"
 local fun = require "fun"
 local meta_functions = require "meta_functions"
+local use_torch = false
+local torch
+local nn
+
+if rspamd_config:has_torch() then
+  use_torch = true
+  torch = require "torch"
+  nn = require "nn"
+end
 
 -- Module vars
 local default_options = {
@@ -298,11 +307,15 @@ local function gen_fann_prefix(rule, id)
   local cksum = rspamd_config:get_symbols_cksum():hex()
   -- We also need to count metatokens:
   local n = meta_functions.rspamd_count_metatokens()
+  local tprefix = ''
+  if use_torch then
+    tprefix = 't';
+  end
   if id then
-    return string.format('%s%s%d%s', rule.prefix, cksum, n, id),
+    return string.format('%s%s%s%d%s', tprefix, rule.prefix, cksum, n, id),
       rule.prefix .. id
   else
-    return string.format('%s%s%d', rule.prefix, cksum, n), nil
+    return string.format('%s%s%s%d', tprefix, rule.prefix, cksum, n), nil
   end
 end
 
@@ -311,20 +324,36 @@ local function is_fann_valid(rule, prefix, ann)
     local n = rspamd_config:get_symbols_count() +
         meta_functions.rspamd_count_metatokens()
 
-    if n ~= ann:get_inputs() then
-      rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
-      ' is found in the cache', prefix, ann:get_inputs(), n)
-      return false
-    end
-    local layers = ann:get_layers()
+    if torch then
+      local nlayers = #ann
+      if nlayers ~= rule.nlayers then
+        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+          prefix, nlayers)
+        return false
+      end
 
-    if not layers or #layers ~= rule.nlayers then
-      rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
-        prefix, #layers)
-      return false
-    end
+      local inp = ann:get(1):nElement()
+      if n ~= inp then
+        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+            ' is found in the cache', prefix, inp, n)
+        return false
+      end
+    else
+      if n ~= ann:get_inputs() then
+        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of inputs: %s, %s symbols' ..
+            ' is found in the cache', prefix, ann:get_inputs(), n)
+        return false
+      end
+      local layers = ann:get_layers()
 
-    return true
+      if not layers or #layers ~= rule.nlayers then
+        rspamd_logger.infox(rspamd_config, 'ANN %s has incorrect number of layers: %s',
+          prefix, #layers)
+        return false
+      end
+
+      return true
+    end
   end
 end
 
@@ -348,15 +377,23 @@ local function fann_scores_filter(task)
       -- Add filtered meta tokens
       fun.each(function(e) table.insert(fann_data, e) end, mt)
 
-      local out = fanns[id].fann:test(fann_data)
-      local symscore = string.format('%.3f', out[1])
+      local score
+      if torch then
+        local out = fanns[id].fann:forward(torch.Tensor(fann_data))
+        score = out[1]
+      else
+        local out = fanns[id].fann:test(fann_data)
+        score = out[1]
+      end
+
+      local symscore = string.format('%.3f', score)
       rspamd_logger.infox(task, 'fann score: %s', symscore)
 
-      if out[1] > 0 then
-        local result = rspamd_util.normalize_prob(out[1] / 2.0, 0)
+      if score > 0 then
+        local result = rspamd_util.normalize_prob(score / 2.0, 0)
         task:insert_result(rule.symbol_spam, result, symscore, id)
       else
-        local result = rspamd_util.normalize_prob((-out[1]) / 2.0, 0)
+        local result = rspamd_util.normalize_prob((-score) / 2.0, 0)
         task:insert_result(rule.symbol_ham, result, symscore, id)
       end
     end
@@ -364,14 +401,25 @@ local function fann_scores_filter(task)
 end
 
 local function create_fann(n, nlayers)
-  local layers = {}
-  local div = 1.0
-  for _ = 1, nlayers - 1 do
-    table.insert(layers, math.floor(n / div))
-    div = div * 2
-  end
-  table.insert(layers, 1)
-  return rspamd_fann.create(nlayers, layers)
+  if torch then
+    -- We ignore number of layers so far when using torch
+    local ann = nn.Sequential()
+    local nhidden = math.floor((n + 1) / 2)
+    ann:add(nn.Linear(n, nhidden))
+    ann:add(nn.PReLU())
+    ann:add(nn.Linear(nhidden, 1))
+
+    return ann
+  else
+    local layers = {}
+    local div = 1.0
+    for _ = 1, nlayers - 1 do
+      table.insert(layers, math.floor(n / div))
+      div = div * 2
+    end
+    table.insert(layers, 1)
+    return rspamd_fann.create(nlayers, layers)
+  end
 end
 
 local function create_train_fann(rule, n, id)
@@ -382,13 +430,7 @@ local function create_train_fann(rule, n, id)
   end
   -- Fix that for flexibe layers number
   if fanns[id].fann then
-    if n ~= fanns[id].fann:get_inputs() or --
-      (fanns[id].fann_train and n ~= fanns[id].fann_train:get_inputs()) then
-      rspamd_logger.infox(rspamd_config,
-        'recreate ANN %s as it has a wrong number of inputs, version %s',
-        prefix,
-        fanns[id].version)
-
+    if not is_fann_valid(rule, prefix, fanns[id].fann) then
       fanns[id].fann_train = create_fann(n, rule.nlayers)
       fanns[id].fann = nil
     elseif fanns[id].version % rule.max_usages == 0 then
@@ -421,7 +463,11 @@ local function load_or_invalidate_fann(rule, data, id, ev_base)
     rspamd_logger.errx(rspamd_config, 'cannot decompress ANN %s: %s', prefix, err)
     return
   else
-    ann = rspamd_fann.load_data(ann_data)
+    if torch then
+      ann = torch.MemoryFile(torch.CharStorage():string(tostring(ann_data))):readObject()
+    else
+      ann = rspamd_fann.load_data(ann_data)
+    end
   end
 
   if is_fann_valid(rule, prefix, ann) then
@@ -582,7 +628,15 @@ local function train_fann(rule, _, ev_base, elt)
     else
       rspamd_logger.infox(rspamd_config, 'trained ANN %s: MSE: %s',
         prefix, train_mse)
-      local ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+      local ann_data
+      if torch then
+        local f = torch.MemoryFile()
+        f:writeObject(fanns[elt].fann_train)
+        ann_data = rspamd_util.zstd_compress(f:storage():string())
+      else
+        ann_data = rspamd_util.zstd_compress(fanns[elt].fann_train:data())
+      end
+
       fanns[elt].version = fanns[elt].version + 1
       fanns[elt].fann = fanns[elt].fann_train
       fanns[elt].fann_train = nil
@@ -636,8 +690,7 @@ local function train_fann(rule, _, ev_base, elt)
       end, fun.zip(fun.filter(filt, spam_elts), fun.filter(filt, ham_elts)))
 
       -- Now we can train fann
-      if not fanns[elt] or not fanns[elt].fann_train
-        or n ~= fanns[elt].fann_train:get_inputs() then
+      if not fanns[elt] or not fanns[elt].fann_train then
         -- Create fann if it does not exist
         create_train_fann(rule, n, elt)
       end