]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Check attachments only on AV scanners in attachments_only mode
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 14 Jul 2018 18:35:35 +0000 (19:35 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 14 Jul 2018 18:35:35 +0000 (19:35 +0100)
src/plugins/lua/antivirus.lua

index 8e2ddbbbd054b1753ff366368d896c049319d930..6adabff3c0fa1cd509dcf91ed8da6f3e661eaabd 100644 (file)
@@ -277,36 +277,23 @@ local function savapi_config(opts)
   return nil
 end
 
-local function message_not_too_large(task, rule)
+local function message_not_too_large(task, content, rule)
   local max_size = tonumber(rule['max_size'])
   if not max_size then return true end
-  if task:get_size() > max_size then
+  if #content > max_size then
     rspamd_logger.infox("skip %s AV check as it is too large: %s (%s is allowed)",
-      rule.type, task:get_size(), max_size)
+      rule.type, #content, max_size)
     return false
   end
   return true
 end
 
-local function need_av_check(task, rule)
-  if rule['attachments_only'] then
-    for _,p in ipairs(task:get_parts()) do
-      if p:get_filename() and not p:is_image() then
-        return message_not_too_large(task, rule)
-      end
-    end
-
-    rspamd_logger.infox("skip %s AV check as there are no attachments in a message",
-      rule.type)
-
-    return false
-  else
-    return message_not_too_large(task, rule)
-  end
+local function need_av_check(task, content, rule)
+  return message_not_too_large(task, content, rule)
 end
 
-local function check_av_cache(task, rule, fn)
-  local key = task:get_digest()
+local function check_av_cache(task, digest, rule, fn)
+  local key = digest
 
   local function redis_av_cb(err, data)
     if data and type(data) == 'string' then
@@ -345,8 +332,8 @@ local function check_av_cache(task, rule, fn)
   return false
 end
 
-local function save_av_cache(task, rule, to_save)
-  local key = task:get_digest()
+local function save_av_cache(task, digest, rule, to_save)
+  local key = digest
 
   local function redis_set_cb(err)
     -- Do nothing
@@ -378,14 +365,15 @@ local function save_av_cache(task, rule, to_save)
   return false
 end
 
-local function fprot_check(task, rule)
+local function fprot_check(task, content, digest, rule)
   local function fprot_check_uncached ()
     local upstream = rule.upstreams:get_upstream_round_robin()
     local addr = upstream:get_addr()
     local retransmits = rule.retransmits
     local scan_id = task:get_queue_id()
     if not scan_id then scan_id = task:get_uid() end
-    local header = string.format('SCAN STREAM %s SIZE %d\n', scan_id, task:get_size())
+    local header = string.format('SCAN STREAM %s SIZE %d\n', scan_id,
+        #content)
     local footer = '\n'
 
     local function fprot_callback(err, data)
@@ -402,7 +390,7 @@ local function fprot_check(task, rule)
               port = addr:get_port(),
               timeout = rule['timeout'],
               callback = fprot_callback,
-              data = { header, task:get_content(), footer },
+              data = { header, content, footer },
               stop_pattern = '\n'
             })
           else
@@ -437,7 +425,7 @@ local function fprot_check(task, rule)
           end
         end
         if cached then
-          save_av_cache(task, rule, cached)
+          save_av_cache(task, digest, rule, cached)
         end
       end
     end
@@ -448,13 +436,13 @@ local function fprot_check(task, rule)
       port = addr:get_port(),
       timeout = rule['timeout'],
       callback = fprot_callback,
-      data = { header, task:get_content(), footer },
+      data = { header, content, footer },
       stop_pattern = '\n'
     })
   end
 
-  if need_av_check(task, rule) then
-    if check_av_cache(task, rule, fprot_check_uncached) then
+  if need_av_check(task, content, rule) then
+    if check_av_cache(task, digest, rule, fprot_check_uncached) then
       return
     else
       fprot_check_uncached()
@@ -462,13 +450,13 @@ local function fprot_check(task, rule)
   end
 end
 
-local function clamav_check(task, rule)
+local function clamav_check(task, content, digest, rule)
   local function clamav_check_uncached ()
     local upstream = rule.upstreams:get_upstream_round_robin()
     local addr = upstream:get_addr()
     local retransmits = rule.retransmits
     local header = rspamd_util.pack("c9 c1 >I4", "zINSTREAM", "\0",
-      task:get_size())
+      #content)
     local footer = rspamd_util.pack(">I4", 0)
 
     local function clamav_callback(err, data)
@@ -486,7 +474,7 @@ local function clamav_check(task, rule)
               port = addr:get_port(),
               timeout = rule['timeout'],
               callback = clamav_callback,
-              data = { header, task:get_content(), footer },
+              data = { header, content, footer },
               stop_pattern = '\0'
             })
           else
@@ -522,7 +510,7 @@ local function clamav_check(task, rule)
           end
         end
         if cached then
-          save_av_cache(task, rule, cached)
+          save_av_cache(task, digest, rule, cached)
         end
       end
     end
@@ -533,13 +521,13 @@ local function clamav_check(task, rule)
       port = addr:get_port(),
       timeout = rule['timeout'],
       callback = clamav_callback,
-      data = { header, task:get_content(), footer },
+      data = { header, content, footer },
       stop_pattern = '\0'
     })
   end
 
-  if need_av_check(task, rule) then
-    if check_av_cache(task, rule, clamav_check_uncached) then
+  if need_av_check(task, content, rule) then
+    if check_av_cache(task, digest, rule, clamav_check_uncached) then
       return
     else
       clamav_check_uncached()
@@ -547,13 +535,13 @@ local function clamav_check(task, rule)
   end
 end
 
-local function sophos_check(task, rule)
+local function sophos_check(task, content, digest, rule)
   local function sophos_check_uncached ()
     local upstream = rule.upstreams:get_upstream_round_robin()
     local addr = upstream:get_addr()
     local retransmits = rule.retransmits
     local protocol = 'SSSP/1.0\n'
-    local streamsize = string.format('SCANDATA %d\n', task:get_size())
+    local streamsize = string.format('SCANDATA %d\n', #content)
     local bye = 'BYE\n'
 
     local function sophos_callback(err, data, conn)
@@ -571,7 +559,7 @@ local function sophos_check(task, rule)
               port = addr:get_port(),
               timeout = rule['timeout'],
               callback = sophos_callback,
-              data = { protocol, streamsize, task:get_content(), bye }
+              data = { protocol, streamsize, content, bye }
             })
           else
             rspamd_logger.errx(task, 'failed to scan, maximum retransmits exceed')
@@ -589,24 +577,24 @@ local function sophos_check(task, rule)
         local vname = string.match(data, 'VIRUS (%S+) ')
         if vname then
           yield_result(task, rule, vname)
-          save_av_cache(task, rule, vname)
+          save_av_cache(task, digest, rule, vname)
         else
           if string.find(data, 'DONE OK') then
             if rule['log_clean'] then
               rspamd_logger.infox(task, '%s [%s]: message is clean', rule['symbol'], rule['type'])
             end
-            save_av_cache(task, rule, 'OK')
+            save_av_cache(task, digest, rule, 'OK')
           elseif string.find(data, 'FAIL 0212') then
             if rule['savdi_report_encrypted'] then
               rspamd_logger.infox(task, 'Message is ENCRYPTED (0212 SOPHOS_SAVI_ERROR_FILE_ENCRYPTED): %s', data)
               yield_result(task, rule, "SAVDI_FILE_ENCRYPTED")
-              save_av_cache(task, rule, "SAVDI_FILE_ENCRYPTED")
+              save_av_cache(task, digest, rule, "SAVDI_FILE_ENCRYPTED")
             end
           elseif string.find(data, 'REJ 4') then
             if rule['savdi_report_oversize'] then
               rspamd_logger.infox(task, 'Message is OVERSIZED (SSSP reject code 4): %s', data)
-              yield_result(task, rule, "SAVDI_FILE_OVERSIZED")
-              save_av_cache(task, rule, "SAVDI_FILE_OVERSIZED")
+              yield_result(task, digest, rule, "SAVDI_FILE_OVERSIZED")
+              save_av_cache(task, digest, rule, "SAVDI_FILE_OVERSIZED")
             end
           elseif string.find(data, 'REJ 1') then
             rspamd_logger.errx(task, 'SAVDI (Protocol error (REJ 1)): %s', data)
@@ -626,12 +614,12 @@ local function sophos_check(task, rule)
       port = addr:get_port(),
       timeout = rule['timeout'],
       callback = sophos_callback,
-      data = { protocol, streamsize, task:get_content(), bye }
+      data = { protocol, streamsize, content, bye }
     })
   end
 
-  if need_av_check(task, rule) then
-    if check_av_cache(task, rule, sophos_check_uncached) then
+  if need_av_check(task, content, rule) then
+    if check_av_cache(task, digest, rule, sophos_check_uncached) then
       return
     else
       sophos_check_uncached()
@@ -639,7 +627,7 @@ local function sophos_check(task, rule)
   end
 end
 
-local function savapi_check(task, rule)
+local function savapi_check(task, content, digest, rule)
   local function savapi_check_uncached ()
     local upstream = rule.upstreams:get_upstream_round_robin()
     local addr = upstream:get_addr()
@@ -664,7 +652,7 @@ local function savapi_check(task, rule)
         end
 
         yield_result(task, rule, vname)
-        save_av_cache(task, rule, vname)
+        save_av_cache(task, digest, rule, vname)
       end
       if conn then
         conn:close()
@@ -680,7 +668,7 @@ local function savapi_check(task, rule)
         if rule['log_clean'] then
           rspamd_logger.infox(task, '%s: message is clean', rule['type'])
         end
-        save_av_cache(task, rule, 'OK')
+        save_av_cache(task, digest, rule, 'OK')
         conn:add_write(savapi_fin_cb, 'QUIT\n')
 
       -- Terminal response - infected
@@ -772,8 +760,8 @@ local function savapi_check(task, rule)
     })
   end
 
-  if need_av_check(task, rule) then
-    if check_av_cache(task, rule, savapi_check_uncached) then
+  if need_av_check(task, content, rule) then
+    if check_av_cache(task, digest, rule, savapi_check_uncached) then
       return
     else
       savapi_check_uncached()
@@ -855,7 +843,21 @@ local function add_antivirus_rule(sym, opts)
   end
 
   return function(task)
-    return cfg.check(task, rule)
+    if rule.attachments_only then
+      local parts = task:get_parts() or {}
+
+      for _,p in ipairs(parts) do
+        if not p:is_image() and not p:is_text() and not p:is_multipart() then
+          local content = p:get_content()
+
+          if content and #content > 0 then
+            cfg.check(task, content, p:get_digest(), rule)
+          end
+        end
+      end
+    else
+      cfg.check(task, task:get_content(), task:get_digest(), rule)
+    end
   end
 end