]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Use lua_redis for redis_cache as well
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 23 Nov 2018 14:28:37 +0000 (14:28 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 23 Nov 2018 16:10:28 +0000 (16:10 +0000)
src/libstat/learn_cache/redis_cache.c

index d8b7c9c37a03d41f98be16a39f223c5f3f500cfa..c976ce9c52aeec0746537a05dea835a71d7122d0 100644 (file)
@@ -22,6 +22,7 @@
 #include "ucl.h"
 #include "hiredis.h"
 #include "adapters/libevent.h"
+#include "lua/lua_common.h"
 
 #define REDIS_DEFAULT_TIMEOUT 0.5
 #define REDIS_STAT_TIMEOUT 30
 static const gchar *M = "redis learn cache";
 
 struct rspamd_redis_cache_ctx {
+       lua_State *L;
        struct rspamd_statfile_config *stcf;
-       struct upstream_list *read_servers;
-       struct upstream_list *write_servers;
        const gchar *password;
        const gchar *dbname;
        const gchar *redis_object;
        gdouble timeout;
+       gint conf_ref;
 };
 
 struct rspamd_redis_cache_runtime {
@@ -56,6 +57,22 @@ rspamd_stat_cache_redis_quark (void)
        return g_quark_from_static_string (M);
 }
 
+static inline struct upstream_list *
+rspamd_redis_get_servers (struct rspamd_redis_cache_ctx *ctx,
+                                                 const gchar *what)
+{
+       lua_State *L = ctx->L;
+       struct upstream_list *res;
+
+       lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref);
+       lua_pushstring (L, what);
+       lua_gettable (L, -2);
+       res = *((struct upstream_list**)lua_touserdata (L, -1));
+       lua_settop (L, 0);
+
+       return res;
+}
+
 static void
 rspamd_redis_cache_maybe_auth (struct rspamd_redis_cache_ctx *ctx,
                redisAsyncContext *redis)
@@ -217,94 +234,6 @@ rspamd_stat_cache_redis_generate_id (struct rspamd_task *task)
        rspamd_mempool_set_variable (task->task_pool, "words_hash", b32out, g_free);
 }
 
-static gboolean
-rspamd_redis_cache_try_ucl (struct rspamd_redis_cache_ctx *cache_ctx,
-               const ucl_object_t *obj,
-               struct rspamd_config *cfg,
-               const gchar *symbol)
-{
-       const ucl_object_t *elt, *relt;
-
-       elt = ucl_object_lookup_any (obj, "read_servers", "servers", NULL);
-
-       if (elt == NULL) {
-               return FALSE;
-       }
-
-       cache_ctx->read_servers = rspamd_upstreams_create (cfg->ups_ctx);
-       if (!rspamd_upstreams_from_ucl (cache_ctx->read_servers, elt,
-                       REDIS_DEFAULT_PORT, NULL)) {
-               msg_err ("statfile %s cannot get read servers configuration",
-                               symbol);
-               return FALSE;
-       }
-
-       relt = elt;
-
-       elt = ucl_object_lookup (obj, "write_servers");
-       if (elt == NULL) {
-               /* Use read servers as write ones */
-               g_assert (relt != NULL);
-               cache_ctx->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
-               if (!rspamd_upstreams_from_ucl (cache_ctx->write_servers, relt,
-                               REDIS_DEFAULT_PORT, NULL)) {
-                       msg_err ("statfile %s cannot get write servers configuration",
-                                       symbol);
-                       return FALSE;
-               }
-       }
-       else {
-               cache_ctx->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
-               if (!rspamd_upstreams_from_ucl (cache_ctx->write_servers, elt,
-                               REDIS_DEFAULT_PORT, NULL)) {
-                       msg_err ("statfile %s cannot get write servers configuration",
-                                       symbol);
-                       rspamd_upstreams_destroy (cache_ctx->write_servers);
-                       cache_ctx->write_servers = NULL;
-               }
-       }
-
-
-       elt = ucl_object_lookup (obj, "timeout");
-       if (elt) {
-               cache_ctx->timeout = ucl_object_todouble (elt);
-       }
-       else {
-               cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
-       }
-
-       elt = ucl_object_lookup (obj, "password");
-       if (elt) {
-               cache_ctx->password = ucl_object_tostring (elt);
-       }
-       else {
-               cache_ctx->password = NULL;
-       }
-
-       elt = ucl_object_lookup_any (obj, "db", "database", "dbname", NULL);
-       if (elt) {
-               if (ucl_object_type (elt) == UCL_STRING) {
-                       cache_ctx->dbname = ucl_object_tostring (elt);
-               }
-               else if (ucl_object_type (elt) == UCL_INT) {
-                       cache_ctx->dbname = ucl_object_tostring_forced (elt);
-               }
-       }
-       else {
-               cache_ctx->dbname = NULL;
-       }
-
-       elt = ucl_object_lookup_any (obj, "cache_key", "key", NULL);
-       if (elt == NULL || ucl_object_type (elt) != UCL_STRING) {
-               cache_ctx->redis_object = DEFAULT_REDIS_KEY;
-       }
-       else {
-               cache_ctx->redis_object = ucl_object_tostring (elt);
-       }
-
-       return TRUE;
-}
-
 gpointer
 rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
                struct rspamd_config *cfg,
@@ -315,24 +244,27 @@ rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
        struct rspamd_statfile_config *stf = st->stcf;
        const ucl_object_t *obj;
        gboolean ret = FALSE;
+       lua_State *L = (lua_State *)cfg->lua_state;
+       gint conf_ref = -1;
 
        cache_ctx = g_malloc0 (sizeof (*cache_ctx));
+       cache_ctx->timeout = REDIS_DEFAULT_TIMEOUT;
+       cache_ctx->L = L;
 
        /* First search in backend configuration */
        obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
        if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
-               ret = rspamd_redis_cache_try_ucl (cache_ctx, obj, cfg, stf->symbol);
+               ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref);
        }
 
        /* Now try statfiles config */
-       if (!ret) {
-               ret = rspamd_redis_cache_try_ucl (cache_ctx, stf->opts, cfg, stf->symbol);
+       if (!ret && stf->opts) {
+               ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref);
        }
 
        /* Now try classifier config */
-       if (!ret) {
-               ret = rspamd_redis_cache_try_ucl (cache_ctx, st->classifier->cfg->opts, cfg,
-                               stf->symbol);
+       if (!ret && st->classifier->cfg->opts) {
+               ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref);
        }
 
        /* Now try global redis settings */
@@ -345,23 +277,61 @@ rspamd_stat_cache_redis_init (struct rspamd_stat_ctx *ctx,
                        specific_obj = ucl_object_lookup (obj, "statistics");
 
                        if (specific_obj) {
-                               ret = rspamd_redis_cache_try_ucl (cache_ctx, specific_obj, cfg,
-                                               stf->symbol);
+                               ret = rspamd_lua_try_load_redis (L,
+                                               specific_obj, cfg, &conf_ref);
                        }
                        else {
-                               ret = rspamd_redis_cache_try_ucl (cache_ctx, obj, cfg,
-                                               stf->symbol);
+                               ret = rspamd_lua_try_load_redis (L,
+                                               obj, cfg, &conf_ref);
                        }
                }
        }
 
-
        if (!ret) {
                msg_err_config ("cannot init redis cache for %s", stf->symbol);
                g_free (cache_ctx);
                return NULL;
        }
 
+       obj = ucl_object_lookup (st->classifier->cfg->opts, "cache_key");
+
+       if (obj) {
+               cache_ctx->redis_object = ucl_object_tostring (obj);
+       }
+       else {
+               cache_ctx->redis_object = DEFAULT_REDIS_KEY;
+       }
+
+       cache_ctx->conf_ref = conf_ref;
+
+       /* Check some common table values */
+       lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref);
+
+       lua_pushstring (L, "timeout");
+       lua_gettable (L, -2);
+       if (lua_type (L, -1) == LUA_TNUMBER) {
+               cache_ctx->timeout = lua_tonumber (L, -1);
+       }
+       lua_pop (L, 1);
+
+       lua_pushstring (L, "db");
+       lua_gettable (L, -2);
+       if (lua_type (L, -1) == LUA_TSTRING) {
+               cache_ctx->dbname = rspamd_mempool_strdup (cfg->cfg_pool,
+                               lua_tostring (L, -1));
+       }
+       lua_pop (L, 1);
+
+       lua_pushstring (L, "password");
+       lua_gettable (L, -2);
+       if (lua_type (L, -1) == LUA_TSTRING) {
+               cache_ctx->password = rspamd_mempool_strdup (cfg->cfg_pool,
+                               lua_tostring (L, -1));
+       }
+       lua_pop (L, 1);
+
+       lua_settop (L, 0);
+
        cache_ctx->stcf = stf;
 
        return (gpointer)cache_ctx;
@@ -374,28 +344,39 @@ rspamd_stat_cache_redis_runtime (struct rspamd_task *task,
        struct rspamd_redis_cache_ctx *ctx = c;
        struct rspamd_redis_cache_runtime *rt;
        struct upstream *up;
+       struct upstream_list *ups;
        rspamd_inet_addr_t *addr;
 
        g_assert (ctx != NULL);
 
-       if (learn && ctx->write_servers == NULL) {
-               msg_err_task ("no write servers defined for %s, cannot learn",
-                               ctx->stcf->symbol);
-               return NULL;
-       }
-
        if (task->tokens == NULL || task->tokens->len == 0) {
                return NULL;
        }
 
        if (learn) {
-               up = rspamd_upstream_get (ctx->write_servers,
+               ups = rspamd_redis_get_servers (ctx, "write_servers");
+
+               if (!ups) {
+                       msg_err_task ("no write servers defined for %s, cannot learn",
+                                       ctx->stcf->symbol);
+                       return NULL;
+               }
+
+               up = rspamd_upstream_get (ups,
                                RSPAMD_UPSTREAM_MASTER_SLAVE,
                                NULL,
                                0);
        }
        else {
-               up = rspamd_upstream_get (ctx->read_servers,
+               ups = rspamd_redis_get_servers (ctx, "read_servers");
+
+               if (!ups) {
+                       msg_err_task ("no read servers defined for %s, cannot check",
+                                       ctx->stcf->symbol);
+                       return NULL;
+               }
+
+               up = rspamd_upstream_get (ups,
                                RSPAMD_UPSTREAM_ROUND_ROBIN,
                                NULL,
                                0);
@@ -512,5 +493,12 @@ rspamd_stat_cache_redis_learn (struct rspamd_task *task,
 void
 rspamd_stat_cache_redis_close (gpointer c)
 {
+       struct rspamd_redis_cache_ctx *ctx = (struct rspamd_redis_cache_ctx *)c;
+       lua_State *L = ctx->L;
+
+       if (ctx->conf_ref) {
+               luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref);
+       }
 
+       g_free (ctx);
 }