]> source.dussan.org Git - rspamd.git/commitdiff
Start improved redis lua api
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 26 Jan 2016 10:04:45 +0000 (10:04 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Tue, 26 Jan 2016 10:04:45 +0000 (10:04 +0000)
src/lua/lua_redis.c

index b7ac637857c3a58b448945e1f541d7213bc43ec5..3ecd76d043170aa6a480a9674983c6a740a5abe2 100644 (file)
@@ -57,14 +57,25 @@ end
 
 LUA_FUNCTION_DEF (redis, make_request);
 LUA_FUNCTION_DEF (redis, make_request_sync);
+LUA_FUNCTION_DEF (redis, connect);
+LUA_FUNCTION_DEF (redis, connect_sync);
+LUA_FUNCTION_DEF (redis, add_cmd);
+LUA_FUNCTION_DEF (redis, exec);
 
-static const struct luaL_reg redislib_m[] = {
+static const struct luaL_reg redislib_f[] = {
        LUA_INTERFACE_DEF (redis, make_request),
        LUA_INTERFACE_DEF (redis, make_request_sync),
-       {"__tostring", rspamd_lua_class_tostring},
+       LUA_INTERFACE_DEF (redis, connect),
+       LUA_INTERFACE_DEF (redis, connect_sync),
        {NULL, NULL}
 };
 
+static const struct luaL_reg redislib_m[] = {
+       LUA_INTERFACE_DEF (redis, add_cmd),
+       LUA_INTERFACE_DEF (redis, exec),
+       {"__tostring", rspamd_lua_class_tostring},
+};
+
 #ifdef WITH_HIREDIS
 /**
  * Struct for userdata representation
@@ -83,6 +94,23 @@ struct lua_redis_userdata {
        guint16 terminated;
 };
 
+struct lua_redis_ctx {
+       gboolean async;
+       union {
+               struct lua_redis_userdata async;
+               redisContext *sync;
+       } d;
+       ref_entry_t ref;
+};
+
+static struct lua_redis_ctx *
+lua_check_redis (lua_State * L, gint pos)
+{
+       void *ud = luaL_checkudata (L, pos, "rspamd{redis}");
+       luaL_argcheck (L, ud != NULL, pos, "'redis' expected");
+       return ud ? *((struct lua_redis_ctx **)ud) : NULL;
+}
+
 static void
 lua_redis_free_args (char **args, guint nargs)
 {
@@ -98,17 +126,36 @@ lua_redis_free_args (char **args, guint nargs)
 }
 
 static void
-lua_redis_fin (void *arg)
+lua_redis_dtor (struct lua_redis_ctx *ctx)
 {
-       struct lua_redis_userdata *ud = arg;
-
-       if (ud->ctx) {
-               ud->terminated = 1;
-               redisAsyncFree (ud->ctx);
-               lua_redis_free_args (ud->args, ud->nargs);
-               event_del (&ud->timeout);
-               luaL_unref (ud->L, LUA_REGISTRYINDEX, ud->cbref);
+       struct lua_redis_userdata *ud;
+
+       if (ctx->async) {
+               ud = &ctx->d.async;
+
+               if (ud->ctx) {
+                       ud->terminated = 1;
+                       redisAsyncFree (ud->ctx);
+                       lua_redis_free_args (ud->args, ud->nargs);
+                       event_del (&ud->timeout);
+                       luaL_unref (ud->L, LUA_REGISTRYINDEX, ud->cbref);
+               }
        }
+       else {
+               if (ctx->d.sync) {
+                       redisFree (ctx->d.sync);
+               }
+       }
+
+       g_slice_free1 (sizeof (*ctx), ctx);
+}
+
+static void
+lua_redis_fin (void *arg)
+{
+       struct lua_redis_ctx *ctx = arg;
+
+       REF_RELEASE (ctx);
 }
 
 /**
@@ -210,10 +257,16 @@ static void
 lua_redis_callback (redisAsyncContext *c, gpointer r, gpointer priv)
 {
        redisReply *reply = r;
-       struct lua_redis_userdata *ud = priv;
+       struct lua_redis_ctx *ctx = priv;
+       struct lua_redis_userdata *ud;
+
+       REF_RETAIN (ctx);
+
+       ud = &ctx->d.async;
 
        if (ud->terminated) {
                /* We are already at the termination stage, just go out */
+               REF_RELEASE (ctx);
                return;
        }
 
@@ -238,15 +291,20 @@ lua_redis_callback (redisAsyncContext *c, gpointer r, gpointer priv)
                        lua_redis_push_error (c->errstr, ud, TRUE);
                }
        }
+       REF_RELEASE (ctx);
 }
 
 static void
 lua_redis_timeout (int fd, short what, gpointer u)
 {
-       struct lua_redis_userdata *ud = u;
+       struct lua_redis_ctx *ctx = u;
+       struct lua_redis_userdata *ud;
 
+       REF_RETAIN (ctx);
+       ud = &ctx->d.async;
        msg_info ("timeout while querying redis server");
        lua_redis_push_error ("timeout while connecting the server", ud, TRUE);
+       REF_RELEASE (ctx);
 }
 
 
@@ -312,7 +370,7 @@ lua_redis_connect_cb (const struct redisAsyncContext *c, int status)
  * @function rspamd_redis.make_request({params})
  * Make request to redis server, params is a table of key=value arguments in any order
  * @param {task} task worker task object
- * @param {ip} host server address
+ * @param {ip|string} host server address
  * @param {function} callback callback to be called in form `function (task, err, data)`
  * @param {string} cmd command to be sent to redis
  * @param {table} args numeric array of strings used as redis arguments
@@ -322,10 +380,12 @@ lua_redis_connect_cb (const struct redisAsyncContext *c, int status)
 static int
 lua_redis_make_request (lua_State *L)
 {
+       struct lua_redis_ctx *ctx;
+       rspamd_inet_addr_t *ip = NULL;
        struct lua_redis_userdata *ud;
        struct rspamd_lua_ip *addr = NULL;
        struct rspamd_task *task = NULL;
-       const gchar *cmd = NULL;
+       const gchar *cmd = NULL, *host;
        gint top, cbref = -1;
        struct timeval tv;
        gboolean ret = FALSE;
@@ -358,20 +418,43 @@ lua_redis_make_request (lua_State *L)
 
                lua_pushstring (L, "host");
                lua_gettable (L, -2);
+
                if (lua_type (L, -1) == LUA_TUSERDATA) {
                        addr = lua_check_ip (L, -1);
                }
+               else if (lua_type (L, -1) == LUA_TSTRING) {
+                       host = lua_tostring (L, -1);
+
+                       if (rspamd_parse_inet_address (&ip, host, strlen (host))) {
+                               addr = g_alloca (sizeof (*addr));
+                               addr->addr = ip;
+
+                               if (rspamd_inet_address_get_port (ip) == 0) {
+                                       rspamd_inet_address_set_port (ip, 6379);
+                               }
+
+                               if (task) {
+                                       rspamd_mempool_add_destructor (task->task_pool,
+                                                       (rspamd_mempool_destruct_t)rspamd_inet_address_destroy,
+                                                       ip);
+                               }
+                       }
+               }
+
                lua_pop (L, 1);
 
                lua_pushstring (L, "timeout");
                lua_gettable (L, -2);
-               timeout = lua_tonumber (L, -1);
+               if (lua_type (L, -1) == LUA_TNUMBER) {
+                       timeout = lua_tonumber (L, -1);
+               }
                lua_pop (L, 1);
 
                if (task != NULL && addr != NULL && cbref != -1 && cmd != NULL) {
-                       ud =
-                                       rspamd_mempool_alloc (task->task_pool,
-                                                       sizeof (struct lua_redis_userdata));
+                       ctx = g_slice_alloc0 (sizeof (struct lua_redis_ctx));
+                       REF_INIT_RETAIN (ctx, lua_redis_dtor);
+                       ctx->async = TRUE;
+                       ud = &ctx->d.async;
                        ud->task = task;
                        ud->L = L;
                        ud->cbref = cbref;
@@ -392,12 +475,14 @@ lua_redis_make_request (lua_State *L)
        else if ((task = lua_check_task (L, 1)) != NULL) {
                addr = lua_check_ip (L, 2);
                top = lua_gettop (L);
+
                /* Now get callback */
                if (lua_isfunction (L, 3) && addr != NULL && addr->addr && top >= 4) {
                        /* Create userdata */
-                       ud =
-                               rspamd_mempool_alloc (task->task_pool,
-                                       sizeof (struct lua_redis_userdata));
+                       ctx = g_slice_alloc0 (sizeof (struct lua_redis_ctx));
+                       REF_INIT_RETAIN (ctx, lua_redis_dtor);
+                       ctx->async = TRUE;
+                       ud = &ctx->d.async;
                        ud->task = task;
                        ud->L = L;
 
@@ -428,21 +513,20 @@ lua_redis_make_request (lua_State *L)
                redisAsyncSetConnectCallback (ud->ctx, lua_redis_connect_cb);
 
                if (ud->ctx == NULL || ud->ctx->err) {
-                       ud->terminated = 1;
-                       redisAsyncFree (ud->ctx);
-                       lua_redis_free_args (ud->args, ud->nargs);
-                       luaL_unref (ud->L, LUA_REGISTRYINDEX, ud->cbref);
+                       REF_RELEASE (ctx);
                        lua_pushboolean (L, FALSE);
 
                        return 1;
                }
+
                redisLibeventAttach (ud->ctx, ud->task->ev_base);
                ret = redisAsyncCommandArgv (ud->ctx,
                                        lua_redis_callback,
-                                       ud,
+                                       ctx,
                                        ud->nargs,
                                        (const gchar **)ud->args,
                                        NULL);
+
                if (ret == REDIS_OK) {
                        rspamd_session_add_event (ud->task->s,
                                        lua_redis_fin,
@@ -450,16 +534,13 @@ lua_redis_make_request (lua_State *L)
                                        g_quark_from_static_string ("lua redis"));
 
                        double_to_tv (timeout, &tv);
-                       event_set (&ud->timeout, -1, EV_TIMEOUT, lua_redis_timeout, ud);
+                       event_set (&ud->timeout, -1, EV_TIMEOUT, lua_redis_timeout, ctx);
                        event_base_set (ud->task->ev_base, &ud->timeout);
                        event_add (&ud->timeout, &tv);
                }
                else {
                        msg_info ("call to redis failed: %s", ud->ctx->errstr);
-                       ud->terminated = 1;
-                       lua_redis_free_args (ud->args, ud->nargs);
-                       redisAsyncFree (ud->ctx);
-                       luaL_unref (ud->L, LUA_REGISTRYINDEX, ud->cbref);
+                       REF_RELEASE (ctx);
                }
        }
 
@@ -471,8 +552,7 @@ lua_redis_make_request (lua_State *L)
 /***
  * @function rspamd_redis.make_request_sync({params})
  * Make blocking request to redis server, params is a table of key=value arguments in any order
- * @param {task} task worker task object
- * @param {ip} host server address
+ * @param {ip|string} host server address
  * @param {string} cmd command to be sent to redis
  * @param {table} args numeric array of strings used as redis arguments
  * @param {number} timeout timeout in seconds for request (1.0 by default)
@@ -587,6 +667,206 @@ lua_redis_make_request_sync (lua_State *L)
 
        return 1;
 }
+
+/***
+ * @function rspamd_redis.connect({params})
+ * Make request to redis server, params is a table of key=value arguments in any order
+ * @param {task} task worker task object
+ * @param {ip|string} host server address
+ * @param {number} timeout timeout in seconds for request (1.0 by default)
+ * @return {redis} new connection object or nil if connection failed
+ */
+static int
+lua_redis_connect (lua_State *L)
+{
+       struct rspamd_lua_ip *addr = NULL;
+       rspamd_inet_addr_t *ip = NULL;
+       const gchar *host;
+       struct timeval tv;
+       struct lua_redis_ctx *ctx = NULL, **pctx;
+       struct lua_redis_userdata *ud;
+       struct rspamd_task *task = NULL;
+       gboolean ret = FALSE;
+       gdouble timeout = REDIS_DEFAULT_TIMEOUT;
+
+       if (lua_istable (L, 1)) {
+               /* Table version */
+               lua_pushstring (L, "task");
+               lua_gettable (L, -2);
+               if (lua_type (L, -1) == LUA_TUSERDATA) {
+                       task = lua_check_task (L, -1);
+               }
+               lua_pop (L, 1);
+
+
+               lua_pushstring (L, "host");
+               lua_gettable (L, -2);
+
+               if (lua_type (L, -1) == LUA_TUSERDATA) {
+                       addr = lua_check_ip (L, -1);
+               }
+               else if (lua_type (L, -1) == LUA_TSTRING) {
+                       host = lua_tostring (L, -1);
+
+                       if (rspamd_parse_inet_address (&ip, host, strlen (host))) {
+                               addr = g_alloca (sizeof (*addr));
+                               addr->addr = ip;
+
+                               if (rspamd_inet_address_get_port (ip) == 0) {
+                                       rspamd_inet_address_set_port (ip, 6379);
+                               }
+
+                               if (task) {
+                                       rspamd_mempool_add_destructor (task->task_pool,
+                                                       (rspamd_mempool_destruct_t)rspamd_inet_address_destroy,
+                                                       ip);
+                               }
+                       }
+               }
+
+               lua_pop (L, 1);
+
+               lua_pushstring (L, "timeout");
+               if (lua_type (L, -1) == LUA_TNUMBER) {
+                       lua_gettable (L, -2);
+               }
+               timeout = lua_tonumber (L, -1);
+               lua_pop (L, 1);
+
+               if (task != NULL && addr != NULL) {
+                       ctx = g_slice_alloc0 (sizeof (struct lua_redis_ctx));
+                       REF_INIT_RETAIN (ctx, lua_redis_dtor);
+                       ctx->async = TRUE;
+                       ud = &ctx->d.async;
+                       ud->task = task;
+                       ud->L = L;
+                       ud->cbref = -1;
+                       ret = TRUE;
+               }
+       }
+
+       if (ret && ctx) {
+               ud->terminated = 0;
+               ud->ctx = redisAsyncConnect (rspamd_inet_address_to_string (addr->addr),
+                               rspamd_inet_address_get_port (addr->addr));
+               redisAsyncSetConnectCallback (ud->ctx, lua_redis_connect_cb);
+
+               if (ud->ctx == NULL || ud->ctx->err) {
+                       REF_RELEASE (ctx);
+                       lua_pushboolean (L, FALSE);
+
+                       return 1;
+               }
+
+               redisLibeventAttach (ud->ctx, ud->task->ev_base);
+               pctx = lua_newuserdata (L, sizeof (ctx));
+               *pctx = ctx;
+               rspamd_lua_setclass (L, "rspamd{redis}", -1);
+       }
+       else {
+               lua_pushnil (L);
+       }
+
+       return 1;
+}
+
+static int
+lua_redis_connect_sync (lua_State *L)
+{
+       struct rspamd_lua_ip *addr = NULL;
+       rspamd_inet_addr_t *ip = NULL;
+       const gchar *host;
+       struct timeval tv;
+       gboolean ret = FALSE;
+       gdouble timeout = REDIS_DEFAULT_TIMEOUT;
+       struct lua_redis_ctx *ctx, **pctx;
+       redisReply *r;
+
+       if (lua_istable (L, 1)) {
+               lua_pushstring (L, "host");
+               lua_gettable (L, -2);
+               if (lua_type (L, -1) == LUA_TUSERDATA) {
+                       addr = lua_check_ip (L, -1);
+               }
+               else if (lua_type (L, -1) == LUA_TSTRING) {
+                       host = lua_tostring (L, -1);
+                       if (rspamd_parse_inet_address (&ip, host, strlen (host))) {
+                               addr = g_alloca (sizeof (*addr));
+                               addr->addr = ip;
+
+                               if (rspamd_inet_address_get_port (ip) == 0) {
+                                       rspamd_inet_address_set_port (ip, 6379);
+                               }
+                       }
+               }
+               lua_pop (L, 1);
+
+               lua_pushstring (L, "timeout");
+               lua_gettable (L, -2);
+               if (lua_type (L, -1) == LUA_TNUMBER) {
+                       timeout = lua_tonumber (L, -1);
+               }
+               lua_pop (L, 1);
+
+               if (addr) {
+                       ret = TRUE;
+               }
+       }
+
+       if (ret) {
+               double_to_tv (timeout, &tv);
+               ctx = g_slice_alloc0 (sizeof (struct lua_redis_ctx));
+               REF_INIT_RETAIN (ctx, lua_redis_dtor);
+               ctx->async = FALSE;
+               ctx->d.sync = redisConnectWithTimeout (
+                               rspamd_inet_address_to_string (addr->addr),
+                               rspamd_inet_address_get_port (addr->addr), tv);
+
+               if (ip) {
+                       rspamd_inet_address_destroy (ip);
+               }
+
+               if (ctx->d.sync == NULL || ctx->d.sync->err) {
+                       REF_RELEASE (ctx);
+                       lua_pushboolean (L, FALSE);
+
+                       return 1;
+               }
+
+               pctx = lua_newuserdata (L, sizeof (ctx));
+               *pctx = ctx;
+               rspamd_lua_setclass (L, "rspamd{redis}", -1);
+
+       }
+       else {
+               if (ip) {
+                       rspamd_inet_address_destroy (ip);
+               }
+               msg_err ("bad arguments for redis request");
+               lua_pushboolean (L, FALSE);
+       }
+
+       return 1;
+}
+
+static int
+lua_redis_add_cmd (lua_State *L)
+{
+       msg_warn ("rspamd is compiled with no redis support");
+
+       lua_pushboolean (L, FALSE);
+
+       return 1;
+}
+static int
+lua_redis_exec (lua_State *L)
+{
+       msg_warn ("rspamd is compiled with no redis support");
+
+       lua_pushboolean (L, FALSE);
+
+       return 1;
+}
 #else
 static int
 lua_redis_make_request (lua_State *L)
@@ -606,13 +886,49 @@ lua_redis_make_request_sync (lua_State *L)
 
        return 1;
 }
+static int
+lua_redis_connect (lua_State *L)
+{
+       msg_warn ("rspamd is compiled with no redis support");
+
+       lua_pushboolean (L, FALSE);
+
+       return 1;
+}
+static int
+lua_redis_connect_sync (lua_State *L)
+{
+       msg_warn ("rspamd is compiled with no redis support");
+
+       lua_pushboolean (L, FALSE);
+
+       return 1;
+}
+static int
+lua_redis_add_cmd (lua_State *L)
+{
+       msg_warn ("rspamd is compiled with no redis support");
+
+       lua_pushboolean (L, FALSE);
+
+       return 1;
+}
+static int
+lua_redis_exec (lua_State *L)
+{
+       msg_warn ("rspamd is compiled with no redis support");
+
+       lua_pushboolean (L, FALSE);
+
+       return 1;
+}
 #endif
 
 static gint
 lua_load_redis (lua_State * L)
 {
        lua_newtable (L);
-       luaL_register (L, NULL, redislib_m);
+       luaL_register (L, NULL, redislib_f);
 
        return 1;
 }
@@ -624,5 +940,17 @@ lua_load_redis (lua_State * L)
 void
 luaopen_redis (lua_State * L)
 {
+       luaL_newmetatable (L, "rspamd{redis}");
+       lua_pushstring (L, "__index");
+       lua_pushvalue (L, -2);
+       lua_settable (L, -3);
+
+       lua_pushstring (L, "class");
+       lua_pushstring (L, "rspamd{redis}");
+       lua_rawset (L, -3);
+
+       luaL_register (L, NULL, redislib_m);
+       lua_pop (L, 1);
+
        rspamd_lua_add_preload (L, "rspamd_redis", lua_load_redis);
 }