]> source.dussan.org Git - rspamd.git/commitdiff
[Feature] Lua_trie: More flexible API
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 4 Sep 2019 17:40:48 +0000 (18:40 +0100)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Wed, 4 Sep 2019 17:40:48 +0000 (18:40 +0100)
src/lua/lua_trie.c

index 456610b1f7d588bd8a0fc3fc148d3a4341ba3787..b030c735add3b5027b4a1c39e883f89a8e0787b6 100644 (file)
@@ -145,6 +145,7 @@ lua_trie_create (lua_State *L)
        return 1;
 }
 
+/* Normal callback type */
 static gint
 lua_trie_callback (struct rspamd_multipattern *mp,
                guint strnum,
@@ -176,18 +177,54 @@ lua_trie_callback (struct rspamd_multipattern *mp,
        return ret;
 }
 
+/* Table like callback, expect result table on top of the stack */
+static gint
+lua_trie_table_callback (struct rspamd_multipattern *mp,
+                                  guint strnum,
+                                  gint match_start,
+                                  gint textpos,
+                                  const gchar *text,
+                                  gsize len,
+                                  void *context)
+{
+       lua_State *L = context;
+
+       /* Set table, indexed by pattern number */
+       lua_rawgeti (L, -1, strnum + 1);
+
+       if (lua_istable (L, -1)) {
+               /* Already have table, add offset */
+               gsize last = rspamd_lua_table_size (L, -1);
+               lua_pushinteger (L, textpos);
+               lua_rawseti (L, -2, last + 1);
+               /* Remove table from the stack */
+               lua_pop (L, 1);
+       }
+       else {
+               /* Pop none */
+               lua_pop (L, 1);
+               /* New table */
+               lua_newtable (L);
+               lua_pushinteger (L, textpos);
+               lua_rawseti (L, -2, 1);
+               lua_rawseti (L, -2, strnum + 1);
+       }
+
+       return 0;
+}
+
 /*
  * We assume that callback argument is at pos 3 and icase is in position 4
  */
 static gint
 lua_trie_search_str (lua_State *L, struct rspamd_multipattern *trie,
-               const gchar *str, gsize len)
+               const gchar *str, gsize len, rspamd_multipattern_cb_t cb)
 {
        gint ret;
        guint nfound = 0;
 
        if ((ret = rspamd_multipattern_lookup (trie, str, len,
-                       lua_trie_callback, L, &nfound)) == 0) {
+                       cb, L, &nfound)) == 0) {
                return nfound;
        }
 
@@ -195,12 +232,11 @@ lua_trie_search_str (lua_State *L, struct rspamd_multipattern *trie,
 }
 
 /***
- * @method trie:match(input, cb[, caseless])
+ * @method trie:match(input, [cb])
  * Search for patterns in `input` invoking `cb` optionally ignoring case
  * @param {table or string} input one or several (if `input` is an array) strings of input text
  * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends
- * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
- * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
+ * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number
  */
 static gint
 lua_trie_match (lua_State *L)
@@ -210,8 +246,16 @@ lua_trie_match (lua_State *L)
        const gchar *text;
        gsize len;
        gboolean found = FALSE;
+       struct rspamd_lua_text *t;
+       rspamd_multipattern_cb_t cb = lua_trie_callback;
 
        if (trie) {
+               if (lua_type (L, 3) != LUA_TFUNCTION) {
+                       /* Table like match */
+                       lua_newtable (L);
+                       cb = lua_trie_table_callback;
+               }
+
                if (lua_type (L, 2) == LUA_TTABLE) {
                        lua_pushvalue (L, 2);
                        lua_pushnil (L);
@@ -220,10 +264,19 @@ lua_trie_match (lua_State *L)
                                if (lua_isstring (L, -1)) {
                                        text = lua_tolstring (L, -1, &len);
 
-                                       if (lua_trie_search_str (L, trie, text, len)) {
+                                       if (lua_trie_search_str (L, trie, text, len, cb)) {
                                                found = TRUE;
                                        }
                                }
+                               else if (lua_isuserdata (L, -1)) {
+                                       t = lua_check_text (L, -1);
+
+                                       if (t) {
+                                               if (lua_trie_search_str (L, trie, t->start, t->len, cb)) {
+                                                       found = TRUE;
+                                               }
+                                       }
+                               }
                                lua_pop (L, 1);
                        }
 
@@ -232,18 +285,28 @@ lua_trie_match (lua_State *L)
                else if (lua_type (L, 2) == LUA_TSTRING) {
                        text = lua_tolstring (L, 2, &len);
 
-                       if (lua_trie_search_str (L, trie, text, len)) {
+                       if (lua_trie_search_str (L, trie, text, len, cb)) {
+                               found = TRUE;
+                       }
+               }
+               else if (lua_type (L, 2) == LUA_TUSERDATA) {
+                       t = lua_check_text (L, -1);
+
+                       if (t && lua_trie_search_str (L, trie, t->start, t->len, cb)) {
                                found = TRUE;
                        }
                }
        }
 
-       lua_pushboolean (L, found);
+       if (lua_type (L, 3) == LUA_TFUNCTION) {
+               lua_pushboolean (L, found);
+       }
+
        return 1;
 }
 
 /***
- * @method trie:search_mime(task, cb[, caseless])
+ * @method trie:search_mime(task, cb)
  * This is a helper mehthod to search pattern within text parts of a message in rspamd task
  * @param {task} task object
  * @param {function} cb callback called on each pattern match @see trie:match
@@ -260,6 +323,7 @@ lua_trie_search_mime (lua_State *L)
        const gchar *text;
        gsize len, i;
        gboolean found = FALSE;
+       rspamd_multipattern_cb_t cb = lua_trie_callback;
 
        if (trie && task) {
                PTR_ARRAY_FOREACH (MESSAGE_FIELD (task, text_parts), i, part) {
@@ -267,7 +331,7 @@ lua_trie_search_mime (lua_State *L)
                                text = part->utf_content->data;
                                len = part->utf_content->len;
 
-                               if (lua_trie_search_str (L, trie, text, len) != 0) {
+                               if (lua_trie_search_str (L, trie, text, len, cb) != 0) {
                                        found = TRUE;
                                }
                        }
@@ -300,7 +364,7 @@ lua_trie_search_rawmsg (lua_State *L)
                text = task->msg.begin;
                len = task->msg.len;
 
-               if (lua_trie_search_str (L, trie, text, len) != 0) {
+               if (lua_trie_search_str (L, trie, text, len, lua_trie_callback) != 0) {
                        found = TRUE;
                }
        }
@@ -338,7 +402,7 @@ lua_trie_search_rawbody (lua_State *L)
                        len = task->msg.len;
                }
 
-               if (lua_trie_search_str (L, trie, text, len) != 0) {
+               if (lua_trie_search_str (L, trie, text, len, lua_trie_callback) != 0) {
                        found = TRUE;
                }
        }