]> source.dussan.org Git - rspamd.git/commitdiff
[Minor] Add function to load sqlite params from the config
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 15 Feb 2018 13:13:18 +0000 (13:13 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 15 Feb 2018 13:13:18 +0000 (13:13 +0000)
lualib/stat_tools.lua

index 48ecb1be4b9d5d166af798520837ce9f9cd4abbf..049de44f4595c36df2cd4989851a99c698c2dbdb 100644 (file)
@@ -362,4 +362,103 @@ end
 
 exports.convert_sqlite_to_redis = convert_sqlite_to_redis
 
+-- Loads sqlite3 based classifiers and output data in form of array of objects:
+-- [
+--  {
+--  symbol_spam = XXX
+--  symbol_ham = YYY
+--  db_spam = XXX.sqlite
+--  db_ham = YYY.sqlite
+--  learn_cahe = ZZZ.sqlite
+--  per_user = true/false
+--  label = str
+--  }
+-- ]
+local function load_sqlite_config(cfg)
+  local result = {}
+
+  local function parse_classifier(cls)
+    local tbl = {}
+    if cls.cache then
+      local cache = cls.cache
+      if cache.type == 'sqlite3' and (cls.file or cls.path) then
+        tbl.learn_cache = (cls.file or cls.path)
+      end
+    end
+
+    if cls.per_user then
+      tbl.per_user = cls.per_user
+    end
+
+    if cls.label then
+      tbl.label = cls.label
+    end
+
+    local statfiles = cls.statfile
+    for _,stf in ipairs(statfiles) do
+      local path = (stf.file or stf.path or stf.db or stf.dbname)
+      local symbol = stf.symbol or 'undefined'
+
+      if not path then
+        logger.errx('no path defined for statfile %s', symbol)
+      else
+
+        local spam
+        if stf.spam then
+          spam = stf.spam
+        else
+          if string.match(symbol:upper(), 'SPAM') then
+            spam = true
+          else
+            spam = false
+          end
+        end
+
+        if spam then
+          tbl.symbol_spam = symbol
+          tbl.db_spam = path
+        else
+          tbl.symbol_ham = symbol
+          tbl.db_ham = path
+        end
+      end
+    end
+
+    if tbl.symbol_spam and tbl.symbol_ham and tbl.db_ham and tbl.db_spam then
+      table.insert(result, tbl)
+    end
+  end
+
+  local classifier = cfg.classifier
+
+  if classifier then
+    if classifier[1] then
+      for _,cls in ipairs(classifier) do
+        if cls.backend and cls.backend == 'sqlite3' then
+          parse_classifier(cls)
+        end
+      end
+    else
+      if classifier.bayes then
+        classifier = classifier.bayes
+        if classifier[1] then
+          for _,cls in ipairs(classifier) do
+            if cls.backend and cls.backend == 'sqlite3' then
+              parse_classifier(cls)
+            end
+          end
+        else
+          if classifier.backend and classifier.backend == 'sqlite3' then
+            parse_classifier(classifier)
+          end
+        end
+      end
+    end
+  end
+
+  return result
+end
+
+exports.load_sqlite_config = load_sqlite_config
+
 return exports