]> source.dussan.org Git - rspamd.git/commitdiff
[Project] Use lua_redis to configure servers in bayes Redis backend
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 23 Nov 2018 14:08:58 +0000 (14:08 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 23 Nov 2018 16:10:28 +0000 (16:10 +0000)
src/libstat/backends/redis_backend.c
src/lua/lua_common.c

index 9e6c5f46b070edffba6b02519b0c8d40d6b57fc9..20144a7ec243aab3a50e5561becc8d3c63c858b2 100644 (file)
@@ -42,9 +42,9 @@ INIT_LOG_MODULE(stat_redis)
 #define REDIS_STAT_TIMEOUT 30
 
 struct redis_stat_ctx {
+       lua_State *L;
        struct rspamd_statfile_config *stcf;
-       struct upstream_list *read_servers;
-       struct upstream_list *write_servers;
+       gint conf_ref;
        struct rspamd_stat_async_elt *stat_elt;
        const gchar *redis_object;
        const gchar *password;
@@ -113,6 +113,22 @@ rspamd_redis_stat_quark (void)
        return g_quark_from_static_string (M);
 }
 
+static inline struct upstream_list *
+rspamd_redis_get_servers (struct redis_stat_ctx *ctx,
+                                                 const gchar *what)
+{
+       lua_State *L = ctx->L;
+       struct upstream_list *res;
+
+       lua_rawgeti (L, LUA_REGISTRYINDEX, ctx->conf_ref);
+       lua_pushstring (L, what);
+       lua_gettable (L, -2);
+       res = *((struct upstream_list**)lua_touserdata (L, -1));
+       lua_settop (L, 0);
+
+       return res;
+}
+
 /*
  * Non-static for lua unit testing
  */
@@ -939,6 +955,7 @@ rspamd_redis_async_stat_cb (struct rspamd_stat_async_elt *elt, gpointer d)
        struct rspamd_redis_stat_elt *redis_elt = elt->ud;
        struct rspamd_redis_stat_cbdata *cbdata;
        rspamd_inet_addr_t *addr;
+       struct upstream_list *ups;
 
        g_assert (redis_elt != NULL);
 
@@ -952,8 +969,15 @@ rspamd_redis_async_stat_cb (struct rspamd_stat_async_elt *elt, gpointer d)
        /* Disable further events unless needed */
        elt->enabled = FALSE;
 
+       ups = rspamd_redis_get_servers (ctx, "read_servers");
+
+       if (!ups) {
+               return;
+       }
+
        cbdata = g_malloc0 (sizeof (*cbdata));
-       cbdata->selected = rspamd_upstream_get (ctx->read_servers,
+
+       cbdata->selected = rspamd_upstream_get (ups,
                                        RSPAMD_UPSTREAM_ROUND_ROBIN,
                                        NULL,
                                        0);
@@ -1250,78 +1274,6 @@ rspamd_redis_learned (redisAsyncContext *c, gpointer r, gpointer priv)
                rspamd_session_remove_event (task->s, rspamd_redis_fin_learn, rt);
        }
 }
-
-static gboolean
-rspamd_redis_try_ucl (struct redis_stat_ctx *backend,
-               const ucl_object_t *obj,
-               struct rspamd_config *cfg,
-               const gchar *symbol)
-{
-       const ucl_object_t *elt, *relt;
-
-       elt = ucl_object_lookup_any (obj, "read_servers", "servers", NULL);
-
-       if (elt == NULL) {
-               return FALSE;
-       }
-
-       backend->read_servers = rspamd_upstreams_create (cfg->ups_ctx);
-       if (!rspamd_upstreams_from_ucl (backend->read_servers, elt,
-                       REDIS_DEFAULT_PORT, NULL)) {
-               msg_err ("statfile %s cannot get read servers configuration",
-                               symbol);
-               return FALSE;
-       }
-
-       relt = elt;
-
-       elt = ucl_object_lookup (obj, "write_servers");
-       if (elt == NULL) {
-               /* Use read servers as write ones */
-               g_assert (relt != NULL);
-               backend->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
-               if (!rspamd_upstreams_from_ucl (backend->write_servers, relt,
-                               REDIS_DEFAULT_PORT, NULL)) {
-                       msg_err ("statfile %s cannot get write servers configuration",
-                                       symbol);
-                       return FALSE;
-               }
-       }
-       else {
-               backend->write_servers = rspamd_upstreams_create (cfg->ups_ctx);
-               if (!rspamd_upstreams_from_ucl (backend->write_servers, elt,
-                               REDIS_DEFAULT_PORT, NULL)) {
-                       msg_err ("statfile %s cannot get write servers configuration",
-                                       symbol);
-                       rspamd_upstreams_destroy (backend->write_servers);
-                       backend->write_servers = NULL;
-               }
-       }
-
-       elt = ucl_object_lookup_any (obj, "db", "database", "dbname", NULL);
-       if (elt) {
-               if (ucl_object_type (elt) == UCL_STRING) {
-                       backend->dbname = ucl_object_tostring (elt);
-               }
-               else if (ucl_object_type (elt) == UCL_INT) {
-                       backend->dbname = ucl_object_tostring_forced (elt);
-               }
-       }
-       else {
-               backend->dbname = NULL;
-       }
-
-       elt = ucl_object_lookup (obj, "password");
-       if (elt) {
-               backend->password = ucl_object_tostring (elt);
-       }
-       else {
-               backend->password = NULL;
-       }
-
-       return TRUE;
-}
-
 static void
 rspamd_redis_parse_classifier_opts (struct redis_stat_ctx *backend,
                const ucl_object_t *obj,
@@ -1379,14 +1331,6 @@ rspamd_redis_parse_classifier_opts (struct redis_stat_ctx *backend,
                backend->redis_object = ucl_object_tostring (elt);
        }
 
-       elt = ucl_object_lookup (obj, "timeout");
-       if (elt) {
-               backend->timeout = ucl_object_todouble (elt);
-       }
-       else {
-               backend->timeout = REDIS_DEFAULT_TIMEOUT;
-       }
-
        elt = ucl_object_lookup (obj, "store_tokens");
        if (elt) {
                backend->store_tokens = ucl_object_toboolean (elt);
@@ -1433,24 +1377,27 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx,
        struct rspamd_redis_stat_elt *st_elt;
        const ucl_object_t *obj;
        gboolean ret = FALSE;
+       gint conf_ref = -1;
+       lua_State *L = (lua_State *)cfg->lua_state;
 
        backend = g_malloc0 (sizeof (*backend));
+       backend->L = L;
+       backend->timeout = REDIS_DEFAULT_TIMEOUT;
 
        /* First search in backend configuration */
        obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
        if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
-               ret = rspamd_redis_try_ucl (backend, obj, cfg, stf->symbol);
+               ret = rspamd_lua_try_load_redis (L, obj, cfg, &conf_ref);
        }
 
        /* Now try statfiles config */
-       if (!ret) {
-               ret = rspamd_redis_try_ucl (backend, stf->opts, cfg, stf->symbol);
+       if (!ret && stf->opts) {
+               ret = rspamd_lua_try_load_redis (L, stf->opts, cfg, &conf_ref);
        }
 
        /* Now try classifier config */
-       if (!ret) {
-               ret = rspamd_redis_try_ucl (backend, st->classifier->cfg->opts, cfg,
-                               stf->symbol);
+       if (!ret && st->classifier->cfg->opts) {
+               ret = rspamd_lua_try_load_redis (L, st->classifier->cfg->opts, cfg, &conf_ref);
        }
 
        /* Now try global redis settings */
@@ -1463,12 +1410,12 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx,
                        specific_obj = ucl_object_lookup (obj, "statistics");
 
                        if (specific_obj) {
-                               ret = rspamd_redis_try_ucl (backend, specific_obj, cfg,
-                                               stf->symbol);
+                               ret = rspamd_lua_try_load_redis (L,
+                                               specific_obj, cfg, &conf_ref);
                        }
                        else {
-                               ret = rspamd_redis_try_ucl (backend, obj, cfg,
-                                               stf->symbol);
+                               ret = rspamd_lua_try_load_redis (L,
+                                               obj, cfg, &conf_ref);
                        }
                }
        }
@@ -1479,6 +1426,36 @@ rspamd_redis_init (struct rspamd_stat_ctx *ctx,
                return NULL;
        }
 
+       backend->conf_ref = conf_ref;
+
+       /* Check some common table values */
+       lua_rawgeti (L, LUA_REGISTRYINDEX, conf_ref);
+
+       lua_pushstring (L, "timeout");
+       lua_gettable (L, -2);
+       if (lua_type (L, -1) == LUA_TNUMBER) {
+               backend->timeout = lua_tonumber (L, -1);
+       }
+       lua_pop (L, 1);
+
+       lua_pushstring (L, "db");
+       lua_gettable (L, -2);
+       if (lua_type (L, -1) == LUA_TSTRING) {
+               backend->dbname = rspamd_mempool_strdup (cfg->cfg_pool,
+                               lua_tostring (L, -1));
+       }
+       lua_pop (L, 1);
+
+       lua_pushstring (L, "password");
+       lua_gettable (L, -2);
+       if (lua_type (L, -1) == LUA_TSTRING) {
+               backend->password = rspamd_mempool_strdup (cfg->cfg_pool,
+                               lua_tostring (L, -1));
+       }
+       lua_pop (L, 1);
+
+       lua_settop (L, 0);
+
        rspamd_redis_parse_classifier_opts (backend, st->classifier->cfg->opts, cfg);
        stf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND;
        backend->stcf = stf;
@@ -1504,25 +1481,35 @@ rspamd_redis_runtime (struct rspamd_task *task,
        struct redis_stat_ctx *ctx = REDIS_CTX (c);
        struct redis_stat_runtime *rt;
        struct upstream *up;
+       struct upstream_list *ups;
        char *object_expanded = NULL;
        rspamd_inet_addr_t *addr;
 
        g_assert (ctx != NULL);
        g_assert (stcf != NULL);
 
-       if (learn && ctx->write_servers == NULL) {
-               msg_err_task ("no write servers defined for %s, cannot learn", stcf->symbol);
-               return NULL;
-       }
-
        if (learn) {
-               up = rspamd_upstream_get (ctx->write_servers,
+               ups = rspamd_redis_get_servers (ctx, "write_servers");
+
+               if (!ups) {
+                       msg_err_task ("no write servers defined for %s, cannot learn",
+                                       stcf->symbol);
+                       return NULL;
+               }
+               up = rspamd_upstream_get (ups,
                                RSPAMD_UPSTREAM_MASTER_SLAVE,
                                NULL,
                                0);
        }
        else {
-               up = rspamd_upstream_get (ctx->read_servers,
+               ups = rspamd_redis_get_servers (ctx, "read_servers");
+
+               if (!ups) {
+                       msg_err_task ("no read servers defined for %s, cannot stat",
+                                       stcf->symbol);
+                       return NULL;
+               }
+               up = rspamd_upstream_get (ups,
                                RSPAMD_UPSTREAM_ROUND_ROBIN,
                                NULL,
                                0);
@@ -1576,13 +1563,10 @@ void
 rspamd_redis_close (gpointer p)
 {
        struct redis_stat_ctx *ctx = REDIS_CTX (p);
+       lua_State *L = ctx->L;
 
-       if (ctx->read_servers) {
-               rspamd_upstreams_destroy (ctx->read_servers);
-       }
-
-       if (ctx->write_servers) {
-               rspamd_upstreams_destroy (ctx->write_servers);
+       if (ctx->conf_ref) {
+               luaL_unref (L, LUA_REGISTRYINDEX, ctx->conf_ref);
        }
 
        g_free (ctx);
@@ -1685,6 +1669,7 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
 {
        struct redis_stat_runtime *rt = REDIS_RUNTIME (p);
        struct upstream *up;
+       struct upstream_list *ups;
        rspamd_inet_addr_t *addr;
        struct timeval tv;
        rspamd_fstring_t *query;
@@ -1698,7 +1683,12 @@ rspamd_redis_learn_tokens (struct rspamd_task *task, GPtrArray *tokens,
                return FALSE;
        }
 
-       up = rspamd_upstream_get (rt->ctx->write_servers,
+       ups = rspamd_redis_get_servers (rt->ctx, "write_servers");
+
+       if (!ups) {
+               return FALSE;
+       }
+       up = rspamd_upstream_get (ups,
                        RSPAMD_UPSTREAM_MASTER_SLAVE,
                        NULL,
                        0);
index 54c1dcad6e0223930980df2faf69a08cb1fbee95..ac6b11e18b266bbd233045ae65ee47f1833962a0 100644 (file)
@@ -2373,7 +2373,7 @@ rspamd_lua_try_load_redis (lua_State *L, const ucl_object_t *obj,
        *pcfg = cfg;
        lua_pushvalue (L, res_pos);
 
-       if (lua_pcall (L, 0, 1, err_idx) != 0) {
+       if (lua_pcall (L, 3, 1, err_idx) != 0) {
                GString *tb;
 
                tb = lua_touserdata (L, -1);