]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Rework lua tcp module
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 17 Sep 2016 12:22:20 +0000 (13:22 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Sat, 17 Sep 2016 12:22:20 +0000 (13:22 +0100)
src/lua/lua_tcp.c

index 4c202787fffd47ba896e5ee6863e7d4e7885812d..38d14adfc9e45e716cf40ff2be1fa70cef013d30 100644 (file)
@@ -17,6 +17,7 @@
 #include "buffer.h"
 #include "dns.h"
 #include "utlist.h"
+#include "ref.h"
 #include "unix-std.h"
 
 static void lua_tcp_handler (int fd, short what, gpointer ud);
@@ -49,8 +50,22 @@ end
 
 LUA_FUNCTION_DEF (tcp, request);
 
+/***
+ * @method tcp:close()
+ *
+ * Closes TCP connection
+ */
+LUA_FUNCTION_DEF (tcp, close);
+LUA_FUNCTION_DEF (tcp, gc);
+
 static const struct luaL_reg tcp_libf[] = {
        LUA_INTERFACE_DEF (tcp, request),
+       {NULL, NULL}
+};
+
+static const struct luaL_reg tcp_libm[] = {
+       LUA_INTERFACE_DEF (tcp, close),
+       {"__gc", lua_tcp_gc},
        {"__tostring", rspamd_lua_class_tostring},
        {NULL, NULL}
 };
@@ -67,14 +82,17 @@ struct lua_tcp_cbdata {
        gchar *stop_pattern;
        struct rspamd_async_watcher *w;
        struct event ev;
+       ref_entry_t ref;
        gint fd;
        gint cbref;
+       gint connect_cb;
        guint iovlen;
        guint pos;
        guint total;
+       guint16 port;
        gboolean partial;
        gboolean do_shutdown;
-       guint16 port;
+       gboolean connected;
 };
 
 static const int default_tcp_timeout = 5000;
@@ -110,6 +128,14 @@ lua_tcp_fin (gpointer arg)
        g_slice_free1 (sizeof (struct lua_tcp_cbdata), cbd);
 }
 
+static struct lua_tcp_cbdata *
+lua_check_tcp (lua_State *L, gint pos)
+{
+       void *ud = rspamd_lua_check_udata (L, pos, "rspamd{tcp}");
+       luaL_argcheck (L, ud != NULL, pos, "'tcp' expected");
+       return ud ? *((struct lua_tcp_cbdata **)ud) : NULL;
+}
+
 static void
 lua_tcp_maybe_free (struct lua_tcp_cbdata *cbd)
 {
@@ -126,13 +152,24 @@ static void
 lua_tcp_push_error (struct lua_tcp_cbdata *cbd, const char *err, ...)
 {
        va_list ap;
+       struct lua_tcp_cbdata **pcbd;
 
-       va_start (ap, err);
        lua_rawgeti (cbd->L, LUA_REGISTRYINDEX, cbd->cbref);
+
+       /* Error message */
+       va_start (ap, err);
        lua_pushvfstring (cbd->L, err, ap);
        va_end (ap);
 
-       if (lua_pcall (cbd->L, 1, 0, 0) != 0) {
+       /* Body */
+       lua_pushnil (cbd->L);
+       /* Connection */
+       pcbd = lua_newuserdata (cbd->L, sizeof (*pcbd));
+       *pcbd = cbd;
+       REF_RETAIN (cbd);
+       rspamd_lua_setclass (cbd->L, "rspamd{tcp}", -1);
+
+       if (lua_pcall (cbd->L, 3, 0, 0) != 0) {
                msg_info ("callback call failed: %s", lua_tostring (cbd->L, -1));
                lua_pop (cbd->L, 1);
        }
@@ -142,6 +179,7 @@ static void
 lua_tcp_push_data (struct lua_tcp_cbdata *cbd, const gchar *str, gsize len)
 {
        struct rspamd_lua_text *t;
+       struct lua_tcp_cbdata **pcbd;
 
        lua_rawgeti (cbd->L, LUA_REGISTRYINDEX, cbd->cbref);
        /* Error */
@@ -152,8 +190,13 @@ lua_tcp_push_data (struct lua_tcp_cbdata *cbd, const gchar *str, gsize len)
        t->start = str;
        t->len = len;
        t->own = FALSE;
+       /* Connection */
+       pcbd = lua_newuserdata (cbd->L, sizeof (*pcbd));
+       *pcbd = cbd;
+       rspamd_lua_setclass (cbd->L, "rspamd{tcp}", -1);
+       REF_RETAIN (cbd);
 
-       if (lua_pcall (cbd->L, 2, 0, 0) != 0) {
+       if (lua_pcall (cbd->L, 3, 0, 0) != 0) {
                msg_info ("callback call failed: %s", lua_tostring (cbd->L, -1));
                lua_pop (cbd->L, 1);
        }
@@ -208,7 +251,8 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd)
        if (r == -1) {
                lua_tcp_push_error (cbd, "IO write error while trying to write %d "
                                "bytes: %s", (gint)remain, strerror (errno));
-               lua_tcp_maybe_free (cbd);
+               REF_RELEASE (cbd);
+
                return;
        }
        else {
@@ -257,6 +301,8 @@ lua_tcp_handler (int fd, short what, gpointer ud)
        gssize r;
        guint slen;
 
+       REF_RETAIN (cbd);
+
        if (what == EV_READ) {
                g_assert (cbd->partial || cbd->in != NULL);
 
@@ -285,7 +331,7 @@ lua_tcp_handler (int fd, short what, gpointer ud)
                                }
                        }
 
-                       lua_tcp_maybe_free (cbd);
+                       REF_RELEASE (cbd);
                }
                else {
                        if (cbd->partial) {
@@ -301,7 +347,7 @@ lua_tcp_handler (int fd, short what, gpointer ud)
                                                if (memcmp (cbd->stop_pattern, cbd->in->str +
                                                                (cbd->in->len - slen), slen) == 0) {
                                                        lua_tcp_push_data (cbd, cbd->in->str, cbd->in->len);
-                                                       lua_tcp_maybe_free (cbd);
+                                                       REF_RELEASE (cbd);
                                                }
                                        }
                                }
@@ -309,18 +355,39 @@ lua_tcp_handler (int fd, short what, gpointer ud)
                }
        }
        else if (what == EV_WRITE) {
+               if (!cbd->connected) {
+                       cbd->connected = TRUE;
+
+                       if (cbd->connect_cb != -1) {
+                               struct lua_tcp_cbdata **pcbd;
+
+                               lua_rawgeti (cbd->L, LUA_REGISTRYINDEX, cbd->connect_cb);
+                               pcbd = lua_newuserdata (cbd->L, sizeof (*pcbd));
+                               *pcbd = cbd;
+                               REF_RETAIN (cbd);
+                               rspamd_lua_setclass (cbd->L, "rspamd{tcp}", -1);
+
+                               if (lua_pcall (cbd->L, 1, 0, 0) != 0) {
+                                       msg_info ("callback call failed: %s", lua_tostring (cbd->L, -1));
+                                       lua_pop (cbd->L, 1);
+                               }
+                       }
+               }
+
                lua_tcp_write_helper (cbd);
        }
 #ifdef EV_CLOSED
        else if (what == EV_CLOSED) {
                lua_tcp_push_error (cbd, "Remote peer has closed the connection");
-               lua_tcp_maybe_free (cbd);
+               REF_RELEASE (cbd);
        }
 #endif
        else {
                lua_tcp_push_error (cbd, "IO timeout");
-               lua_tcp_maybe_free (cbd);
+               REF_RELEASE (cbd);
        }
+
+       REF_RELEASE (cbd);
 }
 
 static gboolean
@@ -354,7 +421,7 @@ lua_tcp_dns_handler (struct rdns_reply *reply, gpointer ud)
                rn = rdns_request_get_name (reply->request, NULL);
                lua_tcp_push_error (cbd, "unable to resolve host: %s",
                                rn->name);
-               lua_tcp_maybe_free (cbd);
+               REF_RETAIN (cbd);
        }
        else {
                if (reply->entries->type == RDNS_REQUEST_A) {
@@ -371,7 +438,7 @@ lua_tcp_dns_handler (struct rdns_reply *reply, gpointer ud)
                if (!lua_tcp_make_connection (cbd)) {
                        lua_tcp_push_error (cbd, "unable to make connection to the host %s",
                                        rspamd_inet_address_to_string (cbd->addr));
-                       lua_tcp_maybe_free (cbd);
+                       REF_RETAIN (cbd);
                }
        }
 }
@@ -443,7 +510,7 @@ lua_tcp_request (lua_State *L)
        const gchar *host;
        gchar *stop_pattern = NULL;
        guint port;
-       gint cbref, tp;
+       gint cbref, tp, conn_cbref = -1;
        struct event_base *ev_base;
        struct lua_tcp_cbdata *cbd;
        struct rspamd_dns_resolver *resolver;
@@ -529,6 +596,10 @@ lua_tcp_request (lua_State *L)
                        lua_pop (L, 1);
                }
 
+               if (pool == NULL) {
+                       return luaL_error (L, "tcp request has no memory pool associated");
+               }
+
                lua_pushstring (L, "timeout");
                lua_gettable (L, -2);
                if (lua_type (L, -1) == LUA_TNUMBER) {
@@ -557,11 +628,14 @@ lua_tcp_request (lua_State *L)
                }
                lua_pop (L, 1);
 
-               if (pool == NULL) {
+               lua_pushstring (L, "on_connect");
+               lua_gettable (L, -2);
+
+               if (lua_type (L, -1) == LUA_TFUNCTION) {
+                       conn_cbref = luaL_ref (L, LUA_REGISTRYINDEX);
+               }
+               else {
                        lua_pop (L, 1);
-                       msg_err ("tcp request has no memory pool associated");
-                       lua_pushboolean (L, FALSE);
-                       return 1;
                }
 
                lua_pushstring (L, "data");
@@ -633,6 +707,8 @@ lua_tcp_request (lua_State *L)
        cbd->pos = 0;
        cbd->port = port;
        cbd->stop_pattern = stop_pattern;
+       cbd->connect_cb = conn_cbref;
+       REF_INIT_RETAIN (cbd, lua_tcp_maybe_free);
 
        if (session) {
                cbd->session = session;
@@ -648,7 +724,7 @@ lua_tcp_request (lua_State *L)
                rspamd_inet_address_set_port (cbd->addr, port);
                /* Host is numeric IP, no need to resolve */
                if (!lua_tcp_make_connection (cbd)) {
-                       lua_tcp_maybe_free (cbd);
+                       REF_RELEASE (cbd);
                        lua_pushboolean (L, FALSE);
 
                        return 1;
@@ -659,14 +735,14 @@ lua_tcp_request (lua_State *L)
                        if (!make_dns_request (resolver, session, NULL, lua_tcp_dns_handler, cbd,
                                        RDNS_REQUEST_A, host)) {
                                lua_tcp_push_error (cbd, "cannot resolve host: %s", host);
-                               lua_tcp_maybe_free (cbd);
+                               REF_RETAIN (cbd);
                        }
                }
                else {
                        if (!make_dns_request_task (task, lua_tcp_dns_handler, cbd,
                                        RDNS_REQUEST_A, host)) {
                                lua_tcp_push_error (cbd, "cannot resolve host: %s", host);
-                               lua_tcp_maybe_free (cbd);
+                               REF_RELEASE (cbd);
                        }
                }
        }
@@ -675,6 +751,30 @@ lua_tcp_request (lua_State *L)
        return 1;
 }
 
+static gint
+lua_tcp_close (lua_State *L)
+{
+       struct lua_tcp_cbdata *cbd = lua_check_tcp (L, 1);
+
+       if (cbd == NULL) {
+               return luaL_error (L, "invalid arguments");
+       }
+
+       REF_RELEASE (cbd);
+
+       return 0;
+}
+
+static gint
+lua_tcp_gc (lua_State *L)
+{
+       struct lua_tcp_cbdata *cbd = lua_check_tcp (L, 1);
+
+       REF_RELEASE (cbd);
+
+       return 0;
+}
+
 static gint
 lua_load_tcp (lua_State * L)
 {
@@ -688,4 +788,6 @@ void
 luaopen_tcp (lua_State * L)
 {
        rspamd_lua_add_preload (L, "rspamd_tcp", lua_load_tcp);
+       rspamd_lua_new_class (L, "rspamd{tcp}", tcp_libm);
+       lua_pop (L, 1);
 }