aboutsummaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/llm_embeddings.lua
blob: d591b4db186d1e1a15b24f1469978b7f4e5ce3ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
--[[
Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com>

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
]] --

local N = "llm_embeddings"

if confighelp then
  rspamd_config:add_example(nil, N,
      "Performs statistical analysis of messages using LLM for embeddings and NN for classification",
      [[
llm_embeddings {
  # Supported types: openai, ollama
  type = "ollama";
  # Your key to access the API
  api_key = "xxx";
  # Model name
  model = "nomic-embed-text";
  # Check the documentation for the model for this value
  dimensions = 8192;
  # Maximum tokens to generate
  max_tokens = 1000;
  # URL for the API
  url = "http://localhost:11434/api/embeddings";
  # Redis parameters to save the resulting classifier
  servers = "localhost:6379";
  # Prefix for keys
  prefix = "llm";
  # How many learns are required to start classifying
  min_learns = 100;
  # Check messages with passthrough result
  allow_passthrough = false;
  # Check messages that are apparent ham (no action and negative score)
  allow_ham = false;
}
  ]])
  return
end

local lua_util = require "lua_util"
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local lua_mime = require "lua_mime"
local ucl = require "ucl"
local rspamd_kann = require "rspamd_kann"
local rspamd_tensor = require "rspamd_tensor"
local lua_redis = require "lua_redis"

local settings = {
  type = 'ollama',
  api_key = nil,
  model = 'gpt-4o-mini',
  max_tokens = 5000,
  timeout = 10,
  prompt = nil,
  condition = nil,
  autolearn = false,
  url = 'http://localhost:11434/api/embeddings',
  allow_passthrough = false,
  allow_ham = false,
  dimensions = 8192,
  hidden_layer_mult = 0.5, -- Compress in hidden layer
}

local has_blas = rspamd_tensor.has_blas()
local kann_model
local model_learns = 0
local redis_params

local function extract_data(task)
  -- Check result
  -- 1) Skip passthrough
  local result = task:get_metric_result()
  if result then
    if result.passthrough and not settings.allow_passthrough then
      return false, 'passthrough'
    end
  end

  -- Check if we have text at all
  local sel_part = lua_mime.get_displayed_text_part(task)

  if not sel_part then
    return false, 'no text part found'
  end

  -- Check limits and size sanity
  local nwords = sel_part:get_words_count()

  if nwords < 5 then
    return false, 'less than 5 words'
  end

  if nwords > settings.max_tokens then
    -- We need to truncate words (sometimes get_words_count returns a different number comparing to `get_words`)
    local words = sel_part:get_words('norm')
    nwords = #words
    if nwords > settings.max_tokens then
      return true, table.concat(words, ' ', 1, settings.max_tokens)
    end
  end
  return true, sel_part:get_content_oneline()
end

local function gen_embeddings_ollama(task, continuation_cb)
  local condition, content = extract_data(task)
  if not condition then
    return
  end

  local function embeddings_cb(err, code, data)
    if err then
      rspamd_logger.errx(task, 'cannot get embeddings: %s', err)
      return
    end

    if data then
      lua_util.debugm(N, task, 'got reply from embeddings model: %s', data)
      local parser = ucl.parser()
      local res, err = parser:parse_string(data)
      if not res then
        rspamd_logger.errx(task, 'cannot parse reply: %s', err)
        return
      end
      local reply = parser:get_object()

      if reply and type(reply) == 'table' and type(reply.embedding) == 'table' then
        lua_util.debugm(N, task, 'got embeddings: %s', #reply.embedding)
        continuation_cb(task, reply.embedding)
      else
        rspamd_logger.errx(task, 'cannot parse embeddings: %s', data)
      end
    end
  end

  local post_data = {
    model = settings.model,
    prompt = content,
  }

  rspamd_http.request({
    url = settings.url,
    task = task,
    callback = embeddings_cb,
    body = ucl.to_json(post_data),
    timeout = settings.timeout,
    headers = {
      ['Authorization'] = settings.api_key,
      ['Content-Type'] = 'application/json',
    },
  })
end

local function kann_model_create()
  local t = rspamd_kann.layer.input(settings.dimensions)
  t = rspamd_kann.transform.relu(t)
  t = rspamd_kann.layer.dense(t, settings.dimensions * settings.hidden_layer_mult);
  t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
  kann_model = rspamd_kann.new.kann(t)
end

local function redis_prefix()
  return settings.prefix .. '_' .. settings.model
end

local function kann_model_save(ev_base)
  if not redis_params then
    return
  end

  local function save_cb(err, _)
    if err then
      rspamd_logger.errx(rspamd_config, 'cannot save model: %s', err)
    end
  end

  local packed_model = kann_model:save()
  local key = redis_prefix() .. '_model'
  lua_redis.redis_make_request_taskless(ev_base, rspamd_config,
      redis_params, key, true,
      save_cb, 'SET', { key, packed_model })
end

local function kann_model_maybe_load(ev_base)
  if not redis_params then
    return
  end

  local function load_cb(err, data)
    if err then
      rspamd_logger.errx(rspamd_config, 'cannot load model: %s', err)
    else
      if data then
        kann_model = rspamd_kann.load(data)
      end
    end
  end

  local key = redis_prefix() .. '_model'
  lua_redis.redis_make_request_taskless(ev_base, rspamd_config,
      redis_params, key, false,
      load_cb, 'GET', { key })
end

local function save_embeddings_vector(task, is_spam)
  local function save_cb(err, _)
    if err then
      rspamd_logger.errx(task, 'cannot save embeddings: %s', err)
    end
  end

  local function save_vector(emb)
    local key = redis_prefix() .. (is_spam and '_spam' or '_ham')
    local packed_vector = ucl.to_format(emb, 'msgpack')
    lua_redis.redis_make_request(task,
        redis_params, key, true,
        save_cb, 'LPUSH', { key, packed_vector })
  end

  gen_embeddings_ollama(task, save_vector)
end

local function nn_learn(task, is_spam)
  if not kann_model then
    kann_model_create()
  end

  save_embeddings_vector(task, is_spam)
end

local function nn_learn_spam(task)
  lua_util.debugm(N, task, 'learn spam')
  nn_learn(task, true)
end

local function nn_learn_ham(task)
  lua_util.debugm(N, task, 'learn ham')
  nn_learn(task, false)
end

local function nn_classify(task)
  -- TODO: Implement
end

local module_config = rspamd_config:get_all_opt(N)
settings = lua_util.override_defaults(settings, module_config)
redis_params = lua_redis.parse_redis_server(N)

if not redis_params then
  rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
  lua_util.disable_module(N, "redis")
  return
end

local id = rspamd_config:register_symbol {
  name = "LLM_CLASSIFY_CHECK",
  type = 'callback',
  callback = nn_classify,
}
rspamd_config:register_symbol {
  name = "LLM_EMBEDDINGS_SPAM",
  type = 'virtual',
  parent = id,
}
rspamd_config:register_symbol {
  name = "LLM_EMBEDDINGS_HAM",
  type = 'virtual',
  parent = id,
}

-- Allow this symbol to be enabled merely explicitly when we need to learn
rspamd_config:register_symbol {
  name = "LLM_LEARN_SPAM",
  type = 'callback',
  callback = nn_learn_spam,
  flags = 'explicit_enable',
}
-- Allow this symbol to be enabled merely explicitly when we need to learn
rspamd_config:register_symbol {
  name = "LLM_LEARN_HAM",
  type = 'callback',
  callback = nn_learn_ham,
  flags = 'explicit_enable',
}