diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-02-25 11:38:28 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2025-02-25 11:38:28 +0000 |
commit | 67d4edcf92e3c13a6b253ea25a7a8ad2db3ea701 (patch) | |
tree | ccd12afde922ecdc29fd9a6147e9b6151d85061a | |
parent | 20ab502c23bb64d70cd56a10fae3345ad6d28d98 (diff) | |
download | rspamd-67d4edcf92e3c13a6b253ea25a7a8ad2db3ea701.tar.gz rspamd-67d4edcf92e3c13a6b253ea25a7a8ad2db3ea701.zip |
[Feature] Cache LLM replies
-rw-r--r-- | src/plugins/lua/gpt.lua | 173 |
1 files changed, 124 insertions, 49 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 270d0fdfc..6c6d8d685 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -15,6 +15,7 @@ limitations under the License. ]] -- local N = "gpt" +local REDIS_PREFIX = "rsllm_" local E = {} if confighelp then @@ -61,6 +62,7 @@ local lua_util = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" local lua_mime = require "lua_mime" +local lua_redis = require "lua_redis" local ucl = require "ucl" local fun = require "fun" @@ -92,7 +94,9 @@ local settings = { allow_passthrough = false, allow_ham = false, json = false, + redis_cache_expire = 3600 * 24, } +local redis_params local function default_condition(task) -- Check result @@ -176,10 +180,10 @@ local function default_condition(task) local words = sel_part:get_words('norm') nwords = #words if nwords > settings.max_tokens then - return true, table.concat(words, ' ', 1, settings.max_tokens) + return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part end end - return true, sel_part:get_content_oneline() + return true, sel_part:get_content_oneline(), sel_part end local function maybe_extract_json(str) @@ -431,7 +435,48 @@ local function default_ollama_json_conversion(task, input) return end -local function check_consensus(task, results) +local function maybe_save_cache(task, result, sel_part) + if not sel_part or not redis_params then + lua_util.debugm(N, task, 'cannot save cache: no part or no redis') + return -- cannot save + end + + local digest = sel_part:get_mimepart():get_digest() + local cache_key = REDIS_PREFIX .. digest + lua_util.debugm(N, task, 'saving cache for %s', cache_key) + local result_json = ucl.to_format(result, 'json-compact') + lua_redis.redis_make_request(task, redis_params, cache_key, false, function(err, _) + if err then + rspamd_logger.errx(task, 'cannot save cache: %s', err) + end + end, + 'SETEX', { cache_key, tostring(settings.redis_cache_expire), result_json }) +end + +local function insert_results(task, result, sel_part) + if not result.probability then + rspamd_logger.errx(task, 'no probability in result') + return + end + if result.probability > 0.5 then + task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_spam") + end + else + if result.reason and settings.reason_header then + lua_mime.modify_headers(task, + { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) + end + task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_ham") + end + end + maybe_save_cache(task, result, sel_part) +end + +local function check_consensus_and_insert_results(task, results, sel_part) for _, result in ipairs(results) do if not result.checked then return @@ -466,24 +511,17 @@ local function check_consensus(task, results) local reason = reasons[1] or nil if nspam > nham and max_spam_prob > 0.75 then - task:insert_result('GPT_SPAM', (max_spam_prob - 0.75) * 4, tostring(max_spam_prob)) - if settings.autolearn then - task:set_flag("learn_spam") - end - - if reason and settings.reason_header then - lua_mime.modify_headers(task, - { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) - end + insert_results(task, { + probability = max_spam_prob, + reason = reason, + }, + sel_part) elseif nham > nspam and max_ham_prob < 0.25 then - task:insert_result('GPT_HAM', (0.25 - max_ham_prob) * 4, tostring(max_ham_prob)) - if settings.autolearn then - task:set_flag("learn_ham") - end - if reason and settings.reason_header then - lua_mime.modify_headers(task, - { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) - end + insert_results(task, { + probability = max_ham_prob, + reason = reason, + }, + sel_part) else -- No consensus lua_util.debugm(N, task, "no consensus") @@ -508,19 +546,43 @@ local function get_meta_llm_content(task) return url_content, from_content end -local function default_llm_check(task) - local ret, content = settings.condition(task) +local function check_llm_uncached(task, content, sel_part) + return settings.specific_check(task, content, sel_part) +end - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return - end +local function check_llm_cached(task, content, sel_part) + local digest = sel_part:get_mimepart():get_digest() + local cache_key = REDIS_PREFIX .. digest - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return + local ret = lua_redis.redis_make_request(task, redis_params, cache_key, false, function(_, err, data) + if err then + rspamd_logger.errx(task, 'cannot check cache: %s', err) + check_llm_uncached(task, content, sel_part) + end + + if data then + local parser = ucl.parser() + local res, parse_err = parser:parse_string(data) + if not res then + rspamd_logger.errx(task, 'Cannot parse cached response: %s', parse_err) + check_llm_uncached(task, content, sel_part) + else + rspamd_logger.infox(task, 'found cached response') + insert_results(task, parser:get_object()) + end + else + check_llm_uncached(task, content, sel_part) + end + end, + 'GET', { cache_key }) + + if not ret then + rspamd_logger.errx(task, 'cannot query cache for request') + check_llm_uncached(task, content, sel_part) end +end +local function openai_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream @@ -533,7 +595,7 @@ local function default_llm_check(task) if err then rspamd_logger.errx(task, '%s: request failed: %s', model, err) upstream:fail() - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) return end @@ -554,7 +616,7 @@ local function default_llm_check(task) results[idx].reason = reason end - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) end end @@ -627,19 +689,7 @@ local function default_llm_check(task) end end -local function ollama_check(task) - local ret, content = settings.condition(task) - - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return - end - - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return - end - +local function ollama_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream @@ -651,7 +701,7 @@ local function ollama_check(task) if err then rspamd_logger.errx(task, '%s: request failed: %s', model, err) upstream:fail() - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) return end @@ -672,7 +722,7 @@ local function ollama_check(task) results[idx].reason = reason end - check_consensus(task, results) + check_consensus_and_insert_results(task, results, sel_part) end end @@ -738,12 +788,29 @@ local function ollama_check(task) end local function gpt_check(task) - return settings.specific_check(task) + local ret, content, sel_part = settings.condition(task) + + if not ret then + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) + return + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + return + end + + if sel_part then + -- Check digest + check_llm_cached(task, content, sel_part) + else + check_llm_uncached(task, content) + end end local types_map = { openai = { - check = default_llm_check, + check = openai_check, condition = default_condition, conversion = function(is_json) return is_json and default_openai_json_conversion or default_openai_plain_conversion @@ -760,10 +827,18 @@ local types_map = { }, } -local opts = rspamd_config:get_all_opt('gpt') +local opts = rspamd_config:get_all_opt(N) if opts then + redis_params = lua_redis.parse_redis_server(N, opts) settings = lua_util.override_defaults(settings, opts) + if redis_params then + lua_redis.register_prefix(REDIS_PREFIX .. '*', N, + 'Cache of LLM requests', { + type = 'string', + }) + end + if not settings.prompt then settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. "FROM and url domains. Evaluate spam probability (0-1). " .. |