diff options
-rw-r--r-- | src/lua/lua_tcp.c | 54 |
1 files changed, 53 insertions, 1 deletions
diff --git a/src/lua/lua_tcp.c b/src/lua/lua_tcp.c index e1024ee36..e839aba24 100644 --- a/src/lua/lua_tcp.c +++ b/src/lua/lua_tcp.c @@ -344,6 +344,7 @@ struct lua_tcp_cbdata { struct rspamd_config *cfg; struct rspamd_ssl_connection *ssl_conn; gchar *hostname; + struct upstream *up; gboolean eof; }; @@ -461,6 +462,10 @@ lua_tcp_fin (gpointer arg) rspamd_inet_address_free (cbd->addr); } + if (cbd->up) { + rspamd_upstream_unref(cbd->up); + } + while (lua_tcp_shift_handler (cbd)) {} g_queue_free (cbd->handlers); @@ -537,6 +542,10 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, lua_State *L; gboolean callback_called = FALSE; + if (is_fatal && cbd->up) { + rspamd_upstream_fail(cbd->up, false, err); + } + if (cbd->thread) { va_start (ap, err); lua_tcp_resume_thread_error_argp (cbd, err, ap); @@ -907,6 +916,10 @@ call_finish_handler: cbd->flags &= ~LUA_TCP_FLAG_SHUTDOWN; } + if (cbd->up) { + rspamd_upstream_ok(cbd->up); + } + lua_tcp_push_data (cbd, NULL, 0); if (!IS_SYNC (cbd)) { lua_tcp_shift_handler (cbd); @@ -1455,6 +1468,7 @@ lua_tcp_arg_toiovec (lua_State *L, gint pos, struct lua_tcp_cbdata *cbd, * - `stop_pattern`: stop reading on finding a certain pattern (e.g. \r\n.\r\n for smtp) * - `shutdown`: half-close socket after writing (boolean: default false) * - `read`: read response after sending request (boolean: default true) + * - `upstream`: optional upstream object that would be used to get an address * @return {boolean} true if request has been sent */ static gint @@ -1473,6 +1487,7 @@ lua_tcp_request (lua_State *L) struct rspamd_task *task = NULL; struct rspamd_config *cfg = NULL; struct iovec *iov = NULL; + struct upstream *up; guint niov = 0, total_out; guint64 h; gdouble timeout = default_tcp_timeout; @@ -1642,6 +1657,20 @@ lua_tcp_request (lua_State *L) lua_pop (L, 1); } + lua_pushstring (L, "upstream"); + lua_gettable (L, 1); + + if (lua_type (L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1); + + if (lup) { + /* Preserve pointer in case if lup is destructed */ + up = lup->up; + } + } + + lua_pop (L, 1); + lua_pushstring (L, "data"); lua_gettable (L, -2); total_out = 0; @@ -1786,6 +1815,10 @@ lua_tcp_request (lua_State *L) cbd->connect_cb = conn_cbref; REF_INIT_RETAIN (cbd, lua_tcp_maybe_free); + if (up) { + cbd->up = rspamd_upstream_ref(up); + } + if (session) { cbd->session = session; @@ -1799,7 +1832,26 @@ lua_tcp_request (lua_State *L) } } - if (rspamd_parse_inet_address (&cbd->addr, + if (cbd->up) { + /* Use upstream to get addr */ + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + + /* Host is numeric IP, no need to resolve */ + lua_tcp_register_watcher (cbd); + + if (!lua_tcp_make_connection (cbd)) { + lua_tcp_push_error (cbd, TRUE, "cannot connect to the host: %s", host); + lua_pushboolean (L, FALSE); + + rspamd_upstream_fail(cbd->up, true, "failed to connect"); + + /* No reset of the item as watcher has been registered */ + TCP_RELEASE (cbd); + + return 1; + } + } + else if (rspamd_parse_inet_address (&cbd->addr, host, strlen (host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { rspamd_inet_address_set_port (cbd->addr, port); /* Host is numeric IP, no need to resolve */ |