aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2025-02-25 11:38:28 +0000
committerVsevolod Stakhov <vsevolod@rspamd.com>2025-02-25 11:38:28 +0000
commit67d4edcf92e3c13a6b253ea25a7a8ad2db3ea701 (patch)
treeccd12afde922ecdc29fd9a6147e9b6151d85061a
parent20ab502c23bb64d70cd56a10fae3345ad6d28d98 (diff)
downloadrspamd-67d4edcf92e3c13a6b253ea25a7a8ad2db3ea701.tar.gz
rspamd-67d4edcf92e3c13a6b253ea25a7a8ad2db3ea701.zip
[Feature] Cache LLM replies
-rw-r--r--src/plugins/lua/gpt.lua173
1 files changed, 124 insertions, 49 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index 270d0fdfc..6c6d8d685 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -15,6 +15,7 @@ limitations under the License.
]] --
local N = "gpt"
+local REDIS_PREFIX = "rsllm_"
local E = {}
if confighelp then
@@ -61,6 +62,7 @@ local lua_util = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local lua_mime = require "lua_mime"
+local lua_redis = require "lua_redis"
local ucl = require "ucl"
local fun = require "fun"
@@ -92,7 +94,9 @@ local settings = {
allow_passthrough = false,
allow_ham = false,
json = false,
+ redis_cache_expire = 3600 * 24,
}
+local redis_params
local function default_condition(task)
-- Check result
@@ -176,10 +180,10 @@ local function default_condition(task)
local words = sel_part:get_words('norm')
nwords = #words
if nwords > settings.max_tokens then
- return true, table.concat(words, ' ', 1, settings.max_tokens)
+ return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part
end
end
- return true, sel_part:get_content_oneline()
+ return true, sel_part:get_content_oneline(), sel_part
end
local function maybe_extract_json(str)
@@ -431,7 +435,48 @@ local function default_ollama_json_conversion(task, input)
return
end
-local function check_consensus(task, results)
+local function maybe_save_cache(task, result, sel_part)
+ if not sel_part or not redis_params then
+ lua_util.debugm(N, task, 'cannot save cache: no part or no redis')
+ return -- cannot save
+ end
+
+ local digest = sel_part:get_mimepart():get_digest()
+ local cache_key = REDIS_PREFIX .. digest
+ lua_util.debugm(N, task, 'saving cache for %s', cache_key)
+ local result_json = ucl.to_format(result, 'json-compact')
+ lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, _)
+ if err then
+ rspamd_logger.errx(task, 'cannot save cache: %s', err)
+ end
+ end,
+ 'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json })
+end
+
+local function insert_results(task, result, sel_part)
+ if not result.probability then
+ rspamd_logger.errx(task, 'no probability in result')
+ return
+ end
+ if result.probability > 0.5 then
+ task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability))
+ if settings.autolearn then
+ task:set_flag("learn_spam")
+ end
+ else
+ if result.reason and settings.reason_header then
+ lua_mime.modify_headers(task,
+ { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
+ end
+ task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability))
+ if settings.autolearn then
+ task:set_flag("learn_ham")
+ end
+ end
+ maybe_save_cache(task, result, sel_part)
+end
+
+local function check_consensus_and_insert_results(task, results, sel_part)
for _, result in ipairs(results) do
if not result.checked then
return
@@ -466,24 +511,17 @@ local function check_consensus(task, results)
local reason = reasons[1] or nil
if nspam > nham and max_spam_prob > 0.75 then
- task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob))
- if settings.autolearn then
- task:set_flag("learn_spam")
- end
-
- if reason and settings.reason_header then
- lua_mime.modify_headers(task,
- { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
- end
+ insert_results(task, {
+ probability = max_spam_prob,
+ reason = reason,
+ },
+ sel_part)
elseif nham > nspam and max_ham_prob < 0.25 then
- task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob))
- if settings.autolearn then
- task:set_flag("learn_ham")
- end
- if reason and settings.reason_header then
- lua_mime.modify_headers(task,
- { add = { [settings.reason_header] = { value = 'value', order = 1 } } })
- end
+ insert_results(task, {
+ probability = max_ham_prob,
+ reason = reason,
+ },
+ sel_part)
else
-- No consensus
lua_util.debugm(N, task, "no consensus")
@@ -508,19 +546,43 @@ local function get_meta_llm_content(task)
return url_content, from_content
end
-local function default_llm_check(task)
- local ret, content = settings.condition(task)
+local function check_llm_uncached(task, content, sel_part)
+ return settings.specific_check(task, content, sel_part)
+end
- if not ret then
- rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
- return
- end
+local function check_llm_cached(task, content, sel_part)
+ local digest = sel_part:get_mimepart():get_digest()
+ local cache_key = REDIS_PREFIX .. digest
- if not content then
- lua_util.debugm(N, task, "no content to send to gpt classification")
- return
+ local ret = lua_redis.redis_make_request(task, redis_params, cache_key, false, function(_, err, data)
+ if err then
+ rspamd_logger.errx(task, 'cannot check cache: %s', err)
+ check_llm_uncached(task, content, sel_part)
+ end
+
+ if data then
+ local parser = ucl.parser()
+ local res, parse_err = parser:parse_string(data)
+ if not res then
+ rspamd_logger.errx(task, 'Cannot parse cached response: %s', parse_err)
+ check_llm_uncached(task, content, sel_part)
+ else
+ rspamd_logger.infox(task, 'found cached response')
+ insert_results(task, parser:get_object())
+ end
+ else
+ check_llm_uncached(task, content, sel_part)
+ end
+ end,
+ 'GET', { cache_key })
+
+ if not ret then
+ rspamd_logger.errx(task, 'cannot query cache for request')
+ check_llm_uncached(task, content, sel_part)
end
+end
+local function openai_check(task, content, sel_part)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
local upstream
@@ -533,7 +595,7 @@ local function default_llm_check(task)
if err then
rspamd_logger.errx(task, '%s: request failed: %s', model, err)
upstream:fail()
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
return
end
@@ -554,7 +616,7 @@ local function default_llm_check(task)
results[idx].reason = reason
end
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
end
end
@@ -627,19 +689,7 @@ local function default_llm_check(task)
end
end
-local function ollama_check(task)
- local ret, content = settings.condition(task)
-
- if not ret then
- rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
- return
- end
-
- if not content then
- lua_util.debugm(N, task, "no content to send to gpt classification")
- return
- end
-
+local function ollama_check(task, content, sel_part)
lua_util.debugm(N, task, "sending content to gpt: %s", content)
local upstream
@@ -651,7 +701,7 @@ local function ollama_check(task)
if err then
rspamd_logger.errx(task, '%s: request failed: %s', model, err)
upstream:fail()
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
return
end
@@ -672,7 +722,7 @@ local function ollama_check(task)
results[idx].reason = reason
end
- check_consensus(task, results)
+ check_consensus_and_insert_results(task, results, sel_part)
end
end
@@ -738,12 +788,29 @@ local function ollama_check(task)
end
local function gpt_check(task)
- return settings.specific_check(task)
+ local ret, content, sel_part = settings.condition(task)
+
+ if not ret then
+ rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content)
+ return
+ end
+
+ if not content then
+ lua_util.debugm(N, task, "no content to send to gpt classification")
+ return
+ end
+
+ if sel_part then
+ -- Check digest
+ check_llm_cached(task, content, sel_part)
+ else
+ check_llm_uncached(task, content)
+ end
end
local types_map = {
openai = {
- check = default_llm_check,
+ check = openai_check,
condition = default_condition,
conversion = function(is_json)
return is_json and default_openai_json_conversion or default_openai_plain_conversion
@@ -760,10 +827,18 @@ local types_map = {
},
}
-local opts = rspamd_config:get_all_opt('gpt')
+local opts = rspamd_config:get_all_opt(N)
if opts then
+ redis_params = lua_redis.parse_redis_server(N, opts)
settings = lua_util.override_defaults(settings, opts)
+ if redis_params then
+ lua_redis.register_prefix(REDIS_PREFIX .. '*', N,
+ 'Cache of LLM requests', {
+ type = 'string',
+ })
+ end
+
if not settings.prompt then
settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " ..
"FROM and url domains. Evaluate spam probability (0-1). " ..