aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@rspamd.com>2023-05-27 15:02:34 +0100
committerVsevolod Stakhov <vsevolod@rspamd.com>2023-05-27 15:02:34 +0100
commit02c28f369b72d97d75cded4000675b98411a0fcb (patch)
tree30a3aaa2025d55b691f89445d6ee6189ec2152a0 /src
parent92f695cc48a3611a2aaf14944d641343c52e504a (diff)
downloadrspamd-02c28f369b72d97d75cded4000675b98411a0fcb.tar.gz
rspamd-02c28f369b72d97d75cded4000675b98411a0fcb.zip
[Feature] Maps: Add on_load support
Diffstat (limited to 'src')
-rw-r--r--src/libserver/maps/map.c23
-rw-r--r--src/libserver/maps/map.h22
-rw-r--r--src/libserver/maps/map_private.h3
-rw-r--r--src/lua/lua_map.c61
4 files changed, 103 insertions, 6 deletions
diff --git a/src/libserver/maps/map.c b/src/libserver/maps/map.c
index bb1f1f3fc..04557f0fe 100644
--- a/src/libserver/maps/map.c
+++ b/src/libserver/maps/map.c
@@ -995,6 +995,10 @@ rspamd_map_periodic_dtor (struct map_periodic_cbdata *periodic)
if (periodic->need_modify || periodic->cbdata.errored) {
/* Need to notify the real data structure */
periodic->map->fin_callback (&periodic->cbdata, periodic->map->user_data);
+
+ if (map->on_load_function) {
+ map->on_load_function(map, map->on_load_ud);
+ }
}
else {
/* Not modified */
@@ -2300,6 +2304,10 @@ rspamd_map_preload (struct rspamd_config *cfg)
if (succeed) {
map->fin_callback (&fake_cbd.cbdata, map->user_data);
+
+ if (map->on_load_function) {
+ map->on_load_function(map, map->on_load_ud);
+ }
}
else {
msg_info_map ("preload of %s failed", map->name);
@@ -2336,6 +2344,10 @@ rspamd_map_remove_all (struct rspamd_config *cfg)
*map->user_data = NULL;
}
+ if (map->on_load_ud_dtor) {
+ map->on_load_ud_dtor(map->on_load_ud);
+ }
+
for (i = 0; i < map->backends->len; i ++) {
bk = g_ptr_array_index (map->backends, i);
@@ -3106,3 +3118,14 @@ rspamd_map_traverse (struct rspamd_map *map, rspamd_map_traverse_cb cb,
map->traverse_function (*map->user_data, cb, cbdata, reset_hits);
}
}
+
+void
+rspamd_map_set_on_load_function (struct rspamd_map *map, rspamd_map_on_load_function cb,
+ gpointer cbdata, GDestroyNotify dtor)
+{
+ if (map) {
+ map->on_load_function = cb;
+ map->on_load_ud = cbdata;
+ map->on_load_ud_dtor = dtor;
+ }
+}
diff --git a/src/libserver/maps/map.h b/src/libserver/maps/map.h
index 6d77454fb..ac2edc82a 100644
--- a/src/libserver/maps/map.h
+++ b/src/libserver/maps/map.h
@@ -22,6 +22,12 @@ struct map_cb_data;
struct rspamd_worker;
/**
+ * Common map object
+ */
+struct rspamd_config;
+struct rspamd_map;
+
+/**
* Callback types
*/
typedef gchar *(*map_cb_t) (gchar *chunk, gint len,
@@ -37,12 +43,7 @@ typedef gboolean (*rspamd_map_traverse_cb) (gconstpointer key,
typedef void (*rspamd_map_traverse_function) (void *data,
rspamd_map_traverse_cb cb,
gpointer cbdata, gboolean reset_hits);
-
-/**
- * Common map object
- */
-struct rspamd_config;
-struct rspamd_map;
+typedef void (*rspamd_map_on_load_function) (struct rspamd_map *map, gpointer ud);
/**
* Callback data for async load
@@ -151,6 +152,15 @@ rspamd_map_traverse_function rspamd_map_get_traverse_function (struct rspamd_map
void rspamd_map_traverse (struct rspamd_map *map, rspamd_map_traverse_cb cb,
gpointer cbdata, gboolean reset_hits);
+/**
+ * Set map on load callback
+ * @param map
+ * @param cb
+ * @param cbdata
+ */
+void rspamd_map_set_on_load_function (struct rspamd_map *map, rspamd_map_on_load_function cb,
+ gpointer cbdata, GDestroyNotify dtor);
+
#ifdef __cplusplus
}
#endif
diff --git a/src/libserver/maps/map_private.h b/src/libserver/maps/map_private.h
index 74b2ea042..bbbac0cd6 100644
--- a/src/libserver/maps/map_private.h
+++ b/src/libserver/maps/map_private.h
@@ -151,6 +151,9 @@ struct rspamd_map {
rspamd_map_tmp_dtor tmp_dtor;
gpointer tmp_dtor_data;
rspamd_map_traverse_function traverse_function;
+ rspamd_map_on_load_function on_load_function;
+ gpointer on_load_ud;
+ GDestroyNotify on_load_ud_dtor;
gpointer lua_map;
gsize nelts;
guint64 digest;
diff --git a/src/lua/lua_map.c b/src/lua/lua_map.c
index fe01c3031..29e2053f2 100644
--- a/src/lua/lua_map.c
+++ b/src/lua/lua_map.c
@@ -136,6 +136,13 @@ LUA_FUNCTION_DEF (map, get_stats);
LUA_FUNCTION_DEF (map, foreach);
/***
+ * @method map:on_load(callback)
+ * Sets a callback for a map that is called when map is loaded
+ * @param {function} callback callback function, that accepts no arguments (pass maps in a closure if needed)
+ */
+LUA_FUNCTION_DEF (map, on_load);
+
+/***
* @method map:get_data_digest()
* Get data digest for specific map
* @return {string} 64 bit number represented as string (due to Lua limitations)
@@ -159,6 +166,7 @@ static const struct luaL_reg maplib_m[] = {
LUA_INTERFACE_DEF (map, get_uri),
LUA_INTERFACE_DEF (map, get_stats),
LUA_INTERFACE_DEF (map, foreach),
+ LUA_INTERFACE_DEF (map, on_load),
LUA_INTERFACE_DEF (map, get_data_digest),
LUA_INTERFACE_DEF (map, get_nelts),
{"__tostring", rspamd_lua_class_tostring},
@@ -1131,6 +1139,7 @@ lua_map_foreach (lua_State * L)
cbdata.L = L;
lua_pushvalue (L, 2); /* func */
cbdata.cbref = lua_gettop (L);
+ cbdata.use_text = use_text;
if (map->map->traverse_function) {
rspamd_map_traverse (map->map, lua_map_foreach_cb, &cbdata, FALSE);
@@ -1364,6 +1373,58 @@ lua_map_get_uri (lua_State *L)
return map->map->backends->len;
}
+struct lua_map_on_load_cbdata {
+ lua_State *L;
+ gint ref;
+};
+
+static void
+lua_map_on_load_dtor (gpointer p)
+{
+ struct lua_map_on_load_cbdata *cbd = p;
+
+ luaL_unref (cbd->L, LUA_REGISTRYINDEX, cbd->ref);
+ g_free (cbd);
+}
+
+static void
+lua_map_on_load_handler (struct rspamd_map *map, gpointer ud)
+{
+ struct lua_map_on_load_cbdata *cbd = ud;
+ lua_State *L = cbd->L;
+
+ lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->ref);
+
+ if (lua_pcall(L, 0, 0, 0) != 0) {
+ msg_err_map ("call to on_load function failed: %s", lua_tostring (L, -1));
+ }
+}
+
+static gint
+lua_map_on_load (lua_State *L)
+{
+ LUA_TRACE_POINT;
+ struct rspamd_lua_map *map = lua_check_map (L, 1);
+
+ if (map == NULL) {
+ return luaL_error (L, "invalid arguments");
+ }
+
+ if (lua_type (L, 2) == LUA_TFUNCTION) {
+ lua_pushvalue (L, 2);
+ struct lua_map_on_load_cbdata *cbd = g_malloc (sizeof (struct lua_map_on_load_cbdata));
+ cbd->L = L;
+ cbd->ref = luaL_ref (L, LUA_REGISTRYINDEX);
+
+ rspamd_map_set_on_load_function(map->map, lua_map_on_load_handler, cbd, lua_map_on_load_dtor);
+ }
+ else {
+ return luaL_error (L, "invalid callback");
+ }
+
+ return 0;
+}
+
void
luaopen_map (lua_State * L)
{