aboutsummaryrefslogtreecommitdiffstats
path: root/src/lua/lua_tcp.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua/lua_tcp.c')
-rw-r--r--src/lua/lua_tcp.c139
1 files changed, 125 insertions, 14 deletions
diff --git a/src/lua/lua_tcp.c b/src/lua/lua_tcp.c
index 1c789a9d3..1e19efd97 100644
--- a/src/lua/lua_tcp.c
+++ b/src/lua/lua_tcp.c
@@ -15,6 +15,7 @@
*/
#include "lua_common.h"
#include "lua_thread_pool.h"
+#include "libutil/ssl_util.h"
#include "utlist.h"
#include "unix-std.h"
#include <math.h>
@@ -117,6 +118,8 @@ local function http_simple_tcp_symbol(task)
host = '127.0.0.1',
timeout = 20,
port = 18080,
+ ssl = false, -- If SSL connection is needed
+ ssl_verify = true, -- set to false if verify is not needed
}
is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n')
@@ -189,6 +192,14 @@ LUA_FUNCTION_DEF (tcp, add_write);
*/
LUA_FUNCTION_DEF (tcp, shift_callback);
+/***
+ * @method tcp:starttls([no_verify])
+ *
+ * Starts tls connection
+ * @param {boolean} no_verify used to skip ssl verification
+ */
+LUA_FUNCTION_DEF (tcp, starttls);
+
static const struct luaL_reg tcp_libf[] = {
LUA_INTERFACE_DEF (tcp, request),
{"new", lua_tcp_request},
@@ -203,6 +214,7 @@ static const struct luaL_reg tcp_libm[] = {
LUA_INTERFACE_DEF (tcp, add_read),
LUA_INTERFACE_DEF (tcp, add_write),
LUA_INTERFACE_DEF (tcp, shift_callback),
+ LUA_INTERFACE_DEF (tcp, starttls),
{"__tostring", rspamd_lua_class_tostring},
{NULL, NULL}
};
@@ -302,12 +314,14 @@ struct lua_tcp_dtor {
struct lua_tcp_dtor *next;
};
-#define LUA_TCP_FLAG_PARTIAL (1 << 0)
-#define LUA_TCP_FLAG_SHUTDOWN (1 << 2)
-#define LUA_TCP_FLAG_CONNECTED (1 << 3)
-#define LUA_TCP_FLAG_FINISHED (1 << 4)
-#define LUA_TCP_FLAG_SYNC (1 << 5)
-#define LUA_TCP_FLAG_RESOLVED (1 << 6)
+#define LUA_TCP_FLAG_PARTIAL (1u << 0u)
+#define LUA_TCP_FLAG_SHUTDOWN (1u << 2u)
+#define LUA_TCP_FLAG_CONNECTED (1u << 3u)
+#define LUA_TCP_FLAG_FINISHED (1u << 4u)
+#define LUA_TCP_FLAG_SYNC (1u << 5u)
+#define LUA_TCP_FLAG_RESOLVED (1u << 6u)
+#define LUA_TCP_FLAG_SSL (1u << 7u)
+#define LUA_TCP_FLAG_SSL_NOVERIFY (1u << 8u)
#undef TCP_DEBUG_REFS
#ifdef TCP_DEBUG_REFS
@@ -345,6 +359,8 @@ struct lua_tcp_cbdata {
struct rspamd_symcache_item *item;
struct thread_entry *thread;
struct rspamd_config *cfg;
+ struct rspamd_ssl_connection *ssl_conn;
+ gchar *hostname;
gboolean eof;
};
@@ -445,6 +461,11 @@ lua_tcp_fin (gpointer arg)
luaL_unref (cbd->cfg->lua_state, LUA_REGISTRYINDEX, cbd->connect_cb);
}
+ if (cbd->ssl_conn) {
+ /* TODO: postpone close in case ssl is used ! */
+ rspamd_ssl_connection_free (cbd->ssl_conn);
+ }
+
if (cbd->fd != -1) {
event_del (&cbd->ev);
close (cbd->fd);
@@ -464,6 +485,7 @@ lua_tcp_fin (gpointer arg)
}
g_byte_array_unref (cbd->in);
+ g_free (cbd->hostname);
g_free (cbd);
}
@@ -817,7 +839,13 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd)
#ifdef MSG_NOSIGNAL
flags = MSG_NOSIGNAL;
#endif
- r = sendmsg (cbd->fd, &msg, flags);
+
+ if (cbd->ssl_conn) {
+ r = rspamd_ssl_writev (cbd->ssl_conn, msg.msg_iov, msg.msg_iovlen);
+ }
+ else {
+ r = sendmsg (cbd->fd, &msg, flags);
+ }
if (r == -1) {
lua_tcp_push_error (cbd, FALSE, "IO write error while trying to write %d "
@@ -1006,7 +1034,13 @@ lua_tcp_handler (int fd, short what, gpointer ud)
event_type = rh->type;
if (what == EV_READ) {
- r = read (cbd->fd, inbuf, sizeof (inbuf));
+ if (cbd->ssl_conn) {
+ r = rspamd_ssl_read (cbd->ssl_conn, inbuf, sizeof (inbuf));
+ }
+ else {
+ r = read (cbd->fd, inbuf, sizeof (inbuf));
+ }
+
lua_tcp_process_read (cbd, inbuf, r);
}
else if (what == EV_WRITE) {
@@ -1189,6 +1223,21 @@ lua_tcp_register_watcher (struct lua_tcp_cbdata *cbd)
}
}
+static void
+lua_tcp_ssl_on_error (gpointer ud, GError *err)
+{
+ struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *)ud;
+
+ if (err) {
+ lua_tcp_push_error (cbd, TRUE, "ssl error: %s", err->message);
+ }
+ else {
+ lua_tcp_push_error (cbd, TRUE, "ssl error: unknown error");
+ }
+
+ TCP_RELEASE (cbd);
+}
+
static gboolean
lua_tcp_make_connection (struct lua_tcp_cbdata *cbd)
{
@@ -1200,19 +1249,23 @@ lua_tcp_make_connection (struct lua_tcp_cbdata *cbd)
if (fd == -1) {
if (cbd->session) {
rspamd_mempool_t *pool = rspamd_session_mempool (cbd->session);
- msg_info_pool ("cannot connect to %s: %s",
+ msg_info_pool ("cannot connect to %s (%s): %s",
rspamd_inet_address_to_string (cbd->addr),
+ cbd->hostname,
strerror (errno));
}
else {
- msg_info ("cannot connect to %s: %s",
+ msg_info ("cannot connect to %s (%s): %s",
rspamd_inet_address_to_string (cbd->addr),
+ cbd->hostname,
strerror (errno));
}
return FALSE;
}
+ cbd->fd = fd;
+
#if 0
if (!(cbd->flags & LUA_TCP_FLAG_RESOLVED)) {
/* We come here without resolving, so we need to add a watcher */
@@ -1223,10 +1276,39 @@ lua_tcp_make_connection (struct lua_tcp_cbdata *cbd)
}
#endif
- lua_tcp_register_event (cbd);
+ if (cbd->flags & LUA_TCP_FLAG_SSL) {
+ gpointer ssl_ctx;
+ gboolean verify_peer;
+
+ if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) {
+ ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify;
+ verify_peer = FALSE;
+ }
+ else {
+ ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx;
+ verify_peer = TRUE;
+ }
+
+ event_base_set (cbd->ev_base, &cbd->ev);
+ cbd->ssl_conn =
+ rspamd_ssl_connection_new (ssl_ctx, cbd->ev_base, verify_peer);
+
+ if (!rspamd_ssl_connect_fd (cbd->ssl_conn, fd, cbd->hostname, &cbd->ev,
+ &cbd->tv, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) {
+ lua_tcp_push_error (cbd, TRUE, "ssl connection failed: %s",
+ strerror (errno));
+
+ return FALSE;
+ }
+ else {
+ lua_tcp_register_event (cbd);
+ }
+ }
+ else {
+ lua_tcp_register_event (cbd);
+ lua_tcp_plan_handler_event (cbd, TRUE, TRUE);
+ }
- cbd->fd = fd;
- lua_tcp_plan_handler_event (cbd, TRUE, TRUE);
return TRUE;
}
@@ -1359,7 +1441,8 @@ lua_tcp_request (lua_State *L)
guint niov = 0, total_out;
guint64 h;
gdouble timeout = default_tcp_timeout;
- gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE;
+ gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE,
+ ssl = FALSE, ssl_noverify = FALSE;
if (lua_type (L, 1) == LUA_TTABLE) {
lua_pushstring (L, "host");
@@ -1486,6 +1569,20 @@ lua_tcp_request (lua_State *L)
}
lua_pop (L, 1);
+ lua_pushstring (L, "ssl");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TBOOLEAN) {
+ ssl = lua_toboolean (L, -1);
+ }
+ lua_pop (L, 1);
+
+ lua_pushstring (L, "ssl_noverify");
+ lua_gettable (L, -2);
+ if (lua_type (L, -1) == LUA_TBOOLEAN) {
+ ssl_noverify = lua_toboolean (L, -1);
+ }
+ lua_pop (L, 1);
+
lua_pushstring (L, "on_connect");
lua_gettable (L, -2);
@@ -1568,6 +1665,7 @@ lua_tcp_request (lua_State *L)
h = rspamd_random_uint64_fast ();
rspamd_snprintf (cbd->tag, sizeof (cbd->tag), "%uxL", h);
cbd->handlers = g_queue_new ();
+ cbd->hostname = g_strdup (host);
if (total_out > 0) {
struct lua_tcp_handler *wh;
@@ -1598,6 +1696,14 @@ lua_tcp_request (lua_State *L)
cbd->fd = -1;
cbd->port = port;
+ if (ssl) {
+ cbd->flags |= LUA_TCP_FLAG_SSL;
+
+ if (ssl_noverify) {
+ cbd->flags |= LUA_TCP_FLAG_SSL_NOVERIFY;
+ }
+ }
+
if (do_read) {
cbd->in = g_byte_array_sized_new (8192);
}
@@ -2236,6 +2342,11 @@ lua_tcp_sync_shutdown (lua_State *L)
return 0;
}
+static gint
+lua_tcp_starttls (lua_State * L)
+{
+ return 0;
+}
static gint
lua_tcp_sync_gc (lua_State * L)