]> source.dussan.org Git - rspamd.git/commitdiff
[Rework] Lua_util: Another rework for extract_specific_urls
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 20 Aug 2019 09:13:52 +0000 (10:13 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 20 Aug 2019 09:13:52 +0000 (10:13 +0100)
lualib/lua_util.lua
test/lua/unit/lua_util.extract_specific_urls.lua

index cde09ad6a008c7f1a7e7caf8c722791a51c4ca91..4dddb979f47843768a78194efb8abfdbbbd65b12 100644 (file)
@@ -682,6 +682,7 @@ exports.filter_specific_urls = function (urls, params)
   end
 
   local function process_single_url(u)
+    local priority = 1 -- Normal priority
     local esld = u:get_tld()
 
     if params.ignore_redirected and u:is_redirected() then
@@ -697,36 +698,40 @@ exports.filter_specific_urls = function (urls, params)
     local str_hash = tostring(u)
 
     if esld then
+      -- Special cases
+      if (u:get_protocol() ~= 'mailto') and (not u:is_html_displayed()) then
+        if u:is_obscured() then
+          priority = 2
+        else
+          if u:get_user() then
+            priority = 2
+          elseif u:is_subject() or u:is_phished() then
+            priority = 2
+          end
+        end
+      elseif u:is_html_displayed() then
+        priority = 0
+      end
+
       if not eslds[esld] then
-        eslds[esld] = {{str_hash, u}}
+        eslds[esld] = {{str_hash, u, priority}}
         neslds = neslds + 1
       else
         if #eslds[esld] < params.esld_limit then
-          table.insert(eslds[esld], {str_hash, u})
+          table.insert(eslds[esld], {str_hash, u, priority})
         end
       end
 
+
+      -- eSLD - 1 part => tld
       local parts = rspamd_str_split(esld, '.')
       local tld = table.concat(fun.totable(fun.tail(parts)), '.')
 
       if not tlds[tld] then
-        tlds[tld] = {{str_hash, u}}
+        tlds[tld] = {{str_hash, u, priority}}
         ntlds = ntlds + 1
       else
-        table.insert(tlds[tld], {str_hash, u})
-      end
-
-      -- Special cases
-      if not u:get_protocol() == 'mailto' and not u:is_html_displayed() then
-        if u:is_obscured() then
-          insert_url(str_hash, u)
-        else
-          if u:get_user() then
-            insert_url(str_hash, u)
-          elseif u:is_subject() or u:is_phished() then
-            insert_url(str_hash, u)
-          end
-        end
+        table.insert(tlds[tld], {str_hash, u, priority})
       end
     end
   end
@@ -737,24 +742,9 @@ exports.filter_specific_urls = function (urls, params)
 
   local limit = params.limit
   limit = limit - nres
-  if limit <= 0 then limit = 1 end
-
-  if neslds <= limit then
-    -- We can get urls based on their eslds
-    repeat
-      local item_found = false
-
-      for _,lurls in pairs(eslds) do
-        if #lurls > 0 then
-          local last = table.remove(lurls)
-          insert_url(last[1], last[2])
-          limit = limit - 1
-          item_found = true
-        end
-      end
-
-    until limit <= 0 or not item_found
+  if limit < 0 then limit = 0 end
 
+  if limit == 0 then
     res = exports.values(res)
     if params.task and not params.no_cache then
       params.task:cache_set(cache_key, res)
@@ -762,16 +752,49 @@ exports.filter_specific_urls = function (urls, params)
     return res
   end
 
-  if ntlds <= limit then
-    while limit > 0 do
-      for _,lurls in pairs(tlds) do
+  -- Sort eSLDs and tlds
+  local function sort_stuff(tbl)
+    -- Sort according to max priority
+    table.sort(tbl, function(e1, e2)
+      -- Sort by priority so max priority is at the end
+      table.sort(e1, function(tr1, tr2)
+        return tr1[3] < tr2[3]
+      end)
+      table.sort(e2, function(tr1, tr2)
+        return tr1[3] < tr2[3]
+      end)
+
+      if e1[#e1][3] ~= e2[#e2][3] then
+        -- Sort by priority so max priority is at the beginning
+        return e1[#e1][3] > e2[#e2][3]
+      else
+        -- Prefer less urls to more urls per esld
+        return #e1 < #e2
+      end
+
+    end)
+
+    return tbl
+  end
+
+  eslds = sort_stuff(exports.values(eslds))
+  neslds = #eslds
+
+  if neslds <= limit then
+    -- Number of eslds < limit
+    repeat
+      local item_found = false
+
+      for _,lurls in ipairs(eslds) do
         if #lurls > 0 then
           local last = table.remove(lurls)
           insert_url(last[1], last[2])
           limit = limit - 1
+          item_found = true
         end
       end
-    end
+
+    until limit <= 0 or not item_found
 
     res = exports.values(res)
     if params.task and not params.no_cache then
@@ -780,30 +803,18 @@ exports.filter_specific_urls = function (urls, params)
     return res
   end
 
-  -- We need to sort tlds table first
-  local tlds_keys = {}
-  for k,_ in pairs(tlds) do table.insert(tlds_keys, k) end
-  table.sort(tlds_keys, function (t1, t2)
-    return #tlds[t1] < #tlds[t2]
-  end)
-
-  ntlds = #tlds_keys
-  for i=1,ntlds / 2 do
-    local tld1 = tlds[tlds_keys[i]]
-    local tld2 = tlds[tlds_keys[ntlds - i]]
-    if #tld1 > 0 then
-      local last = table.remove(tld1)
-      insert_url(last[1], last[2])
-      limit = limit - 1
-    end
-    if #tld2 > 0 then
-      local last = table.remove(tld2)
-      insert_url(last[1], last[2])
-      limit = limit - 1
-    end
+  tlds = sort_stuff(exports.values(tlds))
+  ntlds = #tlds
 
-    if limit <= 0 then
-      break
+  -- Number of tlds < limit
+  while limit > 0 do
+    for _,lurls in ipairs(tlds) do
+      if #lurls > 0 then
+        local last = table.remove(lurls)
+        insert_url(last[1], last[2])
+        limit = limit - 1
+      end
+      if limit == 0 then break end
     end
   end
 
@@ -811,7 +822,6 @@ exports.filter_specific_urls = function (urls, params)
   if params.task and not params.no_cache then
     params.task:cache_set(cache_key, res)
   end
-
   return res
 end
 
index 93816745e79e43e29d9e79f411c13af8069f447d..c84a7ca8d72a7cbfd61179352e0073c5a7be9128 100644 (file)
@@ -192,8 +192,8 @@ context("Lua util - extract_specific_urls", function()
 
     local actual = util.extract_specific_urls({
       task = task,
-      limit = 2,
-      esld_limit = 2,
+      limit = 1,
+      esld_limit = 1,
     })
 
     local actual_result = prepare_actual_result(actual)
@@ -202,7 +202,7 @@ context("Lua util - extract_specific_urls", function()
       local s = logger.slog("case[%1] %2 =?= %3", i, expect, actual_result)
       print(s) --]]
 
-    assert_equal("domain.com", actual_result[1], "checking that first url is the one with highest suspiciousness level")
+    assert_rspamd_table_eq({actual = actual_result, expect = {"domain.com"}})
 
   end)
 end)