aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/asn.lua
blob: c31cc46c5f0cf37786bf288575f83d5715b9f38d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
--[[
Copyright (c) 2011-2016, Vsevolod Stakhov <vsevolod@highsecure.ru>
Copyright (c) 2016, Andrew Lewis <nerf@judo.za.org>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]]--

local rspamd_logger = require "rspamd_logger"
local rspamd_regexp = require "rspamd_regexp"
local rspamd_redis = require "rspamd_redis"

local options = {
  provider_type = 'rspamd',
  provider_info = {
    ip4 = 'asn.rspamd.com',
    ip6 = 'asn6.rspamd.com',
  },
  symbol = 'ASN',
  expire = 86400, -- 1 day by default
  key_prefix = 'rasn',
}
local redis_params

local rspamd_re = rspamd_regexp.create_cached("[\\|\\s]")

local function asn_check(task)

  local function asn_set(asn, ipnet, country)
    local descr_t = {}
    if asn then
      task:get_mempool():set_variable("asn", asn)
      table.insert(descr_t, "asn:" .. asn)
    end
    if ipnet then
      task:get_mempool():set_variable("ipnet", ipnet)
      table.insert(descr_t, "ipnet:" .. ipnet)
    end
    if country then
      task:get_mempool():set_variable("country", country)
      table.insert(descr_t, "country:" .. country)
    end
    if options['symbol'] then
      task:insert_result(options['symbol'], 0.0, table.concat(descr_t, ', '))
    end
  end

  local asn_check_func = {}
  function asn_check_func.rspamd(ip)
    local function rspamd_dns_cb(resolver, to_resolve, results, err, key)
      if not (results and results[1]) then return end
      local parts = rspamd_re:split(results[1])
      -- "15169 | 8.8.8.0/24 | US | arin |" for 8.8.8.8
      asn_set(parts[1], parts[2], parts[3])

      if redis_params then
        local redis_key = options.key_prefix .. ip:to_string()
        local ret,conn,upstream
        local function redis_asn_set_cb(err, data)
          if not err then
            upstream:ok()
          else
            rspamd_logger.infox(task, 'got error %s when setting asn record on server %s',
              err, upstream:get_addr())
          end
        end
        ret,conn,upstream = rspamd_redis_make_request(task,
          redis_params, -- connect params
          redis_key, -- hash key
          true, -- is write
          redis_asn_set_cb, --callback
          'HMSET', -- command
          {redis_key, "asn", parts[1], "net", parts[2], "country", parts[3]} -- arguments
        )
        if conn then
          conn:add_cmd('EXPIRE', {
            redis_key, tostring(options['expire'])
          })
        else
          rspamd_logger.infox(task, 'got error while connecting to redis: %1', upstream:get_addr())
          upstream:fail()
        end
      end
    end
    local dnsbl = options['provider_info']['ip' .. ip:get_version()]
    local req_name = rspamd_logger.slog("%1.%2",
        table.concat(ip:inversed_str_octets(), '.'), dnsbl)
    task:get_resolver():resolve_txt(task:get_session(), task:get_mempool(),
        req_name, rspamd_dns_cb)
  end

  local function asn_check_cache(ip, continuation_func)
    local key = options.key_prefix .. ip:to_string()

    local function redis_asn_get_cb(err, data)
      if err or not data or type(data[1]) ~= 'string' then
        continuation_func(ip)
      else
        asn_set(data[1], data[2], data[3])
        -- Refresh key
        local function redis_asn_expire_cb(err, data)
        end

        local ret,_,_ = rspamd_redis_make_request(task,
          redis_params, -- connect params
          key, -- hash key
          true, -- is write
          redis_asn_expire_cb, --callback
          'EXPIRE', -- command
          {key, tostring(options.expire)} -- arguments
        )
      end
    end

    local ret,_,_ = rspamd_redis_make_request(task,
      redis_params, -- connect params
      key, -- hash key
      false, -- is write
      redis_asn_get_cb, --callback
      'HMGET', -- command
      {key, "asn", "net", "country"} -- arguments
    )

    if not ret then
      continuation_func(ip)
    end
  end

  local ip = task:get_from_ip()
  if not (ip and ip:is_valid()) then return end

  if not redis_params then
    asn_check_func[options['provider_type']](ip)
  else
    asn_check_cache(ip, asn_check_func[options['provider_type']])
  end
end

-- Configuration options
local configure_asn_module = function()
  local opts =  rspamd_config:get_all_opt('asn')
  if opts then
    for k,v in pairs(opts) do
      options[k] = v
    end
  end
  if options['provider_type'] == 'rspamd' then
    if not options['provider_info'] and options['provider_info']['ip4'] and
        options['provider_info']['ip6'] then
      rspamd_logger.errx("Missing required provider_info for rspamd")
      return false
    end
  else
    rspamd_logger.errx("Unknown provider_type: %s", options['provider_type'])
    return false
  end
  redis_params = rspamd_parse_redis_server('asn')
  return true
end

if configure_asn_module() then
  rspamd_config:register_symbol({
    name = 'ASN_CHECK',
    type = 'prefilter',
    callback = asn_check,
    priority = 10,
  })
end