aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/lua_util.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/lua_util.lua')
-rw-r--r--lualib/lua_util.lua122
1 files changed, 122 insertions, 0 deletions
diff --git a/lualib/lua_util.lua b/lualib/lua_util.lua
index 5af09d316..e72d7d319 100644
--- a/lualib/lua_util.lua
+++ b/lualib/lua_util.lua
@@ -475,4 +475,126 @@ end
exports.override_defaults = override_defaults
+--[[[
+-- @function lua_util.extract_specific_urls(task, limit)
+-- Apply heuristic in extracting of urls from task, this function
+-- tries its best to extract specific number of urls from a task based on
+-- their characteristics
+--]]
+exports.extract_specific_urls = function(task, lim, need_emails)
+ local fun = require "fun"
+ local cache_key = string.format('sp_urls_%d%s', lim, need_emails)
+
+ local cached = task:cache_get(cache_key)
+
+ if cached then
+ return cached
+ end
+
+ local urls = task:get_urls(need_emails)
+
+
+ if not urls then return {} end
+
+ if #urls <= lim then
+ task:cache_set(cache_key, urls)
+
+ return urls
+ end
+
+ -- Filter by tld:
+ local tlds = {}
+ local eslds = {}
+ local ntlds, neslds = 0, 0
+
+ local res = {}
+
+ for _,u in ipairs(urls) do
+ local esld = u:get_tld()
+
+ if esld then
+ if not eslds[esld] then
+ eslds[esld] = {u}
+ neslds = neslds + 1
+ else
+ table.insert(eslds[esld], u)
+ end
+
+ local parts = rspamd_str_split(esld, '.')
+ local tld = table.concat(fun.totable(fun.tail(parts)), '.')
+
+ if not tlds[tld] then
+ tlds[tld] = {u}
+ ntlds = ntlds + 1
+ else
+ table.insert(tlds[tld], u)
+ end
+
+ -- Extract priority urls that are proven to be malicious
+ if not u:is_html_displayed() then
+ if u:is_obscured() then
+ table.insert(res, u)
+ else
+ if u:get_user() then
+ table.insert(res, u)
+ elseif u:is_subject() then
+ table.insert(res, u)
+ end
+ end
+ end
+ end
+ end
+
+ lim = lim - #res
+ if lim <= 0 then lim = 1 end
+
+ if neslds <= lim then
+ -- We can get urls based on their eslds
+ while lim > 0 do
+ for _,urls in pairs(eslds) do
+ table.insert(res, table.remove(urls))
+ lim = lim - 1
+ end
+ end
+
+ task:cache_set(cache_key, urls)
+ return res
+ end
+
+ if ntlds <= lim then
+ while lim > 0 do
+ for _,urls in pairs(tlds) do
+ table.insert(res, table.remove(urls))
+ lim = lim - 1
+ end
+ end
+
+ task:cache_set(cache_key, urls)
+ return res
+ end
+
+ -- We need to sort tlds table first
+ local tlds_keys = {}
+ for k,_ in pairs(tlds) do table.insert(tlds_keys, k) end
+ table.sort(tlds_keys, function (t1, t2)
+ return #tlds[t1] < #tlds[t2]
+ end)
+
+ local ntlds = #tlds_keys
+ for i=1,ntlds / 2 do
+ local tld1 = tlds[tlds_keys[i]]
+ local tld2 = tlds[tlds_keys[ntlds - i]]
+ table.insert(res, table.remove(tld1))
+ table.insert(res, table.remove(tld2))
+ lim = lim - 2
+
+ if lim <= 0 then
+ break
+ end
+ end
+
+ task:cache_set(cache_key, urls)
+ return res
+end
+
return exports