--[[ Copyright (c) 2018, Vsevolod Stakhov Copyright (c) 2019, Carsten Rosenberg Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ]]-- --[[[ -- @module lua_scanners_common -- This module contains common external scanners functions --]] local rspamd_logger = require "rspamd_logger" local rspamd_regexp = require "rspamd_regexp" local lua_util = require "lua_util" local lua_redis = require "lua_redis" local fun = require "fun" local exports = {} local function log_clean(task, rule, msg) msg = msg or 'message or mime_part is clean' if rule.log_clean then rspamd_logger.infox(task, '%s: %s', rule.log_prefix, msg) else lua_util.debugm(rule.module_name, task, '%s: %s', rule.log_prefix, msg) end end local function match_patterns(default_sym, found, patterns, dyn_weight) if type(patterns) ~= 'table' then return default_sym, dyn_weight end if not patterns[1] then for sym, pat in pairs(patterns) do if pat:match(found) then return sym, '1' end end return default_sym, dyn_weight else for _, p in ipairs(patterns) do for sym, pat in pairs(p) do if pat:match(found) then return sym, '1' end end end return default_sym, dyn_weight end end local function yield_result(task, rule, vname, N, dyn_weight) local all_whitelisted = true if not dyn_weight then dyn_weight = 1.0 end if type(vname) == 'string' then local symname, symscore = match_patterns(rule.symbol, vname, rule.patterns, dyn_weight) if rule.whitelist and rule.whitelist:get_key(vname) then rspamd_logger.infox(task, '%s: "%s" is in whitelist', rule.log_prefix, vname) return end task:insert_result(symname, symscore, vname) rspamd_logger.infox(task, '%s: %s found: "%s"', rule.log_prefix, rule.detection_category, vname) elseif type(vname) == 'table' then for _, vn in ipairs(vname) do local symname, symscore = match_patterns(rule.symbol, vn, rule.patterns, dyn_weight) if rule.whitelist and rule.whitelist:get_key(vn) then rspamd_logger.infox(task, '%s: "%s" is in whitelist', rule.log_prefix, vn) else all_whitelisted = false task:insert_result(symname, symscore, vn) rspamd_logger.infox(task, '%s: %s found: "%s"', rule.log_prefix, rule.detection_category, vn) end end end if rule.action then if type(vname) == 'table' then if all_whitelisted then return end vname = table.concat(vname, '; ') end task:set_pre_result(rule.action, lua_util.template(rule.message or 'Rejected', { SCANNER = rule.name, VIRUS = vname, }), N) end end local function message_not_too_large(task, content, rule) local max_size = tonumber(rule.max_size) if not max_size then return true end if #content > max_size then rspamd_logger.infox(task, "skip %s check as it is too large: %s (%s is allowed)", rule.log_prefix, #content, max_size) return false end return true end local function need_av_check(task, content, rule) return message_not_too_large(task, content, rule) end local function check_av_cache(task, digest, rule, fn) local key = digest local function redis_av_cb(err, data) if data and type(data) == 'string' then -- Cached data = rspamd_str_split(data, '\t') local threat_string = rspamd_str_split(data[1], '\v') local score = data[2] or rule.default_score if threat_string[1] ~= 'OK' then lua_util.debugm(rule.module_name, task, '%s: got cached threat result for %s: %s', rule.log_prefix, key, threat_string[1]) yield_result(task, rule, threat_string, score) else lua_util.debugm(rule.module_name, task, '%s: got cached negative result for %s: %s', rule.log_prefix, key, threat_string[1]) end else if err then rspamd_logger.errx(task, 'got error checking cache: %s', err) end fn() end end if rule.redis_params then key = rule.prefix .. key if lua_redis.redis_make_request(task, rule.redis_params, -- connect params key, -- hash key false, -- is write redis_av_cb, --callback 'GET', -- command {key} -- arguments) ) then return true end end return false end local function save_av_cache(task, digest, rule, to_save, dyn_weight) local key = digest local function redis_set_cb(err) -- Do nothing if err then rspamd_logger.errx(task, 'failed to save %s cache for %s -> "%s": %s', rule.detection_category, to_save, key, err) else lua_util.debugm(rule.module_name, task, '%s: saved cached result for %s: %s', rule.log_prefix, key, to_save) end end if type(to_save) == 'table' then to_save = table.concat(to_save, '\v') end local value = table.concat({to_save, dyn_weight}, '\t') if rule.redis_params and rule.prefix then key = rule.prefix .. key lua_redis.redis_make_request(task, rule.redis_params, -- connect params key, -- hash key true, -- is write redis_set_cb, --callback 'SETEX', -- command { key, rule.cache_expire or 0, value } ) end return false end local function create_regex_table(patterns) local regex_table = {} if patterns[1] then for i, p in ipairs(patterns) do if type(p) == 'table' then local new_set = {} for k, v in pairs(p) do new_set[k] = rspamd_regexp.create_cached(v) end regex_table[i] = new_set else regex_table[i] = {} end end else for k, v in pairs(patterns) do regex_table[k] = rspamd_regexp.create_cached(v) end end return regex_table end local function match_filter(task, found, patterns) if type(patterns) ~= 'table' then return false end if not patterns[1] then for _, pat in pairs(patterns) do if pat:match(found) then return true end end return false else for _, p in ipairs(patterns) do for _, pat in ipairs(p) do if pat:match(found) then return true end end end return false end end -- borrowed from mime_types.lua -- ext is the last extension, LOWERCASED -- ext2 is the one before last extension LOWERCASED local function gen_extension(fname) local filename_parts = rspamd_str_split(fname, '.') local ext = {} for n = 1, 2 do ext[n] = #filename_parts > n and string.lower(filename_parts[#filename_parts + 1 - n]) or nil end return ext[1],ext[2],filename_parts end local function check_parts_match(task, rule) local filter_func = function(p) local content_type,content_subtype = p:get_type() local fname = p:get_filename() local ext, ext2, part_table local extension_check = false local content_type_check = false local text_part_min_words_check = true if rule.scan_all_mime_parts == false then -- check file extension and filename regex matching if fname ~= nil then ext,ext2,part_table = gen_extension(fname) lua_util.debugm(rule.module_name, task, '%s: extension found: %s - 2.ext: %s - parts: %s', rule.log_prefix, ext, ext2, part_table) if match_filter(task, ext, rule.mime_parts_filter_ext) or match_filter(task, ext2, rule.mime_parts_filter_ext) then lua_util.debugm(rule.module_name, task, '%s: extension matched: %s', rule.log_prefix, ext) extension_check = true end if match_filter(task, fname, rule.mime_parts_filter_regex) then content_type_check = true end end -- check content type regex matching if content_type ~= nil and content_subtype ~= nil then if match_filter(task, content_type..'/'..content_subtype, rule.mime_parts_filter_regex) then lua_util.debugm(rule.module_name, task, '%s: regex ct: %s', rule.log_prefix, content_type..'/'..content_subtype) content_type_check = true end end end -- check text_part has more words than text_part_min_words_check if rule.text_part_min_words and p:is_text() then text_part_min_words_check = p:get_words_count() >= tonumber(rule.text_part_min_words) end return (rule.scan_image_mime and p:is_image()) or (rule.scan_text_mime and text_part_min_words_check) or (p:is_attachment() and rule.scan_all_mime_parts ~= false) or extension_check or content_type_check end return fun.filter(filter_func, task:get_parts()) end exports.log_clean = log_clean exports.yield_result = yield_result exports.match_patterns = match_patterns exports.need_av_check = need_av_check exports.check_av_cache = check_av_cache exports.save_av_cache = save_av_cache exports.create_regex_table = create_regex_table exports.check_parts_match = check_parts_match setmetatable(exports, { __call = function(t, override) for k, v in pairs(t) do if _G[k] ~= nil then local msg = 'function ' .. k .. ' already exists in global scope.' if override then _G[k] = v print('WARNING: ' .. msg .. ' Overwritten.') else print('NOTICE: ' .. msg .. ' Skipped.') end else _G[k] = v end end end, }) return exports