]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Lua_trie: Allow to report start of the match
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Oct 2020 18:04:31 +0000 (18:04 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Mon, 26 Oct 2020 18:04:31 +0000 (18:04 +0000)
src/lua/lua_trie.c

index 7c63fc6870e7c8f74feb8deff994840c36f77fcc..33e5832a88e2c119681ebd125f9b89620ea9ed76 100644 (file)
@@ -160,23 +160,39 @@ lua_trie_create (lua_State *L)
        return 1;
 }
 
+#define PUSH_TRIE_MATCH(L, start, end, report_start) do { \
+       if (report_start) { \
+               lua_createtable (L, 2, 0); \
+               lua_pushinteger (L, (start)); \
+               lua_rawseti (L, -2, 1); \
+               lua_pushinteger (L, (end)); \
+               lua_rawseti (L, -2, 2); \
+       } \
+       else { \
+               lua_pushinteger (L, (end)); \
+       } \
+} while(0)
+
 /* Normal callback type */
 static gint
-lua_trie_callback (struct rspamd_multipattern *mp,
-               guint strnum,
-               gint match_start,
-               gint textpos,
-               const gchar *text,
-               gsize len,
-               void *context)
+lua_trie_lua_cb_callback (struct rspamd_multipattern *mp,
+                                                 guint strnum,
+                                                 gint match_start,
+                                                 gint textpos,
+                                                 const gchar *text,
+                                                 gsize len,
+                                                 void *context)
 {
        lua_State *L = context;
        gint ret;
 
+       gboolean report_start = lua_toboolean (L, -1);
+
        /* Function */
        lua_pushvalue (L, 3);
        lua_pushinteger (L, strnum + 1);
-       lua_pushinteger (L, textpos);
+
+       PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
 
        if (lua_pcall (L, 2, 1, 0) != 0) {
                msg_info ("call to trie callback has failed: %s",
@@ -204,13 +220,14 @@ lua_trie_table_callback (struct rspamd_multipattern *mp,
 {
        lua_State *L = context;
 
+       gint report_start = lua_toboolean (L, -2);
        /* Set table, indexed by pattern number */
        lua_rawgeti (L, -1, strnum + 1);
 
        if (lua_istable (L, -1)) {
                /* Already have table, add offset */
                gsize last = rspamd_lua_table_size (L, -1);
-               lua_pushinteger (L, textpos);
+               PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
                lua_rawseti (L, -2, last + 1);
                /* Remove table from the stack */
                lua_pop (L, 1);
@@ -220,7 +237,7 @@ lua_trie_table_callback (struct rspamd_multipattern *mp,
                lua_pop (L, 1);
                /* New table */
                lua_newtable (L);
-               lua_pushinteger (L, textpos);
+               PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
                lua_rawseti (L, -2, 1);
                lua_rawseti (L, -2, strnum + 1);
        }
@@ -247,10 +264,11 @@ lua_trie_search_str (lua_State *L, struct rspamd_multipattern *trie,
 }
 
 /***
- * @method trie:match(input, [cb])
+ * @method trie:match(input, [cb][, report_start])
  * Search for patterns in `input` invoking `cb` optionally ignoring case
  * @param {table or string} input one or several (if `input` is an array) strings of input text
  * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends
+ * @param {boolean} report_start report both start and end offset when matching patterns
  * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number
  */
 static gint
@@ -260,16 +278,29 @@ lua_trie_match (lua_State *L)
        struct rspamd_multipattern *trie = lua_check_trie (L, 1);
        const gchar *text;
        gsize len;
-       gboolean found = FALSE;
+       gboolean found = FALSE, report_start = FALSE;
        struct rspamd_lua_text *t;
-       rspamd_multipattern_cb_t cb = lua_trie_callback;
+       rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
+
+       gint old_top = lua_gettop (L);
 
        if (trie) {
                if (lua_type (L, 3) != LUA_TFUNCTION) {
+                       if (lua_isboolean (L, 3)) {
+                               report_start = lua_toboolean (L, 3);
+                       }
+
+                       lua_pushboolean (L, report_start);
                        /* Table like match */
                        lua_newtable (L);
                        cb = lua_trie_table_callback;
                }
+               else {
+                       if (lua_isboolean (L, 4)) {
+                               report_start = lua_toboolean (L, 4);
+                       }
+                       lua_pushboolean (L, report_start);
+               }
 
                if (lua_type (L, 2) == LUA_TTABLE) {
                        lua_pushvalue (L, 2);
@@ -294,8 +325,6 @@ lua_trie_match (lua_State *L)
                                }
                                lua_pop (L, 1);
                        }
-
-                       lua_pop (L, 1); /* table */
                }
                else if (lua_type (L, 2) == LUA_TSTRING) {
                        text = lua_tolstring (L, 2, &len);
@@ -314,8 +343,12 @@ lua_trie_match (lua_State *L)
        }
 
        if (lua_type (L, 3) == LUA_TFUNCTION) {
+               lua_settop (L, old_top);
                lua_pushboolean (L, found);
        }
+       else {
+               lua_remove (L, -2);
+       }
 
        return 1;
 }
@@ -338,7 +371,7 @@ lua_trie_search_mime (lua_State *L)
        const gchar *text;
        gsize len, i;
        gboolean found = FALSE;
-       rspamd_multipattern_cb_t cb = lua_trie_callback;
+       rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
 
        if (trie && task) {
                PTR_ARRAY_FOREACH (MESSAGE_FIELD (task, text_parts), i, part) {
@@ -379,7 +412,7 @@ lua_trie_search_rawmsg (lua_State *L)
                text = task->msg.begin;
                len = task->msg.len;
 
-               if (lua_trie_search_str (L, trie, text, len, lua_trie_callback) != 0) {
+               if (lua_trie_search_str (L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
                        found = TRUE;
                }
        }
@@ -417,7 +450,7 @@ lua_trie_search_rawbody (lua_State *L)
                        len = task->msg.len;
                }
 
-               if (lua_trie_search_str (L, trie, text, len, lua_trie_callback) != 0) {
+               if (lua_trie_search_str (L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
                        found = TRUE;
                }
        }