aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/gpt.lua
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2024-06-28 10:53:10 +0100
committerVsevolod Stakhov <vsevolod@rspamd.com>2024-06-28 10:53:10 +0100
commitc052ad8f8cf8b15c75f59d6ca82e51f640006257 (patch)
treef57e4419a700603118e1a2ec1ec99985bd6afbb6 /src/plugins/lua/gpt.lua
parent109198a958ac45031ede5316af4a85a59f0cfad8 (diff)
downloadrspamd-c052ad8f8cf8b15c75f59d6ca82e51f640006257.tar.gz
rspamd-c052ad8f8cf8b15c75f59d6ca82e51f640006257.zip
[Minor] Use keep-alive and upstreams logic
Diffstat (limited to 'src/plugins/lua/gpt.lua')
-rw-r--r--src/plugins/lua/gpt.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua
index d7d1c496e..ec2e05a33 100644
--- a/src/plugins/lua/gpt.lua
+++ b/src/plugins/lua/gpt.lua
@@ -43,6 +43,8 @@ gpt {
autolearn = true;
# Reply conversion (lua code)
reply_conversion = "xxx";
+ # URL for the API
+ url = "https://api.openai.com/v1/chat/completions";ß
}
]])
return
@@ -64,6 +66,7 @@ local settings = {
prompt = nil,
condition = nil,
autolearn = false,
+ url = 'https://api.openai.com/v1/chat/completions',
}
local function default_condition(task)
@@ -113,14 +116,17 @@ local function openai_gpt_check(task)
lua_util.debugm(N, task, "skip checking gpt as the condition is not met")
return
end
+ local upstream
local function on_reply(err, code, body)
if err then
rspamd_logger.errx(task, 'request failed: %s', err)
+ upstream:fail()
return
end
+ upstream:ok()
lua_util.debugm(N, task, "got reply: %s", body)
if code ~= 200 then
rspamd_logger.errx(task, 'bad reply: %s', body)
@@ -177,8 +183,10 @@ local function openai_gpt_check(task)
}
}
}
+
+ upstream = settings.upstreams:get_upstream_round_robin()
local http_params = {
- url = 'https://api.openai.com/v1/chat/completions',
+ url = settings.url,
mime_type = 'application/json',
timeout = settings.timeout,
log_obj = task,
@@ -186,8 +194,10 @@ local function openai_gpt_check(task)
headers = {
['Authorization'] = 'Bearer ' .. settings.api_key,
},
+ keepalive = true,
body = ucl.to_format(body, 'json-compact', true),
task = task,
+ upstream = upstream,
}
rspamd_http.request(http_params)
@@ -233,6 +243,8 @@ if opts then
return
end
+ settings.upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), settings.url)
+
local id = rspamd_config:register_symbol({
name = 'GPT_CHECK',
type = 'postfilter',