]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Lua_task: Add get_urls_filtered method
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 19 Mar 2021 17:04:41 +0000 (17:04 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Fri, 19 Mar 2021 17:04:41 +0000 (17:04 +0000)
src/lua/lua_task.c
src/lua/lua_url.c
src/lua/lua_url.h

index bce91b4fb48619ab9f64a458c5bdea88d23b08f9..2de3fb5ed577ca7658750c431a3a2042ab6ff286 100644 (file)
@@ -256,6 +256,18 @@ local function phishing_cb(task)
 end
  */
 LUA_FUNCTION_DEF (task, get_urls);
+/***
+ * @method task:get_urls_filtered([{flags_include}, [{flags_exclude}]], [{protocols_mask}])
+ * Get urls managed by either exclude or include flags list
+ * - If flags include are nil then all but excluded urls are returned
+ * - If flags exclude are nil then only included explicitly urls are returned
+ * - If both parameters are nil then all urls are included
+ * @param {table|string} flags_include included flags
+ * @param {table|string} flags_exclude excluded flags
+ * @param {table|string} protocols_mask incude only specific protocols
+ * @return {table rspamd_url} list of urls matching conditions
+ */
+LUA_FUNCTION_DEF (task, get_urls_filtered);
 /***
  * @method task:has_urls([need_emails])
  * Returns 'true' if a task has urls listed
@@ -1212,6 +1224,7 @@ static const struct luaL_reg tasklib_m[] = {
        LUA_INTERFACE_DEF (task, append_message),
        LUA_INTERFACE_DEF (task, has_urls),
        LUA_INTERFACE_DEF (task, get_urls),
+       LUA_INTERFACE_DEF (task, get_urls_filtered),
        LUA_INTERFACE_DEF (task, inject_url),
        LUA_INTERFACE_DEF (task, get_content),
        LUA_INTERFACE_DEF (task, get_filename),
@@ -2463,6 +2476,74 @@ lua_task_get_urls (lua_State * L)
        return 1;
 }
 
+static gint
+lua_task_get_urls_filtered (lua_State * L)
+{
+       LUA_TRACE_POINT;
+       struct rspamd_task *task = lua_check_task (L, 1);
+       struct lua_tree_cb_data cb;
+       struct rspamd_url *u;
+       static const gint default_protocols_mask = PROTOCOL_HTTP|PROTOCOL_HTTPS|
+                                                                                          PROTOCOL_FILE|PROTOCOL_FTP;
+       gsize sz, max_urls = 0;
+
+       if (task) {
+               if (task->cfg) {
+                       max_urls = task->cfg->max_lua_urls;
+               }
+
+               if (task->message == NULL) {
+                       lua_newtable (L);
+
+                       return 1;
+               }
+
+               if (!lua_url_cbdata_fill_exclude_include (L, 2, &cb, default_protocols_mask, max_urls)) {
+                       return luaL_error (L, "invalid arguments");
+               }
+
+               sz = kh_size (MESSAGE_FIELD (task, urls));
+               sz = lua_url_adjust_skip_prob (task->task_timestamp,
+                               MESSAGE_FIELD (task, digest), &cb, sz);
+
+               lua_createtable (L, sz, 0);
+
+               if (cb.sort) {
+                       struct rspamd_url **urls_sorted;
+                       gint i = 0;
+
+                       urls_sorted = g_new0 (struct rspamd_url *, sz);
+
+                       kh_foreach_key (MESSAGE_FIELD(task, urls), u, {
+                               if (i < sz) {
+                                       urls_sorted[i] = u;
+                                       i ++;
+                               }
+                       });
+
+                       qsort (urls_sorted, i, sizeof (struct rspamd_url *), rspamd_url_cmp_qsort);
+
+                       for (int j = 0; j < i; j ++) {
+                               lua_tree_url_callback (urls_sorted[j], urls_sorted[j], &cb);
+                       }
+
+                       g_free (urls_sorted);
+               }
+               else {
+                       kh_foreach_key (MESSAGE_FIELD(task, urls), u, {
+                               lua_tree_url_callback(u, u, &cb);
+                       });
+               }
+
+               lua_url_cbdata_dtor (&cb);
+       }
+       else {
+               return luaL_error (L, "invalid arguments, no task");
+       }
+
+       return 1;
+}
+
 static gint
 lua_task_has_urls (lua_State * L)
 {
index 69c7d79bf8343f20ed29c40118d721bc111bc0a9..b56f025c4a01723d3de7806e7eabead5ae8f542e 100644 (file)
@@ -957,15 +957,26 @@ lua_tree_url_callback (gpointer key, gpointer value, gpointer ud)
 
        if ((url->protocol & cb->protocols_mask) == url->protocol) {
 
-               if (cb->flags_mode == url_flags_mode_include_any) {
+               /* Handle different flags application logic */
+               switch (cb->flags_mode) {
+               case url_flags_mode_include_any:
                        if (url->flags != (url->flags & cb->flags_mask)) {
                                return;
                        }
-               }
-               else {
+                       break;
+               case url_flags_mode_include_explicit:
                        if ((url->flags & cb->flags_mask) != cb->flags_mask) {
                                return;
                        }
+                       break;
+               case url_flags_mode_exclude_include:
+                       if (url->flags & cb->flags_exclude_mask) {
+                               return;
+                       }
+                       if (url->flags != (url->flags & cb->flags_mask)) {
+                               return;
+                       }
+                       break;
                }
 
                if (cb->skip_prob > 0) {
@@ -1207,6 +1218,113 @@ lua_url_cbdata_fill (lua_State *L,
        return TRUE;
 }
 
+gboolean
+lua_url_cbdata_fill_exclude_include (lua_State *L,
+                                        gint pos,
+                                        struct lua_tree_cb_data *cbd,
+                                        guint default_protocols,
+                                        gsize max_urls)
+{
+       guint protocols_mask = default_protocols;
+       guint include_flags_mask, exclude_flags_mask;
+
+       gint pos_arg_type = lua_type (L, pos);
+
+       memset (cbd, 0, sizeof (*cbd));
+       cbd->flags_mode = url_flags_mode_exclude_include;
+
+       /* Include flags */
+       if (pos_arg_type == LUA_TTABLE) {
+               include_flags_mask = 0; /* Reset to no flags */
+
+               for (lua_pushnil(L); lua_next(L, pos); lua_pop (L, 1)) {
+                       int nmask = 0;
+                       const gchar *fname = lua_tostring (L, -1);
+
+                       if (rspamd_url_flag_from_string(fname, &nmask)) {
+                               include_flags_mask |= nmask;
+                       }
+                       else {
+                               msg_info ("bad url include flag: %s", fname);
+                               return FALSE;
+                       }
+               }
+       }
+       else if (pos_arg_type == LUA_TNIL) {
+               /* Include all flags */
+               include_flags_mask = ~0U;
+       }
+       else {
+               msg_info ("bad arguments: wrong include mask");
+               return FALSE;
+       }
+
+       /* Exclude flags */
+       pos_arg_type = lua_type (L, pos + 1);
+       if (pos_arg_type == LUA_TTABLE) {
+               exclude_flags_mask = 0; /* Reset to no flags */
+
+               for (lua_pushnil(L); lua_next(L, pos); lua_pop (L, 1)) {
+                       int nmask = 0;
+
+                       const gchar *fname = lua_tostring (L, -1);
+
+                       if (rspamd_url_flag_from_string(fname, &nmask)) {
+                               exclude_flags_mask |= nmask;
+                       }
+                       else {
+                               msg_info ("bad url exclude flag: %s", fname);
+                               return FALSE;
+                       }
+               }
+       }
+       else if (pos_arg_type == LUA_TNIL) {
+               /* Empty all exclude flags */
+               exclude_flags_mask = 0U;
+       }
+       else {
+               msg_info ("bad arguments: wrong exclude mask");
+               return FALSE;
+       }
+
+       if (lua_type (L, pos + 2) == LUA_TTABLE) {
+               protocols_mask = 0U; /* Reset all protocols */
+
+               for (lua_pushnil (L); lua_next (L, pos + 2); lua_pop (L, 1)) {
+                       int nmask;
+                       const gchar *pname = lua_tostring (L, -1);
+
+                       nmask = rspamd_url_protocol_from_string (pname);
+
+                       if (nmask != PROTOCOL_UNKNOWN) {
+                               protocols_mask |= nmask;
+                       }
+                       else {
+                               msg_info ("bad url protocol: %s", pname);
+                               return FALSE;
+                       }
+               }
+       }
+       else {
+               protocols_mask = default_protocols;
+       }
+
+       cbd->i = 1;
+       cbd->L = L;
+       cbd->max_urls = max_urls;
+       cbd->protocols_mask = protocols_mask;
+       cbd->flags_mask = include_flags_mask;
+       cbd->flags_exclude_mask = exclude_flags_mask;
+
+       /* This needs to be removed from the stack */
+       rspamd_lua_class_metatable (L, "rspamd{url}");
+       cbd->metatable_pos = lua_gettop (L);
+       (void)lua_checkstack (L, cbd->metatable_pos + 4);
+
+       return TRUE;
+}
+
+
 void
 lua_url_cbdata_dtor (struct lua_tree_cb_data *cbd)
 {
index 705fe16153aec8905c3087398b8de33a2cfa05ad..904a56da71b7e26822a02f465e20c4106b012856 100644 (file)
@@ -27,15 +27,17 @@ struct lua_tree_cb_data {
        int i;
        int metatable_pos;
        guint flags_mask;
+       guint flags_exclude_mask;
        guint protocols_mask;
        enum {
                url_flags_mode_include_any,
                url_flags_mode_include_explicit,
+               url_flags_mode_exclude_include,
        } flags_mode;
+       gboolean sort;
        gsize max_urls;
        gdouble skip_prob;
        guint64 xoroshiro_state[4];
-       gboolean sort;
 };
 
 void lua_tree_url_callback (gpointer key, gpointer value, gpointer ud);
@@ -53,6 +55,11 @@ gboolean lua_url_cbdata_fill (lua_State *L, gint pos,
                                                          guint default_flags,
                                                          gsize max_urls);
 
+gboolean lua_url_cbdata_fill_exclude_include (lua_State *L, gint pos,
+                                                         struct lua_tree_cb_data *cbd,
+                                                         guint default_protocols,
+                                                         gsize max_urls);
+
 /**
  * Cleanup url cbdata
  * @param cbd