]> source.dussan.org Git - rspamd.git/commitdiff
[Fix] Fix enabling/disabling squeezed rules
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 16 Mar 2018 17:14:32 +0000 (17:14 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 16 Mar 2018 17:14:32 +0000 (17:14 +0000)
lualib/lua_squeeze_rules.lua
src/plugins/lua/settings.lua

index 9bcccf6f0045844d484dc47e4f16e465a67e5b07..e310986d609901e717d3009f711b17ae4b29ad4d 100644 (file)
@@ -25,37 +25,46 @@ local squeezed_rdeps = {} -- squeezed reverse deps
 local SN = 'lua_squeeze'
 local squeeze_sym = 'LUA_SQUEEZE'
 local squeeze_function_ids = {}
+local squeezed_groups = {}
 
 local function gen_lua_squeeze_function(order)
   return function(task)
+    local symbols_disabled = task:cache_get('squeezed_disable')
     for _,data in ipairs(squeezed_rules[order]) do
-      local ret = {data[1](task)}
-
-      if #ret ~= 0 then
-        local first = ret[1]
-        local sym = data[2]
-        -- Function has returned something, so it is rule, not a plugin
-        if type(first) == 'boolean' then
-          if first then
-            table.remove(ret, 1)
 
-            if type(ret[1]) == 'table' then
-              task:insert_result(sym, 1.0, ret[1])
-            else
-              task:insert_result(sym, 1.0, ret)
+      if not symbols_disabled or not symbols_disabled[data[2]] then
+        local ret = {data[1](task)}
+
+        if #ret ~= 0 then
+          local first = ret[1]
+          local sym = data[2]
+          -- Function has returned something, so it is rule, not a plugin
+          if type(first) == 'boolean' then
+            if first then
+              table.remove(ret, 1)
+
+              if type(ret[1]) == 'table' then
+                task:insert_result(sym, 1.0, ret[1])
+              else
+                task:insert_result(sym, 1.0, ret)
+              end
             end
-          end
-        elseif type(first) == 'number' then
-          table.remove(ret, 1)
+          elseif type(first) == 'number' then
+            table.remove(ret, 1)
 
-          if first ~= 0 then
-            task:insert_result(sym, first, ret)
+            if first ~= 0 then
+              task:insert_result(sym, first, ret)
+            end
+          else
+            task:insert_result(sym, 1.0, ret)
           end
-        else
-          task:insert_result(sym, 1.0, ret)
         end
+      else
+        logger.debugm(SN, task, 'skip symbol due to settings: %s', data[2])
+      end
+
+
       end
-    end
   end
 end
 
@@ -224,10 +233,87 @@ exports.squeeze_init = function()
       parent = squeeze_function_ids[v.order],
       no_squeeze = true, -- to avoid infinite recursion
     }
+    local metric_sym = rspamd_config:get_metric_symbol(k)
+
+    if metric_sym then
+      v.group = metric_sym.group
+      v.score = metric_sym.score
+      v.description = description
+
+      if not squeezed_groups[v.group] then
+        logger.debugm(SN, rspamd_config, 'added squeezed group: %s', v.group)
+        squeezed_groups[v.group] = {}
+      end
+
+      table.insert(squeezed_groups[v.group], v)
+    end
     if not squeezed_rules[v.order] then
       squeezed_rules[v.order] = {}
     end
-    table.insert(squeezed_rules[v.order], {v.cb,k})
+    table.insert(squeezed_rules[v.order], {v.cb,k,v})
+  end
+end
+
+exports.handle_settings = function(task, settings)
+  local symbols_disabled = {}
+  local symbols_enabled = {}
+  local found = false
+
+  if settings.default then settings = settings.default end
+
+  if settings.symbols_enabled then
+    for k,v in squeezed_symbols do
+      if not settings.symbols_enabled[k] then
+        symbols_disabled[k] = true
+        found = true
+      else
+        symbols_enabled[k] = true
+      end
+    end
+  end
+
+  if settings.groups_enabled then
+    for k,syms in pairs(squeezed_groups) do
+      if not settings.groups_enabled[k] then
+        for _,sym in ipairs(syms) do
+          if not symbols_enabled[sym] then
+            symbols_disabled[sym] = true
+            found = true
+          end
+        end
+      else
+        for _,sym in ipairs(syms) do
+          if symbols_disabled[sym] then
+            symbols_disabled[sym] = nil
+          end
+          symbols_enabled[sym] = true
+        end
+      end
+    end
+  end
+
+  if settings.symbols_disabled then
+    for k,v in squeezed_symbols do
+      if settings.symbols_disabled[k] then
+        symbols_disabled[k] = true
+        found = true
+      end
+    end
+  end
+
+  if settings.groups_disabled then
+    for k,syms in pairs(squeezed_groups) do
+      if settings.groups_disabled[k] then
+        for _,sym in ipairs(syms) do
+          symbols_disabled[sym] = true
+          found = true
+        end
+      end
+    end
+  end
+
+  if found then
+    task:cache_set('squeezed_disable', symbols_disabled)
   end
 end
 
index 13647c4803f65946c6599dce41014e2a53ac0ec2..3ebf83f99d1ac0fff1077c499d749632e494709a 100644 (file)
@@ -24,6 +24,7 @@ end
 
 local rspamd_logger = require "rspamd_logger"
 local rspamd_maps = require "lua_maps"
+local lua_squeeze = require "lua_squeeze_rules"
 local redis_params
 
 local settings = {}
@@ -47,6 +48,7 @@ local function check_query_settings(task)
       local settings_obj = parser:get_object()
       task:set_settings(settings_obj)
       task:cache_set('settings', settings_obj)
+      lua_squeeze.handle_settings(task, settings_obj)
 
       return true
     else
@@ -75,6 +77,7 @@ local function check_query_settings(task)
 
       task:set_settings(nset)
       task:cache_set('settings', nset)
+      lua_squeeze.handle_settings(task, nset)
 
       return true
     end
@@ -87,6 +90,7 @@ local function check_query_settings(task)
     local elt = settings_ids[id_str]
     if elt and elt['apply'] then
       task:set_settings(elt['apply'])
+      lua_squeeze.handle_settings(task, elt['apply'])
       task:cache_set('settings', elt['apply'])
 
       if elt.apply['add_headers'] or elt.apply['remove_headers'] then
@@ -336,6 +340,7 @@ local function check_settings(task)
             task:get_message_id(), s.name)
           if rule['apply'] then
             task:set_settings(rule['apply'])
+            lua_squeeze.handle_settings(task, rule['apply'])
             task:cache_set('settings', rule['apply'])
             applied = true
           end
@@ -648,6 +653,7 @@ local function gen_redis_callback(handler, id)
               rspamd_logger.infox(task, "<%1> apply settings according to redis rule %2",
                 task:get_message_id(), id)
               task:set_settings(obj)
+              lua_squeeze.handle_settings(task, obj)
               task:cache_set('settings', obj)
               break
             end