From: Mikhail Galanin Date: Thu, 30 Aug 2018 15:50:55 +0000 (+0100) Subject: [Minor] Added coroutines support for TCP library X-Git-Tag: 1.8.0~187^2~8 X-Git-Url: https://source.dussan.org/?a=commitdiff_plain;h=a78803aeb558c0ebb9ada2a0f71f960ac31f373d;p=rspamd.git [Minor] Added coroutines support for TCP library --- diff --git a/lualib/lua_tcp_sync.lua b/lualib/lua_tcp_sync.lua new file mode 100644 index 000000000..b506338f1 --- /dev/null +++ b/lualib/lua_tcp_sync.lua @@ -0,0 +1,213 @@ +local rspamd_tcp = require "rspamd_tcp" +local lua_util = require "lua_util" + +local exports = {} +local N = 'tcp_sync' + +local tcp_sync = {_conn = nil, _data = '', _eof = false, _addr = ''} +local metatable = { + __tostring = function (self) + return "class {tcp_sync connect to: " .. self._addr .. "}" + end +} + +function tcp_sync.new(connection) + local self = {} + + for name, method in pairs(tcp_sync) do + if name ~= 'new' then + self[name] = method + end + end + + self._conn = connection + + setmetatable(self, metatable) + + return self +end + +--[[[ +-- @function tcp_sync.read_once() +-- +-- Acts exactly like low-level tcp_sync.read_once() +-- the only exception is that if there is some pending data, +-- it's returned immediately and no underlying call is performed +-- +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:read_once() + local is_ok, data + if self._data:len() > 0 then + data = self._data + self._data = nil + return true, data + end + + is_ok, data = self._conn:read_once() + + return is_ok, data +end + +--[[[ +-- @function tcp_sync.read_until(pattern) +-- +-- Reads data from the connection until pattern is found +-- returns all bytes before the pattern +-- +-- @param {pattern} Read data until pattern is found +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- @example +-- +--]] +function tcp_sync:read_until(pattern) + repeat + local pos_start, pos_end = self._data:find(pattern, 1, true) + if pos_start then + local data = self._data:sub(1, pos_start - 1) + self._data = self._data:sub(pos_end + 1) + return true, data + end + + local is_ok, more_data = self._conn:read_once() + if not is_ok then + return is_ok, more_data + end + + self._data = self._data .. more_data + until false +end + +--[[[ +-- @function tcp_sync.read_bytes(n) +-- +-- Reads {n} bytes from the stream +-- +-- @param {n} Number of bytes to read +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:read_bytes(n) + repeat + if self._data:len() >= n then + local data = self._data:sub(1, n) + self._data = self._data:sub(n + 1) + return true, data + end + + local is_ok, more_data = self._conn:read_once() + if not is_ok then + return is_ok, more_data + end + + self._data = self._data .. more_data + until false +end + +--[[[ +-- @function tcp_sync.read_until_eof(n) +-- +-- Reads stream until EOF is reached +-- +-- @return +-- true, {data} if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:read_until_eof() + while not self:eof() do + local is_ok, more_data = self._conn:read_once() + if not is_ok then + if self:eof() then + -- this error is EOF (connection terminated) + -- exactly what we were waiting for + break + end + return is_ok, more_data + end + self._data = self._data .. more_data + end + + local data = self._data + self._data = '' + return true, data +end + +--[[[ +-- @function tcp_sync.write(n) +-- +-- Writes data into the stream. +-- +-- @return +-- true if everything is fine +-- false, {error message} otherwise +-- +--]] +function tcp_sync:write(data) + return self._conn:write(data) +end + +--[[[ +-- @function tcp_sync.close() +-- +-- Closes the connection. If the connection was created with task, +-- this method is called automatically as soon as the task is done +-- Calling this method helps to prevent connections leak. +-- The object is finally destroyed by garbage collector. +-- +-- @return +-- +--]] +function tcp_sync:close() + return self._conn:close() +end + +--[[[ +-- @function tcp_sync.eof() +-- +-- @return +-- true if last "read" operation ended with EOF +-- false otherwise +-- +--]] +function tcp_sync:eof() + if not self._eof and self._conn:eof() then + self._eof = true + end + return self._eof +end + +--[[[ +-- @function tcp_sync.shutdown(n) +-- +-- half-close socket +-- +-- @return +-- +--]] +function tcp_sync:shutdown() + return self._conn:shutdown() +end + +exports.connect = function (args) + local is_ok, connection = rspamd_tcp.connect_sync(args) + if not is_ok then + return is_ok, connection + end + + local instance = tcp_sync.new(connection) + instance._addr = string.format("%s:%s", tostring(args.host), tostring(args.port)) + + lua_util.debugm(N, args.task, 'Connected to %s', instance._addr) + + return true, instance +end + +return exports \ No newline at end of file diff --git a/src/lua/lua_tcp.c b/src/lua/lua_tcp.c index 61a73acf8..851605be1 100644 --- a/src/lua/lua_tcp.c +++ b/src/lua/lua_tcp.c @@ -92,6 +92,61 @@ rspamd_config:register_symbol({ LUA_FUNCTION_DEF (tcp, request); +/*** + * @method connect_sync() + * + * Creates pseudo-synchronous TCP connection. + * Each method of the connection requiring IO, becames a yielding point, + * i.e. current thread Lua thread is get suspended and resumes as soon as IO is done + * + * This class represents low-level API, using of "lua_tcp_sync" module is recommended. + * + * @example + +local rspamd_tcp = require "rspamd_tcp" +local logger = require "rspamd_logger" + +local function http_simple_tcp_symbol(task) + + local err + local is_ok, connection = tcp_sync.connect { + task = task, + host = '127.0.0.1', + timeout = 20, + port = 18080, + } + + is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n') + + logger.errx(task, 'write %1, %2', is_ok, err) + if not is_ok then + logger.errx(task, 'write error: %1', err) + end + + local data + is_ok, data = connection:read_once(); + + logger.errx(task, 'read_once: is_ok: %1, data: %2', is_ok, data) + + is_ok, err = connection:write("POST /request2 HTTP/1.1\r\n\r\n") + logger.errx(task, 'write[2] %1, %2', is_ok, err) + + is_ok, data = connection:read_once(); + logger.errx(task, 'read_once[2]: is_ok %1, data: %2', is_ok, data) + + connection:close() +end + +rspamd_config:register_symbol({ + name = 'SIMPLE_TCP_TEST', + score = 1.0, + callback = http_simple_tcp_symbol, + no_squeeze = true +}) + * + */ +LUA_FUNCTION_DEF (tcp, connect_sync); + /*** * @method tcp:close() * @@ -135,6 +190,7 @@ static const struct luaL_reg tcp_libf[] = { LUA_INTERFACE_DEF (tcp, request), {"new", lua_tcp_request}, {"connect", lua_tcp_request}, + {"connect_sync", lua_tcp_connect_sync}, {NULL, NULL} }; @@ -148,6 +204,67 @@ static const struct luaL_reg tcp_libm[] = { {NULL, NULL} }; +/*** + * @method tcp:close() + * + * Closes TCP connection + */ +LUA_FUNCTION_DEF (tcp_sync, close); + +/*** + * @method set_timeout(timeout) + * + * Sets timeout for IO operations + */ +LUA_FUNCTION_DEF (tcp_sync, set_timeout); + +/*** + * @method read_once() + * + * Performs one read operation. If syscall returned with EAGAIN/EINT, + * restarts the operation, so it always returns either data or error. + */ +LUA_FUNCTION_DEF (tcp_sync, read_once); + +/*** + * @method eof() + * + * True if last IO operation ended with EOF, i.e. endpoint closed connection + */ +LUA_FUNCTION_DEF (tcp_sync, eof); + +/*** + * @method shutdown() + * + * Half-shutdown TCP connection + */ +LUA_FUNCTION_DEF (tcp_sync, shutdown); + +/*** + * @method write() + * + * Writes data into the stream. If syscall returned with EAGAIN/EINT + * restarts the operation. If performs write() until all the passed + * data is written completely. + */ +LUA_FUNCTION_DEF (tcp_sync, write); + +LUA_FUNCTION_DEF (tcp_sync, gc); + +static void lua_tcp_sync_session_dtor (gpointer ud); + +static const struct luaL_reg tcp_sync_libm[] = { + LUA_INTERFACE_DEF (tcp_sync, close), + LUA_INTERFACE_DEF (tcp_sync, set_timeout), + LUA_INTERFACE_DEF (tcp_sync, read_once), + LUA_INTERFACE_DEF (tcp_sync, write), + LUA_INTERFACE_DEF (tcp_sync, eof), + LUA_INTERFACE_DEF (tcp_sync, shutdown), + {"__gc", lua_tcp_sync_gc}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL} +}; + struct lua_tcp_read_handler { gchar *stop_pattern; guint plen; @@ -158,13 +275,14 @@ struct lua_tcp_write_handler { struct iovec *iov; guint iovlen; guint pos; - guint total; + gsize total_bytes; gint cbref; }; enum lua_tcp_handler_type { LUA_WANT_WRITE = 0, LUA_WANT_READ, + LUA_WANT_CONNECT // used only with sync connections }; struct lua_tcp_handler { @@ -185,6 +303,23 @@ struct lua_tcp_dtor { #define LUA_TCP_FLAG_SHUTDOWN (1 << 2) #define LUA_TCP_FLAG_CONNECTED (1 << 3) #define LUA_TCP_FLAG_FINISHED (1 << 4) +#define LUA_TCP_FLAG_SYNC (1 << 5) + +#undef TCP_DEBUG_REFS +#ifdef TCP_DEBUG_REFS +#define TCP_RETAIN(x) do { \ + msg_err ("retain ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RETAIN(x); \ +} while (0) + +#define TCP_RELEASE(x) do { \ + msg_err ("release ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RELEASE(x); \ +} while (0) +#else +#define TCP_RETAIN(x) REF_RETAIN(x) +#define TCP_RELEASE(x) REF_RELEASE(x) +#endif struct lua_tcp_cbdata { struct rspamd_async_session *session; @@ -204,8 +339,13 @@ struct lua_tcp_cbdata { struct lua_tcp_dtor *dtors; ref_entry_t ref; struct rspamd_task *task; + struct thread_entry *thread; + struct rspamd_config *cfg; + gboolean eof; }; +#define IS_SYNC(c) (((c)->flags & LUA_TCP_FLAG_SYNC) != 0) + #define msg_debug_tcp(...) rspamd_conditional_debug_fast (NULL, cbd->addr, \ rspamd_lua_tcp_log_id, "lua_tcp", cbd->tag, \ G_STRFUNC, \ @@ -216,6 +356,10 @@ INIT_LOG_MODULE(lua_tcp) static void lua_tcp_handler (int fd, short what, gpointer ud); static void lua_tcp_plan_handler_event (struct lua_tcp_cbdata *cbd, gboolean can_read, gboolean can_write); +static void lua_tcp_unregister_event (struct lua_tcp_cbdata *cbd); + +static void +lua_tcp_void_finalyser (gpointer arg) {} static const int default_tcp_timeout = 5000; @@ -257,7 +401,7 @@ lua_tcp_shift_handler (struct lua_tcp_cbdata *cbd) g_free (hdl->h.r.stop_pattern); } } - else { + else if (hdl->type == LUA_WANT_WRITE) { if (hdl->h.w.cbref) { luaL_unref (cbd->task->cfg->lua_state, LUA_REGISTRYINDEX, hdl->h.w.cbref); } @@ -266,6 +410,10 @@ lua_tcp_shift_handler (struct lua_tcp_cbdata *cbd) g_free (hdl->h.w.iov); } } + else { + msg_debug_tcp ("removing connect handler"); + /* LUA_WANT_CONNECT: it doesn't allocate anything, nothing to do here */ + } g_free (hdl); @@ -278,15 +426,25 @@ lua_tcp_fin (gpointer arg) struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *)arg; struct lua_tcp_dtor *dtor, *dttmp; - msg_debug_tcp ("finishing TCP connection"); + if (IS_SYNC (cbd) && cbd->task) { + /* + pointer is now becoming invalid, we should remove registered destructor, + all the necessary steps are done here + */ + rspamd_mempool_replace_destructor (cbd->task->task_pool, + lua_tcp_sync_session_dtor, cbd, NULL); + } - if (cbd->connect_cb) { + msg_debug_tcp ("finishing TCP %s connection", IS_SYNC (cbd) ? "sync" : "async"); + + if (cbd->connect_cb != -1) { luaL_unref (cbd->task->cfg->lua_state, LUA_REGISTRYINDEX, cbd->connect_cb); } if (cbd->fd != -1) { event_del (&cbd->ev); close (cbd->fd); + cbd->fd = -1; } if (cbd->addr) { @@ -316,12 +474,25 @@ lua_check_tcp (lua_State *L, gint pos) static void lua_tcp_maybe_free (struct lua_tcp_cbdata *cbd) { - if (cbd->async_ev) { - rspamd_session_watcher_pop (cbd->session, cbd->w); - rspamd_session_remove_event (cbd->session, lua_tcp_fin, cbd); + if (IS_SYNC (cbd)) { + /* + * in this mode, we don't remove object, we only remove the event + * Object is owned by lua and will be destroyed on __gc() + */ + if (cbd->async_ev) { + rspamd_session_watcher_pop (cbd->session, cbd->w); + rspamd_session_remove_event (cbd->session, lua_tcp_void_finalyser, cbd); + } + cbd->async_ev = NULL; } else { - lua_tcp_fin (cbd); + if (cbd->async_ev) { + rspamd_session_watcher_pop (cbd->session, cbd->w); + rspamd_session_remove_event (cbd->session, lua_tcp_fin, cbd); + } + else { + lua_tcp_fin (cbd); + } } } @@ -331,6 +502,8 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, const char *err, ...) __attribute__ ((format(printf, 3, 4))); #endif +static void lua_tcp_resume_thread_error_argp (struct lua_tcp_cbdata *cbd, const gchar *error, va_list argp); + static void lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, const char *err, ...) @@ -342,7 +515,17 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, struct lua_callback_state cbs; lua_State *L; - lua_thread_pool_prepare_callback (cbd->task->cfg->lua_thread_pool, &cbs); + if (cbd->thread) { + va_start (ap, err); + va_copy (ap_copy, ap); + lua_tcp_resume_thread_error_argp (cbd, err, ap_copy); + va_end (ap_copy); + va_end (ap); + + return; + } + + lua_thread_pool_prepare_callback (cbd->cfg->lua_thread_pool, &cbs); L = cbs.L; va_start (ap, err); @@ -376,7 +559,7 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, pcbd = lua_newuserdata (L, sizeof (*pcbd)); *pcbd = cbd; rspamd_lua_setclass (L, "rspamd{tcp}", -1); - REF_RETAIN (cbd); + TCP_RETAIN (cbd); if (lua_pcall (L, 3, 0, 0) != 0) { msg_info ("callback call failed: %s", lua_tostring (L, -1)); @@ -384,7 +567,7 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, lua_settop (L, top); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } if (!is_fatal) { @@ -401,6 +584,8 @@ lua_tcp_push_error (struct lua_tcp_cbdata *cbd, gboolean is_fatal, lua_thread_pool_restore_callback (&cbs); } +static void lua_tcp_resume_thread (struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len); + static void lua_tcp_push_data (struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) { @@ -411,6 +596,11 @@ lua_tcp_push_data (struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) struct lua_callback_state cbs; lua_State *L; + if (cbd->thread) { + lua_tcp_resume_thread (cbd, str, len); + return; + } + lua_thread_pool_prepare_callback (cbd->task->cfg->lua_thread_pool, &cbs); L = cbs.L; @@ -448,19 +638,71 @@ lua_tcp_push_data (struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) *pcbd = cbd; rspamd_lua_setclass (L, "rspamd{tcp}", -1); - REF_RETAIN (cbd); + TCP_RETAIN (cbd); if (lua_pcall (L, arg_cnt, 0, 0) != 0) { msg_info ("callback call failed: %s", lua_tostring (L, -1)); } lua_settop (L, top); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } lua_thread_pool_restore_callback (&cbs); } +static void +lua_tcp_resume_thread_error_argp (struct lua_tcp_cbdata *cbd, const gchar *error, va_list argp) +{ + struct thread_entry *thread = cbd->thread; + lua_State *L = thread->lua_state; + + lua_pushboolean (L, FALSE); + lua_pushvfstring (L, error, argp); + + lua_tcp_shift_handler (cbd); + // lua_tcp_unregister_event (cbd); + lua_thread_pool_set_running_entry (cbd->cfg->lua_thread_pool, cbd->thread); + lua_thread_resume (thread, 2); + TCP_RELEASE (cbd); +} + +static void +lua_tcp_resume_thread (struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) +{ +/* + * typical call returns: + * + * read: + * error: + * (nil, error message) + * got data: + * (true, data) + * write/connect: + * error: + * (nil, error message) + * wrote + * (true) + */ + + lua_State *L = cbd->thread->lua_state; + struct lua_tcp_handler *hdl; + hdl = g_queue_peek_head (cbd->handlers); + + lua_pushboolean (L, TRUE); + if (hdl->type == LUA_WANT_READ) { + lua_pushlstring (L, str, len); + } + else { + lua_pushnil (L); + } + lua_tcp_shift_handler (cbd); + lua_thread_pool_set_running_entry (cbd->cfg->lua_thread_pool, cbd->thread); + lua_thread_resume (cbd->thread, 2); + + TCP_RELEASE (cbd); +} + static void lua_tcp_plan_read (struct lua_tcp_cbdata *cbd) { @@ -475,6 +717,28 @@ lua_tcp_plan_read (struct lua_tcp_cbdata *cbd) event_add (&cbd->ev, &cbd->tv); } +static void +lua_tcp_connect_helper (struct lua_tcp_cbdata *cbd) +{ + /* This is used for sync mode only */ + lua_State *L = cbd->thread->lua_state; + + struct lua_tcp_cbdata **pcbd; + + lua_pushboolean (L, TRUE); + + lua_thread_pool_set_running_entry (cbd->cfg->lua_thread_pool, cbd->thread); + pcbd = lua_newuserdata (L, sizeof (*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass (L, "rspamd{tcp_sync}", -1); + + lua_tcp_shift_handler (cbd); + + // lua_tcp_unregister_event (cbd); + lua_thread_resume (cbd->thread, 2); + TCP_RELEASE (cbd); +} + static void lua_tcp_write_helper (struct lua_tcp_cbdata *cbd) { @@ -493,7 +757,7 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd) g_assert (hdl != NULL && hdl->type == LUA_WANT_WRITE); wh = &hdl->h.w; - if (wh->pos == wh->total) { + if (wh->pos == wh->total_bytes) { goto call_finish_handler; } @@ -531,8 +795,11 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd) if (r == -1) { lua_tcp_push_error (cbd, FALSE, "IO write error while trying to write %d " "bytes: %s", (gint)remain, strerror (errno)); - lua_tcp_shift_handler (cbd); - lua_tcp_plan_handler_event (cbd, TRUE, FALSE); + if (!IS_SYNC (cbd)) { + /* sync connection methods perform this inside */ + lua_tcp_shift_handler (cbd); + lua_tcp_plan_handler_event (cbd, TRUE, FALSE); + } return; } @@ -540,7 +807,7 @@ lua_tcp_write_helper (struct lua_tcp_cbdata *cbd) wh->pos += r; } - if (wh->pos >= wh->total) { + if (wh->pos >= wh->total_bytes) { goto call_finish_handler; } else { @@ -561,13 +828,15 @@ call_finish_handler: } lua_tcp_push_data (cbd, NULL, 0); - lua_tcp_shift_handler (cbd); - lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + if (!IS_SYNC (cbd)) { + lua_tcp_shift_handler (cbd); + lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + } } static gboolean lua_tcp_process_read_handler (struct lua_tcp_cbdata *cbd, - struct lua_tcp_read_handler *rh) + struct lua_tcp_read_handler *rh, gboolean eof) { guint slen; goffset pos; @@ -581,15 +850,16 @@ lua_tcp_process_read_handler (struct lua_tcp_cbdata *cbd, msg_debug_tcp ("found TCP stop pattern"); lua_tcp_push_data (cbd, cbd->in->data, pos); + if (!IS_SYNC (cbd)) { + lua_tcp_shift_handler (cbd); + } if (pos + slen < cbd->in->len) { /* We have a leftover */ memmove (cbd->in->data, cbd->in->data + pos + slen, cbd->in->len - (pos + slen)); - lua_tcp_shift_handler (cbd); cbd->in->len = cbd->in->len - (pos + slen); } else { - lua_tcp_shift_handler (cbd); cbd->in->len = 0; } @@ -604,9 +874,14 @@ lua_tcp_process_read_handler (struct lua_tcp_cbdata *cbd, } else { msg_debug_tcp ("read TCP partial data"); - lua_tcp_push_data (cbd, cbd->in->data, cbd->in->len); - lua_tcp_shift_handler (cbd); + slen = cbd->in->len; + + /* we have eaten all the data, handler should not know that there is something */ cbd->in->len = 0; + lua_tcp_push_data (cbd, cbd->in->data, slen); + if (!IS_SYNC (cbd)) { + lua_tcp_shift_handler (cbd); + } return TRUE; } @@ -635,21 +910,24 @@ lua_tcp_process_read (struct lua_tcp_cbdata *cbd, else { g_byte_array_append (cbd->in, in, r); - if (!lua_tcp_process_read_handler (cbd, rh)) { + if (!lua_tcp_process_read_handler (cbd, rh, FALSE)) { /* Plan more read */ lua_tcp_plan_read (cbd); } else { /* Go towards the next handler */ - lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + if (!IS_SYNC (cbd)) { + lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + } } } } else if (r == 0) { /* EOF */ + cbd->eof = TRUE; if (cbd->in->len > 0) { /* We have some data to process */ - lua_tcp_process_read_handler (cbd, rh); + lua_tcp_process_read_handler (cbd, rh, TRUE); } else { lua_tcp_push_error (cbd, FALSE, "IO read error: connection terminated"); @@ -670,7 +948,7 @@ lua_tcp_process_read (struct lua_tcp_cbdata *cbd, lua_tcp_push_error (cbd, TRUE, "IO read error while trying to read data: %s", strerror (errno)); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } } @@ -684,36 +962,40 @@ lua_tcp_handler (int fd, short what, gpointer ud) socklen_t so_len = sizeof (so_error); struct lua_callback_state cbs; lua_State *L; - - REF_RETAIN (cbd); + enum lua_tcp_handler_type event_type; + TCP_RETAIN (cbd); msg_debug_tcp ("processed TCP event: %d", what); + struct lua_tcp_handler *rh = g_queue_peek_head (cbd->handlers); + event_type = rh->type; + if (what == EV_READ) { r = read (cbd->fd, inbuf, sizeof (inbuf)); lua_tcp_process_read (cbd, inbuf, r); } else if (what == EV_WRITE) { + if (!(cbd->flags & LUA_TCP_FLAG_CONNECTED)) { if (getsockopt (fd, SOL_SOCKET, SO_ERROR, &so_error, &so_len) == -1) { lua_tcp_push_error (cbd, TRUE, "Cannot get socket error: %s", strerror (errno)); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); goto out; } else if (so_error != 0) { lua_tcp_push_error (cbd, TRUE, "Socket error detected: %s", strerror (so_error)); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); goto out; } else { cbd->flags |= LUA_TCP_FLAG_CONNECTED; - lua_thread_pool_prepare_callback (cbd->task->cfg->lua_thread_pool, &cbs); - L = cbs.L; - if (cbd->connect_cb != -1) { + lua_thread_pool_prepare_callback (cbd->task->cfg->lua_thread_pool, &cbs); + L = cbs.L; + struct lua_tcp_cbdata **pcbd; gint top; @@ -721,7 +1003,7 @@ lua_tcp_handler (int fd, short what, gpointer ud) lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->connect_cb); pcbd = lua_newuserdata (L, sizeof (*pcbd)); *pcbd = cbd; - REF_RETAIN (cbd); + TCP_RETAIN (cbd); rspamd_lua_setclass (L, "rspamd{tcp}", -1); if (lua_pcall (L, 1, 0, 0) != 0) { @@ -730,26 +1012,36 @@ lua_tcp_handler (int fd, short what, gpointer ud) lua_settop (L, top); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); + + lua_thread_pool_restore_callback (&cbs); } } } - lua_tcp_write_helper (cbd); + if (event_type == LUA_WANT_WRITE) { + lua_tcp_write_helper (cbd); + } + else if (event_type == LUA_WANT_CONNECT) { + lua_tcp_connect_helper (cbd); + } + else { + g_assert_not_reached (); + } } #ifdef EV_CLOSED else if (what == EV_CLOSED) { lua_tcp_push_error (cbd, TRUE, "Remote peer has closed the connection"); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } #endif else { lua_tcp_push_error (cbd, TRUE, "IO timeout"); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } out: - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } static void @@ -764,7 +1056,7 @@ lua_tcp_plan_handler_event (struct lua_tcp_cbdata *cbd, gboolean can_read, if (!(cbd->flags & LUA_TCP_FLAG_FINISHED)) { /* We are finished with a connection */ msg_debug_tcp ("no handlers left, finish session"); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); cbd->flags |= LUA_TCP_FLAG_FINISHED; } } @@ -774,10 +1066,12 @@ lua_tcp_plan_handler_event (struct lua_tcp_cbdata *cbd, gboolean can_read, /* We need to check if we have some leftover in the buffer */ if (cbd->in->len > 0) { msg_debug_tcp ("process read buffer leftover"); - if (lua_tcp_process_read_handler (cbd, &hdl->h.r)) { - /* We can go to the next handler */ - lua_tcp_shift_handler (cbd); - lua_tcp_plan_handler_event (cbd, can_read, can_write); + if (lua_tcp_process_read_handler (cbd, &hdl->h.r, FALSE)) { + if (!IS_SYNC(cbd)) { + /* We can go to the next handler */ + lua_tcp_shift_handler (cbd); + lua_tcp_plan_handler_event (cbd, can_read, can_write); + } } } else { @@ -791,18 +1085,20 @@ lua_tcp_plan_handler_event (struct lua_tcp_cbdata *cbd, gboolean can_read, else { /* Cannot read more */ lua_tcp_push_error (cbd, FALSE, "EOF, cannot read more data"); - lua_tcp_shift_handler (cbd); - lua_tcp_plan_handler_event (cbd, can_read, can_write); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler (cbd); + lua_tcp_plan_handler_event (cbd, can_read, can_write); + } } } } - else { + else if (hdl->type == LUA_WANT_WRITE) { /* * We need to plan write event if there is something in the * write request */ - if (hdl->h.w.pos < hdl->h.w.total) { + if (hdl->h.w.pos < hdl->h.w.total_bytes) { msg_debug_tcp ("plan new write"); if (can_write) { event_set (&cbd->ev, cbd->fd, EV_WRITE, lua_tcp_handler, cbd); @@ -812,8 +1108,10 @@ lua_tcp_plan_handler_event (struct lua_tcp_cbdata *cbd, gboolean can_read, else { /* Cannot write more */ lua_tcp_push_error (cbd, FALSE, "EOF, cannot write more data"); - lua_tcp_shift_handler (cbd); - lua_tcp_plan_handler_event (cbd, can_read, can_write); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler (cbd); + lua_tcp_plan_handler_event (cbd, can_read, can_write); + } } } else { @@ -821,16 +1119,23 @@ lua_tcp_plan_handler_event (struct lua_tcp_cbdata *cbd, gboolean can_read, g_assert_not_reached (); } } + else { /* LUA_WANT_CONNECT */ + msg_debug_tcp ("plan new connect"); + event_set (&cbd->ev, cbd->fd, EV_WRITE, lua_tcp_handler, cbd); + event_base_set (cbd->ev_base, &cbd->ev); + event_add (&cbd->ev, &cbd->tv); + } } } - static gboolean lua_tcp_register_event (struct lua_tcp_cbdata *cbd) { if (cbd->session) { + event_finalizer_t fin = IS_SYNC (cbd) ? lua_tcp_void_finalyser : lua_tcp_fin; + cbd->async_ev = rspamd_session_add_event (cbd->session, - (event_finalizer_t) lua_tcp_fin, + fin, cbd, g_quark_from_static_string ("lua tcp")); @@ -885,7 +1190,7 @@ lua_tcp_dns_handler (struct rdns_reply *reply, gpointer ud) rn = rdns_request_get_name (reply->request, NULL); lua_tcp_push_error (cbd, TRUE, "unable to resolve host: %s", rn->name); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } else { if (reply->entries->type == RDNS_REQUEST_A) { @@ -902,7 +1207,7 @@ lua_tcp_dns_handler (struct rdns_reply *reply, gpointer ud) if (!lua_tcp_make_connection (cbd)) { lua_tcp_push_error (cbd, TRUE, "unable to make connection to the host %s", rspamd_inet_address_to_string (cbd->addr)); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); } } } @@ -966,7 +1271,6 @@ lua_tcp_arg_toiovec (lua_State *L, gint pos, struct lua_tcp_cbdata *cbd, * - `ev_base`: event base (if no task specified) * - `resolver`: DNS resolver (no task) * - `session`: events session (no task) - * - `pool`: memory pool (no task) * - `host`: IP or name of the peer (required) * - `port`: remote port to use * - `data`: a table of strings or `rspamd_text` objects that contains data pieces @@ -1037,6 +1341,7 @@ lua_tcp_request (lua_State *L) ev_base = task->ev_base; resolver = task->resolver; session = task->s; + cfg = task->cfg; } lua_pop (L, 1); @@ -1202,6 +1507,7 @@ lua_tcp_request (lua_State *L) } cbd->task = task; + cbd->cfg = cfg; h = rspamd_random_uint64_fast (); rspamd_snprintf (cbd->tag, sizeof (cbd->tag), "%uxL", h); cbd->handlers = g_queue_new (); @@ -1213,7 +1519,7 @@ lua_tcp_request (lua_State *L) wh->type = LUA_WANT_WRITE; wh->h.w.iov = iov; wh->h.w.iovlen = niov; - wh->h.w.total = total_out; + wh->h.w.total_bytes = total_out; wh->h.w.pos = 0; /* Cannot set write handler here */ wh->h.w.cbref = -1; @@ -1269,7 +1575,7 @@ lua_tcp_request (lua_State *L) cbd->session = session; if (rspamd_session_is_destroying (session)) { - REF_RELEASE (cbd); + TCP_RELEASE (cbd); lua_pushboolean (L, FALSE); return 1; @@ -1280,7 +1586,7 @@ lua_tcp_request (lua_State *L) rspamd_inet_address_set_port (cbd->addr, port); /* Host is numeric IP, no need to resolve */ if (!lua_tcp_make_connection (cbd)) { - REF_RELEASE (cbd); + TCP_RELEASE (cbd); lua_pushboolean (L, FALSE); return 1; @@ -1294,7 +1600,7 @@ lua_tcp_request (lua_State *L) if (!make_dns_request (resolver, session, NULL, lua_tcp_dns_handler, cbd, RDNS_REQUEST_A, host)) { lua_tcp_push_error (cbd, TRUE, "cannot resolve host: %s", host); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); lua_pushboolean (L, FALSE); return 1; } @@ -1306,7 +1612,7 @@ lua_tcp_request (lua_State *L) if (!make_dns_request_task (task, lua_tcp_dns_handler, cbd, RDNS_REQUEST_A, host)) { lua_tcp_push_error (cbd, TRUE, "cannot resolve host: %s", host); - REF_RELEASE (cbd); + TCP_RELEASE (cbd); lua_pushboolean (L, FALSE); return 1; } @@ -1320,6 +1626,175 @@ lua_tcp_request (lua_State *L) return 1; } +/*** + * @function rspamd_tcp.connect_sync({params}) + * Creates new pseudo-synchronous connection to the specific address:port + * + * - `task`: rspamd task objects (implies `pool`, `session`, `ev_base` and `resolver` arguments) + * - `ev_base`: event base (if no task specified) + * - `resolver`: DNS resolver (no task) + * - `session`: events session (no task) + * - `config`: config (no task) + * - `host`: IP or name of the peer (required) + * - `port`: remote port to use + * - `timeout`: floating point value that specifies timeout for IO operations in **seconds** + * @return {boolean} true if request has been sent + */ +static gint +lua_tcp_connect_sync (lua_State *L) +{ + LUA_TRACE_POINT; + GError *err = NULL; + + gint64 port = -1; + gdouble timeout = default_tcp_timeout; + const gchar *host = NULL; + gint ret; + guint64 h; + + struct rspamd_task *task = NULL; + struct rspamd_async_session *session = NULL; + struct rspamd_dns_resolver *resolver = NULL; + struct rspamd_config *cfg = NULL; + struct event_base *ev_base = NULL; + + int arguments_validated = rspamd_lua_parse_table_arguments (L, 1, &err, + "task=U{task};session=U{session};resolver=U{resolver};ev_base=U{ev_base};" + "*host=S;*port=I;timeout=N;config=U{cfg}", + &task, &session, &resolver, &ev_base, + &host, &port, &timeout, &cfg); + + if (!arguments_validated) { + if (err) { + ret = luaL_error (L, "invalid arguments: %s", err->message); + g_error_free (err); + + return ret; + } + + return luaL_error (L, "invalid arguments"); + } + + if (0 > port || port > 65535) { + return luaL_error (L, "invalid port given (correct values: 1..65535)"); + } + + if (task == NULL && (cfg == NULL || ev_base == NULL || session == NULL)) { + return luaL_error (L, "invalid arguments: either task or config+ev_base+session should be set"); + } + + if (timeout < 0.00001) { + /* rspamd_lua_parse_table_arguments() sets missing N field to zero */ + timeout = default_tcp_timeout; + } + else { + timeout *= 1000.; + } + if (task) { + cfg = task->cfg; + ev_base = task->ev_base; + session = task->s; + } + if (resolver == NULL) { + if (task) { + resolver = task->resolver; + } + else { + resolver = lua_tcp_global_resolver (ev_base, cfg); + } + } + + struct lua_tcp_cbdata *cbd = g_new0 (struct lua_tcp_cbdata, 1); + + cbd->task = task; + cbd->cfg = cfg; + cbd->thread = lua_thread_pool_get_running_entry (cfg->lua_thread_pool); + + h = rspamd_random_uint64_fast (); + rspamd_snprintf (cbd->tag, sizeof (cbd->tag), "%uxL", h); + cbd->handlers = g_queue_new (); + + cbd->ev_base = ev_base; + cbd->flags |= LUA_TCP_FLAG_SYNC; + msec_to_tv (timeout, &cbd->tv); + cbd->fd = -1; + cbd->port = (guint16)port; + + cbd->in = g_byte_array_new (); + + cbd->connect_cb = -1; + + REF_INIT_RETAIN (cbd, lua_tcp_maybe_free); + + if (task) { + rspamd_mempool_add_destructor (task->task_pool, lua_tcp_sync_session_dtor, cbd); + } + + struct lua_tcp_handler *wh; + + wh = g_malloc0 (sizeof (*wh)); + wh->type = LUA_WANT_CONNECT; + + g_queue_push_tail (cbd->handlers, wh); + + if (session) { + cbd->session = session; + + if (rspamd_session_is_destroying (session)) { + TCP_RELEASE (cbd); + lua_pushboolean (L, FALSE); + lua_pushliteral (L, "Session is being destroyed, requests are not allowed"); + + return 2; + } + } + + if (rspamd_parse_inet_address (&cbd->addr, host, 0)) { + rspamd_inet_address_set_port (cbd->addr, (guint16)port); + /* Host is numeric IP, no need to resolve */ + if (!lua_tcp_make_connection (cbd)) { + TCP_RELEASE (cbd); + lua_pushboolean (L, FALSE); + lua_pushliteral (L, "Failed to initiate connection"); + + return 2; + } + else { + lua_tcp_register_event (cbd); + } + } + else { + if (task == NULL) { + if (!make_dns_request (resolver, session, NULL, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + TCP_RELEASE (cbd); + lua_pushboolean (L, FALSE); + lua_pushliteral (L, "Failed to initiate dns request"); + + return 2; + } + else { + lua_tcp_register_event (cbd); + } + } + else { + if (!make_dns_request_task (task, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + TCP_RELEASE (cbd); + lua_pushboolean (L, FALSE); + lua_pushliteral (L, "Failed to initiate dns request"); + + return 2; + } + else { + lua_tcp_register_event (cbd); + } + } + } + + return lua_thread_yield (cbd->thread, 0); +} + static gint lua_tcp_close (lua_State *L) { @@ -1331,7 +1806,7 @@ lua_tcp_close (lua_State *L) } cbd->flags |= LUA_TCP_FLAG_FINISHED; - REF_RELEASE (cbd); + TCP_RELEASE (cbd); return 0; } @@ -1466,7 +1941,7 @@ lua_tcp_add_write (lua_State *L) wh->type = LUA_WANT_WRITE; wh->h.w.iov = iov; wh->h.w.iovlen = niov; - wh->h.w.total = total_out; + wh->h.w.total_bytes = total_out; wh->h.w.pos = 0; /* Cannot set write handler here */ wh->h.w.cbref = cbref; @@ -1494,6 +1969,226 @@ lua_tcp_shift_callback (lua_State *L) return 0; } +static struct lua_tcp_cbdata * +lua_check_sync_tcp (lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata (L, pos, "rspamd{tcp_sync}"); + luaL_argcheck (L, ud != NULL, pos, "'tcp' expected"); + return ud ? *((struct lua_tcp_cbdata **)ud) : NULL; +} + +static int +lua_tcp_sync_close (lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + + if (cbd == NULL) { + return luaL_error (L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->fd != -1) { + event_del (&cbd->ev); + close (cbd->fd); + cbd->fd = -1; + } + + return 0; +} + +static void +lua_tcp_sync_session_dtor (gpointer ud) +{ + struct lua_tcp_cbdata *cbd = ud; + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->fd != -1) { + msg_debug ("closing sync TCP connection"); + event_del (&cbd->ev); + close (cbd->fd); + cbd->fd = -1; + } + + /* Task is gone, we should not try use it anymore */ + cbd->task = NULL; + + /* All events are removed when task is done, we should not refer them */ + cbd->async_ev = NULL; +} + +static int +lua_tcp_sync_set_timeout (lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + gdouble ms = lua_tonumber (L, 2); + + if (cbd == NULL) { + return luaL_error (L, "invalid arguments: self is not rspamd{tcp_sync}"); + } + if (lua_type (L, 2) != LUA_TNUMBER) { + return luaL_error (L, "invalid arguments: second parameter is expected to be number"); + } + + ms *= 1000.0; + double_to_tv (ms, &cbd->tv); + + return 0; +} + +static int +lua_tcp_sync_read_once (lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + struct lua_tcp_handler *rh; + + if (cbd == NULL) { + return luaL_error (L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + struct thread_entry *thread = lua_thread_pool_get_running_entry (cbd->cfg->lua_thread_pool); + + rh = g_malloc0 (sizeof (*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = -1; + + msg_debug_tcp ("added read sync event, thread: %p", thread); + + g_queue_push_tail (cbd->handlers, rh); + lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + + TCP_RETAIN (cbd); + + return lua_thread_yield (thread, 0); +} + +static int +lua_tcp_sync_write (lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + struct lua_tcp_handler *wh; + gint tp; + struct iovec *iov = NULL; + guint niov = 0; + gsize total_out = 0; + + if (cbd == NULL) { + return luaL_error (L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + struct thread_entry *thread = lua_thread_pool_get_running_entry (cbd->cfg->lua_thread_pool); + + tp = lua_type (L, 2); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc (sizeof (*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec (L, 2, cbd, iov)) { + msg_err ("tcp request has bad data argument"); + g_free (iov); + g_free (cbd); + + return luaL_error (L, "invalid arguments second parameter (data) is expected to be either string or rspamd{text}"); + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushvalue (L, 3); + + lua_pushnil (L); + while (lua_next (L, -2) != 0) { + niov ++; + lua_pop (L, 1); + } + + iov = g_malloc (sizeof (*iov) * niov); + lua_pushnil (L); + niov = 0; + + while (lua_next (L, -2) != 0) { + if (!lua_tcp_arg_toiovec (L, -1, cbd, &iov[niov])) { + msg_err ("tcp request has bad data argument at pos %d", niov); + g_free (iov); + g_free (cbd); + + return luaL_error (L, "invalid arguments second parameter (data) is expected to be either string or rspamd{text}"); + } + + total_out += iov[niov].iov_len; + niov ++; + + lua_pop (L, 1); + } + + lua_pop (L, 1); + } + + wh = g_malloc0 (sizeof (*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + wh->h.w.cbref = -1; + msg_debug_tcp ("added sync write event, thread: %p", thread); + + g_queue_push_tail (cbd->handlers, wh); + lua_tcp_plan_handler_event (cbd, TRUE, TRUE); + + TCP_RETAIN (cbd); + // lua_tcp_register_event (cbd); + + return lua_thread_yield (thread, 0); +} + +static gint +lua_tcp_sync_eof(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + if (cbd == NULL) { + return luaL_error (L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + lua_pushboolean(L, cbd->eof); + + return 1; +} + +static gint +lua_tcp_sync_shutdown (lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + if (cbd == NULL) { + return luaL_error (L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + shutdown (cbd->fd, SHUT_WR); + + return 0; +} + + +static gint +lua_tcp_sync_gc (lua_State * L) +{ + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp (L, 1); + if (!cbd) { + return luaL_error (L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + lua_tcp_maybe_free(cbd); + lua_tcp_fin (cbd); + + return 0; +} + static gint lua_load_tcp (lua_State * L) { @@ -1508,5 +2203,6 @@ luaopen_tcp (lua_State * L) { rspamd_lua_add_preload (L, "rspamd_tcp", lua_load_tcp); rspamd_lua_new_class (L, "rspamd{tcp}", tcp_libm); + rspamd_lua_new_class (L, "rspamd{tcp_sync}", tcp_sync_libm); lua_pop (L, 1); } diff --git a/src/lua/lua_thread_pool.c b/src/lua/lua_thread_pool.c index 3fc14534d..df3ed775e 100644 --- a/src/lua/lua_thread_pool.c +++ b/src/lua/lua_thread_pool.c @@ -158,6 +158,20 @@ lua_thread_pool_set_running_entry (struct lua_thread_pool *pool, struct thread_e pool->running_entry = thread_entry; } +static void +lua_thread_pool_set_running_entry_for_thread (struct thread_entry *thread_entry) +{ + struct lua_thread_pool *pool; + + if (thread_entry->task) { + pool = thread_entry->task->cfg->lua_thread_pool; + } + else { + pool = thread_entry->cfg->lua_thread_pool; + } + + lua_thread_pool_set_running_entry (pool, thread_entry); +} void lua_thread_pool_prepare_callback (struct lua_thread_pool *pool, struct lua_callback_state *cbs) @@ -206,6 +220,8 @@ lua_thread_resume (struct thread_entry *thread_entry, gint narg) */ g_assert (lua_status (thread_entry->lua_state) == LUA_YIELD); + lua_thread_pool_set_running_entry_for_thread(thread_entry); + lua_resume_thread_internal (thread_entry, narg); } @@ -263,5 +279,7 @@ lua_resume_thread_internal (struct thread_entry *thread_entry, gint narg) gint lua_thread_yield (struct thread_entry *thread_entry, gint nresults) { + g_assert (lua_status (thread_entry->lua_state) == 0); + return lua_yield (thread_entry->lua_state, nresults); }