]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Lua_tcp: Add preliminary support of SSL connections
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 29 May 2019 21:15:45 +0000 (22:15 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 29 May 2019 21:15:45 +0000 (22:15 +0100)
src/lua/lua_tcp.c

index 1c789a9d3527af6d9722077a5bff9b32c2da490d..1e19efd97f78e6ea62a0c90d6aee451a79f96db4 100644 (file)
@@ -15,6 +15,7 @@
  */
 #include "lua_common.h"
 #include "lua_thread_pool.h"
+#include "libutil/ssl_util.h"
 #include "utlist.h"
 #include "unix-std.h"
 #include <math.h>
@@ -117,6 +118,8 @@ local function http_simple_tcp_symbol(task)
       host = '127.0.0.1',
       timeout = 20,
       port = 18080,
+      ssl = false, -- If SSL connection is needed
+      ssl_verify = true, -- set to false if verify is not needed
     }
 
     is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n')
@@ -189,6 +192,14 @@ LUA_FUNCTION_DEF (tcp, add_write);
  */
 LUA_FUNCTION_DEF (tcp, shift_callback);
 
+/***
+ * @method tcp:starttls([no_verify])
+ *
+ * Starts tls connection
+ * @param {boolean} no_verify used to skip ssl verification
+ */
+LUA_FUNCTION_DEF (tcp, starttls);
+
 static const struct luaL_reg tcp_libf[] = {
        LUA_INTERFACE_DEF (tcp, request),
        {"new", lua_tcp_request},
@@ -203,6 +214,7 @@ static const struct luaL_reg tcp_libm[] = {
        LUA_INTERFACE_DEF (tcp, add_read),
        LUA_INTERFACE_DEF (tcp, add_write),
        LUA_INTERFACE_DEF (tcp, shift_callback),
+       LUA_INTERFACE_DEF (tcp, starttls),
        {"__tostring", rspamd_lua_class_tostring},
        {NULL, NULL}
 };
@@ -302,12 +314,14 @@ struct lua_tcp_dtor {
        struct lua_tcp_dtor *next;
 };
 
-#define LUA_TCP_FLAG_PARTIAL (1 << 0)
-#define LUA_TCP_FLAG_SHUTDOWN (1 << 2)
-#define LUA_TCP_FLAG_CONNECTED (1 << 3)
-#define LUA_TCP_FLAG_FINISHED (1 << 4)
-#define LUA_TCP_FLAG_SYNC (1 << 5)
-#define LUA_TCP_FLAG_RESOLVED (1 << 6)
+#define LUA_TCP_FLAG_PARTIAL (1u << 0u)
+#define LUA_TCP_FLAG_SHUTDOWN (1u << 2u)
+#define LUA_TCP_FLAG_CONNECTED (1u << 3u)
+#define LUA_TCP_FLAG_FINISHED (1u << 4u)
+#define LUA_TCP_FLAG_SYNC (1u << 5u)
+#define LUA_TCP_FLAG_RESOLVED (1u << 6u)
+#define LUA_TCP_FLAG_SSL (1u << 7u)
+#define LUA_TCP_FLAG_SSL_NOVERIFY (1u << 8u)
 
 #undef TCP_DEBUG_REFS
 #ifdef TCP_DEBUG_REFS
@@ -345,6 +359,8 @@ struct lua_tcp_cbdata {
        struct rspamd_symcache_item *item;
        struct thread_entry *thread;
        struct rspamd_config *cfg;
+       struct rspamd_ssl_connection *ssl_conn;
+       gchar *hostname;
        gboolean eof;
 };
 
@@ -445,6 +461,11 @@ lua_tcp_fin (gpointer arg)
                luaL_unref (cbd->cfg->lua_state, LUA_REGISTRYINDEX, cbd->connect_cb);
        }
 
+       if (cbd->ssl_conn) {
+               /* TODO: postpone close in case ssl is used ! */
+               rspamd_ssl_connection_free (cbd->ssl_conn);
+       }
+
        if (cbd->fd != -1) {
                event_del (&cbd->ev);
                close (cbd->fd);
@@ -464,6 +485,7 @@ lua_tcp_fin (gpointer arg)
        }
 
        g_byte_array_unref (cbd->in);
+       g_free (cbd->hostname);
        g_free (cbd);
 }
 
@@ -817,7 +839,13 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd)
 #ifdef MSG_NOSIGNAL
        flags = MSG_NOSIGNAL;
 #endif
-       r = sendmsg (cbd->fd, &msg, flags);
+
+       if (cbd->ssl_conn) {
+               r = rspamd_ssl_writev (cbd->ssl_conn, msg.msg_iov, msg.msg_iovlen);
+       }
+       else {
+               r = sendmsg (cbd->fd, &msg, flags);
+       }
 
        if (r == -1) {
                lua_tcp_push_error (cbd, FALSE, "IO write error while trying to write %d "
@@ -1006,7 +1034,13 @@ lua_tcp_handler (int fd, short what, gpointer ud)
        event_type = rh->type;
 
        if (what == EV_READ) {
-               r = read (cbd->fd, inbuf, sizeof (inbuf));
+               if (cbd->ssl_conn) {
+                       r = rspamd_ssl_read (cbd->ssl_conn, inbuf, sizeof (inbuf));
+               }
+               else {
+                       r = read (cbd->fd, inbuf, sizeof (inbuf));
+               }
+
                lua_tcp_process_read (cbd, inbuf, r);
        }
        else if (what == EV_WRITE) {
@@ -1189,6 +1223,21 @@ lua_tcp_register_watcher (struct lua_tcp_cbdata *cbd)
        }
 }
 
+static void
+lua_tcp_ssl_on_error (gpointer ud, GError *err)
+{
+       struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *)ud;
+
+       if (err) {
+               lua_tcp_push_error (cbd, TRUE, "ssl error: %s", err->message);
+       }
+       else {
+               lua_tcp_push_error (cbd, TRUE, "ssl error: unknown error");
+       }
+
+       TCP_RELEASE (cbd);
+}
+
 static gboolean
 lua_tcp_make_connection (struct lua_tcp_cbdata *cbd)
 {
@@ -1200,19 +1249,23 @@ lua_tcp_make_connection (struct lua_tcp_cbdata *cbd)
        if (fd == -1) {
                if (cbd->session) {
                        rspamd_mempool_t *pool = rspamd_session_mempool (cbd->session);
-                       msg_info_pool ("cannot connect to %s: %s",
+                       msg_info_pool ("cannot connect to %s (%s): %s",
                                        rspamd_inet_address_to_string (cbd->addr),
+                                       cbd->hostname,
                                        strerror (errno));
                }
                else {
-                       msg_info ("cannot connect to %s: %s",
+                       msg_info ("cannot connect to %s (%s): %s",
                                        rspamd_inet_address_to_string (cbd->addr),
+                                       cbd->hostname,
                                        strerror (errno));
                }
 
                return FALSE;
        }
 
+       cbd->fd = fd;
+
 #if 0
        if (!(cbd->flags & LUA_TCP_FLAG_RESOLVED)) {
                /* We come here without resolving, so we need to add a watcher */
@@ -1223,10 +1276,39 @@ lua_tcp_make_connection (struct lua_tcp_cbdata *cbd)
        }
 #endif
 
-       lua_tcp_register_event (cbd);
+       if (cbd->flags & LUA_TCP_FLAG_SSL) {
+               gpointer ssl_ctx;
+               gboolean verify_peer;
+
+               if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) {
+                       ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify;
+                       verify_peer = FALSE;
+               }
+               else {
+                       ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx;
+                       verify_peer = TRUE;
+               }
+
+               event_base_set (cbd->ev_base, &cbd->ev);
+               cbd->ssl_conn =
+                               rspamd_ssl_connection_new (ssl_ctx, cbd->ev_base, verify_peer);
+
+               if (!rspamd_ssl_connect_fd (cbd->ssl_conn, fd, cbd->hostname, &cbd->ev,
+                               &cbd->tv, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) {
+                       lua_tcp_push_error (cbd, TRUE, "ssl connection failed: %s",
+                                       strerror (errno));
+
+                       return FALSE;
+               }
+               else {
+                       lua_tcp_register_event (cbd);
+               }
+       }
+       else {
+               lua_tcp_register_event (cbd);
+               lua_tcp_plan_handler_event (cbd, TRUE, TRUE);
+       }
 
-       cbd->fd = fd;
-       lua_tcp_plan_handler_event (cbd, TRUE, TRUE);
 
        return TRUE;
 }
@@ -1359,7 +1441,8 @@ lua_tcp_request (lua_State *L)
        guint niov = 0, total_out;
        guint64 h;
        gdouble timeout = default_tcp_timeout;
-       gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE;
+       gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE,
+               ssl = FALSE, ssl_noverify = FALSE;
 
        if (lua_type (L, 1) == LUA_TTABLE) {
                lua_pushstring (L, "host");
@@ -1486,6 +1569,20 @@ lua_tcp_request (lua_State *L)
                }
                lua_pop (L, 1);
 
+               lua_pushstring (L, "ssl");
+               lua_gettable (L, -2);
+               if (lua_type (L, -1) == LUA_TBOOLEAN) {
+                       ssl = lua_toboolean (L, -1);
+               }
+               lua_pop (L, 1);
+
+               lua_pushstring (L, "ssl_noverify");
+               lua_gettable (L, -2);
+               if (lua_type (L, -1) == LUA_TBOOLEAN) {
+                       ssl_noverify = lua_toboolean (L, -1);
+               }
+               lua_pop (L, 1);
+
                lua_pushstring (L, "on_connect");
                lua_gettable (L, -2);
 
@@ -1568,6 +1665,7 @@ lua_tcp_request (lua_State *L)
        h = rspamd_random_uint64_fast ();
        rspamd_snprintf (cbd->tag, sizeof (cbd->tag), "%uxL", h);
        cbd->handlers = g_queue_new ();
+       cbd->hostname = g_strdup (host);
 
        if (total_out > 0) {
                struct lua_tcp_handler *wh;
@@ -1598,6 +1696,14 @@ lua_tcp_request (lua_State *L)
        cbd->fd = -1;
        cbd->port = port;
 
+       if (ssl) {
+               cbd->flags |= LUA_TCP_FLAG_SSL;
+
+               if (ssl_noverify) {
+                       cbd->flags |= LUA_TCP_FLAG_SSL_NOVERIFY;
+               }
+       }
+
        if (do_read) {
                cbd->in = g_byte_array_sized_new (8192);
        }
@@ -2236,6 +2342,11 @@ lua_tcp_sync_shutdown (lua_State *L)
        return 0;
 }
 
+static gint
+lua_tcp_starttls (lua_State * L)
+{
+       return 0;
+}
 
 static gint
 lua_tcp_sync_gc (lua_State * L)