]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Pass params via table instead of long arglist. Added esld_limit param to...
authorMikhail Galanin <mgalanin@mimecast.com>
Tue, 31 Jul 2018 08:31:50 +0000 (09:31 +0100)
committerMikhail Galanin <mgalanin@mimecast.com>
Tue, 31 Jul 2018 08:31:50 +0000 (09:31 +0100)
lualib/lua_util.lua
test/lua/unit/lua_util.extract_specific_urls.lua [new file with mode: 0644]

index 520318fe650f52054a00cdccdabc118f0540195c..c43286d8bd680b0014daf9b1bc2e6caab1b47ab3 100644 (file)
@@ -490,36 +490,71 @@ end
 exports.override_defaults = override_defaults
 
 --[[[
--- @function lua_util.extract_specific_urls(task, limit, [need_emails[, filter[, prefix])
+-- @function lua_util.extract_specific_urls(params)
+-- params: {
+- - task
+- - limit <int> (default = 9999)
+- - esld_limit <int> (default = 9999) n domains per eSLD (effective second level domain)
+                                      works only if number of unique eSLD less than `limit`
+- - need_emails <bool> (default = false)
+- - filter <callback> (default = nil)
+- - prefix <string> cache prefix (default = nil)
+-- }
 -- Apply heuristic in extracting of urls from task, this function
 -- tries its best to extract specific number of urls from a task based on
 -- their characteristics
 --]]
-exports.extract_specific_urls = function(task, lim, need_emails, filter, prefix)
+-- exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
+exports.extract_specific_urls = function(params_or_task, limit, need_emails, filter, prefix)
+  local default_params = {
+    limit = 9999,
+    esld_limit = 9999,
+    need_emails = false,
+    filter = nil,
+    prefix = nil
+  }
+
+  local params
+  if type(params_or_task) == 'table' and type(limit) == 'nil' then
+    params = params_or_task
+  else
+    -- Deprecated call
+    params = {
+      task = params_or_task,
+      limit = limit,
+      need_emails = need_emails,
+      filter = filter,
+      prefix = prefix
+    }
+  end
+  for k,v in pairs(default_params) do
+    if not params[k] then params[k] = default_params[k] end
+  end
+
+
   local cache_key
 
-  if prefix then
-    cache_key = prefix
+  if params.prefix then
+    cache_key = params.prefix
   else
-    cache_key = string.format('sp_urls_%d%s', lim, need_emails)
+    cache_key = string.format('sp_urls_%d%s', params.limit, params.need_emails)
   end
 
 
-  local cached = task:cache_get(cache_key)
+  local cached = params.task:cache_get(cache_key)
 
   if cached then
     return cached
   end
 
-  local urls = task:get_urls(need_emails)
+  local urls = params.task:get_urls(params.need_emails)
 
   if not urls then return {} end
 
-  if filter then urls = fun.totable(fun.filter(filter, urls)) end
-
-  if #urls <= lim then
-    task:cache_set(cache_key, urls)
+  if params.filter then urls = fun.totable(fun.filter(params.filter, urls)) end
 
+  if #urls <= params.limit and #urls <= params.esld_limit then
+    params.task:cache_set(cache_key, urls)
     return urls
   end
 
@@ -538,7 +573,9 @@ exports.extract_specific_urls = function(task, lim, need_emails, filter, prefix)
         eslds[esld] = {u}
         neslds = neslds + 1
       else
-        table.insert(eslds[esld], u)
+        if #eslds[esld] < params.esld_limit then
+          table.insert(eslds[esld], u)
+        end
       end
 
       local parts = rspamd_str_split(esld, '.')
@@ -566,35 +603,40 @@ exports.extract_specific_urls = function(task, lim, need_emails, filter, prefix)
     end
   end
 
-  lim = lim - #res
-  if lim <= 0 then lim = 1 end
+  local limit = params.limit
+  limit = limit - #res
+  if limit <= 0 then limit = 1 end
 
-  if neslds <= lim then
+  if neslds <= limit then
     -- We can get urls based on their eslds
-    while lim > 0 do
+    repeat
+      local item_found = false
+
       for _,lurls in pairs(eslds) do
         if #lurls > 0 then
           table.insert(res, table.remove(lurls))
-          lim = lim - 1
+          limit = limit - 1
+          item_found = true
         end
       end
-    end
 
-    task:cache_set(cache_key, urls)
+    until limit <= 0 or not item_found
+
+    params.task:cache_set(cache_key, urls)
     return res
   end
 
-  if ntlds <= lim then
-    while lim > 0 do
+  if ntlds <= limit then
+    while limit > 0 do
       for _,lurls in pairs(tlds) do
         if #lurls > 0 then
           table.insert(res, table.remove(lurls))
-          lim = lim - 1
+          limit = limit - 1
         end
       end
     end
 
-    task:cache_set(cache_key, urls)
+    params.task:cache_set(cache_key, urls)
     return res
   end
 
@@ -611,14 +653,14 @@ exports.extract_specific_urls = function(task, lim, need_emails, filter, prefix)
     local tld2 = tlds[tlds_keys[ntlds - i]]
     table.insert(res, table.remove(tld1))
     table.insert(res, table.remove(tld2))
-    lim = lim - 2
+    limit = limit - 2
 
-    if lim <= 0 then
+    if limit <= 0 then
       break
     end
   end
 
-  task:cache_set(cache_key, urls)
+  params.task:cache_set(cache_key, urls)
   return res
 end
 
diff --git a/test/lua/unit/lua_util.extract_specific_urls.lua b/test/lua/unit/lua_util.extract_specific_urls.lua
new file mode 100644 (file)
index 0000000..424cca5
--- /dev/null
@@ -0,0 +1,127 @@
+context("Lua util - extract_specific_urls", function()
+  local util  = require 'lua_util'
+  local mpool = require "rspamd_mempool"
+  local fun   = require "fun"
+  local url   = require "rspamd_url"
+  local logger = require "rspamd_logger"
+  local ffi = require "ffi"
+
+  ffi.cdef[[
+  void rspamd_url_init (const char *tld_file);
+  unsigned ottery_rand_range(unsigned top);
+  void rspamd_http_normalize_path_inplace(char *path, size_t len, size_t *nlen);
+  ]]
+
+  local test_dir = string.gsub(debug.getinfo(1).source, "^@(.+/)[^/]+$", "%1")
+
+  ffi.C.rspamd_url_init(string.format('%s/%s', test_dir, "test_tld.dat"))
+
+  local task_object = {
+    urls      = {},
+    cache_set = function(self, ...) end,
+    cache_get = function(self, ...) end,
+    get_urls  = function(self, need_emails) return self.urls end
+  }
+
+  local url_list = {
+    "google.com",
+    "mail.com",
+    "bizz.com",
+    "bing.com",
+    "example.com",
+    "gov.co.net",
+    "tesco.co.net",
+    "domain1.co.net",
+    "domain2.co.net",
+    "domain3.co.net",
+    "domain4.co.net",
+    "abc.org",
+    "icq.org",
+    "meet.org",
+    "domain1.org",
+    "domain2.org",
+    "domain3.org",
+    "domain3.org",
+    "test.com",
+  }
+
+  local cases = {
+    {expect = url_list, filter = nil, limit = 9999, need_emails = true, prefix = 'p'},
+    {expect = {}, filter = (function() return false end), limit = 9999, need_emails = true, prefix = 'p'},
+    {expect = {"domain4.co.net", "test.com"}, filter = nil, limit = 2, need_emails = true, prefix = 'p'},
+    {
+      expect = {"gov.co.net", "tesco.co.net", "domain1.co.net", "domain2.co.net", "domain3.co.net", "domain4.co.net"},
+      filter = (function(s) return s:get_host():sub(-4) == ".net" end),
+      limit = 9999,
+      need_emails = true,
+      prefix = 'p'
+    },
+    {
+      input  = {"a.google.com", "b.google.com", "c.google.com", "a.net", "bb.net", "a.bb.net", "b.bb.net"},
+      expect = {"a.bb.net", "b.google.com", "a.net", "bb.net", "a.google.com"},
+      filter = nil,
+      limit = 9999,
+      esld_limit = 2,
+      need_emails = true,
+      prefix = 'p'
+    }
+  }
+
+  local pool = mpool.create()
+
+  for i,c in ipairs(cases) do
+
+    local function prepare_url_list(c)
+      return fun.totable(fun.map(
+        function (u) return url.create(pool, u) end,
+        c.input or url_list
+      ))
+    end
+
+    test("extract_specific_urls, backward compatibility case #" .. i, function()
+      task_object.urls = prepare_url_list(c)
+      if (c.esld_limit) then
+        -- not awailable in deprecated version
+        return
+      end
+      local actual = util.extract_specific_urls(task_object, c.limit, c.need_emails, c.filter, c.prefix)
+
+      local actual_result = fun.totable(fun.map(
+        function(u) return u:get_host() end,
+        actual
+      ))
+
+      --[[
+        local s = logger.slog("%1 =?= %2", c.expect, actual_result)
+        print(s) --]]
+
+      assert_equal(true, util.table_cmp(c.expect, actual_result), "checking that we got the same tables")
+
+    end)
+
+    test("extract_specific_urls " .. i, function()
+      task_object.urls = prepare_url_list(c)
+
+      local actual = util.extract_specific_urls({
+        task = task_object,
+        limit = c.limit,
+        esld_limit = c.esld_limit,
+        need_emails = c.need_emails,
+        filter = c.filter,
+        prefix = c.prefix,
+      })
+
+      local actual_result = fun.totable(fun.map(
+        function(u) return u:get_host() end,
+        actual
+      ))
+
+      --[[
+        local s = logger.slog("case[%1] %2 =?= %3", i, c.expect, actual_result)
+        print(s) --]]
+
+      assert_equal(true, util.table_cmp(c.expect, actual_result), "checking that we got the same tables")
+
+    end)
+  end
+end)
\ No newline at end of file