aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins')
-rw-r--r--src/plugins/lua/elastic.lua172
-rw-r--r--src/plugins/lua/gpt.lua244
-rw-r--r--src/plugins/lua/rbl.lua6
3 files changed, 229 insertions, 193 deletions
diff --git a/src/plugins/lua/elastic.lua b/src/plugins/lua/elastic.lua
index 8bed9fcf4..b26fbd8e8 100644
--- a/src/plugins/lua/elastic.lua
+++ b/src/plugins/lua/elastic.lua
@@ -19,6 +19,7 @@ local rspamd_logger = require 'rspamd_logger'
local rspamd_http = require "rspamd_http"
local lua_util = require "lua_util"
local rspamd_util = require "rspamd_util"
+local rspamd_text = require "rspamd_text"
local ucl = require "ucl"
local upstream_list = require "rspamd_upstream_list"
@@ -287,7 +288,7 @@ end
local function handle_error(action, component, limit)
if states[component]['errors'] >= limit then
rspamd_logger.errx(rspamd_config, 'cannot %s elastic %s, failed attempts: %s/%s, stop trying',
- action, component:gsub('_', ' '), states[component]['errors'], limit)
+ action, component:gsub('_', ' '), states[component]['errors'], limit)
states[component]['configured'] = true
else
states[component]['errors'] = states[component]['errors'] + 1
@@ -315,54 +316,6 @@ local function get_received_delay(received_headers)
return delay
end
-local function is_empty(str)
- -- define a pattern that includes invisible unicode characters
- local str_cleared = str:gsub('[' ..
- '\xC2\xA0' .. -- U+00A0 non-breaking space
- '\xE2\x80\x8B' .. -- U+200B zero width space
- '\xEF\xBB\xBF' .. -- U+FEFF byte order mark (zero width no-break space)
- '\xE2\x80\x8C' .. -- U+200C zero width non-joiner
- '\xE2\x80\x8D' .. -- U+200D zero width joiner
- '\xE2\x80\x8E' .. -- U+200E left-to-right mark
- '\xE2\x80\x8F' .. -- U+200F right-to-left mark
- '\xE2\x81\xA0' .. -- U+2060 word joiner
- '\xE2\x80\xAA' .. -- U+202A left-to-right embedding
- '\xE2\x80\xAB' .. -- U+202B right-to-left embedding
- '\xE2\x80\xAC' .. -- U+202C pop directional formatting
- '\xE2\x80\xAD' .. -- U+202D left-to-right override
- '\xE2\x80\xAE' .. -- U+202E right-to-left override
- '\xE2\x81\x9F' .. -- U+2061 function application
- '\xE2\x81\xA1' .. -- U+2061 invisible separator
- '\xE2\x81\xA2' .. -- U+2062 invisible times
- '\xE2\x81\xA3' .. -- U+2063 invisible separator
- '\xE2\x81\xA4' .. -- U+2064 invisible plus
- ']', '') -- gsub replaces all matched characters with an empty string
- if str_cleared:match('[%S]') then
- return false
- else
- return true
- end
-end
-
-local function fill_empty_strings(tbl, empty_value)
- local filled_tbl = {}
- for key, value in pairs(tbl) do
- if value and type(value) == 'table' then
- local nested_filtered = fill_empty_strings(value, empty_value)
- if next(nested_filtered) ~= nil then
- filled_tbl[key] = nested_filtered
- end
- elseif type(value) == 'boolean' then
- filled_tbl[key] = value
- elseif value and type(value) == 'string' and is_empty(value) then
- filled_tbl[key] = empty_value
- elseif value then
- filled_tbl[key] = value
- end
- end
- return filled_tbl
-end
-
local function create_bulk_json(es_index, logs_to_send)
local tbl = {}
for _, row in pairs(logs_to_send) do
@@ -407,15 +360,17 @@ local function elastic_send_data(flush_all, task, cfg, ev_base)
local function http_callback(err, code, body, _)
local push_done = false
if err then
- rspamd_logger.errx(log_object, 'cannot send logs to elastic (%s): %s; failed attempts: %s/%s',
- push_url, err, buffer['errors'], settings['limits']['max_fail'])
+ rspamd_logger.errx(log_object,
+ 'cannot send logs to elastic (%s): %s; failed attempts: %s/%s',
+ push_url, err, buffer['errors'], settings['limits']['max_fail'])
elseif code == 200 then
local parser = ucl.parser()
local res, ucl_err = parser:parse_string(body)
if not ucl_err and res then
local obj = parser:get_object()
push_done = true
- lua_util.debugm(N, log_object, 'successfully sent payload with %s logs', nlogs_to_send)
+ lua_util.debugm(N, log_object,
+ 'successfully sent payload with %s logs', nlogs_to_send)
if obj['errors'] then
for _, value in pairs(obj['items']) do
if value['index'] and value['index']['status'] >= 400 then
@@ -424,15 +379,15 @@ local function elastic_send_data(flush_all, task, cfg, ev_base)
local error_type = safe_get(value, 'index', 'error', 'type') or ''
local error_reason = safe_get(value, 'index', 'error', 'reason') or ''
rspamd_logger.warnx(log_object,
- 'error while pushing logs to elastic, status: %s, index: %s, type: %s, reason: %s',
- status, index, error_type, error_reason)
+ 'error while pushing logs to elastic, status: %s, index: %s, type: %s, reason: %s',
+ status, index, error_type, error_reason)
end
end
end
else
rspamd_logger.errx(log_object,
- 'cannot parse response from elastic (%s): %s; failed attempts: %s/%s',
- push_url, ucl_err, buffer['errors'], settings['limits']['max_fail'])
+ 'cannot parse response from elastic (%s): %s; failed attempts: %s/%s',
+ push_url, ucl_err, buffer['errors'], settings['limits']['max_fail'])
end
else
rspamd_logger.errx(log_object,
@@ -448,8 +403,8 @@ local function elastic_send_data(flush_all, task, cfg, ev_base)
upstream:fail()
if buffer['errors'] >= settings['limits']['max_fail'] then
rspamd_logger.errx(log_object,
- 'failed to send %s log lines, failed attempts: %s/%s, removing failed logs from bugger',
- nlogs_to_send, buffer['errors'], settings['limits']['max_fail'])
+ 'failed to send %s log lines, failed attempts: %s/%s, removing failed logs from bugger',
+ nlogs_to_send, buffer['errors'], settings['limits']['max_fail'])
buffer['logs']:pop_first(nlogs_to_send)
buffer['errors'] = 0
else
@@ -494,6 +449,7 @@ local function get_general_metadata(task)
local empty = settings['index_template']['empty_value']
local user = task:get_user()
r.rspamd_server = rspamd_hostname or empty
+ r.digest = task:get_digest() or empty
r.action = task:get_metric_action() or empty
r.score = task:get_metric_score()[1] or 0
@@ -565,8 +521,8 @@ local function get_general_metadata(task)
if task:has_from('smtp') then
local from = task:get_from({ 'smtp', 'orig' })[1]
if from and
- from['user'] and #from['user'] > 0 and
- from['domain'] and #from['domain'] > 0
+ from['user'] and #from['user'] > 0 and
+ from['domain'] and #from['domain'] > 0
then
r.from_user = from['user']
r.from_domain = from['domain']:lower()
@@ -578,8 +534,8 @@ local function get_general_metadata(task)
if task:has_from('mime') then
local mime_from = task:get_from({ 'mime', 'orig' })[1]
if mime_from and
- mime_from['user'] and #mime_from['user'] > 0 and
- mime_from['domain'] and #mime_from['domain'] > 0
+ mime_from['user'] and #mime_from['user'] > 0 and
+ mime_from['domain'] and #mime_from['domain'] > 0
then
r.mime_from_user = mime_from['user']
r.mime_from_domain = mime_from['domain']:lower()
@@ -608,25 +564,34 @@ local function get_general_metadata(task)
local function process_header(name)
local hdr = task:get_header_full(name)
- local headers_text_ignore_above = settings['index_template']['headers_text_ignore_above'] - 3
if hdr and #hdr > 0 then
local l = {}
for _, h in ipairs(hdr) do
- if settings['index_template']['headers_count_ignore_above'] ~= 0 and
- #l >= settings['index_template']['headers_count_ignore_above']
+ if settings['index_template']['headers_count_ignore_above'] > 0 and
+ #l >= settings['index_template']['headers_count_ignore_above']
then
table.insert(l, 'ignored above...')
break
end
local header
- if settings['index_template']['headers_text_ignore_above'] ~= 0 and
- h.decoded and #h.decoded >= headers_text_ignore_above
- then
- header = h.decoded:sub(1, headers_text_ignore_above) .. '...'
- elseif h.decoded and #h.decoded > 0 then
- header = h.decoded
+ local header_len
+ if h.decoded then
+ header = rspamd_text.fromstring(h.decoded)
+ header_len = header:len_utf8()
else
- header = empty
+ table.insert(l, empty)
+ break
+ end
+ if not header_len or header_len == 0 then
+ table.insert(l, empty)
+ break
+ end
+ if settings['index_template']['headers_text_ignore_above'] > 0 and
+ header_len >= settings['index_template']['headers_text_ignore_above']
+ then
+ header = header:sub_utf8(1, settings['index_template']['headers_text_ignore_above'])
+ table.insert(l, header .. rspamd_text.fromstring('...'))
+ break
end
table.insert(l, header)
end
@@ -686,7 +651,7 @@ local function get_general_metadata(task)
r.received_delay = get_received_delay(task:get_received_headers())
- return fill_empty_strings(r, empty)
+ return r
end
local function elastic_collect(task)
@@ -773,8 +738,8 @@ local function configure_geoip_pipeline(cfg, ev_base)
upstream:ok()
else
rspamd_logger.errx(rspamd_config,
- 'cannot configure elastic geoip pipeline (%s), status code: %s, response: %s',
- geoip_url, code, body)
+ 'cannot configure elastic geoip pipeline (%s), status code: %s, response: %s',
+ geoip_url, code, body)
upstream:fail()
handle_error('configure', 'geoip_pipeline', settings['limits']['max_fail'])
end
@@ -810,8 +775,9 @@ local function put_index_policy(cfg, ev_base, upstream, host, policy_url, index_
states['index_policy']['configured'] = true
upstream:ok()
else
- rspamd_logger.errx(rspamd_config, 'cannot configure elastic index policy (%s), status code: %s, response: %s',
- policy_url, code, body)
+ rspamd_logger.errx(rspamd_config,
+ 'cannot configure elastic index policy (%s), status code: %s, response: %s',
+ policy_url, code, body)
upstream:fail()
handle_error('configure', 'index_policy', settings['limits']['max_fail'])
end
@@ -867,7 +833,7 @@ local function get_index_policy(cfg, ev_base, upstream, host, policy_url, index_
if not lua_util.table_cmp(our_policy['policy']['default_state'], current_default_state) then
update_needed = true
elseif not lua_util.table_cmp(our_policy['policy']['ism_template'][1]['index_patterns'],
- current_ism_index_patterns) then
+ current_ism_index_patterns) then
update_needed = true
elseif not lua_util.table_cmp(our_policy['policy']['states'], current_states) then
update_needed = true
@@ -890,8 +856,8 @@ local function get_index_policy(cfg, ev_base, upstream, host, policy_url, index_
put_index_policy(cfg, ev_base, upstream, host, policy_url, index_policy_json)
else
rspamd_logger.errx(rspamd_config,
- 'current elastic index policy (%s) not returned correct seq_no/primary_term, policy will not be updated, response: %s',
- policy_url, body)
+ 'current elastic index policy (%s) not returned correct seq_no/primary_term, policy will not be updated, response: %s',
+ policy_url, body)
upstream:fail()
handle_error('validate current', 'index_policy', settings['limits']['max_fail'])
end
@@ -909,8 +875,8 @@ local function get_index_policy(cfg, ev_base, upstream, host, policy_url, index_
end
else
rspamd_logger.errx(rspamd_config,
- 'cannot get current elastic index policy (%s), status code: %s, response: %s',
- policy_url, code, body)
+ 'cannot get current elastic index policy (%s), status code: %s, response: %s',
+ policy_url, code, body)
handle_error('get current', 'index_policy', settings['limits']['max_fail'])
upstream:fail()
end
@@ -1037,7 +1003,7 @@ local function configure_index_policy(cfg, ev_base)
}
index_policy['policy']['phases']['delete'] = delete_obj
end
- -- opensearch state policy with hot state
+ -- opensearch state policy with hot state
elseif detected_distro['name'] == 'opensearch' then
local retry = {
count = 3,
@@ -1313,6 +1279,7 @@ local function configure_index_template(cfg, ev_base)
type = 'object',
properties = {
rspamd_server = t_keyword,
+ digest = t_keyword,
action = t_keyword,
score = t_double,
symbols = symbols_obj,
@@ -1381,7 +1348,7 @@ local function configure_index_template(cfg, ev_base)
upstream:ok()
else
rspamd_logger.errx(rspamd_config, 'cannot configure elastic index template (%s), status code: %s, response: %s',
- template_url, code, body)
+ template_url, code, body)
upstream:fail()
handle_error('configure', 'index_template', settings['limits']['max_fail'])
end
@@ -1424,8 +1391,9 @@ local function verify_distro(manual)
local supported_distro_info = supported_distro[detected_distro_name]
-- check that detected_distro_version is valid
if not detected_distro_version or type(detected_distro_version) ~= 'string' then
- rspamd_logger.errx(rspamd_config, 'elastic version should be a string, but we received: %s',
- type(detected_distro_version))
+ rspamd_logger.errx(rspamd_config,
+ 'elastic version should be a string, but we received: %s',
+ type(detected_distro_version))
valid = false
elseif detected_distro_version == '' then
rspamd_logger.errx(rspamd_config, 'unsupported elastic version: empty string')
@@ -1434,21 +1402,22 @@ local function verify_distro(manual)
-- compare versions using compare_versions
local cmp_from = compare_versions(detected_distro_version, supported_distro_info['from'])
if cmp_from == -1 then
- rspamd_logger.errx(rspamd_config, 'unsupported elastic version: %s, minimal supported version of %s is %s',
- detected_distro_version, detected_distro_name, supported_distro_info['from'])
+ rspamd_logger.errx(rspamd_config,
+ 'unsupported elastic version: %s, minimal supported version of %s is %s',
+ detected_distro_version, detected_distro_name, supported_distro_info['from'])
valid = false
else
local cmp_till = compare_versions(detected_distro_version, supported_distro_info['till'])
if (cmp_till >= 0) and not supported_distro_info['till_unknown'] then
rspamd_logger.errx(rspamd_config,
- 'unsupported elastic version: %s, maximum supported version of %s is less than %s',
- detected_distro_version, detected_distro_name, supported_distro_info['till'])
+ 'unsupported elastic version: %s, maximum supported version of %s is less than %s',
+ detected_distro_version, detected_distro_name, supported_distro_info['till'])
valid = false
elseif (cmp_till >= 0) and supported_distro_info['till_unknown'] then
rspamd_logger.warnx(rspamd_config,
- 'compatibility of elastic version: %s is unknown, maximum known supported version of %s is less than %s,' ..
- 'use at your own risk',
- detected_distro_version, detected_distro_name, supported_distro_info['till'])
+ 'compatibility of elastic version: %s is unknown, maximum known ' ..
+ 'supported version of %s is less than %s, use at your own risk',
+ detected_distro_version, detected_distro_name, supported_distro_info['till'])
valid_unknown = true
end
end
@@ -1460,11 +1429,12 @@ local function verify_distro(manual)
else
if valid and manual then
rspamd_logger.infox(
- rspamd_config, 'assuming elastic distro: %s, version: %s', detected_distro_name, detected_distro_version)
+ rspamd_config, 'assuming elastic distro: %s, version: %s', detected_distro_name, detected_distro_version)
detected_distro['supported'] = true
elseif valid and not manual then
- rspamd_logger.infox(rspamd_config, 'successfully connected to elastic distro: %s, version: %s',
- detected_distro_name, detected_distro_version)
+ rspamd_logger.infox(rspamd_config,
+ 'successfully connected to elastic distro: %s, version: %s',
+ detected_distro_name, detected_distro_version)
detected_distro['supported'] = true
else
handle_error('configure', 'distro', settings['version']['autodetect_max_fail'])
@@ -1477,7 +1447,7 @@ local function configure_distro(cfg, ev_base)
detected_distro['name'] = settings['version']['override']['name']
detected_distro['version'] = settings['version']['override']['version']
rspamd_logger.infox(rspamd_config,
- 'automatic detection of elastic distro and version is disabled, taking configuration from settings')
+ 'automatic detection of elastic distro and version is disabled, taking configuration from settings')
verify_distro(true)
end
@@ -1490,14 +1460,16 @@ local function configure_distro(cfg, ev_base)
rspamd_logger.errx(rspamd_config, 'cannot connect to elastic (%s): %s', root_url, err)
upstream:fail()
elseif code ~= 200 then
- rspamd_logger.errx(rspamd_config, 'cannot connect to elastic (%s), status code: %s, response: %s', root_url, code,
- body)
+ rspamd_logger.errx(rspamd_config,
+ 'cannot connect to elastic (%s), status code: %s, response: %s',
+ root_url, code, body)
upstream:fail()
else
local parser = ucl.parser()
local res, ucl_err = parser:parse_string(body)
if not res then
- rspamd_logger.errx(rspamd_config, 'failed to parse reply from elastic (%s): %s', root_url, ucl_err)
+ rspamd_logger.errx(rspamd_config, 'failed to parse reply from elastic (%s): %s',
+ root_url, ucl_err)
upstream:fail()
else
local obj = parser:get_object()
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index feccae73f..4888eaa19 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -48,6 +48,8 @@ gpt {
allow_passthrough = false;
# Check messages that are apparent ham (no action and negative score)
allow_ham = false;
+ # default send response_format field { type = "json_object" }
+ include_response_format = true,
}
]])
return
@@ -317,6 +319,47 @@ local function ollama_conversion(task, input)
return
end
+local function check_consensus(task, results)
+ for _, result in ipairs(results) do
+ if not result.checked then
+ return
+ end
+ end
+
+ local nspam, nham = 0, 0
+ local max_spam_prob, max_ham_prob = 0, 0
+
+ for _, result in ipairs(results) do
+ if result.success then
+ if result.probability > 0.5 then
+ nspam = nspam + 1
+ max_spam_prob = math.max(max_spam_prob, result.probability)
+ lua_util.debugm(N, task, "model: %s; spam: %s", result.model, result.probability)
+ else
+ nham = nham + 1
+ max_ham_prob = math.min(max_ham_prob, result.probability)
+ lua_util.debugm(N, task, "model: %s; ham: %s", result.model, result.probability)
+ end
+ end
+ end
+
+ if nspam > nham and max_spam_prob > 0.75 then
+ task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob))
+ if settings.autolearn then
+ task:set_flag("learn_spam")
+ end
+ elseif nham > nspam and max_ham_prob < 0.25 then
+ task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob))
+ if settings.autolearn then
+ task:set_flag("learn_ham")
+ end
+ else
+ -- No consensus
+ lua_util.debugm(N, task, "no consensus")
+ end
+
+end
+
local function get_meta_llm_content(task)
local url_content = "Url domains: no urls found"
if task:has_urls() then
@@ -351,40 +394,36 @@ local function default_llm_check(task)
local upstream
- local function on_reply(err, code, body)
+ local results = {}
- if err then
- rspamd_logger.errx(task, 'request failed: %s', err)
- upstream:fail()
- return
- end
+ local function gen_reply_closure(model, idx)
+ return function(err, code, body)
+ results[idx].checked = true
+ if err then
+ rspamd_logger.errx(task, '%s: request failed: %s', model, err)
+ upstream:fail()
+ check_consensus(task, results)
+ return
+ end
- upstream:ok()
- lua_util.debugm(N, task, "got reply: %s", body)
- if code ~= 200 then
- rspamd_logger.errx(task, 'bad reply: %s', body)
- return
- end
+ upstream:ok()
+ lua_util.debugm(N, task, "%s: got reply: %s", model, body)
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'bad reply: %s', body)
+ return
+ end
- local reply = settings.reply_conversion(task, body)
- if not reply then
- return
- end
+ local reply = settings.reply_conversion(task, body)
- if reply > 0.75 then
- task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_spam")
- end
- elseif reply < 0.25 then
- task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_ham")
+ results[idx].model = model
+
+ if reply then
+ results[idx].success = true
+ results[idx].probability = reply
end
- else
- lua_util.debugm(N, task, "uncertain result: %s", reply)
- end
+ check_consensus(task, results)
+ end
end
local from_content, url_content = get_meta_llm_content(task)
@@ -393,7 +432,6 @@ local function default_llm_check(task)
model = settings.model,
max_tokens = settings.max_tokens,
temperature = settings.temperature,
- response_format = { type = "json_object" },
messages = {
{
role = 'system',
@@ -418,24 +456,43 @@ local function default_llm_check(task)
}
}
+ -- Conditionally add response_format
+ if settings.include_response_format then
+ body.response_format = { type = "json_object" }
+ end
+
+ if type(settings.model) == 'string' then
+ settings.model = { settings.model }
+ end
+
upstream = settings.upstreams:get_upstream_round_robin()
- local http_params = {
- url = settings.url,
- mime_type = 'application/json',
- timeout = settings.timeout,
- log_obj = task,
- callback = on_reply,
- headers = {
- ['Authorization'] = 'Bearer ' .. settings.api_key,
- },
- keepalive = true,
- body = ucl.to_format(body, 'json-compact', true),
- task = task,
- upstream = upstream,
- use_gzip = true,
- }
+ for idx, model in ipairs(settings.model) do
+ results[idx] = {
+ success = false,
+ checked = false
+ }
+ body.model = model
+ local http_params = {
+ url = settings.url,
+ mime_type = 'application/json',
+ timeout = settings.timeout,
+ log_obj = task,
+ callback = gen_reply_closure(model, idx),
+ headers = {
+ ['Authorization'] = 'Bearer ' .. settings.api_key,
+ },
+ keepalive = true,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ upstream = upstream,
+ use_gzip = true,
+ }
+
+ if not rspamd_http.request(http_params) then
+ results[idx].checked = true
+ end
- rspamd_http.request(http_params)
+ end
end
local function ollama_check(task)
@@ -454,51 +511,49 @@ local function ollama_check(task)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
local upstream
+ local results = {}
+
+ local function gen_reply_closure(model, idx)
+ return function(err, code, body)
+ results[idx].checked = true
+ if err then
+ rspamd_logger.errx(task, '%s: request failed: %s', model, err)
+ upstream:fail()
+ check_consensus(task, results)
+ return
+ end
- local function on_reply(err, code, body)
-
- if err then
- rspamd_logger.errx(task, 'request failed: %s', err)
- upstream:fail()
- return
- end
+ upstream:ok()
+ lua_util.debugm(N, task, "%s: got reply: %s", model, body)
+ if code ~= 200 then
+ rspamd_logger.errx(task, 'bad reply: %s', body)
+ return
+ end
- upstream:ok()
- lua_util.debugm(N, task, "got reply: %s", body)
- if code ~= 200 then
- rspamd_logger.errx(task, 'bad reply: %s', body)
- return
- end
+ local reply = settings.reply_conversion(task, body)
- local reply = settings.reply_conversion(task, body)
- if not reply then
- return
- end
+ results[idx].model = model
- if reply > 0.75 then
- task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_spam")
- end
- elseif reply < 0.25 then
- task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply))
- if settings.autolearn then
- task:set_flag("learn_ham")
+ if reply then
+ results[idx].success = true
+ results[idx].probability = reply
end
- else
- lua_util.debugm(N, task, "uncertain result: %s", reply)
- end
+ check_consensus(task, results)
+ end
end
local from_content, url_content = get_meta_llm_content(task)
+ if type(settings.model) == 'string' then
+ settings.model = { settings.model }
+ end
+
local body = {
stream = false,
model = settings.model,
max_tokens = settings.max_tokens,
temperature = settings.temperature,
- response_format = { type = "json_object" },
messages = {
{
role = 'system',
@@ -523,21 +578,30 @@ local function ollama_check(task)
}
}
- upstream = settings.upstreams:get_upstream_round_robin()
- local http_params = {
- url = settings.url,
- mime_type = 'application/json',
- timeout = settings.timeout,
- log_obj = task,
- callback = on_reply,
- keepalive = true,
- body = ucl.to_format(body, 'json-compact', true),
- task = task,
- upstream = upstream,
- use_gzip = true,
- }
+ for i, model in ipairs(settings.model) do
+ -- Conditionally add response_format
+ if settings.include_response_format then
+ body.response_format = { type = "json_object" }
+ end
- rspamd_http.request(http_params)
+ body.model = model
+
+ upstream = settings.upstreams:get_upstream_round_robin()
+ local http_params = {
+ url = settings.url,
+ mime_type = 'application/json',
+ timeout = settings.timeout,
+ log_obj = task,
+ callback = gen_reply_closure(model, i),
+ keepalive = true,
+ body = ucl.to_format(body, 'json-compact', true),
+ task = task,
+ upstream = upstream,
+ use_gzip = true,
+ }
+
+ rspamd_http.request(http_params)
+ end
end
local function gpt_check(task)
@@ -618,4 +682,4 @@ if opts then
parent = id,
score = -2.0,
})
-end \ No newline at end of file
+end
diff --git a/src/plugins/lua/rbl.lua b/src/plugins/lua/rbl.lua
index 76c84f85d..bcdef5ce1 100644
--- a/src/plugins/lua/rbl.lua
+++ b/src/plugins/lua/rbl.lua
@@ -596,7 +596,7 @@ local function gen_rbl_callback(rule)
ignore_ip = rule.no_ip,
need_images = rule.images,
need_emails = false,
- need_content = rule.content_urls or false,
+ need_content = rule.content_urls,
esld_limit = esld_lim,
no_cache = true,
}
@@ -983,7 +983,7 @@ local function gen_rbl_callback(rule)
if req.resolve_ip then
-- Deal with both ipv4 and ipv6
-- Resolve names first
- if r:resolve_a({
+ if (rule.ipv4 == nil or rule.ipv4) and r:resolve_a({
task = task,
name = req.n,
callback = gen_rbl_ip_dns_callback(req),
@@ -991,7 +991,7 @@ local function gen_rbl_callback(rule)
}) then
nresolved = nresolved + 1
end
- if r:resolve('aaaa', {
+ if (rule.ipv6 == nil or rule.ipv6) and r:resolve('aaaa', {
task = task,
name = req.n,
callback = gen_rbl_ip_dns_callback(req),