aboutsummaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm
diff options
context:
space:
mode:
authorVsevolod Stakhov <vsevolod@highsecure.ru>2017-11-04 12:42:14 +0000
committerVsevolod Stakhov <vsevolod@highsecure.ru>2017-11-04 12:42:14 +0000
commit3cf2f3130034e46ff0224cc576626599422e8401 (patch)
tree11c85e27beda01882a63b6ff3920610816d5ab4b /lualib/rspamadm
parent4d589bd8919effafcdb79005a3c2a5eb461480f3 (diff)
downloadrspamd-3cf2f3130034e46ff0224cc576626599422e8401.tar.gz
rspamd-3cf2f3130034e46ff0224cc576626599422e8401.zip
[Rework] Stop embedding rspamadm scripts into C
Diffstat (limited to 'lualib/rspamadm')
-rw-r--r--lualib/rspamadm/ansicolors.lua58
-rw-r--r--lualib/rspamadm/confighelp.lua111
-rw-r--r--lualib/rspamadm/fuzzy_convert.lua198
-rw-r--r--lualib/rspamadm/fuzzy_stat.lua274
-rw-r--r--lualib/rspamadm/getopt.lua34
-rw-r--r--lualib/rspamadm/grep.lua112
-rw-r--r--lualib/rspamadm/stat_convert.lua225
7 files changed, 1012 insertions, 0 deletions
diff --git a/lualib/rspamadm/ansicolors.lua b/lualib/rspamadm/ansicolors.lua
new file mode 100644
index 000000000..739cf427c
--- /dev/null
+++ b/lualib/rspamadm/ansicolors.lua
@@ -0,0 +1,58 @@
+local colormt = {}
+local ansicolors = {}
+
+function colormt:__tostring()
+ return self.value
+end
+
+function colormt:__concat(other)
+ return tostring(self) .. tostring(other)
+end
+
+function colormt:__call(s)
+ return self .. s .. ansicolors.reset
+end
+
+colormt.__metatable = {}
+
+local function makecolor(value)
+ return setmetatable({ value = string.char(27) .. '[' .. tostring(value) .. 'm' }, colormt)
+end
+
+local colors = {
+ -- attributes
+ reset = 0,
+ clear = 0,
+ bright = 1,
+ dim = 2,
+ underscore = 4,
+ blink = 5,
+ reverse = 7,
+ hidden = 8,
+
+ -- foreground
+ black = 30,
+ red = 31,
+ green = 32,
+ yellow = 33,
+ blue = 34,
+ magenta = 35,
+ cyan = 36,
+ white = 37,
+
+ -- background
+ onblack = 40,
+ onred = 41,
+ ongreen = 42,
+ onyellow = 43,
+ onblue = 44,
+ onmagenta = 45,
+ oncyan = 46,
+ onwhite = 47,
+}
+
+for c, v in pairs(colors) do
+ ansicolors[c] = makecolor(v)
+end
+
+return ansicolors \ No newline at end of file
diff --git a/lualib/rspamadm/confighelp.lua b/lualib/rspamadm/confighelp.lua
new file mode 100644
index 000000000..a03578b6e
--- /dev/null
+++ b/lualib/rspamadm/confighelp.lua
@@ -0,0 +1,111 @@
+local opts = {}
+local known_attrs = {
+ data = 1,
+ example = 1,
+ type = 1,
+ required = 1,
+ default = 1,
+}
+
+local getopt = require "rspamadm/getopt"
+local ansicolors = require "rspamadm/ansicolors"
+
+local function maybe_print_color(key)
+ if not opts['no-color'] then
+ return ansicolors.white .. key .. ansicolors.reset
+ else
+ return key
+ end
+end
+
+local function sort_values(tbl)
+ local res = {}
+ for k, v in pairs(tbl) do
+ table.insert(res, { key = k, value = v })
+ end
+
+ -- Sort order
+ local order = {
+ options = 1,
+ dns = 2,
+ upstream = 3,
+ logging = 4,
+ metric = 5,
+ composite = 6,
+ classifier = 7,
+ modules = 8,
+ lua = 9,
+ worker = 10,
+ workers = 11,
+ }
+
+ table.sort(res, function(a, b)
+ local oa = order[a['key']]
+ local ob = order[b['key']]
+
+ if oa and ob then
+ return oa < ob
+ elseif oa then
+ return -1 < 0
+ elseif ob then
+ return 1 < 0
+ else
+ return a['key'] < b['key']
+ end
+
+ end)
+
+ return res
+end
+
+local function print_help(key, value, tabs)
+ print(string.format('%sConfiguration element: %s', tabs, maybe_print_color(key)))
+
+ if not opts['short'] then
+ if value['data'] then
+ local nv = string.match(value['data'], '^#%s*(.*)%s*$') or value.data
+ print(string.format('%s\tDescription: %s', tabs, nv))
+ end
+ if value['type'] then
+ print(string.format('%s\tType: %s', tabs, value['type']))
+ end
+ if type(value['required']) == 'boolean' then
+ if value['required'] then
+ print(string.format('%s\tRequired: %s', tabs,
+ maybe_print_color(tostring(value['required']))))
+ else
+ print(string.format('%s\tRequired: %s', tabs,
+ tostring(value['required'])))
+ end
+ end
+ if value['default'] then
+ print(string.format('%s\tDefault: %s', tabs, value['default']))
+ end
+ if not opts['no-examples'] and value['example'] then
+ local nv = string.match(value['example'], '^%s*(.*[^%s])%s*$') or value.example
+ print(string.format('%s\tExample:\n%s', tabs, nv))
+ end
+ if value.type and value.type == 'object' then
+ print('')
+ end
+ end
+
+ local sorted = sort_values(value)
+ for _, v in ipairs(sorted) do
+ if not known_attrs[v['key']] then
+ -- We need to go deeper
+ print_help(v['key'], v['value'], tabs .. '\t')
+ end
+ end
+end
+
+return function(args, res)
+ opts = getopt.getopt(args, '')
+
+ local sorted = sort_values(res)
+
+ for _,v in ipairs(sorted) do
+ print_help(v['key'], v['value'], '')
+ print('')
+ end
+end
diff --git a/lualib/rspamadm/fuzzy_convert.lua b/lualib/rspamadm/fuzzy_convert.lua
new file mode 100644
index 000000000..2d473ca46
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_convert.lua
@@ -0,0 +1,198 @@
+local sqlite3 = require "rspamd_sqlite3"
+local redis = require "rspamd_redis"
+local util = require "rspamd_util"
+
+local function connect_redis(server, password, db)
+ local ret
+ local conn, err = redis.connect_sync({
+ host = server,
+ })
+
+ if not conn then
+ return nil, 'Cannot connect: ' .. err
+ end
+
+ if password then
+ ret = conn:add_cmd('AUTH', {password})
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ end
+ if db then
+ ret = conn:add_cmd('SELECT', {db})
+ if not ret then
+ return nil, 'Cannot queue command'
+ end
+ end
+
+ return conn, nil
+end
+
+local function send_digests(digests, redis_host, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_password, redis_db)
+ if err then
+ print(err)
+ return false
+ end
+ local ret
+ for _, v in ipairs(digests) do
+ ret = conn:add_cmd('HMSET', {
+ 'fuzzy' .. v[1],
+ 'F', v[2],
+ 'V', v[3],
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret = conn:add_cmd('EXPIRE', {
+ 'fuzzy' .. v[1],
+ tostring(v[4]),
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+local function send_shingles(shingles, redis_host, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_password, redis_db)
+ if err then
+ print("Redis error: " .. err)
+ return false
+ end
+ local ret
+ for _, v in ipairs(shingles) do
+ ret = conn:add_cmd('SET', {
+ 'fuzzy_' .. v[2] .. '_' .. v[1],
+ v[4],
+ })
+ if not ret then
+ print('Cannot batch SET command: ' .. err)
+ return false
+ end
+ ret = conn:add_cmd('EXPIRE', {
+ 'fuzzy_' .. v[2] .. '_' .. v[1],
+ tostring(v[3]),
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+local function update_counters(total, redis_host, redis_password, redis_db)
+ local conn, err = connect_redis(redis_host, redis_password, redis_db)
+ if err then
+ print(err)
+ return false
+ end
+ local ret
+ ret = conn:add_cmd('SET', {
+ 'fuzzylocal',
+ total,
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret = conn:add_cmd('SET', {
+ 'fuzzy_count',
+ total,
+ })
+ if not ret then
+ print('Cannot batch command')
+ return false
+ end
+ ret, err = conn:exec()
+ if not ret then
+ print('Cannot execute batched commands: ' .. err)
+ return false
+ end
+ return true
+end
+
+return function (_, res)
+ local db = sqlite3.open(res['source_db'])
+ local shingles = {}
+ local digests = {}
+ local num_batch_digests = 0
+ local num_batch_shingles = 0
+ local total_digests = 0
+ local total_shingles = 0
+ local lim_batch = 1000 -- Update each 1000 entries
+ local redis_password = res['redis_password']
+ local redis_db = nil
+
+ if res['redis_db'] then
+ redis_db = tostring(res['redis_db'])
+ end
+
+ if not db then
+ print('Cannot open source db: ' .. res['source_db'])
+ return
+ end
+
+ local now = util.get_time()
+ for row in db:rows('SELECT id, flag, digest, value, time FROM digests') do
+
+ local expire_in = math.floor(now - row.time + res['expiry'])
+ if expire_in >= 1 then
+ table.insert(digests, {row.digest, row.flag, row.value, expire_in})
+ num_batch_digests = num_batch_digests + 1
+ total_digests = total_digests + 1
+ for srow in db:rows('SELECT value, number FROM shingles WHERE digest_id = ' .. row.id) do
+ table.insert(shingles, {srow.value, srow.number, expire_in, row.digest})
+ total_shingles = total_shingles + 1
+ num_batch_shingles = num_batch_shingles + 1
+ end
+ end
+ if num_batch_digests >= lim_batch then
+ if not send_digests(digests, res['redis_host'], redis_password, redis_db) then
+ return
+ end
+ num_batch_digests = 0
+ digests = {}
+ end
+ if num_batch_shingles >= lim_batch then
+ if not send_shingles(shingles, res['redis_host'], redis_password, redis_db) then
+ return
+ end
+ num_batch_shingles = 0
+ shingles = {}
+ end
+ end
+ if digests[1] then
+ if not send_digests(digests, res['redis_host'], redis_password, redis_db) then
+ return
+ end
+ end
+ if shingles[1] then
+ if not send_shingles(shingles, res['redis_host'], redis_password, redis_db) then
+ return
+ end
+ end
+
+ local message = string.format(
+ 'Migrated %d digests and %d shingles',
+ total_digests, total_shingles
+ )
+ if not update_counters(total_digests, res['redis_host'], redis_password, redis_db) then
+ message = message .. ' but failed to update counters'
+ end
+ print(message)
+end
diff --git a/lualib/rspamadm/fuzzy_stat.lua b/lualib/rspamadm/fuzzy_stat.lua
new file mode 100644
index 000000000..748dbda20
--- /dev/null
+++ b/lualib/rspamadm/fuzzy_stat.lua
@@ -0,0 +1,274 @@
+local util = require "rspamd_util"
+local opts = {}
+
+local function add_data(target, src)
+ for k,v in pairs(src) do
+ if k ~= 'ips' then
+ if target[k] then
+ target[k] = target[k] + v
+ else
+ target[k] = v
+ end
+ else
+ if not target['ips'] then target['ips'] = {} end
+ -- Iterate over IPs
+ for ip,st in pairs(v) do
+ if not target['ips'][ip] then target['ips'][ip] = {} end
+ add_data(target['ips'][ip], st)
+ end
+ end
+ end
+end
+
+local function print_num(num)
+ if opts['n'] or opts['number'] then
+ return tostring(num)
+ else
+ return util.humanize_number(num)
+ end
+end
+
+local function print_stat(st, tabs)
+ if st['checked'] then
+ print(string.format('%sChecked: %s', tabs, print_num(st['checked'])))
+ end
+ if st['matched'] then
+ print(string.format('%sMatched: %s', tabs, print_num(st['matched'])))
+ end
+ if st['errors'] then
+ print(string.format('%sErrors: %s', tabs, print_num(st['errors'])))
+ end
+ if st['added'] then
+ print(string.format('%sAdded: %s', tabs, print_num(st['added'])))
+ end
+ if st['deleted'] then
+ print(string.format('%sDeleted: %s', tabs, print_num(st['deleted'])))
+ end
+end
+
+-- Sort by checked
+local function sort_ips(tbl, _opts)
+ local res = {}
+ for k,v in pairs(tbl) do
+ table.insert(res, {ip = k, data = v})
+ end
+
+ local function sort_order(elt)
+ local key = 'checked'
+ local _res = 0
+
+ if _opts['sort'] then
+ if _opts['sort'] == 'matched' then
+ key = 'matched'
+ elseif _opts['sort'] == 'errors' then
+ key = 'errors'
+ elseif _opts['sort'] == 'ip' then
+ return elt['ip']
+ end
+ end
+
+ if elt['data'][key] then
+ _res = elt['data'][key]
+ end
+
+ return _res
+ end
+
+ table.sort(res, function(a, b)
+ return sort_order(a) > sort_order(b)
+ end)
+
+ return res
+end
+
+local function add_result(dst, src, k)
+ if type(src) == 'table' then
+ if type(dst) == 'number' then
+ -- Convert dst to table
+ dst = {dst}
+ elseif type(dst) == 'nil' then
+ dst = {}
+ end
+
+ for i,v in ipairs(src) do
+ if dst[i] and k ~= 'fuzzy_stored' then
+ dst[i] = dst[i] + v
+ else
+ dst[i] = v
+ end
+ end
+ else
+ if type(dst) == 'table' then
+ if k ~= 'fuzzy_stored' then
+ dst[1] = dst[1] + src
+ else
+ dst[1] = src
+ end
+ else
+ if dst and k ~= 'fuzzy_stored' then
+ dst = dst + src
+ else
+ dst = src
+ end
+ end
+ end
+
+ return dst
+end
+
+local function print_result(r)
+ local function num_to_epoch(num)
+ if num == 1 then
+ return 'v0.6'
+ elseif num == 2 then
+ return 'v0.8'
+ elseif num == 3 then
+ return 'v0.9'
+ elseif num == 4 then
+ return 'v1.0+'
+ end
+ return '???'
+ end
+ if type(r) == 'table' then
+ local res = {}
+ for i,num in ipairs(r) do
+ res[i] = string.format('(%s: %s)', num_to_epoch(i), print_num(num))
+ end
+
+ return table.concat(res, ', ')
+ end
+
+ return print_num(r)
+end
+
+local getopt = require "rspamadm/getopt"
+
+return function(args, res)
+ local res_ips = {}
+ local res_databases = {}
+ local wrk = res['workers']
+ opts = getopt.getopt(args, '')
+
+ if wrk then
+ for _,pr in pairs(wrk) do
+ -- processes cycle
+ if pr['data'] then
+ local id = pr['id']
+
+ if id then
+ local res_db = res_databases[id]
+ if not res_db then
+ res_db = {
+ keys = {}
+ }
+ res_databases[id] = res_db
+ end
+
+ -- General stats
+ for k,v in pairs(pr['data']) do
+ if k ~= 'keys' and k ~= 'errors_ips' then
+ res_db[k] = add_result(res_db[k], v, k)
+ elseif k == 'errors_ips' then
+ -- Errors ips
+ if not res_db['errors_ips'] then
+ res_db['errors_ips'] = {}
+ end
+ for ip,nerrors in pairs(v) do
+ if not res_db['errors_ips'][ip] then
+ res_db['errors_ips'][ip] = nerrors
+ else
+ res_db['errors_ips'][ip] = nerrors + res_db['errors_ips'][ip]
+ end
+ end
+ end
+ end
+
+ if pr['data']['keys'] then
+ local res_keys = res_db['keys']
+ if not res_keys then
+ res_keys = {}
+ res_db['keys'] = res_keys
+ end
+ -- Go through keys in input
+ for k,elts in pairs(pr['data']['keys']) do
+ -- keys cycle
+ if not res_keys[k] then
+ res_keys[k] = {}
+ end
+
+ add_data(res_keys[k], elts)
+
+ if elts['ips'] then
+ for ip,v in pairs(elts['ips']) do
+ if not res_ips[ip] then
+ res_ips[ip] = {}
+ end
+ add_data(res_ips[ip], v)
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+ end
+
+ -- General stats
+ for db,st in pairs(res_databases) do
+ print(string.format('Statistics for storage %s', db))
+
+ for k,v in pairs(st) do
+ if k ~= 'keys' and k ~= 'errors_ips' then
+ print(string.format('%s: %s', k, print_result(v)))
+ end
+ end
+ print('')
+
+ local res_keys = st['keys']
+ if res_keys and not opts['no-keys'] and not opts['short'] then
+ print('Keys statistics:')
+ for k,_st in pairs(res_keys) do
+ print(string.format('Key id: %s', k))
+ print_stat(_st, '\t')
+
+ if _st['ips'] and not opts['no-ips'] then
+ print('')
+ print('\tIPs stat:')
+ local sorted_ips = sort_ips(_st['ips'], opts)
+
+ for _,v in ipairs(sorted_ips) do
+ print(string.format('\t%s', v['ip']))
+ print_stat(v['data'], '\t\t')
+ print('')
+ end
+ end
+
+ print('')
+ end
+ end
+ if st['errors_ips'] and not opts['no-ips'] and not opts['short'] then
+ print('')
+ print('Errors IPs statistics:')
+ table.sort(st['errors_ips'], function(a, b)
+ return a > b
+ end)
+ for i, v in pairs(st['errors_ips']) do
+ print(string.format('%s: %s', i, print_result(v)))
+ end
+ print('')
+ end
+ end
+
+ if not opts['no-ips'] and not opts['short'] then
+ print('')
+ print('IPs statistics:')
+
+ local sorted_ips = sort_ips(res_ips, opts)
+ for _, v in ipairs(sorted_ips) do
+ print(string.format('%s', v['ip']))
+ print_stat(v['data'], '\t')
+ print('')
+ end
+ end
+end
+
diff --git a/lualib/rspamadm/getopt.lua b/lualib/rspamadm/getopt.lua
new file mode 100644
index 000000000..bd0a2f67e
--- /dev/null
+++ b/lualib/rspamadm/getopt.lua
@@ -0,0 +1,34 @@
+local function getopt(arg, options)
+ local tab = {}
+ for k, v in ipairs(arg) do
+ if string.sub(v, 1, 2) == "--" then
+ local x = string.find(v, "=", 1, true)
+ if x then tab[string.sub(v, 3, x - 1)] = string.sub(v, x + 1)
+ else tab[string.sub(v, 3)] = true
+ end
+ elseif string.sub(v, 1, 1) == "-" then
+ local y = 2
+ local l = string.len(v)
+ local jopt
+ while (y <= l) do
+ jopt = string.sub(v, y, y)
+ if string.find(options, jopt, 1, true) then
+ if y < l then
+ tab[jopt] = string.sub(v, y + 1)
+ y = l
+ else
+ tab[jopt] = arg[k + 1]
+ end
+ else
+ tab[jopt] = true
+ end
+ y = y + 1
+ end
+ end
+ end
+ return tab
+end
+
+return {
+ getopt = getopt
+}
diff --git a/lualib/rspamadm/grep.lua b/lualib/rspamadm/grep.lua
new file mode 100644
index 000000000..a9d4b084a
--- /dev/null
+++ b/lualib/rspamadm/grep.lua
@@ -0,0 +1,112 @@
+return function(_, res)
+
+ local rspamd_regexp = require 'rspamd_regexp'
+
+ local buffer = {}
+ local matches = {}
+
+ local pattern = res['pattern']
+ local re
+ if pattern then
+ re = rspamd_regexp.create(pattern)
+ if not re then
+ io.stderr:write("Couldn't compile regex: " .. pattern .. '\n')
+ os.exit(1)
+ end
+ end
+
+ local plainm = true
+ if res['luapat'] then
+ plainm = false
+ end
+ local orphans = res['orphans']
+ local search_str = res['string']
+ local sensitive = res['sensitive']
+ local partial = res['partial']
+ if search_str and not sensitive then
+ search_str = string.lower(search_str)
+ end
+ local inputs = res['inputs']
+
+ for _, n in ipairs(inputs) do
+ local h, err
+ if string.match(n, '%.xz$') then
+ h, err = io.popen('xzcat ' .. n, 'r')
+ elseif string.match(n, '%.bz2$') then
+ h, err = io.popen('bzcat ' .. n, 'r')
+ elseif string.match(n, '%.gz$') then
+ h, err = io.popen('zcat ' .. n, 'r')
+ elseif n == 'stdin' then
+ h = io.input()
+ else
+ h, err = io.open(n, 'r')
+ end
+ if not h then
+ if err then
+ io.stderr:write("Couldn't open file (" .. n .. '): ' .. err .. '\n')
+ else
+ io.stderr:write("Couldn't open file (" .. n .. '): no error\n')
+ end
+ else
+ for line in h:lines() do
+ local hash = string.match(line, '<(%x+)>')
+ local already_matching = false
+ if hash then
+ if matches[hash] then
+ table.insert(matches[hash], line)
+ already_matching = true
+ else
+ if buffer[hash] then
+ table.insert(buffer[hash], line)
+ else
+ buffer[hash] = {line}
+ end
+ end
+ end
+ local ismatch = false
+ if re then
+ ismatch = re:match(line)
+ elseif sensitive and search_str then
+ ismatch = string.find(line, search_str, 1, plainm)
+ elseif search_str then
+ local lwr = string.lower(line)
+ ismatch = string.find(lwr, search_str, 1, plainm)
+ end
+ if ismatch then
+ if not hash then
+ if orphans then
+ print('*** orphaned ***')
+ print(line)
+ print()
+ end
+ elseif not already_matching then
+ matches[hash] = buffer[hash]
+ end
+ end
+ local is_end = string.match(line, '<%x+>; task; rspamd_protocol_http_reply:')
+ if is_end then
+ buffer[hash] = nil
+ if matches[hash] then
+ for _, v in ipairs(matches[hash]) do
+ print(v)
+ end
+ print()
+ matches[hash] = nil
+ end
+ end
+ end
+ if partial then
+ for k, v in pairs(matches) do
+ print('*** partial ***')
+ for _, vv in ipairs(v) do
+ print(vv)
+ end
+ print()
+ matches[k] = nil
+ end
+ else
+ matches = {}
+ end
+ end
+ end
+end
diff --git a/lualib/rspamadm/stat_convert.lua b/lualib/rspamadm/stat_convert.lua
new file mode 100644
index 000000000..7b6de9836
--- /dev/null
+++ b/lualib/rspamadm/stat_convert.lua
@@ -0,0 +1,225 @@
+local sqlite3 = require "rspamd_sqlite3"
+local redis = require "rspamd_redis"
+local util = require "rspamd_util"
+
+local function send_redis(server, symbol, tokens, password, db, cmd)
+ local ret = true
+ local conn,err = redis.connect_sync({
+ host = server,
+ })
+
+ local err_str
+
+ if not conn then
+ print('Cannot connect to ' .. server .. ' error: ' .. err)
+ return false, err
+ end
+
+ if password then
+ conn:add_cmd('AUTH', {password})
+ end
+ if db then
+ conn:add_cmd('SELECT', {db})
+ end
+
+ for _,t in ipairs(tokens) do
+ if not conn:add_cmd(cmd, {symbol .. t[3], t[1], t[2]}) then
+ ret = false
+ err_str = 'add command failure' .. string.format('%s %s',
+ cmd, table.concat({symbol .. t[3], t[1], t[2]}, ' '))
+ end
+ end
+
+ if ret then
+ ret,err_str = conn:exec()
+ end
+
+ return ret,err_str
+end
+
+local function convert_learned(cache, server, password, redis_db)
+ local converted = 0
+ local db = sqlite3.open(cache)
+ local ret = true
+ local err_str
+
+ if not db then
+ print('Cannot open cache database: ' .. cache)
+ return false
+ end
+
+ db:sql('BEGIN;')
+
+ local conn,err = redis.connect_sync({
+ host = server,
+ })
+
+ if not conn then
+ print('Cannot connect to ' .. server .. ' error: ' .. err)
+ return false
+ end
+
+ if password then
+ conn:add_cmd('AUTH', {password})
+ end
+ if redis_db then
+ conn:add_cmd('SELECT', {redis_db})
+ end
+
+ for row in db:rows('SELECT * FROM learns;') do
+ local is_spam
+ local digest = tostring(util.encode_base32(row.digest))
+
+ if row.flag == '0' then
+ is_spam = '-1'
+ else
+ is_spam = '1'
+ end
+
+ if not conn:add_cmd('HSET', {'learned_ids', digest, is_spam}) then
+ print('Cannot add hash: ' .. digest)
+ ret = false
+ else
+ converted = converted + 1
+ end
+ end
+ db:sql('COMMIT;')
+
+ if ret then
+ ret,err_str = conn:exec()
+ end
+
+ if ret then
+ print(string.format('Converted %d cached items from sqlite3 learned cache to redis',
+ converted))
+ else
+ print('Error occurred during sending data to redis: ' .. err_str)
+ end
+
+ return ret
+end
+
+return function (_, res)
+ local db = sqlite3.open(res['source_db'])
+ local tokens = {}
+ local num = 0
+ local total = 0
+ local nusers = 0
+ local lim = 1000 -- Update each 1000 tokens
+ local users_map = {}
+ local learns = {}
+ local redis_password = res['redis_password']
+ local redis_db = nil
+ local cmd = 'HINCRBY'
+ local ret, err_str
+
+ if res['redis_db'] then
+ redis_db = tostring(res['redis_db'])
+ end
+ if res['reset_previous'] then
+ cmd = 'HSET'
+ end
+
+ if res['cache_db'] then
+ if not convert_learned(res['cache_db'], res['redis_host'],
+ redis_password, redis_db) then
+ print('Cannot convert learned cache to redis')
+ return
+ end
+ end
+
+ if not db then
+ print('Cannot open source db: ' .. res['source_db'])
+ return
+ end
+
+ db:sql('BEGIN;')
+ -- Fill users mapping
+ for row in db:rows('SELECT * FROM users;') do
+ if row.id == '0' then
+ users_map[row.id] = ''
+ else
+ users_map[row.id] = row.name
+ end
+ learns[row.id] = row.learns
+ nusers = nusers + 1
+ end
+
+ -- Workaround for old databases
+ for row in db:rows('SELECT * FROM languages') do
+ if learns['0'] then
+ learns['0'] = learns['0'] + row.learns
+ else
+ learns['0'] = row.learns
+ end
+ end
+
+ -- Fill tokens, sending data to redis each `lim` records
+ for row in db:rows('SELECT token,value,user FROM tokens;') do
+ local user = ''
+ if row.user ~= 0 and users_map[row.user] then
+ user = users_map[row.user]
+ end
+
+ table.insert(tokens, {row.token, row.value, user})
+
+ num = num + 1
+ total = total + 1
+ if num > lim then
+ ret,err_str = send_redis(res['redis_host'], res['symbol'],
+ tokens, redis_password, redis_db, cmd)
+ if not ret then
+ print('Cannot send tokens to the redis server: ' .. err_str)
+ return
+ end
+
+ num = 0
+ tokens = {}
+ end
+ end
+ if #tokens > 0 then
+ ret, err_str = send_redis(res['redis_host'], res['symbol'], tokens,
+ redis_password, redis_db, cmd)
+
+ if not ret then
+ print('Cannot send tokens to the redis server: ' .. err_str)
+ return
+ end
+ end
+ -- Now update all users
+ local conn,err = redis.connect_sync({
+ host = res['redis_host'],
+ })
+
+ if not conn then
+ print('Cannot connect to ' .. res['redis_host'] .. ' error: ' .. err)
+ return false
+ end
+
+ if redis_password then
+ conn:add_cmd('AUTH', {redis_password})
+ end
+ if redis_db then
+ conn:add_cmd('SELECT', {redis_db})
+ end
+
+ for id,learned in pairs(learns) do
+ local user = users_map[id]
+ if not conn:add_cmd(cmd, {res['symbol'] .. user, 'learns', learned}) then
+ print('Cannot update learns for user: ' .. user)
+ end
+ if not conn:add_cmd('SADD', {res['symbol'] .. '_keys', res['symbol'] .. user}) then
+ print('Cannot update learns for user: ' .. user)
+ end
+ end
+ db:sql('COMMIT;')
+
+ ret = conn:exec()
+
+ if ret then
+ print(string.format('Migrated %d tokens for %d users for symbol %s',
+ total, nusers, res['symbol']))
+ else
+ print('Error occurred during sending data to redis')
+ end
+end