diff options
author | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-12-19 16:52:27 +0000 |
---|---|---|
committer | Vsevolod Stakhov <vsevolod@rspamd.com> | 2024-12-19 16:52:27 +0000 |
commit | 4d0cb6a366ae7c3d2c3f19e2d4c91d9b46ce4870 (patch) | |
tree | 401acd9569ac5ed39fdf1e767d3b2d1c7cd13866 /src | |
parent | d35385f951ee38dfdd0bedc77eb7e2d1e5809e40 (diff) | |
download | rspamd-vstakhov-llm-embeddings.tar.gz rspamd-vstakhov-llm-embeddings.zip |
[Project] Add skeleton for embeddings modulevstakhov-llm-embeddings
Diffstat (limited to 'src')
-rw-r--r-- | src/plugins/lua/llm_embeddings.lua | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/src/plugins/lua/llm_embeddings.lua b/src/plugins/lua/llm_embeddings.lua new file mode 100644 index 000000000..d591b4db1 --- /dev/null +++ b/src/plugins/lua/llm_embeddings.lua @@ -0,0 +1,295 @@ +--[[ +Copyright (c) 2024, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +local N = "llm_embeddings" + +if confighelp then + rspamd_config:add_example(nil, N, + "Performs statistical analysis of messages using LLM for embeddings and NN for classification", + [[ +llm_embeddings { + # Supported types: openai, ollama + type = "ollama"; + # Your key to access the API + api_key = "xxx"; + # Model name + model = "nomic-embed-text"; + # Check the documentation for the model for this value + dimensions = 8192; + # Maximum tokens to generate + max_tokens = 1000; + # URL for the API + url = "http://localhost:11434/api/embeddings"; + # Redis parameters to save the resulting classifier + servers = "localhost:6379"; + # Prefix for keys + prefix = "llm"; + # How many learns are required to start classifying + min_learns = 100; + # Check messages with passthrough result + allow_passthrough = false; + # Check messages that are apparent ham (no action and negative score) + allow_ham = false; +} + ]]) + return +end + +local lua_util = require "lua_util" +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local lua_mime = require "lua_mime" +local ucl = require "ucl" +local rspamd_kann = require "rspamd_kann" +local rspamd_tensor = require "rspamd_tensor" +local lua_redis = require "lua_redis" + +local settings = { + type = 'ollama', + api_key = nil, + model = 'gpt-4o-mini', + max_tokens = 5000, + timeout = 10, + prompt = nil, + condition = nil, + autolearn = false, + url = 'http://localhost:11434/api/embeddings', + allow_passthrough = false, + allow_ham = false, + dimensions = 8192, + hidden_layer_mult = 0.5, -- Compress in hidden layer +} + +local has_blas = rspamd_tensor.has_blas() +local kann_model +local model_learns = 0 +local redis_params + +local function extract_data(task) + -- Check result + -- 1) Skip passthrough + local result = task:get_metric_result() + if result then + if result.passthrough and not settings.allow_passthrough then + return false, 'passthrough' + end + end + + -- Check if we have text at all + local sel_part = lua_mime.get_displayed_text_part(task) + + if not sel_part then + return false, 'no text part found' + end + + -- Check limits and size sanity + local nwords = sel_part:get_words_count() + + if nwords < 5 then + return false, 'less than 5 words' + end + + if nwords > settings.max_tokens then + -- We need to truncate words (sometimes get_words_count returns a different number comparing to `get_words`) + local words = sel_part:get_words('norm') + nwords = #words + if nwords > settings.max_tokens then + return true, table.concat(words, ' ', 1, settings.max_tokens) + end + end + return true, sel_part:get_content_oneline() +end + +local function gen_embeddings_ollama(task, continuation_cb) + local condition, content = extract_data(task) + if not condition then + return + end + + local function embeddings_cb(err, code, data) + if err then + rspamd_logger.errx(task, 'cannot get embeddings: %s', err) + return + end + + if data then + lua_util.debugm(N, task, 'got reply from embeddings model: %s', data) + local parser = ucl.parser() + local res, err = parser:parse_string(data) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + + if reply and type(reply) == 'table' and type(reply.embedding) == 'table' then + lua_util.debugm(N, task, 'got embeddings: %s', #reply.embedding) + continuation_cb(task, reply.embedding) + else + rspamd_logger.errx(task, 'cannot parse embeddings: %s', data) + end + end + end + + local post_data = { + model = settings.model, + prompt = content, + } + + rspamd_http.request({ + url = settings.url, + task = task, + callback = embeddings_cb, + body = ucl.to_json(post_data), + timeout = settings.timeout, + headers = { + ['Authorization'] = settings.api_key, + ['Content-Type'] = 'application/json', + }, + }) +end + +local function kann_model_create() + local t = rspamd_kann.layer.input(settings.dimensions) + t = rspamd_kann.transform.relu(t) + t = rspamd_kann.layer.dense(t, settings.dimensions * settings.hidden_layer_mult); + t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) + kann_model = rspamd_kann.new.kann(t) +end + +local function redis_prefix() + return settings.prefix .. '_' .. settings.model +end + +local function kann_model_save(ev_base) + if not redis_params then + return + end + + local function save_cb(err, _) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save model: %s', err) + end + end + + local packed_model = kann_model:save() + local key = redis_prefix() .. '_model' + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, + redis_params, key, true, + save_cb, 'SET', { key, packed_model }) +end + +local function kann_model_maybe_load(ev_base) + if not redis_params then + return + end + + local function load_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot load model: %s', err) + else + if data then + kann_model = rspamd_kann.load(data) + end + end + end + + local key = redis_prefix() .. '_model' + lua_redis.redis_make_request_taskless(ev_base, rspamd_config, + redis_params, key, false, + load_cb, 'GET', { key }) +end + +local function save_embeddings_vector(task, is_spam) + local function save_cb(err, _) + if err then + rspamd_logger.errx(task, 'cannot save embeddings: %s', err) + end + end + + local function save_vector(emb) + local key = redis_prefix() .. (is_spam and '_spam' or '_ham') + local packed_vector = ucl.to_format(emb, 'msgpack') + lua_redis.redis_make_request(task, + redis_params, key, true, + save_cb, 'LPUSH', { key, packed_vector }) + end + + gen_embeddings_ollama(task, save_vector) +end + +local function nn_learn(task, is_spam) + if not kann_model then + kann_model_create() + end + + save_embeddings_vector(task, is_spam) +end + +local function nn_learn_spam(task) + lua_util.debugm(N, task, 'learn spam') + nn_learn(task, true) +end + +local function nn_learn_ham(task) + lua_util.debugm(N, task, 'learn ham') + nn_learn(task, false) +end + +local function nn_classify(task) + -- TODO: Implement +end + +local module_config = rspamd_config:get_all_opt(N) +settings = lua_util.override_defaults(settings, module_config) +redis_params = lua_redis.parse_redis_server(N) + +if not redis_params then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + lua_util.disable_module(N, "redis") + return +end + +local id = rspamd_config:register_symbol { + name = "LLM_CLASSIFY_CHECK", + type = 'callback', + callback = nn_classify, +} +rspamd_config:register_symbol { + name = "LLM_EMBEDDINGS_SPAM", + type = 'virtual', + parent = id, +} +rspamd_config:register_symbol { + name = "LLM_EMBEDDINGS_HAM", + type = 'virtual', + parent = id, +} + +-- Allow this symbol to be enabled merely explicitly when we need to learn +rspamd_config:register_symbol { + name = "LLM_LEARN_SPAM", + type = 'callback', + callback = nn_learn_spam, + flags = 'explicit_enable', +} +-- Allow this symbol to be enabled merely explicitly when we need to learn +rspamd_config:register_symbol { + name = "LLM_LEARN_HAM", + type = 'callback', + callback = nn_learn_ham, + flags = 'explicit_enable', +}
\ No newline at end of file |