diff options
Diffstat (limited to 'src')
113 files changed, 6325 insertions, 2055 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 92edb0b6a..6cc49e4e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,246 +1,286 @@ -MACRO(_AddModulesForced MLIST) -# Generate unique string for this build - SET(MODULES_C "${CMAKE_CURRENT_BINARY_DIR}/modules.c") - FILE(WRITE "${MODULES_C}" - "#include \"rspamd.h\"\n") - - # Handle even old cmake - LIST(LENGTH ${MLIST} MLIST_COUNT) - MATH(EXPR MLIST_MAX ${MLIST_COUNT}-1) - - FOREACH(MOD_IDX RANGE ${MLIST_MAX}) - LIST(GET ${MLIST} ${MOD_IDX} MOD) - FILE(APPEND "${MODULES_C}" "extern module_t ${MOD}_module;\n") - ENDFOREACH(MOD_IDX RANGE ${MLIST_MAX}) - - FILE(APPEND "${MODULES_C}" "\n\nmodule_t *modules[] = {\n") - - FOREACH(MOD_IDX RANGE ${MLIST_MAX}) - LIST(GET ${MLIST} ${MOD_IDX} MOD) - FILE(APPEND "${MODULES_C}" "&${MOD}_module,\n") - ENDFOREACH(MOD_IDX RANGE ${MLIST_MAX}) - - FILE(APPEND "${MODULES_C}" "NULL\n};\n") -ENDMACRO(_AddModulesForced MLIST) - -MACRO(_AddWorkersForced WLIST) - SET(WORKERS_C "${CMAKE_CURRENT_BINARY_DIR}/workers.c") - FILE(WRITE "${WORKERS_C}" - "#include \"rspamd.h\"\n") - - # Handle even old cmake - LIST(LENGTH ${WLIST} WLIST_COUNT) - MATH(EXPR WLIST_MAX ${WLIST_COUNT}-1) - FOREACH(MOD_IDX RANGE ${WLIST_MAX}) - LIST(GET ${WLIST} ${MOD_IDX} WRK) - FILE(APPEND "${WORKERS_C}" "extern worker_t ${WRK}_worker;\n") - ENDFOREACH(MOD_IDX RANGE ${WLIST_MAX}) - - FILE(APPEND "${WORKERS_C}" "\n\nworker_t *workers[] = {\n") - - FOREACH(MOD_IDX RANGE ${WLIST_MAX}) - LIST(GET ${WLIST} ${MOD_IDX} WRK) - FILE(APPEND "${WORKERS_C}" "&${WRK}_worker,\n") - ENDFOREACH(MOD_IDX RANGE ${WLIST_MAX}) - FILE(APPEND "${WORKERS_C}" "NULL\n};\n") -ENDMACRO(_AddWorkersForced WLIST) - -MACRO(AddModules MLIST WLIST) - _AddModulesForced(${MLIST}) - _AddWorkersForced(${WLIST}) - #IF(NOT EXISTS "modules.c") - # _AddModulesForced(${MLIST} ${WLIST}) - #ELSE(NOT EXISTS "modules.c") - # FILE(STRINGS "modules.c" FILE_ID_RAW REGEX "^/.*[a-zA-Z0-9]+.*/$") - # STRING(REGEX MATCH "[a-zA-Z0-9]+" FILE_ID "${FILE_ID_RAW}") - # IF(NOT FILE_ID STREQUAL MODULES_ID) - # MESSAGE("Regenerate modules info") - # _AddModulesForced(${MLIST} ${WLIST}) - # ENDIF(NOT FILE_ID STREQUAL MODULES_ID) - #ENDIF(NOT EXISTS "modules.c") -ENDMACRO(AddModules MLIST WLIST) - -# Rspamd core components -IF (ENABLE_CLANG_PLUGIN MATCHES "ON") - SET(CMAKE_C_FLAGS - "${CMAKE_C_FLAGS} -Xclang -load -Xclang ${CMAKE_CURRENT_BINARY_DIR}/../clang-plugin/librspamd-clang${CMAKE_SHARED_LIBRARY_SUFFIX} -Xclang -add-plugin -Xclang rspamd-ast") - IF(CLANG_EXTRA_PLUGINS_LIBS) - FOREACH(_lib ${CLANG_EXTRA_PLUGINS_LIBS}) - SET(CMAKE_C_FLAGS - "${CMAKE_C_FLAGS} -Xclang -load -Xclang ${_lib}") - SET(CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} -Xclang -load -Xclang ${_lib}") - ENDFOREACH() - ENDIF() - IF(CLANG_EXTRA_PLUGINS) - FOREACH(_plug ${CLANG_EXTRA_PLUGINS}) - SET(CMAKE_C_FLAGS - "${CMAKE_C_FLAGS} -Xclang -add-plugin -Xclang ${_plug}") - SET(CMAKE_CXX_FLAGS - "${CMAKE_C_FLAGS} -Xclang -add-plugin -Xclang ${_plug}") - ENDFOREACH() - ENDIF() -ENDIF () - -ADD_SUBDIRECTORY(lua) -ADD_SUBDIRECTORY(libcryptobox) -ADD_SUBDIRECTORY(libutil) -ADD_SUBDIRECTORY(libserver) -ADD_SUBDIRECTORY(libmime) -ADD_SUBDIRECTORY(libstat) -ADD_SUBDIRECTORY(client) -ADD_SUBDIRECTORY(rspamadm) - -SET(RSPAMDSRC controller.c - fuzzy_storage.c - rspamd.c - worker.c - rspamd_proxy.c) - -SET(PLUGINSSRC plugins/regexp.c - plugins/chartable.cxx - plugins/fuzzy_check.c - plugins/dkim_check.c - libserver/rspamd_control.c) - -SET(MODULES_LIST regexp chartable fuzzy_check dkim) -SET(WORKERS_LIST normal controller fuzzy rspamd_proxy) -IF (ENABLE_HYPERSCAN MATCHES "ON") - LIST(APPEND WORKERS_LIST "hs_helper") - LIST(APPEND RSPAMDSRC "hs_helper.c") -ENDIF() - -AddModules(MODULES_LIST WORKERS_LIST) -LIST(LENGTH PLUGINSSRC RSPAMD_MODULES_NUM) - -SET(RAGEL_DEPENDS "${CMAKE_SOURCE_DIR}/src/ragel/smtp_address.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/smtp_date.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/smtp_base.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/content_disposition.rl") -RAGEL_TARGET(ragel_smtp_addr - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_addr_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -T1 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/smtp_addr_parser.rl.c) -RAGEL_TARGET(ragel_content_disposition - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/content_disposition_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/content_disposition.rl.c) -RAGEL_TARGET(ragel_rfc2047 - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/rfc2047_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/rfc2047.rl.c) -RAGEL_TARGET(ragel_smtp_date - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_date_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/date_parser.rl.c) -RAGEL_TARGET(ragel_smtp_ip - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ip_parser.rl.c) -# Fucking cmake... -FOREACH(_GEN ${LIBSERVER_GENERATED}) - set_source_files_properties(${_GEN} PROPERTIES GENERATED TRUE) -ENDFOREACH() -######################### LINK SECTION ############################### - -IF(ENABLE_STATIC MATCHES "ON") - ADD_LIBRARY(rspamd-server STATIC - ${RSPAMD_CRYPTOBOX} - ${RSPAMD_UTIL} - ${RSPAMD_LUA} - ${RSPAMD_SERVER} - ${RSPAMD_STAT} - ${RSPAMD_MIME} - ${CMAKE_CURRENT_BINARY_DIR}/modules.c - ${PLUGINSSRC} - "${RAGEL_ragel_smtp_addr_OUTPUTS}" - "${RAGEL_ragel_newlines_strip_OUTPUTS}" - "${RAGEL_ragel_content_type_OUTPUTS}" - "${RAGEL_ragel_content_disposition_OUTPUTS}" - "${RAGEL_ragel_rfc2047_OUTPUTS}" - "${RAGEL_ragel_smtp_date_OUTPUTS}" - "${RAGEL_ragel_smtp_ip_OUTPUTS}" - ${BACKWARD_ENABLE}) -ELSE() - ADD_LIBRARY(rspamd-server SHARED - ${RSPAMD_CRYPTOBOX} - ${RSPAMD_UTIL} - ${RSPAMD_SERVER} - ${RSPAMD_STAT} - ${RSPAMD_MIME} - ${RSPAMD_LUA} - ${CMAKE_CURRENT_BINARY_DIR}/modules.c - ${PLUGINSSRC} - "${RAGEL_ragel_smtp_addr_OUTPUTS}" - "${RAGEL_ragel_newlines_strip_OUTPUTS}" - "${RAGEL_ragel_content_type_OUTPUTS}" - "${RAGEL_ragel_content_disposition_OUTPUTS}" - "${RAGEL_ragel_rfc2047_OUTPUTS}" - "${RAGEL_ragel_smtp_date_OUTPUTS}" - "${RAGEL_ragel_smtp_ip_OUTPUTS}" - ${BACKWARD_ENABLE}) -ENDIF() - -FOREACH(_DEP ${LIBSERVER_DEPENDS}) - ADD_DEPENDENCIES(rspamd-server "${_DEP}") -ENDFOREACH() - -TARGET_LINK_LIBRARIES(rspamd-server rspamd-http-parser) -TARGET_LINK_LIBRARIES(rspamd-server rspamd-fpconv) -TARGET_LINK_LIBRARIES(rspamd-server rspamd-cdb) -TARGET_LINK_LIBRARIES(rspamd-server rspamd-lpeg) -TARGET_LINK_LIBRARIES(rspamd-server lcbtrie) -IF(SYSTEM_ZSTD MATCHES "OFF") - TARGET_LINK_LIBRARIES(rspamd-server rspamd-zstd) -ELSE() - TARGET_LINK_LIBRARIES(rspamd-server zstd) -ENDIF() -TARGET_LINK_LIBRARIES(rspamd-server rspamd-simdutf) - -IF (ENABLE_CLANG_PLUGIN MATCHES "ON") - ADD_DEPENDENCIES(rspamd-server rspamd-clang) -ENDIF() - -IF (NOT WITH_LUAJIT) - TARGET_LINK_LIBRARIES(rspamd-server rspamd-bit) -ENDIF() - -IF (ENABLE_SNOWBALL MATCHES "ON") - TARGET_LINK_LIBRARIES(rspamd-server stemmer) -ENDIF() -TARGET_LINK_LIBRARIES(rspamd-server rspamd-hiredis) - -IF (ENABLE_FANN MATCHES "ON") - TARGET_LINK_LIBRARIES(rspamd-server fann) -ENDIF () - -IF (ENABLE_HYPERSCAN MATCHES "ON") - TARGET_LINK_LIBRARIES(rspamd-server hs) -ENDIF() - -IF(WITH_BLAS) - TARGET_LINK_LIBRARIES(rspamd-server ${BLAS_REQUIRED_LIBRARIES}) -ENDIF() - -TARGET_LINK_LIBRARIES(rspamd-server ${RSPAMD_REQUIRED_LIBRARIES}) -ADD_BACKWARD(rspamd-server) - -ADD_EXECUTABLE(rspamd ${RSPAMDSRC} ${CMAKE_CURRENT_BINARY_DIR}/workers.c ${CMAKE_CURRENT_BINARY_DIR}/config.h) -ADD_BACKWARD(rspamd) -SET_TARGET_PROPERTIES(rspamd PROPERTIES LINKER_LANGUAGE CXX) -SET_TARGET_PROPERTIES(rspamd-server PROPERTIES LINKER_LANGUAGE CXX) -IF(NOT NO_TARGET_VERSIONS) - SET_TARGET_PROPERTIES(rspamd PROPERTIES VERSION ${RSPAMD_VERSION}) -ENDIF() - -#TARGET_LINK_LIBRARIES(rspamd ${RSPAMD_REQUIRED_LIBRARIES}) -TARGET_LINK_LIBRARIES(rspamd rspamd-server) - -INSTALL(TARGETS rspamd RUNTIME DESTINATION bin) -INSTALL(TARGETS rspamd-server LIBRARY DESTINATION ${RSPAMD_LIBDIR})
\ No newline at end of file +# Function to generate module registrations +function(generate_modules_list MODULE_LIST) + # Generate unique string for this build + set(MODULES_C "${CMAKE_CURRENT_BINARY_DIR}/modules.c") + file(WRITE "${MODULES_C}" + "#include \"rspamd.h\"\n") + + # Process each module + foreach (MOD IN LISTS ${MODULE_LIST}) + file(APPEND "${MODULES_C}" "extern module_t ${MOD}_module;\n") + endforeach () + + file(APPEND "${MODULES_C}" "\n\nmodule_t *modules[] = {\n") + + foreach (MOD IN LISTS ${MODULE_LIST}) + file(APPEND "${MODULES_C}" "&${MOD}_module,\n") + endforeach () + + file(APPEND "${MODULES_C}" "NULL\n};\n") + + # Return the generated file path + set(MODULES_C_PATH "${MODULES_C}" PARENT_SCOPE) +endfunction() + +# Function to generate worker registrations +function(generate_workers_list WORKER_LIST) + set(WORKERS_C "${CMAKE_CURRENT_BINARY_DIR}/workers.c") + file(WRITE "${WORKERS_C}" + "#include \"rspamd.h\"\n") + + # Process each worker + foreach (WRK IN LISTS ${WORKER_LIST}) + file(APPEND "${WORKERS_C}" "extern worker_t ${WRK}_worker;\n") + endforeach () + + file(APPEND "${WORKERS_C}" "\n\nworker_t *workers[] = {\n") + + foreach (WRK IN LISTS ${WORKER_LIST}) + file(APPEND "${WORKERS_C}" "&${WRK}_worker,\n") + endforeach () + + file(APPEND "${WORKERS_C}" "NULL\n};\n") + + # Return the generated file path + set(WORKERS_C_PATH "${WORKERS_C}" PARENT_SCOPE) +endfunction() + +# Function to generate both modules and workers +function(generate_registration_code MODULE_LIST WORKER_LIST) + generate_modules_list(${MODULE_LIST}) + generate_workers_list(${WORKER_LIST}) + + # Set parent scope variables + set(MODULES_C_PATH ${MODULES_C_PATH} PARENT_SCOPE) + set(WORKERS_C_PATH ${WORKERS_C_PATH} PARENT_SCOPE) +endfunction() + +# Configure Clang Plugin if enabled +if (ENABLE_CLANG_PLUGIN) + set(CLANG_PLUGIN_FLAGS "-Xclang -load -Xclang ${CMAKE_CURRENT_BINARY_DIR}/../clang-plugin/librspamd-clang${CMAKE_SHARED_LIBRARY_SUFFIX} -Xclang -add-plugin -Xclang rspamd-ast") + + # Apply to both C and C++ compiler flags + add_compile_options(${CLANG_PLUGIN_FLAGS}) + + # Add any extra clang plugins + if (CLANG_EXTRA_PLUGINS_LIBS) + foreach (lib ${CLANG_EXTRA_PLUGINS_LIBS}) + add_compile_options("-Xclang" "-load" "-Xclang" "${lib}") + endforeach () + endif () + + if (CLANG_EXTRA_PLUGINS) + foreach (plug ${CLANG_EXTRA_PLUGINS}) + add_compile_options("-Xclang" "-add-plugin" "-Xclang" "${plug}") + endforeach () + endif () +endif () + +# Add subdirectories for components +add_subdirectory(lua) +add_subdirectory(libcryptobox) +add_subdirectory(libutil) +add_subdirectory(libserver) +add_subdirectory(libmime) +add_subdirectory(libstat) +add_subdirectory(client) +add_subdirectory(rspamadm) + +# Define source files +set(RSPAMD_SOURCES + controller.c + fuzzy_storage.c + rspamd.c + worker.c + rspamd_proxy.c) + +set(PLUGIN_SOURCES + plugins/regexp.c + plugins/chartable.cxx + plugins/fuzzy_check.c + plugins/dkim_check.c + libserver/rspamd_control.c) + +# Define module and worker lists +set(MODULES_LIST regexp chartable fuzzy_check dkim) +set(WORKERS_LIST normal controller fuzzy rspamd_proxy) + +# Add hyperscan worker if enabled +if (ENABLE_HYPERSCAN) + list(APPEND WORKERS_LIST hs_helper) + list(APPEND RSPAMD_SOURCES hs_helper.c) +endif () + +# Generate modules and workers registration code +generate_registration_code(MODULES_LIST WORKERS_LIST) + +# Count the number of modules +list(LENGTH PLUGIN_SOURCES RSPAMD_MODULES_NUM) + +# Configure Ragel for parsers +set(RAGEL_DEPENDS + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_address.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_date.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_base.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/content_disposition.rl") + +# Generate parsers with Ragel +ragel_target(ragel_smtp_addr + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_addr_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -T1 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/smtp_addr_parser.rl.c) + +ragel_target(ragel_content_disposition + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/content_disposition_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/content_disposition.rl.c) + +ragel_target(ragel_rfc2047 + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/rfc2047_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/rfc2047.rl.c) + +ragel_target(ragel_smtp_date + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_date_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/date_parser.rl.c) + +ragel_target(ragel_smtp_ip + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ip_parser.rl.c) + +# Mark generated files correctly +foreach (_gen ${LIBSERVER_GENERATED}) + set_source_files_properties(${_gen} PROPERTIES GENERATED TRUE) +endforeach () + +# Collection of all generated Ragel outputs +set(RAGEL_OUTPUTS + ${RAGEL_ragel_smtp_addr_OUTPUTS} + ${RAGEL_ragel_newlines_strip_OUTPUTS} + ${RAGEL_ragel_content_type_OUTPUTS} + ${RAGEL_ragel_content_disposition_OUTPUTS} + ${RAGEL_ragel_rfc2047_OUTPUTS} + ${RAGEL_ragel_smtp_date_OUTPUTS} + ${RAGEL_ragel_smtp_ip_OUTPUTS}) + +# Common sources for rspamd-server +set(SERVER_COMMON_SOURCES + ${RSPAMD_CRYPTOBOX} + ${RSPAMD_UTIL} + ${RSPAMD_LUA} + ${RSPAMD_SERVER} + ${RSPAMD_STAT} + ${RSPAMD_MIME} + ${MODULES_C_PATH} + ${PLUGIN_SOURCES} + ${RAGEL_OUTPUTS} + ${BACKWARD_ENABLE}) + +# Build rspamd-server as static or shared library based on configuration +if (ENABLE_STATIC) + add_library(rspamd-server STATIC ${SERVER_COMMON_SOURCES}) +else () + add_library(rspamd-server SHARED ${SERVER_COMMON_SOURCES}) +endif () + +# Set dependencies for rspamd-server +foreach (_dep ${LIBSERVER_DEPENDS}) + add_dependencies(rspamd-server "${_dep}") +endforeach () + +# Link dependencies +target_link_libraries(rspamd-server + PRIVATE + rspamd-http-parser + rspamd-fpconv + rspamd-cdb + rspamd-lpeg + ottery + lcbtrie + rspamd-simdutf + rdns + ucl) + +# Handle xxhash dependency +if (SYSTEM_XXHASH) + target_link_libraries(rspamd-server PUBLIC xxhash) +else () + target_link_libraries(rspamd-server PUBLIC rspamd-xxhash) +endif () + +# Handle zstd dependency +if (SYSTEM_ZSTD) + target_link_libraries(rspamd-server PUBLIC zstd) +else () + target_link_libraries(rspamd-server PRIVATE rspamd-zstd) +endif () + +# Handle clang plugin dependency +if (ENABLE_CLANG_PLUGIN) + add_dependencies(rspamd-server rspamd-clang) +endif () + +# Handle Lua JIT/Lua dependency +if (NOT WITH_LUAJIT) + target_link_libraries(rspamd-server PRIVATE rspamd-bit) +endif () + +# Link additional optional dependencies +if (ENABLE_SNOWBALL) + target_link_libraries(rspamd-server PRIVATE stemmer) +endif () + +target_link_libraries(rspamd-server PRIVATE rspamd-hiredis) + +if (ENABLE_FANN) + target_link_libraries(rspamd-server PRIVATE fann) +endif () + +if (ENABLE_HYPERSCAN) + target_link_libraries(rspamd-server PUBLIC hs) +endif () + +if (WITH_BLAS) + target_link_libraries(rspamd-server PRIVATE ${BLAS_REQUIRED_LIBRARIES}) +endif () + +# Link all required system libraries +target_link_libraries(rspamd-server PUBLIC ${RSPAMD_REQUIRED_LIBRARIES}) + +# Add Backward support for stacktrace +add_backward(rspamd-server) + +# Build main rspamd executable +add_executable(rspamd + ${RSPAMD_SOURCES} + ${WORKERS_C_PATH} + ${CMAKE_CURRENT_BINARY_DIR}/config.h) + +# Configure rspamd executable +add_backward(rspamd) +set_target_properties(rspamd PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(rspamd-server PROPERTIES LINKER_LANGUAGE CXX) + +if (NOT NO_TARGET_VERSIONS) + set_target_properties(rspamd PROPERTIES VERSION ${RSPAMD_VERSION}) +endif () + +# Link rspamd executable with the server library +target_link_libraries(rspamd PRIVATE rspamd-server) + +# Install targets +install(TARGETS rspamd + RUNTIME + DESTINATION bin) + +install(TARGETS rspamd-server + LIBRARY + DESTINATION ${RSPAMD_LIBDIR}) diff --git a/src/client/rspamc.cxx b/src/client/rspamc.cxx index 31a4aaf24..404359877 100644 --- a/src/client/rspamc.cxx +++ b/src/client/rspamc.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,8 +35,8 @@ #include "frozen/string.h" #include "frozen/unordered_map.h" -#include "fmt/format.h" -#include "fmt/color.h" +#include "contrib/fmt/include/fmt/format.h" +#include "contrib/fmt/include/fmt/color.h" #include "libutil/cxx/file_util.hxx" #include "libutil/cxx/util.hxx" diff --git a/src/client/rspamdclient.c b/src/client/rspamdclient.c index d07b24332..4d79590c5 100644 --- a/src/client/rspamdclient.c +++ b/src/client/rspamdclient.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -231,7 +231,7 @@ rspamd_client_finish_handler(struct rspamd_http_connection *conn, } } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk_full(parser, start, len, ucl_parser_get_default_priority(parser), UCL_DUPLICATE_APPEND, UCL_PARSE_AUTO)) { diff --git a/src/controller.c b/src/controller.c index 386448f93..0550ba6b8 100644 --- a/src/controller.c +++ b/src/controller.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,6 +68,7 @@ #define PATH_NEIGHBOURS "/neighbours" #define PATH_PLUGINS "/plugins" #define PATH_PING "/ping" +#define PATH_BAYES_CLASSIFIERS "/bayes/classifiers" #define msg_err_session(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ session->pool->tag.tagname, session->pool->tag.uid, \ @@ -979,12 +980,6 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, if (bk->protocol == MAP_PROTO_FILE) { editable = rspamd_controller_can_edit_map(bk); - - if (!editable && access(bk->uri, R_OK) == -1) { - /* Skip unreadable and non-existing maps */ - continue; - } - obj = ucl_object_typed_new(UCL_OBJECT); ucl_object_insert_key(obj, ucl_object_fromint(bk->id), "map", 0, false); @@ -994,8 +989,34 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, } ucl_object_insert_key(obj, ucl_object_fromstring(bk->uri), "uri", 0, false); + ucl_object_insert_key(obj, ucl_object_fromstring("file"), + "type", 0, false); ucl_object_insert_key(obj, ucl_object_frombool(editable), "editable", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->loaded), + "loaded", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->cached), + "cached", 0, false); + ucl_array_append(top, obj); + } + else { + obj = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(obj, ucl_object_fromint(bk->id), + "map", 0, false); + if (map->description) { + ucl_object_insert_key(obj, ucl_object_fromstring(map->description), + "description", 0, false); + } + ucl_object_insert_key(obj, ucl_object_fromstring(bk->uri), + "uri", 0, false); + ucl_object_insert_key(obj, ucl_object_fromstring(rspamd_map_fetch_protocol_name(bk->protocol)), + "type", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(false), + "editable", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->loaded), + "loaded", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->cached), + "cached", 0, false); ucl_array_append(top, obj); } } @@ -1008,6 +1029,21 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, return 0; } +gboolean +rspamd_controller_map_traverse_callback(gconstpointer key, gconstpointer value, gsize _hits, gpointer ud) +{ + rspamd_fstring_t **target = (rspamd_fstring_t **) ud; + + *target = rspamd_fstring_append(*target, key, strlen(key)); + + if (value) { + *target = rspamd_fstring_append(*target, " ", 1); + *target = rspamd_fstring_append(*target, value, strlen(value)); + } + *target = rspamd_fstring_append(*target, "\n", 1); + + return TRUE; +} /* * Get map command handler: * request: /getmap @@ -1020,7 +1056,7 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, { struct rspamd_controller_session *session = conn_ent->ud; GList *cur; - struct rspamd_map *map; + struct rspamd_map *map = NULL; struct rspamd_map_backend *bk = NULL; const rspamd_ftok_t *idstr; struct stat st; @@ -1054,7 +1090,7 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, PTR_ARRAY_FOREACH(map->backends, i, bk) { - if (bk->id == id && bk->protocol == MAP_PROTO_FILE) { + if (bk->id == id) { found = TRUE; break; } @@ -1069,32 +1105,53 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, return 0; } - if (stat(bk->uri, &st) == -1 || (fd = open(bk->uri, O_RDONLY)) == -1) { + if (bk->protocol == MAP_PROTO_FILE) { + if (stat(bk->uri, &st) == -1 || (fd = open(bk->uri, O_RDONLY)) == -1) { + reply = rspamd_http_new_message(HTTP_RESPONSE); + reply->date = time(NULL); + reply->code = 200; + } + else { + + reply = rspamd_http_new_message(HTTP_RESPONSE); + reply->date = time(NULL); + reply->code = 200; + + if (st.st_size > 0) { + if (!rspamd_http_message_set_body_from_fd(reply, fd)) { + close(fd); + rspamd_http_message_unref(reply); + msg_err_session("cannot read map %s: %s", bk->uri, strerror(errno)); + rspamd_controller_send_error(conn_ent, 500, "Map read error"); + return 0; + } + } + else { + rspamd_fstring_t *empty_body = rspamd_fstring_new_init("", 0); + rspamd_http_message_set_body_from_fstring_steal(reply, empty_body); + } + + close(fd); + } + } + else if (bk->protocol == MAP_PROTO_STATIC) { + /* We can just traverse map and form reply */ reply = rspamd_http_new_message(HTTP_RESPONSE); - reply->date = time(NULL); reply->code = 200; + rspamd_fstring_t *map_body = rspamd_fstring_new(); + rspamd_map_traverse(bk->map, rspamd_controller_map_traverse_callback, &map_body, FALSE); + rspamd_http_message_set_body_from_fstring_steal(reply, map_body); } - else { - + else if (map->shared->loaded) { reply = rspamd_http_new_message(HTTP_RESPONSE); - reply->date = time(NULL); reply->code = 200; - - if (st.st_size > 0) { - if (!rspamd_http_message_set_body_from_fd(reply, fd)) { - close(fd); - rspamd_http_message_unref(reply); - msg_err_session("cannot read map %s: %s", bk->uri, strerror(errno)); - rspamd_controller_send_error(conn_ent, 500, "Map read error"); - return 0; - } - } - else { - rspamd_fstring_t *empty_body = rspamd_fstring_new_init("", 0); - rspamd_http_message_set_body_from_fstring_steal(reply, empty_body); - } - - close(fd); + rspamd_fstring_t *map_body = rspamd_fstring_new(); + rspamd_map_traverse(bk->map, rspamd_controller_map_traverse_callback, &map_body, FALSE); + rspamd_http_message_set_body_from_fstring_steal(reply, map_body); + } + else { + reply = rspamd_http_new_message(HTTP_RESPONSE); + reply->code = 404; } rspamd_http_connection_reset(conn_ent->conn); @@ -2255,7 +2312,7 @@ rspamd_controller_handle_saveactions( return 0; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, msg->body_buf.begin, msg->body_buf.len)) { if ((error = ucl_parser_get_error(parser)) != NULL) { msg_err_session("cannot parse input: %s", error); @@ -2378,7 +2435,7 @@ rspamd_controller_handle_savesymbols( return 0; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, msg->body_buf.begin, msg->body_buf.len)) { if ((error = ucl_parser_get_error(parser)) != NULL) { msg_err_session("cannot parse input: %s", error); @@ -3235,7 +3292,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent, rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods", "POST, GET, OPTIONS"); rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers", - "Content-Type,Password,Map,Weight,Flag"); + "Classifier,Content-Type,Password,Map,Weight,Flag,Hash"); rspamd_http_connection_reset(conn_ent->conn); rspamd_http_router_insert_headers(conn_ent->rt, rep); rspamd_http_connection_write_message(conn_ent->conn, @@ -3390,6 +3447,40 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en return 0; } +/* + * Bayes classifier list command handler: + * request: /bayes/classifiers + * headers: Password + * reply: JSON array of Bayes classifier names + * Note: list is in reverse of declaration order (GList prepend). + */ +static int +rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg) +{ + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_controller_worker_ctx *ctx = session->ctx; + ucl_object_t *arr; + struct rspamd_classifier_config *clc; + GList *cur; + + if (!rspamd_controller_check_password(conn_ent, session, msg, FALSE)) { + return 0; + } + + arr = ucl_object_typed_new(UCL_ARRAY); + cur = g_list_last(ctx->cfg->classifiers); + while (cur) { + clc = cur->data; + ucl_array_append(arr, ucl_object_fromstring(clc->name)); + cur = g_list_previous(cur); + } + + rspamd_controller_send_ucl(conn_ent, arr); + ucl_object_unref(arr); + return 0; +} + static void rspamd_controller_error_handler(struct rspamd_http_connection_entry *conn_ent, @@ -3999,6 +4090,9 @@ start_controller_worker(struct rspamd_worker *worker) rspamd_http_router_add_path(ctx->http, PATH_PING, rspamd_controller_handle_ping); + rspamd_http_router_add_path(ctx->http, + PATH_BAYES_CLASSIFIERS, + rspamd_controller_handle_bayes_classifiers); rspamd_controller_register_plugins_paths(ctx); #if 0 diff --git a/src/fuzzy_storage.c b/src/fuzzy_storage.c index bc69be98c..58d123712 100644 --- a/src/fuzzy_storage.c +++ b/src/fuzzy_storage.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,7 +61,7 @@ worker_t fuzzy_worker = { "fuzzy", /* Name */ init_fuzzy, /* Init function */ start_fuzzy, /* Start function */ - RSPAMD_WORKER_HAS_SOCKET | RSPAMD_WORKER_NO_STRICT_CONFIG, + RSPAMD_WORKER_HAS_SOCKET | RSPAMD_WORKER_NO_STRICT_CONFIG | RSPAMD_WORKER_FUZZY, RSPAMD_WORKER_SOCKET_UDP, /* UDP socket */ RSPAMD_WORKER_VER /* Version info */ }; @@ -124,6 +124,11 @@ fuzzy_kp_equal(gconstpointer a, gconstpointer b) return (memcmp(pa, pb, RSPAMD_FUZZY_KEYLEN) == 0); } +enum fuzzy_key_op { + FUZZY_KEY_READ = 0x1u << 0, + FUZZY_KEY_WRITE = 0x1u << 1, + FUZZY_KEY_DELETE = 0x1u << 2, +}; KHASH_SET_INIT_INT(fuzzy_key_ids_set); KHASH_INIT(fuzzy_key_flag_stat, int, struct fuzzy_key_stat, 1, kh_int_hash_func, kh_int_hash_equal); @@ -135,10 +140,12 @@ struct fuzzy_key { khash_t(fuzzy_key_flag_stat) * flags_stat; khash_t(fuzzy_key_ids_set) * forbidden_ids; struct rspamd_leaky_bucket_elt *rl_bucket; + ucl_object_t *extensions; double burst; double rate; ev_tstamp expire; bool expired; + int flags; /* enum fuzzy_key_op */ ref_entry_t ref; }; @@ -146,6 +153,11 @@ KHASH_INIT(rspamd_fuzzy_keys_hash, const unsigned char *, struct fuzzy_key *, 1, fuzzy_kp_hash, fuzzy_kp_equal); +struct rspamd_lua_fuzzy_script { + int cbref; + struct rspamd_lua_fuzzy_script *next; +}; + struct rspamd_fuzzy_storage_ctx { uint64_t magic; /* Events base */ @@ -208,9 +220,9 @@ struct rspamd_fuzzy_storage_ctx { struct rspamd_worker *worker; const ucl_object_t *skip_map; struct rspamd_hash_map_helper *skip_hashes; - int lua_pre_handler_cbref; - int lua_post_handler_cbref; - int lua_blacklist_cbref; + struct rspamd_lua_fuzzy_script *lua_pre_handlers; + struct rspamd_lua_fuzzy_script *lua_post_handlers; + struct rspamd_lua_fuzzy_script *lua_blacklist_handlers; khash_t(fuzzy_key_ids_set) * default_forbidden_ids; /* Ids that should not override other ids */ khash_t(fuzzy_key_ids_set) * weak_ids; @@ -330,7 +342,7 @@ ucl_keymap_fin_cb(struct map_cb_data *data, void **target) return; } - parser = ucl_parser_new(UCL_PARSER_NO_FILEVARS); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, jb->buf->str, jb->buf->len)) { msg_err_config("cannot load ucl data: parse error %s", @@ -429,6 +441,11 @@ rspamd_fuzzy_check_ratelimit_bucket(struct fuzzy_session *session, struct rspamd { gboolean ratelimited = FALSE, new_ratelimit = FALSE; + /* Nothing to check */ + if (isnan(max_burst) || isnan(max_rate)) { + return ratelimit_pass; + } + if (isnan(elt->cur)) { /* There is an issue with the previous logic: the TTL is updated each time * we see that new bucket. Hence, we need to check the `last` and act accordingly @@ -585,25 +602,29 @@ rspamd_fuzzy_maybe_call_blacklisted(struct rspamd_fuzzy_storage_ctx *ctx, rspamd_inet_addr_t *addr, const char *reason) { - if (ctx->lua_blacklist_cbref != -1) { - lua_State *L = ctx->cfg->lua_state; - int err_idx, ret; + if (ctx->lua_blacklist_handlers != NULL) { + struct rspamd_lua_fuzzy_script *cur; + LL_FOREACH(ctx->lua_blacklist_handlers, cur) + { + lua_State *L = ctx->cfg->lua_state; + int err_idx, ret; - lua_pushcfunction(L, &rspamd_lua_traceback); - err_idx = lua_gettop(L); - lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->lua_blacklist_cbref); - /* client IP */ - rspamd_lua_ip_push(L, addr); - /* block reason */ - lua_pushstring(L, reason); + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cur->cbref); + /* client IP */ + rspamd_lua_ip_push(L, addr); + /* block reason */ + lua_pushstring(L, reason); - if ((ret = lua_pcall(L, 2, 0, err_idx)) != 0) { - msg_err("call to lua_blacklist_cbref " - "script failed (%d): %s", - ret, lua_tostring(L, -1)); - } + if ((ret = lua_pcall(L, 2, 0, err_idx)) != 0) { + msg_err("call to lua_blacklist_cbref " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + } - lua_settop(L, 0); + lua_settop(L, 0); + } } } @@ -624,12 +645,15 @@ rspamd_fuzzy_check_client(struct rspamd_fuzzy_storage_ctx *ctx, } static gboolean -rspamd_fuzzy_check_write(struct fuzzy_session *session) +rspamd_fuzzy_check_write(struct fuzzy_session *session, uint8_t cmd) { if (session->ctx->read_only) { return FALSE; } + /* + * Check IP first + */ if (session->ctx->update_ips != NULL && session->addr) { if (rspamd_inet_address_get_af(session->addr) == AF_UNIX) { return TRUE; @@ -643,6 +667,9 @@ rspamd_fuzzy_check_write(struct fuzzy_session *session) } } + /* + * Check global list of the update keys + */ if (session->ctx->update_keys != NULL && session->key->stat && session->key->key) { static char base32_buf[rspamd_cryptobox_HASHBYTES * 2 + 1]; unsigned int raw_len; @@ -657,6 +684,15 @@ rspamd_fuzzy_check_write(struct fuzzy_session *session) } } + if (session->key) { + if (cmd == FUZZY_WRITE && session->key->flags & FUZZY_KEY_WRITE) { + return TRUE; + } + else if (cmd == FUZZY_DEL && session->key->flags & FUZZY_KEY_DELETE) { + return TRUE; + } + } + return FALSE; } @@ -711,6 +747,10 @@ fuzzy_key_dtor(gpointer p) g_free(key->name); } + if (key->extensions) { + ucl_object_unref(key->extensions); + } + g_free(key); } } @@ -1259,80 +1299,91 @@ rspamd_fuzzy_check_callback(struct rspamd_fuzzy_reply *result, void *ud) break; } - if (session->ctx->lua_post_handler_cbref != -1) { - /* Start lua post handler */ - lua_State *L = session->ctx->cfg->lua_state; - int err_idx, ret; + if (session->ctx->lua_post_handlers != NULL) { + struct rspamd_lua_fuzzy_script *cur; + LL_FOREACH(session->ctx->lua_post_handlers, cur) + { + /* Start lua post handler */ + lua_State *L = session->ctx->cfg->lua_state; + int err_idx, ret, nargs = 10; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + /* Preallocate stack (small opt) */ + lua_checkstack(L, err_idx + nargs + 9); + /* function */ + lua_rawgeti(L, LUA_REGISTRYINDEX, cur->cbref); + /* client IP */ + if (session->addr) { + rspamd_lua_ip_push(L, session->addr); + } + else { + lua_pushnil(L); + } + /* client command */ + lua_pushinteger(L, cmd->cmd); + /* command value (push as rspamd_text) */ + (void) lua_new_text(L, result->digest, sizeof(result->digest), FALSE); + /* is shingle */ + lua_pushboolean(L, is_shingle); + /* result value */ + lua_pushinteger(L, result->v1.value); + /* result probability */ + lua_pushnumber(L, result->v1.prob); + /* result flag */ + lua_pushinteger(L, result->v1.flag); + /* result timestamp */ + lua_pushinteger(L, result->ts); + /* TODO: add additional data maybe (encryption, pubkey, etc) */ + rspamd_fuzzy_extensions_tolua(L, session); + /* We push shingles merely for commands that modify content to avoid extra work */ + if (is_shingle && cmd->cmd != FUZZY_CHECK) { + lua_newshingle(L, &session->cmd.sgl); + } + else { + lua_pushnil(L); + } - lua_pushcfunction(L, &rspamd_lua_traceback); - err_idx = lua_gettop(L); - /* Preallocate stack (small opt) */ - lua_checkstack(L, err_idx + 9); - /* function */ - lua_rawgeti(L, LUA_REGISTRYINDEX, session->ctx->lua_post_handler_cbref); - /* client IP */ - if (session->addr) { - rspamd_lua_ip_push(L, session->addr); - } - else { - lua_pushnil(L); - } - /* client command */ - lua_pushinteger(L, cmd->cmd); - /* command value (push as rspamd_text) */ - (void) lua_new_text(L, result->digest, sizeof(result->digest), FALSE); - /* is shingle */ - lua_pushboolean(L, is_shingle); - /* result value */ - lua_pushinteger(L, result->v1.value); - /* result probability */ - lua_pushnumber(L, result->v1.prob); - /* result flag */ - lua_pushinteger(L, result->v1.flag); - /* result timestamp */ - lua_pushinteger(L, result->ts); - /* TODO: add additional data maybe (encryption, pubkey, etc) */ - rspamd_fuzzy_extensions_tolua(L, session); - - if ((ret = lua_pcall(L, 9, LUA_MULTRET, err_idx)) != 0) { - msg_err("call to lua_post_handler lua " - "script failed (%d): %s", - ret, lua_tostring(L, -1)); - } - else { - /* Return values order: - * the first reply will be on err_idx + 1 - * if it is true, then we need to read the former ones: - * 2-nd will be reply code - * 3-rd will be probability (or 0.0 if missing) - * 4-th value is flag (or default flag if missing) - */ - ret = lua_toboolean(L, err_idx + 1); + if ((ret = lua_pcall(L, nargs, LUA_MULTRET, err_idx)) != 0) { + msg_err("call to lua_post_handler lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + } + else { + /* Return values order: + * the first reply will be on err_idx + 1 + * if it is true, then we need to read the former ones: + * 2-nd will be reply code + * 3-rd will be probability (or 0.0 if missing) + * 4-th value is flag (or default flag if missing) + */ + ret = lua_toboolean(L, err_idx + 1); - if (ret) { - /* Artificial reply */ - result->v1.value = lua_tointeger(L, err_idx + 2); + if (ret) { + /* Artificial reply */ + result->v1.value = lua_tointeger(L, err_idx + 2); - if (lua_isnumber(L, err_idx + 3)) { - result->v1.prob = lua_tonumber(L, err_idx + 3); - } - else { - result->v1.prob = 0.0f; - } + if (lua_isnumber(L, err_idx + 3)) { + result->v1.prob = lua_tonumber(L, err_idx + 3); + } + else { + result->v1.prob = 0.0f; + } - if (lua_isnumber(L, err_idx + 4)) { - result->v1.flag = lua_tointeger(L, err_idx + 4); - } + if (lua_isnumber(L, err_idx + 4)) { + result->v1.flag = lua_tointeger(L, err_idx + 4); + } - lua_settop(L, 0); - rspamd_fuzzy_make_reply(cmd, result, session, send_flags); - REF_RELEASE(session); + lua_settop(L, 0); + rspamd_fuzzy_make_reply(cmd, result, session, send_flags); + REF_RELEASE(session); - return; + return; + } } - } - lua_settop(L, 0); + lua_settop(L, 0); + } } if (!isnan(session->ctx->delay) && @@ -1449,61 +1500,78 @@ rspamd_fuzzy_process_command(struct fuzzy_session *session) result.v1.flag = cmd->flag; result.v1.tag = cmd->tag; - if (session->ctx->lua_pre_handler_cbref != -1) { - /* Start lua pre handler */ - lua_State *L = session->ctx->cfg->lua_state; - int err_idx, ret; - - lua_pushcfunction(L, &rspamd_lua_traceback); - err_idx = lua_gettop(L); - /* Preallocate stack (small opt) */ - lua_checkstack(L, err_idx + 5); - /* function */ - lua_rawgeti(L, LUA_REGISTRYINDEX, session->ctx->lua_pre_handler_cbref); - /* client IP */ - rspamd_lua_ip_push(L, session->addr); - /* client command */ - lua_pushinteger(L, cmd->cmd); - /* command value (push as rspamd_text) */ - (void) lua_new_text(L, cmd->digest, sizeof(cmd->digest), FALSE); - /* is shingle */ - lua_pushboolean(L, is_shingle); - /* TODO: add additional data maybe (encryption, pubkey, etc) */ - rspamd_fuzzy_extensions_tolua(L, session); - - if ((ret = lua_pcall(L, 5, LUA_MULTRET, err_idx)) != 0) { - msg_err("call to lua_pre_handler lua " - "script failed (%d): %s", - ret, lua_tostring(L, -1)); - } - else { - /* Return values order: - * the first reply will be on err_idx + 1 - * if it is true, then we need to read the former ones: - * 2-nd will be reply code - * 3-rd will be probability (or 0.0 if missing) - */ - ret = lua_toboolean(L, err_idx + 1); + if (session->ctx->lua_pre_handlers != NULL) { + struct rspamd_lua_fuzzy_script *cur; + + LL_FOREACH(session->ctx->lua_pre_handlers, cur) + { + /* Start lua pre handler */ + lua_State *L = session->ctx->cfg->lua_state; + int err_idx, ret, nargs = 8; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + /* Preallocate stack (small opt) */ + lua_checkstack(L, err_idx + nargs + 1); + /* function */ + lua_rawgeti(L, LUA_REGISTRYINDEX, cur->cbref); + /* client IP */ + rspamd_lua_ip_push(L, session->addr); + /* client command */ + lua_pushinteger(L, cmd->cmd); + /* command value (push as rspamd_text) */ + (void) lua_new_text(L, cmd->digest, sizeof(cmd->digest), FALSE); + /* is shingle */ + lua_pushboolean(L, is_shingle); + /* TODO: add additional data maybe (encryption, pubkey, etc) */ + rspamd_fuzzy_extensions_tolua(L, session); + + /* We push shingles merely for commands that modify content to avoid extra work */ + if (is_shingle && cmd->cmd != FUZZY_CHECK) { + lua_newshingle(L, &session->cmd.sgl); + } + else { + lua_pushnil(L); + } - if (ret) { - /* Artificial reply */ - result.v1.value = lua_tointeger(L, err_idx + 2); + /* Flag and value */ + lua_pushinteger(L, cmd->flag); + lua_pushinteger(L, cmd->value); - if (lua_isnumber(L, err_idx + 3)) { - result.v1.prob = lua_tonumber(L, err_idx + 3); - } - else { - result.v1.prob = 0.0f; - } + if ((ret = lua_pcall(L, nargs, LUA_MULTRET, err_idx)) != 0) { + msg_err("call to lua_pre_handler lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + } + else { + /* Return values order: + * the first reply will be on err_idx + 1 + * if it is true, then we need to read the former ones: + * 2-nd will be reply code + * 3-rd will be probability (or 0.0 if missing) + */ + ret = lua_toboolean(L, err_idx + 1); - lua_settop(L, 0); - rspamd_fuzzy_make_reply(cmd, &result, session, send_flags); + if (ret) { + /* Artificial reply */ + result.v1.value = lua_tointeger(L, err_idx + 2); - return; + if (lua_isnumber(L, err_idx + 3)) { + result.v1.prob = lua_tonumber(L, err_idx + 3); + } + else { + result.v1.prob = 0.0f; + } + + lua_settop(L, 0); + rspamd_fuzzy_make_reply(cmd, &result, session, send_flags); + + return; + } } - } - lua_settop(L, 0); + lua_settop(L, 0); + } } @@ -1628,6 +1696,14 @@ rspamd_fuzzy_process_command(struct fuzzy_session *session) } } + /* Key is not allowed to read */ + if (session->key && !(session->key->flags & FUZZY_KEY_READ)) { + result.v1.value = 503; + result.v1.prob = 0.0f; + rspamd_fuzzy_make_reply(cmd, &result, session, send_flags); + return; + } + if (is_rate_allowed) { REF_RETAIN(session); rspamd_fuzzy_backend_check(session->ctx->backend, cmd, @@ -1655,7 +1731,7 @@ rspamd_fuzzy_process_command(struct fuzzy_session *session) rspamd_fuzzy_make_reply(cmd, &result, session, send_flags); } else { - if (rspamd_fuzzy_check_write(session)) { + if (rspamd_fuzzy_check_write(session, cmd->cmd)) { /* Check whitelist */ if (session->ctx->skip_hashes && cmd->cmd == FUZZY_WRITE) { rspamd_encode_hex_buf(cmd->digest, sizeof(cmd->digest), @@ -1676,28 +1752,31 @@ rspamd_fuzzy_process_command(struct fuzzy_session *session) cmd->version |= RSPAMD_FUZZY_FLAG_WEAK; } - if (session->worker->index == 0 || session->ctx->peer_fd == -1) { - /* Just add to the queue */ - up_cmd.is_shingle = is_shingle; - ptr = is_shingle ? (gpointer) &up_cmd.cmd.shingle : (gpointer) &up_cmd.cmd.normal; - memcpy(ptr, cmd, up_len); - g_array_append_val(session->ctx->updates_pending, up_cmd); - } - else { - /* We need to send request to the peer */ - up_req = g_malloc0(sizeof(*up_req)); - up_req->cmd.is_shingle = is_shingle; - ptr = is_shingle ? (gpointer) &up_req->cmd.cmd.shingle : (gpointer) &up_req->cmd.cmd.normal; - memcpy(ptr, cmd, up_len); - - if (!fuzzy_peer_try_send(session->ctx->peer_fd, up_req)) { - up_req->io_ev.data = up_req; - ev_io_init(&up_req->io_ev, fuzzy_peer_send_io, - session->ctx->peer_fd, EV_WRITE); - ev_io_start(session->ctx->event_loop, &up_req->io_ev); + /* Noop backends must skip all updates logic as irrelevant */ + if (!rspamd_fuzzy_backend_is_noop(session->ctx->backend)) { + if (session->worker->index == 0 || session->ctx->peer_fd == -1) { + /* Just add to the queue */ + up_cmd.is_shingle = is_shingle; + ptr = is_shingle ? (gpointer) &up_cmd.cmd.shingle : (gpointer) &up_cmd.cmd.normal; + memcpy(ptr, cmd, up_len); + g_array_append_val(session->ctx->updates_pending, up_cmd); } else { - g_free(up_req); + /* We need to send request to the peer */ + up_req = g_malloc0(sizeof(*up_req)); + up_req->cmd.is_shingle = is_shingle; + ptr = is_shingle ? (gpointer) &up_req->cmd.cmd.shingle : (gpointer) &up_req->cmd.cmd.normal; + memcpy(ptr, cmd, up_len); + + if (!fuzzy_peer_try_send(session->ctx->peer_fd, up_req)) { + up_req->io_ev.data = up_req; + ev_io_init(&up_req->io_ev, fuzzy_peer_send_io, + session->ctx->peer_fd, EV_WRITE); + ev_io_start(session->ctx->event_loop, &up_req->io_ev); + } + else { + g_free(up_req); + } } } @@ -2586,7 +2665,7 @@ rspamd_fuzzy_maybe_load_ratelimits(struct rspamd_fuzzy_storage_ctx *ctx) RSPAMD_DBDIR); if (access(path, R_OK) != -1) { - struct ucl_parser *parser = ucl_parser_new(UCL_PARSER_NO_IMPLICIT_ARRAYS | UCL_PARSER_DISABLE_MACRO); + struct ucl_parser *parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (ucl_parser_add_file(parser, path)) { ucl_object_t *obj = ucl_parser_get_object(parser); int loaded = 0; @@ -2712,14 +2791,12 @@ lua_fuzzy_add_pre_handler(lua_State *L) if (wrk && lua_isfunction(L, 2)) { ctx = (struct rspamd_fuzzy_storage_ctx *) wrk->ctx; + struct rspamd_lua_fuzzy_script *script; - if (ctx->lua_pre_handler_cbref != -1) { - /* Should not happen */ - luaL_unref(L, LUA_REGISTRYINDEX, ctx->lua_pre_handler_cbref); - } - + script = g_malloc0(sizeof(*script)); lua_pushvalue(L, 2); - ctx->lua_pre_handler_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + script->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + LL_APPEND(ctx->lua_pre_handlers, script); } else { return luaL_error(L, "invalid arguments, worker + function are expected"); @@ -2740,17 +2817,16 @@ lua_fuzzy_add_post_handler(lua_State *L) } wrk = *pwrk; + ctx = (struct rspamd_fuzzy_storage_ctx *) wrk->ctx; if (wrk && lua_isfunction(L, 2)) { ctx = (struct rspamd_fuzzy_storage_ctx *) wrk->ctx; + struct rspamd_lua_fuzzy_script *script; - if (ctx->lua_post_handler_cbref != -1) { - /* Should not happen */ - luaL_unref(L, LUA_REGISTRYINDEX, ctx->lua_post_handler_cbref); - } - + script = g_malloc0(sizeof(*script)); lua_pushvalue(L, 2); - ctx->lua_post_handler_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + script->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + LL_APPEND(ctx->lua_post_handlers, script); } else { return luaL_error(L, "invalid arguments, worker + function are expected"); @@ -2771,17 +2847,15 @@ lua_fuzzy_add_blacklist_handler(lua_State *L) } wrk = *pwrk; + ctx = (struct rspamd_fuzzy_storage_ctx *) wrk->ctx; if (wrk && lua_isfunction(L, 2)) { - ctx = (struct rspamd_fuzzy_storage_ctx *) wrk->ctx; - - if (ctx->lua_blacklist_cbref != -1) { - /* Should not happen */ - luaL_unref(L, LUA_REGISTRYINDEX, ctx->lua_blacklist_cbref); - } + struct rspamd_lua_fuzzy_script *script; + script = g_malloc0(sizeof(*script)); lua_pushvalue(L, 2); - ctx->lua_blacklist_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + script->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + LL_APPEND(ctx->lua_blacklist_handlers, script); } else { return luaL_error(L, "invalid arguments, worker + function are expected"); @@ -2942,6 +3016,8 @@ fuzzy_add_keypair_from_ucl(struct rspamd_config *cfg, const ucl_object_t *obj, key->rate = NAN; key->expire = NAN; key->rl_bucket = NULL; + /* Allow read by default */ + key->flags = FUZZY_KEY_READ; /* Preallocate some space for flags */ kh_resize(fuzzy_key_flag_stat, key->flags_stat, 8); const unsigned char *pk = rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_PK, @@ -2973,6 +3049,7 @@ fuzzy_add_keypair_from_ucl(struct rspamd_config *cfg, const ucl_object_t *obj, const ucl_object_t *extensions = rspamd_keypair_get_extensions(kp); if (extensions) { + key->extensions = ucl_object_ref(extensions); lua_State *L = RSPAMD_LUA_CFG_STATE(cfg); const ucl_object_t *forbidden_ids = ucl_object_lookup(extensions, "forbidden_ids"); @@ -3052,9 +3129,48 @@ fuzzy_add_keypair_from_ucl(struct rspamd_config *cfg, const ucl_object_t *obj, if (name && ucl_object_type(name) == UCL_STRING) { key->name = g_strdup(ucl_object_tostring(name)); } + + /* Check permissions */ + const ucl_object_t *read_only = ucl_object_lookup(extensions, "read_only"); + if (read_only && ucl_object_type(read_only) == UCL_BOOLEAN) { + if (ucl_object_toboolean(read_only)) { + key->flags &= ~(FUZZY_KEY_WRITE | FUZZY_KEY_DELETE); + } + else { + key->flags |= (FUZZY_KEY_WRITE | FUZZY_KEY_DELETE); + } + } + + const ucl_object_t *allowed_ops = ucl_object_lookup(extensions, "allowed_ops"); + if (allowed_ops && ucl_object_type(allowed_ops) == UCL_ARRAY) { + const ucl_object_t *cur; + ucl_object_iter_t it = NULL; + /* Reset to only allowed */ + key->flags = 0; + + while ((cur = ucl_object_iterate(allowed_ops, &it, true)) != NULL) { + if (ucl_object_type(cur) == UCL_STRING) { + const char *op = ucl_object_tostring(cur); + + if (g_ascii_strcasecmp(op, "read") == 0) { + key->flags |= FUZZY_KEY_READ; + } + else if (g_ascii_strcasecmp(op, "write") == 0) { + key->flags |= FUZZY_KEY_WRITE; + } + else if (g_ascii_strcasecmp(op, "delete") == 0) { + key->flags |= FUZZY_KEY_DELETE; + } + else { + msg_warn_config("invalid operation: %s", op); + } + } + } + } } - msg_debug("loaded keypair %*bs; expire=%f; rate=%f; burst=%f; name=%s", (int) crypto_box_publickeybytes(), pk, + msg_debug("loaded keypair %*bs; expire=%f; rate=%f; burst=%f; name=%s", + (int) crypto_box_publickeybytes(), pk, key->expire, key->rate, key->burst, key->name); return key; @@ -3122,9 +3238,6 @@ init_fuzzy(struct rspamd_config *cfg) ctx->magic = rspamd_fuzzy_storage_magic; ctx->sync_timeout = DEFAULT_SYNC_TIMEOUT; ctx->keypair_cache_size = DEFAULT_KEYPAIR_CACHE_SIZE; - ctx->lua_pre_handler_cbref = -1; - ctx->lua_post_handler_cbref = -1; - ctx->lua_blacklist_cbref = -1; ctx->keys = kh_init(rspamd_fuzzy_keys_hash); rspamd_mempool_add_destructor(cfg->cfg_pool, (rspamd_mempool_destruct_t) fuzzy_hash_table_dtor, ctx->keys); @@ -3629,13 +3742,11 @@ start_fuzzy(struct rspamd_worker *worker) } /* Ratelimits */ - if (!isnan(ctx->leaky_bucket_rate) && !isnan(ctx->leaky_bucket_burst)) { - ctx->ratelimit_buckets = rspamd_lru_hash_new_full(ctx->max_buckets, - NULL, fuzzy_rl_bucket_free, - rspamd_inet_address_hash, rspamd_inet_address_equal); + ctx->ratelimit_buckets = rspamd_lru_hash_new_full(ctx->max_buckets, + NULL, fuzzy_rl_bucket_free, + rspamd_inet_address_hash, rspamd_inet_address_equal); - rspamd_fuzzy_maybe_load_ratelimits(ctx); - } + rspamd_fuzzy_maybe_load_ratelimits(ctx); /* Maps events */ ctx->resolver = rspamd_dns_resolver_init(worker->srv->logger, @@ -3677,12 +3788,12 @@ start_fuzzy(struct rspamd_worker *worker) .func = lua_fuzzy_add_pre_handler, }; rspamd_lua_add_metamethod(ctx->cfg->lua_state, rspamd_worker_classname, &fuzzy_lua_reg); - fuzzy_lua_reg = (luaL_Reg){ + fuzzy_lua_reg = (luaL_Reg) { .name = "add_fuzzy_post_handler", .func = lua_fuzzy_add_post_handler, }; rspamd_lua_add_metamethod(ctx->cfg->lua_state, rspamd_worker_classname, &fuzzy_lua_reg); - fuzzy_lua_reg = (luaL_Reg){ + fuzzy_lua_reg = (luaL_Reg) { .name = "add_fuzzy_blacklist_handler", .func = lua_fuzzy_add_blacklist_handler, }; @@ -3735,16 +3846,22 @@ start_fuzzy(struct rspamd_worker *worker) rspamd_lru_hash_destroy(ctx->ratelimit_buckets); } - if (ctx->lua_pre_handler_cbref != -1) { - luaL_unref(ctx->cfg->lua_state, LUA_REGISTRYINDEX, ctx->lua_pre_handler_cbref); - } + struct rspamd_lua_fuzzy_script *cur, *tmp; - if (ctx->lua_post_handler_cbref != -1) { - luaL_unref(ctx->cfg->lua_state, LUA_REGISTRYINDEX, ctx->lua_post_handler_cbref); + LL_FOREACH_SAFE(ctx->lua_pre_handlers, cur, tmp) + { + luaL_unref(ctx->cfg->lua_state, LUA_REGISTRYINDEX, cur->cbref); + g_free(cur); } - - if (ctx->lua_blacklist_cbref != -1) { - luaL_unref(ctx->cfg->lua_state, LUA_REGISTRYINDEX, ctx->lua_blacklist_cbref); + LL_FOREACH_SAFE(ctx->lua_post_handlers, cur, tmp) + { + luaL_unref(ctx->cfg->lua_state, LUA_REGISTRYINDEX, cur->cbref); + g_free(cur); + } + LL_FOREACH_SAFE(ctx->lua_blacklist_handlers, cur, tmp) + { + luaL_unref(ctx->cfg->lua_state, LUA_REGISTRYINDEX, cur->cbref); + g_free(cur); } if (ctx->default_forbidden_ids) { diff --git a/src/libmime/lang_detection.c b/src/libmime/lang_detection.c index 6e180ea66..b783b8325 100644 --- a/src/libmime/lang_detection.c +++ b/src/libmime/lang_detection.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -363,7 +363,7 @@ rspamd_language_detector_read_file(struct rspamd_config *cfg, double mean = 0, std = 0, delta = 0, delta2 = 0, m2 = 0; enum rspamd_language_category cat = RSPAMD_LANGUAGE_MAX; - parser = ucl_parser_new(UCL_PARSER_NO_FILEVARS); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_file(parser, path)) { msg_warn_config("cannot parse file %s: %s", path, ucl_parser_get_error(parser)); @@ -825,7 +825,7 @@ rspamd_language_detector_init(struct rspamd_config *cfg) languages_pattern = g_string_sized_new(PATH_MAX); rspamd_printf_gstring(languages_pattern, "%s/stop_words", languages_path); - parser = ucl_parser_new(UCL_PARSER_DEFAULT); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (ucl_parser_add_file(parser, languages_pattern->str)) { stop_words = ucl_parser_get_object(parser); @@ -936,7 +936,7 @@ end: } static void -rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, +rspamd_language_detector_random_select(rspamd_words_t *ucs_tokens, unsigned int nwords, goffset *offsets_out, uint64_t *seed) { @@ -946,7 +946,7 @@ rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, g_assert(nwords != 0); g_assert(offsets_out != NULL); - g_assert(ucs_tokens->len >= nwords); + g_assert(kv_size(*ucs_tokens) >= nwords); /* * We split input array into `nwords` parts. For each part we randomly select * an element from this particular split. Here is an example: @@ -963,22 +963,22 @@ rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, * their splits. It is not uniform distribution but it seems to be better * to include words from different text parts */ - step_len = ucs_tokens->len / nwords; - remainder = ucs_tokens->len % nwords; + step_len = kv_size(*ucs_tokens) / nwords; + remainder = kv_size(*ucs_tokens) % nwords; out_idx = 0; coin = rspamd_random_uint64_fast_seed(seed); sel = coin % (step_len + remainder); offsets_out[out_idx] = sel; - for (i = step_len + remainder; i < ucs_tokens->len; + for (i = step_len + remainder; i < kv_size(*ucs_tokens); i += step_len, out_idx++) { unsigned int ntries = 0; coin = rspamd_random_uint64_fast_seed(seed); sel = (coin % step_len) + i; for (;;) { - tok = &g_array_index(ucs_tokens, rspamd_stat_token_t, sel); + tok = &kv_A(*ucs_tokens, sel); /* Filter bad tokens */ if (tok->unicode.len >= 2 && @@ -995,8 +995,8 @@ rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, if (ntries < step_len) { sel = (coin % step_len) + i; } - else if (ntries < ucs_tokens->len) { - sel = coin % ucs_tokens->len; + else if (ntries < kv_size(*ucs_tokens)) { + sel = coin % kv_size(*ucs_tokens); } else { offsets_out[out_idx] = sel; @@ -1223,12 +1223,12 @@ static void rspamd_language_detector_detect_type(struct rspamd_task *task, unsigned int nwords, struct rspamd_lang_detector *d, - GArray *words, + rspamd_words_t *words, enum rspamd_language_category cat, khash_t(rspamd_candidates_hash) * candidates, struct rspamd_mime_text_part *part) { - unsigned int nparts = MIN(words->len, nwords); + unsigned int nparts = MIN(kv_size(*words), nwords); goffset *selected_words; rspamd_stat_token_t *tok; unsigned int i; @@ -1241,8 +1241,7 @@ rspamd_language_detector_detect_type(struct rspamd_task *task, msg_debug_lang_det("randomly selected %d words", nparts); for (i = 0; i < nparts; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, - selected_words[i]); + tok = &kv_A(*words, selected_words[i]); if (tok->unicode.len >= 3) { rspamd_language_detector_detect_word(task, d, tok, candidates, @@ -1282,7 +1281,7 @@ static enum rspamd_language_detected_type rspamd_language_detector_try_ngramm(struct rspamd_task *task, unsigned int nwords, struct rspamd_lang_detector *d, - GArray *ucs_tokens, + rspamd_words_t *ucs_tokens, enum rspamd_language_category cat, khash_t(rspamd_candidates_hash) * candidates, struct rspamd_mime_text_part *part) @@ -1863,7 +1862,7 @@ rspamd_language_detector_detect(struct rspamd_task *task, if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) { rspamd_fasttext_predict_result_t fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector, task, - part->utf_words, 4); + &part->utf_words, 4); ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result); @@ -1930,11 +1929,11 @@ rspamd_language_detector_detect(struct rspamd_task *task, if (!ret) { /* Apply trigramms detection */ candidates = kh_init(rspamd_candidates_hash); - if (part->utf_words->len < default_short_text_limit) { + if (kv_size(part->utf_words) < default_short_text_limit) { r = rs_detect_none; msg_debug_lang_det("text is too short for trigrams detection: " "%d words; at least %d words required", - (int) part->utf_words->len, + (int) kv_size(part->utf_words), (int) default_short_text_limit); switch (cat) { case RSPAMD_LANGUAGE_CYRILLIC: @@ -1960,7 +1959,7 @@ rspamd_language_detector_detect(struct rspamd_task *task, r = rspamd_language_detector_try_ngramm(task, default_words, d, - part->utf_words, + &part->utf_words, cat, candidates, part); @@ -2123,4 +2122,4 @@ int rspamd_language_detector_elt_flags(const struct rspamd_language_elt *elt) } return 0; -}
\ No newline at end of file +} diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx index 89916151f..983ff78de 100644 --- a/src/libmime/lang_detection_fasttext.cxx +++ b/src/libmime/lang_detection_fasttext.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,8 +20,9 @@ #include "fasttext/fasttext.h" #include "libserver/cfg_file.h" #include "libserver/logger.h" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "stat_api.h" +#include "libserver/word.h" #include <exception> #include <string_view> #include <vector> @@ -180,26 +181,32 @@ bool rspamd_lang_detection_fasttext_is_enabled(void *ud) rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, struct rspamd_task *task, - GArray *utf_words, + rspamd_words_t *utf_words, int k) { #ifndef WITH_FASTTEXT return nullptr; #else /* Avoid too long inputs */ - static const unsigned int max_fasttext_input_len = 1024 * 1024; + static const size_t max_fasttext_input_len = 1024 * 1024; auto *real_model = FASTTEXT_MODEL_TO_C_API(ud); std::vector<std::int32_t> words_vec; - words_vec.reserve(utf_words->len); - for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) { - const auto *w = &g_array_index(utf_words, rspamd_stat_token_t, i); + if (!utf_words || !utf_words->a) { + return nullptr; + } + + auto words_count = kv_size(*utf_words); + words_vec.reserve(words_count); + + for (auto i = 0; i < std::min(words_count, max_fasttext_input_len); i++) { + const auto *w = &kv_A(*utf_words, i); if (w->original.len > 0) { real_model->word2vec(w->original.begin, w->original.len, words_vec); } } - msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), utf_words->len); + msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), words_count); auto *res = real_model->detect_language(words_vec, k); @@ -266,4 +273,4 @@ void rspamd_fasttext_predict_result_destroy(rspamd_fasttext_predict_result_t res #endif } -G_END_DECLS
\ No newline at end of file +G_END_DECLS diff --git a/src/libmime/lang_detection_fasttext.h b/src/libmime/lang_detection_fasttext.h index 2a2756968..e2b67181a 100644 --- a/src/libmime/lang_detection_fasttext.h +++ b/src/libmime/lang_detection_fasttext.h @@ -17,6 +17,7 @@ #define RSPAMD_LANG_DETECTION_FASTTEXT_H #include "config.h" +#include "libserver/word.h" G_BEGIN_DECLS struct rspamd_config; @@ -53,7 +54,7 @@ typedef void *rspamd_fasttext_predict_result_t; * @return TRUE if language is detected */ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, - struct rspamd_task *task, GArray *utf_words, int k); + struct rspamd_task *task, rspamd_words_t *utf_words, int k); /** * Get number of languages detected diff --git a/src/libmime/message.c b/src/libmime/message.c index f2cabf399..8442c80ac 100644 --- a/src/libmime/message.c +++ b/src/libmime/message.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,8 @@ #include "contrib/uthash/utlist.h" #include "contrib/t1ha/t1ha.h" #include "received.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "libstat/tokenizers/custom_tokenizer.h" #define GTUBE_SYMBOL "GTUBE" @@ -71,14 +73,14 @@ rspamd_mime_part_extract_words(struct rspamd_task *task, rspamd_stat_token_t *w; unsigned int i, total_len = 0, short_len = 0; - if (part->utf_words) { - rspamd_stem_words(part->utf_words, task->task_pool, part->language, + if (part->utf_words.a) { + rspamd_stem_words(&part->utf_words, task->task_pool, part->language, task->lang_det); - for (i = 0; i < part->utf_words->len; i++) { + for (i = 0; i < kv_size(part->utf_words); i++) { uint64_t h; - w = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + w = &kv_A(part->utf_words, i); if (w->stemmed.len > 0) { /* @@ -108,7 +110,7 @@ rspamd_mime_part_extract_words(struct rspamd_task *task, } } - if (part->utf_words->len) { + if (kv_size(part->utf_words)) { double *avg_len_p, *short_len_p; avg_len_p = rspamd_mempool_get_variable(task->task_pool, @@ -185,21 +187,24 @@ rspamd_mime_part_create_words(struct rspamd_task *task, tok_type = RSPAMD_TOKENIZE_RAW; } - part->utf_words = rspamd_tokenize_text( + /* Initialize kvec for words */ + kv_init(part->utf_words); + + rspamd_tokenize_text( part->utf_stripped_content->data, part->utf_stripped_content->len, &part->utf_stripped_text, tok_type, task->cfg, part->exceptions, NULL, - NULL, + &part->utf_words, task->task_pool); - if (part->utf_words) { + if (part->utf_words.a) { part->normalized_hashes = g_array_sized_new(FALSE, FALSE, - sizeof(uint64_t), part->utf_words->len); - rspamd_normalize_words(part->utf_words, task->task_pool); + sizeof(uint64_t), kv_size(part->utf_words)); + rspamd_normalize_words(&part->utf_words, task->task_pool); } } @@ -209,7 +214,7 @@ rspamd_mime_part_detect_language(struct rspamd_task *task, { struct rspamd_lang_detector_res *lang; - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words && part->utf_words->len > 0 && + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a && kv_size(part->utf_words) > 0 && task->lang_det) { if (rspamd_language_detector_detect(task, task->lang_det, part)) { lang = g_ptr_array_index(part->languages, 0); @@ -1106,8 +1111,8 @@ rspamd_message_dtor(struct rspamd_message *msg) PTR_ARRAY_FOREACH(msg->text_parts, i, tp) { - if (tp->utf_words) { - g_array_free(tp->utf_words, TRUE); + if (tp->utf_words.a) { + kv_destroy(tp->utf_words); } if (tp->normalized_hashes) { g_array_free(tp->normalized_hashes, TRUE); @@ -1583,7 +1588,7 @@ void rspamd_message_process(struct rspamd_task *task) rspamd_mime_part_extract_words(task, text_part); - if (text_part->utf_words) { + if (text_part->utf_words.a) { total_words += text_part->nwords; } } diff --git a/src/libmime/message.h b/src/libmime/message.h index cb695773e..e6b454362 100644 --- a/src/libmime/message.h +++ b/src/libmime/message.h @@ -16,6 +16,7 @@ #include "libserver/url.h" #include "libutil/ref.h" #include "libutil/str_util.h" +#include "libserver/word.h" #include <unicode/uchar.h> #include <unicode/utext.h> @@ -139,7 +140,7 @@ struct rspamd_mime_text_part { GByteArray *utf_raw_content; /* utf raw content */ GByteArray *utf_stripped_content; /* utf content with no newlines */ GArray *normalized_hashes; /* Array of uint64_t */ - GArray *utf_words; /* Array of rspamd_stat_token_t */ + rspamd_words_t utf_words; /* kvec of rspamd_word_t */ UText utf_stripped_text; /* Used by libicu to represent the utf8 content */ GPtrArray *newlines; /**< positions of newlines in text, relative to content*/ diff --git a/src/libmime/mime_string.hxx b/src/libmime/mime_string.hxx index b181576d3..d6c11d018 100644 --- a/src/libmime/mime_string.hxx +++ b/src/libmime/mime_string.hxx @@ -497,19 +497,19 @@ public: } /* Comparison */ - auto operator==(const basic_mime_string &other) + auto operator==(const basic_mime_string &other) const { return other.storage == storage; } - auto operator==(const storage_type &other) + auto operator==(const storage_type &other) const { return other == storage; } - auto operator==(const view_type &other) + auto operator==(const view_type &other) const { return other == storage; } - auto operator==(const CharT *other) + auto operator==(const CharT *other) const { if (other == NULL) { return false; diff --git a/src/libserver/CMakeLists.txt b/src/libserver/CMakeLists.txt index dd17865de..d3415bdb2 100644 --- a/src/libserver/CMakeLists.txt +++ b/src/libserver/CMakeLists.txt @@ -12,6 +12,7 @@ SET(LIBRSPAMDSERVERSRC ${CMAKE_CURRENT_SOURCE_DIR}/fuzzy_backend/fuzzy_backend.c ${CMAKE_CURRENT_SOURCE_DIR}/fuzzy_backend/fuzzy_backend_sqlite.c ${CMAKE_CURRENT_SOURCE_DIR}/fuzzy_backend/fuzzy_backend_redis.c + ${CMAKE_CURRENT_SOURCE_DIR}/fuzzy_backend/fuzzy_backend_noop.c ${CMAKE_CURRENT_SOURCE_DIR}/milter.c ${CMAKE_CURRENT_SOURCE_DIR}/monitored.c ${CMAKE_CURRENT_SOURCE_DIR}/protocol.c diff --git a/src/libserver/backtrace.cxx b/src/libserver/backtrace.cxx index c24e61936..5ebde677e 100644 --- a/src/libserver/backtrace.cxx +++ b/src/libserver/backtrace.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #ifdef BACKWARD_ENABLE #include "contrib/backward-cpp/backward.hpp" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "logger.h" namespace rspamd { diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index f59c6ff89..36941da7a 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ struct worker_s; struct rspamd_external_libs_ctx; struct rspamd_cryptobox_pubkey; struct rspamd_dns_resolver; +struct rspamd_tokenizer_manager; /** * Logging type @@ -395,6 +396,8 @@ struct rspamd_config { unsigned int log_error_elts; /**< number of elements in error logbuf */ unsigned int log_error_elt_maxlen; /**< maximum size of error log element */ unsigned int log_task_max_elts; /**< maximum number of elements in task logging */ + unsigned int log_max_tag_len; /**< maximum length of log tag */ + char *log_tag_strip_policy_str; /**< log tag strip policy string */ struct rspamd_worker_log_pipe *log_pipes; gboolean compat_messages; /**< use old messages in the protocol (array) */ @@ -495,9 +498,10 @@ struct rspamd_config { char *zstd_output_dictionary; /**< path to zstd output dictionary */ ucl_object_t *neighbours; /**< other servers in the cluster */ - struct rspamd_config_settings_elt *setting_ids; /**< preprocessed settings ids */ - struct rspamd_lang_detector *lang_det; /**< language detector */ - struct rspamd_worker *cur_worker; /**< set dynamically by each worker */ + struct rspamd_config_settings_elt *setting_ids; /**< preprocessed settings ids */ + struct rspamd_lang_detector *lang_det; /**< language detector */ + struct rspamd_tokenizer_manager *tokenizer_manager; /**< custom tokenizer manager */ + struct rspamd_worker *cur_worker; /**< set dynamically by each worker */ ref_entry_t ref; /**< reference counter */ }; diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx index 79509e12e..0a48e8a4f 100644 --- a/src/libserver/cfg_rcl.cxx +++ b/src/libserver/cfg_rcl.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ #include <algorithm>// for std::transform #include <memory> #include "contrib/ankerl/unordered_dense.h" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "libutil/cxx/util.hxx" #include "libutil/cxx/file_util.hxx" #include "frozen/unordered_set.h" @@ -299,6 +299,14 @@ rspamd_rcl_logging_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, cfg->log_flags |= RSPAMD_LOG_FLAG_USEC; } + /* Set default values for new log tag options */ + if (cfg->log_max_tag_len == 0) { + cfg->log_max_tag_len = RSPAMD_LOG_ID_LEN; /* Default to new max size */ + } + if (cfg->log_tag_strip_policy_str == NULL) { + cfg->log_tag_strip_policy_str = rspamd_mempool_strdup(cfg->cfg_pool, "right"); + } + return rspamd_rcl_section_parse_defaults(cfg, *section, cfg->cfg_pool, obj, (void *) cfg, err); } @@ -1700,6 +1708,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) G_STRUCT_OFFSET(struct rspamd_config, log_task_max_elts), RSPAMD_CL_FLAG_UINT, "Maximum number of elements in task log entry (7 by default)"); + rspamd_rcl_add_default_handler(sub, + "max_tag_len", + rspamd_rcl_parse_struct_integer, + G_STRUCT_OFFSET(struct rspamd_config, log_max_tag_len), + RSPAMD_CL_FLAG_UINT, + "Maximum length of log tag cannot exceed 32 (" G_STRINGIFY(RSPAMD_LOG_ID_LEN) ") by default)"); + rspamd_rcl_add_default_handler(sub, + "tag_strip_policy", + rspamd_rcl_parse_struct_string, + G_STRUCT_OFFSET(struct rspamd_config, log_tag_strip_policy_str), + 0, + "Log tag strip policy when tag exceeds max length: 'right', 'left', 'middle' (right by default)"); /* Documentation only options, handled in log_handler to map flags */ rspamd_rcl_add_doc_by_path(cfg, @@ -2210,7 +2230,7 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) rspamd_rcl_add_doc_by_path(cfg, "options", - "Swtich mode of gtube patterns: disable, reject, all", + "Switch mode of gtube patterns: disable, reject, all", "gtube_patterns", UCL_STRING, nullptr, @@ -2308,7 +2328,7 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) rspamd_rcl_parse_struct_time, G_STRUCT_OFFSET(struct rspamd_config, upstream_resolve_min_interval), RSPAMD_CL_FLAG_TIME_FLOAT, - "Minumum interval to perform resolving (60 seconds by default)"); + "Minimum interval to perform resolving (60 seconds by default)"); } if (!(skip_sections && g_hash_table_lookup(skip_sections, "actions"))) { @@ -3640,7 +3660,7 @@ rspamd_config_parse_ucl(struct rspamd_config *cfg, /* Try to load keyfile if available */ auto keyfile_name = fmt::format("{}.key", filename); rspamd::util::raii_file::open(keyfile_name, O_RDONLY).map([&](const auto &keyfile) { - auto *kp_parser = ucl_parser_new(0); + auto *kp_parser = ucl_parser_new(UCL_PARSER_DEFAULT); if (ucl_parser_add_fd(kp_parser, keyfile.get_fd())) { auto *kp_obj = ucl_parser_get_object(kp_parser); diff --git a/src/libserver/cfg_utils.cxx b/src/libserver/cfg_utils.cxx index 9612cdae4..c7bb20210 100644 --- a/src/libserver/cfg_utils.cxx +++ b/src/libserver/cfg_utils.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,13 +65,18 @@ #include <string> #include <string_view> #include <vector> -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "cxx/util.hxx" #include "frozen/unordered_map.h" #include "frozen/string.h" #include "contrib/expected/expected.hpp" #include "contrib/ankerl/unordered_dense.h" +#include "libserver/task.h" +#include "libserver/url.h" +#define RSPAMD_TOKENIZER_INTERNAL// We need to use internal tokenizer API +#include "libstat/tokenizers/custom_tokenizer.h" + #define DEFAULT_SCORE 10.0 #define DEFAULT_RLIMIT_NOFILE 2048 @@ -821,6 +826,65 @@ rspamd_adjust_clocks_resolution(struct rspamd_config *cfg) #endif } +extern "C" { + +gboolean +rspamd_config_load_custom_tokenizers(struct rspamd_config *cfg, GError **err) +{ + /* Load custom tokenizers */ + const ucl_object_t *custom_tokenizers = ucl_object_lookup_path(cfg->cfg_ucl_obj, + "options.custom_tokenizers"); + if (custom_tokenizers != NULL) { + msg_info_config("loading custom tokenizers"); + + if (!cfg->tokenizer_manager) { + cfg->tokenizer_manager = rspamd_tokenizer_manager_new(cfg->cfg_pool); + } + + ucl_object_iter_t it = ucl_object_iterate_new(custom_tokenizers); + const ucl_object_t *tok_obj; + const char *tok_name; + + while ((tok_obj = ucl_object_iterate_safe(it, true)) != NULL) { + tok_name = ucl_object_key(tok_obj); + GError *local_err = NULL; + + if (!rspamd_tokenizer_manager_load_tokenizer(cfg->tokenizer_manager, + tok_name, tok_obj, &local_err)) { + msg_err_config("failed to load custom tokenizer '%s': %s", + tok_name, local_err ? local_err->message : "unknown error"); + + if (err && !*err) { + *err = g_error_copy(local_err); + } + + if (local_err) { + g_error_free(local_err); + } + + ucl_object_iterate_free(it); + return FALSE; + } + } + ucl_object_iterate_free(it); + + msg_info_config("loaded custom tokenizers successfully"); + } + + return TRUE; +} + +void rspamd_config_unload_custom_tokenizers(struct rspamd_config *cfg) +{ + if (cfg->tokenizer_manager) { + msg_info_config("unloading custom tokenizers"); + rspamd_tokenizer_manager_destroy(cfg->tokenizer_manager); + cfg->tokenizer_manager = NULL; + } +} + +}// extern "C" + /* * Perform post load actions */ @@ -940,6 +1004,20 @@ rspamd_config_post_load(struct rspamd_config *cfg, msg_err_config("cannot configure libraries, fatal error"); return FALSE; } + + /* Load custom tokenizers using the new function */ + GError *tokenizer_err = NULL; + if (!rspamd_config_load_custom_tokenizers(cfg, &tokenizer_err)) { + msg_err_config("failed to load custom tokenizers: %s", + tokenizer_err ? tokenizer_err->message : "unknown error"); + if (tokenizer_err) { + g_error_free(tokenizer_err); + } + + if (opts & RSPAMD_CONFIG_INIT_VALIDATE) { + ret = tl::make_unexpected(std::string{"failed to load custom tokenizers"}); + } + } } /* Validate cache */ @@ -1363,7 +1441,7 @@ rspamd_ucl_fin_cb(struct map_cb_data *data, void **target) } /* New data available */ - auto *parser = ucl_parser_new(0); + auto *parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, (unsigned char *) cbdata->buf.data(), cbdata->buf.size())) { msg_err_config("cannot parse map %s: %s", diff --git a/src/libserver/css/css_parser.cxx b/src/libserver/css/css_parser.cxx index 11fa830f0..ade499ba4 100644 --- a/src/libserver/css/css_parser.cxx +++ b/src/libserver/css/css_parser.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "css_rule.hxx" #include "css_util.hxx" #include "css.hxx" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include <vector> #include <unicode/utf8.h> diff --git a/src/libserver/css/css_selector.cxx b/src/libserver/css/css_selector.cxx index d2ae093cb..527b12377 100644 --- a/src/libserver/css/css_selector.cxx +++ b/src/libserver/css/css_selector.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #include "css_selector.hxx" #include "css.hxx" #include "libserver/html/html.hxx" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #define DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL #include "doctest/doctest.h" diff --git a/src/libserver/css/css_value.cxx b/src/libserver/css/css_value.cxx index f2ff55363..52a61d3b6 100644 --- a/src/libserver/css/css_value.cxx +++ b/src/libserver/css/css_value.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "frozen/string.h" #include "libutil/util.h" #include "contrib/ankerl/unordered_dense.h" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #define DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL #include "doctest/doctest.h" diff --git a/src/libserver/dkim.c b/src/libserver/dkim.c index 0f51c66c0..8b61d39a9 100644 --- a/src/libserver/dkim.c +++ b/src/libserver/dkim.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -2141,8 +2141,7 @@ rspamd_dkim_canonize_body(struct rspamd_task *task, if (ctx->body_canon_type == DKIM_CANON_SIMPLE) { /* Simple canonization */ while (rspamd_dkim_simple_body_step(ctx, ctx->body_hash, - &start, end - start, &remain)) - ; + &start, end - start, &remain)); /* * If we have l= tag then we cannot add crlf... @@ -2178,8 +2177,7 @@ rspamd_dkim_canonize_body(struct rspamd_task *task, size_t orig_len = remain; while (rspamd_dkim_relaxed_body_step(ctx, ctx->body_hash, - &start, end - start, &remain)) - ; + &start, end - start, &remain)); if (ctx->len > 0 && remain > (double) orig_len * 0.1) { msg_info_task("DKIM l tag does not cover enough of the body: %d (%d actual size)", @@ -2874,50 +2872,71 @@ rspamd_dkim_check(rspamd_dkim_context_t *ctx, case RSPAMD_DKIM_KEY_RSA: { GError *err = NULL; - if (!rspamd_cryptobox_verify_evp_rsa(nid, ctx->b, ctx->blen, raw_digest, dlen, - key->specific.key_ssl.key_evp, &err)) { - - if (err == NULL) { - msg_debug_dkim("headers rsa verify failed"); - ERR_clear_error(); - res->rcode = DKIM_REJECT; - res->fail_reason = "headers rsa verify failed"; - - msg_info_dkim( - "%s: headers RSA verification failure; " - "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", - rspamd_dkim_type_to_string(ctx->common.type), - (int) (body_end - body_start), ctx->common.body_canonicalised, - ctx->common.headers_canonicalised, - ctx->domain, ctx->selector, - RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), - ctx->dkim_header); - } - else { - res->rcode = DKIM_PERM_ERROR; - res->fail_reason = "openssl internal error"; - msg_err_dkim("internal OpenSSL error: %s", err->message); - msg_info_dkim( - "%s: headers RSA verification failure due to OpenSSL internal error; " - "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", - rspamd_dkim_type_to_string(ctx->common.type), - (int) (body_end - body_start), ctx->common.body_canonicalised, - ctx->common.headers_canonicalised, - ctx->domain, ctx->selector, - RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), - ctx->dkim_header); - - ERR_clear_error(); - g_error_free(err); + if (ctx->sig_alg == DKIM_SIGN_ECDSASHA256 || + ctx->sig_alg == DKIM_SIGN_EDDSASHA256 || + ctx->sig_alg == DKIM_SIGN_ECDSASHA512) { + /* RSA key provided for ECDSA/EDDSA signature */ + res->rcode = DKIM_PERM_ERROR; + res->fail_reason = "rsa key for ecdsa/eddsa signature"; + msg_info_dkim( + "%s: wrong RSA key for ecdsa/eddsa signature; " + "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", + rspamd_dkim_type_to_string(ctx->common.type), + (int) (body_end - body_start), ctx->common.body_canonicalised, + ctx->common.headers_canonicalised, + ctx->domain, ctx->selector, + RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), + ctx->dkim_header); + } + else { + if (!rspamd_cryptobox_verify_evp_rsa(nid, ctx->b, ctx->blen, raw_digest, dlen, + key->specific.key_ssl.key_evp, &err)) { + + if (err == NULL) { + msg_debug_dkim("headers rsa verify failed"); + ERR_clear_error(); + res->rcode = DKIM_REJECT; + res->fail_reason = "headers rsa verify failed"; + + msg_info_dkim( + "%s: headers RSA verification failure; " + "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", + rspamd_dkim_type_to_string(ctx->common.type), + (int) (body_end - body_start), ctx->common.body_canonicalised, + ctx->common.headers_canonicalised, + ctx->domain, ctx->selector, + RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), + ctx->dkim_header); + } + else { + res->rcode = DKIM_PERM_ERROR; + res->fail_reason = "openssl internal error"; + msg_err_dkim("internal OpenSSL error: %s", err->message); + msg_info_dkim( + "%s: headers RSA verification failure due to OpenSSL internal error; " + "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", + rspamd_dkim_type_to_string(ctx->common.type), + (int) (body_end - body_start), ctx->common.body_canonicalised, + ctx->common.headers_canonicalised, + ctx->domain, ctx->selector, + RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), + ctx->dkim_header); + + ERR_clear_error(); + g_error_free(err); + } } } break; } case RSPAMD_DKIM_KEY_ECDSA: - if (rspamd_cryptobox_verify_evp_ecdsa(nid, ctx->b, ctx->blen, raw_digest, dlen, - key->specific.key_ssl.key_evp) != 1) { + if (ctx->sig_alg != DKIM_SIGN_ECDSASHA256 && + ctx->sig_alg != DKIM_SIGN_ECDSASHA512) { + /* ECDSA key provided for RSA/EDDSA signature */ + res->rcode = DKIM_PERM_ERROR; + res->fail_reason = "ECDSA key for rsa/eddsa signature"; msg_info_dkim( - "%s: headers ECDSA verification failure; " + "%s: ECDSA key for rsa/eddsa signature; " "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", rspamd_dkim_type_to_string(ctx->common.type), (int) (body_end - body_start), ctx->common.body_canonicalised, @@ -2925,18 +2944,34 @@ rspamd_dkim_check(rspamd_dkim_context_t *ctx, ctx->domain, ctx->selector, RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), ctx->dkim_header); - msg_debug_dkim("headers ecdsa verify failed"); - ERR_clear_error(); - res->rcode = DKIM_REJECT; - res->fail_reason = "headers ecdsa verify failed"; + } + else { + if (rspamd_cryptobox_verify_evp_ecdsa(nid, ctx->b, ctx->blen, raw_digest, dlen, + key->specific.key_ssl.key_evp) != 1) { + msg_info_dkim( + "%s: headers ECDSA verification failure; " + "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", + rspamd_dkim_type_to_string(ctx->common.type), + (int) (body_end - body_start), ctx->common.body_canonicalised, + ctx->common.headers_canonicalised, + ctx->domain, ctx->selector, + RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), + ctx->dkim_header); + msg_debug_dkim("headers ecdsa verify failed"); + ERR_clear_error(); + res->rcode = DKIM_REJECT; + res->fail_reason = "headers ecdsa verify failed"; + } } break; case RSPAMD_DKIM_KEY_EDDSA: - if (!rspamd_cryptobox_verify(ctx->b, ctx->blen, raw_digest, dlen, - key->specific.key_eddsa)) { + if (ctx->sig_alg != DKIM_SIGN_EDDSASHA256) { + /* EDDSA key provided for RSA/ECDSA signature */ + res->rcode = DKIM_PERM_ERROR; + res->fail_reason = "EDDSA key for rsa/ecdsa signature"; msg_info_dkim( - "%s: headers EDDSA verification failure; " + "%s: EDDSA key for rsa/ecdsa signature; " "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", rspamd_dkim_type_to_string(ctx->common.type), (int) (body_end - body_start), ctx->common.body_canonicalised, @@ -2944,14 +2979,27 @@ rspamd_dkim_check(rspamd_dkim_context_t *ctx, ctx->domain, ctx->selector, RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), ctx->dkim_header); - msg_debug_dkim("headers eddsa verify failed"); - res->rcode = DKIM_REJECT; - res->fail_reason = "headers eddsa verify failed"; + } + else { + if (!rspamd_cryptobox_verify(ctx->b, ctx->blen, raw_digest, dlen, + key->specific.key_eddsa)) { + msg_info_dkim( + "%s: headers EDDSA verification failure; " + "body length %d->%d; headers length %d; d=%s; s=%s; key_md5=%*xs; orig header: %s", + rspamd_dkim_type_to_string(ctx->common.type), + (int) (body_end - body_start), ctx->common.body_canonicalised, + ctx->common.headers_canonicalised, + ctx->domain, ctx->selector, + RSPAMD_DKIM_KEY_ID_LEN, rspamd_dkim_key_id(key), + ctx->dkim_header); + msg_debug_dkim("headers eddsa verify failed"); + res->rcode = DKIM_REJECT; + res->fail_reason = "headers eddsa verify failed"; + } } break; } - if (ctx->common.type == RSPAMD_DKIM_ARC_SEAL && res->rcode == DKIM_CONTINUE) { switch (ctx->cv) { case RSPAMD_ARC_INVALID: diff --git a/src/libserver/dynamic_cfg.c b/src/libserver/dynamic_cfg.c index 984517074..6d648d745 100644 --- a/src/libserver/dynamic_cfg.c +++ b/src/libserver/dynamic_cfg.c @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -195,7 +195,7 @@ json_config_fin_cb(struct map_cb_data *data, void **target) return; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, jb->buf->str, jb->buf->len)) { msg_err("cannot load json data: parse error %s", diff --git a/src/libserver/fuzzy_backend/fuzzy_backend.c b/src/libserver/fuzzy_backend/fuzzy_backend.c index c18463618..3d5cbb863 100644 --- a/src/libserver/fuzzy_backend/fuzzy_backend.c +++ b/src/libserver/fuzzy_backend/fuzzy_backend.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,7 @@ #include "fuzzy_backend.h" #include "fuzzy_backend_sqlite.h" #include "fuzzy_backend_redis.h" +#include "fuzzy_backend_noop.h" #include "cfg_file.h" #include "fuzzy_wire.h" @@ -26,6 +27,7 @@ enum rspamd_fuzzy_backend_type { RSPAMD_FUZZY_BACKEND_SQLITE = 0, RSPAMD_FUZZY_BACKEND_REDIS = 1, + RSPAMD_FUZZY_BACKEND_NOOP = 2, }; static void *rspamd_fuzzy_backend_init_sqlite(struct rspamd_fuzzy_backend *bk, @@ -96,6 +98,16 @@ static const struct rspamd_fuzzy_backend_subr fuzzy_subrs[] = { .id = rspamd_fuzzy_backend_id_redis, .periodic = rspamd_fuzzy_backend_expire_redis, .close = rspamd_fuzzy_backend_close_redis, + }, + [RSPAMD_FUZZY_BACKEND_NOOP] = { + .init = rspamd_fuzzy_backend_init_noop, + .check = rspamd_fuzzy_backend_check_noop, + .update = rspamd_fuzzy_backend_update_noop, + .count = rspamd_fuzzy_backend_count_noop, + .version = rspamd_fuzzy_backend_version_noop, + .id = rspamd_fuzzy_backend_id_noop, + .periodic = rspamd_fuzzy_backend_expire_noop, + .close = rspamd_fuzzy_backend_close_noop, }}; struct rspamd_fuzzy_backend { @@ -288,6 +300,9 @@ rspamd_fuzzy_backend_create(struct ev_loop *ev_base, else if (strcmp(ucl_object_tostring(elt), "redis") == 0) { type = RSPAMD_FUZZY_BACKEND_REDIS; } + else if (strcmp(ucl_object_tostring(elt), "noop") == 0) { + type = RSPAMD_FUZZY_BACKEND_NOOP; + } else { g_set_error(err, rspamd_fuzzy_backend_quark(), EINVAL, "invalid backend type: %s", @@ -547,6 +562,11 @@ void rspamd_fuzzy_backend_close(struct rspamd_fuzzy_backend *bk) g_free(bk); } +bool rspamd_fuzzy_backend_is_noop(struct rspamd_fuzzy_backend *bk) +{ + return bk->type == RSPAMD_FUZZY_BACKEND_NOOP; +} + struct ev_loop * rspamd_fuzzy_backend_event_base(struct rspamd_fuzzy_backend *backend) { diff --git a/src/libserver/fuzzy_backend/fuzzy_backend.h b/src/libserver/fuzzy_backend/fuzzy_backend.h index fe22d473e..249c4d1c3 100644 --- a/src/libserver/fuzzy_backend/fuzzy_backend.h +++ b/src/libserver/fuzzy_backend/fuzzy_backend.h @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -124,6 +124,8 @@ double rspamd_fuzzy_backend_get_expire(struct rspamd_fuzzy_backend *backend); */ void rspamd_fuzzy_backend_close(struct rspamd_fuzzy_backend *backend); +bool rspamd_fuzzy_backend_is_noop(struct rspamd_fuzzy_backend *bk); + #ifdef __cplusplus } #endif diff --git a/src/libserver/fuzzy_backend/fuzzy_backend_noop.c b/src/libserver/fuzzy_backend/fuzzy_backend_noop.c new file mode 100644 index 000000000..024d19882 --- /dev/null +++ b/src/libserver/fuzzy_backend/fuzzy_backend_noop.c @@ -0,0 +1,97 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "fuzzy_backend_noop.h" + +/* + * No operations backend (useful for scripts only stuff) + */ + +void *rspamd_fuzzy_backend_init_noop(struct rspamd_fuzzy_backend *bk, + const ucl_object_t *obj, + struct rspamd_config *cfg, + GError **err) +{ + /* Return non-NULL to distinguish from error */ + return (void *) (uintptr_t) (-1); +} + +void rspamd_fuzzy_backend_check_noop(struct rspamd_fuzzy_backend *bk, + const struct rspamd_fuzzy_cmd *cmd, + rspamd_fuzzy_check_cb cb, void *ud, + void *subr_ud) +{ + struct rspamd_fuzzy_reply rep; + + if (cb) { + memset(&rep, 0, sizeof(rep)); + cb(&rep, ud); + } + + return; +} + +void rspamd_fuzzy_backend_update_noop(struct rspamd_fuzzy_backend *bk, + GArray *updates, const char *src, + rspamd_fuzzy_update_cb cb, void *ud, + void *subr_ud) +{ + if (cb) { + cb(TRUE, 0, 0, 0, 0, ud); + } + + return; +} + +void rspamd_fuzzy_backend_count_noop(struct rspamd_fuzzy_backend *bk, + rspamd_fuzzy_count_cb cb, void *ud, + void *subr_ud) +{ + if (cb) { + cb(0, ud); + } + + return; +} + +void rspamd_fuzzy_backend_version_noop(struct rspamd_fuzzy_backend *bk, + const char *src, + rspamd_fuzzy_version_cb cb, void *ud, + void *subr_ud) +{ + if (cb) { + cb(0, ud); + } + + return; +} + +const char *rspamd_fuzzy_backend_id_noop(struct rspamd_fuzzy_backend *bk, + void *subr_ud) +{ + return NULL; +} + +void rspamd_fuzzy_backend_expire_noop(struct rspamd_fuzzy_backend *bk, + void *subr_ud) +{ +} + +void rspamd_fuzzy_backend_close_noop(struct rspamd_fuzzy_backend *bk, + void *subr_ud) +{ +} diff --git a/src/libserver/fuzzy_backend/fuzzy_backend_noop.h b/src/libserver/fuzzy_backend/fuzzy_backend_noop.h new file mode 100644 index 000000000..ac063dc39 --- /dev/null +++ b/src/libserver/fuzzy_backend/fuzzy_backend_noop.h @@ -0,0 +1,66 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FUZZY_BACKEND_NOOP_H +#define FUZZY_BACKEND_NOOP_H + +#include "config.h" +#include "fuzzy_backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Subroutines for fuzzy_backend + */ +void *rspamd_fuzzy_backend_init_noop(struct rspamd_fuzzy_backend *bk, + const ucl_object_t *obj, + struct rspamd_config *cfg, + GError **err); + +void rspamd_fuzzy_backend_check_noop(struct rspamd_fuzzy_backend *bk, + const struct rspamd_fuzzy_cmd *cmd, + rspamd_fuzzy_check_cb cb, void *ud, + void *subr_ud); + +void rspamd_fuzzy_backend_update_noop(struct rspamd_fuzzy_backend *bk, + GArray *updates, const char *src, + rspamd_fuzzy_update_cb cb, void *ud, + void *subr_ud); + +void rspamd_fuzzy_backend_count_noop(struct rspamd_fuzzy_backend *bk, + rspamd_fuzzy_count_cb cb, void *ud, + void *subr_ud); + +void rspamd_fuzzy_backend_version_noop(struct rspamd_fuzzy_backend *bk, + const char *src, + rspamd_fuzzy_version_cb cb, void *ud, + void *subr_ud); + +const char *rspamd_fuzzy_backend_id_noop(struct rspamd_fuzzy_backend *bk, + void *subr_ud); + +void rspamd_fuzzy_backend_expire_noop(struct rspamd_fuzzy_backend *bk, + void *subr_ud); + +void rspamd_fuzzy_backend_close_noop(struct rspamd_fuzzy_backend *bk, + void *subr_ud); + +#ifdef __cplusplus +} +#endif + +#endif//FUZZY_BACKEND_NOOP_H diff --git a/src/libserver/fuzzy_backend/fuzzy_backend_redis.c b/src/libserver/fuzzy_backend/fuzzy_backend_redis.c index 27c663070..f150d48be 100644 --- a/src/libserver/fuzzy_backend/fuzzy_backend_redis.c +++ b/src/libserver/fuzzy_backend/fuzzy_backend_redis.c @@ -116,11 +116,9 @@ rspamd_redis_get_servers(struct rspamd_fuzzy_backend_redis *ctx, res = *((struct upstream_list **) lua_touserdata(L, -1)); } else { - struct lua_logger_trace tr; char outbuf[8192]; - memset(&tr, 0, sizeof(tr)); - lua_logger_out_type(L, -2, outbuf, sizeof(outbuf) - 1, &tr, + lua_logger_out(L, -2, outbuf, sizeof(outbuf), LUA_ESCAPE_UNPRINTABLE); msg_err("cannot get %s upstreams for Redis fuzzy storage %s; table content: %s", diff --git a/src/libserver/fuzzy_backend/fuzzy_backend_redis.h b/src/libserver/fuzzy_backend/fuzzy_backend_redis.h index afeb1c573..0a536c2fa 100644 --- a/src/libserver/fuzzy_backend/fuzzy_backend_redis.h +++ b/src/libserver/fuzzy_backend/fuzzy_backend_redis.h @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,7 +15,6 @@ */ #ifndef SRC_LIBSERVER_FUZZY_BACKEND_REDIS_H_ #define SRC_LIBSERVER_FUZZY_BACKEND_REDIS_H_ - #include "config.h" #include "fuzzy_backend.h" diff --git a/src/libserver/html/html.cxx b/src/libserver/html/html.cxx index 0fe31c2a3..93d1fdf91 100644 --- a/src/libserver/html/html.cxx +++ b/src/libserver/html/html.cxx @@ -1,11 +1,11 @@ -/*- - * Copyright 2021 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -35,9 +35,9 @@ #include "html_tag.hxx" #include "html_url.hxx" -#include <frozen/unordered_map.h> -#include <frozen/string.h> -#include <fmt/core.h> +#include "contrib/frozen/include/frozen/unordered_map.h" +#include "contrib/frozen/include/frozen/string.h" +#include "contrib/fmt/include/fmt/core.h" #include <unicode/uversion.h> diff --git a/src/libserver/html/html_tests.cxx b/src/libserver/html/html_tests.cxx index 00595feaa..3be836a2d 100644 --- a/src/libserver/html/html_tests.cxx +++ b/src/libserver/html/html_tests.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "libserver/task.h" #include <vector> -#include <fmt/core.h> +#include "contrib/fmt/include/fmt/core.h" #define DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL diff --git a/src/libserver/http/http_connection.c b/src/libserver/http/http_connection.c index baf37a385..b5d70fc1c 100644 --- a/src/libserver/http/http_connection.c +++ b/src/libserver/http/http_connection.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1670,7 +1670,22 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, { char datebuf[64]; int meth_len = 0; - const char *conn_type = "close"; + const char *server_conn_header, *client_conn_header; + + /* Set up connection header strings based on flags and connection type */ + if (msg->flags & RSPAMD_HTTP_FLAG_HAS_CONNECTION_HEADER) { + server_conn_header = ""; + client_conn_header = ""; + } + else { + server_conn_header = "Connection: close\r\n"; + if (conn->opts & RSPAMD_HTTP_CLIENT_KEEP_ALIVE) { + client_conn_header = "Connection: keep-alive\r\n"; + } + else { + client_conn_header = "Connection: close\r\n"; + } + } if (conn->type == RSPAMD_HTTP_SERVER) { /* Format reply */ @@ -1712,12 +1727,14 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_snprintf(repbuf, replen, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n" "Content-Type: %s", /* NO \r\n at the end ! */ - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen, mime_type); } @@ -1725,11 +1742,13 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_snprintf(repbuf, replen, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z", /* NO \r\n at the end ! */ - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen); } @@ -1737,11 +1756,12 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, /* External reply */ rspamd_printf_fstring(buf, "HTTP/1.1 200 OK\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n", + server_conn_header, priv->ctx->config.server_hdr, datebuf, enclen); } @@ -1750,12 +1770,14 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_printf_fstring(buf, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n" "Content-Type: %s\r\n", - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen, mime_type); } @@ -1763,11 +1785,13 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_printf_fstring(buf, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n", - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen); } @@ -1804,10 +1828,6 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, else { /* Client request */ - if (conn->opts & RSPAMD_HTTP_CLIENT_KEEP_ALIVE) { - conn_type = "keep-alive"; - } - /* Format request */ enclen += RSPAMD_FSTRING_LEN(msg->url) + strlen(http_method_str(msg->method)) + 1; @@ -1819,21 +1839,21 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, "%s %s HTTP/1.0\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n" - "Connection: %s\r\n", + "%s", "POST", "/post", enclen, - conn_type); + client_conn_header); } else { rspamd_printf_fstring(buf, "%s %V HTTP/1.0\r\n" "Content-Length: %z\r\n" - "Connection: %s\r\n", + "%s", http_method_str(msg->method), msg->url, bodylen, - conn_type); + client_conn_header); if (bodylen > 0) { if (mime_type == NULL) { @@ -1857,26 +1877,26 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, if (rspamd_http_message_is_standard_port(msg)) { rspamd_printf_fstring(buf, "%s %s HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n", "POST", "/post", - conn_type, + client_conn_header, host, enclen); } else { rspamd_printf_fstring(buf, "%s %s HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s:%d\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n", "POST", "/post", - conn_type, + client_conn_header, host, msg->port, enclen); @@ -1888,21 +1908,21 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, if ((msg->flags & RSPAMD_HTTP_FLAG_HAS_HOST_HEADER)) { rspamd_printf_fstring(buf, "%s %s://%s:%d/%V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Content-Length: %z\r\n", http_method_str(msg->method), (conn->opts & RSPAMD_HTTP_CLIENT_SSL) ? "https" : "http", host, msg->port, msg->url, - conn_type, + client_conn_header, bodylen); } else { if (rspamd_http_message_is_standard_port(msg)) { rspamd_printf_fstring(buf, "%s %s://%s:%d/%V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), @@ -1910,14 +1930,14 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, host, msg->port, msg->url, - conn_type, + client_conn_header, host, bodylen); } else { rspamd_printf_fstring(buf, "%s %s://%s:%d/%V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s:%d\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), @@ -1925,7 +1945,7 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, host, msg->port, msg->url, - conn_type, + client_conn_header, host, msg->port, bodylen); @@ -1937,35 +1957,35 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, if ((msg->flags & RSPAMD_HTTP_FLAG_HAS_HOST_HEADER)) { rspamd_printf_fstring(buf, "%s %V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Content-Length: %z\r\n", http_method_str(msg->method), msg->url, - conn_type, + client_conn_header, bodylen); } else { if (rspamd_http_message_is_standard_port(msg)) { rspamd_printf_fstring(buf, "%s %V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), msg->url, - conn_type, + client_conn_header, host, bodylen); } else { rspamd_printf_fstring(buf, "%s %V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s:%d\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), msg->url, - conn_type, + client_conn_header, host, msg->port, bodylen); @@ -2633,4 +2653,4 @@ void rspamd_http_connection_disable_encryption(struct rspamd_http_connection *co priv->peer_key = NULL; priv->flags &= ~RSPAMD_HTTP_CONN_FLAG_ENCRYPTED; } -}
\ No newline at end of file +} diff --git a/src/libserver/http/http_connection.h b/src/libserver/http/http_connection.h index f6ec03d95..466a3edd9 100644 --- a/src/libserver/http/http_connection.h +++ b/src/libserver/http/http_connection.h @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -80,9 +80,13 @@ struct rspamd_storage_shmem { */ #define RSPAMD_HTTP_FLAG_HAS_HOST_HEADER (1 << 7) /** + * Connection header has been set for a message + */ +#define RSPAMD_HTTP_FLAG_HAS_CONNECTION_HEADER (1 << 8) +/** * Message is intended for SSL connection */ -#define RSPAMD_HTTP_FLAG_WANT_SSL (1 << 8) +#define RSPAMD_HTTP_FLAG_WANT_SSL (1 << 9) /** * Options for HTTP connection */ diff --git a/src/libserver/http/http_message.c b/src/libserver/http/http_message.c index 0c9708450..e5e4a0469 100644 --- a/src/libserver/http/http_message.c +++ b/src/libserver/http/http_message.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -539,6 +539,9 @@ void rspamd_http_message_add_header_len(struct rspamd_http_message *msg, if (g_ascii_strcasecmp(name, "host") == 0) { msg->flags |= RSPAMD_HTTP_FLAG_HAS_HOST_HEADER; } + else if (g_ascii_strcasecmp(name, "connection") == 0) { + msg->flags |= RSPAMD_HTTP_FLAG_HAS_CONNECTION_HEADER; + } hdr->combined = rspamd_fstring_sized_new(nlen + vlen + 4); rspamd_printf_fstring(&hdr->combined, "%s: %*s\r\n", name, (int) vlen, @@ -746,4 +749,4 @@ const char *rspamd_http_message_get_url(struct rspamd_http_message *msg, gsize * } return NULL; -}
\ No newline at end of file +} diff --git a/src/libserver/hyperscan_tools.cxx b/src/libserver/hyperscan_tools.cxx index 5035bee2c..75863bf39 100644 --- a/src/libserver/hyperscan_tools.cxx +++ b/src/libserver/hyperscan_tools.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include <filesystem> #include "contrib/ankerl/unordered_dense.h" #include "contrib/ankerl/svector.h" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "libutil/cxx/file_util.hxx" #include "libutil/cxx/error.hxx" #include "hs.h" diff --git a/src/libserver/logger/logger.c b/src/libserver/logger/logger.c index 25818e7a5..600b7f1e1 100644 --- a/src/libserver/logger/logger.c +++ b/src/libserver/logger/logger.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ #include "unix-std.h" #include "logger_private.h" - static rspamd_logger_t *default_logger = NULL; static rspamd_logger_t *emergency_logger = NULL; static struct rspamd_log_modules *log_modules = NULL; @@ -30,6 +29,61 @@ static struct rspamd_log_modules *log_modules = NULL; static const char lf_chr = '\n'; unsigned int rspamd_task_log_id = (unsigned int) -1; + +/** + * Strip log tag according to the configured policy + * @param original_tag original log tag + * @param original_len length of original tag + * @param dest destination buffer + * @param max_len maximum length allowed + * @param policy stripping policy + * @return actual length of stripped tag + */ +static gsize +rspamd_log_strip_tag(const char *original_tag, gsize original_len, + char *dest, gsize max_len, + enum rspamd_log_tag_strip_policy policy) +{ + if (original_len <= max_len) { + /* No stripping needed */ + memcpy(dest, original_tag, original_len); + return original_len; + } + + switch (policy) { + case RSPAMD_LOG_TAG_STRIP_RIGHT: + /* Cut right part (current behavior) */ + memcpy(dest, original_tag, max_len); + return max_len; + + case RSPAMD_LOG_TAG_STRIP_LEFT: + /* Cut left part (take last elements) */ + memcpy(dest, original_tag + (original_len - max_len), max_len); + return max_len; + + case RSPAMD_LOG_TAG_STRIP_MIDDLE: + /* Half from start and half from end */ + if (max_len >= 2) { + gsize first_half = max_len / 2; + gsize second_half = max_len - first_half; + + memcpy(dest, original_tag, first_half); + memcpy(dest + first_half, + original_tag + (original_len - second_half), + second_half); + } + else if (max_len == 1) { + /* Just take first character */ + dest[0] = original_tag[0]; + } + return max_len; + + default: + /* Fallback to right stripping */ + memcpy(dest, original_tag, max_len); + return max_len; + } +} RSPAMD_CONSTRUCTOR(rspamd_task_log_init) { rspamd_task_log_id = rspamd_logger_add_debug_module("task"); @@ -160,6 +214,10 @@ rspamd_log_open_emergency(rspamd_mempool_t *pool, int flags) logger->process_type = "main"; logger->pid = getpid(); + /* Initialize log tag configuration with defaults */ + logger->max_log_tag_len = RSPAMD_LOG_ID_LEN; /* Keep backward compatibility default */ + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_RIGHT; + const struct rspamd_logger_funcs *funcs = &console_log_funcs; memcpy(&logger->ops, funcs, sizeof(*funcs)); @@ -258,6 +316,28 @@ rspamd_log_open_specific(rspamd_mempool_t *pool, logger->process_type = ptype; logger->enabled = TRUE; + /* Initialize log tag configuration with defaults */ + if (cfg && cfg->log_max_tag_len > 0) { + logger->max_log_tag_len = MIN(MEMPOOL_UID_LEN, cfg->log_max_tag_len); + } + else { + logger->max_log_tag_len = RSPAMD_LOG_ID_LEN; /* Keep backward compatibility default */ + } + + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_RIGHT; + + if (cfg && cfg->log_tag_strip_policy_str) { + if (g_ascii_strcasecmp(cfg->log_tag_strip_policy_str, "left") == 0) { + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_LEFT; + } + else if (g_ascii_strcasecmp(cfg->log_tag_strip_policy_str, "middle") == 0) { + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_MIDDLE; + } + else { + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_RIGHT; /* Default */ + } + } + /* Set up conditional logging */ if (cfg) { if (cfg->debug_ip_map != NULL) { @@ -1026,6 +1106,36 @@ log_time(double now, rspamd_logger_t *rspamd_log, char *timebuf, } } +/** + * Process log ID with stripping policy and return the effective length + * @param logger logger instance with configuration + * @param id original log ID + * @param processed_id buffer to store processed ID (should be at least max_log_tag_len + 1) + * @return effective length of processed ID + */ +static inline int +rspamd_log_process_id(rspamd_logger_t *logger, const char *id, char *processed_id) +{ + if (id == NULL) { + return 0; + } + + gsize original_len = strlen(id); + gsize max_len = MIN(MEMPOOL_UID_LEN, logger->max_log_tag_len); + + if (original_len <= max_len) { + /* No processing needed */ + memcpy(processed_id, id, original_len); + return original_len; + } + + /* Apply stripping policy */ + gsize processed_len = rspamd_log_strip_tag(id, original_len, processed_id, max_len, + logger->log_tag_strip_policy); + + return processed_len; +} + void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, double ts, const char *module, @@ -1059,8 +1169,17 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, if (G_UNLIKELY(log_json)) { /* Perform JSON logging */ - unsigned int slen = id ? strlen(id) : strlen("(NULL)"); - slen = MIN(RSPAMD_LOG_ID_LEN, slen); + char processed_id[MEMPOOL_UID_LEN]; + int processed_len = 0; + + if (id) { + processed_len = rspamd_log_process_id(logger, id, processed_id); + } + else { + strcpy(processed_id, "(NULL)"); + processed_len = strlen(processed_id); + } + r = rspamd_snprintf(tmpbuf, sizeof(tmpbuf), "{\"ts\": %f, " "\"pid\": %P, " "\"severity\": \"%s\", " @@ -1073,7 +1192,7 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, logger->pid, rspamd_get_log_severity_string(level_flags), logger->process_type, - slen, id, + processed_len, processed_id, module, function); iov_ctx->iov[0].iov_base = tmpbuf; @@ -1229,16 +1348,17 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, glong mremain, mr; char *m; + char processed_id[MEMPOOL_UID_LEN]; + int processed_len = 0; modulebuf[0] = '\0'; mremain = sizeof(modulebuf); m = modulebuf; if (id != NULL) { - unsigned int slen = strlen(id); - slen = MIN(RSPAMD_LOG_ID_LEN, slen); - mr = rspamd_snprintf(m, mremain, "<%*.s>; ", slen, - id); + processed_len = rspamd_log_process_id(logger, id, processed_id); + mr = rspamd_snprintf(m, mremain, "<%*.s>; ", processed_len, + processed_id); m += mr; mremain -= mr; } @@ -1289,6 +1409,17 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, if (logger->log_level == G_LOG_LEVEL_DEBUG) { iov_ctx->iov[niov].iov_base = (void *) timebuf; iov_ctx->iov[niov++].iov_len = strlen(timebuf); + if (id != NULL) { + char processed_id[MEMPOOL_UID_LEN]; + int processed_len = rspamd_log_process_id(logger, id, processed_id); + + iov_ctx->iov[niov].iov_base = (void *) "; "; + iov_ctx->iov[niov++].iov_len = 2; + iov_ctx->iov[niov].iov_base = (void *) processed_id; + iov_ctx->iov[niov++].iov_len = processed_len; + iov_ctx->iov[niov].iov_base = (void *) ";"; + iov_ctx->iov[niov++].iov_len = 1; + } iov_ctx->iov[niov].iov_base = (void *) " "; iov_ctx->iov[niov++].iov_len = 1; } diff --git a/src/libserver/logger/logger_private.h b/src/libserver/logger/logger_private.h index 80178ad32..387d8639b 100644 --- a/src/libserver/logger/logger_private.h +++ b/src/libserver/logger/logger_private.h @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,12 @@ #define REPEATS_MAX 300 #define LOGBUF_LEN 8192 +enum rspamd_log_tag_strip_policy { + RSPAMD_LOG_TAG_STRIP_RIGHT = 0, /* Cut right part (current behavior) */ + RSPAMD_LOG_TAG_STRIP_LEFT, /* Cut left part (take last elements) */ + RSPAMD_LOG_TAG_STRIP_MIDDLE, /* Half from start and half from end */ +}; + struct rspamd_log_module { char *mname; unsigned int id; @@ -73,6 +79,10 @@ struct rspamd_logger_s { gboolean is_debug; gboolean no_lock; + /* Log tag configuration */ + unsigned int max_log_tag_len; + enum rspamd_log_tag_strip_policy log_tag_strip_policy; + pid_t pid; const char *process_type; struct rspamd_radix_map_helper *debug_ip; diff --git a/src/libserver/maps/map.c b/src/libserver/maps/map.c index 43a9d5d86..ac82d39bb 100644 --- a/src/libserver/maps/map.c +++ b/src/libserver/maps/map.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -84,7 +84,8 @@ RSPAMD_CONSTRUCTOR(rspamd_map_log_init) } /** - * Write HTTP request + * Write HTTP request with proper cache validation headers + * Uses ETags (If-None-Match) and Last-Modified (If-Modified-Since) for conditional requests */ static void write_http_request(struct http_callback_data *cbd) @@ -109,7 +110,8 @@ write_http_request(struct http_callback_data *cbd) } if (cbd->data->etag) { rspamd_http_message_add_header_len(msg, "If-None-Match", - cbd->data->etag->str, cbd->data->etag->len); + cbd->data->etag->str, + cbd->data->etag->len); } } @@ -295,6 +297,101 @@ rspamd_map_cache_cb(struct ev_loop *loop, ev_timer *w, int revents) } } +/** + * Calculate next check time with proper priority for different cache validation mechanisms + * Priority: ETags > Last-Modified > Cache expiration headers + * @param now current time + * @param expires time from cache expiration header + * @param map_check_interval base polling interval + * @param has_etag whether we have ETag for conditional requests + * @param has_last_modified whether we have Last-Modified for conditional requests + * @return next check time + */ +static inline time_t +rspamd_http_map_process_next_check(struct rspamd_map *map, + struct rspamd_map_backend *bk, + time_t now, + time_t expires, + time_t map_check_interval, + gboolean has_etag, + gboolean has_last_modified) +{ + static const time_t interval_mult = 4; /* Reduced from 16 to be more responsive */ + static const time_t min_respectful_interval = 5; + time_t next_check; + time_t effective_interval = map_check_interval; + + /* + * Priority order for cache validation: + * 1. ETags (most reliable) + * 2. Last-Modified dates + * 3. Cache expiration headers (least reliable) + */ + + if (has_etag || has_last_modified) { + /* + * If we have ETags or Last-Modified, we can use conditional requests + * to avoid unnecessary downloads. However, we still need to be respectful + * to servers and not DoS them with overly aggressive polling. + */ + if (map_check_interval < min_respectful_interval) { + /* + * User configured very aggressive polling, but server provides cache validation. + * Enforce minimum respectful interval to avoid DoS'ing the server. + */ + effective_interval = min_respectful_interval * interval_mult; + msg_info_map("map polling interval %d too aggressive with server cache support for %s, " + "using %d seconds minimum", + (int) map_check_interval, bk->uri, (int) effective_interval); + } + + if (expires > now && (expires - now) <= effective_interval * interval_mult) { + /* Use expires header if it's reasonable (within interval_mult x poll interval) */ + next_check = expires; + } + else { + /* Use effective interval, don't extend too much */ + next_check = now + effective_interval; + } + } + else if (expires > now) { + /* + * No ETags or Last-Modified available, rely on cache expiration. + * But still cap the interval to avoid too long delays. + * No need for respectful interval protection here since no conditional requests. + */ + if (expires - now > map_check_interval * interval_mult) { + next_check = now + map_check_interval * interval_mult; + } + else { + next_check = expires; + } + } + else { + /* No valid cache information, check immediately */ + next_check = now; + } + + return next_check; +} + +/** + * Calculate respectful polling interval to avoid DoS'ing servers with cache validation + * @param map_check_interval user configured interval + * @return effective interval that respects server resources + */ +static inline time_t +rspamd_map_get_respectful_interval(time_t map_check_interval) +{ + static const time_t min_respectful_interval = 5; /* Minimum 5 seconds to be respectful */ + static const time_t interval_mult = 4; /* Multiplier for respectful minimum */ + + if (map_check_interval < min_respectful_interval) { + return min_respectful_interval * interval_mult; + } + return map_check_interval; +} + static int http_map_finish(struct rspamd_http_connection *conn, struct rspamd_http_message *msg) @@ -316,12 +413,15 @@ http_map_finish(struct rspamd_http_connection *conn, if (msg->code == 200) { if (cbd->check) { - msg_info_map("need to reread map from %s", cbd->bk->uri); + msg_info_map("need to reread map from %s (reply code 200); " + "date timestamp: %z, last modified: %z", + cbd->bk->uri, (size_t) msg->date, (size_t) msg->last_modified); cbd->periodic->need_modify = TRUE; /* Reset the whole chain */ cbd->periodic->cur_backend = 0; /* Reset cache, old cached data will be cleaned on timeout */ g_atomic_int_set(&data->cache->available, 0); + g_atomic_int_set(&map->shared->loaded, 0); data->cur_cache_cbd = NULL; rspamd_map_process_periodic(cbd->periodic); @@ -330,6 +430,7 @@ http_map_finish(struct rspamd_http_connection *conn, return 0; } + /* This code is executed when we are actually reading a map */ cbd->data->last_checked = msg->date; if (msg->last_modified) { @@ -360,10 +461,11 @@ http_map_finish(struct rspamd_http_connection *conn, goto err; } - /* Check for expires */ + /* Check for expires + etag */ double cached_timeout = map->poll_timeout * 2; expires_hdr = rspamd_http_message_find_header(msg, "Expires"); + etag_hdr = rspamd_http_message_find_header(msg, "ETag"); if (expires_hdr) { time_t hdate; @@ -371,19 +473,27 @@ http_map_finish(struct rspamd_http_connection *conn, hdate = rspamd_http_parse_date(expires_hdr->begin, expires_hdr->len); if (hdate != (time_t) -1 && hdate > msg->date) { - cached_timeout = map->next_check - msg->date + - map->poll_timeout * 2; - - map->next_check = hdate; + map->next_check = rspamd_http_map_process_next_check(map, bk, msg->date, hdate, + (time_t) map->poll_timeout, + etag_hdr != NULL, + msg->last_modified != 0); + cached_timeout = map->next_check - msg->date; } else { msg_info_map("invalid expires header: %T, ignore it", expires_hdr); map->next_check = 0; } } - - /* Check for etag */ - etag_hdr = rspamd_http_message_find_header(msg, "ETag"); + else if (etag_hdr != NULL || msg->last_modified != 0) { + /* No expires header, but we have ETag or Last-Modified - use respectful interval */ + time_t effective_interval = rspamd_map_get_respectful_interval(map->poll_timeout); + if (effective_interval != map->poll_timeout) { + msg_info_map("map polling interval %d too aggressive with server cache support, " + "using %d seconds minimum", + (int) map->poll_timeout, (int) effective_interval); + } + map->next_check = msg->date + effective_interval; + } if (etag_hdr) { if (cbd->data->etag) { @@ -404,10 +514,7 @@ http_map_finish(struct rspamd_http_connection *conn, MAP_RETAIN(cbd->shmem_data, "shmem_data"); cbd->data->gen++; - /* - * We know that a map is in the locked state - */ - g_atomic_int_set(&data->cache->available, 1); + /* Store cached data */ rspamd_strlcpy(data->cache->shmem_name, cbd->shmem_data->shm_name, sizeof(data->cache->shmem_name)); @@ -509,6 +616,12 @@ http_map_finish(struct rspamd_http_connection *conn, cbd->periodic->cur_backend++; munmap(in, dlen); + + /* Announce for other processes */ + g_atomic_int_set(&data->cache->available, 1); + g_atomic_int_set(&map->shared->loaded, 1); + g_atomic_int_set(&map->shared->cached, 1); + rspamd_map_process_periodic(cbd->periodic); } else if (msg->code == 304 && cbd->check) { @@ -522,19 +635,34 @@ http_map_finish(struct rspamd_http_connection *conn, } expires_hdr = rspamd_http_message_find_header(msg, "Expires"); + bool has_expires = (expires_hdr != NULL); if (expires_hdr) { time_t hdate; hdate = rspamd_http_parse_date(expires_hdr->begin, expires_hdr->len); if (hdate != (time_t) -1 && hdate > msg->date) { - map->next_check = hdate; + map->next_check = rspamd_http_map_process_next_check(map, bk, msg->date, hdate, + (time_t) map->poll_timeout, + cbd->data->etag != NULL, + msg->last_modified != 0); } else { msg_info_map("invalid expires header: %T, ignore it", expires_hdr); map->next_check = 0; + has_expires = false; } } + else if (cbd->data->etag != NULL || msg->last_modified != 0) { + /* No expires header, but we have ETag or Last-Modified - use respectful interval */ + time_t effective_interval = rspamd_map_get_respectful_interval(map->poll_timeout); + if (effective_interval != map->poll_timeout) { + msg_info_map("map polling interval %d too aggressive with server cache support, " + "using %d seconds minimum", + (int) map->poll_timeout, (int) effective_interval); + } + map->next_check = msg->date + effective_interval; + } etag_hdr = rspamd_http_message_find_header(msg, "ETag"); @@ -547,19 +675,24 @@ http_map_finish(struct rspamd_http_connection *conn, } } - if (map->next_check) { + if (has_expires) { rspamd_http_date_format(next_check_date, sizeof(next_check_date), map->next_check); - msg_info_map("data is not modified for server %s, next check at %s " + msg_info_map("data is not modified for server %s (%s), next check at %s " "(http cache based: %T)", - cbd->data->host, next_check_date, expires_hdr); + cbd->data->host, + bk->uri, + next_check_date, + expires_hdr); } else { rspamd_http_date_format(next_check_date, sizeof(next_check_date), - rspamd_get_calendar_ticks() + map->poll_timeout); - msg_info_map("data is not modified for server %s, next check at %s " + map->next_check); + msg_info_map("data is not modified for server %s (%s), next check at %s " "(timer based)", - cbd->data->host, next_check_date); + cbd->data->host, + bk->uri, + next_check_date); } rspamd_map_update_http_cached_file(map, bk, cbd->data); @@ -902,6 +1035,8 @@ read_map_file(struct rspamd_map *map, struct file_map_data *data, map->read_callback(NULL, 0, &periodic->cbdata, TRUE); } + g_atomic_int_set(&map->shared->loaded, 1); + return TRUE; } @@ -986,6 +1121,7 @@ read_map_static(struct rspamd_map *map, struct static_map_data *data, } data->processed = TRUE; + g_atomic_int_set(&map->shared->loaded, 1); return TRUE; } @@ -993,9 +1129,7 @@ read_map_static(struct rspamd_map *map, struct static_map_data *data, static void rspamd_map_periodic_dtor(struct map_periodic_cbdata *periodic) { - struct rspamd_map *map; - - map = periodic->map; + struct rspamd_map *map = periodic->map; msg_debug_map("periodic dtor %p; need_modify=%d", periodic, periodic->need_modify); if (periodic->need_modify || periodic->cbdata.errored) { @@ -1010,18 +1144,13 @@ rspamd_map_periodic_dtor(struct map_periodic_cbdata *periodic) /* Not modified */ } - if (periodic->locked) { - g_atomic_int_set(periodic->map->locked, 0); - msg_debug_map("unlocked map %s", periodic->map->name); - - if (periodic->map->wrk->state == rspamd_worker_state_running) { - rspamd_map_schedule_periodic(periodic->map, - RSPAMD_SYMBOL_RESULT_NORMAL); - } - else { - msg_debug_map("stop scheduling periodics for %s; terminating state", - periodic->map->name); - } + if (periodic->map->wrk->state == rspamd_worker_state_running) { + rspamd_map_schedule_periodic(periodic->map, + RSPAMD_MAP_SCHEDULE_NORMAL); + } + else { + msg_debug_map("stop scheduling periodics for %s; terminating state", + periodic->map->name); } g_free(periodic); @@ -1458,7 +1587,7 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, const unsigned char *data, gsize len) { - char path[PATH_MAX]; + char path[PATH_MAX], temp_path[PATH_MAX]; unsigned char digest[rspamd_cryptobox_HASHBYTES]; struct rspamd_config *cfg = map->cfg; int fd; @@ -1471,8 +1600,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, rspamd_cryptobox_hash(digest, bk->uri, strlen(bk->uri), NULL, 0); rspamd_snprintf(path, sizeof(path), "%s%c%*xs.map", cfg->maps_cache_dir, G_DIR_SEPARATOR, 20, digest); + rspamd_snprintf(temp_path, sizeof(temp_path), "%s.tmp.%d.%d", path, + (int) getpid(), (int) rspamd_get_calendar_ticks()); - fd = rspamd_file_xopen(path, O_WRONLY | O_TRUNC | O_CREAT, + fd = rspamd_file_xopen(temp_path, O_WRONLY | O_TRUNC | O_CREAT, 00600, FALSE); if (fd == -1) { @@ -1480,8 +1611,9 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, } if (!rspamd_file_lock(fd, FALSE)) { - msg_err_map("cannot lock file %s: %s", path, strerror(errno)); + msg_err_map("cannot lock file %s: %s", temp_path, strerror(errno)); close(fd); + unlink(temp_path); return FALSE; } @@ -1500,9 +1632,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, } if (write(fd, &header, sizeof(header)) != sizeof(header)) { - msg_err_map("cannot write file %s (header stage): %s", path, strerror(errno)); + msg_err_map("cannot write file %s (header stage): %s", temp_path, strerror(errno)); rspamd_file_unlock(fd, FALSE); close(fd); + unlink(temp_path); return FALSE; } @@ -1510,9 +1643,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, if (header.etag_len > 0) { if (write(fd, RSPAMD_FSTRING_DATA(htdata->etag), header.etag_len) != header.etag_len) { - msg_err_map("cannot write file %s (etag stage): %s", path, strerror(errno)); + msg_err_map("cannot write file %s (etag stage): %s", temp_path, strerror(errno)); rspamd_file_unlock(fd, FALSE); close(fd); + unlink(temp_path); return FALSE; } @@ -1520,9 +1654,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, /* Now write the rest */ if (write(fd, data, len) != len) { - msg_err_map("cannot write file %s (data stage): %s", path, strerror(errno)); + msg_err_map("cannot write file %s (data stage): %s", temp_path, strerror(errno)); rspamd_file_unlock(fd, FALSE); close(fd); + unlink(temp_path); return FALSE; } @@ -1530,6 +1665,13 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, rspamd_file_unlock(fd, FALSE); close(fd); + /* Atomically move temp file to final location */ + if (rename(temp_path, path) != 0) { + msg_err_map("cannot rename %s to %s: %s", temp_path, path, strerror(errno)); + unlink(temp_path); + return FALSE; + } + msg_info_map("saved data from %s in %s, %uz bytes", bk->uri, path, len + sizeof(header) + header.etag_len); return TRUE; @@ -1663,7 +1805,11 @@ rspamd_map_read_http_cached_file(struct rspamd_map *map, double now = rspamd_get_calendar_ticks(); if (header.next_check > now) { - map->next_check = header.next_check; + /* We assume that we have this data inside the cached file */ + map->next_check = rspamd_http_map_process_next_check(map, bk, now, header.next_check, + map->poll_timeout, + header.etag_len > 0, + true); } else { map->next_check = now; @@ -1710,6 +1856,8 @@ rspamd_map_read_http_cached_file(struct rspamd_map *map, struct tm tm; char ncheck_buf[32], lm_buf[32]; + g_atomic_int_set(&map->shared->loaded, 1); + g_atomic_int_set(&map->shared->cached, 1); rspamd_localtime(map->next_check, &tm); strftime(ncheck_buf, sizeof(ncheck_buf) - 1, "%Y-%m-%d %H:%M:%S", &tm); rspamd_localtime(htdata->last_modified, &tm); @@ -1752,7 +1900,6 @@ rspamd_map_common_http_callback(struct rspamd_map *map, (int) data->last_modified, (int) data->cache->last_modified); periodic->need_modify = TRUE; - /* Reset the whole chain */ periodic->cur_backend = 0; rspamd_map_process_periodic(periodic); } @@ -2010,33 +2157,22 @@ rspamd_map_process_periodic(struct map_periodic_cbdata *cbd) map = cbd->map; map->scheduled_check = NULL; - if (!map->file_only && !cbd->locked) { - if (!g_atomic_int_compare_and_exchange(cbd->map->locked, - 0, 1)) { - msg_debug_map( - "don't try to reread map %s as it is locked by other process, " - "will reread it later", - cbd->map->name); - rspamd_map_schedule_periodic(map, RSPAMD_MAP_SCHEDULE_LOCKED); - MAP_RELEASE(cbd, "periodic"); + /* For each backend we need to check for modifications */ + if (cbd->cur_backend >= cbd->map->backends->len) { + /* Last backend */ + msg_debug_map("finished map: %d of %d", cbd->cur_backend, + cbd->map->backends->len); + MAP_RELEASE(cbd, "periodic"); - return; - } - else { - msg_debug_map("locked map %s", cbd->map->name); - cbd->locked = TRUE; - } + return; } + bk = g_ptr_array_index(map->backends, cbd->cur_backend); + if (cbd->errored) { /* We should not check other backends if some backend has failed*/ rspamd_map_schedule_periodic(cbd->map, RSPAMD_MAP_SCHEDULE_ERROR); - if (cbd->locked) { - g_atomic_int_set(cbd->map->locked, 0); - cbd->locked = FALSE; - } - /* Also set error flag for the map consumer */ cbd->cbdata.errored = true; @@ -2047,19 +2183,7 @@ rspamd_map_process_periodic(struct map_periodic_cbdata *cbd) return; } - /* For each backend we need to check for modifications */ - if (cbd->cur_backend >= cbd->map->backends->len) { - /* Last backend */ - msg_debug_map("finished map: %d of %d", cbd->cur_backend, - cbd->map->backends->len); - MAP_RELEASE(cbd, "periodic"); - - return; - } - if (cbd->map->wrk && cbd->map->wrk->state == rspamd_worker_state_running) { - bk = g_ptr_array_index(cbd->map->backends, cbd->cur_backend); - g_assert(bk != NULL); if (cbd->need_modify) { /* Load data from the next backend */ @@ -2764,10 +2888,6 @@ rspamd_map_parse_backend(struct rspamd_config *cfg, const char *map_line) bk->data.sd = sdata; } - bk->id = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_T1HA, - bk->uri, strlen(bk->uri), - 0xdeadbabe); - return bk; err: @@ -2798,6 +2918,13 @@ rspamd_map_calculate_hash(struct rspamd_map *map) rspamd_cryptobox_hash_init(&st, NULL, 0); + if (map->name) { + rspamd_cryptobox_hash_update(&st, map->name, strlen(map->name)); + } + if (map->description) { + rspamd_cryptobox_hash_update(&st, map->description, strlen(map->description)); + } + for (i = 0; i < map->backends->len; i++) { bk = g_ptr_array_index(map->backends, i); rspamd_cryptobox_hash_update(&st, bk->uri, strlen(bk->uri)); @@ -2806,6 +2933,26 @@ rspamd_map_calculate_hash(struct rspamd_map *map) rspamd_cryptobox_hash_final(&st, cksum); cksum_encoded = rspamd_encode_base32(cksum, sizeof(cksum), RSPAMD_BASE32_DEFAULT); rspamd_strlcpy(map->tag, cksum_encoded, sizeof(map->tag)); + + for (i = 0; i < map->backends->len; i++) { + bk = g_ptr_array_index(map->backends, i); + + /* Also update each backend */ + rspamd_cryptobox_fast_hash_state_t hst; + rspamd_cryptobox_fast_hash_init(&hst, 0); + rspamd_cryptobox_fast_hash_update(&hst, bk->uri, strlen(bk->uri)); + rspamd_cryptobox_fast_hash_update(&hst, map->tag, sizeof(map->tag)); + + if (bk->protocol == MAP_PROTO_STATIC) { + /* Static maps content is pre-defined */ + rspamd_cryptobox_fast_hash_update(&hst, bk->data.sd->data, + bk->data.sd->len); + } + + /* We use only 52 bits to be compatible with other numbers representation */ + bk->id = rspamd_cryptobox_fast_hash_final(&hst) & ~(0xFFFULL << 52); + } + g_free(cksum_encoded); } @@ -2871,8 +3018,8 @@ rspamd_map_add(struct rspamd_config *cfg, map->user_data = user_data; map->cfg = cfg; map->id = rspamd_random_uint64_fast(); - map->locked = - rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(int)); + map->shared = + rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(struct rspamd_map_shared_data)); map->backends = g_ptr_array_sized_new(1); map->wrk = worker; rspamd_mempool_add_destructor(cfg->cfg_pool, rspamd_ptr_array_free_hard, @@ -2971,8 +3118,8 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, map->user_data = user_data; map->cfg = cfg; map->id = rspamd_random_uint64_fast(); - map->locked = - rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(int)); + map->shared = + rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(struct rspamd_map_shared_data)); map->backends = g_ptr_array_new(); map->wrk = worker; map->no_file_read = (flags & RSPAMD_MAP_FILE_NO_READ); @@ -3091,7 +3238,7 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, goto err; } - gboolean all_local = TRUE; + gboolean all_local = TRUE, all_loaded = TRUE; PTR_ARRAY_FOREACH(map->backends, i, bk) { @@ -3110,9 +3257,8 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, map_data = g_string_sized_new(32); if (rspamd_map_add_static_string(cfg, elt, map_data)) { - bk->data.sd->data = map_data->str; bk->data.sd->len = map_data->len; - g_string_free(map_data, FALSE); + bk->data.sd->data = (unsigned char *) g_string_free(map_data, FALSE); } else { g_string_free(map_data, TRUE); @@ -3135,13 +3281,16 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, } ucl_object_iterate_free(it); - bk->data.sd->data = map_data->str; bk->data.sd->len = map_data->len; - g_string_free(map_data, FALSE); + bk->data.sd->data = (unsigned char *) g_string_free(map_data, FALSE); } } else if (bk->protocol != MAP_PROTO_FILE) { all_local = FALSE; + all_loaded = FALSE; /* Will be loaded later */ + } + else { + all_loaded = FALSE; /* Will be loaded later (even for files) */ } } @@ -3150,6 +3299,11 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, cfg->map_file_watch_multiplier); } + if (all_loaded) { + /* Static map */ + g_atomic_int_set(&map->shared->loaded, 1); + } + rspamd_map_calculate_hash(map); msg_debug_map("added map from ucl"); diff --git a/src/libserver/maps/map_private.h b/src/libserver/maps/map_private.h index d0b22fe36..65df8d7f5 100644 --- a/src/libserver/maps/map_private.h +++ b/src/libserver/maps/map_private.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,23 @@ enum fetch_proto { MAP_PROTO_STATIC }; +static const char * +rspamd_map_fetch_protocol_name(enum fetch_proto proto) +{ + switch (proto) { + case MAP_PROTO_FILE: + return "file"; + case MAP_PROTO_HTTP: + return "http"; + case MAP_PROTO_HTTPS: + return "https"; + case MAP_PROTO_STATIC: + return "static"; + default: + return "unknown"; + } +} + /** * Data specific to file maps */ @@ -76,7 +93,7 @@ struct rspamd_http_map_cached_cbdata { time_t last_checked; }; -struct rspamd_map_cachepoint { +struct rspamd_http_map_cache { int available; gsize len; time_t last_modified; @@ -88,7 +105,7 @@ struct rspamd_map_cachepoint { */ struct http_map_data { /* Shared cache data */ - struct rspamd_map_cachepoint *cache; + struct rspamd_http_map_cache *cache; /* Non-shared for cache owner, used to cleanup cache */ struct rspamd_http_map_cached_cbdata *cur_cache_cbd; char *userinfo; @@ -117,6 +134,7 @@ union rspamd_map_backend_data { struct rspamd_map; + struct rspamd_map_backend { enum fetch_proto protocol; gboolean is_signed; @@ -124,7 +142,7 @@ struct rspamd_map_backend { gboolean is_fallback; struct rspamd_map *map; struct ev_loop *event_loop; - uint32_t id; + uint64_t id; struct rspamd_cryptobox_pubkey *trusted_pubkey; union rspamd_map_backend_data data; char *uri; @@ -133,6 +151,14 @@ struct rspamd_map_backend { struct map_periodic_cbdata; +/* + * Shared between workers + */ +struct rspamd_map_shared_data { + int loaded; + int cached; +}; + struct rspamd_map { struct rspamd_dns_resolver *r; struct rspamd_config *cfg; @@ -168,7 +194,7 @@ struct rspamd_map { bool no_file_read; /* Do not read files */ bool seen; /* This map has already been watched or pre-loaded */ /* Shared lock for temporary disabling of map reading (e.g. when this map is written by UI) */ - int *locked; + struct rspamd_map_shared_data *shared; char tag[MEMPOOL_UID_LEN]; }; @@ -185,7 +211,6 @@ struct map_periodic_cbdata { ev_timer ev; gboolean need_modify; gboolean errored; - gboolean locked; unsigned int cur_backend; ref_entry_t ref; }; diff --git a/src/libserver/milter.c b/src/libserver/milter.c index 94b0d6cc1..09ddddaba 100644 --- a/src/libserver/milter.c +++ b/src/libserver/milter.c @@ -1473,8 +1473,6 @@ rspamd_milter_macro_http(struct rspamd_milter_session *session, { rspamd_http_message_add_header_len(msg, QUEUE_ID_HEADER, found->begin, found->len); - rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, - found->begin, found->len); } else { @@ -1482,8 +1480,6 @@ rspamd_milter_macro_http(struct rspamd_milter_session *session, { rspamd_http_message_add_header_len(msg, QUEUE_ID_HEADER, found->begin, found->len); - rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, - found->begin, found->len); } } diff --git a/src/libserver/re_cache.c b/src/libserver/re_cache.c index 06e9f3328..50b155ae0 100644 --- a/src/libserver/re_cache.c +++ b/src/libserver/re_cache.c @@ -998,20 +998,21 @@ rspamd_re_cache_process_selector(struct rspamd_task *task, return result; } + static inline unsigned int -rspamd_process_words_vector(GArray *words, - const unsigned char **scvec, - unsigned int *lenvec, - struct rspamd_re_class *re_class, - unsigned int cnt, - gboolean *raw) +rspamd_process_words_vector_kvec(rspamd_words_t *words, + const unsigned char **scvec, + unsigned int *lenvec, + struct rspamd_re_class *re_class, + unsigned int cnt, + gboolean *raw) { unsigned int j; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; - if (words) { - for (j = 0; j < words->len; j++) { - tok = &g_array_index(words, rspamd_stat_token_t, j); + if (words && words->a) { + for (j = 0; j < kv_size(*words); j++) { + tok = &kv_A(*words, j); if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF)) { @@ -1432,13 +1433,13 @@ rspamd_re_cache_exec_re(struct rspamd_task *task, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, text_part) { - if (text_part->utf_words) { - cnt += text_part->utf_words->len; + if (text_part->utf_words.a) { + cnt += kv_size(text_part->utf_words); } } - if (task->meta_words && task->meta_words->len > 0) { - cnt += task->meta_words->len; + if (task->meta_words.a && kv_size(task->meta_words) > 0) { + cnt += kv_size(task->meta_words); } if (cnt > 0) { @@ -1449,15 +1450,15 @@ rspamd_re_cache_exec_re(struct rspamd_task *task, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, text_part) { - if (text_part->utf_words) { - cnt = rspamd_process_words_vector(text_part->utf_words, - scvec, lenvec, re_class, cnt, &raw); + if (text_part->utf_words.a) { + cnt = rspamd_process_words_vector_kvec(&text_part->utf_words, + scvec, lenvec, re_class, cnt, &raw); } } - if (task->meta_words) { - cnt = rspamd_process_words_vector(task->meta_words, - scvec, lenvec, re_class, cnt, &raw); + if (task->meta_words.a) { + cnt = rspamd_process_words_vector_kvec(&task->meta_words, + scvec, lenvec, re_class, cnt, &raw); } ret = rspamd_re_cache_process_regexp_data(rt, re, diff --git a/src/libserver/redis_pool.cxx b/src/libserver/redis_pool.cxx index cea8d0c86..586260a6f 100644 --- a/src/libserver/redis_pool.cxx +++ b/src/libserver/redis_pool.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -465,6 +465,8 @@ auto redis_pool_elt::new_connection() -> redisAsyncContext * * We cannot reuse connection, so we just recursively call * this function one more time */ + msg_debug_rpool("cannot reuse the existing connection to %s:%d: %p; errno=%d", + ip.c_str(), port, conn->ctx, err); return new_connection(); } else { @@ -481,6 +483,9 @@ auto redis_pool_elt::new_connection() -> redisAsyncContext * } else { auto *nctx = redis_async_new(); + msg_debug_rpool("error in the inactive connection: %s; opened new connection to %s:%d: %p", + conn->ctx->errstr, ip.c_str(), port, nctx); + if (nctx) { active.emplace_front(std::make_unique<redis_pool_connection>(pool, this, db.c_str(), username.c_str(), password.c_str(), nctx)); @@ -492,10 +497,14 @@ auto redis_pool_elt::new_connection() -> redisAsyncContext * } else { auto *nctx = redis_async_new(); + if (nctx) { active.emplace_front(std::make_unique<redis_pool_connection>(pool, this, db.c_str(), username.c_str(), password.c_str(), nctx)); active.front()->elt_pos = active.begin(); + auto conn = active.front().get(); + msg_debug_rpool("no inactive connections; opened new connection to %s:%d: %p", + ip.c_str(), port, nctx); } return nctx; diff --git a/src/libserver/roll_history.c b/src/libserver/roll_history.c index 66a53a597..d0f145d8f 100644 --- a/src/libserver/roll_history.c +++ b/src/libserver/roll_history.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -231,7 +231,7 @@ rspamd_roll_history_load(struct roll_history *history, const char *filename) return FALSE; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_fd(parser, fd)) { msg_warn("cannot parse history file %s: %s", filename, diff --git a/src/libserver/rspamd_control.c b/src/libserver/rspamd_control.c index 1bff2ff12..9e35cb575 100644 --- a/src/libserver/rspamd_control.c +++ b/src/libserver/rspamd_control.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -214,7 +214,7 @@ rspamd_control_write_reply(struct rspamd_control_session *session) case RSPAMD_CONTROL_FUZZY_STAT: if (elt->attached_fd != -1) { /* We have some data to parse */ - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); ucl_object_insert_key(cur, ucl_object_fromint( elt->reply.reply.fuzzy_stat.status), diff --git a/src/libserver/symcache/symcache_impl.cxx b/src/libserver/symcache/symcache_impl.cxx index 7159555d2..c1ca2a6ed 100644 --- a/src/libserver/symcache/symcache_impl.cxx +++ b/src/libserver/symcache/symcache_impl.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "unix-std.h" #include "libutil/cxx/file_util.hxx" #include "libutil/cxx/util.hxx" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "contrib/t1ha/t1ha.h" #ifdef __has_include @@ -126,7 +126,7 @@ auto symcache::init() -> bool } else { msg_err_cache("cannot register delayed dependency %s -> %s: " - "destionation %s is missing", + "destination %s is missing", delayed_dep.from.data(), delayed_dep.to.data(), delayed_dep.to.data()); } @@ -274,7 +274,7 @@ auto symcache::load_items() -> bool return false; } - auto *parser = ucl_parser_new(0); + auto *parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); const auto *p = (const std::uint8_t *) (hdr + 1); if (!ucl_parser_add_chunk(parser, p, cached_map->get_size() - sizeof(*hdr))) { @@ -1338,4 +1338,4 @@ auto symcache::get_max_timeout(std::vector<std::pair<double, const cache_item *> return accumulated_timeout; } -}// namespace rspamd::symcache
\ No newline at end of file +}// namespace rspamd::symcache diff --git a/src/libserver/symcache/symcache_item.cxx b/src/libserver/symcache/symcache_item.cxx index 233e8113a..f58332ea5 100644 --- a/src/libserver/symcache/symcache_item.cxx +++ b/src/libserver/symcache/symcache_item.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #include "lua/lua_common.h" #include "symcache_internal.hxx" #include "symcache_item.hxx" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "libserver/task.h" #include "libutil/cxx/util.hxx" #include <numeric> diff --git a/src/libserver/task.c b/src/libserver/task.c index bd1e07549..9f5b1f00a 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -196,8 +196,8 @@ void rspamd_task_free(struct rspamd_task *task) rspamd_email_address_free(task->from_envelope_orig); } - if (task->meta_words) { - g_array_free(task->meta_words, TRUE); + if (task->meta_words.a) { + kv_destroy(task->meta_words); } ucl_object_unref(task->messages); diff --git a/src/libserver/task.h b/src/libserver/task.h index 6be350098..1c1778fee 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -24,6 +24,7 @@ #include "dns.h" #include "re_cache.h" #include "khash.h" +#include "libserver/word.h" #ifdef __cplusplus extern "C" { @@ -187,7 +188,7 @@ struct rspamd_task { struct rspamd_scan_result *result; /**< Metric result */ khash_t(rspamd_task_lua_cache) lua_cache; /**< cache of lua objects */ GPtrArray *tokens; /**< statistics tokens */ - GArray *meta_words; /**< rspamd_stat_token_t produced from meta headers + rspamd_words_t meta_words; /**< rspamd_word_t produced from meta headers (e.g. Subject) */ GPtrArray *rcpt_envelope; /**< array of rspamd_email_address */ diff --git a/src/libserver/word.h b/src/libserver/word.h new file mode 100644 index 000000000..7698bf327 --- /dev/null +++ b/src/libserver/word.h @@ -0,0 +1,88 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_WORD_H +#define RSPAMD_WORD_H + +#include "config.h" +#include "fstring.h" +#include "contrib/libucl/kvec.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @file word.h + * Word processing structures and definitions + */ + +/* Word flags */ +#define RSPAMD_WORD_FLAG_TEXT (1u << 0) +#define RSPAMD_WORD_FLAG_META (1u << 1) +#define RSPAMD_WORD_FLAG_LUA_META (1u << 2) +#define RSPAMD_WORD_FLAG_EXCEPTION (1u << 3) +#define RSPAMD_WORD_FLAG_HEADER (1u << 4) +#define RSPAMD_WORD_FLAG_UNIGRAM (1u << 5) +#define RSPAMD_WORD_FLAG_UTF (1u << 6) +#define RSPAMD_WORD_FLAG_NORMALISED (1u << 7) +#define RSPAMD_WORD_FLAG_STEMMED (1u << 8) +#define RSPAMD_WORD_FLAG_BROKEN_UNICODE (1u << 9) +#define RSPAMD_WORD_FLAG_STOP_WORD (1u << 10) +#define RSPAMD_WORD_FLAG_SKIPPED (1u << 11) +#define RSPAMD_WORD_FLAG_INVISIBLE_SPACES (1u << 12) +#define RSPAMD_WORD_FLAG_EMOJI (1u << 13) + +/** + * Word structure representing tokenized text + */ +typedef struct rspamd_word_s { + rspamd_ftok_t original; /* utf8 raw */ + rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ + rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ + rspamd_ftok_t stemmed; /* stemmed utf8 */ + unsigned int flags; +} rspamd_word_t; + +/** + * Vector of words using kvec + */ +typedef kvec_t(rspamd_word_t) rspamd_words_t; + +/* Legacy typedefs for backward compatibility */ +typedef rspamd_word_t rspamd_stat_token_t; + +/* Legacy flag aliases for backward compatibility */ +#define RSPAMD_STAT_TOKEN_FLAG_TEXT RSPAMD_WORD_FLAG_TEXT +#define RSPAMD_STAT_TOKEN_FLAG_META RSPAMD_WORD_FLAG_META +#define RSPAMD_STAT_TOKEN_FLAG_LUA_META RSPAMD_WORD_FLAG_LUA_META +#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION RSPAMD_WORD_FLAG_EXCEPTION +#define RSPAMD_STAT_TOKEN_FLAG_HEADER RSPAMD_WORD_FLAG_HEADER +#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM RSPAMD_WORD_FLAG_UNIGRAM +#define RSPAMD_STAT_TOKEN_FLAG_UTF RSPAMD_WORD_FLAG_UTF +#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED RSPAMD_WORD_FLAG_NORMALISED +#define RSPAMD_STAT_TOKEN_FLAG_STEMMED RSPAMD_WORD_FLAG_STEMMED +#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE RSPAMD_WORD_FLAG_BROKEN_UNICODE +#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD RSPAMD_WORD_FLAG_STOP_WORD +#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED RSPAMD_WORD_FLAG_SKIPPED +#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES RSPAMD_WORD_FLAG_INVISIBLE_SPACES +#define RSPAMD_STAT_TOKEN_FLAG_EMOJI RSPAMD_WORD_FLAG_EMOJI + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_WORD_H */ diff --git a/src/libserver/worker_util.c b/src/libserver/worker_util.c index 75836573f..685ee9cd2 100644 --- a/src/libserver/worker_util.c +++ b/src/libserver/worker_util.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -167,7 +167,7 @@ static void rspamd_worker_terminate_handlers(struct rspamd_worker *w) { if (w->nconns == 0 && - (!(w->flags & RSPAMD_WORKER_SCANNER) || w->srv->cfg->on_term_scripts == NULL)) { + (!(w->flags & (RSPAMD_WORKER_SCANNER | RSPAMD_WORKER_FUZZY)) || w->srv->cfg->on_term_scripts == NULL)) { /* * We are here either: * - No active connections are represented @@ -190,7 +190,7 @@ rspamd_worker_terminate_handlers(struct rspamd_worker *w) if (w->state != rspamd_worker_wait_final_scripts) { w->state = rspamd_worker_wait_final_scripts; - if ((w->flags & RSPAMD_WORKER_SCANNER) && + if ((w->flags & (RSPAMD_WORKER_SCANNER | RSPAMD_WORKER_FUZZY)) && rspamd_worker_call_finish_handlers(w)) { msg_info("performing async finishing actions"); w->state = rspamd_worker_wait_final_scripts; @@ -2138,7 +2138,7 @@ rspamd_controller_load_saved_stats(struct rspamd_main *rspamd_main, return; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_file(parser, cfg->stats_file)) { msg_err_config("cannot parse controller stats from %s: %s", diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt index 64d572a57..eddf64e49 100644 --- a/src/libstat/CMakeLists.txt +++ b/src/libstat/CMakeLists.txt @@ -1,25 +1,26 @@ # Librspamdserver -SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c - ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) +SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c + ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) -SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c - ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) +SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizer_manager.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) -SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c - ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) +SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c + ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) -SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c - ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c - ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx - ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx - ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) +SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) -SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c +SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx) SET(RSPAMD_STAT ${LIBSTATSRC} - ${TOKENIZERSSRC} - ${CLASSIFIERSSRC} - ${BACKENDSSRC} - ${CACHESSRC} PARENT_SCOPE) + ${TOKENIZERSSRC} + ${CLASSIFIERSSRC} + ${BACKENDSSRC} + ${CACHESSRC} PARENT_SCOPE) diff --git a/src/libstat/backends/cdb_backend.cxx b/src/libstat/backends/cdb_backend.cxx index bd05e8ef8..0f55a725c 100644 --- a/src/libstat/backends/cdb_backend.cxx +++ b/src/libstat/backends/cdb_backend.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ #include <optional> #include "contrib/expected/expected.hpp" #include "contrib/ankerl/unordered_dense.h" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" namespace rspamd::stat::cdb { diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx index 06842b078..7137904e9 100644 --- a/src/libstat/backends/redis_backend.cxx +++ b/src/libstat/backends/redis_backend.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "stat_internal.h" #include "upstream.h" #include "libserver/mempool_vars_internal.h" -#include "fmt/base.h" +#include "contrib/fmt/include/fmt/base.h" #include "libutil/cxx/error.hxx" diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index f28922588..811566ad3 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -20,6 +20,7 @@ #include "task.h" #include "lua/lua_common.h" #include "contrib/libev/ev.h" +#include "libserver/word.h" #ifdef __cplusplus extern "C" { @@ -30,36 +31,14 @@ extern "C" { * High level statistics API */ -#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0) -#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1) -#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2) -#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3) -#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4) -#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5) -#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6) -#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7) -#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8) -#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9) -#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 10) -#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 11) -#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 12) -#define RSPAMD_STAT_TOKEN_FLAG_EMOJI (1u << 13) - -typedef struct rspamd_stat_token_s { - rspamd_ftok_t original; /* utf8 raw */ - rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ - rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ - rspamd_ftok_t stemmed; /* stemmed utf8 */ - unsigned int flags; -} rspamd_stat_token_t; #define RSPAMD_TOKEN_VALUE_TYPE float typedef struct token_node_s { uint64_t data; unsigned int window_idx; unsigned int flags; - rspamd_stat_token_t *t1; - rspamd_stat_token_t *t2; + rspamd_word_t *t1; + rspamd_word_t *t2; RSPAMD_TOKEN_VALUE_TYPE values[0]; } rspamd_token_t; diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 17caf4cc6..176064087 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,12 +36,13 @@ static void rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task) { - GArray *ar; - rspamd_stat_token_t elt; + rspamd_words_t *words; + rspamd_word_t elt; unsigned int i; lua_State *L = task->cfg->lua_state; - ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16); + words = rspamd_mempool_alloc(task->task_pool, sizeof(*words)); + kv_init(*words); memset(&elt, 0, sizeof(elt)); elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; @@ -87,7 +88,7 @@ rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, elt.normalized.begin = elt.original.begin; elt.normalized.len = elt.original.len; - g_array_append_val(ar, elt); + kv_push_safe(rspamd_word_t, *words, elt, meta_words_error); } lua_pop(L, 1); @@ -99,17 +100,20 @@ rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, } - if (ar->len > 0) { + if (kv_size(*words) > 0) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - ar, + words, TRUE, "M", task->tokens); } - rspamd_mempool_add_destructor(task->task_pool, - rspamd_array_free_hard, ar); + return; +meta_words_error: + + msg_err("cannot process meta words for task" + "memory allocation error, skipping the remaining"); } /* @@ -134,8 +138,8 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) { - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { - reserved_len += part->utf_words->len; + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a) { + reserved_len += kv_size(part->utf_words); } /* XXX: normal window size */ reserved_len += 5; @@ -149,9 +153,9 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) { - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - part->utf_words, IS_TEXT_PART_UTF(part), + &part->utf_words, IS_TEXT_PART_UTF(part), NULL, task->tokens); } @@ -163,10 +167,10 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, } } - if (task->meta_words != NULL) { + if (task->meta_words.a) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - task->meta_words, + &task->meta_words, TRUE, "SUBJECT", task->tokens); diff --git a/src/libstat/tokenizers/custom_tokenizer.h b/src/libstat/tokenizers/custom_tokenizer.h new file mode 100644 index 000000000..bc173a1da --- /dev/null +++ b/src/libstat/tokenizers/custom_tokenizer.h @@ -0,0 +1,177 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_CUSTOM_TOKENIZER_H +#define RSPAMD_CUSTOM_TOKENIZER_H + +/* Check if we're being included by internal Rspamd code or external plugins */ +#ifdef RSPAMD_TOKENIZER_INTERNAL +/* Internal Rspamd usage - use the full headers */ +#include "config.h" +#include "ucl.h" +#include "libserver/word.h" +#else +/* External plugin usage - use standalone types */ +#include "rspamd_tokenizer_types.h" +/* Forward declaration for UCL object - plugins should include ucl.h if needed */ +typedef struct ucl_object_s ucl_object_t; +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define RSPAMD_CUSTOM_TOKENIZER_API_VERSION 1 + +/** + * Tokenization result - compatible with both internal and external usage + */ +typedef rspamd_words_t rspamd_tokenizer_result_t; + +/** + * Custom tokenizer API that must be implemented by language-specific tokenizer plugins + * All functions use only plain C types to ensure clean boundaries + */ +typedef struct rspamd_custom_tokenizer_api { + /* API version for compatibility checking */ + unsigned int api_version; + + /* Name of the tokenizer (e.g., "japanese_mecab") */ + const char *name; + + /** + * Global initialization function called once when the tokenizer is loaded + * @param config UCL configuration object for this tokenizer (may be NULL) + * @param error_buf Buffer for error message (at least 256 bytes) + * @return 0 on success, non-zero on failure + */ + int (*init)(const ucl_object_t *config, char *error_buf, size_t error_buf_size); + + /** + * Global cleanup function called when the tokenizer is unloaded + */ + void (*deinit)(void); + + /** + * Quick language detection to check if this tokenizer can handle the text + * @param text UTF-8 text to analyze + * @param len Length of the text in bytes + * @return Confidence score 0.0-1.0, or -1.0 if cannot handle + */ + double (*detect_language)(const char *text, size_t len); + + /** + * Main tokenization function + * @param text UTF-8 text to tokenize + * @param len Length of the text in bytes + * @param result Output kvec to fill with rspamd_word_t elements + * @return 0 on success, non-zero on failure + * + * The tokenizer should allocate result->a using its own allocator + * Rspamd will call cleanup_result() to free it after processing + */ + int (*tokenize)(const char *text, size_t len, + rspamd_tokenizer_result_t *result); + + /** + * Cleanup the result from tokenize() + * @param result Result kvec returned by tokenize() + * + * This function should free result->a using the same allocator + * that was used in tokenize() and reset the kvec fields. + * This ensures proper memory management across DLL boundaries. + * Note: This does NOT free the result structure itself, only its contents. + */ + void (*cleanup_result)(rspamd_tokenizer_result_t *result); + + /** + * Optional: Get language hint for better language detection + * @return Language code (e.g., "ja", "zh") or NULL + */ + const char *(*get_language_hint)(void); + + /** + * Optional: Get minimum confidence threshold for this tokenizer + * @return Minimum confidence (0.0-1.0) or -1.0 to use default + */ + double (*get_min_confidence)(void); + +} rspamd_custom_tokenizer_api_t; + +/** + * Entry point function that plugins must export + * Must be named "rspamd_tokenizer_get_api" + */ +typedef const rspamd_custom_tokenizer_api_t *(*rspamd_tokenizer_get_api_func)(void); + +/* Internal Rspamd structures - not exposed to plugins */ +#ifdef RSPAMD_TOKENIZER_INTERNAL + +/** + * Custom tokenizer instance + */ +struct rspamd_custom_tokenizer { + char *name; /* Tokenizer name from config */ + char *path; /* Path to .so file */ + void *handle; /* dlopen handle */ + const rspamd_custom_tokenizer_api_t *api; /* API functions */ + double priority; /* Detection priority */ + double min_confidence; /* Minimum confidence threshold */ + gboolean enabled; /* Is tokenizer enabled */ + ucl_object_t *config; /* Tokenizer-specific config */ +}; + +/** + * Tokenizer manager structure + */ +struct rspamd_tokenizer_manager { + GHashTable *tokenizers; /* name -> rspamd_custom_tokenizer */ + GArray *detection_order; /* Ordered by priority */ + rspamd_mempool_t *pool; + double default_threshold; /* Default confidence threshold */ +}; + +/* Manager functions */ +struct rspamd_tokenizer_manager *rspamd_tokenizer_manager_new(rspamd_mempool_t *pool); +void rspamd_tokenizer_manager_destroy(struct rspamd_tokenizer_manager *mgr); + +gboolean rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr, + const char *name, + const ucl_object_t *config, + GError **err); + +struct rspamd_custom_tokenizer *rspamd_tokenizer_manager_detect( + struct rspamd_tokenizer_manager *mgr, + const char *text, size_t len, + double *confidence, + const char *lang_hint, + const char **detected_lang_hint); + +/* Helper function to tokenize with exceptions handling */ +rspamd_tokenizer_result_t *rspamd_custom_tokenizer_tokenize_with_exceptions( + struct rspamd_custom_tokenizer *tokenizer, + const char *text, + gsize len, + GList *exceptions, + rspamd_mempool_t *pool); + +#endif /* RSPAMD_TOKENIZER_INTERNAL */ + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_CUSTOM_TOKENIZER_H */ diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c index 0bc3414a5..360c71d36 100644 --- a/src/libstat/tokenizers/osb.c +++ b/src/libstat/tokenizers/osb.c @@ -21,6 +21,7 @@ #include "tokenizers.h" #include "stat_internal.h" #include "libmime/lang_detection.h" +#include "libserver/word.h" /* Size for features pipe */ #define DEFAULT_FEATURE_WINDOW_SIZE 2 @@ -268,7 +269,7 @@ struct token_pipe_entry { int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result) @@ -282,7 +283,7 @@ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, gsize token_size; unsigned int processed = 0, i, w, window_size, token_flags = 0; - if (words == NULL) { + if (words == NULL || !words->a) { return FALSE; } @@ -306,8 +307,8 @@ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, sizeof(RSPAMD_TOKEN_VALUE_TYPE) * ctx->statfiles->len; g_assert(token_size > 0); - for (w = 0; w < words->len; w++) { - token = &g_array_index(words, rspamd_stat_token_t, w); + for (w = 0; w < kv_size(*words); w++) { + token = &kv_A(*words, w); token_flags = token->flags; const char *begin; gsize len; diff --git a/src/libstat/tokenizers/rspamd_tokenizer_types.h b/src/libstat/tokenizers/rspamd_tokenizer_types.h new file mode 100644 index 000000000..eb8518290 --- /dev/null +++ b/src/libstat/tokenizers/rspamd_tokenizer_types.h @@ -0,0 +1,89 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_TOKENIZER_TYPES_H +#define RSPAMD_TOKENIZER_TYPES_H + +/* + * Standalone type definitions for custom tokenizers + * This header is completely self-contained and does not depend on any external libraries. + * Custom tokenizers should include only this header to get access to all necessary types. + */ + +#include <stdint.h> +#include <stddef.h> + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Basic string token structure + */ +typedef struct rspamd_ftok { + size_t len; + const char *begin; +} rspamd_ftok_t; + +/** + * Unicode string token structure + */ +typedef struct rspamd_ftok_unicode { + size_t len; + const uint32_t *begin; +} rspamd_ftok_unicode_t; + +/* Word flags */ +#define RSPAMD_WORD_FLAG_TEXT (1u << 0u) +#define RSPAMD_WORD_FLAG_META (1u << 1u) +#define RSPAMD_WORD_FLAG_LUA_META (1u << 2u) +#define RSPAMD_WORD_FLAG_EXCEPTION (1u << 3u) +#define RSPAMD_WORD_FLAG_HEADER (1u << 4u) +#define RSPAMD_WORD_FLAG_UNIGRAM (1u << 5u) +#define RSPAMD_WORD_FLAG_UTF (1u << 6u) +#define RSPAMD_WORD_FLAG_NORMALISED (1u << 7u) +#define RSPAMD_WORD_FLAG_STEMMED (1u << 8u) +#define RSPAMD_WORD_FLAG_BROKEN_UNICODE (1u << 9u) +#define RSPAMD_WORD_FLAG_STOP_WORD (1u << 10u) +#define RSPAMD_WORD_FLAG_SKIPPED (1u << 11u) +#define RSPAMD_WORD_FLAG_INVISIBLE_SPACES (1u << 12u) +#define RSPAMD_WORD_FLAG_EMOJI (1u << 13u) + +/** + * Word structure + */ +typedef struct rspamd_word { + rspamd_ftok_t original; + rspamd_ftok_unicode_t unicode; + rspamd_ftok_t normalized; + rspamd_ftok_t stemmed; + unsigned int flags; +} rspamd_word_t; + +/** + * Array of words + */ +typedef struct rspamd_words { + rspamd_word_t *a; + size_t n; + size_t m; +} rspamd_words_t; + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_TOKENIZER_TYPES_H */ diff --git a/src/libstat/tokenizers/tokenizer_manager.c b/src/libstat/tokenizers/tokenizer_manager.c new file mode 100644 index 000000000..e6fb5e8d8 --- /dev/null +++ b/src/libstat/tokenizers/tokenizer_manager.c @@ -0,0 +1,500 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" +#include "tokenizers.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "custom_tokenizer.h" +#include "libutil/util.h" +#include "libserver/logger.h" +#include <dlfcn.h> + +#define msg_err_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_warn_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_info_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_debug_tokenizer(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_tokenizer_log_id, "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(tokenizer) + +static void +rspamd_custom_tokenizer_dtor(gpointer p) +{ + struct rspamd_custom_tokenizer *tok = p; + + if (tok) { + if (tok->api && tok->api->deinit) { + tok->api->deinit(); + } + + if (tok->handle) { + dlclose(tok->handle); + } + + if (tok->config) { + ucl_object_unref(tok->config); + } + + g_free(tok->name); + g_free(tok->path); + g_free(tok); + } +} + +static int +rspamd_custom_tokenizer_priority_cmp(gconstpointer a, gconstpointer b) +{ + const struct rspamd_custom_tokenizer *t1 = *(const struct rspamd_custom_tokenizer **) a; + const struct rspamd_custom_tokenizer *t2 = *(const struct rspamd_custom_tokenizer **) b; + + /* Higher priority first */ + if (t1->priority > t2->priority) { + return -1; + } + else if (t1->priority < t2->priority) { + return 1; + } + + return 0; +} + +struct rspamd_tokenizer_manager * +rspamd_tokenizer_manager_new(rspamd_mempool_t *pool) +{ + struct rspamd_tokenizer_manager *mgr; + + mgr = rspamd_mempool_alloc0(pool, sizeof(*mgr)); + mgr->pool = pool; + mgr->tokenizers = g_hash_table_new_full(rspamd_strcase_hash, + rspamd_strcase_equal, + NULL, + rspamd_custom_tokenizer_dtor); + mgr->detection_order = g_array_new(FALSE, FALSE, sizeof(struct rspamd_custom_tokenizer *)); + mgr->default_threshold = 0.7; /* Default confidence threshold */ + + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) g_hash_table_unref, + mgr->tokenizers); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) rspamd_array_free_hard, + mgr->detection_order); + + msg_info_tokenizer("created custom tokenizer manager with default confidence threshold %.3f", + mgr->default_threshold); + + return mgr; +} + +void rspamd_tokenizer_manager_destroy(struct rspamd_tokenizer_manager *mgr) +{ + /* Cleanup is handled by memory pool destructors */ +} + +gboolean +rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr, + const char *name, + const ucl_object_t *config, + GError **err) +{ + struct rspamd_custom_tokenizer *tok; + const ucl_object_t *elt; + rspamd_tokenizer_get_api_func get_api; + const rspamd_custom_tokenizer_api_t *api; + void *handle; + const char *path; + gboolean enabled = TRUE; + double priority = 50.0; + char error_buf[256]; + + g_assert(mgr != NULL); + g_assert(name != NULL); + g_assert(config != NULL); + + msg_info_tokenizer("starting to load custom tokenizer '%s'", name); + + /* Check if enabled */ + elt = ucl_object_lookup(config, "enabled"); + if (elt && ucl_object_type(elt) == UCL_BOOLEAN) { + enabled = ucl_object_toboolean(elt); + } + + if (!enabled) { + msg_info_tokenizer("custom tokenizer '%s' is disabled", name); + return TRUE; + } + + /* Get path */ + elt = ucl_object_lookup(config, "path"); + if (!elt || ucl_object_type(elt) != UCL_STRING) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "missing 'path' for tokenizer %s", name); + return FALSE; + } + path = ucl_object_tostring(elt); + msg_info_tokenizer("custom tokenizer '%s' will be loaded from path: %s", name, path); + + /* Get priority */ + elt = ucl_object_lookup(config, "priority"); + if (elt) { + priority = ucl_object_todouble(elt); + } + msg_info_tokenizer("custom tokenizer '%s' priority set to %.1f", name, priority); + + /* Load the shared library */ + msg_info_tokenizer("loading shared library for custom tokenizer '%s'", name); + handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "cannot load tokenizer %s from %s: %s", + name, path, dlerror()); + return FALSE; + } + msg_info_tokenizer("successfully loaded shared library for custom tokenizer '%s'", name); + + /* Get the API entry point */ + msg_info_tokenizer("looking up API entry point for custom tokenizer '%s'", name); + get_api = (rspamd_tokenizer_get_api_func) dlsym(handle, "rspamd_tokenizer_get_api"); + if (!get_api) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "cannot find entry point in %s: %s", + path, dlerror()); + return FALSE; + } + + /* Get the API */ + msg_info_tokenizer("calling API entry point for custom tokenizer '%s'", name); + api = get_api(); + if (!api) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "tokenizer %s returned NULL API", name); + return FALSE; + } + msg_info_tokenizer("successfully obtained API from custom tokenizer '%s'", name); + + /* Check API version */ + msg_info_tokenizer("checking API version for custom tokenizer '%s' (got %u, expected %u)", + name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION); + if (api->api_version != RSPAMD_CUSTOM_TOKENIZER_API_VERSION) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "tokenizer %s has incompatible API version %u (expected %u)", + name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION); + return FALSE; + } + + /* Create tokenizer instance */ + tok = g_malloc0(sizeof(*tok)); + tok->name = g_strdup(name); + tok->path = g_strdup(path); + tok->handle = handle; + tok->api = api; + tok->priority = priority; + tok->enabled = enabled; + + /* Get tokenizer config */ + elt = ucl_object_lookup(config, "config"); + if (elt) { + tok->config = ucl_object_ref(elt); + } + + /* Get minimum confidence */ + if (api->get_min_confidence) { + tok->min_confidence = api->get_min_confidence(); + msg_info_tokenizer("custom tokenizer '%s' provides minimum confidence threshold: %.3f", + name, tok->min_confidence); + } + else { + tok->min_confidence = mgr->default_threshold; + msg_info_tokenizer("custom tokenizer '%s' using default confidence threshold: %.3f", + name, tok->min_confidence); + } + + /* Initialize the tokenizer */ + if (api->init) { + msg_info_tokenizer("initializing custom tokenizer '%s'", name); + error_buf[0] = '\0'; + if (api->init(tok->config, error_buf, sizeof(error_buf)) != 0) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "failed to initialize tokenizer %s: %s", + name, error_buf[0] ? error_buf : "unknown error"); + rspamd_custom_tokenizer_dtor(tok); + return FALSE; + } + msg_info_tokenizer("successfully initialized custom tokenizer '%s'", name); + } + else { + msg_info_tokenizer("custom tokenizer '%s' does not require initialization", name); + } + + /* Add to manager */ + g_hash_table_insert(mgr->tokenizers, tok->name, tok); + g_array_append_val(mgr->detection_order, tok); + + /* Re-sort by priority */ + g_array_sort(mgr->detection_order, rspamd_custom_tokenizer_priority_cmp); + msg_info_tokenizer("custom tokenizer '%s' registered and sorted by priority (total tokenizers: %u)", + name, mgr->detection_order->len); + + msg_info_tokenizer("successfully loaded custom tokenizer '%s' (priority %.1f) from %s", + name, priority, path); + + return TRUE; +} + +struct rspamd_custom_tokenizer * +rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr, + const char *text, size_t len, + double *confidence, + const char *lang_hint, + const char **detected_lang_hint) +{ + struct rspamd_custom_tokenizer *tok, *best_tok = NULL; + double conf, best_conf = 0.0; + unsigned int i; + + g_assert(mgr != NULL); + g_assert(text != NULL); + + msg_debug_tokenizer("starting tokenizer detection for text of length %zu", len); + + if (confidence) { + *confidence = 0.0; + } + + if (detected_lang_hint) { + *detected_lang_hint = NULL; + } + + /* If we have a language hint, try to find a tokenizer for that language first */ + if (lang_hint) { + msg_info_tokenizer("trying to find tokenizer for language hint: %s", lang_hint); + for (i = 0; i < mgr->detection_order->len; i++) { + tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i); + + if (!tok->enabled || !tok->api->get_language_hint) { + continue; + } + + /* Check if this tokenizer handles the hinted language */ + const char *tok_lang = tok->api->get_language_hint(); + if (tok_lang && g_ascii_strcasecmp(tok_lang, lang_hint) == 0) { + msg_info_tokenizer("found tokenizer '%s' for language hint '%s'", tok->name, lang_hint); + /* Found a tokenizer for this language, check if it actually detects it */ + if (tok->api->detect_language) { + conf = tok->api->detect_language(text, len); + msg_info_tokenizer("tokenizer '%s' confidence for hinted language: %.3f (threshold: %.3f)", + tok->name, conf, tok->min_confidence); + if (conf >= tok->min_confidence) { + /* Use this tokenizer */ + msg_info_tokenizer("using tokenizer '%s' for language hint '%s' with confidence %.3f", + tok->name, lang_hint, conf); + if (confidence) { + *confidence = conf; + } + if (detected_lang_hint) { + *detected_lang_hint = tok_lang; + } + return tok; + } + } + } + } + msg_info_tokenizer("no suitable tokenizer found for language hint '%s', falling back to general detection", lang_hint); + } + + /* Try each tokenizer in priority order */ + msg_info_tokenizer("trying %u tokenizers for general detection", mgr->detection_order->len); + for (i = 0; i < mgr->detection_order->len; i++) { + tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i); + + if (!tok->enabled || !tok->api->detect_language) { + msg_debug_tokenizer("skipping tokenizer '%s' (enabled: %s, has detect_language: %s)", + tok->name, tok->enabled ? "yes" : "no", + tok->api->detect_language ? "yes" : "no"); + continue; + } + + conf = tok->api->detect_language(text, len); + msg_info_tokenizer("tokenizer '%s' detection confidence: %.3f (threshold: %.3f, current best: %.3f)", + tok->name, conf, tok->min_confidence, best_conf); + + if (conf > best_conf && conf >= tok->min_confidence) { + best_conf = conf; + best_tok = tok; + msg_info_tokenizer("tokenizer '%s' is new best with confidence %.3f", tok->name, best_conf); + + /* Early exit if very confident */ + if (conf >= 0.95) { + msg_info_tokenizer("very high confidence (%.3f >= 0.95), using tokenizer '%s' immediately", + conf, tok->name); + break; + } + } + } + + if (best_tok) { + msg_info_tokenizer("selected tokenizer '%s' with confidence %.3f", best_tok->name, best_conf); + if (confidence) { + *confidence = best_conf; + } + + if (detected_lang_hint && best_tok->api->get_language_hint) { + *detected_lang_hint = best_tok->api->get_language_hint(); + msg_info_tokenizer("detected language hint: %s", *detected_lang_hint); + } + } + else { + msg_info_tokenizer("no suitable tokenizer found during detection"); + } + + return best_tok; +} + +/* Helper function to tokenize with a custom tokenizer handling exceptions */ +rspamd_tokenizer_result_t * +rspamd_custom_tokenizer_tokenize_with_exceptions( + struct rspamd_custom_tokenizer *tokenizer, + const char *text, + gsize len, + GList *exceptions, + rspamd_mempool_t *pool) +{ + rspamd_tokenizer_result_t *words; + rspamd_tokenizer_result_t result; + struct rspamd_process_exception *ex; + GList *cur_ex = exceptions; + gsize pos = 0; + unsigned int i; + int ret; + + /* Allocate result kvec in pool */ + words = rspamd_mempool_alloc(pool, sizeof(*words)); + kv_init(*words); + + /* If no exceptions, tokenize the whole text */ + if (!exceptions) { + kv_init(result); + + ret = tokenizer->api->tokenize(text, len, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result to output */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + kv_push(rspamd_word_t, *words, tok); + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + + return words; + } + + /* Process text with exceptions */ + while (pos < len && cur_ex) { + ex = (struct rspamd_process_exception *) cur_ex->data; + + /* Tokenize text before exception */ + if (ex->pos > pos) { + gsize segment_len = ex->pos - pos; + kv_init(result); + + ret = tokenizer->api->tokenize(text + pos, segment_len, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result, adjusting positions for segment offset */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + + /* Adjust pointers to point to the original text */ + gsize offset_in_segment = tok.original.begin - (text + pos); + if (offset_in_segment < segment_len) { + tok.original.begin = text + pos + offset_in_segment; + /* Ensure we don't go past the exception boundary */ + if (tok.original.begin + tok.original.len <= text + ex->pos) { + kv_push(rspamd_word_t, *words, tok); + } + } + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + } + + /* Add exception as a special token */ + rspamd_word_t ex_tok; + memset(&ex_tok, 0, sizeof(ex_tok)); + + if (ex->type == RSPAMD_EXCEPTION_URL) { + ex_tok.original.begin = "!!EX!!"; + ex_tok.original.len = 6; + } + else { + ex_tok.original.begin = text + ex->pos; + ex_tok.original.len = ex->len; + } + ex_tok.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + kv_push(rspamd_word_t, *words, ex_tok); + + /* Move past exception */ + pos = ex->pos + ex->len; + cur_ex = g_list_next(cur_ex); + } + + /* Process remaining text after last exception */ + if (pos < len) { + kv_init(result); + + ret = tokenizer->api->tokenize(text + pos, len - pos, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result, adjusting positions for segment offset */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + + /* Adjust pointers to point to the original text */ + gsize offset_in_segment = tok.original.begin - (text + pos); + if (offset_in_segment < (len - pos)) { + tok.original.begin = text + pos + offset_in_segment; + kv_push(rspamd_word_t, *words, tok); + } + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + } + + return words; +} diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c index 0ea1bcfc6..8a9f42992 100644 --- a/src/libstat/tokenizers/tokenizers.c +++ b/src/libstat/tokenizers/tokenizers.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ #include "contrib/mumhash/mum.h" #include "libmime/lang_detection.h" #include "libstemmer.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "custom_tokenizer.h" #include <unicode/utf8.h> #include <unicode/uchar.h> @@ -35,8 +37,8 @@ #include <math.h> -typedef gboolean (*token_get_function)(rspamd_stat_token_t *buf, char const **pos, - rspamd_stat_token_t *token, +typedef gboolean (*token_get_function)(rspamd_word_t *buf, char const **pos, + rspamd_word_t *token, GList **exceptions, gsize *rl, gboolean check_signature); const char t_delimiters[256] = { @@ -69,8 +71,8 @@ const char t_delimiters[256] = { /* Get next word from specified f_str_t buf */ static gboolean -rspamd_tokenizer_get_word_raw(rspamd_stat_token_t *buf, - char const **cur, rspamd_stat_token_t *token, +rspamd_tokenizer_get_word_raw(rspamd_word_t *buf, + char const **cur, rspamd_word_t *token, GList **exceptions, gsize *rl, gboolean unused) { gsize remain, pos; @@ -164,7 +166,7 @@ rspamd_tokenize_check_limit(gboolean decay, unsigned int nwords, uint64_t *hv, uint64_t *prob, - const rspamd_stat_token_t *token, + const rspamd_word_t *token, gssize remain, gssize total) { @@ -242,9 +244,9 @@ rspamd_utf_word_valid(const unsigned char *text, const unsigned char *end, } while (0) static inline void -rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) +rspamd_tokenize_exception(struct rspamd_process_exception *ex, rspamd_words_t *res) { - rspamd_stat_token_t token; + rspamd_word_t token; memset(&token, 0, sizeof(token)); @@ -253,7 +255,7 @@ rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) token.original.len = sizeof("!!EX!!") - 1; token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, exception_error); token.flags = 0; } else if (ex->type == RSPAMD_EXCEPTION_URL) { @@ -271,28 +273,33 @@ rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) } token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, exception_error); token.flags = 0; } + return; + +exception_error: + /* On error, just skip this exception token */ + return; } -GArray * +rspamd_words_t * rspamd_tokenize_text(const char *text, gsize len, const UText *utxt, enum rspamd_tokenize_type how, struct rspamd_config *cfg, GList *exceptions, uint64_t *hash, - GArray *cur_words, + rspamd_words_t *output_kvec, rspamd_mempool_t *pool) { - rspamd_stat_token_t token, buf; + rspamd_word_t token, buf; const char *pos = NULL; gsize l = 0; - GArray *res; + rspamd_words_t *res; GList *cur = exceptions; - unsigned int min_len = 0, max_len = 0, word_decay = 0, initial_size = 128; + unsigned int min_len = 0, max_len = 0, word_decay = 0; uint64_t hv = 0; gboolean decay = FALSE, long_text_mode = FALSE; uint64_t prob = 0; @@ -300,9 +307,12 @@ rspamd_tokenize_text(const char *text, gsize len, static const gsize long_text_limit = 1 * 1024 * 1024; static const ev_tstamp max_exec_time = 0.2; /* 200 ms */ ev_tstamp start; + struct rspamd_custom_tokenizer *custom_tok = NULL; + double custom_confidence = 0.0; + const char *detected_lang = NULL; if (text == NULL) { - return cur_words; + return output_kvec; } if (len > long_text_limit) { @@ -323,15 +333,59 @@ rspamd_tokenize_text(const char *text, gsize len, min_len = cfg->min_word_len; max_len = cfg->max_word_len; word_decay = cfg->words_decay; - initial_size = word_decay * 2; } - if (!cur_words) { - res = g_array_sized_new(FALSE, FALSE, sizeof(rspamd_stat_token_t), - initial_size); + if (!output_kvec) { + res = pool ? rspamd_mempool_alloc0(pool, sizeof(*res)) : g_malloc0(sizeof(*res)); + ; } else { - res = cur_words; + res = output_kvec; + } + + /* Try custom tokenizers first if we're in UTF mode */ + if (cfg && cfg->tokenizer_manager && how == RSPAMD_TOKENIZE_UTF && utxt != NULL) { + custom_tok = rspamd_tokenizer_manager_detect( + cfg->tokenizer_manager, + text, len, + &custom_confidence, + NULL, /* no input language hint */ + &detected_lang); + + if (custom_tok && custom_confidence >= custom_tok->min_confidence) { + /* Use custom tokenizer with exception handling */ + rspamd_tokenizer_result_t *custom_res = rspamd_custom_tokenizer_tokenize_with_exceptions( + custom_tok, text, len, exceptions, pool); + + if (custom_res) { + msg_debug_pool("using custom tokenizer %s (confidence: %.2f) for text tokenization", + custom_tok->name, custom_confidence); + + /* Copy custom tokenizer results to output kvec */ + for (unsigned int i = 0; i < kv_size(*custom_res); i++) { + kv_push_safe(rspamd_word_t, *res, kv_A(*custom_res, i), custom_tokenizer_error); + } + + /* Calculate hash if needed */ + if (hash && kv_size(*res) > 0) { + for (unsigned int i = 0; i < kv_size(*res); i++) { + rspamd_word_t *t = &kv_A(*res, i); + if (t->original.len >= sizeof(uint64_t)) { + uint64_t tmp; + memcpy(&tmp, t->original.begin, sizeof(tmp)); + hv = mum_hash_step(hv, tmp); + } + } + *hash = mum_hash_finish(hv); + } + + return res; + } + else { + msg_warn_pool("custom tokenizer %s failed to tokenize text, falling back to default", + custom_tok->name); + } + } } if (G_UNLIKELY(how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) { @@ -343,7 +397,7 @@ rspamd_tokenize_text(const char *text, gsize len, } if (token.original.len > 0 && - rspamd_tokenize_check_limit(decay, word_decay, res->len, + rspamd_tokenize_check_limit(decay, word_decay, kv_size(*res), &hv, &prob, &token, pos - text, len)) { if (!decay) { decay = TRUE; @@ -355,28 +409,28 @@ rspamd_tokenize_text(const char *text, gsize len, } if (long_text_mode) { - if ((res->len + 1) % 16 == 0) { + if ((kv_size(*res) + 1) % 16 == 0) { ev_tstamp now = ev_time(); if (now - start > max_exec_time) { msg_warn_pool_check( "too long time has been spent on tokenization:" - " %.1f ms, limit is %.1f ms; %d words added so far", + " %.1f ms, limit is %.1f ms; %z words added so far", (now - start) * 1e3, max_exec_time * 1e3, - res->len); + kv_size(*res)); goto end; } } } - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, tokenize_error); - if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + if (kv_size(*res) * sizeof(token) > (0x1ull << 30u)) { /* Due to bug in glib ! */ msg_err_pool_check( - "too many words found: %d, stop tokenization to avoid DoS", - res->len); + "too many words found: %z, stop tokenization to avoid DoS", + kv_size(*res)); goto end; } @@ -523,7 +577,7 @@ rspamd_tokenize_text(const char *text, gsize len, } if (token.original.len > 0 && - rspamd_tokenize_check_limit(decay, word_decay, res->len, + rspamd_tokenize_check_limit(decay, word_decay, kv_size(*res), &hv, &prob, &token, p, len)) { if (!decay) { decay = TRUE; @@ -536,15 +590,15 @@ rspamd_tokenize_text(const char *text, gsize len, if (token.original.len > 0) { /* Additional check for number of words */ - if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + if (kv_size(*res) * sizeof(token) > (0x1ull << 30u)) { /* Due to bug in glib ! */ - msg_err("too many words found: %d, stop tokenization to avoid DoS", - res->len); + msg_err("too many words found: %z, stop tokenization to avoid DoS", + kv_size(*res)); goto end; } - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, tokenize_error); } /* Also check for long text mode */ @@ -552,15 +606,15 @@ rspamd_tokenize_text(const char *text, gsize len, /* Check time each 128 words added */ const int words_check_mask = 0x7F; - if ((res->len & words_check_mask) == words_check_mask) { + if ((kv_size(*res) & words_check_mask) == words_check_mask) { ev_tstamp now = ev_time(); if (now - start > max_exec_time) { msg_warn_pool_check( "too long time has been spent on tokenization:" - " %.1f ms, limit is %.1f ms; %d words added so far", + " %.1f ms, limit is %.1f ms; %z words added so far", (now - start) * 1e3, max_exec_time * 1e3, - res->len); + kv_size(*res)); goto end; } @@ -590,8 +644,14 @@ end: } return res; + +tokenize_error: +custom_tokenizer_error: + msg_err_pool("failed to allocate memory for tokenization"); + return res; } + #undef SHIFT_EX static void @@ -625,32 +685,38 @@ rspamd_add_metawords_from_str(const char *beg, gsize len, #endif } + /* Initialize meta_words kvec if not already done */ + if (!task->meta_words.a) { + kv_init(task->meta_words); + } + if (valid_utf) { utext_openUTF8(&utxt, beg, len, &uc_err); - task->meta_words = rspamd_tokenize_text(beg, len, - &utxt, RSPAMD_TOKENIZE_UTF, - task->cfg, NULL, NULL, - task->meta_words, - task->task_pool); + rspamd_tokenize_text(beg, len, + &utxt, RSPAMD_TOKENIZE_UTF, + task->cfg, NULL, NULL, + &task->meta_words, + task->task_pool); utext_close(&utxt); } else { - task->meta_words = rspamd_tokenize_text(beg, len, - NULL, RSPAMD_TOKENIZE_RAW, - task->cfg, NULL, NULL, task->meta_words, - task->task_pool); + rspamd_tokenize_text(beg, len, + NULL, RSPAMD_TOKENIZE_RAW, + task->cfg, NULL, NULL, + &task->meta_words, + task->task_pool); } } void rspamd_tokenize_meta_words(struct rspamd_task *task) { unsigned int i = 0; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; if (MESSAGE_FIELD(task, subject)) { rspamd_add_metawords_from_str(MESSAGE_FIELD(task, subject), @@ -667,7 +733,7 @@ void rspamd_tokenize_meta_words(struct rspamd_task *task) } } - if (task->meta_words != NULL) { + if (task->meta_words.a) { const char *language = NULL; if (MESSAGE_FIELD(task, text_parts) && @@ -680,12 +746,12 @@ void rspamd_tokenize_meta_words(struct rspamd_task *task) } } - rspamd_normalize_words(task->meta_words, task->task_pool); - rspamd_stem_words(task->meta_words, task->task_pool, language, + rspamd_normalize_words(&task->meta_words, task->task_pool); + rspamd_stem_words(&task->meta_words, task->task_pool, language, task->lang_det); - for (i = 0; i < task->meta_words->len; i++) { - tok = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(task->meta_words); i++) { + tok = &kv_A(task->meta_words, i); tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER; } } @@ -759,7 +825,7 @@ rspamd_ucs32_to_normalised(rspamd_stat_token_t *tok, tok->normalized.begin = dest; } -void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool) +void rspamd_normalize_single_word(rspamd_word_t *tok, rspamd_mempool_t *pool) { UErrorCode uc_err = U_ZERO_ERROR; UConverter *utf8_converter; @@ -858,25 +924,27 @@ void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *po } } -void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool) + +void rspamd_normalize_words(rspamd_words_t *words, rspamd_mempool_t *pool) { - rspamd_stat_token_t *tok; + rspamd_word_t *tok; unsigned int i; - for (i = 0; i < words->len; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); rspamd_normalize_single_word(tok, pool); } } -void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, + +void rspamd_stem_words(rspamd_words_t *words, rspamd_mempool_t *pool, const char *language, struct rspamd_lang_detector *lang_detector) { static GHashTable *stemmers = NULL; struct sb_stemmer *stem = NULL; unsigned int i; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; char *dest; gsize dlen; @@ -909,8 +977,18 @@ void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, stem = NULL; } } - for (i = 0; i < words->len; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); + + /* Skip stemming if token has already been stemmed by custom tokenizer */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_STEMMED) { + /* Already stemmed, just check for stop words */ + if (tok->stemmed.len > 0 && lang_detector != NULL && + rspamd_language_detector_is_stop_word(lang_detector, tok->stemmed.begin, tok->stemmed.len)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD; + } + continue; + } if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { if (stem) { @@ -952,4 +1030,4 @@ void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, } } } -}
\ No newline at end of file +} diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h index d4a8824a8..bb0bb54e2 100644 --- a/src/libstat/tokenizers/tokenizers.h +++ b/src/libstat/tokenizers/tokenizers.h @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "fstring.h" #include "rspamd.h" #include "stat_api.h" +#include "libserver/word.h" #include <unicode/utext.h> @@ -43,7 +44,7 @@ struct rspamd_stat_tokenizer { int (*tokenize_func)(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result); @@ -59,20 +60,20 @@ enum rspamd_tokenize_type { int token_node_compare_func(gconstpointer a, gconstpointer b); -/* Tokenize text into array of words (rspamd_stat_token_t type) */ -GArray *rspamd_tokenize_text(const char *text, gsize len, - const UText *utxt, - enum rspamd_tokenize_type how, - struct rspamd_config *cfg, - GList *exceptions, - uint64_t *hash, - GArray *cur_words, - rspamd_mempool_t *pool); +/* Tokenize text into kvec of words (rspamd_word_t type) */ +rspamd_words_t *rspamd_tokenize_text(const char *text, gsize len, + const UText *utxt, + enum rspamd_tokenize_type how, + struct rspamd_config *cfg, + GList *exceptions, + uint64_t *hash, + rspamd_words_t *output_kvec, + rspamd_mempool_t *pool); /* OSB tokenize function */ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result); @@ -83,11 +84,11 @@ gpointer rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool, struct rspamd_lang_detector; -void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool); +void rspamd_normalize_single_word(rspamd_word_t *tok, rspamd_mempool_t *pool); -void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool); - -void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, +/* Word processing functions */ +void rspamd_normalize_words(rspamd_words_t *words, rspamd_mempool_t *pool); +void rspamd_stem_words(rspamd_words_t *words, rspamd_mempool_t *pool, const char *language, struct rspamd_lang_detector *lang_detector); diff --git a/src/libutil/cxx/file_util.cxx b/src/libutil/cxx/file_util.cxx index 2f031f076..bc5dcdf3b 100644 --- a/src/libutil/cxx/file_util.cxx +++ b/src/libutil/cxx/file_util.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ #include "file_util.hxx" -#include <fmt/core.h> +#include "contrib/fmt/include/fmt/core.h" #include "libutil/util.h" #include "libutil/unix-std.h" diff --git a/src/libutil/fstring.h b/src/libutil/fstring.h index 0792ab9fa..ca9f689c8 100644 --- a/src/libutil/fstring.h +++ b/src/libutil/fstring.h @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -30,8 +30,8 @@ extern "C" { */ typedef struct f_str_s { - gsize len; - gsize allocated; + size_t len; + size_t allocated; char str[]; } rspamd_fstring_t; @@ -40,12 +40,12 @@ typedef struct f_str_s { #define RSPAMD_FSTRING_LIT(lit) rspamd_fstring_new_init((lit), sizeof(lit) - 1) typedef struct f_str_tok { - gsize len; + size_t len; const char *begin; } rspamd_ftok_t; typedef struct f_str_unicode_tok { - gsize len; /* in UChar32 */ + size_t len; /* in UChar32 */ const UChar32 *begin; } rspamd_ftok_unicode_t; diff --git a/src/libutil/mem_pool.c b/src/libutil/mem_pool.c index 2912d423c..575b4e497 100644 --- a/src/libutil/mem_pool.c +++ b/src/libutil/mem_pool.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -403,9 +403,10 @@ rspamd_mempool_new_(gsize size, const char *tag, int flags, const char *loc) /* Generate new uid */ uint64_t uid = rspamd_random_uint64_fast(); - rspamd_encode_hex_buf((unsigned char *) &uid, sizeof(uid), - new_pool->tag.uid, sizeof(new_pool->tag.uid) - 1); - new_pool->tag.uid[sizeof(new_pool->tag.uid) - 1] = '\0'; + G_STATIC_ASSERT(sizeof(new_pool->tag.uid) >= sizeof(uid) * 2 + 1); + int enc_len = rspamd_encode_hex_buf((unsigned char *) &uid, sizeof(uid), + new_pool->tag.uid, sizeof(new_pool->tag.uid) - 1); + new_pool->tag.uid[enc_len] = '\0'; mem_pool_stat->pools_allocated++; @@ -984,8 +985,7 @@ __mutex_spin(rspamd_mempool_mutex_t *mutex) ts.tv_sec = 0; ts.tv_nsec = MUTEX_SLEEP_TIME; /* Spin */ - while (nanosleep(&ts, &ts) == -1 && errno == EINTR) - ; + while (nanosleep(&ts, &ts) == -1 && errno == EINTR); #else #error No methods to spin are defined #endif @@ -1157,7 +1157,6 @@ void rspamd_mempool_wunlock_rwlock(rspamd_mempool_rwlock_t *lock) } #endif -#define RSPAMD_MEMPOOL_VARS_HASH_SEED 0xb32ad7c55eb2e647ULL void rspamd_mempool_set_variable(rspamd_mempool_t *pool, const char *name, gpointer value, @@ -1175,12 +1174,10 @@ void rspamd_mempool_set_variable(rspamd_mempool_t *pool, } } - int hv = rspamd_cryptobox_fast_hash(name, strlen(name), - RSPAMD_MEMPOOL_VARS_HASH_SEED); khiter_t it; int r; - it = kh_put(rspamd_mempool_vars_hash, pool->priv->variables, hv, &r); + it = kh_put(rspamd_mempool_vars_hash, pool->priv->variables, name, &r); if (it == kh_end(pool->priv->variables)) { g_assert_not_reached(); @@ -1196,6 +1193,10 @@ void rspamd_mempool_set_variable(rspamd_mempool_t *pool, pvar->dtor(pvar->data); } } + else { + /* Store copy of the key to provide persistent storage */ + kh_key(pool->priv->variables, it) = rspamd_mempool_strdup(pool, name); + } pvar = &kh_val(pool->priv->variables, it); pvar->data = value; @@ -1211,10 +1212,8 @@ rspamd_mempool_get_variable(rspamd_mempool_t *pool, const char *name) } khiter_t it; - int hv = rspamd_cryptobox_fast_hash(name, strlen(name), - RSPAMD_MEMPOOL_VARS_HASH_SEED); - it = kh_get(rspamd_mempool_vars_hash, pool->priv->variables, hv); + it = kh_get(rspamd_mempool_vars_hash, pool->priv->variables, name); if (it != kh_end(pool->priv->variables)) { struct rspamd_mempool_variable *pvar; @@ -1234,10 +1233,7 @@ rspamd_mempool_steal_variable(rspamd_mempool_t *pool, const char *name) } khiter_t it; - int hv = rspamd_cryptobox_fast_hash(name, strlen(name), - RSPAMD_MEMPOOL_VARS_HASH_SEED); - - it = kh_get(rspamd_mempool_vars_hash, pool->priv->variables, hv); + it = kh_get(rspamd_mempool_vars_hash, pool->priv->variables, name); if (it != kh_end(pool->priv->variables)) { struct rspamd_mempool_variable *pvar; @@ -1255,10 +1251,8 @@ void rspamd_mempool_remove_variable(rspamd_mempool_t *pool, const char *name) { if (pool->priv->variables != NULL) { khiter_t it; - int hv = rspamd_cryptobox_fast_hash(name, strlen(name), - RSPAMD_MEMPOOL_VARS_HASH_SEED); - it = kh_get(rspamd_mempool_vars_hash, pool->priv->variables, hv); + it = kh_get(rspamd_mempool_vars_hash, pool->priv->variables, name); if (it != kh_end(pool->priv->variables)) { struct rspamd_mempool_variable *pvar; diff --git a/src/libutil/mem_pool.h b/src/libutil/mem_pool.h index 651b44661..00d1a2067 100644 --- a/src/libutil/mem_pool.h +++ b/src/libutil/mem_pool.h @@ -71,7 +71,7 @@ struct f_str_s; #endif #define MEMPOOL_TAG_LEN 16 -#define MEMPOOL_UID_LEN 16 +#define MEMPOOL_UID_LEN 32 /* All pointers are aligned as this variable */ #define MIN_MEM_ALIGNMENT G_MEM_ALIGN diff --git a/src/libutil/mem_pool_internal.h b/src/libutil/mem_pool_internal.h index 26a687b6c..2f9ad15b6 100644 --- a/src/libutil/mem_pool_internal.h +++ b/src/libutil/mem_pool_internal.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,8 +63,8 @@ struct rspamd_mempool_variable { }; KHASH_INIT(rspamd_mempool_vars_hash, - uint32_t, struct rspamd_mempool_variable, 1, - kh_int_hash_func, kh_int_hash_equal); + const char *, struct rspamd_mempool_variable, 1, + kh_str_hash_func, kh_str_hash_equal); struct rspamd_mempool_specific { struct _pool_chain *pools[RSPAMD_MEMPOOL_MAX]; diff --git a/src/libutil/radix.c b/src/libutil/radix.c index 2cae8e34a..bdd722b49 100644 --- a/src/libutil/radix.c +++ b/src/libutil/radix.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ radix_find_compressed(radix_compressed_t *tree, const uint8_t *key, gsize keylen uintptr_t radix_insert_compressed(radix_compressed_t *tree, - uint8_t *key, gsize keylen, + const uint8_t *key, gsize keylen, gsize masklen, uintptr_t value) { @@ -128,6 +128,39 @@ radix_insert_compressed(radix_compressed_t *tree, return old; } +uintptr_t +radix_insert_compressed_addr(radix_compressed_t *tree, + const rspamd_inet_addr_t *addr, + uintptr_t value) +{ + const unsigned char *key; + unsigned int klen = 0; + unsigned char buf[16]; + + if (addr == NULL) { + return RADIX_NO_VALUE; + } + + key = rspamd_inet_address_get_hash_key(addr, &klen); + + if (key && klen) { + if (klen == 4) { + /* Map to ipv6 */ + memset(buf, 0, 10); + buf[10] = 0xffu; + buf[11] = 0xffu; + memcpy(buf + 12, key, klen); + + key = buf; + klen = sizeof(buf); + } + + return radix_insert_compressed(tree, key, klen, 0, value); + } + + return RADIX_NO_VALUE; +} + radix_compressed_t * radix_create_compressed(const char *tree_name) diff --git a/src/libutil/radix.h b/src/libutil/radix.h index c4fe96441..8c1224707 100644 --- a/src/libutil/radix.h +++ b/src/libutil/radix.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "mem_pool.h" #include "util.h" -#define RADIX_NO_VALUE (uintptr_t) - 1 +#define RADIX_NO_VALUE (uintptr_t) -1 #ifdef __cplusplus extern "C" { @@ -39,11 +39,23 @@ typedef struct radix_tree_compressed radix_compressed_t; */ uintptr_t radix_insert_compressed(radix_compressed_t *tree, - uint8_t *key, gsize keylen, + const uint8_t *key, gsize keylen, gsize masklen, uintptr_t value); /** + * Insert new address to the radix trie (works for IPv4 or IPv6 addresses) + * @param tree radix trie + * @param addr address to insert + * @param value opaque value pointer + * @return previous value of the key or `RADIX_NO_VALUE` + */ +uintptr_t +radix_insert_compressed_addr(radix_compressed_t *tree, + const rspamd_inet_addr_t *addr, + uintptr_t value); + +/** * Find a key in a radix trie * @param tree radix trie * @param key key to find (bitstring) diff --git a/src/libutil/shingles.c b/src/libutil/shingles.c index 5fe110eb8..c69c42292 100644 --- a/src/libutil/shingles.c +++ b/src/libutil/shingles.c @@ -18,6 +18,7 @@ #include "cryptobox.h" #include "images.h" #include "libstat/stat_api.h" +#include "libserver/word.h" #define SHINGLES_WINDOW 3 #define SHINGLES_KEY_SIZE rspamd_cryptobox_SIPKEYBYTES @@ -112,7 +113,7 @@ rspamd_shingles_get_keys_cached(const unsigned char key[SHINGLES_KEY_SIZE]) } struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") - rspamd_shingles_from_text(GArray *input, + rspamd_shingles_from_text(rspamd_words_t *input, const unsigned char key[16], rspamd_mempool_t *pool, rspamd_shingles_filter filter, @@ -123,12 +124,16 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") uint64_t **hashes; unsigned char **keys; rspamd_fstring_t *row; - rspamd_stat_token_t *word; + rspamd_word_t *word; uint64_t val; int i, j, k; gsize hlen, ilen = 0, beg = 0, widx = 0; enum rspamd_cryptobox_fast_hash_type ht; + if (!input || !input->a) { + return NULL; + } + if (pool != NULL) { res = rspamd_mempool_alloc(pool, sizeof(*res)); } @@ -138,10 +143,10 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") row = rspamd_fstring_sized_new(256); - for (i = 0; i < input->len; i++) { - word = &g_array_index(input, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*input); i++) { + word = &kv_A(*input, i); - if (!((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0)) { + if (!((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0)) { ilen++; } } @@ -162,10 +167,10 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") for (j = beg; j < i; j++) { word = NULL; - while (widx < input->len) { - word = &g_array_index(input, rspamd_stat_token_t, widx); + while (widx < kv_size(*input)) { + word = &kv_A(*input, widx); - if ((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0) { + if ((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0) { widx++; } else { @@ -237,10 +242,10 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") word = NULL; - while (widx < input->len) { - word = &g_array_index(input, rspamd_stat_token_t, widx); + while (widx < kv_size(*input)) { + word = &kv_A(*input, widx); - if ((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0) { + if ((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0) { widx++; } else { diff --git a/src/libutil/shingles.h b/src/libutil/shingles.h index fe6f16cf8..1ab2c6842 100644 --- a/src/libutil/shingles.h +++ b/src/libutil/shingles.h @@ -18,6 +18,7 @@ #include "config.h" #include "mem_pool.h" +#include "libserver/word.h" #define RSPAMD_SHINGLE_SIZE 32 @@ -48,14 +49,14 @@ typedef uint64_t (*rspamd_shingles_filter)(uint64_t *input, gsize count, /** * Generate shingles from the input of fixed size strings using lemmatizer * if needed - * @param input array of `rspamd_fstring_t` + * @param input kvec of `rspamd_word_t` * @param key secret key used to generate shingles * @param pool pool to allocate shingles array * @param filter hashes filtering function * @param filterd opaque data for filtering function * @return shingles array */ -struct rspamd_shingle *rspamd_shingles_from_text(GArray *input, +struct rspamd_shingle *rspamd_shingles_from_text(rspamd_words_t *input, const unsigned char key[16], rspamd_mempool_t *pool, rspamd_shingles_filter filter, diff --git a/src/lua/CMakeLists.txt b/src/lua/CMakeLists.txt index 46de053ba..135a21da2 100644 --- a/src/lua/CMakeLists.txt +++ b/src/lua/CMakeLists.txt @@ -35,6 +35,7 @@ SET(LUASRC ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c ${CMAKE_CURRENT_SOURCE_DIR}/lua_tensor.c ${CMAKE_CURRENT_SOURCE_DIR}/lua_parsers.c ${CMAKE_CURRENT_SOURCE_DIR}/lua_compress.c - ${CMAKE_CURRENT_SOURCE_DIR}/lua_classnames.c) + ${CMAKE_CURRENT_SOURCE_DIR}/lua_classnames.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_shingles.cxx) SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
\ No newline at end of file diff --git a/src/lua/lua_classnames.c b/src/lua/lua_classnames.c index 7ce2f8abc..2b5a90fe0 100644 --- a/src/lua/lua_classnames.c +++ b/src/lua/lua_classnames.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,6 +67,7 @@ const char *rspamd_url_classname = "rspamd{url}"; const char *rspamd_worker_classname = "rspamd{worker}"; const char *rspamd_zstd_compress_classname = "rspamd{zstd_compress}"; const char *rspamd_zstd_decompress_classname = "rspamd{zstd_decompress}"; +const char *rspamd_shingle_classname = "rspamd{shingle}"; KHASH_INIT(rspamd_lua_static_classes, const char *, const char *, 1, rspamd_str_hash, rspamd_str_equal); @@ -133,6 +134,7 @@ RSPAMD_CONSTRUCTOR(rspamd_lua_init_classnames) CLASS_PUT_STR(worker); CLASS_PUT_STR(zstd_compress); CLASS_PUT_STR(zstd_decompress); + CLASS_PUT_STR(shingle); /* Check consistency */ g_assert(kh_size(lua_static_classes) == RSPAMD_MAX_LUA_CLASSES); diff --git a/src/lua/lua_classnames.h b/src/lua/lua_classnames.h index 53db5f8c2..6e3a6441f 100644 --- a/src/lua/lua_classnames.h +++ b/src/lua/lua_classnames.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,9 +70,10 @@ extern const char *rspamd_url_classname; extern const char *rspamd_worker_classname; extern const char *rspamd_zstd_compress_classname; extern const char *rspamd_zstd_decompress_classname; +extern const char *rspamd_shingle_classname; /* Keep it consistent when adding new classes */ -#define RSPAMD_MAX_LUA_CLASSES 48 +#define RSPAMD_MAX_LUA_CLASSES 49 /* * Return a static class name for a given name (only for known classes) or NULL diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index d79efc308..f36228680 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -985,6 +985,7 @@ rspamd_lua_init(bool wipe_mem) luaopen_tensor(L); luaopen_parsers(L); luaopen_compress(L); + luaopen_shingle(L); #ifndef WITH_LUAJIT rspamd_lua_add_preload(L, "bit", luaopen_bit); lua_settop(L, 0); @@ -2400,7 +2401,7 @@ rspamd_lua_try_load_redis(lua_State *L, const ucl_object_t *obj, return FALSE; } -void rspamd_lua_push_full_word(lua_State *L, rspamd_stat_token_t *w) +void rspamd_lua_push_full_word(lua_State *L, rspamd_word_t *w) { int fl_cnt; @@ -2520,6 +2521,54 @@ int rspamd_lua_push_words(lua_State *L, GArray *words, return 1; } +int rspamd_lua_push_words_kvec(lua_State *L, rspamd_words_t *words, + enum rspamd_lua_words_type how) +{ + rspamd_word_t *w; + unsigned int i, cnt; + + if (!words || !words->a) { + lua_createtable(L, 0, 0); + return 1; + } + + lua_createtable(L, kv_size(*words), 0); + + for (i = 0, cnt = 1; i < kv_size(*words); i++) { + w = &kv_A(*words, i); + + switch (how) { + case RSPAMD_LUA_WORDS_STEM: + if (w->stemmed.len > 0) { + lua_pushlstring(L, w->stemmed.begin, w->stemmed.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_NORM: + if (w->normalized.len > 0) { + lua_pushlstring(L, w->normalized.begin, w->normalized.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_RAW: + if (w->original.len > 0) { + lua_pushlstring(L, w->original.begin, w->original.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_FULL: + rspamd_lua_push_full_word(L, w); + /* Push to the resulting vector */ + lua_rawseti(L, -2, cnt++); + break; + default: + break; + } + } + + return 1; +} + char * rspamd_lua_get_module_name(lua_State *L) { @@ -2657,4 +2706,4 @@ int rspamd_lua_geti(lua_State *L, int pos, int i) return lua_type(L, -1); } -#endif
\ No newline at end of file +#endif diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index 1d39d0c52..d494f0923 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -421,6 +421,8 @@ void luaopen_tensor(lua_State *L); void luaopen_parsers(lua_State *L); +void luaopen_shingle(lua_State *L); + void rspamd_lua_dostring(const char *line); double rspamd_lua_normalize(struct rspamd_config *cfg, @@ -454,6 +456,12 @@ struct rspamd_dns_resolver *lua_check_dns_resolver(lua_State *L, int pos); struct rspamd_lua_url *lua_check_url(lua_State *L, int pos); +/** + * Creates a new shingle object from the existing shingle + */ +struct rspamd_shingle; +void lua_newshingle(lua_State *L, const void *sh); + enum rspamd_lua_parse_arguments_flags { RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT = 0, RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING, @@ -530,9 +538,8 @@ enum lua_logger_escape_type { * @param len * @return */ -gsize lua_logger_out_type(lua_State *L, int pos, char *outbuf, - gsize len, struct lua_logger_trace *trace, - enum lua_logger_escape_type esc_type); +gsize lua_logger_out(lua_State *L, int pos, char *outbuf, gsize len, + enum lua_logger_escape_type esc_type); /** * Safely checks userdata to match specified class @@ -625,7 +632,7 @@ struct rspamd_stat_token_s; * @param L * @param word */ -void rspamd_lua_push_full_word(lua_State *L, struct rspamd_stat_token_s *word); +void rspamd_lua_push_full_word(lua_State *L, rspamd_word_t *word); enum rspamd_lua_words_type { RSPAMD_LUA_WORDS_STEM = 0, @@ -644,6 +651,9 @@ enum rspamd_lua_words_type { int rspamd_lua_push_words(lua_State *L, GArray *words, enum rspamd_lua_words_type how); +int rspamd_lua_push_words_kvec(lua_State *L, rspamd_words_t *words, + enum rspamd_lua_words_type how); + /** * Returns newly allocated name for caller module name * @param L diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 0b4d208b4..7b3a156cd 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,10 @@ #include "utlist.h" #include <math.h> +/* Forward declarations for custom tokenizer functions */ +gboolean rspamd_config_load_custom_tokenizers(struct rspamd_config *cfg, GError **err); +void rspamd_config_unload_custom_tokenizers(struct rspamd_config *cfg); + /*** * This module is used to configure rspamd and is normally available as global * variable named `rspamd_config`. Unlike other modules, it is not necessary to @@ -73,7 +77,7 @@ LUA_FUNCTION_DEF(config, get_ucl); /*** * @method rspamd_config:get_mempool() * Returns static configuration memory pool. - * @return {mempool} [memory pool](mempool.md) object + * @return {mempool} [memory pool](rspamd_mempool.md) object */ LUA_FUNCTION_DEF(config, get_mempool); /*** @@ -118,7 +122,7 @@ local function foo(task) end */ /*** -* @method rspamd_config:radix_from_ucl(obj) +* @method rspamd_config:radix_from_ucl(obj, description) * Creates new embedded map of IP/mask addresses from object. * @param {ucl} obj object * @return {map} radix tree object @@ -862,6 +866,19 @@ LUA_FUNCTION_DEF(config, get_dns_max_requests); */ LUA_FUNCTION_DEF(config, get_dns_timeout); +/*** + * @method rspamd_config:load_custom_tokenizers() + * Loads custom tokenizers from configuration + * @return {boolean} true if successful + */ +LUA_FUNCTION_DEF(config, load_custom_tokenizers); + +/*** + * @method rspamd_config:unload_custom_tokenizers() + * Unloads custom tokenizers and frees memory + */ +LUA_FUNCTION_DEF(config, unload_custom_tokenizers); + static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF(config, get_module_opt), LUA_INTERFACE_DEF(config, get_mempool), @@ -937,6 +954,8 @@ static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF(config, get_tld_path), LUA_INTERFACE_DEF(config, get_dns_max_requests), LUA_INTERFACE_DEF(config, get_dns_timeout), + LUA_INTERFACE_DEF(config, load_custom_tokenizers), + LUA_INTERFACE_DEF(config, unload_custom_tokenizers), {"__tostring", rspamd_lua_class_tostring}, {"__newindex", lua_config_newindex}, {NULL, NULL}}; @@ -4485,11 +4504,14 @@ lua_config_init_subsystem(lua_State *L) nparts = g_strv_length(parts); for (i = 0; i < nparts; i++) { - if (strcmp(parts[i], "filters") == 0) { + const char *str = parts[i]; + + /* TODO: total shit, rework some day */ + if (strcmp(str, "filters") == 0) { rspamd_lua_post_load_config(cfg); rspamd_init_filters(cfg, false, false); } - else if (strcmp(parts[i], "langdet") == 0) { + else if (strcmp(str, "langdet") == 0) { if (!cfg->lang_det) { cfg->lang_det = rspamd_language_detector_init(cfg); rspamd_mempool_add_destructor(cfg->cfg_pool, @@ -4497,10 +4519,10 @@ lua_config_init_subsystem(lua_State *L) cfg->lang_det); } } - else if (strcmp(parts[i], "stat") == 0) { + else if (strcmp(str, "stat") == 0) { rspamd_stat_init(cfg, NULL); } - else if (strcmp(parts[i], "dns") == 0) { + else if (strcmp(str, "dns") == 0) { struct ev_loop *ev_base = lua_check_ev_base(L, 3); if (ev_base) { @@ -4514,11 +4536,25 @@ lua_config_init_subsystem(lua_State *L) return luaL_error(L, "no event base specified"); } } - else if (strcmp(parts[i], "symcache") == 0) { + else if (strcmp(str, "symcache") == 0) { rspamd_symcache_init(cfg->cache); } + else if (strcmp(str, "tokenizers") == 0 || strcmp(str, "custom_tokenizers") == 0) { + GError *err = NULL; + if (!rspamd_config_load_custom_tokenizers(cfg, &err)) { + g_strfreev(parts); + if (err) { + int ret = luaL_error(L, "failed to load custom tokenizers: %s", err->message); + g_error_free(err); + return ret; + } + else { + return luaL_error(L, "failed to load custom tokenizers"); + } + } + } else { - int ret = luaL_error(L, "invalid param: %s", parts[i]); + int ret = luaL_error(L, "invalid param: %s", str); g_strfreev(parts); return ret; @@ -4772,3 +4808,43 @@ void lua_call_finish_script(struct rspamd_config_cfg_lua_script *sc, lua_thread_call(thread, 1); } + +static int +lua_config_load_custom_tokenizers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + GError *err = NULL; + gboolean ret = rspamd_config_load_custom_tokenizers(cfg, &err); + + if (!ret && err) { + lua_pushboolean(L, FALSE); + lua_pushstring(L, err->message); + g_error_free(err); + return 2; + } + + lua_pushboolean(L, ret); + return 1; + } + else { + return luaL_error(L, "invalid arguments"); + } +} + +static int +lua_config_unload_custom_tokenizers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + rspamd_config_unload_custom_tokenizers(cfg); + return 0; + } + else { + return luaL_error(L, "invalid arguments"); + } +} diff --git a/src/lua/lua_cryptobox.c b/src/lua/lua_cryptobox.c index b562c4778..2c2254920 100644 --- a/src/lua/lua_cryptobox.c +++ b/src/lua/lua_cryptobox.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,6 +121,8 @@ LUA_FUNCTION_DEF(cryptobox_secretbox, encrypt); LUA_FUNCTION_DEF(cryptobox_secretbox, decrypt); LUA_FUNCTION_DEF(cryptobox_secretbox, gc); +static void lua_cryptobox_hash_finish(struct rspamd_lua_cryptobox_hash *h); + static const struct luaL_reg cryptoboxlib_f[] = { LUA_INTERFACE_DEF(cryptobox, verify_memory), LUA_INTERFACE_DEF(cryptobox, verify_file), @@ -402,7 +404,7 @@ lua_cryptobox_keypair_load(lua_State *L) if (lua_type(L, 1) == LUA_TSTRING) { buf = luaL_checklstring(L, 1, &len); if (buf != NULL) { - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, buf, len)) { msg_err("cannot open keypair from data: %s", @@ -1306,6 +1308,19 @@ lua_cryptobox_hash_create_keyed(lua_State *L) return 1; } +struct lua_hash_elt { + unsigned char key_hash[rspamd_cryptobox_HASHBYTES]; + unsigned char value_hash[rspamd_cryptobox_HASHBYTES]; +}; + +int lua_cryptobox_hash_elt_cmp(const void *a, const void *b) +{ + const struct lua_hash_elt *ha = (struct lua_hash_elt *) a, + *hb = (struct lua_hash_elt *) b; + + return memcmp(ha->key_hash, hb->key_hash, sizeof(ha->key_hash)); +} + /*** * @function rspamd_cryptobox_hash.create_specific_keyed(key, type, [string]) * Creates new hash context with specified key @@ -1362,13 +1377,60 @@ lua_cryptobox_hash_create_specific_keyed(lua_State *L) return 1; } +static struct rspamd_lua_cryptobox_hash * +lua_cryptobox_hash_copy(const struct rspamd_lua_cryptobox_hash *orig) +{ + struct rspamd_lua_cryptobox_hash *nhash = g_malloc(sizeof(struct rspamd_lua_cryptobox_hash)); + + memcpy(nhash, orig, sizeof(struct rspamd_lua_cryptobox_hash)); + REF_INIT_RETAIN(nhash, lua_cryptobox_hash_dtor); + + if (orig->type == LUA_CRYPTOBOX_HASH_SSL) { + EVP_MD_CTX_copy(nhash->content.c, orig->content.c); + } + else if (orig->type == LUA_CRYPTOBOX_HASH_HMAC) { +#if OPENSSL_VERSION_NUMBER < 0x10100000L || \ + (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x30500000) + /* XXX: dunno what to do with this ancient crap */ +#else +#if OPENSSL_VERSION_NUMBER >= 0x30000000L + nhash->content.hmac_c = EVP_MAC_CTX_dup(orig->content.hmac_c); +#else + nhash->content.hmac_c = HMAC_CTX_new(); + HMAC_CTX_copy(nhash->content.hmac_c, orig->content.hmac_c); +#endif +#endif + } + else if (orig->type == LUA_CRYPTOBOX_HASH_BLAKE2) { + if (posix_memalign((void **) &nhash->content.h, + RSPAMD_ALIGNOF(rspamd_cryptobox_hash_state_t), + sizeof(*nhash->content.h)) != 0) { + g_assert_not_reached(); + } + memcpy(nhash->content.h, orig->content.h, sizeof(*nhash->content.h)); + } + else { + nhash->content.fh = rspamd_cryptobox_fast_hash_new(); + memcpy(nhash->content.fh, orig->content.fh, sizeof(*nhash->content.fh)); + } + + return nhash; +} + +#define MAX_HASH_UPDATE_REC 16 + static void -lua_cryptobox_update_pos(lua_State *L, struct rspamd_lua_cryptobox_hash *h, int pos) +lua_cryptobox_update_pos(lua_State *L, struct rspamd_lua_cryptobox_hash *h, int pos, int rec) { const char *data; struct rspamd_lua_text *t; gsize len; + if (rec > MAX_HASH_UPDATE_REC) { + /* Max recursion is reached, do nothing */ + return; + } + /* Inverse pos if it is relative to the top of the stack */ if (pos < 0) { pos = lua_gettop(L) + pos + 1; @@ -1412,22 +1474,44 @@ lua_cryptobox_update_pos(lua_State *L, struct rspamd_lua_cryptobox_hash *h, int for (gsize i = 1; i <= alen; i++) { lua_rawgeti(L, pos, i); - lua_cryptobox_update_pos(L, h, -1); /* Recurse */ + lua_cryptobox_update_pos(L, h, -1, rec + 1); /* Recurse */ lua_pop(L, 1); } - /* Hash key-value pairs */ + /* Hash key-value pairs and store all stuff in the array */ lua_pushnil(L); + GArray *tbl_digests = g_array_new(false, true, sizeof(struct lua_hash_elt)); while (lua_next(L, pos) != 0) { + struct rspamd_lua_cryptobox_hash *key_h = lua_cryptobox_hash_copy(h), + *value_h = lua_cryptobox_hash_copy(h); + struct lua_hash_elt he; /* Hash key */ lua_pushvalue(L, -2); - lua_cryptobox_update_pos(L, h, -1); + lua_cryptobox_update_pos(L, key_h, -1, rec + 1); lua_pop(L, 1); + lua_cryptobox_hash_finish(key_h); + memcpy(he.key_hash, key_h->out, sizeof(he.key_hash)); + REF_RELEASE(key_h); /* Hash value */ - lua_cryptobox_update_pos(L, h, -1); + lua_cryptobox_update_pos(L, value_h, -1, rec + 1); lua_pop(L, 1); + lua_cryptobox_hash_finish(value_h); + memcpy(he.value_hash, value_h->out, sizeof(he.value_hash)); + REF_RELEASE(value_h); + + g_array_append_val(tbl_digests, he); + } + + /* Sort elements */ + g_array_sort(tbl_digests, lua_cryptobox_hash_elt_cmp); + /* Now update the original hash context */ + for (size_t i = 0; i < tbl_digests->len; i++) { + struct lua_hash_elt *he = &g_array_index(tbl_digests, struct lua_hash_elt, i); + rspamd_lua_hash_update(h, he->key_hash, sizeof(he->key_hash)); + rspamd_lua_hash_update(h, he->value_hash, sizeof(he->value_hash)); } + g_array_free(tbl_digests, true); break; } @@ -1464,7 +1548,7 @@ lua_cryptobox_hash_update(lua_State *L) return luaL_error(L, "invalid arguments or hash is already finalized"); } - lua_cryptobox_update_pos(L, h, 2); + lua_cryptobox_update_pos(L, h, 2, 0); ph = lua_newuserdata(L, sizeof(void *)); *ph = h; diff --git a/src/lua/lua_http.c b/src/lua/lua_http.c index 904f1cbbf..731b8b057 100644 --- a/src/lua/lua_http.c +++ b/src/lua/lua_http.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,22 +29,123 @@ * This module hides all complexity: DNS resolving, sessions management, zero-copy * text transfers and so on under the hood. * @example +-- Basic GET request with callback local rspamd_http = require "rspamd_http" local function symbol_callback(task) local function http_callback(err_message, code, body, headers) task:insert_result('SYMBOL', 1) -- task is available via closure + + if err_message then + -- Handle error + return + end + + -- Process response + if code == 200 then + -- Process body and headers + for name, value in pairs(headers) do + -- Headers are lowercase + end + end end - rspamd_http.request({ - task=task, - url='http://example.com/data', - body=task:get_content(), - callback=http_callback, - headers={Header='Value', OtherHeader='Value'}, - mime_type='text/plain', - }) - end + rspamd_http.request({ + task=task, + url='http://example.com/data', + body=task:get_content(), + callback=http_callback, + headers={Header='Value', OtherHeader='Value', DuplicatedHeader={'Multiple', 'Values'}}, + mime_type='text/plain', + }) +end + +-- POST request with JSON body +local function post_json_example(task) + local ucl = require "ucl" + local data = { + id = task:get_queue_id(), + sender = task:get_from()[1].addr + } + + local json_data = ucl.to_json(data) + + rspamd_http.request({ + task = task, + url = "http://example.com/api/submit", + method = "POST", + body = json_data, + headers = {['Content-Type'] = 'application/json'}, + callback = function(err, code, body, headers) + if not err and code == 200 then + -- Success + end + end + }) +end + +-- Synchronous HTTP request (using coroutines) +local function sync_http_example(task) + -- No callback makes this a synchronous call + local err, response = rspamd_http.request({ + task = task, + url = "http://example.com/api/data", + method = "GET", + timeout = 10.0 + }) + + if not err then + -- Response is a table with code, content, and headers + if response.code == 200 then + -- Process response.content + return true + end + end + return false +end + +-- Using authentication +local function auth_example(task) + rspamd_http.request({ + task = task, + url = "https://example.com/api/protected", + method = "GET", + user = "username", + password = "secret", + callback = function(err, code, body, headers) + -- Process authenticated response + end + }) +end + +-- Using HTTPS with SSL options +local function https_example(task) + rspamd_http.request({ + task = task, + url = "https://example.com/api/secure", + method = "GET", + no_ssl_verify = false, -- Verify SSL (default) + callback = function(err, code, body, headers) + -- Process secure response + end + }) +end + +-- Using keep-alive and gzip +local function advanced_example(task) + rspamd_http.request({ + task = task, + url = "http://example.com/api/data", + method = "POST", + body = task:get_content(), + gzip = true, -- Compress request body + keepalive = true, -- Use keep-alive connection + max_size = 1024 * 1024, -- Limit response to 1MB + callback = function(err, code, body, headers) + -- Process response + end + }) +end */ #define MAX_HEADERS_SIZE 8192 @@ -600,9 +701,9 @@ lua_http_push_headers(lua_State *L, struct rspamd_http_message *msg) * - `config` * * @param {string} url specifies URL for a request in the standard URI form (e.g. 'http://example.com/path') - * @param {function} callback specifies callback function in format `function (err_message, code, body, headers)` that is called on HTTP request completion. if this parameter is missing, the function performs "pseudo-synchronous" call (see [Synchronous and Asynchronous API overview](/doc/lua/sync_async.html#API-example-http-module) + * @param {function} callback specifies callback function in format `function (err_message, code, body, headers)` that is called on HTTP request completion. if this parameter is missing, the function performs "pseudo-synchronous" call (see [Synchronous and Asynchronous API overview](/doc/developers/sync_async.html#API-example-http-module) * @param {task} task if called from symbol handler it is generally a good idea to use the common task objects: event base, DNS resolver and events session - * @param {table} headers optional headers in form `[name='value', name='value']` + * @param {table} headers optional headers in form `[name='value']` or `[name=['value1', 'value2']]` to duplicate a header with multiple values * @param {string} mime_type MIME type of the HTTP content (for example, `text/html`) * @param {string/text} body full body content, can be opaque `rspamd{text}` to avoid data copying * @param {number} timeout floating point request timeout value in seconds (default is 5.0 seconds) diff --git a/src/lua/lua_logger.c b/src/lua/lua_logger.c index 21aad8136..04ff81b6d 100644 --- a/src/lua/lua_logger.c +++ b/src/lua/lua_logger.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,6 +174,11 @@ static const struct luaL_reg loggerlib_f[] = { {"__tostring", rspamd_lua_class_tostring}, {NULL, NULL}}; +static gsize +lua_logger_out_type(lua_State *L, int pos, char *outbuf, + gsize len, struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type); + static void lua_common_log_line(GLogLevelFlags level, lua_State *L, @@ -203,23 +208,19 @@ lua_common_log_line(GLogLevelFlags level, d.currentline); } - rspamd_common_log_function(NULL, - level, - module, - uid, - func_buf, - "%s", - msg); + p = func_buf; } else { - rspamd_common_log_function(NULL, - level, - module, - uid, - G_STRFUNC, - "%s", - msg); + p = (char *) G_STRFUNC; } + + rspamd_common_log_function(NULL, + level, + module, + uid, + p, + "%s", + msg); } /*** Logger interface ***/ @@ -279,105 +280,161 @@ lua_logger_char_safe(int t, unsigned int esc_type) return true; } -static gsize -lua_logger_out_str(lua_State *L, int pos, - char *outbuf, gsize len, - struct lua_logger_trace *trace, - enum lua_logger_escape_type esc_type) +#define LUA_MAX_ARGS 32 +/* Gracefully handles argument mismatches by substituting missing args and noting extra args */ +static glong +lua_logger_log_format_str(lua_State *L, int offset, char *logbuf, gsize remain, + const char *fmt, + enum lua_logger_escape_type esc_type) { - gsize slen, flen; - const char *str = lua_tolstring(L, pos, &slen); - static const char hexdigests[16] = "0123456789abcdef"; - gsize r = 0, s; - - if (str) { - gboolean normal = TRUE; - flen = MIN(slen, len - 1); + const char *c; + gsize r; + int digit; + char *d = logbuf; + unsigned int arg_num, cur_arg = 0, arg_max = lua_gettop(L) - offset; + gboolean args_used[LUA_MAX_ARGS]; + unsigned int used_args_count = 0; + + memset(args_used, 0, sizeof(args_used)); + while (remain > 1 && *fmt) { + if (*fmt == '%') { + ++fmt; + c = fmt; + if (*fmt == 's') { + ++fmt; + ++cur_arg; + } + else { + arg_num = 0; + while ((digit = g_ascii_digit_value(*fmt)) >= 0) { + ++fmt; + arg_num = arg_num * 10 + digit; + if (arg_num >= LUA_MAX_ARGS) { + /* Avoid ridiculously large numbers */ + fmt = c; + break; + } + } - for (r = 0; r < flen; r++) { - if (!lua_logger_char_safe(str[r], esc_type)) { - normal = FALSE; - break; + if (fmt > c) { + /* Update the current argument */ + cur_arg = arg_num; + } } - } - if (normal) { - r = rspamd_strlcpy(outbuf, str, flen + 1); - } - else { - /* Need to escape non-printed characters */ - r = 0; - s = 0; - - while (slen > 0 && len > 1) { - if (!lua_logger_char_safe(str[s], esc_type)) { - if (len >= 3) { - outbuf[r++] = '\\'; - outbuf[r++] = hexdigests[((str[s] >> 4) & 0xF)]; - outbuf[r++] = hexdigests[((str[s]) & 0xF)]; - - len -= 2; - } - else { - outbuf[r++] = '?'; - } + if (fmt > c) { + if (cur_arg < 1 || cur_arg > arg_max) { + /* Missing argument - substitute placeholder */ + r = rspamd_snprintf(d, remain, "<MISSING ARGUMENT>"); } else { - outbuf[r++] = str[s]; + /* Valid argument - output it */ + r = lua_logger_out(L, offset + cur_arg, d, remain, esc_type); + /* Track which arguments are used */ + if (cur_arg <= LUA_MAX_ARGS && !args_used[cur_arg - 1]) { + args_used[cur_arg - 1] = TRUE; + used_args_count++; + } } - s++; - slen--; - len--; + g_assert(r < remain); + remain -= r; + d += r; + continue; } - outbuf[r] = '\0'; + /* Copy % */ + --fmt; } + + *d++ = *fmt++; + --remain; } - return r; + /* Check for extra arguments and append warning if any */ + if (used_args_count > 0 && used_args_count < arg_max && remain > 1) { + unsigned int extra_args = arg_max - used_args_count; + r = rspamd_snprintf(d, remain, " <EXTRA %d ARGUMENTS>", (int) extra_args); + remain -= r; + d += r; + } + + *d = 0; + + return d - logbuf; } +#undef LUA_MAX_ARGS + static gsize -lua_logger_out_num(lua_State *L, int pos, char *outbuf, gsize len, - struct lua_logger_trace *trace) +lua_logger_out_str(lua_State *L, int pos, + char *outbuf, gsize len, + enum lua_logger_escape_type esc_type) { - double num = lua_tonumber(L, pos); - glong inum; - gsize r = 0; + static const char hexdigests[16] = "0123456789abcdef"; + gsize slen; + const unsigned char *str = lua_tolstring(L, pos, &slen); + unsigned char c; + char *out = outbuf; - if ((double) (glong) num == num) { - inum = num; - r = rspamd_snprintf(outbuf, len + 1, "%l", inum); + if (str) { + while (slen > 0 && len > 1) { + c = *str++; + if (lua_logger_char_safe(c, esc_type)) { + *out++ = c; + } + else if (len > 3) { + /* Need to escape non-printed characters */ + *out++ = '\\'; + *out++ = hexdigests[c >> 4]; + *out++ = hexdigests[c & 0xF]; + len -= 2; + } + else { + *out++ = '?'; + } + --slen; + --len; + } } - else { - r = rspamd_snprintf(outbuf, len + 1, "%f", num); + *out = 0; + + return out - outbuf; +} + +static gsize +lua_logger_out_num(lua_State *L, int pos, char *outbuf, gsize len) +{ + double num = lua_tonumber(L, pos); + glong inum = (glong) num; + + if ((double) inum == num) { + return rspamd_snprintf(outbuf, len, "%l", inum); } - return r; + return rspamd_snprintf(outbuf, len, "%f", num); } static gsize -lua_logger_out_boolean(lua_State *L, int pos, char *outbuf, gsize len, - struct lua_logger_trace *trace) +lua_logger_out_boolean(lua_State *L, int pos, char *outbuf, gsize len) { gboolean val = lua_toboolean(L, pos); - gsize r = 0; - - r = rspamd_strlcpy(outbuf, val ? "true" : "false", len + 1); - return r; + return rspamd_snprintf(outbuf, len, val ? "true" : "false"); } static gsize -lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len, - struct lua_logger_trace *trace) +lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len) { - int r = 0, top; + gsize r = 0; + int top; const char *str = NULL; gboolean converted_to_str = FALSE; top = lua_gettop(L); + if (pos < 0) { + pos += top + 1; /* Convert to absolute */ + } if (!lua_getmetatable(L, pos)) { return 0; @@ -396,26 +453,17 @@ lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len, if (lua_isfunction(L, -1)) { lua_pushvalue(L, pos); - if (lua_pcall(L, 1, 1, 0) != 0) { - lua_settop(L, top); - - return 0; - } - - str = lua_tostring(L, -1); - - if (str) { - r = rspamd_snprintf(outbuf, len, "%s", str); + if (lua_pcall(L, 1, 1, 0) == 0) { + str = lua_tostring(L, -1); + if (str) { + r = rspamd_snprintf(outbuf, len, "%s", str); + } } - - lua_settop(L, top); - - return r; } } lua_settop(L, top); - return 0; + return r; } lua_pushstring(L, "__tostring"); @@ -460,12 +508,12 @@ lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len, return r; } -#define MOVE_BUF(d, remain, r) \ - (d) += (r); \ - (remain) -= (r); \ - if ((remain) == 0) { \ - lua_settop(L, old_top); \ - break; \ +#define MOVE_BUF(d, remain, r) \ + (d) += (r); \ + (remain) -= (r); \ + if ((remain) <= 1) { \ + lua_settop(L, top); \ + goto table_oob; \ } static gsize @@ -473,169 +521,153 @@ lua_logger_out_table(lua_State *L, int pos, char *outbuf, gsize len, struct lua_logger_trace *trace, enum lua_logger_escape_type esc_type) { - char *d = outbuf; - gsize remain = len, r; + char *d = outbuf, *str; + gsize remain = len; + glong r; gboolean first = TRUE; gconstpointer self = NULL; - int i, tpos, last_seq = -1, old_top; + int i, last_seq = 0, top; + double num; + glong inum; - if (!lua_istable(L, pos) || remain == 0) { - return 0; - } + /* Type and length checks are done in logger_out_type() */ - old_top = lua_gettop(L); self = lua_topointer(L, pos); /* Check if we have seen this pointer */ for (i = 0; i < TRACE_POINTS; i++) { if (trace->traces[i] == self) { - r = rspamd_snprintf(d, remain + 1, "ref(%p)", self); - - d += r; - - return (d - outbuf); + if ((trace->cur_level + TRACE_POINTS - 1) % TRACE_POINTS == i) { + return rspamd_snprintf(d, remain, "__self"); + } + return rspamd_snprintf(d, remain, "ref(%p)", self); } } trace->traces[trace->cur_level % TRACE_POINTS] = self; + ++trace->cur_level; - lua_pushvalue(L, pos); - r = rspamd_snprintf(d, remain + 1, "{"); - remain -= r; - d += r; + top = lua_gettop(L); + if (pos < 0) { + pos += top + 1; /* Convert to absolute */ + } + + r = rspamd_snprintf(d, remain, "{"); + MOVE_BUF(d, remain, r); /* Get numeric keys (ipairs) */ for (i = 1;; i++) { - lua_rawgeti(L, -1, i); + lua_rawgeti(L, pos, i); if (lua_isnil(L, -1)) { lua_pop(L, 1); + last_seq = i; break; } - last_seq = i; - - if (!first) { - r = rspamd_snprintf(d, remain + 1, ", "); - MOVE_BUF(d, remain, r); - } - - r = rspamd_snprintf(d, remain + 1, "[%d] = ", i); - MOVE_BUF(d, remain, r); - tpos = lua_gettop(L); - - if (lua_topointer(L, tpos) == self) { - r = rspamd_snprintf(d, remain + 1, "__self"); + if (first) { + first = FALSE; + str = "[%d] = "; } else { - r = lua_logger_out_type(L, tpos, d, remain, trace, esc_type); + str = ", [%d] = "; } + r = rspamd_snprintf(d, remain, str, i); + MOVE_BUF(d, remain, r); + + r = lua_logger_out_type(L, -1, d, remain, trace, esc_type); MOVE_BUF(d, remain, r); - first = FALSE; lua_pop(L, 1); } /* Get string keys (pairs) */ - for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + for (lua_pushnil(L); lua_next(L, pos); lua_pop(L, 1)) { /* 'key' is at index -2 and 'value' is at index -1 */ - if (lua_type(L, -2) == LUA_TNUMBER) { - if (last_seq > 0) { - lua_pushvalue(L, -2); - if (lua_tonumber(L, -1) <= last_seq + 1) { - lua_pop(L, 1); + /* Preserve key */ + lua_pushvalue(L, -2); + if (last_seq > 0) { + if (lua_type(L, -1) == LUA_TNUMBER) { + num = lua_tonumber(L, -1); /* no conversion here */ + inum = (glong) num; + if ((double) inum == num && inum > 0 && inum < last_seq) { /* Already seen */ + lua_pop(L, 1); continue; } - - lua_pop(L, 1); } } - if (!first) { - r = rspamd_snprintf(d, remain + 1, ", "); - MOVE_BUF(d, remain, r); - } - - /* Preserve key */ - lua_pushvalue(L, -2); - r = rspamd_snprintf(d, remain + 1, "[%s] = ", - lua_tostring(L, -1)); - lua_pop(L, 1); /* Remove key */ - MOVE_BUF(d, remain, r); - tpos = lua_gettop(L); - - if (lua_topointer(L, tpos) == self) { - r = rspamd_snprintf(d, remain + 1, "__self"); + if (first) { + first = FALSE; + str = "[%2] = %1"; } else { - r = lua_logger_out_type(L, tpos, d, remain, trace, esc_type); + str = ", [%2] = %1"; } + r = lua_logger_log_format_str(L, top + 1, d, remain, str, esc_type); + /* lua_logger_log_format_str now handles errors gracefully */ MOVE_BUF(d, remain, r); - first = FALSE; + /* Remove key */ + lua_pop(L, 1); } - lua_settop(L, old_top); - - r = rspamd_snprintf(d, remain + 1, "}"); + r = rspamd_snprintf(d, remain, "}"); d += r; +table_oob: + --trace->cur_level; + return (d - outbuf); } #undef MOVE_BUF -gsize lua_logger_out_type(lua_State *L, int pos, - char *outbuf, gsize len, - struct lua_logger_trace *trace, - enum lua_logger_escape_type esc_type) +static gsize +lua_logger_out_type(lua_State *L, int pos, + char *outbuf, gsize len, + struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type) { - int type; - gsize r = 0; - if (len == 0) { return 0; } - type = lua_type(L, pos); - trace->cur_level++; + int type = lua_type(L, pos); switch (type) { case LUA_TNUMBER: - r = lua_logger_out_num(L, pos, outbuf, len, trace); - break; + return lua_logger_out_num(L, pos, outbuf, len); case LUA_TBOOLEAN: - r = lua_logger_out_boolean(L, pos, outbuf, len, trace); - break; + return lua_logger_out_boolean(L, pos, outbuf, len); case LUA_TTABLE: - r = lua_logger_out_table(L, pos, outbuf, len, trace, esc_type); - break; + return lua_logger_out_table(L, pos, outbuf, len, trace, esc_type); case LUA_TUSERDATA: - r = lua_logger_out_userdata(L, pos, outbuf, len, trace); - break; + return lua_logger_out_userdata(L, pos, outbuf, len); case LUA_TFUNCTION: - r = rspamd_snprintf(outbuf, len + 1, "function"); - break; + return rspamd_snprintf(outbuf, len, "function"); case LUA_TLIGHTUSERDATA: - r = rspamd_snprintf(outbuf, len + 1, "0x%p", lua_topointer(L, pos)); - break; + return rspamd_snprintf(outbuf, len, "0x%p", lua_topointer(L, pos)); case LUA_TNIL: - r = rspamd_snprintf(outbuf, len + 1, "nil"); - break; + return rspamd_snprintf(outbuf, len, "nil"); case LUA_TNONE: - r = rspamd_snprintf(outbuf, len + 1, "no value"); - break; - default: - /* Try to push everything as string using tostring magic */ - r = lua_logger_out_str(L, pos, outbuf, len, trace, esc_type); - break; + return rspamd_snprintf(outbuf, len, "no value"); } - trace->cur_level--; + /* Try to push everything as string using tostring magic */ + return lua_logger_out_str(L, pos, outbuf, len, esc_type); +} - return r; +gsize lua_logger_out(lua_State *L, int pos, + char *outbuf, gsize len, + enum lua_logger_escape_type esc_type) +{ + struct lua_logger_trace tr; + memset(&tr, 0, sizeof(tr)); + + return lua_logger_out_type(L, pos, outbuf, len, &tr, esc_type); } static const char * @@ -731,126 +763,13 @@ static gboolean lua_logger_log_format(lua_State *L, int fmt_pos, gboolean is_string, char *logbuf, gsize remain) { - char *d; - const char *s, *c; - gsize r, cpylen = 0; - unsigned int arg_num = 0, cur_arg; - bool num_arg = false; - struct lua_logger_trace tr; - enum { - copy_char = 0, - got_percent, - parse_arg_num - } state = copy_char; - - d = logbuf; - s = lua_tostring(L, fmt_pos); - c = s; - cur_arg = fmt_pos; - - if (s == NULL) { + const char *fmt = lua_tostring(L, fmt_pos); + if (fmt == NULL) { return FALSE; } - while (remain > 0 && *s != '\0') { - switch (state) { - case copy_char: - if (*s == '%') { - state = got_percent; - s++; - if (cpylen > 0) { - memcpy(d, c, cpylen); - d += cpylen; - } - cpylen = 0; - } - else { - s++; - cpylen++; - remain--; - } - break; - case got_percent: - if (g_ascii_isdigit(*s) || *s == 's') { - state = parse_arg_num; - c = s; - } - else { - *d++ = *s++; - c = s; - state = copy_char; - } - break; - case parse_arg_num: - if (g_ascii_isdigit(*s)) { - s++; - num_arg = true; - } - else { - if (num_arg) { - arg_num = strtoul(c, NULL, 10); - arg_num += fmt_pos - 1; - /* Update the current argument */ - cur_arg = arg_num; - } - else { - /* We have non numeric argument, e.g. %s */ - arg_num = cur_arg++; - s++; - } - - if (arg_num < 1 || arg_num > (unsigned int) lua_gettop(L) + 1) { - msg_err("wrong argument number: %ud", arg_num); - - return FALSE; - } - - memset(&tr, 0, sizeof(tr)); - r = lua_logger_out_type(L, arg_num + 1, d, remain, &tr, - is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); - g_assert(r <= remain); - remain -= r; - d += r; - state = copy_char; - c = s; - } - break; - } - } - - if (state == parse_arg_num) { - if (num_arg) { - arg_num = strtoul(c, NULL, 10); - arg_num += fmt_pos - 1; - } - else { - /* We have non numeric argument, e.g. %s */ - arg_num = cur_arg; - } - - if (arg_num < 1 || arg_num > (unsigned int) lua_gettop(L) + 1) { - msg_err("wrong argument number: %ud", arg_num); - - return FALSE; - } - - memset(&tr, 0, sizeof(tr)); - r = lua_logger_out_type(L, arg_num + 1, d, remain, &tr, - is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); - g_assert(r <= remain); - remain -= r; - d += r; - } - else if (state == copy_char) { - if (cpylen > 0 && remain > 0) { - memcpy(d, c, cpylen); - d += cpylen; - } - } - - *d = '\0'; - - + /* lua_logger_log_format_str now handles argument mismatches gracefully */ + lua_logger_log_format_str(L, fmt_pos, logbuf, remain, fmt, is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); return TRUE; } @@ -862,15 +781,10 @@ lua_logger_do_log(lua_State *L, { char logbuf[RSPAMD_LOGBUF_SIZE - 128]; const char *uid = NULL; - int fmt_pos = start_pos; int ret; - GError *err = NULL; - if (lua_type(L, start_pos) == LUA_TSTRING) { - fmt_pos = start_pos; - } - else if (lua_type(L, start_pos) == LUA_TUSERDATA) { - fmt_pos = start_pos + 1; + if (lua_type(L, start_pos) == LUA_TUSERDATA) { + GError *err = NULL; uid = lua_logger_get_id(L, start_pos, &err); @@ -884,15 +798,17 @@ lua_logger_do_log(lua_State *L, return ret; } + + ++start_pos; } - else { + + if (lua_type(L, start_pos) != LUA_TSTRING) { /* Bad argument type */ return luaL_error(L, "bad format string type: %s", lua_typename(L, lua_type(L, start_pos))); } - ret = lua_logger_log_format(L, fmt_pos, is_string, - logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, start_pos, is_string, logbuf, sizeof(logbuf)); if (ret) { if (is_string) { @@ -903,12 +819,9 @@ lua_logger_do_log(lua_State *L, lua_common_log_line(level, L, logbuf, uid, "lua", 1); } } - else { - if (is_string) { - lua_pushnil(L); - - return 1; - } + else if (is_string) { + lua_pushnil(L); + return 1; } return 0; @@ -971,11 +884,11 @@ lua_logger_logx(lua_State *L) if (uid && modname) { if (lua_type(L, 4) == LUA_TSTRING) { - ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf)); } else if (lua_type(L, 4) == LUA_TNUMBER) { stack_pos = lua_tonumber(L, 4); - ret = lua_logger_log_format(L, 5, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 5, FALSE, logbuf, sizeof(logbuf)); } else { return luaL_error(L, "invalid argument on pos 4"); @@ -1013,11 +926,11 @@ lua_logger_debugm(lua_State *L) if (uid && module) { if (lua_type(L, 3) == LUA_TSTRING) { - ret = lua_logger_log_format(L, 3, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 3, FALSE, logbuf, sizeof(logbuf)); } else if (lua_type(L, 3) == LUA_TNUMBER) { stack_pos = lua_tonumber(L, 3); - ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf)); } else { return luaL_error(L, "invalid argument on pos 3"); diff --git a/src/lua/lua_map.c b/src/lua/lua_map.c index 062613bd7..5f55ece06 100644 --- a/src/lua/lua_map.c +++ b/src/lua/lua_map.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -319,6 +319,11 @@ int lua_config_radix_from_ucl(lua_State *L) ucl_object_insert_key(fake_obj, ucl_object_fromstring("static"), "url", 0, false); + if (lua_type(L, 3) == LUA_TSTRING) { + ucl_object_insert_key(fake_obj, ucl_object_fromstring(lua_tostring(L, 3)), + "description", 0, false); + } + if ((m = rspamd_map_add_from_ucl(cfg, fake_obj, "static radix map", rspamd_radix_read, rspamd_radix_fin, diff --git a/src/lua/lua_mimepart.c b/src/lua/lua_mimepart.c index 07dba9c93..982b10d90 100644 --- a/src/lua/lua_mimepart.c +++ b/src/lua/lua_mimepart.c @@ -901,7 +901,7 @@ lua_textpart_get_words_count(lua_State *L) return 1; } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_pushinteger(L, 0); } else { @@ -943,7 +943,7 @@ lua_textpart_get_words(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_createtable(L, 0, 0); } else { @@ -957,7 +957,7 @@ lua_textpart_get_words(lua_State *L) } } - return rspamd_lua_push_words(L, part->utf_words, how); + return rspamd_lua_push_words_kvec(L, &part->utf_words, how); } return 1; @@ -976,7 +976,7 @@ lua_textpart_filter_words(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_createtable(L, 0, 0); } else { @@ -998,9 +998,8 @@ lua_textpart_filter_words(lua_State *L) lua_createtable(L, 8, 0); - for (i = 0, cnt = 1; i < part->utf_words->len; i++) { - rspamd_stat_token_t *w = &g_array_index(part->utf_words, - rspamd_stat_token_t, i); + for (i = 0, cnt = 1; i < kv_size(part->utf_words); i++) { + rspamd_word_t *w = &kv_A(part->utf_words, i); switch (how) { case RSPAMD_LUA_WORDS_STEM: @@ -1194,13 +1193,13 @@ struct lua_shingle_filter_cbdata { rspamd_mempool_t *pool; }; -#define STORE_TOKEN(i, t) \ - do { \ - if ((i) < part->utf_words->len) { \ - word = &g_array_index(part->utf_words, rspamd_stat_token_t, (i)); \ - sd->t.begin = word->stemmed.begin; \ - sd->t.len = word->stemmed.len; \ - } \ +#define STORE_TOKEN(i, t) \ + do { \ + if ((i) < kv_size(part->utf_words)) { \ + word = &kv_A(part->utf_words, (i)); \ + sd->t.begin = word->stemmed.begin; \ + sd->t.len = word->stemmed.len; \ + } \ } while (0) static uint64_t @@ -1210,7 +1209,7 @@ lua_shingles_filter(uint64_t *input, gsize count, uint64_t minimal = G_MAXUINT64; gsize i, min_idx = 0; struct lua_shingle_data *sd; - rspamd_stat_token_t *word; + rspamd_word_t *word; struct lua_shingle_filter_cbdata *cbd = (struct lua_shingle_filter_cbdata *) ud; struct rspamd_mime_text_part *part; @@ -1248,7 +1247,7 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) unsigned int i; struct lua_shingle_data *sd; rspamd_cryptobox_hash_state_t st; - rspamd_stat_token_t *word; + rspamd_word_t *word; struct lua_shingle_filter_cbdata cbd; @@ -1256,7 +1255,7 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_pushnil(L); lua_pushnil(L); } @@ -1269,8 +1268,8 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) /* Calculate direct hash */ rspamd_cryptobox_hash_init(&st, key, rspamd_cryptobox_HASHKEYBYTES); - for (i = 0; i < part->utf_words->len; i++) { - word = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(part->utf_words); i++) { + word = &kv_A(part->utf_words, i); rspamd_cryptobox_hash_update(&st, word->stemmed.begin, word->stemmed.len); } @@ -1283,7 +1282,7 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) cbd.pool = pool; cbd.part = part; - sgl = rspamd_shingles_from_text(part->utf_words, key, + sgl = rspamd_shingles_from_text(&part->utf_words, key, pool, lua_shingles_filter, &cbd, RSPAMD_SHINGLES_MUMHASH); if (sgl == NULL) { diff --git a/src/lua/lua_parsers.c b/src/lua/lua_parsers.c index f77b36952..39e1b0317 100644 --- a/src/lua/lua_parsers.c +++ b/src/lua/lua_parsers.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2020 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -108,8 +108,8 @@ int lua_parsers_tokenize_text(lua_State *L) struct rspamd_lua_text *t; struct rspamd_process_exception *ex; UText utxt = UTEXT_INITIALIZER; - GArray *res; - rspamd_stat_token_t *w; + rspamd_words_t *res; + rspamd_word_t *w; if (lua_type(L, 1) == LUA_TSTRING) { in = luaL_checklstring(L, 1, &len); @@ -175,13 +175,15 @@ int lua_parsers_tokenize_text(lua_State *L) lua_pushnil(L); } else { - lua_createtable(L, res->len, 0); + lua_createtable(L, kv_size(*res), 0); - for (i = 0; i < res->len; i++) { - w = &g_array_index(res, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*res); i++) { + w = &kv_A(*res, i); lua_pushlstring(L, w->original.begin, w->original.len); lua_rawseti(L, -2, i + 1); } + kv_destroy(*res); + g_free(res); } cur = exceptions; diff --git a/src/lua/lua_redis.c b/src/lua/lua_redis.c index d20c496ed..491007df3 100644 --- a/src/lua/lua_redis.c +++ b/src/lua/lua_redis.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -130,7 +130,7 @@ struct lua_redis_request_specific_userdata { unsigned int nargs; char **args; gsize *arglens; - struct lua_redis_userdata *c; + struct lua_redis_userdata *common_ud; struct lua_redis_ctx *ctx; struct lua_redis_request_specific_userdata *next; ev_timer timeout_ev; @@ -262,7 +262,7 @@ lua_redis_fin(void *arg) struct lua_redis_ctx *ctx; ctx = sp_ud->ctx; - ud = sp_ud->c; + ud = sp_ud->common_ud; if (ev_can_stop(&sp_ud->timeout_ev)) { ev_timer_stop(sp_ud->ctx->async.event_loop, &sp_ud->timeout_ev); @@ -290,7 +290,7 @@ lua_redis_push_error(const char *err, gboolean connected, ...) { - struct lua_redis_userdata *ud = sp_ud->c; + struct lua_redis_userdata *ud = sp_ud->common_ud; struct lua_callback_state cbs; lua_State *L; @@ -390,7 +390,7 @@ static void lua_redis_push_data(const redisReply *r, struct lua_redis_ctx *ctx, struct lua_redis_request_specific_userdata *sp_ud) { - struct lua_redis_userdata *ud = sp_ud->c; + struct lua_redis_userdata *ud = sp_ud->common_ud; struct lua_callback_state cbs; lua_State *L; @@ -467,14 +467,14 @@ lua_redis_callback(redisAsyncContext *c, gpointer r, gpointer priv) redisAsyncContext *ac; ctx = sp_ud->ctx; - ud = sp_ud->c; + ud = sp_ud->common_ud; if (ud->terminated || !rspamd_lua_is_initialised()) { /* We are already at the termination stage, just go out */ return; } - msg_debug_lua_redis("got reply from redis %p for query %p", sp_ud->c->ctx, + msg_debug_lua_redis("got async reply from redis %p for query %p", sp_ud->common_ud->ctx, sp_ud); REDIS_RETAIN(ctx); @@ -601,7 +601,7 @@ lua_redis_callback_sync(redisAsyncContext *ac, gpointer r, gpointer priv) int results; ctx = sp_ud->ctx; - ud = sp_ud->c; + ud = sp_ud->common_ud; lua_State *L = ctx->async.cfg->lua_state; sp_ud->flags |= LUA_REDIS_SPECIFIC_REPLIED; @@ -620,7 +620,7 @@ lua_redis_callback_sync(redisAsyncContext *ac, gpointer r, gpointer priv) } if (!(sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED)) { - msg_debug_lua_redis("got reply from redis: %p for query %p", ac, sp_ud); + msg_debug_lua_redis("got sync reply from redis: %p for query %p", ac, sp_ud); struct lua_redis_result *result = g_malloc0(sizeof *result); @@ -653,17 +653,17 @@ lua_redis_callback_sync(redisAsyncContext *ac, gpointer r, gpointer priv) /* if error happened, we should terminate the connection, and release it */ - if (result->is_error && sp_ud->c->ctx) { - ac = sp_ud->c->ctx; + if (result->is_error && sp_ud->common_ud->ctx) { + ac = sp_ud->common_ud->ctx; /* Set to NULL to avoid double free in dtor */ - sp_ud->c->ctx = NULL; + sp_ud->common_ud->ctx = NULL; ctx->flags |= LUA_REDIS_TERMINATED; /* * This will call all callbacks pending so the entire context * will be destructed */ - rspamd_redis_pool_release_connection(sp_ud->c->pool, ac, + rspamd_redis_pool_release_connection(sp_ud->common_ud->pool, ac, RSPAMD_REDIS_RELEASE_FATAL); } @@ -679,6 +679,8 @@ lua_redis_callback_sync(redisAsyncContext *ac, gpointer r, gpointer priv) ctx->cmds_pending--; if (ctx->cmds_pending == 0) { + msg_debug_lua_redis("no more commands left for: %p for query %p", ac, sp_ud); + if (ctx->thread) { if (!(sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED)) { /* somebody yielded and waits for results */ @@ -717,16 +719,16 @@ lua_redis_timeout_sync(EV_P_ ev_timer *w, int revents) return; } - ud = sp_ud->c; + ud = sp_ud->common_ud; ctx = sp_ud->ctx; msg_debug_lua_redis("timeout while querying redis server: %p, redis: %p", sp_ud, - sp_ud->c->ctx); + sp_ud->common_ud->ctx); - if (sp_ud->c->ctx) { - ac = sp_ud->c->ctx; + if (sp_ud->common_ud->ctx) { + ac = sp_ud->common_ud->ctx; /* Set to NULL to avoid double free in dtor */ - sp_ud->c->ctx = NULL; + sp_ud->common_ud->ctx = NULL; ac->err = REDIS_ERR_IO; errno = ETIMEDOUT; ctx->flags |= LUA_REDIS_TERMINATED; @@ -735,7 +737,7 @@ lua_redis_timeout_sync(EV_P_ ev_timer *w, int revents) * This will call all callbacks pending so the entire context * will be destructed */ - rspamd_redis_pool_release_connection(sp_ud->c->pool, ac, + rspamd_redis_pool_release_connection(sp_ud->common_ud->pool, ac, RSPAMD_REDIS_RELEASE_FATAL); } } @@ -754,24 +756,24 @@ lua_redis_timeout(EV_P_ ev_timer *w, int revents) } ctx = sp_ud->ctx; - ud = sp_ud->c; + ud = sp_ud->common_ud; REDIS_RETAIN(ctx); msg_debug_lua_redis("timeout while querying redis server: %p, redis: %p", sp_ud, - sp_ud->c->ctx); + sp_ud->common_ud->ctx); lua_redis_push_error("timeout while connecting the server (%.2f sec)", ctx, sp_ud, TRUE, ud->timeout); - if (sp_ud->c->ctx) { - ac = sp_ud->c->ctx; + if (sp_ud->common_ud->ctx) { + ac = sp_ud->common_ud->ctx; /* Set to NULL to avoid double free in dtor */ - sp_ud->c->ctx = NULL; + sp_ud->common_ud->ctx = NULL; ac->err = REDIS_ERR_IO; errno = ETIMEDOUT; /* * This will call all callbacks pending so the entire context * will be destructed */ - rspamd_redis_pool_release_connection(sp_ud->c->pool, ac, + rspamd_redis_pool_release_connection(sp_ud->common_ud->pool, ac, RSPAMD_REDIS_RELEASE_FATAL); } @@ -1095,8 +1097,8 @@ rspamd_lua_redis_prepare_connection(lua_State *L, int *pcbref, gboolean is_async return NULL; } - msg_debug_lua_redis("opened redis connection host=%s; ctx=%p; ud=%p", - host, ctx, ud); + msg_debug_lua_redis("opened redis connection host=%s; lua_ctx=%p; redis_ctx=%p; ud=%p", + host, ctx, ud->ctx, ud); return ctx; } @@ -1137,7 +1139,7 @@ lua_redis_make_request(lua_State *L) ud = &ctx->async; sp_ud = g_malloc0(sizeof(*sp_ud)); sp_ud->cbref = cbref; - sp_ud->c = ud; + sp_ud->common_ud = ud; sp_ud->ctx = ctx; lua_pushstring(L, "cmd"); @@ -1501,21 +1503,18 @@ lua_redis_add_cmd(lua_State *L) } sp_ud = g_malloc0(sizeof(*sp_ud)); + sp_ud->common_ud = &ctx->async; + ud = &ctx->async; if (IS_ASYNC(ctx)) { - sp_ud->c = &ctx->async; - ud = &ctx->async; sp_ud->cbref = cbref; } - else { - sp_ud->c = &ctx->async; - ud = &ctx->async; - } + sp_ud->ctx = ctx; lua_redis_parse_args(L, args_pos, cmd, &sp_ud->args, &sp_ud->arglens, &sp_ud->nargs); - LL_PREPEND(sp_ud->c->specific, sp_ud); + LL_PREPEND(sp_ud->common_ud->specific, sp_ud); if (ud->s && rspamd_session_blocked(ud->s)) { lua_pushboolean(L, 0); @@ -1525,7 +1524,7 @@ lua_redis_add_cmd(lua_State *L) } if (IS_ASYNC(ctx)) { - ret = redisAsyncCommandArgv(sp_ud->c->ctx, + ret = redisAsyncCommandArgv(sp_ud->common_ud->ctx, lua_redis_callback, sp_ud, sp_ud->nargs, @@ -1533,7 +1532,7 @@ lua_redis_add_cmd(lua_State *L) sp_ud->arglens); } else { - ret = redisAsyncCommandArgv(sp_ud->c->ctx, + ret = redisAsyncCommandArgv(sp_ud->common_ud->ctx, lua_redis_callback_sync, sp_ud, sp_ud->nargs, @@ -1554,25 +1553,28 @@ lua_redis_add_cmd(lua_State *L) } sp_ud->timeout_ev.data = sp_ud; + ev_now_update_if_cheap(ud->event_loop); if (IS_ASYNC(ctx)) { ev_timer_init(&sp_ud->timeout_ev, lua_redis_timeout, - sp_ud->c->timeout, 0.0); + sp_ud->common_ud->timeout, 0.0); } else { ev_timer_init(&sp_ud->timeout_ev, lua_redis_timeout_sync, - sp_ud->c->timeout, 0.0); + sp_ud->common_ud->timeout, 0.0); } ev_timer_start(ud->event_loop, &sp_ud->timeout_ev); + msg_debug_lua_redis("added timeout %f for %p", sp_ud->common_ud->timeout, sp_ud); + REDIS_RETAIN(ctx); ctx->cmds_pending++; } else { msg_info("call to redis failed: %s", - sp_ud->c->ctx->errstr); + sp_ud->common_ud->ctx->errstr); lua_pushboolean(L, 0); - lua_pushstring(L, sp_ud->c->ctx->errstr); + lua_pushstring(L, sp_ud->common_ud->ctx->errstr); return 2; } @@ -1606,11 +1608,20 @@ lua_redis_exec(lua_State *L) return 0; } else { - if (ctx->cmds_pending == 0 && g_queue_get_length(ctx->replies) == 0) { + struct lua_redis_userdata *ud = &ctx->async; + int replies_pending = g_queue_get_length(ctx->replies); + + msg_debug_lua_redis("execute pending commands for %p; commands pending = %d; replies pending = %d", + ctx, + ctx->cmds_pending, + replies_pending); + + if (ctx->cmds_pending == 0 && replies_pending == 0) { lua_pushstring(L, "No pending commands to execute"); lua_error(L); } - if (ctx->cmds_pending == 0 && g_queue_get_length(ctx->replies) > 0) { + + if (ctx->cmds_pending == 0 && replies_pending > 0) { int results = lua_redis_push_results(ctx, L); return results; } diff --git a/src/lua/lua_shingles.cxx b/src/lua/lua_shingles.cxx new file mode 100644 index 000000000..7d4b277fc --- /dev/null +++ b/src/lua/lua_shingles.cxx @@ -0,0 +1,133 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "lua_classnames.h" +#include "shingles.h" +#include "contrib/fmt/include/fmt/format.h" + +/*** + * @module rspamd_shingle + * This module provides methods to work with text shingles + */ + +/*** + * @method shingle:to_table() + * Converts shingle to table of decimal strings + * @return {table} table of RSPAMD_SHINGLE_SIZE decimal strings + */ +LUA_FUNCTION_DEF(shingle, to_table); + +/*** + * @method shingle:get(index) + * Gets element at index as two lua_Integer values (high and low 32 bits) + * @param {number} index 1-based index + * @return {number,number} high and low 32-bit parts + */ +LUA_FUNCTION_DEF(shingle, get); + +/*** + * @method shingle:get_string(index) + * Gets element at index as decimal string + * @param {number} index 1-based index + * @return {string} decimal representation + */ +LUA_FUNCTION_DEF(shingle, get_string); + +static const struct luaL_reg shinglelib_m[] = { + LUA_INTERFACE_DEF(shingle, to_table), + LUA_INTERFACE_DEF(shingle, get), + LUA_INTERFACE_DEF(shingle, get_string), + {"__tostring", rspamd_lua_class_tostring}, + {nullptr, nullptr}}; + +static struct rspamd_shingle * +lua_check_shingle(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, rspamd_shingle_classname); + luaL_argcheck(L, ud != nullptr, pos, "'shingle' expected"); + return static_cast<struct rspamd_shingle *>(ud); +} + +void lua_newshingle(lua_State *L, const void *sh) +{ + auto *nsh = static_cast<struct rspamd_shingle *>( + lua_newuserdata(L, sizeof(struct rspamd_shingle))); + + if (sh != nullptr) { + memcpy(nsh, sh, sizeof(struct rspamd_shingle)); + } + + rspamd_lua_setclass(L, rspamd_shingle_classname, -1); +} + +static int +lua_shingle_to_table(lua_State *L) +{ + LUA_TRACE_POINT; + auto *sh = lua_check_shingle(L, 1); + + lua_createtable(L, RSPAMD_SHINGLE_SIZE, 0); + + for (int i = 0; i < RSPAMD_SHINGLE_SIZE; i++) { + auto str = fmt::format("{}", sh->hashes[i]); + lua_pushstring(L, str.c_str()); + lua_rawseti(L, -2, i + 1); + } + + return 1; +} + +static int +lua_shingle_get(lua_State *L) +{ + LUA_TRACE_POINT; + auto *sh = lua_check_shingle(L, 1); + auto idx = luaL_checkinteger(L, 2) - 1; + + if (idx < 0 || idx >= RSPAMD_SHINGLE_SIZE) { + return luaL_error(L, "index out of bounds: %d", idx + 1); + } + + uint64_t val = sh->hashes[idx]; + lua_pushinteger(L, (lua_Integer) (val >> 32)); + lua_pushinteger(L, (lua_Integer) (val & 0xFFFFFFFF)); + + return 2; +} + +static int +lua_shingle_get_string(lua_State *L) +{ + LUA_TRACE_POINT; + auto *sh = lua_check_shingle(L, 1); + auto idx = luaL_checkinteger(L, 2) - 1; + + if (idx < 0 || idx >= RSPAMD_SHINGLE_SIZE) { + return luaL_error(L, "index out of bounds: %d", idx + 1); + } + + auto str = fmt::format("{}", sh->hashes[idx]); + lua_pushstring(L, str.c_str()); + + return 1; +} + +void luaopen_shingle(lua_State *L) +{ + rspamd_lua_new_class(L, rspamd_shingle_classname, shinglelib_m); + lua_pop(L, 1); +} diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 355680881..0b1473b61 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1226,6 +1226,13 @@ LUA_FUNCTION_DEF(task, get_all_named_results); */ LUA_FUNCTION_DEF(task, get_dns_req); +/*** + * @method task:add_timer(timeout, callback) + * Creates a delayed execution task for the specific callback at given timeout (in seconds) + * + */ +LUA_FUNCTION_DEF(task, add_timer); + static const struct luaL_reg tasklib_f[] = { LUA_INTERFACE_DEF(task, create), LUA_INTERFACE_DEF(task, load_from_file), @@ -1353,6 +1360,7 @@ static const struct luaL_reg tasklib_m[] = { LUA_INTERFACE_DEF(task, add_named_result), LUA_INTERFACE_DEF(task, get_all_named_results), LUA_INTERFACE_DEF(task, topointer), + LUA_INTERFACE_DEF(task, add_timer), {"__tostring", rspamd_lua_class_tostring}, {NULL, NULL}}; @@ -6935,7 +6943,7 @@ lua_task_get_meta_words(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (task->meta_words == NULL) { + if (!task->meta_words.a) { lua_createtable(L, 0, 0); } else { @@ -6959,7 +6967,7 @@ lua_task_get_meta_words(lua_State *L) } } - return rspamd_lua_push_words(L, task->meta_words, how); + return rspamd_lua_push_words_kvec(L, &task->meta_words, how); } return 1; @@ -7031,6 +7039,76 @@ lua_lookup_words_array(lua_State *L, return nmatched; } +static unsigned int +lua_lookup_words_kvec(lua_State *L, + int cbpos, + struct rspamd_task *task, + struct rspamd_lua_map *map, + rspamd_words_t *words) +{ + rspamd_word_t *tok; + unsigned int i, nmatched = 0; + int err_idx; + gboolean matched; + const char *key; + gsize keylen; + + if (!words || !words->a) { + return 0; + } + + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); + + matched = FALSE; + + if (tok->normalized.len == 0) { + continue; + } + + key = tok->normalized.begin; + keylen = tok->normalized.len; + + switch (map->type) { + case RSPAMD_LUA_MAP_SET: + case RSPAMD_LUA_MAP_HASH: + /* We know that tok->normalized is zero terminated in fact */ + if (rspamd_match_hash_map(map->data.hash, key, keylen)) { + matched = TRUE; + } + break; + case RSPAMD_LUA_MAP_REGEXP: + case RSPAMD_LUA_MAP_REGEXP_MULTIPLE: + if (rspamd_match_regexp_map_single(map->data.re_map, key, + keylen)) { + matched = TRUE; + } + break; + default: + g_assert_not_reached(); + break; + } + + if (matched) { + nmatched++; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_pushvalue(L, cbpos); /* Function */ + rspamd_lua_push_full_word(L, tok); + + if (lua_pcall(L, 1, 0, err_idx) != 0) { + msg_err_task("cannot call callback function for lookup words: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); + } + } + + return nmatched; +} + static int lua_task_lookup_words(lua_State *L) { @@ -7054,13 +7132,13 @@ lua_task_lookup_words(lua_State *L) PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) { - if (tp->utf_words) { - matches += lua_lookup_words_array(L, 3, task, map, tp->utf_words); + if (tp->utf_words.a) { + matches += lua_lookup_words_kvec(L, 3, task, map, &tp->utf_words); } } - if (task->meta_words) { - matches += lua_lookup_words_array(L, 3, task, map, task->meta_words); + if (task->meta_words.a) { + matches += lua_lookup_words_kvec(L, 3, task, map, &task->meta_words); } lua_pushinteger(L, matches); @@ -7406,6 +7484,102 @@ lua_archive_get_filename(lua_State *L) return 1; } +struct rspamd_task_timer_cbdata { + lua_State *L; + struct rspamd_task *task; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_async_event *async_ev; + int cbref; + ev_timer ev; +}; + +static void +lua_timer_fin(gpointer arg) +{ + struct rspamd_task_timer_cbdata *cbdata = (struct rspamd_task_timer_cbdata *) arg; + + ev_timer_stop(cbdata->task->event_loop, &cbdata->ev); + luaL_unref(cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref); +} + +static void +lua_task_timer_cb(struct ev_loop *loop, struct ev_timer *t, int events) +{ + struct rspamd_task_timer_cbdata *cbdata = (struct rspamd_task_timer_cbdata *) t->data; + lua_State *L; + bool schedule_more = false; + + L = cbdata->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + int err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbdata->cbref); + rspamd_lua_task_push(L, cbdata->task); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err("call to periodic " + "script failed: %s", + lua_tostring(L, -1)); + } + else { + if (lua_isnumber(L, -1)) { + schedule_more = true; + ev_timer_set(&cbdata->ev, lua_tonumber(L, -1), 0.0); + } + } + + if (schedule_more) { + ev_timer_again(loop, t); + } + else { + /* Cleanup */ + if (cbdata->item) { + rspamd_symcache_item_async_dec_check(cbdata->task, cbdata->item, "timer"); + cbdata->item = NULL; + } + rspamd_session_remove_event(cbdata->task->s, lua_timer_fin, cbdata); + } +} + +static int +lua_task_add_timer(lua_State *L) +{ + struct ev_loop *ev_base; + struct rspamd_task *task; + + task = lua_check_task(L, 1); + ev_base = task->event_loop; + if (!lua_isfunction(L, 3)) { + return luaL_error(L, "invalid arguments: callback expected"); + } + + if (!lua_isnumber(L, 2)) { + return luaL_error(L, "invalid arguments: timeout expected"); + } + + struct rspamd_task_timer_cbdata *cbdata = rspamd_mempool_alloc(task->task_pool, sizeof(*cbdata)); + cbdata->L = L; + lua_pushvalue(L, 3); + cbdata->ev.data = cbdata; + cbdata->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + cbdata->task = task; + cbdata->item = rspamd_symcache_get_cur_item(task); + + if (cbdata->item) { + cbdata->async_ev = rspamd_session_add_event_full(task->s, lua_timer_fin, cbdata, "timer", + rspamd_symcache_dyn_item_name(cbdata->task, cbdata->item)); + rspamd_symcache_item_async_inc(task, cbdata->item, "timer"); + } + else { + cbdata->async_ev = rspamd_session_add_event(task->s, lua_timer_fin, cbdata, "timer"); + } + + ev_timer_init(&cbdata->ev, lua_task_timer_cb, lua_tonumber(L, 2), 0.0); + ev_timer_start(ev_base, &cbdata->ev); + + return 0; +} + /* Init part */ static int diff --git a/src/lua/lua_text.c b/src/lua/lua_text.c index 7ce7440c7..b45ee1743 100644 --- a/src/lua/lua_text.c +++ b/src/lua/lua_text.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,6 +73,12 @@ LUA_FUNCTION_DEF(text, byte); */ LUA_FUNCTION_DEF(text, len); /*** + * @method rspamd_text:len_utf8() + * Returns length of a string + * @return {number} length of string in **bytes** + */ +LUA_FUNCTION_DEF(text, len_utf8); +/*** * @method rspamd_text:str() * Converts text to string by copying its content * @return {string} copy of text as Lua string @@ -106,6 +112,12 @@ LUA_FUNCTION_DEF(text, span); */ LUA_FUNCTION_DEF(text, sub); /*** + * @method rspamd_text:sub_utf8(start[, len]) + * Returns a substring for lua_text similar to string.sub from Lua using UTF8 points + * @return {rspamd_text} new rspamd_text with span (must be careful when using with owned texts...) + */ +LUA_FUNCTION_DEF(text, sub_utf8); +/*** * @method rspamd_text:lines([stringify]) * Returns an iter over all lines as rspamd_text objects or as strings if `stringify` is true * @param {boolean} stringify stringify lines @@ -238,12 +250,14 @@ static const struct luaL_reg textlib_f[] = { static const struct luaL_reg textlib_m[] = { LUA_INTERFACE_DEF(text, len), + LUA_INTERFACE_DEF(text, len_utf8), LUA_INTERFACE_DEF(text, str), LUA_INTERFACE_DEF(text, ptr), LUA_INTERFACE_DEF(text, take_ownership), LUA_INTERFACE_DEF(text, save_in_file), LUA_INTERFACE_DEF(text, span), LUA_INTERFACE_DEF(text, sub), + LUA_INTERFACE_DEF(text, sub_utf8), LUA_INTERFACE_DEF(text, lines), LUA_INTERFACE_DEF(text, split), LUA_INTERFACE_DEF(text, at), @@ -1538,8 +1552,7 @@ lua_text_exclude_chars(lua_State *L) pat++; patlen--; } - for (; patlen > 0 && BITOP(byteset, *(unsigned char *) pat, |=); pat++, patlen--) - ; + for (; patlen > 0 && BITOP(byteset, *(unsigned char *) pat, |=); pat++, patlen--); p = t->start; end = t->start + t->len; @@ -1762,6 +1775,91 @@ lua_text_strtoul(lua_State *L) return 1; } +static int +lua_text_len_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + if (t != NULL) { + const char *s = t->start; + int32_t count = 0, i = 0; + UChar32 c; + while (i < t->len) { + U8_NEXT(s, i, t->len, c); + if (c < 0) { + lua_pushnil(L); + return 1; + } + count++; + } + + lua_pushinteger(L, count); + return 1; + } + else { + return luaL_error(L, "invalid arguments"); + } +} + +static int +lua_text_sub_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + if (t == NULL) { + return luaL_error(L, "invalid arguments"); + } + + UChar32 c; + + int32_t start = luaL_checkinteger(L, 2); + int32_t finish = luaL_optinteger(L, 3, -1); + + int32_t len_utf8 = 0; + int32_t i = 0; + while (i < t->len) { + U8_NEXT(t->start, i, t->len, c); + if (c < 0) { + lua_pushnil(L); + return 1; + } + + len_utf8++; + } + + start = relative_pos_start(start, len_utf8); + finish = relative_pos_end(finish, len_utf8); + + if (start > finish) { + lua_new_text(L, "", 0, TRUE); + return 1; + } + + const char *sub_start = t->start; + const char *sub_end = t->start; + int32_t char_pos = 0, remain, utf8_idx = 0; + + while (char_pos < t->len && utf8_idx < start - 1) { + U8_NEXT(t->start, char_pos, t->len, c); + sub_start = &t->start[char_pos]; + utf8_idx++; + } + + remain = (t->start + t->len) - sub_start; + char_pos = 0; + + while (char_pos < remain && utf8_idx < finish) { + U8_NEXT(sub_start, char_pos, remain, c); + sub_end = &sub_start[char_pos]; + utf8_idx++; + } + + /* Copy as we have no other options to make it safe */ + lua_new_text(L, sub_start, sub_end - sub_start, TRUE); + + return 1; +} + /* Used to distinguish lua text metatable */ static const unsigned int rspamd_lua_text_cookie = 0x2b21ef6fU; diff --git a/src/lua/lua_util.c b/src/lua/lua_util.c index ce4d9f67c..f2e9b8fa9 100644 --- a/src/lua/lua_util.c +++ b/src/lua/lua_util.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,12 +23,21 @@ #include "lua_parsers.h" -#ifdef WITH_LUA_REPL -#include "replxx.h" -#endif +#include "replxx.h" #include <math.h> #include <glob.h> +#include <sys/types.h> +#include <sys/time.h> +#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) +#include <sys/sysctl.h> +#ifdef __FreeBSD__ +#include <sys/user.h> +#endif +#endif +#ifdef __APPLE__ +#include <mach/mach.h> +#endif #include "unicode/uspoof.h" #include "unicode/uscript.h" @@ -629,6 +638,27 @@ LUA_FUNCTION_DEF(util, caseless_hash_fast); LUA_FUNCTION_DEF(util, get_hostname); /*** + * @function util.get_uptime() + * Returns system uptime in seconds + * @return {number} uptime in seconds + */ +LUA_FUNCTION_DEF(util, get_uptime); + +/*** + * @function util.get_pid() + * Returns current process PID + * @return {number} process ID + */ +LUA_FUNCTION_DEF(util, get_pid); + +/*** + * @function util.get_memory_usage() + * Returns memory usage information for current process + * @return {table} memory usage info with 'rss' and 'vsize' fields in bytes + */ +LUA_FUNCTION_DEF(util, get_memory_usage); + +/*** * @function util.parse_content_type(ct_string, mempool) * Parses content-type string to a table: * - `type` @@ -730,6 +760,9 @@ static const struct luaL_reg utillib_f[] = { LUA_INTERFACE_DEF(util, umask), LUA_INTERFACE_DEF(util, isatty), LUA_INTERFACE_DEF(util, get_hostname), + LUA_INTERFACE_DEF(util, get_uptime), + LUA_INTERFACE_DEF(util, get_pid), + LUA_INTERFACE_DEF(util, get_memory_usage), LUA_INTERFACE_DEF(util, parse_content_type), LUA_INTERFACE_DEF(util, mime_header_encode), LUA_INTERFACE_DEF(util, pack), @@ -755,9 +788,17 @@ static const struct luaL_reg int64lib_m[] = { {NULL, NULL}}; LUA_FUNCTION_DEF(ev_base, loop); +LUA_FUNCTION_DEF(ev_base, update_time); +LUA_FUNCTION_DEF(ev_base, timestamp); +LUA_FUNCTION_DEF(ev_base, pending_events); +LUA_FUNCTION_DEF(ev_base, add_timer); static const struct luaL_reg ev_baselib_m[] = { LUA_INTERFACE_DEF(ev_base, loop), + LUA_INTERFACE_DEF(ev_base, update_time), + LUA_INTERFACE_DEF(ev_base, timestamp), + LUA_INTERFACE_DEF(ev_base, pending_events), + LUA_INTERFACE_DEF(ev_base, add_timer), {"__tostring", rspamd_lua_class_tostring}, {NULL, NULL}}; @@ -2408,6 +2449,107 @@ lua_util_get_hostname(lua_State *L) } static int +lua_util_get_uptime(lua_State *L) +{ + LUA_TRACE_POINT; + double uptime = 0.0; + +#ifdef __linux__ + FILE *f = fopen("/proc/uptime", "r"); + if (f) { + if (fscanf(f, "%lf", &uptime) != 1) { + uptime = 0.0; + } + fclose(f); + } +#elif defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) + struct timeval boottime; + size_t len = sizeof(boottime); + int mib[2] = {CTL_KERN, KERN_BOOTTIME}; + + if (sysctl(mib, 2, &boottime, &len, NULL, 0) == 0) { + struct timeval now; + gettimeofday(&now, NULL); + uptime = (now.tv_sec - boottime.tv_sec) + + (now.tv_usec - boottime.tv_usec) / 1000000.0; + } +#endif + + lua_pushnumber(L, uptime); + return 1; +} + +static int +lua_util_get_pid(lua_State *L) +{ + LUA_TRACE_POINT; + lua_pushinteger(L, getpid()); + return 1; +} + +static int +lua_util_get_memory_usage(lua_State *L) +{ + LUA_TRACE_POINT; + lua_createtable(L, 0, 2); + +#ifdef __linux__ + FILE *f = fopen("/proc/self/status", "r"); + if (f) { + char line[256]; + long rss = 0, vsize = 0; + + while (fgets(line, sizeof(line), f)) { + if (sscanf(line, "VmRSS: %ld kB", &rss) == 1) { + rss *= 1024; /* Convert to bytes */ + } + else if (sscanf(line, "VmSize: %ld kB", &vsize) == 1) { + vsize *= 1024; /* Convert to bytes */ + } + } + fclose(f); + + lua_pushstring(L, "rss"); + lua_pushinteger(L, rss); + lua_settable(L, -3); + + lua_pushstring(L, "vsize"); + lua_pushinteger(L, vsize); + lua_settable(L, -3); + } +#elif defined(__APPLE__) + struct task_basic_info info; + mach_msg_type_number_t count = TASK_BASIC_INFO_COUNT; + + if (task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t) &info, &count) == KERN_SUCCESS) { + lua_pushstring(L, "rss"); + lua_pushinteger(L, info.resident_size); + lua_settable(L, -3); + + lua_pushstring(L, "vsize"); + lua_pushinteger(L, info.virtual_size); + lua_settable(L, -3); + } +#elif defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) + struct kinfo_proc kp; + size_t len = sizeof(kp); + int mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid()}; + + if (sysctl(mib, 4, &kp, &len, NULL, 0) == 0) { + lua_pushstring(L, "rss"); + lua_pushinteger(L, kp.ki_rssize * getpagesize()); + lua_settable(L, -3); + + lua_pushstring(L, "vsize"); + lua_pushinteger(L, kp.ki_size); + lua_settable(L, -3); + } +#endif + + return 1; +} + +static int lua_util_parse_content_type(lua_State *L) { return lua_parsers_parse_content_type(L); @@ -2502,7 +2644,7 @@ lua_util_readline(lua_State *L) if (lua_type(L, 1) == LUA_TSTRING) { prompt = lua_tostring(L, 1); } -#ifdef WITH_LUA_REPL + static Replxx *rx_instance = NULL; if (rx_instance == NULL) { @@ -2519,26 +2661,6 @@ lua_util_readline(lua_State *L) else { lua_pushnil(L); } -#else - size_t linecap = 0; - ssize_t linelen; - - fprintf(stdout, "%s ", prompt); - - linelen = getline(&input, &linecap, stdin); - - if (linelen > 0) { - if (input[linelen - 1] == '\n') { - linelen--; - } - - lua_pushlstring(L, input, linelen); - free(input); - } - else { - lua_pushnil(L); - } -#endif return 1; } @@ -3611,3 +3733,106 @@ lua_ev_base_loop(lua_State *L) return 1; } + +static int +lua_ev_base_update_time(lua_State *L) +{ + struct ev_loop *ev_base; + + ev_base = lua_check_ev_base(L, 1); + ev_now_update_if_cheap(ev_base); + + lua_pushnumber(L, ev_time()); + + return 1; +} + +static int +lua_ev_base_timestamp(lua_State *L) +{ + struct ev_loop *ev_base; + + ev_base = lua_check_ev_base(L, 1); + lua_pushnumber(L, ev_now(ev_base)); + + return 1; +} + +static int +lua_ev_base_pending_events(lua_State *L) +{ + struct ev_loop *ev_base; + + ev_base = lua_check_ev_base(L, 1); + lua_pushnumber(L, ev_pending_count(ev_base)); + + return 1; +} + +struct rspamd_ev_base_cbdata { + lua_State *L; + int cbref; + ev_timer ev; +}; + +static void +lua_ev_base_cb(struct ev_loop *loop, struct ev_timer *t, int events) +{ + struct rspamd_ev_base_cbdata *cbdata = (struct rspamd_ev_base_cbdata *) t->data; + lua_State *L; + bool schedule_more = false; + + L = cbdata->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + int err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbdata->cbref); + + if (lua_pcall(L, 0, 1, err_idx) != 0) { + msg_err("call to periodic " + "script failed: %s", + lua_tostring(L, -1)); + } + else { + if (lua_isnumber(L, -1)) { + schedule_more = true; + ev_timer_set(&cbdata->ev, lua_tonumber(L, -1), 0.0); + } + } + + if (schedule_more) { + ev_timer_again(loop, t); + } + else { + /* Cleanup */ + ev_timer_stop(loop, t); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->cbref); + g_free(cbdata); + } +} + +static int +lua_ev_base_add_timer(lua_State *L) +{ + struct ev_loop *ev_base; + + ev_base = lua_check_ev_base(L, 1); + if (!lua_isfunction(L, 3)) { + return luaL_error(L, "invalid arguments: callback expected"); + } + + if (!lua_isnumber(L, 2)) { + return luaL_error(L, "invalid arguments: timeout expected"); + } + + struct rspamd_ev_base_cbdata *cbdata = g_malloc(sizeof(*cbdata)); + cbdata->L = L; + lua_pushvalue(L, 3); + cbdata->ev.data = cbdata; + cbdata->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + ev_timer_init(&cbdata->ev, lua_ev_base_cb, lua_tonumber(L, 2), 0.0); + ev_timer_start(ev_base, &cbdata->ev); + + return 0; +} diff --git a/src/plugins/chartable.cxx b/src/plugins/chartable.cxx index a5c7cb899..c82748862 100644 --- a/src/plugins/chartable.cxx +++ b/src/plugins/chartable.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1696,7 +1696,7 @@ rspamd_can_alias_latin(int ch) static double rspamd_chartable_process_word_utf(struct rspamd_task *task, - rspamd_stat_token_t *w, + rspamd_word_t *w, gboolean is_url, unsigned int *ncap, struct chartable_ctx *chartable_module_ctx, @@ -1842,7 +1842,7 @@ rspamd_chartable_process_word_utf(struct rspamd_task *task, static double rspamd_chartable_process_word_ascii(struct rspamd_task *task, - rspamd_stat_token_t *w, + rspamd_word_t *w, gboolean is_url, struct chartable_ctx *chartable_module_ctx) { @@ -1931,17 +1931,17 @@ rspamd_chartable_process_part(struct rspamd_task *task, struct chartable_ctx *chartable_module_ctx, gboolean ignore_diacritics) { - rspamd_stat_token_t *w; + rspamd_word_t *w; unsigned int i, ncap = 0; double cur_score = 0.0; - if (part == nullptr || part->utf_words == nullptr || - part->utf_words->len == 0 || part->nwords == 0) { + if (part == nullptr || part->utf_words.a == nullptr || + kv_size(part->utf_words) == 0 || part->nwords == 0) { return FALSE; } - for (i = 0; i < part->utf_words->len; i++) { - w = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(part->utf_words); i++) { + w = &kv_A(part->utf_words, i); if ((w->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT)) { @@ -2015,13 +2015,13 @@ chartable_symbol_callback(struct rspamd_task *task, ignore_diacritics = TRUE; } - if (task->meta_words != nullptr && task->meta_words->len > 0) { - rspamd_stat_token_t *w; + if (task->meta_words.a && kv_size(task->meta_words) > 0) { + rspamd_word_t *w; double cur_score = 0; - gsize arlen = task->meta_words->len; + gsize arlen = kv_size(task->meta_words); for (i = 0; i < arlen; i++) { - w = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + w = &kv_A(task->meta_words, i); cur_score += rspamd_chartable_process_word_utf(task, w, FALSE, nullptr, chartable_module_ctx, ignore_diacritics); } diff --git a/src/plugins/fuzzy_check.c b/src/plugins/fuzzy_check.c index ece9a91e0..7dd5162ac 100644 --- a/src/plugins/fuzzy_check.c +++ b/src/plugins/fuzzy_check.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,7 +78,8 @@ enum fuzzy_rule_mode { }; struct fuzzy_rule { - struct upstream_list *servers; + struct upstream_list *read_servers; /* Servers for read operations */ + struct upstream_list *write_servers; /* Servers for write operations */ const char *symbol; const char *algorithm_str; const char *name; @@ -543,22 +544,68 @@ fuzzy_parse_rule(struct rspamd_config *cfg, const ucl_object_t *obj, } if ((value = ucl_object_lookup(obj, "servers")) != NULL) { - rule->servers = rspamd_upstreams_create(cfg->ups_ctx); - /* pass max_error and revive_time configuration in upstream for fuzzy storage - * it allows to configure error_rate threshold and upstream dead timer - */ - rspamd_upstreams_set_limits(rule->servers, + rule->read_servers = rspamd_upstreams_create(cfg->ups_ctx); + rspamd_upstreams_set_limits(rule->read_servers, (double) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, (unsigned int) fuzzy_module_ctx->max_errors, 0); rspamd_mempool_add_destructor(cfg->cfg_pool, (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, - rule->servers); - if (!rspamd_upstreams_from_ucl(rule->servers, value, DEFAULT_PORT, NULL)) { + rule->read_servers); + if (!rspamd_upstreams_from_ucl(rule->read_servers, value, DEFAULT_PORT, NULL)) { msg_err_config("cannot read servers definition"); return -1; } + + rule->write_servers = rule->read_servers; + } + else { + /* Check for read_servers and write_servers */ + gboolean has_read = FALSE, has_write = FALSE; + + if ((value = ucl_object_lookup(obj, "read_servers")) != NULL) { + rule->read_servers = rspamd_upstreams_create(cfg->ups_ctx); + rspamd_upstreams_set_limits(rule->read_servers, + (double) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, + (unsigned int) fuzzy_module_ctx->max_errors, 0); + + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, + rule->read_servers); + if (!rspamd_upstreams_from_ucl(rule->read_servers, value, DEFAULT_PORT, NULL)) { + msg_err_config("cannot read read_servers definition"); + return -1; + } + has_read = TRUE; + } + + if ((value = ucl_object_lookup(obj, "write_servers")) != NULL) { + rule->write_servers = rspamd_upstreams_create(cfg->ups_ctx); + rspamd_upstreams_set_limits(rule->write_servers, + (double) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, + (unsigned int) fuzzy_module_ctx->max_errors, 0); + + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, + rule->write_servers); + if (!rspamd_upstreams_from_ucl(rule->write_servers, value, DEFAULT_PORT, NULL)) { + msg_err_config("cannot read write_servers definition"); + return -1; + } + has_write = TRUE; + } + + /* If we have both read and write servers, we don't need the common servers list */ + if (has_read && !has_write) { + /* Use read_servers for all operations */ + rule->write_servers = rule->read_servers; + } + else if (has_write && !has_read) { + /* Use write_servers for all operations */ + rule->read_servers = rule->write_servers; + } } + if ((value = ucl_object_lookup(obj, "fuzzy_map")) != NULL) { it = NULL; while ((cur = ucl_object_iterate(value, &it, true)) != NULL) { @@ -636,7 +683,7 @@ fuzzy_parse_rule(struct rspamd_config *cfg, const ucl_object_t *obj, strlen(shingles_key_str), NULL, 0); rule->shingles_key->len = 16; - if (rspamd_upstreams_count(rule->servers) == 0) { + if (rspamd_upstreams_count(rule->read_servers) == 0) { msg_err_config("no servers defined for fuzzy rule with name: %s", rule->name); return -1; @@ -898,6 +945,24 @@ int fuzzy_check_module_init(struct rspamd_config *cfg, struct module_ctx **ctx) 0); rspamd_rcl_add_doc_by_path(cfg, "fuzzy_check.rule", + "List of servers to check (read-only operations)", + "read_servers", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "List of servers to learn (write operations)", + "write_servers", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", "If true then never try to learn this fuzzy storage", "read_only", UCL_BOOLEAN, @@ -1249,7 +1314,7 @@ int fuzzy_check_module_config(struct rspamd_config *cfg, bool validate) LL_FOREACH(value, cur) { - if (ucl_object_lookup(cur, "servers")) { + if (ucl_object_lookup_any(cur, "servers", "read_servers", "write_servers", NULL) != NULL) { /* Unnamed rule */ fuzzy_parse_rule(cfg, cur, NULL, cb_id); nrules++; @@ -1366,10 +1431,10 @@ fuzzy_io_fin(void *ud) close(session->fd); } -static GArray * +static rspamd_words_t * fuzzy_preprocess_words(struct rspamd_mime_text_part *part, rspamd_mempool_t *pool) { - return part->utf_words; + return &part->utf_words; } static void @@ -1715,26 +1780,30 @@ fuzzy_cmd_write_extensions(struct rspamd_task *task, struct rspamd_email_address *addr = g_ptr_array_index(MESSAGE_FIELD(task, from_mime), 0); - unsigned int to_write = MIN(MAX_FUZZY_DOMAIN, addr->domain_len) + 2; - if (to_write > 0 && to_write <= available) { - *dest++ = RSPAMD_FUZZY_EXT_SOURCE_DOMAIN; - *dest++ = to_write - 2; + if (addr->domain_len > 0) { + /* Filter invalid domains */ + unsigned int to_write = MIN(MAX_FUZZY_DOMAIN, addr->domain_len) + 2; - if (addr->domain_len < MAX_FUZZY_DOMAIN) { - memcpy(dest, addr->domain, addr->domain_len); - dest += addr->domain_len; - } - else { - /* Trim from left */ - memcpy(dest, - addr->domain + (addr->domain_len - MAX_FUZZY_DOMAIN), - MAX_FUZZY_DOMAIN); - dest += MAX_FUZZY_DOMAIN; - } + if (to_write > 0 && to_write <= available) { + *dest++ = RSPAMD_FUZZY_EXT_SOURCE_DOMAIN; + *dest++ = to_write - 2; + + if (addr->domain_len < MAX_FUZZY_DOMAIN) { + memcpy(dest, addr->domain, addr->domain_len); + dest += addr->domain_len; + } + else { + /* Trim from left */ + memcpy(dest, + addr->domain + (addr->domain_len - MAX_FUZZY_DOMAIN), + MAX_FUZZY_DOMAIN); + dest += MAX_FUZZY_DOMAIN; + } - available -= to_write; - written += to_write; + available -= to_write; + written += to_write; + } } } @@ -1792,7 +1861,7 @@ fuzzy_cmd_from_text_part(struct rspamd_task *task, unsigned int i; rspamd_cryptobox_hash_state_t st; rspamd_stat_token_t *word; - GArray *words; + rspamd_words_t *words; struct fuzzy_cmd_io *io; unsigned int additional_length; unsigned char *additional_data; @@ -1901,10 +1970,10 @@ fuzzy_cmd_from_text_part(struct rspamd_task *task, rspamd_cryptobox_hash_init(&st, rule->hash_key->str, rule->hash_key->len); words = fuzzy_preprocess_words(part, task->task_pool); - for (i = 0; i < words->len; i++) { - word = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + word = &kv_A(*words, i); - if (!((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0)) { + if (!((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0)) { rspamd_cryptobox_hash_update(&st, word->stemmed.begin, word->stemmed.len); } @@ -2615,7 +2684,7 @@ fuzzy_insert_metric_results(struct rspamd_task *task, struct fuzzy_rule *rule, if (task->message) { PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) { - if (!IS_TEXT_PART_EMPTY(tp) && tp->utf_words != NULL && tp->utf_words->len > 0) { + if (!IS_TEXT_PART_EMPTY(tp) && kv_size(tp->utf_words) > 0) { seen_text_part = TRUE; if (tp->utf_stripped_text.magic == UTEXT_MAGIC) { @@ -3394,8 +3463,8 @@ register_fuzzy_client_call(struct rspamd_task *task, int sock; if (!rspamd_session_blocked(task->s)) { - /* Get upstream */ - selected = rspamd_upstream_get(rule->servers, RSPAMD_UPSTREAM_ROUND_ROBIN, + /* Get upstream - use read_servers for check operations */ + selected = rspamd_upstream_get(rule->read_servers, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); if (selected) { addr = rspamd_upstream_addr_next(selected); @@ -3522,9 +3591,8 @@ register_fuzzy_controller_call(struct rspamd_http_connection_entry *entry, int sock; int ret = -1; - /* Get upstream */ - - while ((selected = rspamd_upstream_get_forced(rule->servers, + /* Get upstream - use write_servers for learn/unlearn operations */ + while ((selected = rspamd_upstream_get_forced(rule->write_servers, RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) { /* Create UDP socket */ addr = rspamd_upstream_addr_next(selected); @@ -3538,6 +3606,9 @@ register_fuzzy_controller_call(struct rspamd_http_connection_entry *entry, rspamd_upstream_fail(selected, TRUE, strerror(errno)); } else { + msg_info_task("fuzzy storage %s (%s rule) is used for write", + rspamd_inet_address_to_string_pretty(addr), + rule->name); s = rspamd_mempool_alloc0(session->pool, sizeof(struct fuzzy_learn_session)); @@ -3620,6 +3691,7 @@ fuzzy_modify_handler(struct rspamd_http_connection_entry *conn_ent, PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) { if (rule->mode == fuzzy_rule_read_only) { + msg_debug_task("skip rule %s as it is read-only", rule->name); continue; } @@ -3729,6 +3801,8 @@ fuzzy_modify_handler(struct rspamd_http_connection_entry *conn_ent, else { commands = fuzzy_generate_commands(task, rule, cmd, flag, value, flags); + msg_debug_task("fuzzy command %d for rule %s, flag %d, value %d", + cmd, rule->name, flag, value); if (commands != NULL) { res = register_fuzzy_controller_call(conn_ent, rule, @@ -3894,7 +3968,7 @@ fuzzy_check_send_lua_learn(struct fuzzy_rule *rule, /* Get upstream */ if (!rspamd_session_blocked(task->s)) { - while ((selected = rspamd_upstream_get(rule->servers, + while ((selected = rspamd_upstream_get(rule->write_servers, RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) { /* Create UDP socket */ addr = rspamd_upstream_addr_next(selected); @@ -4491,9 +4565,21 @@ fuzzy_lua_list_storages(lua_State *L) lua_setfield(L, -2, "read_only"); /* Push servers */ - lua_createtable(L, rspamd_upstreams_count(rule->servers), 0); - rspamd_upstreams_foreach(rule->servers, lua_upstream_str_inserter, L); - lua_setfield(L, -2, "servers"); + if (rule->read_servers == rule->write_servers) { + /* Same servers for both operations */ + lua_createtable(L, rspamd_upstreams_count(rule->read_servers), 0); + rspamd_upstreams_foreach(rule->read_servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "servers"); + } + else { + /* Different servers for read and write */ + lua_createtable(L, rspamd_upstreams_count(rule->read_servers), 0); + rspamd_upstreams_foreach(rule->read_servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "read_servers"); + lua_createtable(L, rspamd_upstreams_count(rule->write_servers), 0); + rspamd_upstreams_foreach(rule->write_servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "write_servers"); + } /* Push flags */ GHashTableIter it; @@ -4780,7 +4866,7 @@ fuzzy_lua_ping_storage(lua_State *L) rspamd_ptr_array_free_hard, addrs); } else { - struct upstream *selected = rspamd_upstream_get(rule_found->servers, + struct upstream *selected = rspamd_upstream_get(rule_found->read_servers, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); addr = rspamd_upstream_addr_next(selected); } @@ -4824,4 +4910,4 @@ fuzzy_lua_ping_storage(lua_State *L) lua_pushboolean(L, TRUE); return 1; -}
\ No newline at end of file +} diff --git a/src/plugins/lua/arc.lua b/src/plugins/lua/arc.lua index fb5dd93e6..45da1f5a2 100644 --- a/src/plugins/lua/arc.lua +++ b/src/plugins/lua/arc.lua @@ -147,7 +147,7 @@ local function parse_arc_header(hdr, target, is_aar) -- sort by i= attribute table.sort(target, function(a, b) - return (a.i or 0) < (b.i or 0) + return (tonumber(a.i) or 0) < (tonumber(b.i) or 0) end) end @@ -695,11 +695,11 @@ local function do_sign(task, sign_params) sign_params.pubkey = results[1] sign_params.strict_pubkey_check = not settings.allow_pubkey_mismatch elseif not settings.allow_pubkey_mismatch then - rspamd_logger.errx('public key for domain %s/%s is not found: %s, skip signing', + rspamd_logger.errx(task, 'public key for domain %s/%s is not found: %s, skip signing', sign_params.domain, sign_params.selector, err) return else - rspamd_logger.infox('public key for domain %s/%s is not found: %s', + rspamd_logger.infox(task, 'public key for domain %s/%s is not found: %s', sign_params.domain, sign_params.selector, err) end diff --git a/src/plugins/lua/contextal.lua b/src/plugins/lua/contextal.lua new file mode 100644 index 000000000..e29c21645 --- /dev/null +++ b/src/plugins/lua/contextal.lua @@ -0,0 +1,338 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local E = {} +local N = 'contextal' + +if confighelp then + return +end + +local opts = rspamd_config:get_all_opt(N) +if not opts then + return +end + +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local redis_cache = require "lua_cache" +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types +local ucl = require "ucl" + +local cache_context, redis_params + +local contextal_actions = { + ['ALERT'] = true, + ['ALLOW'] = true, + ['BLOCK'] = true, + ['QUARANTINE'] = true, + ['SPAM'] = true, +} + +local config_schema = lua_redis.enrich_schema { + action_symbol_prefix = ts.string:is_optional(), + base_url = ts.string:is_optional(), + cache_prefix = ts.string:is_optional(), + cache_timeout = ts.number:is_optional(), + cache_ttl = ts.number:is_optional(), + custom_actions = ts.array_of(ts.string):is_optional(), + defer_if_no_result = ts.boolean:is_optional(), + defer_message = ts.string:is_optional(), + enabled = ts.boolean:is_optional(), + http_timeout = ts.number:is_optional(), + request_ttl = ts.number:is_optional(), + submission_symbol = ts.string:is_optional(), +} + +local settings = { + action_symbol_prefix = 'CONTEXTAL_ACTION', + base_url = 'http://localhost:8080', + cache_prefix = 'CXAL', + cache_timeout = 5, + cache_ttl = 3600, + custom_actions = {}, + defer_if_no_result = false, + defer_message = 'Awaiting deep scan - try again later', + http_timeout = 2, + request_ttl = 4, + submission_symbol = 'CONTEXTAL_SUBMIT', +} + +local static_boundary = rspamd_util.random_hex(32) +local wait_request_ttl = true + +local function maybe_defer(task, obj) + if settings.defer_if_no_result and not ((obj or E)[1] or E).actions then + task:set_pre_result('soft reject', settings.defer_message, N) + end +end + +local function process_actions(task, obj, is_cached) + lua_util.debugm(N, task, 'got result: %s (%s)', obj, is_cached and 'cached' or 'fresh') + for _, match in ipairs((obj[1] or E).actions or E) do + local act = match.action + local scenario = match.scenario + if not (act and scenario) then + rspamd_logger.err(task, 'bad result: %s', match) + elseif contextal_actions[act] then + task:insert_result(settings.action_symbol_prefix .. '_' .. act, 1.0, scenario) + else + rspamd_logger.err(task, 'unknown action: %s', act) + end + end + + if not cache_context or is_cached then + maybe_defer(task, obj) + return + end + + local cache_obj + if (obj[1] or E).actions then + cache_obj = {[1] = {["actions"] = obj[1].actions}} + else + local work_id = task:get_mempool():get_variable('contextal_work_id', 'string') + if work_id then + cache_obj = {[1] = {["work_id"] = work_id}} + else + rspamd_logger.err(task, 'no work id found in mempool') + return + end + end + + redis_cache.cache_set(task, + task:get_digest(), + cache_obj, + cache_context) + + maybe_defer(task, obj) +end + +local function process_cached(task, obj) + if (obj[1] or E).actions then + lua_util.debugm(N, task, 'using cached actions: %s', obj[1].actions) + task:disable_symbol(settings.action_symbol_prefix) + return process_actions(task, obj, true) + elseif (obj[1] or E).work_id then + lua_util.debugm(N, task, 'using old work ID: %s', obj[1].work_id) + task:get_mempool():set_variable('contextal_work_id', obj[1].work_id) + else + rspamd_logger.err(task, 'bad result (cached): %s', obj) + end +end + +local function action_cb(task) + local work_id = task:get_mempool():get_variable('contextal_work_id', 'string') + if not work_id then + rspamd_logger.err(task, 'no work id found in mempool') + return + end + lua_util.debugm(N, task, 'polling for result for work id: %s', work_id) + + local function http_callback(err, code, body, hdrs) + if err then + rspamd_logger.err(task, 'http error: %s', err) + maybe_defer(task) + return + end + if code ~= 200 then + rspamd_logger.err(task, 'bad http code: %s', code) + maybe_defer(task) + return + end + local parser = ucl.parser() + local _, parse_err = parser:parse_string(body) + if parse_err then + rspamd_logger.err(task, 'cannot parse JSON: %s', err) + maybe_defer(task) + return + end + local obj = parser:get_object() + return process_actions(task, obj, false) + end + + rspamd_http.request({ + task = task, + url = settings.actions_url .. work_id, + callback = http_callback, + timeout = settings.http_timeout, + gzip = settings.gzip, + keepalive = settings.keepalive, + no_ssl_verify = settings.no_ssl_verify, + }) +end + +local function submit(task) + + local function http_callback(err, code, body, hdrs) + if err then + rspamd_logger.err(task, 'http error: %s', err) + maybe_defer(task) + return + end + if code ~= 201 then + rspamd_logger.err(task, 'bad http code: %s', code) + maybe_defer(task) + return + end + local parser = ucl.parser() + local _, parse_err = parser:parse_string(body) + if parse_err then + rspamd_logger.err(task, 'cannot parse JSON: %s', err) + maybe_defer(task) + return + end + local obj = parser:get_object() + local work_id = obj.work_id + if work_id then + task:get_mempool():set_variable('contextal_work_id', work_id) + end + task:insert_result(settings.submission_symbol, 1.0, + string.format('work_id=%s', work_id or 'nil')) + if wait_request_ttl then + task:add_timer(settings.request_ttl, action_cb) + end + end + + local req = { + object_data = {['data'] = task:get_content()}, + } + if settings.request_ttl then + req.ttl = {['data'] = tostring(settings.request_ttl)} + end + if settings.max_recursion then + req.maxrec = {['data'] = tostring(settings.max_recursion)} + end + rspamd_http.request({ + task = task, + url = settings.submit_url, + body = lua_util.table_to_multipart_body(req, static_boundary), + callback = http_callback, + headers = { + ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary) + }, + timeout = settings.http_timeout, + gzip = settings.gzip, + keepalive = settings.keepalive, + no_ssl_verify = settings.no_ssl_verify, + }) +end + +local function cache_hit(task, err, data) + if err then + rspamd_logger.err(task, 'error getting cache: %s', err) + else + process_cached(task, data) + end +end + +local function submit_cb(task) + if cache_context then + redis_cache.cache_get(task, + task:get_digest(), + cache_context, + settings.cache_timeout, + submit, + cache_hit + ) + else + submit(task) + end +end + +local function set_url_path(base, path) + local slash = base:sub(#base) == '/' and '' or '/' + return base .. slash .. path +end + +settings = lua_util.override_defaults(settings, opts) + +local res, err = config_schema:transform(settings) +if not res then + rspamd_logger.warnx(rspamd_config, 'plugin %s is misconfigured: %s', N, err) + local err_msg = string.format("schema error: %s", res) + lua_util.config_utils.push_config_error(N, err_msg) + lua_util.disable_module(N, "failed", err_msg) + return +end + +for _, k in ipairs(settings.custom_actions) do + contextal_actions[k] = true +end + +if not settings.base_url then + if not (settings.submit_url and settings.actions_url) then + rspamd_logger.err(rspamd_config, 'no URL configured for contextal') + lua_util.disable_module(N, 'config') + return + end +else + if not settings.submit_url then + settings.submit_url = set_url_path(settings.base_url, 'api/v1/submit') + end + if not settings.actions_url then + settings.actions_url = set_url_path(settings.base_url, 'api/v1/actions/') + end +end + +redis_params = lua_redis.parse_redis_server(N) +if redis_params then + cache_context = redis_cache.create_cache_context(redis_params, { + cache_prefix = settings.cache_prefix, + cache_ttl = settings.cache_ttl, + cache_format = 'json', + cache_use_hashing = false + }) +end + +local submission_id = rspamd_config:register_symbol({ + name = settings.submission_symbol, + type = 'normal', + group = N, + callback = submit_cb +}) + +local top_options = rspamd_config:get_all_opt('options') +if settings.request_ttl and settings.request_ttl >= (top_options.task_timeout * 0.8) then + rspamd_logger.info(rspamd_config, [[request ttl is >= 80% of task timeout, won't wait on processing]]) + wait_request_ttl = false +elseif not settings.request_ttl then + wait_request_ttl = false +end + +local parent_id +if wait_request_ttl then + parent_id = submission_id +else + parent_id = rspamd_config:register_symbol({ + name = settings.action_symbol_prefix, + type = 'postfilter', + priority = lua_util.symbols_priorities.high - 1, + group = N, + callback = action_cb + }) +end + +for k in pairs(contextal_actions) do + rspamd_config:register_symbol({ + name = settings.action_symbol_prefix .. '_' .. k, + parent = parent_id, + type = 'virtual', + group = N, + }) +end diff --git a/src/plugins/lua/elastic.lua b/src/plugins/lua/elastic.lua index 8bed9fcf4..b26fbd8e8 100644 --- a/src/plugins/lua/elastic.lua +++ b/src/plugins/lua/elastic.lua @@ -19,6 +19,7 @@ local rspamd_logger = require 'rspamd_logger' local rspamd_http = require "rspamd_http" local lua_util = require "lua_util" local rspamd_util = require "rspamd_util" +local rspamd_text = require "rspamd_text" local ucl = require "ucl" local upstream_list = require "rspamd_upstream_list" @@ -287,7 +288,7 @@ end local function handle_error(action, component, limit) if states[component]['errors'] >= limit then rspamd_logger.errx(rspamd_config, 'cannot %s elastic %s, failed attempts: %s/%s, stop trying', - action, component:gsub('_', ' '), states[component]['errors'], limit) + action, component:gsub('_', ' '), states[component]['errors'], limit) states[component]['configured'] = true else states[component]['errors'] = states[component]['errors'] + 1 @@ -315,54 +316,6 @@ local function get_received_delay(received_headers) return delay end -local function is_empty(str) - -- define a pattern that includes invisible unicode characters - local str_cleared = str:gsub('[' .. - '\xC2\xA0' .. -- U+00A0 non-breaking space - '\xE2\x80\x8B' .. -- U+200B zero width space - '\xEF\xBB\xBF' .. -- U+FEFF byte order mark (zero width no-break space) - '\xE2\x80\x8C' .. -- U+200C zero width non-joiner - '\xE2\x80\x8D' .. -- U+200D zero width joiner - '\xE2\x80\x8E' .. -- U+200E left-to-right mark - '\xE2\x80\x8F' .. -- U+200F right-to-left mark - '\xE2\x81\xA0' .. -- U+2060 word joiner - '\xE2\x80\xAA' .. -- U+202A left-to-right embedding - '\xE2\x80\xAB' .. -- U+202B right-to-left embedding - '\xE2\x80\xAC' .. -- U+202C pop directional formatting - '\xE2\x80\xAD' .. -- U+202D left-to-right override - '\xE2\x80\xAE' .. -- U+202E right-to-left override - '\xE2\x81\x9F' .. -- U+2061 function application - '\xE2\x81\xA1' .. -- U+2061 invisible separator - '\xE2\x81\xA2' .. -- U+2062 invisible times - '\xE2\x81\xA3' .. -- U+2063 invisible separator - '\xE2\x81\xA4' .. -- U+2064 invisible plus - ']', '') -- gsub replaces all matched characters with an empty string - if str_cleared:match('[%S]') then - return false - else - return true - end -end - -local function fill_empty_strings(tbl, empty_value) - local filled_tbl = {} - for key, value in pairs(tbl) do - if value and type(value) == 'table' then - local nested_filtered = fill_empty_strings(value, empty_value) - if next(nested_filtered) ~= nil then - filled_tbl[key] = nested_filtered - end - elseif type(value) == 'boolean' then - filled_tbl[key] = value - elseif value and type(value) == 'string' and is_empty(value) then - filled_tbl[key] = empty_value - elseif value then - filled_tbl[key] = value - end - end - return filled_tbl -end - local function create_bulk_json(es_index, logs_to_send) local tbl = {} for _, row in pairs(logs_to_send) do @@ -407,15 +360,17 @@ local function elastic_send_data(flush_all, task, cfg, ev_base) local function http_callback(err, code, body, _) local push_done = false if err then - rspamd_logger.errx(log_object, 'cannot send logs to elastic (%s): %s; failed attempts: %s/%s', - push_url, err, buffer['errors'], settings['limits']['max_fail']) + rspamd_logger.errx(log_object, + 'cannot send logs to elastic (%s): %s; failed attempts: %s/%s', + push_url, err, buffer['errors'], settings['limits']['max_fail']) elseif code == 200 then local parser = ucl.parser() local res, ucl_err = parser:parse_string(body) if not ucl_err and res then local obj = parser:get_object() push_done = true - lua_util.debugm(N, log_object, 'successfully sent payload with %s logs', nlogs_to_send) + lua_util.debugm(N, log_object, + 'successfully sent payload with %s logs', nlogs_to_send) if obj['errors'] then for _, value in pairs(obj['items']) do if value['index'] and value['index']['status'] >= 400 then @@ -424,15 +379,15 @@ local function elastic_send_data(flush_all, task, cfg, ev_base) local error_type = safe_get(value, 'index', 'error', 'type') or '' local error_reason = safe_get(value, 'index', 'error', 'reason') or '' rspamd_logger.warnx(log_object, - 'error while pushing logs to elastic, status: %s, index: %s, type: %s, reason: %s', - status, index, error_type, error_reason) + 'error while pushing logs to elastic, status: %s, index: %s, type: %s, reason: %s', + status, index, error_type, error_reason) end end end else rspamd_logger.errx(log_object, - 'cannot parse response from elastic (%s): %s; failed attempts: %s/%s', - push_url, ucl_err, buffer['errors'], settings['limits']['max_fail']) + 'cannot parse response from elastic (%s): %s; failed attempts: %s/%s', + push_url, ucl_err, buffer['errors'], settings['limits']['max_fail']) end else rspamd_logger.errx(log_object, @@ -448,8 +403,8 @@ local function elastic_send_data(flush_all, task, cfg, ev_base) upstream:fail() if buffer['errors'] >= settings['limits']['max_fail'] then rspamd_logger.errx(log_object, - 'failed to send %s log lines, failed attempts: %s/%s, removing failed logs from bugger', - nlogs_to_send, buffer['errors'], settings['limits']['max_fail']) + 'failed to send %s log lines, failed attempts: %s/%s, removing failed logs from bugger', + nlogs_to_send, buffer['errors'], settings['limits']['max_fail']) buffer['logs']:pop_first(nlogs_to_send) buffer['errors'] = 0 else @@ -494,6 +449,7 @@ local function get_general_metadata(task) local empty = settings['index_template']['empty_value'] local user = task:get_user() r.rspamd_server = rspamd_hostname or empty + r.digest = task:get_digest() or empty r.action = task:get_metric_action() or empty r.score = task:get_metric_score()[1] or 0 @@ -565,8 +521,8 @@ local function get_general_metadata(task) if task:has_from('smtp') then local from = task:get_from({ 'smtp', 'orig' })[1] if from and - from['user'] and #from['user'] > 0 and - from['domain'] and #from['domain'] > 0 + from['user'] and #from['user'] > 0 and + from['domain'] and #from['domain'] > 0 then r.from_user = from['user'] r.from_domain = from['domain']:lower() @@ -578,8 +534,8 @@ local function get_general_metadata(task) if task:has_from('mime') then local mime_from = task:get_from({ 'mime', 'orig' })[1] if mime_from and - mime_from['user'] and #mime_from['user'] > 0 and - mime_from['domain'] and #mime_from['domain'] > 0 + mime_from['user'] and #mime_from['user'] > 0 and + mime_from['domain'] and #mime_from['domain'] > 0 then r.mime_from_user = mime_from['user'] r.mime_from_domain = mime_from['domain']:lower() @@ -608,25 +564,34 @@ local function get_general_metadata(task) local function process_header(name) local hdr = task:get_header_full(name) - local headers_text_ignore_above = settings['index_template']['headers_text_ignore_above'] - 3 if hdr and #hdr > 0 then local l = {} for _, h in ipairs(hdr) do - if settings['index_template']['headers_count_ignore_above'] ~= 0 and - #l >= settings['index_template']['headers_count_ignore_above'] + if settings['index_template']['headers_count_ignore_above'] > 0 and + #l >= settings['index_template']['headers_count_ignore_above'] then table.insert(l, 'ignored above...') break end local header - if settings['index_template']['headers_text_ignore_above'] ~= 0 and - h.decoded and #h.decoded >= headers_text_ignore_above - then - header = h.decoded:sub(1, headers_text_ignore_above) .. '...' - elseif h.decoded and #h.decoded > 0 then - header = h.decoded + local header_len + if h.decoded then + header = rspamd_text.fromstring(h.decoded) + header_len = header:len_utf8() else - header = empty + table.insert(l, empty) + break + end + if not header_len or header_len == 0 then + table.insert(l, empty) + break + end + if settings['index_template']['headers_text_ignore_above'] > 0 and + header_len >= settings['index_template']['headers_text_ignore_above'] + then + header = header:sub_utf8(1, settings['index_template']['headers_text_ignore_above']) + table.insert(l, header .. rspamd_text.fromstring('...')) + break end table.insert(l, header) end @@ -686,7 +651,7 @@ local function get_general_metadata(task) r.received_delay = get_received_delay(task:get_received_headers()) - return fill_empty_strings(r, empty) + return r end local function elastic_collect(task) @@ -773,8 +738,8 @@ local function configure_geoip_pipeline(cfg, ev_base) upstream:ok() else rspamd_logger.errx(rspamd_config, - 'cannot configure elastic geoip pipeline (%s), status code: %s, response: %s', - geoip_url, code, body) + 'cannot configure elastic geoip pipeline (%s), status code: %s, response: %s', + geoip_url, code, body) upstream:fail() handle_error('configure', 'geoip_pipeline', settings['limits']['max_fail']) end @@ -810,8 +775,9 @@ local function put_index_policy(cfg, ev_base, upstream, host, policy_url, index_ states['index_policy']['configured'] = true upstream:ok() else - rspamd_logger.errx(rspamd_config, 'cannot configure elastic index policy (%s), status code: %s, response: %s', - policy_url, code, body) + rspamd_logger.errx(rspamd_config, + 'cannot configure elastic index policy (%s), status code: %s, response: %s', + policy_url, code, body) upstream:fail() handle_error('configure', 'index_policy', settings['limits']['max_fail']) end @@ -867,7 +833,7 @@ local function get_index_policy(cfg, ev_base, upstream, host, policy_url, index_ if not lua_util.table_cmp(our_policy['policy']['default_state'], current_default_state) then update_needed = true elseif not lua_util.table_cmp(our_policy['policy']['ism_template'][1]['index_patterns'], - current_ism_index_patterns) then + current_ism_index_patterns) then update_needed = true elseif not lua_util.table_cmp(our_policy['policy']['states'], current_states) then update_needed = true @@ -890,8 +856,8 @@ local function get_index_policy(cfg, ev_base, upstream, host, policy_url, index_ put_index_policy(cfg, ev_base, upstream, host, policy_url, index_policy_json) else rspamd_logger.errx(rspamd_config, - 'current elastic index policy (%s) not returned correct seq_no/primary_term, policy will not be updated, response: %s', - policy_url, body) + 'current elastic index policy (%s) not returned correct seq_no/primary_term, policy will not be updated, response: %s', + policy_url, body) upstream:fail() handle_error('validate current', 'index_policy', settings['limits']['max_fail']) end @@ -909,8 +875,8 @@ local function get_index_policy(cfg, ev_base, upstream, host, policy_url, index_ end else rspamd_logger.errx(rspamd_config, - 'cannot get current elastic index policy (%s), status code: %s, response: %s', - policy_url, code, body) + 'cannot get current elastic index policy (%s), status code: %s, response: %s', + policy_url, code, body) handle_error('get current', 'index_policy', settings['limits']['max_fail']) upstream:fail() end @@ -1037,7 +1003,7 @@ local function configure_index_policy(cfg, ev_base) } index_policy['policy']['phases']['delete'] = delete_obj end - -- opensearch state policy with hot state + -- opensearch state policy with hot state elseif detected_distro['name'] == 'opensearch' then local retry = { count = 3, @@ -1313,6 +1279,7 @@ local function configure_index_template(cfg, ev_base) type = 'object', properties = { rspamd_server = t_keyword, + digest = t_keyword, action = t_keyword, score = t_double, symbols = symbols_obj, @@ -1381,7 +1348,7 @@ local function configure_index_template(cfg, ev_base) upstream:ok() else rspamd_logger.errx(rspamd_config, 'cannot configure elastic index template (%s), status code: %s, response: %s', - template_url, code, body) + template_url, code, body) upstream:fail() handle_error('configure', 'index_template', settings['limits']['max_fail']) end @@ -1424,8 +1391,9 @@ local function verify_distro(manual) local supported_distro_info = supported_distro[detected_distro_name] -- check that detected_distro_version is valid if not detected_distro_version or type(detected_distro_version) ~= 'string' then - rspamd_logger.errx(rspamd_config, 'elastic version should be a string, but we received: %s', - type(detected_distro_version)) + rspamd_logger.errx(rspamd_config, + 'elastic version should be a string, but we received: %s', + type(detected_distro_version)) valid = false elseif detected_distro_version == '' then rspamd_logger.errx(rspamd_config, 'unsupported elastic version: empty string') @@ -1434,21 +1402,22 @@ local function verify_distro(manual) -- compare versions using compare_versions local cmp_from = compare_versions(detected_distro_version, supported_distro_info['from']) if cmp_from == -1 then - rspamd_logger.errx(rspamd_config, 'unsupported elastic version: %s, minimal supported version of %s is %s', - detected_distro_version, detected_distro_name, supported_distro_info['from']) + rspamd_logger.errx(rspamd_config, + 'unsupported elastic version: %s, minimal supported version of %s is %s', + detected_distro_version, detected_distro_name, supported_distro_info['from']) valid = false else local cmp_till = compare_versions(detected_distro_version, supported_distro_info['till']) if (cmp_till >= 0) and not supported_distro_info['till_unknown'] then rspamd_logger.errx(rspamd_config, - 'unsupported elastic version: %s, maximum supported version of %s is less than %s', - detected_distro_version, detected_distro_name, supported_distro_info['till']) + 'unsupported elastic version: %s, maximum supported version of %s is less than %s', + detected_distro_version, detected_distro_name, supported_distro_info['till']) valid = false elseif (cmp_till >= 0) and supported_distro_info['till_unknown'] then rspamd_logger.warnx(rspamd_config, - 'compatibility of elastic version: %s is unknown, maximum known supported version of %s is less than %s,' .. - 'use at your own risk', - detected_distro_version, detected_distro_name, supported_distro_info['till']) + 'compatibility of elastic version: %s is unknown, maximum known ' .. + 'supported version of %s is less than %s, use at your own risk', + detected_distro_version, detected_distro_name, supported_distro_info['till']) valid_unknown = true end end @@ -1460,11 +1429,12 @@ local function verify_distro(manual) else if valid and manual then rspamd_logger.infox( - rspamd_config, 'assuming elastic distro: %s, version: %s', detected_distro_name, detected_distro_version) + rspamd_config, 'assuming elastic distro: %s, version: %s', detected_distro_name, detected_distro_version) detected_distro['supported'] = true elseif valid and not manual then - rspamd_logger.infox(rspamd_config, 'successfully connected to elastic distro: %s, version: %s', - detected_distro_name, detected_distro_version) + rspamd_logger.infox(rspamd_config, + 'successfully connected to elastic distro: %s, version: %s', + detected_distro_name, detected_distro_version) detected_distro['supported'] = true else handle_error('configure', 'distro', settings['version']['autodetect_max_fail']) @@ -1477,7 +1447,7 @@ local function configure_distro(cfg, ev_base) detected_distro['name'] = settings['version']['override']['name'] detected_distro['version'] = settings['version']['override']['version'] rspamd_logger.infox(rspamd_config, - 'automatic detection of elastic distro and version is disabled, taking configuration from settings') + 'automatic detection of elastic distro and version is disabled, taking configuration from settings') verify_distro(true) end @@ -1490,14 +1460,16 @@ local function configure_distro(cfg, ev_base) rspamd_logger.errx(rspamd_config, 'cannot connect to elastic (%s): %s', root_url, err) upstream:fail() elseif code ~= 200 then - rspamd_logger.errx(rspamd_config, 'cannot connect to elastic (%s), status code: %s, response: %s', root_url, code, - body) + rspamd_logger.errx(rspamd_config, + 'cannot connect to elastic (%s), status code: %s, response: %s', + root_url, code, body) upstream:fail() else local parser = ucl.parser() local res, ucl_err = parser:parse_string(body) if not res then - rspamd_logger.errx(rspamd_config, 'failed to parse reply from elastic (%s): %s', root_url, ucl_err) + rspamd_logger.errx(rspamd_config, 'failed to parse reply from elastic (%s): %s', + root_url, ucl_err) upstream:fail() else local obj = parser:get_object() diff --git a/src/plugins/lua/fuzzy_collect.lua b/src/plugins/lua/fuzzy_collect.lua index 132ace90c..060cc2fc2 100644 --- a/src/plugins/lua/fuzzy_collect.lua +++ b/src/plugins/lua/fuzzy_collect.lua @@ -34,7 +34,7 @@ local settings = { local function send_data_mirror(m, cfg, ev_base, body) local function store_callback(err, _, _, _) if err then - rspamd_logger.errx(cfg, 'cannot save data on %(%s): %s', m.server, m.name, err) + rspamd_logger.errx(cfg, 'cannot save data on %s(%s): %s', m.server, m.name, err) else rspamd_logger.infox(cfg, 'saved data on %s(%s)', m.server, m.name) end diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index feccae73f..5776791a1 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -15,13 +15,14 @@ limitations under the License. ]] -- local N = "gpt" +local REDIS_PREFIX = "rsllm" local E = {} if confighelp then rspamd_config:add_example(nil, 'gpt', - "Performs postfiltering using GPT model", - [[ -gpt { + "Performs postfiltering using GPT model", + [[ + gpt { # Supported types: openai, ollama type = "openai"; # Your key to access the API @@ -48,7 +49,11 @@ gpt { allow_passthrough = false; # Check messages that are apparent ham (no action and negative score) allow_ham = false; -} + # Add header with reason (null to disable) + reason_header = "X-GPT-Reason"; + # Use JSON format for response + json = false; + } ]]) return end @@ -57,8 +62,10 @@ local lua_util = require "lua_util" local rspamd_http = require "rspamd_http" local rspamd_logger = require "rspamd_logger" local lua_mime = require "lua_mime" +local lua_redis = require "lua_redis" local ucl = require "ucl" local fun = require "fun" +local lua_cache = require "lua_cache" -- Exclude checks if one of those is found local default_symbols_to_except = { @@ -71,6 +78,32 @@ local default_symbols_to_except = { BOUNCE = -1, } +local default_extra_symbols = { + GPT_MARKETING = { + score = 0.0, + description = 'GPT model detected marketing content', + category = 'marketing', + }, + GPT_PHISHING = { + score = 3.0, + description = 'GPT model detected phishing content', + category = 'phishing', + }, + GPT_SCAM = { + score = 3.0, + description = 'GPT model detected scam content', + category = 'scam', + }, + GPT_MALWARE = { + score = 3.0, + description = 'GPT model detected malware content', + category = 'malware', + }, +} + +-- Should be filled from extra symbols +local categories_map = {} + local settings = { type = 'openai', api_key = nil, @@ -81,11 +114,18 @@ local settings = { prompt = nil, condition = nil, autolearn = false, + reason_header = nil, url = 'https://api.openai.com/v1/chat/completions', - symbols_to_except = default_symbols_to_except, + symbols_to_except = nil, + symbols_to_trigger = nil, -- Exclude/include logic allow_passthrough = false, allow_ham = false, + json = false, + extra_symbols = nil, + cache_prefix = REDIS_PREFIX, } +local redis_params +local cache_context local function default_condition(task) -- Check result @@ -108,22 +148,44 @@ local function default_condition(task) return false, 'negative score, already decided as ham' end end - -- We also exclude some symbols - for s, required_weight in pairs(settings.symbols_to_except) do - if task:has_symbol(s) then - if required_weight > 0 then - -- Also check score - local sym = task:get_symbol(s) or E - -- Must exist as we checked it before with `has_symbol` - if sym.weight then - if math.abs(sym.weight) >= required_weight then - return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')' + + if settings.symbols_to_except then + for s, required_weight in pairs(settings.symbols_to_except) do + if task:has_symbol(s) then + if required_weight > 0 then + -- Also check score + local sym = task:get_symbol(s) or E + -- Must exist as we checked it before with `has_symbol` + if sym.weight then + if math.abs(sym.weight) >= required_weight then + return false, 'skip as "' .. s .. '" is found (weight: ' .. sym.weight .. ')' + end end + lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, + sym.weight, required_weight) + else + return false, 'skip as "' .. s .. '" is found' end - lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, + end + end + end + if settings.symbols_to_trigger then + for s, required_weight in pairs(settings.symbols_to_trigger) do + if task:has_symbol(s) then + if required_weight > 0 then + -- Also check score + local sym = task:get_symbol(s) or E + -- Must exist as we checked it before with `has_symbol` + if sym.weight then + if math.abs(sym.weight) < required_weight then + return false, 'skip as "' .. s .. '" is found with low weight (weight: ' .. sym.weight .. ')' + end + end + lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, sym.weight, required_weight) + end else - return false, 'skip as "' .. s .. '" is found' + return false, 'skip as "' .. s .. '" is not found' end end end @@ -147,10 +209,10 @@ local function default_condition(task) local words = sel_part:get_words('norm') nwords = #words if nwords > settings.max_tokens then - return true, table.concat(words, ' ', 1, settings.max_tokens) + return true, table.concat(words, ' ', 1, settings.max_tokens), sel_part end end - return true, sel_part:get_content_oneline() + return true, sel_part:get_content_oneline(), sel_part end local function maybe_extract_json(str) @@ -191,7 +253,7 @@ local function maybe_extract_json(str) return nil end -local function default_conversion(task, input) +local function default_openai_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -239,7 +301,7 @@ local function default_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -247,14 +309,116 @@ local function default_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score + return spam_score, reply.reason, {} end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) return end -local function ollama_conversion(task, input) +-- Remove what we don't need +local function clean_reply_line(line) + if not line then + return '' + end + return lua_util.str_trim(line):gsub("^%d%.%s+", "") +end + +-- Assume that we have 3 lines: probability, reason, additional symbols +local function default_openai_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.choices) ~= 'table' or type(reply.choices[1]) ~= 'table' then + rspamd_logger.errx(task, 'no choices in reply') + return + end + + local first_message = reply.choices[1].message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + local lines = lua_util.str_split(first_message, '\n') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + + if spam_score then + return spam_score, reason, categories + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) + return +end + +-- Helper function to remove <think>...</think> and trim leading newlines +local function clean_gpt_response(text) + -- Remove <think>...</think> including multiline + text = text:gsub("<think>.-</think>", "") + -- Trim leading whitespace and newlines + text = text:gsub("^%s*\n*", "") + return text +end + +local function default_ollama_plain_conversion(task, input) + local parser = ucl.parser() + local res, err = parser:parse_string(input) + if not res then + rspamd_logger.errx(task, 'cannot parse reply: %s', err) + return + end + local reply = parser:get_object() + if not reply then + rspamd_logger.errx(task, 'cannot get object from reply') + return + end + + if type(reply.message) ~= 'table' then + rspamd_logger.errx(task, 'bad message in reply') + return + end + + local first_message = reply.message.content + + if not first_message then + rspamd_logger.errx(task, 'no content in the first message') + return + end + + -- Clean message + first_message = clean_gpt_response(first_message) + + local lines = lua_util.str_split(first_message, '\n') + local first_line = clean_reply_line(lines[1]) + local spam_score = tonumber(first_line) + local reason = clean_reply_line(lines[2]) + local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + + if spam_score then + return spam_score, reason, categories + end + + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) + return +end + +local function default_ollama_json_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) if not res then @@ -302,7 +466,7 @@ local function ollama_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -310,13 +474,126 @@ local function ollama_conversion(task, input) rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) end - return spam_score + return spam_score, reply.reason end rspamd_logger.errx(task, 'cannot convert spam score: %s', first_message) return end +-- Make cache specific to all settings to avoid conflicts +local env_digest = nil + +local function redis_cache_key(sel_part) + if not env_digest then + local hasher = require "rspamd_cryptobox_hash" + local digest = hasher.create() + digest:update(settings.prompt) + digest:update(settings.model) + digest:update(settings.url) + env_digest = digest:hex():sub(1, 4) + end + return string.format('%s_%s', env_digest, + sel_part:get_mimepart():get_digest():sub(1, 24)) +end + +local function process_categories(task, categories) + for _, category in ipairs(categories) do + local sym = categories_map[category:lower()] + if sym then + task:insert_result(sym.name, 1.0) + end + end +end + +local function insert_results(task, result, sel_part) + if not result.probability then + rspamd_logger.errx(task, 'no probability in result') + return + end + + if result.probability > 0.5 then + task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_spam") + end + + if result.categories then + process_categories(task, result.categories) + end + else + task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability)) + if settings.autolearn then + task:set_flag("learn_ham") + end + if result.categories then + process_categories(task, result.categories) + end + end + if result.reason and settings.reason_header then + lua_mime.modify_headers(task, + { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } }) + end + + if cache_context then + lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context) + end +end + +local function check_consensus_and_insert_results(task, results, sel_part) + for _, result in ipairs(results) do + if not result.checked then + return + end + end + + local nspam, nham = 0, 0 + local max_spam_prob, max_ham_prob = 0, 0 + local reasons = {} + + for _, result in ipairs(results) do + if result.success then + if result.probability > 0.5 then + nspam = nspam + 1 + max_spam_prob = math.max(max_spam_prob, result.probability) + lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'", + result.model, result.probability, result.reason) + else + nham = nham + 1 + max_ham_prob = math.min(max_ham_prob, result.probability) + lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'", + result.model, result.probability, result.reason) + end + + if result.reason then + table.insert(reasons, result) + end + end + end + + lua_util.shuffle(reasons) + local reason = reasons[1] or nil + + if nspam > nham and max_spam_prob > 0.75 then + insert_results(task, { + probability = max_spam_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) + elseif nham > nspam and max_ham_prob < 0.25 then + insert_results(task, { + probability = max_ham_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) + else + -- No consensus + lua_util.debugm(N, task, "no consensus") + end +end + local function get_meta_llm_content(task) local url_content = "Url domains: no urls found" if task:has_urls() then @@ -334,57 +611,70 @@ local function get_meta_llm_content(task) return url_content, from_content end -local function default_llm_check(task) - local ret, content = settings.condition(task) +local function check_llm_uncached(task, content, sel_part) + return settings.specific_check(task, content, sel_part) +end - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return - end +local function check_llm_cached(task, content, sel_part) + local cache_key = redis_cache_key(sel_part) - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return - end + lua_cache.cache_get(task, cache_key, cache_context, settings.timeout * 1.5, function() + check_llm_uncached(task, content, sel_part) + end, function(_, err, data) + if err then + rspamd_logger.errx(task, 'cannot get cache: %s', err) + check_llm_uncached(task, content, sel_part) + end + if data then + rspamd_logger.infox(task, 'found cached response %s', cache_key) + insert_results(task, data, sel_part) + else + check_llm_uncached(task, content, sel_part) + end + end) +end + +local function openai_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream - local function on_reply(err, code, body) + local results = {} - if err then - rspamd_logger.errx(task, 'request failed: %s', err) - upstream:fail() - return - end + local function gen_reply_closure(model, idx) + return function(err, code, body) + results[idx].checked = true + if err then + rspamd_logger.errx(task, '%s: request failed: %s', model, err) + upstream:fail() + check_consensus_and_insert_results(task, results, sel_part) + return + end - upstream:ok() - lua_util.debugm(N, task, "got reply: %s", body) - if code ~= 200 then - rspamd_logger.errx(task, 'bad reply: %s', body) - return - end + upstream:ok() + lua_util.debugm(N, task, "%s: got reply: %s", model, body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end - local reply = settings.reply_conversion(task, body) - if not reply then - return - end + local reply, reason, categories = settings.reply_conversion(task, body) - if reply > 0.75 then - task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_spam") - end - elseif reply < 0.25 then - task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_ham") + results[idx].model = model + + if reply then + results[idx].success = true + results[idx].probability = reply + results[idx].reason = reason + + if categories then + results[idx].categories = categories + end end - else - lua_util.debugm(N, task, "uncertain result: %s", reply) - end + check_consensus_and_insert_results(task, results, sel_part) + end end local from_content, url_content = get_meta_llm_content(task) @@ -393,7 +683,6 @@ local function default_llm_check(task) model = settings.model, max_tokens = settings.max_tokens, temperature = settings.temperature, - response_format = { type = "json_object" }, messages = { { role = 'system', @@ -401,7 +690,7 @@ local function default_llm_check(task) }, { role = 'user', - content = 'Subject: ' .. task:get_subject() or '', + content = 'Subject: ' .. (task:get_subject() or ''), }, { role = 'user', @@ -418,87 +707,92 @@ local function default_llm_check(task) } } - upstream = settings.upstreams:get_upstream_round_robin() - local http_params = { - url = settings.url, - mime_type = 'application/json', - timeout = settings.timeout, - log_obj = task, - callback = on_reply, - headers = { - ['Authorization'] = 'Bearer ' .. settings.api_key, - }, - keepalive = true, - body = ucl.to_format(body, 'json-compact', true), - task = task, - upstream = upstream, - use_gzip = true, - } - - rspamd_http.request(http_params) -end - -local function ollama_check(task) - local ret, content = settings.condition(task) + -- Conditionally add response_format + if settings.include_response_format then + body.response_format = { type = "json_object" } + end - if not ret then - rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) - return + if type(settings.model) == 'string' then + settings.model = { settings.model } end - if not content then - lua_util.debugm(N, task, "no content to send to gpt classification") - return + upstream = settings.upstreams:get_upstream_round_robin() + for idx, model in ipairs(settings.model) do + results[idx] = { + success = false, + checked = false + } + body.model = model + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = gen_reply_closure(model, idx), + headers = { + ['Authorization'] = 'Bearer ' .. settings.api_key, + }, + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } + + if not rspamd_http.request(http_params) then + results[idx].checked = true + end end +end +local function ollama_check(task, content, sel_part) lua_util.debugm(N, task, "sending content to gpt: %s", content) local upstream + local results = {} + + local function gen_reply_closure(model, idx) + return function(err, code, body) + results[idx].checked = true + if err then + rspamd_logger.errx(task, '%s: request failed: %s', model, err) + upstream:fail() + check_consensus_and_insert_results(task, results, sel_part) + return + end - local function on_reply(err, code, body) - - if err then - rspamd_logger.errx(task, 'request failed: %s', err) - upstream:fail() - return - end + upstream:ok() + lua_util.debugm(N, task, "%s: got reply: %s", model, body) + if code ~= 200 then + rspamd_logger.errx(task, 'bad reply: %s', body) + return + end - upstream:ok() - lua_util.debugm(N, task, "got reply: %s", body) - if code ~= 200 then - rspamd_logger.errx(task, 'bad reply: %s', body) - return - end + local reply, reason = settings.reply_conversion(task, body) - local reply = settings.reply_conversion(task, body) - if not reply then - return - end + results[idx].model = model - if reply > 0.75 then - task:insert_result('GPT_SPAM', (reply - 0.75) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_spam") + if reply then + results[idx].success = true + results[idx].probability = reply + results[idx].reason = reason end - elseif reply < 0.25 then - task:insert_result('GPT_HAM', (0.25 - reply) * 4, tostring(reply)) - if settings.autolearn then - task:set_flag("learn_ham") - end - else - lua_util.debugm(N, task, "uncertain result: %s", reply) - end + check_consensus_and_insert_results(task, results, sel_part) + end end local from_content, url_content = get_meta_llm_content(task) + if type(settings.model) == 'string' then + settings.model = { settings.model } + end + local body = { stream = false, model = settings.model, max_tokens = settings.max_tokens, temperature = settings.temperature, - response_format = { type = "json_object" }, messages = { { role = 'system', @@ -523,50 +817,91 @@ local function ollama_check(task) } } - upstream = settings.upstreams:get_upstream_round_robin() - local http_params = { - url = settings.url, - mime_type = 'application/json', - timeout = settings.timeout, - log_obj = task, - callback = on_reply, - keepalive = true, - body = ucl.to_format(body, 'json-compact', true), - task = task, - upstream = upstream, - use_gzip = true, - } + for i, model in ipairs(settings.model) do + -- Conditionally add response_format + if settings.include_response_format then + body.response_format = { type = "json_object" } + end + + results[i] = { + success = false, + checked = false + } + body.model = model + + upstream = settings.upstreams:get_upstream_round_robin() + local http_params = { + url = settings.url, + mime_type = 'application/json', + timeout = settings.timeout, + log_obj = task, + callback = gen_reply_closure(model, i), + keepalive = true, + body = ucl.to_format(body, 'json-compact', true), + task = task, + upstream = upstream, + use_gzip = true, + } - rspamd_http.request(http_params) + rspamd_http.request(http_params) + end end local function gpt_check(task) - return settings.specific_check(task) + local ret, content, sel_part = settings.condition(task) + + if not ret then + rspamd_logger.info(task, "skip checking gpt as the condition is not met: %s", content) + return + end + + if not content then + lua_util.debugm(N, task, "no content to send to gpt classification") + return + end + + if sel_part then + -- Check digest + check_llm_cached(task, content, sel_part) + else + check_llm_uncached(task, content) + end end local types_map = { openai = { - check = default_llm_check, + check = openai_check, condition = default_condition, - conversion = default_conversion, + conversion = function(is_json) + return is_json and default_openai_json_conversion or default_openai_plain_conversion + end, require_passkey = true, }, ollama = { check = ollama_check, condition = default_condition, - conversion = ollama_conversion, + conversion = function(is_json) + return is_json and default_ollama_json_conversion or default_ollama_plain_conversion + end, require_passkey = false, }, } -local opts = rspamd_config:get_all_opt('gpt') +local opts = rspamd_config:get_all_opt(N) if opts then + redis_params = lua_redis.parse_redis_server(N, opts) settings = lua_util.override_defaults(settings, opts) - if not settings.prompt then - settings.prompt = "You will be provided with the email message, subject, from and url domains, " .. - "and your task is to evaluate the probability to be spam as number from 0 to 1, " .. - "output result as JSON with 'probability' field." + if redis_params then + cache_context = lua_cache.create_cache_context(redis_params, settings, N) + end + + if not settings.symbols_to_except then + settings.symbols_to_except = default_symbols_to_except + end + + if not settings.extra_symbols then + settings.extra_symbols = default_extra_symbols end local llm_type = types_map[settings.type] @@ -586,7 +921,7 @@ if opts then if settings.reply_conversion then settings.reply_conversion = load(settings.reply_conversion)() else - settings.reply_conversion = llm_type.conversion + settings.reply_conversion = llm_type.conversion(settings.json) end if not settings.api_key and llm_type.require_passkey then @@ -610,7 +945,7 @@ if opts then name = 'GPT_SPAM', type = 'virtual', parent = id, - score = 5.0, + score = 3.0, }) rspamd_config:register_symbol({ name = 'GPT_HAM', @@ -618,4 +953,35 @@ if opts then parent = id, score = -2.0, }) -end
\ No newline at end of file + + if settings.extra_symbols then + for sym, data in pairs(settings.extra_symbols) do + rspamd_config:register_symbol({ + name = sym, + type = 'virtual', + parent = id, + score = data.score, + description = data.description, + }) + data.name = sym + categories_map[data.category] = data + end + end + + if not settings.prompt then + if settings.extra_symbols then + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 3 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" .. + "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') + else + settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. + "FROM and url domains. Evaluate spam probability (0-1). " .. + "Output ONLY 2 lines:\n" .. + "1. Numeric score (0.00-1.00)\n" .. + "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" + end + end +end diff --git a/src/plugins/lua/greylist.lua b/src/plugins/lua/greylist.lua index e4a633233..934e17bce 100644 --- a/src/plugins/lua/greylist.lua +++ b/src/plugins/lua/greylist.lua @@ -122,6 +122,29 @@ local function data_key(task) local h = hash.create() h:update(body, len) + local subject = task:get_subject() or '' + h:update(subject) + + -- Take recipients into account + local rcpt = task:get_recipients('smtp') + if rcpt then + table.sort(rcpt, function(r1, r2) + return r1['addr'] < r2['addr'] + end) + + fun.each(function(r) + h:update(r['addr']) + end, rcpt) + end + + -- Use from as well, but mime one + local from = task:get_from('mime') + + local addr = '<>' + if from and from[1] then + addr = from[1]['addr'] + end + h:update(addr) local b32 = settings['key_prefix'] .. 'b' .. h:base32():sub(1, 20) task:get_mempool():set_variable("grey_bodyhash", b32) diff --git a/src/plugins/lua/hfilter.lua b/src/plugins/lua/hfilter.lua index 6bc011b83..32102e4f8 100644 --- a/src/plugins/lua/hfilter.lua +++ b/src/plugins/lua/hfilter.lua @@ -131,6 +131,7 @@ local checks_hellohost = [[ /modem[.-][0-9]/i 5 /[0-9][.-]?dhcp/i 5 /wifi[.-][0-9]/i 5 +/[.-]vps[.-]/i 1 ]] local checks_hellohost_map @@ -199,9 +200,10 @@ local function check_regexp(str, regexp_text) return re:match(str) end -local function add_static_map(data) +local function add_static_map(data, description) return rspamd_config:add_map { type = 'regexp_multi', + description = description, url = { upstreams = 'static', data = data, @@ -568,16 +570,16 @@ local function append_t(t, a) end end if config['helo_enabled'] then - checks_hello_bareip_map = add_static_map(checks_hello_bareip) - checks_hello_badip_map = add_static_map(checks_hello_badip) - checks_hellohost_map = add_static_map(checks_hellohost) - checks_hello_map = add_static_map(checks_hello) + checks_hello_bareip_map = add_static_map(checks_hello_bareip, 'Hfilter: HELO bare ip') + checks_hello_badip_map = add_static_map(checks_hello_badip, 'Hfilter: HELO bad ip') + checks_hellohost_map = add_static_map(checks_hellohost, 'Hfilter: HELO host') + checks_hello_map = add_static_map(checks_hello, 'Hfilter: HELO') append_t(symbols_enabled, symbols_helo) timeout = math.max(timeout, rspamd_config:get_dns_timeout() * 3) end if config['hostname_enabled'] then if not checks_hellohost_map then - checks_hellohost_map = add_static_map(checks_hellohost) + checks_hellohost_map = add_static_map(checks_hellohost, 'Hfilter: HOSTNAME') end append_t(symbols_enabled, symbols_hostname) timeout = math.max(timeout, rspamd_config:get_dns_timeout()) diff --git a/src/plugins/lua/history_redis.lua b/src/plugins/lua/history_redis.lua index a3fdb0ec4..44eb40ad9 100644 --- a/src/plugins/lua/history_redis.lua +++ b/src/plugins/lua/history_redis.lua @@ -138,7 +138,7 @@ end local function history_save(task) local function redis_llen_cb(err, _) if err then - rspamd_logger.errx(task, 'got error %s when writing history row: %s', + rspamd_logger.errx(task, 'got error %s when writing history row', err) end end @@ -188,7 +188,7 @@ local function handle_history_request(task, conn, from, to, reset) if reset then local function redis_ltrim_cb(err, _) if err then - rspamd_logger.errx(task, 'got error %s when resetting history: %s', + rspamd_logger.errx(task, 'got error %s when resetting history', err) conn:send_error(504, '{"error": "' .. err .. '"}') else @@ -258,7 +258,7 @@ local function handle_history_request(task, conn, from, to, reset) (rspamd_util:get_ticks() - t1) * 1000.0) collectgarbage() else - rspamd_logger.errx(task, 'got error %s when getting history: %s', + rspamd_logger.errx(task, 'got error %s when getting history', err) conn:send_error(504, '{"error": "' .. err .. '"}') end diff --git a/src/plugins/lua/known_senders.lua b/src/plugins/lua/known_senders.lua index 5cb2ddcf5..0cbf3cdcf 100644 --- a/src/plugins/lua/known_senders.lua +++ b/src/plugins/lua/known_senders.lua @@ -106,21 +106,26 @@ local function configure_scripts(_, _, _) -- script checks if given recipients are in the local replies set of the sender local redis_zscore_script = [[ local replies_recipients_addrs = ARGV - if replies_recipients_addrs then + if replies_recipients_addrs and #replies_recipients_addrs > 0 then + local found = false for _, rcpt in ipairs(replies_recipients_addrs) do local score = redis.call('ZSCORE', KEYS[1], rcpt) - -- check if score is nil (for some reason redis script does not see if score is a nil value) - if type(score) == 'boolean' then - score = nil - -- 0 is stand for failure code - return 0 + if score then + -- If we found at least one recipient, consider it a match + found = true + break end end - -- first number in return statement is stands for the success/failure code - -- where success code is 1 and failure code is 0 - return 1 + + if found then + -- Success code is 1 + return 1 + else + -- Failure code is 0 + return 0 + end else - -- 0 is a failure code + -- No recipients to check, failure code is 0 return 0 end ]] @@ -259,7 +264,13 @@ local function verify_local_replies_set(task) return nil end - local replies_recipients = task:get_recipients('mime') or E + local replies_recipients = task:get_recipients('smtp') or E + + -- If no recipients, don't proceed + if #replies_recipients == 0 then + lua_util.debugm(N, task, 'No recipients to verify') + return nil + end local replies_sender_string = lua_util.maybe_obfuscate_string(tostring(replies_sender), settings, settings.sender_prefix) @@ -268,13 +279,16 @@ local function verify_local_replies_set(task) local function redis_zscore_script_cb(err, data) if err ~= nil then rspamd_logger.errx(task, 'Could not verify %s local replies set %s', replies_sender_key, err) - end - if data ~= 1 then - lua_util.debugm(N, task, 'Recipients were not verified') return end - lua_util.debugm(N, task, 'Recipients were verified') - task:insert_result(settings.symbol_check_mail_local, 1.0, replies_sender_key) + + -- We need to ensure we're properly checking the result + if data == 1 then + lua_util.debugm(N, task, 'Recipients were verified') + task:insert_result(settings.symbol_check_mail_local, 1.0, replies_sender_key) + else + lua_util.debugm(N, task, 'Recipients were not verified, data=%s', data) + end end local replies_recipients_addrs = {} @@ -284,12 +298,24 @@ local function verify_local_replies_set(task) table.insert(replies_recipients_addrs, replies_recipients[i].addr) end - lua_util.debugm(N, task, 'Making redis request to local replies set') - lua_redis.exec_redis_script(zscore_script_id, + -- Only proceed if we have recipients to check + if #replies_recipients_addrs == 0 then + lua_util.debugm(N, task, 'No recipient addresses to verify') + return nil + end + + lua_util.debugm(N, task, 'Making redis request to local replies set with key %s and recipients %s', + replies_sender_key, table.concat(replies_recipients_addrs, ", ")) + + local ret = lua_redis.exec_redis_script(zscore_script_id, { task = task, is_write = true }, redis_zscore_script_cb, { replies_sender_key }, replies_recipients_addrs) + + if not ret then + rspamd_logger.errx(task, "redis script request wasn't scheduled") + end end local function check_known_incoming_mail_callback(task) diff --git a/src/plugins/lua/milter_headers.lua b/src/plugins/lua/milter_headers.lua index 2daeeed78..17fc90562 100644 --- a/src/plugins/lua/milter_headers.lua +++ b/src/plugins/lua/milter_headers.lua @@ -138,7 +138,7 @@ local function milter_headers(task) local function skip_wanted(hdr) if settings_override then - return true + return false end -- Normal checks local function match_extended_headers_rcpt() diff --git a/src/plugins/lua/mime_types.lua b/src/plugins/lua/mime_types.lua index c69fa1e7b..73cd63c6a 100644 --- a/src/plugins/lua/mime_types.lua +++ b/src/plugins/lua/mime_types.lua @@ -128,6 +128,7 @@ local settings = { inf = 4, its = 4, jnlp = 4, + ['library-ms'] = 4, lnk = 4, ksh = 4, mad = 4, @@ -179,6 +180,7 @@ local settings = { reg = 4, scf = 4, scr = 4, + ['search-ms'] = 4, shs = 4, theme = 4, url = 4, @@ -406,9 +408,9 @@ local function check_mime_type(task) local score2 = check_tables(ext2) -- Check if detected extension match real extension if detected_ext and detected_ext == ext then - check_extension(score1, nil) + check_extension(score1, nil) else - check_extension(score1, score2) + check_extension(score1, score2) end -- Check for archive cloaking like .zip.gz if settings['archive_extensions'][ext2] diff --git a/src/plugins/lua/multimap.lua b/src/plugins/lua/multimap.lua index a61da606b..0c82b167e 100644 --- a/src/plugins/lua/multimap.lua +++ b/src/plugins/lua/multimap.lua @@ -1267,6 +1267,7 @@ local function add_multimap_rule(key, newrule) { rules = newrule.rules, expression = newrule.expression, + description = newrule.description, on_load = newrule.dynamic_symbols and multimap_on_load_gen(newrule) or nil, }, N, 'Combined map for ' .. newrule.symbol) if not newrule.combined then @@ -1281,7 +1282,7 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end elseif newrule['type'] == 'received' then @@ -1302,7 +1303,7 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end else @@ -1311,7 +1312,7 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end end @@ -1327,11 +1328,14 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end elseif newrule['type'] == 'dnsbl' then ret = true + else + rspamd_logger.errx(rspamd_config, 'cannot add rule %s: invalid type %s', + key, newrule['type']) end end diff --git a/src/plugins/lua/phishing.lua b/src/plugins/lua/phishing.lua index 3f5c9e634..4dc3fd924 100644 --- a/src/plugins/lua/phishing.lua +++ b/src/plugins/lua/phishing.lua @@ -39,7 +39,7 @@ local anchor_exceptions_maps = {} local strict_domains_maps = {} local phishing_feed_exclusion_map = nil local generic_service_map = nil -local openphish_map = 'https://www.openphish.com/feed.txt' +local openphish_map = 'https://raw.githubusercontent.com/openphish/public_feed/refs/heads/main/feed.txt' local phishtank_suffix = 'phishtank.rspamd.com' -- Not enabled by default as their feed is quite large local openphish_premium = false diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index c20e61b17..d463658fa 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -373,7 +373,7 @@ local function ratelimit_cb(task) local function gen_check_cb(prefix, bucket, lim_name, lim_key) return function(err, data) if err then - rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data) + rspamd_logger.errx('cannot check limit %s: %s', prefix, err) elseif type(data) == 'table' and data[1] then lua_util.debugm(N, task, "got reply for limit %s (%s / %s); %s burst, %s:%s dyn, %s leaked", @@ -416,7 +416,7 @@ local function ratelimit_cb(task) task:set_pre_result('soft reject', message_func(task, lim_name, prefix, bucket, lim_key), N) else - task:set_pre_result('soft reject', bucket.message) + task:set_pre_result('soft reject', bucket.message, N) end end end @@ -476,7 +476,7 @@ local function maybe_cleanup_pending(task) local bucket = v.bucket local function cleanup_cb(err, data) if err then - rspamd_logger.errx('cannot cleanup limit %s: %s %s', k, err, data) + rspamd_logger.errx('cannot cleanup limit %s: %s', k, err) else lua_util.debugm(N, task, 'cleaned pending bucked for %s: %s', k, data) end diff --git a/src/plugins/lua/rbl.lua b/src/plugins/lua/rbl.lua index 76c84f85d..b5b904b00 100644 --- a/src/plugins/lua/rbl.lua +++ b/src/plugins/lua/rbl.lua @@ -40,6 +40,7 @@ local N = 'rbl' -- Checks that could be performed by rbl module local local_exclusions +local disabled_rbl_suffixes -- Map of disabled rbl suffixes local white_symbols = {} local black_symbols = {} local monitored_addresses = {} @@ -220,7 +221,9 @@ matchers.radix = function(_, _, real_ip, map) end matchers.equality = function(codes, to_match) - if type(codes) ~= 'table' then return codes == to_match end + if type(codes) ~= 'table' then + return codes == to_match + end for _, ip in ipairs(codes) do if to_match == ip then return true @@ -470,6 +473,17 @@ local function gen_rbl_callback(rule) return true end + local function is_allowed(task, _) + if disabled_rbl_suffixes then + if disabled_rbl_suffixes:get_key(rule.rbl) then + lua_util.debugm(N, task, 'skip rbl check: %s; disabled by suffix', rule.rbl) + return false + end + end + + return true + end + local function check_required_symbols(task, _) if rule.require_symbols then return fun.all(function(sym) @@ -596,7 +610,7 @@ local function gen_rbl_callback(rule) ignore_ip = rule.no_ip, need_images = rule.images, need_emails = false, - need_content = rule.content_urls or false, + need_content = rule.content_urls, esld_limit = esld_lim, no_cache = true, } @@ -698,9 +712,9 @@ local function gen_rbl_callback(rule) requests_table, 'received', whitelist) else - lua_util.debugm(N, task, 'rbl %s; skip check_received for %s:' .. - 'Received IP same as From IP and will be checked only in check_from function', - rule.symbol, rh.real_ip) + lua_util.debugm(N, task, 'rbl %s; skip check_received for %s:' .. + 'Received IP same as From IP and will be checked only in check_from function', + rule.symbol, rh.real_ip) end end end @@ -838,6 +852,7 @@ local function gen_rbl_callback(rule) -- Create function pipeline depending on rbl settings local pipeline = { + is_allowed, -- check if rbl is allowed is_alive, -- check monitored status check_required_symbols -- if we have require_symbols then check those symbols } @@ -983,7 +998,7 @@ local function gen_rbl_callback(rule) if req.resolve_ip then -- Deal with both ipv4 and ipv6 -- Resolve names first - if r:resolve_a({ + if (rule.ipv4 == nil or rule.ipv4) and r:resolve_a({ task = task, name = req.n, callback = gen_rbl_ip_dns_callback(req), @@ -991,7 +1006,7 @@ local function gen_rbl_callback(rule) }) then nresolved = nresolved + 1 end - if r:resolve('aaaa', { + if (rule.ipv6 == nil or rule.ipv6) and r:resolve('aaaa', { task = task, name = req.n, callback = gen_rbl_ip_dns_callback(req), @@ -1062,7 +1077,7 @@ local function add_rbl(key, rbl, global_opts) rbl.selector_flatten) if not sel then - rspamd_logger.errx('invalid selector for rbl rule %s: %s', key, selector) + rspamd_logger.errx(rspamd_config, 'invalid selector for rbl rule %s: %s', key, selector) return false end @@ -1108,9 +1123,10 @@ local function add_rbl(key, rbl, global_opts) end for label, v in pairs(rbl.returncodes) do if type(v) ~= 'table' then - v = {v} + v = { v } end - rbl.returncodes_maps[label] = lua_maps.map_add_from_ucl(v, match_type, string.format('%s_%s RBL returncodes', label, rbl.symbol)) + rbl.returncodes_maps[label] = lua_maps.map_add_from_ucl(v, match_type, + string.format('%s_%s RBL returncodes', label, rbl.symbol)) end end @@ -1319,6 +1335,11 @@ if type(opts.attached_maps) == 'table' then end end +if opts.disabled_rbl_suffixes_map then + disabled_rbl_suffixes = lua_maps.map_add_from_ucl(opts.disabled_rbl_suffixes_map, 'set', + 'Disabled suffixes for RBL') +end + for key, rbl in pairs(opts.rbls) do if type(rbl) ~= 'table' or rbl.disabled == true or rbl.enabled == false then rspamd_logger.infox(rspamd_config, 'disable rbl "%s"', key) diff --git a/src/plugins/lua/replies.lua b/src/plugins/lua/replies.lua index 08fb68bc7..2f0153d00 100644 --- a/src/plugins/lua/replies.lua +++ b/src/plugins/lua/replies.lua @@ -79,8 +79,8 @@ local function configure_redis_scripts(_, _) end ]] local set_script_zadd_global = lua_util.jinja_template(redis_script_zadd_global, - { max_global_size = settings.max_global_size }) - global_replies_set_script = lua_redis.add_redis_script(set_script_zadd_global, redis_params) + { max_global_size = settings.max_global_size }) + global_replies_set_script = lua_redis.add_redis_script(set_script_zadd_global, redis_params) local redis_script_zadd_local = [[ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, -({= max_local_size =} + 1)) -- keeping size of local replies set @@ -102,7 +102,7 @@ local function configure_redis_scripts(_, _) end ]] local set_script_zadd_local = lua_util.jinja_template(redis_script_zadd_local, - { expire_time = settings.expire, max_local_size = settings.max_local_size }) + { expire_time = settings.expire, max_local_size = settings.max_local_size }) local_replies_set_script = lua_redis.add_redis_script(set_script_zadd_local, redis_params) end @@ -110,7 +110,7 @@ local function replies_check(task) local in_reply_to local function check_recipient(stored_rcpt) - local rcpts = task:get_recipients('mime') + local rcpts = task:get_recipients('smtp') lua_util.debugm(N, task, 'recipients: %s', rcpts) if rcpts then local filter_predicate = function(input_rcpt) @@ -119,7 +119,7 @@ local function replies_check(task) return real_rcpt_h == stored_rcpt end - if fun.any(filter_predicate, fun.map(function(rcpt) + if fun.all(filter_predicate, fun.map(function(rcpt) return rcpt.addr or '' end, rcpts)) then lua_util.debugm(N, task, 'reply to %s validated', in_reply_to) @@ -155,9 +155,9 @@ local function replies_check(task) end lua_redis.exec_redis_script(global_replies_set_script, - { task = task, is_write = true }, - zadd_global_set_cb, - { global_key }, params) + { task = task, is_write = true }, + zadd_global_set_cb, + { global_key }, params) end local function add_to_replies_set(recipients) @@ -173,7 +173,7 @@ local function replies_check(task) local params = recipients lua_util.debugm(N, task, - 'Adding recipients %s to sender %s local replies set', recipients, sender_key) + 'Adding recipients %s to sender %s local replies set', recipients, sender_key) local function zadd_cb(err, _) if err ~= nil then @@ -189,9 +189,9 @@ local function replies_check(task) table.insert(params, 1, task_time_str) lua_redis.exec_redis_script(local_replies_set_script, - { task = task, is_write = true }, - zadd_cb, - { sender_key }, params) + { task = task, is_write = true }, + zadd_cb, + { sender_key }, params) end local function redis_get_cb(err, data, addr) @@ -387,7 +387,7 @@ if opts then end lua_redis.register_prefix(settings.sender_prefix, N, - 'Prefix to identify replies sets') + 'Prefix to identify replies sets') local id = rspamd_config:register_symbol({ name = 'REPLIES_CHECK', diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua index bd7d91932..eacaee064 100644 --- a/src/plugins/lua/reputation.lua +++ b/src/plugins/lua/reputation.lua @@ -200,7 +200,9 @@ local function dkim_reputation_filter(task, rule) end end - if sel_tld and requests[sel_tld] then + if rule.selector.config.exclusion_map and sel_tld and rule.selector.config.exclusion_map:get_key(sel_tld) then + lua_util.debugm(N, task, 'DKIM domain %s is excluded from reputation scoring', sel_tld) + elseif sel_tld and requests[sel_tld] then if requests[sel_tld] == 'a' then rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task) end @@ -243,9 +245,13 @@ local function dkim_reputation_idempotent(task, rule) if sc then for dom, res in pairs(requests) do - -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs - local query = string.format('%s.%s', dom, res) - rule.backend.set_token(task, rule, nil, query, sc) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(dom) then + lua_util.debugm(N, task, 'DKIM domain %s is excluded from reputation update', dom) + else + -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs + local query = string.format('%s.%s', dom, res) + rule.backend.set_token(task, rule, nil, query, sc) + end end end end @@ -277,6 +283,7 @@ local dkim_selector = { outbound = true, inbound = true, max_accept_adjustment = 2.0, -- How to adjust accepted DKIM score + exclusion_map = nil }, dependencies = { "DKIM_TRACE" }, filter = dkim_reputation_filter, -- used to get scores @@ -356,10 +363,14 @@ local function url_reputation_filter(task, rule) for i, res in pairs(results) do local req = requests[i] if req then - local url_score = generic_reputation_calc(res, rule, - req[2] / mhits, task) - lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score) - score = score + url_score + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(req[1]) then + lua_util.debugm(N, task, 'URL domain %s is excluded from reputation scoring', req[1]) + else + local url_score = generic_reputation_calc(res, rule, + req[2] / mhits, task) + lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score) + score = score + url_score + end end end @@ -386,7 +397,11 @@ local function url_reputation_idempotent(task, rule) if sc then for _, tld in ipairs(requests) do - rule.backend.set_token(task, rule, nil, tld[1], sc) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(tld[1]) then + lua_util.debugm(N, task, 'URL domain %s is excluded from reputation update', tld[1]) + else + rule.backend.set_token(task, rule, nil, tld[1], sc) + end end end end @@ -401,6 +416,7 @@ local url_selector = { check_from = true, outbound = true, inbound = true, + exclusion_map = nil }, filter = url_reputation_filter, -- used to get scores idempotent = url_reputation_idempotent -- used to set scores @@ -439,6 +455,11 @@ local function ip_reputation_filter(task, rule) ip = ip:apply_mask(cfg.ipv6_mask) end + if cfg.exclusion_map and cfg.exclusion_map:get_key(ip) then + lua_util.debugm(N, task, 'IP %s is excluded from reputation scoring', tostring(ip)) + return + end + local pool = task:get_mempool() local asn = pool:get_variable("asn") local country = pool:get_variable("country") @@ -554,6 +575,11 @@ local function ip_reputation_idempotent(task, rule) ip = ip:apply_mask(cfg.ipv6_mask) end + if cfg.exclusion_map and cfg.exclusion_map:get_key(ip) then + lua_util.debugm(N, task, 'IP %s is excluded from reputation update', tostring(ip)) + return + end + local pool = task:get_mempool() local asn = pool:get_variable("asn") local country = pool:get_variable("country") @@ -600,6 +626,7 @@ local ip_selector = { inbound = true, ipv4_mask = 32, -- Mask bits for ipv4 ipv6_mask = 64, -- Mask bits for ipv6 + exclusion_map = nil }, --dependencies = {"ASN"}, -- ASN is a prefilter now... init = ip_reputation_init, @@ -621,6 +648,11 @@ local function spf_reputation_filter(task, rule) local cr = require "rspamd_cryptobox_hash" local hkey = cr.create(spf_record):base32():sub(1, 32) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(hkey) then + lua_util.debugm(N, task, 'SPF record %s is excluded from reputation scoring', hkey) + return + end + lua_util.debugm(N, task, 'check spf record %s -> %s', spf_record, hkey) local function tokens_cb(err, token, values) @@ -649,6 +681,11 @@ local function spf_reputation_idempotent(task, rule) local cr = require "rspamd_cryptobox_hash" local hkey = cr.create(spf_record):base32():sub(1, 32) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(hkey) then + lua_util.debugm(N, task, 'SPF record %s is excluded from reputation update', hkey) + return + end + lua_util.debugm(N, task, 'set spf record %s -> %s = %s', spf_record, hkey, sc) rule.backend.set_token(task, rule, nil, hkey, sc) @@ -663,6 +700,7 @@ local spf_selector = { max_score = nil, outbound = true, inbound = true, + exclusion_map = nil }, dependencies = { "R_SPF_ALLOW" }, filter = spf_reputation_filter, -- used to get scores @@ -697,6 +735,13 @@ local function generic_reputation_init(rule) 'Whitelisted selectors') end + if cfg.exclusion_map then + cfg.exclusion_map = lua_maps.map_add('reputation', + 'generic_exclusion', + 'set', + 'Excluded selectors') + end + return true end @@ -706,6 +751,10 @@ local function generic_reputation_filter(task, rule) local function tokens_cb(err, token, values) if values then + if cfg.exclusion_map and cfg.exclusion_map:get_key(token) then + lua_util.debugm(N, task, 'Generic selector token %s is excluded from reputation scoring', token) + return + end local score = generic_reputation_calc(values, rule, 1.0, task) if math.abs(score) > 1e-3 then @@ -742,14 +791,22 @@ local function generic_reputation_idempotent(task, rule) if sc then if type(selector_res) == 'table' then fun.each(function(e) - lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', - rule['symbol'], e, sc) - rule.backend.set_token(task, rule, nil, e, sc) + if cfg.exclusion_map and cfg.exclusion_map:get_key(e) then + lua_util.debugm(N, task, 'Generic selector token %s is excluded from reputation update', e) + else + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], e, sc) + rule.backend.set_token(task, rule, nil, e, sc) + end end, selector_res) else - lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', - rule['symbol'], selector_res, sc) - rule.backend.set_token(task, rule, nil, selector_res, sc) + if cfg.exclusion_map and cfg.exclusion_map:get_key(selector_res) then + lua_util.debugm(N, task, 'Generic selector token %s is excluded from reputation update', selector_res) + else + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], selector_res, sc) + rule.backend.set_token(task, rule, nil, selector_res, sc) + end end end end @@ -764,6 +821,7 @@ local generic_selector = { selector = ts.string, delimiter = ts.string, whitelist = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional(), + exclusion_map = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional() }, config = { lower_bound = 10, -- minimum number of messages to be scored @@ -773,7 +831,8 @@ local generic_selector = { inbound = true, selector = nil, delimiter = ':', - whitelist = nil + whitelist = nil, + exclusion_map = nil }, init = generic_reputation_init, filter = generic_reputation_filter, -- used to get scores @@ -1107,7 +1166,7 @@ local backends = { name = '1m', mult = 1.0, } - }, -- What buckets should be used, default 1h and 1month + }, -- What buckets should be used, default 1month }, init = reputation_redis_init, get_token = reputation_redis_get_token, @@ -1267,6 +1326,24 @@ local function parse_rule(name, tbl) end end + -- Parse exclusion_map for reputation exclusion lists + if rule.config.exclusion_map then + local map_type = 'set' -- Default to set for string-based selectors (dkim, url, spf, generic) + if sel_type == 'ip' or sel_type == 'sender' then + map_type = 'radix' -- Use radix for IP-based selectors + end + local map = lua_maps.map_add_from_ucl(rule.config.exclusion_map, + map_type, + sel_type .. ' reputation exclusion map') + if not map then + rspamd_logger.errx(rspamd_config, "cannot parse exclusion map config for %s: (%s)", + sel_type, + rule.config.exclusion_map) + return false + end + rule.config.exclusion_map = map + end + local symbol = rule.selector.config.symbol or name if tbl.symbol then symbol = tbl.symbol @@ -1387,4 +1464,4 @@ if opts['rules'] then end else lua_util.disable_module(N, "config") -end +end
\ No newline at end of file diff --git a/src/plugins/lua/settings.lua b/src/plugins/lua/settings.lua index 69d31d301..c576e1325 100644 --- a/src/plugins/lua/settings.lua +++ b/src/plugins/lua/settings.lua @@ -248,17 +248,13 @@ local function check_string_setting(expected, str) end local function check_ip_setting(expected, ip) - if not expected[2] then - if lua_maps.rspamd_maybe_check_map(expected[1], ip:to_string()) then + if type(expected) == "string" then + if lua_maps.rspamd_maybe_check_map(expected, ip:to_string()) then return true end else - if expected[2] ~= 0 then - local nip = ip:apply_mask(expected[2]) - if nip and nip:to_string() == expected[1] then - return true - end - elseif ip:to_string() == expected[1] then + local nip = ip:apply_mask(expected[2]) + if nip and nip:to_string() == expected[1] then return true end end @@ -464,44 +460,52 @@ local function gen_settings_external_cb(name) end -- Process IP address: converted to a table {ip, mask} -local function process_ip_condition(ip) - local out = {} - +local function process_ip_condition(ip, out) if type(ip) == "table" then for _, v in ipairs(ip) do - table.insert(out, process_ip_condition(v)) + process_ip_condition(v, out) end - elseif type(ip) == "string" then - local slash = string.find(ip, '/') + return + end - if not slash then - -- Just a plain IP address - local res = rspamd_ip.from_string(ip) + if type(ip) == "string" then + if string.sub(ip, 1, 4) == "map:" then + -- It is a map, don't apply any extra logic + table.insert(out, ip) + return + end - if res:is_valid() then - out[1] = res:to_string() - out[2] = 0 - else - -- It can still be a map - out[1] = ip - end - else - local res = rspamd_ip.from_string(string.sub(ip, 1, slash - 1)) - local mask = tonumber(string.sub(ip, slash + 1)) + local mask + local slash = string.find(ip, '/') + if slash then + mask = string.sub(ip, slash + 1) + ip = string.sub(ip, 1, slash - 1) + end + + local res = rspamd_ip.from_string(ip) + if res:is_valid() then + if mask then + local mask_num = tonumber(mask) + if mask_num then + -- normalize IP + res = res:apply_mask(mask_num) + if res:is_valid() then + table.insert(out, { res:to_string(), mask_num }) + return + end + end - if res:is_valid() then - out[1] = res:to_string() - out[2] = mask - else - rspamd_logger.errx(rspamd_config, "bad IP address: " .. ip) - return nil + rspamd_logger.errx(rspamd_config, "bad IP mask: %s/%s", ip, mask) + return end + + -- Just a plain IP address + table.insert(out, res:to_string()) + return end - else - return nil end - return out + rspamd_logger.errx(rspamd_config, "bad IP address: " .. ip) end -- Process email like condition, converted to a table with fields: @@ -613,6 +617,12 @@ end -- Used to create a checking closure: if value matches expected somehow, return true local function gen_check_closure(expected, check_func) + if not check_func then + check_func = function(a, b) + return a == b + end + end + return function(value) if not value then return false @@ -623,13 +633,6 @@ local function gen_check_closure(expected, check_func) end if value then - - if not check_func then - check_func = function(a, b) - return a == b - end - end - local ret if type(expected) == 'table' then ret = fun.any(function(d) @@ -659,22 +662,21 @@ local function process_settings_table(tbl, allow_ids, mempool, is_static) local checks = {} if elt.ip then - local ips_table = process_ip_condition(elt['ip']) + local ips_table = {} + process_ip_condition(elt.ip, ips_table) - if ips_table then - lua_util.debugm(N, rspamd_config, 'added ip condition to "%s": %s', - name, ips_table) - checks.ip = { - check = gen_check_closure(convert_to_table(elt.ip, ips_table), check_ip_setting), - extract = function(task) - local ip = task:get_from_ip() - if ip and ip:is_valid() then - return ip - end - return nil - end, - } - end + lua_util.debugm(N, rspamd_config, 'added ip condition to "%s": %s', + name, ips_table) + checks.ip = { + check = gen_check_closure(ips_table, check_ip_setting), + extract = function(task) + local ip = task:get_from_ip() + if ip and ip:is_valid() then + return ip + end + return nil + end, + } end if elt.ip_map then local ips_map = lua_maps.map_add_from_ucl(elt.ip_map, 'radix', @@ -697,23 +699,21 @@ local function process_settings_table(tbl, allow_ids, mempool, is_static) end if elt.client_ip then - local client_ips_table = process_ip_condition(elt.client_ip) - - if client_ips_table then - lua_util.debugm(N, rspamd_config, 'added client_ip condition to "%s": %s', - name, client_ips_table) - checks.client_ip = { - check = gen_check_closure(convert_to_table(elt.client_ip, client_ips_table), - check_ip_setting), - extract = function(task) - local ip = task:get_client_ip() - if ip:is_valid() then - return ip - end - return nil - end, - } - end + local client_ips_table = {} + process_ip_condition(elt.client_ip, client_ips_table) + + lua_util.debugm(N, rspamd_config, 'added client_ip condition to "%s": %s', + name, client_ips_table) + checks.client_ip = { + check = gen_check_closure(client_ips_table, check_ip_setting), + extract = function(task) + local ip = task:get_client_ip() + if ip:is_valid() then + return ip + end + return nil + end, + } end if elt.client_ip_map then local ips_map = lua_maps.map_add_from_ucl(elt.ip_map, 'radix', @@ -1275,7 +1275,7 @@ local function gen_redis_callback(handler, id) ucl_err) else local obj = parser:get_object() - rspamd_logger.infox(task, "<%1> apply settings according to redis rule %2", + rspamd_logger.infox(task, "<%s> apply settings according to redis rule %s", task:get_message_id(), id) apply_settings(task, obj, nil, 'redis') break @@ -1283,7 +1283,7 @@ local function gen_redis_callback(handler, id) end end elseif err then - rspamd_logger.errx(task, 'Redis error: %1', err) + rspamd_logger.errx(task, 'Redis error: %s', err) end end @@ -1371,7 +1371,7 @@ if set_section and set_section[1] and type(set_section[1]) == "string" then opaque_data = true } if not rspamd_config:add_map(map_attrs) then - rspamd_logger.errx(rspamd_config, 'cannot load settings from %1', set_section) + rspamd_logger.errx(rspamd_config, 'cannot load settings from %s', set_section) end elseif set_section and type(set_section) == "table" then settings_map_pool = rspamd_mempool.create() diff --git a/src/plugins/lua/spamassassin.lua b/src/plugins/lua/spamassassin.lua index 3ea794495..c03481de2 100644 --- a/src/plugins/lua/spamassassin.lua +++ b/src/plugins/lua/spamassassin.lua @@ -221,7 +221,7 @@ local function handle_header_def(hline, cur_rule) }) cur_rule['function'] = function(task) if not re then - rspamd_logger.errx(task, 're is missing for rule %1', h) + rspamd_logger.errx(task, 're is missing for rule %s', h) return 0 end @@ -272,7 +272,7 @@ local function handle_header_def(hline, cur_rule) elseif func == 'case' then cur_param['strong'] = true else - rspamd_logger.warnx(rspamd_config, 'Function %1 is not supported in %2', + rspamd_logger.warnx(rspamd_config, 'Function %s is not supported in %s', func, cur_rule['symbol']) end end, fun.tail(args)) @@ -314,7 +314,7 @@ end local function freemail_search(input) local res = 0 local function trie_callback(number, pos) - lua_util.debugm(N, rspamd_config, 'Matched pattern %1 at pos %2', freemail_domains[number], pos) + lua_util.debugm(N, rspamd_config, 'Matched pattern %s at pos %s', freemail_domains[number], pos) res = res + 1 end @@ -369,7 +369,7 @@ local function gen_eval_rule(arg) end return 0 else - rspamd_logger.infox(rspamd_config, 'cannot create regexp %1', re) + rspamd_logger.infox(rspamd_config, 'cannot create regexp %s', re) return 0 end end @@ -461,7 +461,7 @@ local function gen_eval_rule(arg) end end else - rspamd_logger.infox(task, 'unimplemented mime check %1', arg) + rspamd_logger.infox(task, 'unimplemented mime check %s', arg) end end @@ -576,7 +576,7 @@ local function maybe_parse_sa_function(line) local elts = split(line, '[^:]+') arg = elts[2] - lua_util.debugm(N, rspamd_config, 'trying to parse SA function %1 with args %2', + lua_util.debugm(N, rspamd_config, 'trying to parse SA function %s with args %s', elts[1], elts[2]) local substitutions = { { '^exists:', @@ -612,7 +612,7 @@ local function maybe_parse_sa_function(line) end if not func then - rspamd_logger.errx(task, 'cannot find appropriate eval rule for function %1', + rspamd_logger.errx(task, 'cannot find appropriate eval rule for function %s', arg) else return func(task) @@ -685,7 +685,7 @@ local function process_sa_conf(f) end -- We have previous rule valid if not cur_rule['symbol'] then - rspamd_logger.errx(rspamd_config, 'bad rule definition: %1', cur_rule) + rspamd_logger.errx(rspamd_config, 'bad rule definition: %s', cur_rule) end rules[cur_rule['symbol']] = cur_rule cur_rule = {} @@ -695,15 +695,15 @@ local function process_sa_conf(f) local function parse_score(words) if #words == 3 then -- score rule <x> - lua_util.debugm(N, rspamd_config, 'found score for %1: %2', words[2], words[3]) + lua_util.debugm(N, rspamd_config, 'found score for %s: %s', words[2], words[3]) return tonumber(words[3]) elseif #words == 6 then -- score rule <x1> <x2> <x3> <x4> -- we assume here that bayes and network are enabled and select <x4> - lua_util.debugm(N, rspamd_config, 'found score for %1: %2', words[2], words[6]) + lua_util.debugm(N, rspamd_config, 'found score for %s: %s', words[2], words[6]) return tonumber(words[6]) else - rspamd_logger.errx(rspamd_config, 'invalid score for %1', words[2]) + rspamd_logger.errx(rspamd_config, 'invalid score for %s', words[2]) end return 0 @@ -812,7 +812,7 @@ local function process_sa_conf(f) cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) if not cur_rule['re'] then - rspamd_logger.warnx(rspamd_config, "Cannot parse regexp '%1' for %2", + rspamd_logger.warnx(rspamd_config, "Cannot parse regexp '%s' for %s", cur_rule['re_expr'], cur_rule['symbol']) else cur_rule['re']:set_max_hits(1) @@ -829,8 +829,8 @@ local function process_sa_conf(f) cur_rule['mime'] = false end - if cur_rule['re'] and cur_rule['symbol'] and - (cur_rule['header'] or cur_rule['function']) then + if cur_rule['re'] and cur_rule['symbol'] + and (cur_rule['header'] or cur_rule['function']) then valid_rule = true cur_rule['re']:set_max_hits(1) if cur_rule['header'] and cur_rule['ordinary'] then @@ -894,7 +894,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "body" then @@ -931,7 +931,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "rawbody" then @@ -968,7 +968,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "full" then @@ -1006,7 +1006,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "uri" then @@ -1265,11 +1265,11 @@ local function post_process() if res then local nre = rspamd_regexp.create(nexpr) if not nre then - rspamd_logger.errx(rspamd_config, 'cannot apply replacement for rule %1', r) + rspamd_logger.errx(rspamd_config, 'cannot apply replacement for rule %s', r) --rule['re'] = nil else local old_max_hits = rule['re']:get_max_hits() - lua_util.debugm(N, rspamd_config, 'replace %1 -> %2', r, nexpr) + lua_util.debugm(N, rspamd_config, 'replace %s -> %s', r, nexpr) rspamd_config:replace_regexp({ old_re = rule['re'], new_re = nre, @@ -1306,8 +1306,7 @@ local function post_process() end if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1 (%2 header)', k, - h['header']) + rspamd_logger.errx(task, 're is missing for rule %s', h) return 0 end @@ -1434,7 +1433,7 @@ local function post_process() fun.each(function(k, r) local f = function(task) if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1', k) + rspamd_logger.errx(task, 're is missing for rule %s', k) return 0 end @@ -1461,7 +1460,7 @@ local function post_process() fun.each(function(k, r) local f = function(task) if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1', k) + rspamd_logger.errx(task, 're is missing for rule %s', k) return 0 end @@ -1486,7 +1485,7 @@ local function post_process() fun.each(function(k, r) local f = function(task) if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1', k) + rspamd_logger.errx(task, 're is missing for rule %s', k) return 0 end @@ -1629,8 +1628,8 @@ local function post_process() rspamd_config:register_dependency(k, rspamd_symbol) external_deps[k][rspamd_symbol] = true lua_util.debugm(N, rspamd_config, - 'atom %1 is a direct foreign dependency, ' .. - 'register dependency for %2 on %3', + 'atom %s is a direct foreign dependency, ' .. + 'register dependency for %s on %s', a, k, rspamd_symbol) end end @@ -1659,8 +1658,8 @@ local function post_process() rspamd_config:register_dependency(k, dep) external_deps[k][dep] = true lua_util.debugm(N, rspamd_config, - 'atom %1 is an indirect foreign dependency, ' .. - 'register dependency for %2 on %3', + 'atom %s is an indirect foreign dependency, ' .. + 'register dependency for %s on %s', a, k, dep) nchanges = nchanges + 1 end @@ -1694,10 +1693,10 @@ local function post_process() -- Logging output if freemail_domains then freemail_trie = rspamd_trie.create(freemail_domains) - rspamd_logger.infox(rspamd_config, 'loaded %1 freemail domains definitions', + rspamd_logger.infox(rspamd_config, 'loaded %s freemail domains definitions', #freemail_domains) end - rspamd_logger.infox(rspamd_config, 'loaded %1 blacklist/whitelist elements', + rspamd_logger.infox(rspamd_config, 'loaded %s blacklist/whitelist elements', sa_lists['elts']) end @@ -1739,7 +1738,7 @@ if type(section) == "table" then process_sa_conf(f) has_rules = true else - rspamd_logger.errx(rspamd_config, "cannot open %1", matched) + rspamd_logger.errx(rspamd_config, "cannot open %s", matched) end end end @@ -1758,7 +1757,7 @@ if type(section) == "table" then process_sa_conf(f) has_rules = true else - rspamd_logger.errx(rspamd_config, "cannot open %1", matched) + rspamd_logger.errx(rspamd_config, "cannot open %s", matched) end end end diff --git a/src/plugins/lua/trie.lua b/src/plugins/lua/trie.lua index 7ba455289..7c7214b55 100644 --- a/src/plugins/lua/trie.lua +++ b/src/plugins/lua/trie.lua @@ -107,10 +107,10 @@ local function process_trie_file(symbol, cf) local file = io.open(cf['file']) if not file then - rspamd_logger.errx(rspamd_config, 'Cannot open trie file %1', cf['file']) + rspamd_logger.errx(rspamd_config, 'Cannot open trie file %s', cf['file']) else if cf['binary'] then - rspamd_logger.errx(rspamd_config, 'binary trie patterns are not implemented yet: %1', + rspamd_logger.errx(rspamd_config, 'binary trie patterns are not implemented yet: %s', cf['file']) else for line in file:lines() do @@ -123,7 +123,7 @@ end local function process_trie_conf(symbol, cf) if type(cf) ~= 'table' then - rspamd_logger.errx(rspamd_config, 'invalid value for symbol %1: "%2", expected table', + rspamd_logger.errx(rspamd_config, 'invalid value for symbol %s: "%s", expected table', symbol, cf) return end @@ -145,17 +145,17 @@ if opts then if #raw_patterns > 0 then raw_trie = rspamd_trie.create(raw_patterns) - rspamd_logger.infox(rspamd_config, 'registered raw search trie from %1 patterns', #raw_patterns) + rspamd_logger.infox(rspamd_config, 'registered raw search trie from %s patterns', #raw_patterns) end if #mime_patterns > 0 then mime_trie = rspamd_trie.create(mime_patterns) - rspamd_logger.infox(rspamd_config, 'registered mime search trie from %1 patterns', #mime_patterns) + rspamd_logger.infox(rspamd_config, 'registered mime search trie from %s patterns', #mime_patterns) end if #body_patterns > 0 then body_trie = rspamd_trie.create(body_patterns) - rspamd_logger.infox(rspamd_config, 'registered body search trie from %1 patterns', #body_patterns) + rspamd_logger.infox(rspamd_config, 'registered body search trie from %s patterns', #body_patterns) end local id = -1 diff --git a/src/rspamadm/configdump.c b/src/rspamadm/configdump.c index 456875cf2..d090b66f0 100644 --- a/src/rspamadm/configdump.c +++ b/src/rspamadm/configdump.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,8 @@ static gboolean symbol_groups_only = FALSE; static gboolean symbol_full_details = FALSE; static gboolean skip_template = FALSE; static char *config = NULL; +static gboolean local_conf_only = FALSE; +static gboolean override_conf_only = FALSE; extern struct rspamd_main *rspamd_main; /* Defined in modules.c */ extern module_t *modules[]; @@ -66,6 +68,8 @@ static GOptionEntry entries[] = { "Show full symbol details only", NULL}, {"skip-template", 'T', 0, G_OPTION_ARG_NONE, &skip_template, "Do not apply Jinja templates", NULL}, + {"local", 0, 0, G_OPTION_ARG_NONE, &local_conf_only, "Show only local and override configuration", NULL}, + {"override", 0, 0, G_OPTION_ARG_NONE, &override_conf_only, "Show only override configuration", NULL}, {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}}; static const char * @@ -82,6 +86,8 @@ rspamadm_configdump_help(gboolean full_help, const struct rspamadm_command *cmd) "-c: config file to test\n" "-m: show state of modules only\n" "-h: show help for dumped options\n" + "--local: show only local (and override) configuration\n" + "--override: show only override configuration\n" "--help: shows available options and commands"; } else { @@ -96,6 +102,84 @@ config_logger(rspamd_mempool_t *pool, gpointer ud) { } +static ucl_object_t * +filter_non_default(const ucl_object_t *obj, bool override_only) +{ + ucl_object_t *result = NULL; + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + + if (obj == NULL) { + return NULL; + } + + int min_prio = override_only ? 1 : 0; + + if (ucl_object_get_priority(obj) > min_prio) { + + switch (ucl_object_type(obj)) { + case UCL_OBJECT: + result = ucl_object_typed_new(ucl_object_type(obj)); + + while ((cur = ucl_object_iterate(obj, &it, true))) { + ucl_object_t *filtered = filter_non_default(cur, override_conf_only); + if (filtered) { + ucl_object_insert_key(result, filtered, ucl_object_key(cur), cur->keylen, true); + } + } + break; + case UCL_ARRAY: + result = ucl_object_typed_new(ucl_object_type(obj)); + + while ((cur = ucl_object_iterate(obj, &it, true))) { + ucl_object_t *filtered = filter_non_default(cur, override_conf_only); + if (filtered) { + ucl_array_append(result, filtered); + } + } + default: + result = ucl_object_ref(obj); + break; + } + + return result; + } + + if (ucl_object_type(obj) == UCL_OBJECT || ucl_object_type(obj) == UCL_ARRAY) { + bool has_non_default = false; + + result = ucl_object_typed_new(ucl_object_type(obj)); + while ((cur = ucl_object_iterate(obj, &it, true))) { + ucl_object_t *filtered = filter_non_default(cur, override_only); + if (filtered) { + has_non_default = true; + + if (ucl_object_type(obj) == UCL_OBJECT) { + ucl_object_insert_key(result, filtered, + ucl_object_key(cur), cur->keylen, true); + } + else if (ucl_object_type(obj) == UCL_ARRAY) { + ucl_array_append(result, filtered); + } + else { + g_assert_not_reached(); + } + } + } + + /* Avoid empty objects */ + if (!has_non_default) { + ucl_object_unref(result); + result = NULL; + } + + return result; + } + + + return NULL; +} + static void rspamadm_add_doc_elt(const ucl_object_t *obj, const ucl_object_t *doc_obj, ucl_object_t *comment_obj) @@ -524,7 +608,20 @@ rspamadm_configdump(int argc, char **argv, const struct rspamadm_command *cmd) /* Output configuration */ if (argc == 1) { - rspamadm_dump_section_obj(cfg, cfg->cfg_ucl_obj, cfg->doc_strings); + const ucl_object_t *output_obj = cfg->cfg_ucl_obj; + if (local_conf_only || override_conf_only) { + output_obj = filter_non_default(cfg->cfg_ucl_obj, override_conf_only); + if (!output_obj) { + rspamd_printf("No non-default configuration found\n"); + exit(EXIT_SUCCESS); + } + } + + rspamadm_dump_section_obj(cfg, output_obj, cfg->doc_strings); + + if (local_conf_only || override_conf_only) { + ucl_object_unref((ucl_object_t *) output_obj); + } } else { for (i = 1; i < argc; i++) { @@ -537,10 +634,18 @@ rspamadm_configdump(int argc, char **argv, const struct rspamadm_command *cmd) else { LL_FOREACH(obj, cur) { + const ucl_object_t *output_obj = cur; + if (local_conf_only || override_conf_only) { + output_obj = filter_non_default(cur, override_conf_only); + if (!output_obj) { + rspamd_printf("No non-default configuration found for section %s\n", argv[i]); + continue; + } + } if (!json && !compact) { rspamd_printf("*** Section %s ***\n", argv[i]); } - rspamadm_dump_section_obj(cfg, cur, doc_obj); + rspamadm_dump_section_obj(cfg, output_obj, doc_obj); if (!json && !compact) { rspamd_printf("\n*** End of section %s ***\n", argv[i]); @@ -548,6 +653,10 @@ rspamadm_configdump(int argc, char **argv, const struct rspamadm_command *cmd) else { rspamd_printf("\n"); } + + if (local_conf_only || override_conf_only) { + ucl_object_unref((ucl_object_t *) output_obj); + } } } } diff --git a/src/rspamadm/control.c b/src/rspamadm/control.c index 381bdaa7a..cd550c04e 100644 --- a/src/rspamadm/control.c +++ b/src/rspamadm/control.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -112,7 +112,7 @@ rspamd_control_finish_handler(struct rspamd_http_connection *conn, struct rspamadm_control_cbdata *cbdata = conn->ud; body = rspamd_http_message_get_body(msg, &body_len); - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!body || !ucl_parser_add_chunk(parser, body, body_len)) { rspamd_fprintf(stderr, "cannot parse server's reply: %s\n", diff --git a/src/rspamadm/lua_repl.c b/src/rspamadm/lua_repl.c index 1d6da5aa9..f9099d895 100644 --- a/src/rspamadm/lua_repl.c +++ b/src/rspamadm/lua_repl.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,8 @@ #include "lua/lua_thread_pool.h" #include "message.h" #include "unix-std.h" -#ifdef WITH_LUA_REPL + #include "replxx.h" -#endif #include "worker_util.h" #ifdef WITH_LUAJIT #include <luajit.h> @@ -43,10 +42,7 @@ static int batch = -1; extern struct rspamd_async_session *rspamadm_session; static const char *default_history_file = ".rspamd_repl.hist"; - -#ifdef WITH_LUA_REPL static Replxx *rx_instance = NULL; -#endif #ifdef WITH_LUAJIT #define MAIN_PROMPT LUAJIT_VERSION "> " @@ -232,7 +228,6 @@ rspamadm_exec_input(lua_State *L, const char *input) int i, cbref; int top = 0; char outbuf[8192]; - struct lua_logger_trace tr; struct thread_entry *thread = lua_thread_pool_get_for_config(rspamd_main->cfg); L = thread->lua_state; @@ -272,9 +267,8 @@ rspamadm_exec_input(lua_State *L, const char *input) rspamd_printf("local function: %d\n", cbref); } else { - memset(&tr, 0, sizeof(tr)); - lua_logger_out_type(L, i, outbuf, sizeof(outbuf) - 1, &tr, - LUA_ESCAPE_UNPRINTABLE); + lua_logger_out(L, i, outbuf, sizeof(outbuf), + LUA_ESCAPE_UNPRINTABLE); rspamd_printf("%s\n", outbuf); } } @@ -393,7 +387,6 @@ rspamadm_lua_message_handler(lua_State *L, int argc, char **argv) gpointer map; gsize len; char outbuf[8192]; - struct lua_logger_trace tr; if (argv[1] == NULL) { rspamd_printf("no callback is specified\n"); @@ -455,9 +448,8 @@ rspamadm_lua_message_handler(lua_State *L, int argc, char **argv) rspamd_printf("lua callback for %s returned:\n", argv[i]); for (j = old_top + 1; j <= lua_gettop(L); j++) { - memset(&tr, 0, sizeof(tr)); - lua_logger_out_type(L, j, outbuf, sizeof(outbuf), &tr, - LUA_ESCAPE_UNPRINTABLE); + lua_logger_out(L, j, outbuf, sizeof(outbuf), + LUA_ESCAPE_UNPRINTABLE); rspamd_printf("%s\n", outbuf); } } @@ -503,7 +495,6 @@ rspamadm_lua_try_dot_command(lua_State *L, const char *input) return FALSE; } -#ifdef WITH_LUA_REPL static int lex_ref_idx = -1; static void @@ -599,20 +590,14 @@ lua_syntax_highlighter(const char *str, ReplxxColor *colours, int size, void *ud lua_settop(L, 0); } -#endif static void rspamadm_lua_run_repl(lua_State *L, bool is_batch) { char *input = NULL; -#ifdef WITH_LUA_REPL gboolean is_multiline = FALSE; GString *tb = NULL; gsize i; -#else - /* Always set is_batch */ - is_batch = TRUE; -#endif for (;;) { if (is_batch) { @@ -644,7 +629,6 @@ rspamadm_lua_run_repl(lua_State *L, bool is_batch) lua_settop(L, 0); } else { -#ifdef WITH_LUA_REPL replxx_set_highlighter_callback(rx_instance, lua_syntax_highlighter, L); @@ -706,7 +690,6 @@ rspamadm_lua_run_repl(lua_State *L, bool is_batch) } } } -#endif } } @@ -1009,16 +992,12 @@ rspamadm_lua(int argc, char **argv, const struct rspamadm_command *cmd) } if (!batch) { -#ifdef WITH_LUA_REPL rx_instance = replxx_init(); replxx_set_max_history_size(rx_instance, max_history); replxx_history_load(rx_instance, histfile); -#endif rspamadm_lua_run_repl(L, false); -#ifdef WITH_LUA_REPL replxx_history_save(rx_instance, histfile); replxx_end(rx_instance); -#endif } else { rspamadm_lua_run_repl(L, true); diff --git a/src/rspamadm/signtool.c b/src/rspamadm/signtool.c index 6d60e6700..538767b19 100644 --- a/src/rspamadm/signtool.c +++ b/src/rspamadm/signtool.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -573,7 +573,7 @@ rspamadm_signtool(int argc, char **argv, const struct rspamadm_command *cmd) else { g_assert(keypair_file != NULL); - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_file(parser, keypair_file) || (top = ucl_parser_get_object(parser)) == NULL) { diff --git a/src/rspamd.h b/src/rspamd.h index a5ef068e1..be66a192b 100644 --- a/src/rspamd.h +++ b/src/rspamd.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,6 +69,7 @@ enum rspamd_worker_flags { RSPAMD_WORKER_NO_TERMINATE_DELAY = (1 << 7), RSPAMD_WORKER_OLD_CONFIG = (1 << 8), RSPAMD_WORKER_NO_STRICT_CONFIG = (1 << 9), + RSPAMD_WORKER_FUZZY = (1 << 10), }; struct rspamd_worker_accept_event { diff --git a/src/rspamd_proxy.c b/src/rspamd_proxy.c index 694e87c12..77d2336b2 100644 --- a/src/rspamd_proxy.c +++ b/src/rspamd_proxy.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -85,6 +85,12 @@ worker_t rspamd_proxy_worker = { RSPAMD_WORKER_SOCKET_TCP, /* TCP socket */ RSPAMD_WORKER_VER}; +enum rspamd_proxy_log_tag_type { + RSPAMD_PROXY_LOG_TAG_SESSION = 0, /* Use session mempool tag (default) */ + RSPAMD_PROXY_LOG_TAG_QUEUE_ID, /* Use Queue-ID from client message */ + RSPAMD_PROXY_LOG_TAG_NONE, /* Skip log tag passing */ +}; + struct rspamd_http_upstream { char *name; char *settings_id; @@ -96,6 +102,10 @@ struct rspamd_http_upstream { gboolean local; gboolean self_scan; gboolean compress; + gboolean ssl; + gboolean keepalive; /* Whether to use keepalive for this upstream */ + enum rspamd_proxy_log_tag_type log_tag_type; + ucl_object_t *extra_headers; }; struct rspamd_http_mirror { @@ -109,6 +119,10 @@ struct rspamd_http_mirror { int parser_to_ref; gboolean local; gboolean compress; + gboolean ssl; + gboolean keepalive; /* Whether to use keepalive for this mirror */ + enum rspamd_proxy_log_tag_type log_tag_type; + ucl_object_t *extra_headers; }; static const uint64_t rspamd_rspamd_proxy_magic = 0xcdeb4fd1fc351980ULL; @@ -161,6 +175,8 @@ struct rspamd_proxy_ctx { /* Language detector */ struct rspamd_lang_detector *lang_det; double task_timeout; + /* Default log tag type for worker */ + enum rspamd_proxy_log_tag_type log_tag_type; struct rspamd_main *srv; }; @@ -214,6 +230,7 @@ struct rspamd_proxy_session { enum rspamd_proxy_legacy_support legacy_support; int retries; ref_entry_t ref; + gboolean use_keepalive; /* Whether to use keepalive for this session */ }; static gboolean proxy_send_master_message(struct rspamd_proxy_session *session); @@ -224,6 +241,77 @@ rspamd_proxy_quark(void) return g_quark_from_static_string("rspamd-proxy"); } +static enum rspamd_proxy_log_tag_type +rspamd_proxy_parse_log_tag_type(const char *str) +{ + if (str == NULL) { + return RSPAMD_PROXY_LOG_TAG_SESSION; + } + + if (g_ascii_strcasecmp(str, "session") == 0 || + g_ascii_strcasecmp(str, "session_tag") == 0) { + return RSPAMD_PROXY_LOG_TAG_SESSION; + } + else if (g_ascii_strcasecmp(str, "queue_id") == 0 || + g_ascii_strcasecmp(str, "queue-id") == 0) { + return RSPAMD_PROXY_LOG_TAG_QUEUE_ID; + } + else if (g_ascii_strcasecmp(str, "none") == 0 || + g_ascii_strcasecmp(str, "skip") == 0) { + return RSPAMD_PROXY_LOG_TAG_NONE; + } + + /* Default to session tag for unknown values */ + return RSPAMD_PROXY_LOG_TAG_SESSION; +} + +static void +rspamd_proxy_add_log_tag_header(struct rspamd_http_message *msg, + struct rspamd_proxy_session *session, + enum rspamd_proxy_log_tag_type log_tag_type) +{ + const rspamd_ftok_t *queue_id_hdr; + + switch (log_tag_type) { + case RSPAMD_PROXY_LOG_TAG_SESSION: + /* Use session mempool tag (current behavior) */ + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + break; + + case RSPAMD_PROXY_LOG_TAG_QUEUE_ID: + /* Try to extract Queue-ID from client message */ + if (session->client_message) { + queue_id_hdr = rspamd_http_message_find_header(session->client_message, QUEUE_ID_HEADER); + if (queue_id_hdr) { + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, + queue_id_hdr->begin, queue_id_hdr->len); + } + /* If no Queue-ID found, fall back to session tag */ + else { + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + } + } + else { + /* No client message, fall back to session tag */ + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + } + break; + + case RSPAMD_PROXY_LOG_TAG_NONE: + /* Skip adding log tag header */ + break; + + default: + /* Fall back to session tag for unknown types */ + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + break; + } +} + static gboolean rspamd_proxy_parse_lua_parser(lua_State *L, const ucl_object_t *obj, int *ref_from, int *ref_to, GError **err) @@ -392,6 +480,7 @@ rspamd_proxy_parse_upstream(rspamd_mempool_t *pool, up->parser_from_ref = -1; up->parser_to_ref = -1; up->timeout = ctx->timeout; + up->log_tag_type = ctx->log_tag_type; /* Inherit from worker default */ elt = ucl_object_lookup(obj, "key"); if (elt != NULL) { @@ -420,6 +509,21 @@ rspamd_proxy_parse_upstream(rspamd_mempool_t *pool, up->compress = TRUE; } + elt = ucl_object_lookup(obj, "ssl"); + if (elt && ucl_object_toboolean(elt)) { + up->ssl = TRUE; + } + + elt = ucl_object_lookup_any(obj, "keepalive", "keep_alive", NULL); + if (elt && ucl_object_toboolean(elt)) { + up->keepalive = TRUE; + } + + elt = ucl_object_lookup_any(obj, "keepalive", "keep_alive", NULL); + if (elt && ucl_object_toboolean(elt)) { + up->keepalive = TRUE; + } + elt = ucl_object_lookup(obj, "hosts"); if (elt == NULL && !up->self_scan) { @@ -469,6 +573,27 @@ rspamd_proxy_parse_upstream(rspamd_mempool_t *pool, up->settings_id = rspamd_mempool_strdup(pool, ucl_object_tostring(elt)); } + elt = ucl_object_lookup(obj, "extra_headers"); + if (elt && ucl_object_type(elt) == UCL_OBJECT) { + up->extra_headers = ucl_object_ref(elt); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) ucl_object_unref, + up->extra_headers); + } + + elt = ucl_object_lookup(obj, "extra_headers"); + if (elt && ucl_object_type(elt) == UCL_OBJECT) { + up->extra_headers = ucl_object_ref(elt); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) ucl_object_unref, + up->extra_headers); + } + + elt = ucl_object_lookup_any(obj, "log_tag", "log_tag_type", NULL); + if (elt && ucl_object_type(elt) == UCL_STRING) { + up->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(elt)); + } + /* * Accept lua function here in form * fun :: String -> UCL @@ -568,6 +693,7 @@ rspamd_proxy_parse_mirror(rspamd_mempool_t *pool, up->parser_to_ref = -1; up->parser_from_ref = -1; up->timeout = ctx->timeout; + up->log_tag_type = ctx->log_tag_type; /* Inherit from worker default */ elt = ucl_object_lookup(obj, "key"); if (elt != NULL) { @@ -648,6 +774,11 @@ rspamd_proxy_parse_mirror(rspamd_mempool_t *pool, up->settings_id = rspamd_mempool_strdup(pool, ucl_object_tostring(elt)); } + elt = ucl_object_lookup_any(obj, "log_tag", "log_tag_type", NULL); + if (elt && ucl_object_type(elt) == UCL_STRING) { + up->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(elt)); + } + g_ptr_array_add(ctx->mirrors, up); return TRUE; @@ -747,6 +878,29 @@ err: return FALSE; } +static gboolean +rspamd_proxy_parse_log_tag_worker_option(rspamd_mempool_t *pool, + const ucl_object_t *obj, + gpointer ud, + struct rspamd_rcl_section *section, + GError **err) +{ + struct rspamd_proxy_ctx *ctx; + struct rspamd_rcl_struct_parser *pd = ud; + + ctx = pd->user_struct; + + if (ucl_object_type(obj) != UCL_STRING) { + g_set_error(err, rspamd_proxy_quark(), 100, + "log_tag_type option must be a string"); + return FALSE; + } + + ctx->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(obj)); + + return TRUE; +} + gpointer init_rspamd_proxy(struct rspamd_config *cfg) { @@ -772,6 +926,7 @@ init_rspamd_proxy(struct rspamd_config *cfg) (rspamd_mempool_destruct_t) rspamd_array_free_hard, ctx->cmp_refs); ctx->max_retries = DEFAULT_RETRIES; ctx->spam_header = RSPAMD_MILTER_SPAM_HEADER; + ctx->log_tag_type = RSPAMD_PROXY_LOG_TAG_SESSION; /* Default to session tag */ rspamd_rcl_register_worker_option(cfg, type, @@ -895,6 +1050,16 @@ init_rspamd_proxy(struct rspamd_config *cfg) 0, "Use custom tempfail message"); + /* We need a custom parser for log_tag_type as it's an enum */ + rspamd_rcl_register_worker_option(cfg, + type, + "log_tag_type", + rspamd_proxy_parse_log_tag_worker_option, + ctx, + 0, + 0, + "Log tag type: session (default), queue_id, or none"); + return ctx; } @@ -905,7 +1070,11 @@ proxy_backend_close_connection(struct rspamd_proxy_backend_connection *conn) if (conn->backend_conn) { rspamd_http_connection_reset(conn->backend_conn); rspamd_http_connection_unref(conn->backend_conn); - close(conn->backend_sock); + + if (!(conn->s && conn->s->use_keepalive)) { + /* Only close socket if we're not using keepalive */ + close(conn->backend_sock); + } } conn->flags |= RSPAMD_BACKEND_CLOSED; @@ -970,7 +1139,7 @@ proxy_backend_parse_results(struct rspamd_proxy_session *session, RSPAMD_FTOK_ASSIGN(&json_ct, "application/json"); if (ct && rspamd_ftok_casecmp(ct, &json_ct) == 0) { - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, in, inlen)) { char *encoded; @@ -1384,6 +1553,8 @@ proxy_backend_mirror_finish_handler(struct rspamd_http_connection *conn, struct rspamd_proxy_backend_connection *bk_conn = conn->ud; struct rspamd_proxy_session *session; const rspamd_ftok_t *orig_ct; + const rspamd_ftok_t *conn_hdr; + gboolean is_keepalive = FALSE; session = bk_conn->s; @@ -1403,6 +1574,36 @@ proxy_backend_mirror_finish_handler(struct rspamd_http_connection *conn, bk_conn->name, msg->code); rspamd_upstream_ok(bk_conn->up); + /* Check if we can use keepalive */ + conn_hdr = rspamd_http_message_find_header(msg, "Connection"); + if (conn_hdr) { + if (rspamd_substring_search_caseless(conn_hdr->begin, conn_hdr->len, + "keep-alive", 10) != -1) { + is_keepalive = TRUE; + } + } + + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + /* Store connection in keepalive pool */ + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_prepare_keepalive(session->ctx->http_ctx, + conn, rspamd_upstream_addr_cur(bk_conn->up), + up_name, FALSE); + rspamd_http_context_push_keepalive(session->ctx->http_ctx, + conn, msg, session->ctx->event_loop); + + msg_debug_session("pushed mirror connection to %s to keepalive pool", + bk_conn->name); + + /* Mark connection as closed without actually closing it */ + bk_conn->flags |= RSPAMD_BACKEND_CLOSED; + REF_RELEASE(bk_conn->s); + return 0; + } + } + proxy_backend_close_connection(bk_conn); REF_RELEASE(bk_conn->s); @@ -1418,6 +1619,7 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) struct rspamd_proxy_backend_connection *bk_conn; struct rspamd_http_message *msg; GError *err = NULL; + const rspamd_inet_addr_t *keepalive_addr; coin = rspamd_random_double(); @@ -1429,6 +1631,157 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) continue; } + /* Check if we can use keepalive for this mirror */ + if (m->keepalive && session->ctx->http_ctx) { + const char *up_name = NULL; + unsigned int port = 0; + + /* Try to find a keepalive connection */ + if (m->u) { + struct upstream *up = rspamd_upstream_get(m->u, + RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); + if (up) { + up_name = rspamd_upstream_name(up); + port = rspamd_inet_address_get_port(rspamd_upstream_addr_cur(up)); + } + } + + if (up_name) { + keepalive_addr = rspamd_http_context_has_keepalive( + session->ctx->http_ctx, up_name, port, m->ssl); + + if (keepalive_addr) { + /* We found a keepalive connection, use it */ + struct rspamd_http_connection *conn; + + conn = rspamd_http_context_check_keepalive( + session->ctx->http_ctx, + (rspamd_inet_addr_t *) keepalive_addr, + up_name, + m->ssl); + + if (conn) { + /* We have a keepalive connection, set it up */ + bk_conn = rspamd_mempool_alloc0(session->pool, sizeof(*bk_conn)); + bk_conn->s = session; + bk_conn->name = m->name; + bk_conn->timeout = m->timeout; + bk_conn->parser_from_ref = m->parser_from_ref; + bk_conn->parser_to_ref = m->parser_to_ref; + bk_conn->backend_conn = conn; + bk_conn->backend_sock = conn->fd; + + msg = rspamd_http_connection_copy_msg(session->client_message, &err); + + if (msg == NULL) { + msg_err_session("cannot copy message to send to a mirror %s: %e", + m->name, err); + if (err) { + g_error_free(err); + } + continue; + } + + if (up_name) { + rspamd_http_message_remove_header(msg, "Host"); + rspamd_http_message_add_header(msg, "Host", up_name); + } + rspamd_http_message_remove_header(msg, "Connection"); + rspamd_http_message_add_header(msg, "Connection", "keep-alive"); + + if (msg->url->len == 0) { + msg->url = rspamd_fstring_append(msg->url, "/check", strlen("/check")); + } + + if (m->settings_id != NULL) { + rspamd_http_message_remove_header(msg, "Settings-ID"); + rspamd_http_message_add_header(msg, "Settings-ID", m->settings_id); + } + + /* Add extra headers if specified */ + if (m->extra_headers != NULL) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + const char *key, *value; + + while ((cur = ucl_object_iterate(m->extra_headers, &it, true)) != NULL) { + key = ucl_object_key(cur); + value = ucl_object_tostring(cur); + + if (key != NULL && value != NULL) { + rspamd_http_message_remove_header(msg, key); + rspamd_http_message_add_header(msg, key, value); + } + } + } + + /* Add log tag header based on mirror's configuration */ + rspamd_proxy_add_log_tag_header(msg, session, m->log_tag_type); + + /* Set handlers for the connection */ + conn->error_handler = proxy_backend_mirror_error_handler; + conn->finish_handler = proxy_backend_mirror_finish_handler; + conn->ud = bk_conn; + + if (m->key) { + msg->peer_key = rspamd_pubkey_ref(m->key); + } + + if (m->local || rspamd_inet_address_is_local(keepalive_addr)) { + if (session->fname) { + rspamd_http_message_add_header(msg, "File", session->fname); + } + + msg->method = HTTP_GET; + rspamd_http_connection_write_message_shared(conn, + msg, up_name, + NULL, bk_conn, + bk_conn->timeout); + } + else { + if (session->fname) { + msg->flags &= ~RSPAMD_HTTP_FLAG_SHMEM; + rspamd_http_message_set_body(msg, session->map, session->map_len); + } + + msg->method = HTTP_POST; + + if (m->compress) { + proxy_request_compress(msg); + + if (session->client_milter_conn) { + rspamd_http_message_add_header(msg, "Content-Type", + "application/octet-stream"); + } + } + else { + if (session->client_milter_conn) { + rspamd_http_message_add_header(msg, "Content-Type", + "text/plain"); + } + } + + rspamd_http_connection_write_message(conn, + msg, up_name, NULL, bk_conn, + bk_conn->timeout); + } + + g_ptr_array_add(session->mirror_conns, bk_conn); + REF_RETAIN(session); + msg_info_session("send request to %s (using keepalive)", m->name); + + /* + * We have found the existing keepalive connection, so we can + * process another mirror + */ + continue; + } + } + } + } + + /* Non-keepalive connection */ + bk_conn = rspamd_mempool_alloc0(session->pool, sizeof(*bk_conn)); bk_conn->s = session; @@ -1472,7 +1825,9 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) rspamd_http_message_remove_header(msg, "Host"); rspamd_http_message_add_header(msg, "Host", up_name); } - rspamd_http_message_add_header(msg, "Connection", "close"); + rspamd_http_message_remove_header(msg, "Connection"); + rspamd_http_message_add_header(msg, "Connection", + m->keepalive ? "keep-alive" : "close"); if (msg->url->len == 0) { msg->url = rspamd_fstring_append(msg->url, "/check", strlen("/check")); @@ -1483,12 +1838,38 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) rspamd_http_message_add_header(msg, "Settings-ID", m->settings_id); } + /* Add extra headers if specified */ + if (m->extra_headers != NULL) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + const char *key, *value; + + while ((cur = ucl_object_iterate(m->extra_headers, &it, true)) != NULL) { + key = ucl_object_key(cur); + value = ucl_object_tostring(cur); + + if (key != NULL && value != NULL) { + rspamd_http_message_remove_header(msg, key); + rspamd_http_message_add_header(msg, key, value); + } + } + } + + /* Add log tag header based on mirror's configuration */ + rspamd_proxy_add_log_tag_header(msg, session, m->log_tag_type); + + unsigned int http_opts = RSPAMD_HTTP_CLIENT_SIMPLE; + + if (m->ssl) { + http_opts |= RSPAMD_HTTP_CLIENT_SSL; + } + bk_conn->backend_conn = rspamd_http_connection_new_client_socket( session->ctx->http_ctx, NULL, proxy_backend_mirror_error_handler, proxy_backend_mirror_finish_handler, - RSPAMD_HTTP_CLIENT_SIMPLE, + http_opts, bk_conn->backend_sock); if (m->key) { @@ -1600,8 +1981,9 @@ proxy_backend_master_error_handler(struct rspamd_http_connection *conn, GError * session->retries++; msg_info_session("abnormally closing connection from backend: %s, error: %e," " retries left: %d", - rspamd_inet_address_to_string_pretty( - rspamd_upstream_addr_cur(session->master_conn->up)), + session->master_conn->up ? rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->master_conn->up)) + : "self-scan", err, session->ctx->max_retries - session->retries); rspamd_upstream_fail(bk_conn->up, FALSE, err ? err->message : "unknown"); @@ -1632,8 +2014,9 @@ proxy_backend_master_error_handler(struct rspamd_http_connection *conn, GError * else { msg_info_session("retry connection to: %s" " retries left: %d", - rspamd_inet_address_to_string( - rspamd_upstream_addr_cur(session->master_conn->up)), + session->master_conn->up ? rspamd_inet_address_to_string( + rspamd_upstream_addr_cur(session->master_conn->up)) + : "self-scan", session->ctx->max_retries - session->retries); } } @@ -1647,7 +2030,9 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, struct rspamd_proxy_session *session, *nsession; rspamd_fstring_t *reply; const rspamd_ftok_t *orig_ct; + const rspamd_ftok_t *conn_hdr; goffset body_offset = -1; + gboolean is_keepalive = FALSE; session = bk_conn->s; rspamd_http_connection_steal_msg(session->master_conn->backend_conn); @@ -1663,6 +2048,16 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_http_message_remove_header(msg, "Server"); rspamd_http_message_remove_header(msg, "Key"); orig_ct = rspamd_http_message_find_header(msg, "Content-Type"); + + /* Check if we can use keepalive */ + conn_hdr = rspamd_http_message_find_header(msg, "Connection"); + if (conn_hdr) { + if (rspamd_substring_search_caseless(conn_hdr->begin, conn_hdr->len, + "keep-alive", 10) != -1) { + is_keepalive = TRUE; + } + } + rspamd_http_connection_reset(session->master_conn->backend_conn); if (!proxy_backend_parse_results(session, bk_conn, session->ctx->lua_state, @@ -1695,6 +2090,22 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_upstream_ok(bk_conn->up); + /* Handle keepalive for master connection */ + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + /* Store connection in keepalive pool */ + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_prepare_keepalive(session->ctx->http_ctx, + conn, rspamd_upstream_addr_cur(bk_conn->up), + up_name, FALSE); + + /* We'll push to keepalive pool after we're done with the response */ + msg_debug_session("will push master connection to %s to keepalive pool", + up_name); + } + } + if (session->client_milter_conn) { nsession = proxy_session_refresh(session); @@ -1708,6 +2119,20 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_milter_send_task_results(nsession->client_milter_conn, session->master_conn->results, NULL, 0); } + + /* Push to keepalive if needed */ + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_push_keepalive(session->ctx->http_ctx, + conn, msg, session->ctx->event_loop); + + /* Mark connection as closed without actually closing it */ + bk_conn->flags |= RSPAMD_BACKEND_CLOSED; + } + } + REF_RELEASE(session); rspamd_http_message_free(msg); } @@ -1723,6 +2148,19 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_http_connection_write_message(session->client_conn, msg, NULL, passed_ct, session, bk_conn->timeout); + + /* Push to keepalive if needed */ + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_push_keepalive(session->ctx->http_ctx, + conn, msg, session->ctx->event_loop); + + /* Mark connection as closed without actually closing it */ + bk_conn->flags |= RSPAMD_BACKEND_CLOSED; + } + } } return 0; @@ -1982,6 +2420,9 @@ proxy_send_master_message(struct rspamd_proxy_session *session) /* Remove the original `Connection` header */ rspamd_http_message_remove_header(session->client_message, "Connection"); + /* Set keepalive flag based on backend configuration */ + session->use_keepalive = backend ? backend->keepalive : FALSE; + if (backend == NULL) { /* No backend */ msg_err_session("cannot find upstream for %s", host ? hostbuf : "default"); @@ -2063,14 +2504,21 @@ proxy_send_master_message(struct rspamd_proxy_session *session) if (up_name) { rspamd_http_message_add_header(msg, "Host", up_name); } - rspamd_http_message_add_header(msg, "Connection", "close"); + rspamd_http_message_add_header(msg, "Connection", + backend->keepalive ? "keep-alive" : "close"); + + unsigned int http_opts = RSPAMD_HTTP_CLIENT_SIMPLE; + + if (backend->ssl) { + http_opts |= RSPAMD_HTTP_CLIENT_SSL; + } session->master_conn->backend_conn = rspamd_http_connection_new_client_socket( session->ctx->http_ctx, NULL, proxy_backend_master_error_handler, proxy_backend_master_finish_handler, - RSPAMD_HTTP_CLIENT_SIMPLE, + http_opts, session->master_conn->backend_sock); session->master_conn->flags &= ~RSPAMD_BACKEND_CLOSED; session->master_conn->parser_from_ref = backend->parser_from_ref; @@ -2086,6 +2534,26 @@ proxy_send_master_message(struct rspamd_proxy_session *session) backend->settings_id); } + /* Add extra headers if specified */ + if (backend->extra_headers != NULL) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + const char *key, *value; + + while ((cur = ucl_object_iterate(backend->extra_headers, &it, true)) != NULL) { + key = ucl_object_key(cur); + value = ucl_object_tostring(cur); + + if (key != NULL && value != NULL) { + rspamd_http_message_remove_header(msg, key); + rspamd_http_message_add_header(msg, key, value); + } + } + } + + /* Add log tag header based on backend's configuration */ + rspamd_proxy_add_log_tag_header(msg, session, backend->log_tag_type); + if (backend->local || rspamd_inet_address_is_local( rspamd_upstream_addr_cur( @@ -2206,8 +2674,6 @@ proxy_client_finish_handler(struct rspamd_http_connection *conn, rspamd_http_message_remove_header(msg, "Keep-Alive"); rspamd_http_message_remove_header(msg, "Connection"); rspamd_http_message_remove_header(msg, "Key"); - rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, - strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); proxy_open_mirror_connections(session); rspamd_http_connection_reset(session->client_conn); @@ -2216,8 +2682,9 @@ proxy_client_finish_handler(struct rspamd_http_connection *conn, } else { msg_info_session("finished master connection to %s; HTTP code: %d", - rspamd_inet_address_to_string_pretty( - rspamd_upstream_addr_cur(session->master_conn->up)), + session->master_conn->up ? rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->master_conn->up)) + : "self-scan", msg->code); proxy_backend_close_connection(session->master_conn); REF_RELEASE(session); |