diff options
-rw-r--r-- | src/lua/lua_common.h | 9 | ||||
-rw-r--r-- | src/lua/lua_http.c | 50 | ||||
-rw-r--r-- | src/lua/lua_upstream.c | 23 |
3 files changed, 63 insertions, 19 deletions
diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index c961d37dd..a6e98a4ba 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -165,6 +165,11 @@ struct rspamd_lua_cached_entry { guint id; }; +struct rspamd_lua_upstream { + struct upstream *up; + gint upref; +}; + /* Common utility functions */ /** @@ -284,7 +289,9 @@ struct rspamd_lua_text *lua_new_text (lua_State *L, const gchar *start, */ bool lua_is_text_binary(struct rspamd_lua_text *t); -struct rspamd_lua_regexp *lua_check_regexp (lua_State *L, gint pos); +struct rspamd_lua_regexp* lua_check_regexp (lua_State *L, gint pos); + +struct rspamd_lua_upstream* lua_check_upstream(lua_State *L, int pos); enum rspamd_lua_task_header_type { RSPAMD_TASK_HEADER_PUSH_SIMPLE = 0, diff --git a/src/lua/lua_http.c b/src/lua/lua_http.c index 1fb5732f8..5a05f7058 100644 --- a/src/lua/lua_http.c +++ b/src/lua/lua_http.c @@ -16,6 +16,7 @@ #include "lua_common.h" #include "lua_thread_pool.h" #include "libserver/http/http_private.h" +#include "libutil/upstream.h" #include "ref.h" #include "unix-std.h" #include "zlib.h" @@ -77,6 +78,7 @@ struct lua_http_cbdata { gchar *mime_type; gchar *host; gchar *auth; + struct upstream *up; const gchar *url; gsize max_size; gint flags; @@ -201,6 +203,11 @@ static void lua_http_error_handler (struct rspamd_http_connection *conn, GError *err) { struct lua_http_cbdata *cbd = (struct lua_http_cbdata *)conn->ud; + + if (cbd->up) { + rspamd_upstream_fail(cbd->up, false, err ? err->message : "unknown error"); + } + if (cbd->cbref == -1) { if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_YIELDED) { cbd->flags &= ~RSPAMD_LUA_HTTP_FLAG_YIELDED; @@ -599,6 +606,7 @@ lua_http_request (lua_State *L) struct rspamd_config *cfg = NULL; struct rspamd_cryptobox_pubkey *peer_key = NULL; struct rspamd_cryptobox_keypair *local_kp = NULL; + struct upstream *up = NULL; const gchar *url, *lua_body; rspamd_fstring_t *body = NULL; gint cbref = -1; @@ -940,6 +948,20 @@ lua_http_request (lua_State *L) lua_pop (L, 1); + lua_pushstring (L, "upstream"); + lua_gettable (L, 1); + + if (lua_type (L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1); + + if (lup) { + /* Preserve pointer in case if lup is destructed */ + up = lup->up; + } + } + + lua_pop (L, 1); + lua_pushstring (L, "user"); lua_gettable (L, 1); @@ -1036,6 +1058,7 @@ lua_http_request (lua_State *L) cbd->url = url; cbd->auth = auth; cbd->task = task; + cbd->up = up; if (cbd->cbref == -1) { cbd->thread = lua_thread_pool_get_running_entry (cfg->lua_thread_pool); @@ -1065,6 +1088,7 @@ lua_http_request (lua_State *L) bool numeric_ip = false; /* Check if we can skip resolving */ + gsize hostlen = 0; const gchar *host = rspamd_http_message_get_http_host (msg, &hostlen); @@ -1072,6 +1096,7 @@ lua_http_request (lua_State *L) cbd->host = g_malloc (hostlen + 1); rspamd_strlcpy (cbd->host, host, hostlen + 1); + /* Keep-alive entry is available */ if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_KEEP_ALIVE) { const rspamd_inet_addr_t *ka_addr = rspamd_http_context_has_keepalive(NULL, cbd->host, @@ -1084,16 +1109,30 @@ lua_http_request (lua_State *L) } } + /* + * No keep-alive stuff, check if we have upstream or if we can parse host as + * a numeric address + */ if (!cbd->addr) { - /* We use msg->host here, not cbd->host ! */ - if (rspamd_parse_inet_address (&cbd->addr, - msg->host->str, msg->host->len, - RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + if (cbd->up) { numeric_ip = true; + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + } + else { + /* We use msg->host here, not cbd->host ! */ + if (rspamd_parse_inet_address(&cbd->addr, + msg->host->str, msg->host->len, + RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + numeric_ip = true; + } } } } else { + if (cbd->up) { + numeric_ip = true; + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + } cbd->host = NULL; } @@ -1105,6 +1144,9 @@ lua_http_request (lua_State *L) ret = lua_http_make_connection (cbd); if (!ret) { + if (cbd->up) { + rspamd_upstream_fail(cbd->up, true, "HTTP connection failed"); + } if (cbd->ref.refcount > 1) { /* Not released by make_connection */ REF_RELEASE (cbd); diff --git a/src/lua/lua_upstream.c b/src/lua/lua_upstream.c index 5019f28d3..94e251a95 100644 --- a/src/lua/lua_upstream.c +++ b/src/lua/lua_upstream.c @@ -95,15 +95,10 @@ static const struct luaL_reg upstream_m[] = { /* Upstream class */ -struct rspamd_lua_upstream { - struct upstream *up; - gint upref; -}; - -static struct rspamd_lua_upstream * -lua_check_upstream (lua_State * L) +struct rspamd_lua_upstream * +lua_check_upstream(lua_State *L, int pos) { - void *ud = rspamd_lua_check_udata (L, 1, "rspamd{upstream}"); + void *ud = rspamd_lua_check_udata (L, pos, "rspamd{upstream}"); luaL_argcheck (L, ud != NULL, 1, "'upstream' expected"); return ud ? (struct rspamd_lua_upstream *)ud : NULL; @@ -118,7 +113,7 @@ static gint lua_upstream_get_addr (lua_State *L) { LUA_TRACE_POINT; - struct rspamd_lua_upstream *up = lua_check_upstream (L); + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); if (up) { rspamd_lua_ip_push (L, rspamd_upstream_addr_next (up->up)); @@ -139,7 +134,7 @@ static gint lua_upstream_get_name (lua_State *L) { LUA_TRACE_POINT; - struct rspamd_lua_upstream *up = lua_check_upstream (L); + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); if (up) { lua_pushstring (L, rspamd_upstream_name (up->up)); @@ -160,7 +155,7 @@ static gint lua_upstream_get_port (lua_State *L) { LUA_TRACE_POINT; - struct rspamd_lua_upstream *up = lua_check_upstream (L); + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); if (up) { lua_pushinteger (L, rspamd_upstream_port (up->up)); @@ -180,7 +175,7 @@ static gint lua_upstream_fail (lua_State *L) { LUA_TRACE_POINT; - struct rspamd_lua_upstream *up = lua_check_upstream (L); + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); gboolean fail_addr = FALSE; const gchar *reason = "unknown"; @@ -211,7 +206,7 @@ static gint lua_upstream_ok (lua_State *L) { LUA_TRACE_POINT; - struct rspamd_lua_upstream *up = lua_check_upstream (L); + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); if (up) { rspamd_upstream_ok (up->up); @@ -224,7 +219,7 @@ static gint lua_upstream_destroy (lua_State *L) { LUA_TRACE_POINT; - struct rspamd_lua_upstream *up = lua_check_upstream (L); + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); if (up) { /* Remove reference to the parent */ |