]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Allow lua_http module to accept upstreams
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 2 Jul 2022 12:32:56 +0000 (13:32 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 2 Jul 2022 12:32:56 +0000 (13:32 +0100)
src/lua/lua_common.h
src/lua/lua_http.c
src/lua/lua_upstream.c

index c961d37ddfe96b191a58f50509d8633d44f96972..a6e98a4bacd835e3f153f658cad0f59e501c7662 100644 (file)
@@ -165,6 +165,11 @@ struct rspamd_lua_cached_entry {
        guint id;
 };
 
+struct rspamd_lua_upstream {
+       struct upstream *up;
+       gint upref;
+};
+
 /* Common utility functions */
 
 /**
@@ -284,7 +289,9 @@ struct rspamd_lua_text *lua_new_text (lua_State *L, const gchar *start,
  */
 bool lua_is_text_binary(struct rspamd_lua_text *t);
 
-struct rspamd_lua_regexp *lua_check_regexp (lua_State *L, gint pos);
+struct rspamd_lua_regexp* lua_check_regexp (lua_State *L, gint pos);
+
+struct rspamd_lua_upstream* lua_check_upstream(lua_State *L, int pos);
 
 enum rspamd_lua_task_header_type {
        RSPAMD_TASK_HEADER_PUSH_SIMPLE = 0,
index 1fb5732f84cf4ad3f1227908a63e9dff466c1b15..5a05f7058e97ca0d2e876f2709e0f1e8fddfb61f 100644 (file)
@@ -16,6 +16,7 @@
 #include "lua_common.h"
 #include "lua_thread_pool.h"
 #include "libserver/http/http_private.h"
+#include "libutil/upstream.h"
 #include "ref.h"
 #include "unix-std.h"
 #include "zlib.h"
@@ -77,6 +78,7 @@ struct lua_http_cbdata {
        gchar *mime_type;
        gchar *host;
        gchar *auth;
+       struct upstream *up;
        const gchar *url;
        gsize max_size;
        gint flags;
@@ -201,6 +203,11 @@ static void
 lua_http_error_handler (struct rspamd_http_connection *conn, GError *err)
 {
        struct lua_http_cbdata *cbd = (struct lua_http_cbdata *)conn->ud;
+
+       if (cbd->up) {
+               rspamd_upstream_fail(cbd->up, false, err ? err->message : "unknown error");
+       }
+
        if (cbd->cbref == -1) {
                if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_YIELDED) {
                        cbd->flags &= ~RSPAMD_LUA_HTTP_FLAG_YIELDED;
@@ -599,6 +606,7 @@ lua_http_request (lua_State *L)
        struct rspamd_config *cfg = NULL;
        struct rspamd_cryptobox_pubkey *peer_key = NULL;
        struct rspamd_cryptobox_keypair *local_kp = NULL;
+       struct upstream *up = NULL;
        const gchar *url, *lua_body;
        rspamd_fstring_t *body = NULL;
        gint cbref = -1;
@@ -940,6 +948,20 @@ lua_http_request (lua_State *L)
 
                lua_pop (L, 1);
 
+               lua_pushstring (L, "upstream");
+               lua_gettable (L, 1);
+
+               if (lua_type (L, -1) == LUA_TUSERDATA) {
+                       struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1);
+
+                       if (lup) {
+                               /* Preserve pointer in case if lup is destructed */
+                               up = lup->up;
+                       }
+               }
+
+               lua_pop (L, 1);
+
                lua_pushstring (L, "user");
                lua_gettable (L, 1);
 
@@ -1036,6 +1058,7 @@ lua_http_request (lua_State *L)
        cbd->url = url;
        cbd->auth = auth;
        cbd->task = task;
+       cbd->up = up;
 
        if (cbd->cbref == -1) {
                cbd->thread = lua_thread_pool_get_running_entry (cfg->lua_thread_pool);
@@ -1065,6 +1088,7 @@ lua_http_request (lua_State *L)
        bool numeric_ip = false;
 
        /* Check if we can skip resolving */
+
        gsize hostlen = 0;
        const gchar *host = rspamd_http_message_get_http_host (msg, &hostlen);
 
@@ -1072,6 +1096,7 @@ lua_http_request (lua_State *L)
                cbd->host = g_malloc (hostlen + 1);
                rspamd_strlcpy (cbd->host, host, hostlen + 1);
 
+               /* Keep-alive entry is available */
                if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_KEEP_ALIVE) {
                        const rspamd_inet_addr_t *ka_addr = rspamd_http_context_has_keepalive(NULL,
                                        cbd->host,
@@ -1084,16 +1109,30 @@ lua_http_request (lua_State *L)
                        }
                }
 
+               /*
+                * No keep-alive stuff, check if we have upstream or if we can parse host as
+                * a numeric address
+                */
                if (!cbd->addr) {
-                       /* We use msg->host here, not cbd->host ! */
-                       if (rspamd_parse_inet_address (&cbd->addr,
-                                       msg->host->str, msg->host->len,
-                                       RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) {
+                       if (cbd->up) {
                                numeric_ip = true;
+                               cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL);
+                       }
+                       else {
+                               /* We use msg->host here, not cbd->host ! */
+                               if (rspamd_parse_inet_address(&cbd->addr,
+                                               msg->host->str, msg->host->len,
+                                               RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) {
+                                       numeric_ip = true;
+                               }
                        }
                }
        }
        else {
+               if (cbd->up) {
+                       numeric_ip = true;
+                       cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL);
+               }
                cbd->host = NULL;
        }
 
@@ -1105,6 +1144,9 @@ lua_http_request (lua_State *L)
                ret = lua_http_make_connection (cbd);
 
                if (!ret) {
+                       if (cbd->up) {
+                               rspamd_upstream_fail(cbd->up, true, "HTTP connection failed");
+                       }
                        if (cbd->ref.refcount > 1) {
                                /* Not released by make_connection */
                                REF_RELEASE (cbd);
index 5019f28d349a6e33cd4a7bb4c5594bfc9e512744..94e251a954b5bf928b565f7d0654992257df86c0 100644 (file)
@@ -95,15 +95,10 @@ static const struct luaL_reg upstream_m[] = {
 
 /* Upstream class */
 
-struct rspamd_lua_upstream {
-       struct upstream *up;
-       gint upref;
-};
-
-static struct rspamd_lua_upstream *
-lua_check_upstream (lua_State * L)
+struct rspamd_lua_upstream *
+lua_check_upstream(lua_State *L, int pos)
 {
-       void *ud = rspamd_lua_check_udata (L, 1, "rspamd{upstream}");
+       void *ud = rspamd_lua_check_udata (L, pos, "rspamd{upstream}");
 
        luaL_argcheck (L, ud != NULL, 1, "'upstream' expected");
        return ud ? (struct rspamd_lua_upstream *)ud : NULL;
@@ -118,7 +113,7 @@ static gint
 lua_upstream_get_addr (lua_State *L)
 {
        LUA_TRACE_POINT;
-       struct rspamd_lua_upstream *up = lua_check_upstream (L);
+       struct rspamd_lua_upstream *up = lua_check_upstream(L, 1);
 
        if (up) {
                rspamd_lua_ip_push (L, rspamd_upstream_addr_next (up->up));
@@ -139,7 +134,7 @@ static gint
 lua_upstream_get_name (lua_State *L)
 {
        LUA_TRACE_POINT;
-       struct rspamd_lua_upstream *up = lua_check_upstream (L);
+       struct rspamd_lua_upstream *up = lua_check_upstream(L, 1);
 
        if (up) {
                lua_pushstring (L, rspamd_upstream_name (up->up));
@@ -160,7 +155,7 @@ static gint
 lua_upstream_get_port (lua_State *L)
 {
        LUA_TRACE_POINT;
-       struct rspamd_lua_upstream *up = lua_check_upstream (L);
+       struct rspamd_lua_upstream *up = lua_check_upstream(L, 1);
 
        if (up) {
                lua_pushinteger (L, rspamd_upstream_port (up->up));
@@ -180,7 +175,7 @@ static gint
 lua_upstream_fail (lua_State *L)
 {
        LUA_TRACE_POINT;
-       struct rspamd_lua_upstream *up = lua_check_upstream (L);
+       struct rspamd_lua_upstream *up = lua_check_upstream(L, 1);
        gboolean fail_addr = FALSE;
        const gchar *reason = "unknown";
 
@@ -211,7 +206,7 @@ static gint
 lua_upstream_ok (lua_State *L)
 {
        LUA_TRACE_POINT;
-       struct rspamd_lua_upstream *up = lua_check_upstream (L);
+       struct rspamd_lua_upstream *up = lua_check_upstream(L, 1);
 
        if (up) {
                rspamd_upstream_ok (up->up);
@@ -224,7 +219,7 @@ static gint
 lua_upstream_destroy (lua_State *L)
 {
        LUA_TRACE_POINT;
-       struct rspamd_lua_upstream *up = lua_check_upstream (L);
+       struct rspamd_lua_upstream *up = lua_check_upstream(L, 1);
 
        if (up) {
                /* Remove reference to the parent */