]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Use keep-alive and upstreams logic
authorVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 28 Jun 2024 09:53:10 +0000 (10:53 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Fri, 28 Jun 2024 09:53:10 +0000 (10:53 +0100)
src/plugins/lua/gpt.lua

index d7d1c496e04028bf23a1babf615e527630c277c8..ec2e05a333b5b5d739b4bc611323fcffe711a612 100644 (file)
@@ -43,6 +43,8 @@ gpt {
   autolearn = true;
   # Reply conversion (lua code)
   reply_conversion = "xxx";
+  # URL for the API
+  url = "https://api.openai.com/v1/chat/completions";ß
 }
   ]])
   return
@@ -64,6 +66,7 @@ local settings = {
   prompt = nil,
   condition = nil,
   autolearn = false,
+  url = 'https://api.openai.com/v1/chat/completions',
 }
 
 local function default_condition(task)
@@ -113,14 +116,17 @@ local function openai_gpt_check(task)
     lua_util.debugm(N, task, "skip checking gpt as the condition is not met")
     return
   end
+  local upstream
 
   local function on_reply(err, code, body)
 
     if err then
       rspamd_logger.errx(task, 'request failed: %s', err)
+      upstream:fail()
       return
     end
 
+    upstream:ok()
     lua_util.debugm(N, task, "got reply: %s", body)
     if code ~= 200 then
       rspamd_logger.errx(task, 'bad reply: %s', body)
@@ -177,8 +183,10 @@ local function openai_gpt_check(task)
       }
     }
   }
+
+  upstream = settings.upstreams:get_upstream_round_robin()
   local http_params = {
-    url = 'https://api.openai.com/v1/chat/completions',
+    url = settings.url,
     mime_type = 'application/json',
     timeout = settings.timeout,
     log_obj = task,
@@ -186,8 +194,10 @@ local function openai_gpt_check(task)
     headers = {
       ['Authorization'] = 'Bearer ' .. settings.api_key,
     },
+    keepalive = true,
     body = ucl.to_format(body, 'json-compact', true),
     task = task,
+    upstream = upstream,
   }
 
   rspamd_http.request(http_params)
@@ -233,6 +243,8 @@ if opts then
     return
   end
 
+  settings.upstreams = lua_util.http_upstreams_by_url(rspamd_config:get_mempool(), settings.url)
+
   local id = rspamd_config:register_symbol({
     name = 'GPT_CHECK',
     type = 'postfilter',