]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Accept upstream in lua_tcp
authorVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 2 Jul 2022 12:43:57 +0000 (13:43 +0100)
committerVsevolod Stakhov <vsevolod@rspamd.com>
Sat, 2 Jul 2022 12:43:57 +0000 (13:43 +0100)
src/lua/lua_tcp.c

index e1024ee36a2366dae7abe89ff6949af0ce4cfa7e..e839aba24f6e83b4665d8fd6d86a9e6dc3929aa1 100644 (file)
@@ -344,6 +344,7 @@ struct lua_tcp_cbdata {
        struct rspamd_config *cfg;
        struct rspamd_ssl_connection *ssl_conn;
        gchar *hostname;
+       struct upstream *up;
        gboolean eof;
 };
 
@@ -461,6 +462,10 @@ lua_tcp_fin (gpointer arg)
                rspamd_inet_address_free (cbd->addr);
        }
 
+       if (cbd->up) {
+               rspamd_upstream_unref(cbd->up);
+       }
+
        while (lua_tcp_shift_handler (cbd)) {}
        g_queue_free (cbd->handlers);
 
@@ -537,6 +542,10 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal,
        lua_State *L;
        gboolean callback_called = FALSE;
 
+       if (is_fatal && cbd->up) {
+               rspamd_upstream_fail(cbd->up, false, err);
+       }
+
        if (cbd->thread) {
                va_start (ap, err);
                lua_tcp_resume_thread_error_argp (cbd, err, ap);
@@ -907,6 +916,10 @@ call_finish_handler:
                cbd->flags &= ~LUA_TCP_FLAG_SHUTDOWN;
        }
 
+       if (cbd->up) {
+               rspamd_upstream_ok(cbd->up);
+       }
+
        lua_tcp_push_data (cbd, NULL, 0);
        if (!IS_SYNC (cbd)) {
                lua_tcp_shift_handler (cbd);
@@ -1455,6 +1468,7 @@ lua_tcp_arg_toiovec (lua_State *L, gint pos, struct lua_tcp_cbdata *cbd,
  * - `stop_pattern`: stop reading on finding a certain pattern (e.g. \r\n.\r\n for smtp)
  * - `shutdown`: half-close socket after writing (boolean: default false)
  * - `read`: read response after sending request (boolean: default true)
+ * - `upstream`: optional upstream object that would be used to get an address
  * @return {boolean} true if request has been sent
  */
 static gint
@@ -1473,6 +1487,7 @@ lua_tcp_request (lua_State *L)
        struct rspamd_task *task = NULL;
        struct rspamd_config *cfg = NULL;
        struct iovec *iov = NULL;
+       struct upstream *up;
        guint niov = 0, total_out;
        guint64 h;
        gdouble timeout = default_tcp_timeout;
@@ -1642,6 +1657,20 @@ lua_tcp_request (lua_State *L)
                        lua_pop (L, 1);
                }
 
+               lua_pushstring (L, "upstream");
+               lua_gettable (L, 1);
+
+               if (lua_type (L, -1) == LUA_TUSERDATA) {
+                       struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1);
+
+                       if (lup) {
+                               /* Preserve pointer in case if lup is destructed */
+                               up = lup->up;
+                       }
+               }
+
+               lua_pop (L, 1);
+
                lua_pushstring (L, "data");
                lua_gettable (L, -2);
                total_out = 0;
@@ -1786,6 +1815,10 @@ lua_tcp_request (lua_State *L)
        cbd->connect_cb = conn_cbref;
        REF_INIT_RETAIN (cbd, lua_tcp_maybe_free);
 
+       if (up) {
+               cbd->up = rspamd_upstream_ref(up);
+       }
+
        if (session) {
                cbd->session = session;
 
@@ -1799,7 +1832,26 @@ lua_tcp_request (lua_State *L)
                }
        }
 
-       if (rspamd_parse_inet_address (&cbd->addr,
+       if (cbd->up) {
+               /* Use upstream to get addr */
+               cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL);
+
+               /* Host is numeric IP, no need to resolve */
+               lua_tcp_register_watcher (cbd);
+
+               if (!lua_tcp_make_connection (cbd)) {
+                       lua_tcp_push_error (cbd, TRUE, "cannot connect to the host: %s", host);
+                       lua_pushboolean (L, FALSE);
+
+                       rspamd_upstream_fail(cbd->up, true, "failed to connect");
+
+                       /* No reset of the item as watcher has been registered */
+                       TCP_RELEASE (cbd);
+
+                       return 1;
+               }
+       }
+       else if (rspamd_parse_inet_address (&cbd->addr,
                        host, strlen (host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) {
                rspamd_inet_address_set_port (cbd->addr, port);
                /* Host is numeric IP, no need to resolve */