diff options
Diffstat (limited to 'src/lua')
-rw-r--r-- | src/lua/lua_tcp.c | 139 |
1 files changed, 125 insertions, 14 deletions
diff --git a/src/lua/lua_tcp.c b/src/lua/lua_tcp.c index 1c789a9d3..1e19efd97 100644 --- a/src/lua/lua_tcp.c +++ b/src/lua/lua_tcp.c @@ -15,6 +15,7 @@ */ #include "lua_common.h" #include "lua_thread_pool.h" +#include "libutil/ssl_util.h" #include "utlist.h" #include "unix-std.h" #include <math.h> @@ -117,6 +118,8 @@ local function http_simple_tcp_symbol(task) host = '127.0.0.1', timeout = 20, port = 18080, + ssl = false, -- If SSL connection is needed + ssl_verify = true, -- set to false if verify is not needed } is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n') @@ -189,6 +192,14 @@ LUA_FUNCTION_DEF (tcp, add_write); */ LUA_FUNCTION_DEF (tcp, shift_callback); +/*** + * @method tcp:starttls([no_verify]) + * + * Starts tls connection + * @param {boolean} no_verify used to skip ssl verification + */ +LUA_FUNCTION_DEF (tcp, starttls); + static const struct luaL_reg tcp_libf[] = { LUA_INTERFACE_DEF (tcp, request), {"new", lua_tcp_request}, @@ -203,6 +214,7 @@ static const struct luaL_reg tcp_libm[] = { LUA_INTERFACE_DEF (tcp, add_read), LUA_INTERFACE_DEF (tcp, add_write), LUA_INTERFACE_DEF (tcp, shift_callback), + LUA_INTERFACE_DEF (tcp, starttls), {"__tostring", rspamd_lua_class_tostring}, {NULL, NULL} }; @@ -302,12 +314,14 @@ struct lua_tcp_dtor { struct lua_tcp_dtor *next; }; -#define LUA_TCP_FLAG_PARTIAL (1 << 0) -#define LUA_TCP_FLAG_SHUTDOWN (1 << 2) -#define LUA_TCP_FLAG_CONNECTED (1 << 3) -#define LUA_TCP_FLAG_FINISHED (1 << 4) -#define LUA_TCP_FLAG_SYNC (1 << 5) -#define LUA_TCP_FLAG_RESOLVED (1 << 6) +#define LUA_TCP_FLAG_PARTIAL (1u << 0u) +#define LUA_TCP_FLAG_SHUTDOWN (1u << 2u) +#define LUA_TCP_FLAG_CONNECTED (1u << 3u) +#define LUA_TCP_FLAG_FINISHED (1u << 4u) +#define LUA_TCP_FLAG_SYNC (1u << 5u) +#define LUA_TCP_FLAG_RESOLVED (1u << 6u) +#define LUA_TCP_FLAG_SSL (1u << 7u) +#define LUA_TCP_FLAG_SSL_NOVERIFY (1u << 8u) #undef TCP_DEBUG_REFS #ifdef TCP_DEBUG_REFS @@ -345,6 +359,8 @@ struct lua_tcp_cbdata { struct rspamd_symcache_item *item; struct thread_entry *thread; struct rspamd_config *cfg; + struct rspamd_ssl_connection *ssl_conn; + gchar *hostname; gboolean eof; }; @@ -445,6 +461,11 @@ lua_tcp_fin (gpointer arg) luaL_unref (cbd->cfg->lua_state, LUA_REGISTRYINDEX, cbd->connect_cb); } + if (cbd->ssl_conn) { + /* TODO: postpone close in case ssl is used ! */ + rspamd_ssl_connection_free (cbd->ssl_conn); + } + if (cbd->fd != -1) { event_del (&cbd->ev); close (cbd->fd); @@ -464,6 +485,7 @@ lua_tcp_fin (gpointer arg) } g_byte_array_unref (cbd->in); + g_free (cbd->hostname); g_free (cbd); } @@ -817,7 +839,13 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd) #ifdef MSG_NOSIGNAL flags = MSG_NOSIGNAL; #endif - r = sendmsg (cbd->fd, &msg, flags); + + if (cbd->ssl_conn) { + r = rspamd_ssl_writev (cbd->ssl_conn, msg.msg_iov, msg.msg_iovlen); + } + else { + r = sendmsg (cbd->fd, &msg, flags); + } if (r == -1) { lua_tcp_push_error (cbd, FALSE, "IO write error while trying to write %d " @@ -1006,7 +1034,13 @@ lua_tcp_handler (int fd, short what, gpointer ud) event_type = rh->type; if (what == EV_READ) { - r = read (cbd->fd, inbuf, sizeof (inbuf)); + if (cbd->ssl_conn) { + r = rspamd_ssl_read (cbd->ssl_conn, inbuf, sizeof (inbuf)); + } + else { + r = read (cbd->fd, inbuf, sizeof (inbuf)); + } + lua_tcp_process_read (cbd, inbuf, r); } else if (what == EV_WRITE) { @@ -1189,6 +1223,21 @@ lua_tcp_register_watcher (struct lua_tcp_cbdata *cbd) } } +static void +lua_tcp_ssl_on_error (gpointer ud, GError *err) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *)ud; + + if (err) { + lua_tcp_push_error (cbd, TRUE, "ssl error: %s", err->message); + } + else { + lua_tcp_push_error (cbd, TRUE, "ssl error: unknown error"); + } + + TCP_RELEASE (cbd); +} + static gboolean lua_tcp_make_connection (struct lua_tcp_cbdata *cbd) { @@ -1200,19 +1249,23 @@ lua_tcp_make_connection (struct lua_tcp_cbdata *cbd) if (fd == -1) { if (cbd->session) { rspamd_mempool_t *pool = rspamd_session_mempool (cbd->session); - msg_info_pool ("cannot connect to %s: %s", + msg_info_pool ("cannot connect to %s (%s): %s", rspamd_inet_address_to_string (cbd->addr), + cbd->hostname, strerror (errno)); } else { - msg_info ("cannot connect to %s: %s", + msg_info ("cannot connect to %s (%s): %s", rspamd_inet_address_to_string (cbd->addr), + cbd->hostname, strerror (errno)); } return FALSE; } + cbd->fd = fd; + #if 0 if (!(cbd->flags & LUA_TCP_FLAG_RESOLVED)) { /* We come here without resolving, so we need to add a watcher */ @@ -1223,10 +1276,39 @@ lua_tcp_make_connection (struct lua_tcp_cbdata *cbd) } #endif - lua_tcp_register_event (cbd); + if (cbd->flags & LUA_TCP_FLAG_SSL) { + gpointer ssl_ctx; + gboolean verify_peer; + + if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify; + verify_peer = FALSE; + } + else { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx; + verify_peer = TRUE; + } + + event_base_set (cbd->ev_base, &cbd->ev); + cbd->ssl_conn = + rspamd_ssl_connection_new (ssl_ctx, cbd->ev_base, verify_peer); + + if (!rspamd_ssl_connect_fd (cbd->ssl_conn, fd, cbd->hostname, &cbd->ev, + &cbd->tv, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) { + lua_tcp_push_error (cbd, TRUE, "ssl connection failed: %s", + strerror (errno)); + + return FALSE; + } + else { + lua_tcp_register_event (cbd); + } + } + else { + lua_tcp_register_event (cbd); + lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + } - cbd->fd = fd; - lua_tcp_plan_handler_event (cbd, TRUE, TRUE); return TRUE; } @@ -1359,7 +1441,8 @@ lua_tcp_request (lua_State *L) guint niov = 0, total_out; guint64 h; gdouble timeout = default_tcp_timeout; - gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE; + gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE, + ssl = FALSE, ssl_noverify = FALSE; if (lua_type (L, 1) == LUA_TTABLE) { lua_pushstring (L, "host"); @@ -1486,6 +1569,20 @@ lua_tcp_request (lua_State *L) } lua_pop (L, 1); + lua_pushstring (L, "ssl"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TBOOLEAN) { + ssl = lua_toboolean (L, -1); + } + lua_pop (L, 1); + + lua_pushstring (L, "ssl_noverify"); + lua_gettable (L, -2); + if (lua_type (L, -1) == LUA_TBOOLEAN) { + ssl_noverify = lua_toboolean (L, -1); + } + lua_pop (L, 1); + lua_pushstring (L, "on_connect"); lua_gettable (L, -2); @@ -1568,6 +1665,7 @@ lua_tcp_request (lua_State *L) h = rspamd_random_uint64_fast (); rspamd_snprintf (cbd->tag, sizeof (cbd->tag), "%uxL", h); cbd->handlers = g_queue_new (); + cbd->hostname = g_strdup (host); if (total_out > 0) { struct lua_tcp_handler *wh; @@ -1598,6 +1696,14 @@ lua_tcp_request (lua_State *L) cbd->fd = -1; cbd->port = port; + if (ssl) { + cbd->flags |= LUA_TCP_FLAG_SSL; + + if (ssl_noverify) { + cbd->flags |= LUA_TCP_FLAG_SSL_NOVERIFY; + } + } + if (do_read) { cbd->in = g_byte_array_sized_new (8192); } @@ -2236,6 +2342,11 @@ lua_tcp_sync_shutdown (lua_State *L) return 0; } +static gint +lua_tcp_starttls (lua_State * L) +{ + return 0; +} static gint lua_tcp_sync_gc (lua_State * L) |