1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
--[[
Copyright (c) 2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--
local rspamd_logger = require "rspamd_logger"
local torch
local exports = {}
local lua_nn_models = {}
if rspamd_config:has_torch() then
torch = require "torch"
torch.setnumthreads(1)
end
if torch then
exports.load_rspamd_nn = function()
local function gen_process_callback(name)
return function(str)
if str then
local f = torch.MemoryFile(torch.CharStorage():string(str))
local ret, tnn_or_err = pcall(function() f:readObject() end)
if not ret then
rspamd_logger.errx(rspamd_config, "cannot load neural net model %s: %s",
name, tnn_or_err)
else
rspamd_logger.infox(rspamd_config, "loaded NN model %s: %s bytes",
name, #str)
lua_nn_models[name] = tnn_or_err
end
end
end
end
local section = rspamd_config:get_all_opt("nn_models")
if section and type(section) == 'table' then
for k,v in pairs(section) do
if not rspamd_config:add_map(v, "nn map " .. k, gen_process_callback(k)) then
rspamd_logger.warnx(rspamd_config, 'cannot load NN map %1', k)
end
end
end
end
exports.try_rspamd_nn = function(name, input)
if not lua_nn_models.name then
return false, 0.0
else
local ret, res_or_err = pcall(function() lua_nn_models.name:forward(input) end)
if not ret then
rspamd_logger.errx(rspamd_config, "cannot run neural net model %s: %s",
name, res_or_err)
else
return true, tonumber(res_or_err)
end
end
return false, 0.0
end
else
exports.load_rspamd_nn = function()
end
exports.try_rspamd_nn = function(name, input)
return false,0.0
end
end
return exports
|