diff options
Diffstat (limited to 'src')
83 files changed, 4367 insertions, 1325 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 92edb0b6a..6cc49e4e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,246 +1,286 @@ -MACRO(_AddModulesForced MLIST) -# Generate unique string for this build - SET(MODULES_C "${CMAKE_CURRENT_BINARY_DIR}/modules.c") - FILE(WRITE "${MODULES_C}" - "#include \"rspamd.h\"\n") - - # Handle even old cmake - LIST(LENGTH ${MLIST} MLIST_COUNT) - MATH(EXPR MLIST_MAX ${MLIST_COUNT}-1) - - FOREACH(MOD_IDX RANGE ${MLIST_MAX}) - LIST(GET ${MLIST} ${MOD_IDX} MOD) - FILE(APPEND "${MODULES_C}" "extern module_t ${MOD}_module;\n") - ENDFOREACH(MOD_IDX RANGE ${MLIST_MAX}) - - FILE(APPEND "${MODULES_C}" "\n\nmodule_t *modules[] = {\n") - - FOREACH(MOD_IDX RANGE ${MLIST_MAX}) - LIST(GET ${MLIST} ${MOD_IDX} MOD) - FILE(APPEND "${MODULES_C}" "&${MOD}_module,\n") - ENDFOREACH(MOD_IDX RANGE ${MLIST_MAX}) - - FILE(APPEND "${MODULES_C}" "NULL\n};\n") -ENDMACRO(_AddModulesForced MLIST) - -MACRO(_AddWorkersForced WLIST) - SET(WORKERS_C "${CMAKE_CURRENT_BINARY_DIR}/workers.c") - FILE(WRITE "${WORKERS_C}" - "#include \"rspamd.h\"\n") - - # Handle even old cmake - LIST(LENGTH ${WLIST} WLIST_COUNT) - MATH(EXPR WLIST_MAX ${WLIST_COUNT}-1) - FOREACH(MOD_IDX RANGE ${WLIST_MAX}) - LIST(GET ${WLIST} ${MOD_IDX} WRK) - FILE(APPEND "${WORKERS_C}" "extern worker_t ${WRK}_worker;\n") - ENDFOREACH(MOD_IDX RANGE ${WLIST_MAX}) - - FILE(APPEND "${WORKERS_C}" "\n\nworker_t *workers[] = {\n") - - FOREACH(MOD_IDX RANGE ${WLIST_MAX}) - LIST(GET ${WLIST} ${MOD_IDX} WRK) - FILE(APPEND "${WORKERS_C}" "&${WRK}_worker,\n") - ENDFOREACH(MOD_IDX RANGE ${WLIST_MAX}) - FILE(APPEND "${WORKERS_C}" "NULL\n};\n") -ENDMACRO(_AddWorkersForced WLIST) - -MACRO(AddModules MLIST WLIST) - _AddModulesForced(${MLIST}) - _AddWorkersForced(${WLIST}) - #IF(NOT EXISTS "modules.c") - # _AddModulesForced(${MLIST} ${WLIST}) - #ELSE(NOT EXISTS "modules.c") - # FILE(STRINGS "modules.c" FILE_ID_RAW REGEX "^/.*[a-zA-Z0-9]+.*/$") - # STRING(REGEX MATCH "[a-zA-Z0-9]+" FILE_ID "${FILE_ID_RAW}") - # IF(NOT FILE_ID STREQUAL MODULES_ID) - # MESSAGE("Regenerate modules info") - # _AddModulesForced(${MLIST} ${WLIST}) - # ENDIF(NOT FILE_ID STREQUAL MODULES_ID) - #ENDIF(NOT EXISTS "modules.c") -ENDMACRO(AddModules MLIST WLIST) - -# Rspamd core components -IF (ENABLE_CLANG_PLUGIN MATCHES "ON") - SET(CMAKE_C_FLAGS - "${CMAKE_C_FLAGS} -Xclang -load -Xclang ${CMAKE_CURRENT_BINARY_DIR}/../clang-plugin/librspamd-clang${CMAKE_SHARED_LIBRARY_SUFFIX} -Xclang -add-plugin -Xclang rspamd-ast") - IF(CLANG_EXTRA_PLUGINS_LIBS) - FOREACH(_lib ${CLANG_EXTRA_PLUGINS_LIBS}) - SET(CMAKE_C_FLAGS - "${CMAKE_C_FLAGS} -Xclang -load -Xclang ${_lib}") - SET(CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} -Xclang -load -Xclang ${_lib}") - ENDFOREACH() - ENDIF() - IF(CLANG_EXTRA_PLUGINS) - FOREACH(_plug ${CLANG_EXTRA_PLUGINS}) - SET(CMAKE_C_FLAGS - "${CMAKE_C_FLAGS} -Xclang -add-plugin -Xclang ${_plug}") - SET(CMAKE_CXX_FLAGS - "${CMAKE_C_FLAGS} -Xclang -add-plugin -Xclang ${_plug}") - ENDFOREACH() - ENDIF() -ENDIF () - -ADD_SUBDIRECTORY(lua) -ADD_SUBDIRECTORY(libcryptobox) -ADD_SUBDIRECTORY(libutil) -ADD_SUBDIRECTORY(libserver) -ADD_SUBDIRECTORY(libmime) -ADD_SUBDIRECTORY(libstat) -ADD_SUBDIRECTORY(client) -ADD_SUBDIRECTORY(rspamadm) - -SET(RSPAMDSRC controller.c - fuzzy_storage.c - rspamd.c - worker.c - rspamd_proxy.c) - -SET(PLUGINSSRC plugins/regexp.c - plugins/chartable.cxx - plugins/fuzzy_check.c - plugins/dkim_check.c - libserver/rspamd_control.c) - -SET(MODULES_LIST regexp chartable fuzzy_check dkim) -SET(WORKERS_LIST normal controller fuzzy rspamd_proxy) -IF (ENABLE_HYPERSCAN MATCHES "ON") - LIST(APPEND WORKERS_LIST "hs_helper") - LIST(APPEND RSPAMDSRC "hs_helper.c") -ENDIF() - -AddModules(MODULES_LIST WORKERS_LIST) -LIST(LENGTH PLUGINSSRC RSPAMD_MODULES_NUM) - -SET(RAGEL_DEPENDS "${CMAKE_SOURCE_DIR}/src/ragel/smtp_address.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/smtp_date.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/smtp_base.rl" - "${CMAKE_SOURCE_DIR}/src/ragel/content_disposition.rl") -RAGEL_TARGET(ragel_smtp_addr - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_addr_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -T1 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/smtp_addr_parser.rl.c) -RAGEL_TARGET(ragel_content_disposition - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/content_disposition_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/content_disposition.rl.c) -RAGEL_TARGET(ragel_rfc2047 - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/rfc2047_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/rfc2047.rl.c) -RAGEL_TARGET(ragel_smtp_date - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_date_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/date_parser.rl.c) -RAGEL_TARGET(ragel_smtp_ip - INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip_parser.rl - DEPENDS ${RAGEL_DEPENDS} - COMPILE_FLAGS -G2 - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ip_parser.rl.c) -# Fucking cmake... -FOREACH(_GEN ${LIBSERVER_GENERATED}) - set_source_files_properties(${_GEN} PROPERTIES GENERATED TRUE) -ENDFOREACH() -######################### LINK SECTION ############################### - -IF(ENABLE_STATIC MATCHES "ON") - ADD_LIBRARY(rspamd-server STATIC - ${RSPAMD_CRYPTOBOX} - ${RSPAMD_UTIL} - ${RSPAMD_LUA} - ${RSPAMD_SERVER} - ${RSPAMD_STAT} - ${RSPAMD_MIME} - ${CMAKE_CURRENT_BINARY_DIR}/modules.c - ${PLUGINSSRC} - "${RAGEL_ragel_smtp_addr_OUTPUTS}" - "${RAGEL_ragel_newlines_strip_OUTPUTS}" - "${RAGEL_ragel_content_type_OUTPUTS}" - "${RAGEL_ragel_content_disposition_OUTPUTS}" - "${RAGEL_ragel_rfc2047_OUTPUTS}" - "${RAGEL_ragel_smtp_date_OUTPUTS}" - "${RAGEL_ragel_smtp_ip_OUTPUTS}" - ${BACKWARD_ENABLE}) -ELSE() - ADD_LIBRARY(rspamd-server SHARED - ${RSPAMD_CRYPTOBOX} - ${RSPAMD_UTIL} - ${RSPAMD_SERVER} - ${RSPAMD_STAT} - ${RSPAMD_MIME} - ${RSPAMD_LUA} - ${CMAKE_CURRENT_BINARY_DIR}/modules.c - ${PLUGINSSRC} - "${RAGEL_ragel_smtp_addr_OUTPUTS}" - "${RAGEL_ragel_newlines_strip_OUTPUTS}" - "${RAGEL_ragel_content_type_OUTPUTS}" - "${RAGEL_ragel_content_disposition_OUTPUTS}" - "${RAGEL_ragel_rfc2047_OUTPUTS}" - "${RAGEL_ragel_smtp_date_OUTPUTS}" - "${RAGEL_ragel_smtp_ip_OUTPUTS}" - ${BACKWARD_ENABLE}) -ENDIF() - -FOREACH(_DEP ${LIBSERVER_DEPENDS}) - ADD_DEPENDENCIES(rspamd-server "${_DEP}") -ENDFOREACH() - -TARGET_LINK_LIBRARIES(rspamd-server rspamd-http-parser) -TARGET_LINK_LIBRARIES(rspamd-server rspamd-fpconv) -TARGET_LINK_LIBRARIES(rspamd-server rspamd-cdb) -TARGET_LINK_LIBRARIES(rspamd-server rspamd-lpeg) -TARGET_LINK_LIBRARIES(rspamd-server lcbtrie) -IF(SYSTEM_ZSTD MATCHES "OFF") - TARGET_LINK_LIBRARIES(rspamd-server rspamd-zstd) -ELSE() - TARGET_LINK_LIBRARIES(rspamd-server zstd) -ENDIF() -TARGET_LINK_LIBRARIES(rspamd-server rspamd-simdutf) - -IF (ENABLE_CLANG_PLUGIN MATCHES "ON") - ADD_DEPENDENCIES(rspamd-server rspamd-clang) -ENDIF() - -IF (NOT WITH_LUAJIT) - TARGET_LINK_LIBRARIES(rspamd-server rspamd-bit) -ENDIF() - -IF (ENABLE_SNOWBALL MATCHES "ON") - TARGET_LINK_LIBRARIES(rspamd-server stemmer) -ENDIF() -TARGET_LINK_LIBRARIES(rspamd-server rspamd-hiredis) - -IF (ENABLE_FANN MATCHES "ON") - TARGET_LINK_LIBRARIES(rspamd-server fann) -ENDIF () - -IF (ENABLE_HYPERSCAN MATCHES "ON") - TARGET_LINK_LIBRARIES(rspamd-server hs) -ENDIF() - -IF(WITH_BLAS) - TARGET_LINK_LIBRARIES(rspamd-server ${BLAS_REQUIRED_LIBRARIES}) -ENDIF() - -TARGET_LINK_LIBRARIES(rspamd-server ${RSPAMD_REQUIRED_LIBRARIES}) -ADD_BACKWARD(rspamd-server) - -ADD_EXECUTABLE(rspamd ${RSPAMDSRC} ${CMAKE_CURRENT_BINARY_DIR}/workers.c ${CMAKE_CURRENT_BINARY_DIR}/config.h) -ADD_BACKWARD(rspamd) -SET_TARGET_PROPERTIES(rspamd PROPERTIES LINKER_LANGUAGE CXX) -SET_TARGET_PROPERTIES(rspamd-server PROPERTIES LINKER_LANGUAGE CXX) -IF(NOT NO_TARGET_VERSIONS) - SET_TARGET_PROPERTIES(rspamd PROPERTIES VERSION ${RSPAMD_VERSION}) -ENDIF() - -#TARGET_LINK_LIBRARIES(rspamd ${RSPAMD_REQUIRED_LIBRARIES}) -TARGET_LINK_LIBRARIES(rspamd rspamd-server) - -INSTALL(TARGETS rspamd RUNTIME DESTINATION bin) -INSTALL(TARGETS rspamd-server LIBRARY DESTINATION ${RSPAMD_LIBDIR})
\ No newline at end of file +# Function to generate module registrations +function(generate_modules_list MODULE_LIST) + # Generate unique string for this build + set(MODULES_C "${CMAKE_CURRENT_BINARY_DIR}/modules.c") + file(WRITE "${MODULES_C}" + "#include \"rspamd.h\"\n") + + # Process each module + foreach (MOD IN LISTS ${MODULE_LIST}) + file(APPEND "${MODULES_C}" "extern module_t ${MOD}_module;\n") + endforeach () + + file(APPEND "${MODULES_C}" "\n\nmodule_t *modules[] = {\n") + + foreach (MOD IN LISTS ${MODULE_LIST}) + file(APPEND "${MODULES_C}" "&${MOD}_module,\n") + endforeach () + + file(APPEND "${MODULES_C}" "NULL\n};\n") + + # Return the generated file path + set(MODULES_C_PATH "${MODULES_C}" PARENT_SCOPE) +endfunction() + +# Function to generate worker registrations +function(generate_workers_list WORKER_LIST) + set(WORKERS_C "${CMAKE_CURRENT_BINARY_DIR}/workers.c") + file(WRITE "${WORKERS_C}" + "#include \"rspamd.h\"\n") + + # Process each worker + foreach (WRK IN LISTS ${WORKER_LIST}) + file(APPEND "${WORKERS_C}" "extern worker_t ${WRK}_worker;\n") + endforeach () + + file(APPEND "${WORKERS_C}" "\n\nworker_t *workers[] = {\n") + + foreach (WRK IN LISTS ${WORKER_LIST}) + file(APPEND "${WORKERS_C}" "&${WRK}_worker,\n") + endforeach () + + file(APPEND "${WORKERS_C}" "NULL\n};\n") + + # Return the generated file path + set(WORKERS_C_PATH "${WORKERS_C}" PARENT_SCOPE) +endfunction() + +# Function to generate both modules and workers +function(generate_registration_code MODULE_LIST WORKER_LIST) + generate_modules_list(${MODULE_LIST}) + generate_workers_list(${WORKER_LIST}) + + # Set parent scope variables + set(MODULES_C_PATH ${MODULES_C_PATH} PARENT_SCOPE) + set(WORKERS_C_PATH ${WORKERS_C_PATH} PARENT_SCOPE) +endfunction() + +# Configure Clang Plugin if enabled +if (ENABLE_CLANG_PLUGIN) + set(CLANG_PLUGIN_FLAGS "-Xclang -load -Xclang ${CMAKE_CURRENT_BINARY_DIR}/../clang-plugin/librspamd-clang${CMAKE_SHARED_LIBRARY_SUFFIX} -Xclang -add-plugin -Xclang rspamd-ast") + + # Apply to both C and C++ compiler flags + add_compile_options(${CLANG_PLUGIN_FLAGS}) + + # Add any extra clang plugins + if (CLANG_EXTRA_PLUGINS_LIBS) + foreach (lib ${CLANG_EXTRA_PLUGINS_LIBS}) + add_compile_options("-Xclang" "-load" "-Xclang" "${lib}") + endforeach () + endif () + + if (CLANG_EXTRA_PLUGINS) + foreach (plug ${CLANG_EXTRA_PLUGINS}) + add_compile_options("-Xclang" "-add-plugin" "-Xclang" "${plug}") + endforeach () + endif () +endif () + +# Add subdirectories for components +add_subdirectory(lua) +add_subdirectory(libcryptobox) +add_subdirectory(libutil) +add_subdirectory(libserver) +add_subdirectory(libmime) +add_subdirectory(libstat) +add_subdirectory(client) +add_subdirectory(rspamadm) + +# Define source files +set(RSPAMD_SOURCES + controller.c + fuzzy_storage.c + rspamd.c + worker.c + rspamd_proxy.c) + +set(PLUGIN_SOURCES + plugins/regexp.c + plugins/chartable.cxx + plugins/fuzzy_check.c + plugins/dkim_check.c + libserver/rspamd_control.c) + +# Define module and worker lists +set(MODULES_LIST regexp chartable fuzzy_check dkim) +set(WORKERS_LIST normal controller fuzzy rspamd_proxy) + +# Add hyperscan worker if enabled +if (ENABLE_HYPERSCAN) + list(APPEND WORKERS_LIST hs_helper) + list(APPEND RSPAMD_SOURCES hs_helper.c) +endif () + +# Generate modules and workers registration code +generate_registration_code(MODULES_LIST WORKERS_LIST) + +# Count the number of modules +list(LENGTH PLUGIN_SOURCES RSPAMD_MODULES_NUM) + +# Configure Ragel for parsers +set(RAGEL_DEPENDS + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_address.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_date.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/smtp_base.rl" + "${CMAKE_SOURCE_DIR}/src/ragel/content_disposition.rl") + +# Generate parsers with Ragel +ragel_target(ragel_smtp_addr + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_addr_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -T1 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/smtp_addr_parser.rl.c) + +ragel_target(ragel_content_disposition + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/content_disposition_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/content_disposition.rl.c) + +ragel_target(ragel_rfc2047 + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/rfc2047_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/rfc2047.rl.c) + +ragel_target(ragel_smtp_date + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_date_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/date_parser.rl.c) + +ragel_target(ragel_smtp_ip + INPUTS ${CMAKE_SOURCE_DIR}/src/ragel/smtp_ip_parser.rl + DEPENDS ${RAGEL_DEPENDS} + COMPILE_FLAGS -G2 + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ip_parser.rl.c) + +# Mark generated files correctly +foreach (_gen ${LIBSERVER_GENERATED}) + set_source_files_properties(${_gen} PROPERTIES GENERATED TRUE) +endforeach () + +# Collection of all generated Ragel outputs +set(RAGEL_OUTPUTS + ${RAGEL_ragel_smtp_addr_OUTPUTS} + ${RAGEL_ragel_newlines_strip_OUTPUTS} + ${RAGEL_ragel_content_type_OUTPUTS} + ${RAGEL_ragel_content_disposition_OUTPUTS} + ${RAGEL_ragel_rfc2047_OUTPUTS} + ${RAGEL_ragel_smtp_date_OUTPUTS} + ${RAGEL_ragel_smtp_ip_OUTPUTS}) + +# Common sources for rspamd-server +set(SERVER_COMMON_SOURCES + ${RSPAMD_CRYPTOBOX} + ${RSPAMD_UTIL} + ${RSPAMD_LUA} + ${RSPAMD_SERVER} + ${RSPAMD_STAT} + ${RSPAMD_MIME} + ${MODULES_C_PATH} + ${PLUGIN_SOURCES} + ${RAGEL_OUTPUTS} + ${BACKWARD_ENABLE}) + +# Build rspamd-server as static or shared library based on configuration +if (ENABLE_STATIC) + add_library(rspamd-server STATIC ${SERVER_COMMON_SOURCES}) +else () + add_library(rspamd-server SHARED ${SERVER_COMMON_SOURCES}) +endif () + +# Set dependencies for rspamd-server +foreach (_dep ${LIBSERVER_DEPENDS}) + add_dependencies(rspamd-server "${_dep}") +endforeach () + +# Link dependencies +target_link_libraries(rspamd-server + PRIVATE + rspamd-http-parser + rspamd-fpconv + rspamd-cdb + rspamd-lpeg + ottery + lcbtrie + rspamd-simdutf + rdns + ucl) + +# Handle xxhash dependency +if (SYSTEM_XXHASH) + target_link_libraries(rspamd-server PUBLIC xxhash) +else () + target_link_libraries(rspamd-server PUBLIC rspamd-xxhash) +endif () + +# Handle zstd dependency +if (SYSTEM_ZSTD) + target_link_libraries(rspamd-server PUBLIC zstd) +else () + target_link_libraries(rspamd-server PRIVATE rspamd-zstd) +endif () + +# Handle clang plugin dependency +if (ENABLE_CLANG_PLUGIN) + add_dependencies(rspamd-server rspamd-clang) +endif () + +# Handle Lua JIT/Lua dependency +if (NOT WITH_LUAJIT) + target_link_libraries(rspamd-server PRIVATE rspamd-bit) +endif () + +# Link additional optional dependencies +if (ENABLE_SNOWBALL) + target_link_libraries(rspamd-server PRIVATE stemmer) +endif () + +target_link_libraries(rspamd-server PRIVATE rspamd-hiredis) + +if (ENABLE_FANN) + target_link_libraries(rspamd-server PRIVATE fann) +endif () + +if (ENABLE_HYPERSCAN) + target_link_libraries(rspamd-server PUBLIC hs) +endif () + +if (WITH_BLAS) + target_link_libraries(rspamd-server PRIVATE ${BLAS_REQUIRED_LIBRARIES}) +endif () + +# Link all required system libraries +target_link_libraries(rspamd-server PUBLIC ${RSPAMD_REQUIRED_LIBRARIES}) + +# Add Backward support for stacktrace +add_backward(rspamd-server) + +# Build main rspamd executable +add_executable(rspamd + ${RSPAMD_SOURCES} + ${WORKERS_C_PATH} + ${CMAKE_CURRENT_BINARY_DIR}/config.h) + +# Configure rspamd executable +add_backward(rspamd) +set_target_properties(rspamd PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(rspamd-server PROPERTIES LINKER_LANGUAGE CXX) + +if (NOT NO_TARGET_VERSIONS) + set_target_properties(rspamd PROPERTIES VERSION ${RSPAMD_VERSION}) +endif () + +# Link rspamd executable with the server library +target_link_libraries(rspamd PRIVATE rspamd-server) + +# Install targets +install(TARGETS rspamd + RUNTIME + DESTINATION bin) + +install(TARGETS rspamd-server + LIBRARY + DESTINATION ${RSPAMD_LIBDIR}) diff --git a/src/client/rspamdclient.c b/src/client/rspamdclient.c index d07b24332..4d79590c5 100644 --- a/src/client/rspamdclient.c +++ b/src/client/rspamdclient.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -231,7 +231,7 @@ rspamd_client_finish_handler(struct rspamd_http_connection *conn, } } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk_full(parser, start, len, ucl_parser_get_default_priority(parser), UCL_DUPLICATE_APPEND, UCL_PARSE_AUTO)) { diff --git a/src/controller.c b/src/controller.c index 386448f93..0550ba6b8 100644 --- a/src/controller.c +++ b/src/controller.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -68,6 +68,7 @@ #define PATH_NEIGHBOURS "/neighbours" #define PATH_PLUGINS "/plugins" #define PATH_PING "/ping" +#define PATH_BAYES_CLASSIFIERS "/bayes/classifiers" #define msg_err_session(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ session->pool->tag.tagname, session->pool->tag.uid, \ @@ -979,12 +980,6 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, if (bk->protocol == MAP_PROTO_FILE) { editable = rspamd_controller_can_edit_map(bk); - - if (!editable && access(bk->uri, R_OK) == -1) { - /* Skip unreadable and non-existing maps */ - continue; - } - obj = ucl_object_typed_new(UCL_OBJECT); ucl_object_insert_key(obj, ucl_object_fromint(bk->id), "map", 0, false); @@ -994,8 +989,34 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, } ucl_object_insert_key(obj, ucl_object_fromstring(bk->uri), "uri", 0, false); + ucl_object_insert_key(obj, ucl_object_fromstring("file"), + "type", 0, false); ucl_object_insert_key(obj, ucl_object_frombool(editable), "editable", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->loaded), + "loaded", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->cached), + "cached", 0, false); + ucl_array_append(top, obj); + } + else { + obj = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(obj, ucl_object_fromint(bk->id), + "map", 0, false); + if (map->description) { + ucl_object_insert_key(obj, ucl_object_fromstring(map->description), + "description", 0, false); + } + ucl_object_insert_key(obj, ucl_object_fromstring(bk->uri), + "uri", 0, false); + ucl_object_insert_key(obj, ucl_object_fromstring(rspamd_map_fetch_protocol_name(bk->protocol)), + "type", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(false), + "editable", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->loaded), + "loaded", 0, false); + ucl_object_insert_key(obj, ucl_object_frombool(map->shared->cached), + "cached", 0, false); ucl_array_append(top, obj); } } @@ -1008,6 +1029,21 @@ rspamd_controller_handle_maps(struct rspamd_http_connection_entry *conn_ent, return 0; } +gboolean +rspamd_controller_map_traverse_callback(gconstpointer key, gconstpointer value, gsize _hits, gpointer ud) +{ + rspamd_fstring_t **target = (rspamd_fstring_t **) ud; + + *target = rspamd_fstring_append(*target, key, strlen(key)); + + if (value) { + *target = rspamd_fstring_append(*target, " ", 1); + *target = rspamd_fstring_append(*target, value, strlen(value)); + } + *target = rspamd_fstring_append(*target, "\n", 1); + + return TRUE; +} /* * Get map command handler: * request: /getmap @@ -1020,7 +1056,7 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, { struct rspamd_controller_session *session = conn_ent->ud; GList *cur; - struct rspamd_map *map; + struct rspamd_map *map = NULL; struct rspamd_map_backend *bk = NULL; const rspamd_ftok_t *idstr; struct stat st; @@ -1054,7 +1090,7 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, PTR_ARRAY_FOREACH(map->backends, i, bk) { - if (bk->id == id && bk->protocol == MAP_PROTO_FILE) { + if (bk->id == id) { found = TRUE; break; } @@ -1069,32 +1105,53 @@ rspamd_controller_handle_get_map(struct rspamd_http_connection_entry *conn_ent, return 0; } - if (stat(bk->uri, &st) == -1 || (fd = open(bk->uri, O_RDONLY)) == -1) { + if (bk->protocol == MAP_PROTO_FILE) { + if (stat(bk->uri, &st) == -1 || (fd = open(bk->uri, O_RDONLY)) == -1) { + reply = rspamd_http_new_message(HTTP_RESPONSE); + reply->date = time(NULL); + reply->code = 200; + } + else { + + reply = rspamd_http_new_message(HTTP_RESPONSE); + reply->date = time(NULL); + reply->code = 200; + + if (st.st_size > 0) { + if (!rspamd_http_message_set_body_from_fd(reply, fd)) { + close(fd); + rspamd_http_message_unref(reply); + msg_err_session("cannot read map %s: %s", bk->uri, strerror(errno)); + rspamd_controller_send_error(conn_ent, 500, "Map read error"); + return 0; + } + } + else { + rspamd_fstring_t *empty_body = rspamd_fstring_new_init("", 0); + rspamd_http_message_set_body_from_fstring_steal(reply, empty_body); + } + + close(fd); + } + } + else if (bk->protocol == MAP_PROTO_STATIC) { + /* We can just traverse map and form reply */ reply = rspamd_http_new_message(HTTP_RESPONSE); - reply->date = time(NULL); reply->code = 200; + rspamd_fstring_t *map_body = rspamd_fstring_new(); + rspamd_map_traverse(bk->map, rspamd_controller_map_traverse_callback, &map_body, FALSE); + rspamd_http_message_set_body_from_fstring_steal(reply, map_body); } - else { - + else if (map->shared->loaded) { reply = rspamd_http_new_message(HTTP_RESPONSE); - reply->date = time(NULL); reply->code = 200; - - if (st.st_size > 0) { - if (!rspamd_http_message_set_body_from_fd(reply, fd)) { - close(fd); - rspamd_http_message_unref(reply); - msg_err_session("cannot read map %s: %s", bk->uri, strerror(errno)); - rspamd_controller_send_error(conn_ent, 500, "Map read error"); - return 0; - } - } - else { - rspamd_fstring_t *empty_body = rspamd_fstring_new_init("", 0); - rspamd_http_message_set_body_from_fstring_steal(reply, empty_body); - } - - close(fd); + rspamd_fstring_t *map_body = rspamd_fstring_new(); + rspamd_map_traverse(bk->map, rspamd_controller_map_traverse_callback, &map_body, FALSE); + rspamd_http_message_set_body_from_fstring_steal(reply, map_body); + } + else { + reply = rspamd_http_new_message(HTTP_RESPONSE); + reply->code = 404; } rspamd_http_connection_reset(conn_ent->conn); @@ -2255,7 +2312,7 @@ rspamd_controller_handle_saveactions( return 0; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, msg->body_buf.begin, msg->body_buf.len)) { if ((error = ucl_parser_get_error(parser)) != NULL) { msg_err_session("cannot parse input: %s", error); @@ -2378,7 +2435,7 @@ rspamd_controller_handle_savesymbols( return 0; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, msg->body_buf.begin, msg->body_buf.len)) { if ((error = ucl_parser_get_error(parser)) != NULL) { msg_err_session("cannot parse input: %s", error); @@ -3235,7 +3292,7 @@ rspamd_controller_handle_unknown(struct rspamd_http_connection_entry *conn_ent, rspamd_http_message_add_header(rep, "Access-Control-Allow-Methods", "POST, GET, OPTIONS"); rspamd_http_message_add_header(rep, "Access-Control-Allow-Headers", - "Content-Type,Password,Map,Weight,Flag"); + "Classifier,Content-Type,Password,Map,Weight,Flag,Hash"); rspamd_http_connection_reset(conn_ent->conn); rspamd_http_router_insert_headers(conn_ent->rt, rep); rspamd_http_connection_write_message(conn_ent->conn, @@ -3390,6 +3447,40 @@ rspamd_controller_handle_lua_plugin(struct rspamd_http_connection_entry *conn_en return 0; } +/* + * Bayes classifier list command handler: + * request: /bayes/classifiers + * headers: Password + * reply: JSON array of Bayes classifier names + * Note: list is in reverse of declaration order (GList prepend). + */ +static int +rspamd_controller_handle_bayes_classifiers(struct rspamd_http_connection_entry *conn_ent, + struct rspamd_http_message *msg) +{ + struct rspamd_controller_session *session = conn_ent->ud; + struct rspamd_controller_worker_ctx *ctx = session->ctx; + ucl_object_t *arr; + struct rspamd_classifier_config *clc; + GList *cur; + + if (!rspamd_controller_check_password(conn_ent, session, msg, FALSE)) { + return 0; + } + + arr = ucl_object_typed_new(UCL_ARRAY); + cur = g_list_last(ctx->cfg->classifiers); + while (cur) { + clc = cur->data; + ucl_array_append(arr, ucl_object_fromstring(clc->name)); + cur = g_list_previous(cur); + } + + rspamd_controller_send_ucl(conn_ent, arr); + ucl_object_unref(arr); + return 0; +} + static void rspamd_controller_error_handler(struct rspamd_http_connection_entry *conn_ent, @@ -3999,6 +4090,9 @@ start_controller_worker(struct rspamd_worker *worker) rspamd_http_router_add_path(ctx->http, PATH_PING, rspamd_controller_handle_ping); + rspamd_http_router_add_path(ctx->http, + PATH_BAYES_CLASSIFIERS, + rspamd_controller_handle_bayes_classifiers); rspamd_controller_register_plugins_paths(ctx); #if 0 diff --git a/src/fuzzy_storage.c b/src/fuzzy_storage.c index 919ea2118..58d123712 100644 --- a/src/fuzzy_storage.c +++ b/src/fuzzy_storage.c @@ -342,7 +342,7 @@ ucl_keymap_fin_cb(struct map_cb_data *data, void **target) return; } - parser = ucl_parser_new(UCL_PARSER_NO_FILEVARS); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, jb->buf->str, jb->buf->len)) { msg_err_config("cannot load ucl data: parse error %s", @@ -1305,7 +1305,7 @@ rspamd_fuzzy_check_callback(struct rspamd_fuzzy_reply *result, void *ud) { /* Start lua post handler */ lua_State *L = session->ctx->cfg->lua_state; - int err_idx, ret, nargs = 9; + int err_idx, ret, nargs = 10; lua_pushcfunction(L, &rspamd_lua_traceback); err_idx = lua_gettop(L); @@ -1339,7 +1339,9 @@ rspamd_fuzzy_check_callback(struct rspamd_fuzzy_reply *result, void *ud) /* We push shingles merely for commands that modify content to avoid extra work */ if (is_shingle && cmd->cmd != FUZZY_CHECK) { lua_newshingle(L, &session->cmd.sgl); - nargs++; + } + else { + lua_pushnil(L); } if ((ret = lua_pcall(L, nargs, LUA_MULTRET, err_idx)) != 0) { @@ -1505,7 +1507,7 @@ rspamd_fuzzy_process_command(struct fuzzy_session *session) { /* Start lua pre handler */ lua_State *L = session->ctx->cfg->lua_state; - int err_idx, ret, nargs = 7; + int err_idx, ret, nargs = 8; lua_pushcfunction(L, &rspamd_lua_traceback); err_idx = lua_gettop(L); @@ -1527,7 +1529,9 @@ rspamd_fuzzy_process_command(struct fuzzy_session *session) /* We push shingles merely for commands that modify content to avoid extra work */ if (is_shingle && cmd->cmd != FUZZY_CHECK) { lua_newshingle(L, &session->cmd.sgl); - nargs++; + } + else { + lua_pushnil(L); } /* Flag and value */ @@ -2661,7 +2665,7 @@ rspamd_fuzzy_maybe_load_ratelimits(struct rspamd_fuzzy_storage_ctx *ctx) RSPAMD_DBDIR); if (access(path, R_OK) != -1) { - struct ucl_parser *parser = ucl_parser_new(UCL_PARSER_NO_IMPLICIT_ARRAYS | UCL_PARSER_DISABLE_MACRO); + struct ucl_parser *parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (ucl_parser_add_file(parser, path)) { ucl_object_t *obj = ucl_parser_get_object(parser); int loaded = 0; diff --git a/src/libmime/lang_detection.c b/src/libmime/lang_detection.c index 6e180ea66..b783b8325 100644 --- a/src/libmime/lang_detection.c +++ b/src/libmime/lang_detection.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -363,7 +363,7 @@ rspamd_language_detector_read_file(struct rspamd_config *cfg, double mean = 0, std = 0, delta = 0, delta2 = 0, m2 = 0; enum rspamd_language_category cat = RSPAMD_LANGUAGE_MAX; - parser = ucl_parser_new(UCL_PARSER_NO_FILEVARS); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_file(parser, path)) { msg_warn_config("cannot parse file %s: %s", path, ucl_parser_get_error(parser)); @@ -825,7 +825,7 @@ rspamd_language_detector_init(struct rspamd_config *cfg) languages_pattern = g_string_sized_new(PATH_MAX); rspamd_printf_gstring(languages_pattern, "%s/stop_words", languages_path); - parser = ucl_parser_new(UCL_PARSER_DEFAULT); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (ucl_parser_add_file(parser, languages_pattern->str)) { stop_words = ucl_parser_get_object(parser); @@ -936,7 +936,7 @@ end: } static void -rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, +rspamd_language_detector_random_select(rspamd_words_t *ucs_tokens, unsigned int nwords, goffset *offsets_out, uint64_t *seed) { @@ -946,7 +946,7 @@ rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, g_assert(nwords != 0); g_assert(offsets_out != NULL); - g_assert(ucs_tokens->len >= nwords); + g_assert(kv_size(*ucs_tokens) >= nwords); /* * We split input array into `nwords` parts. For each part we randomly select * an element from this particular split. Here is an example: @@ -963,22 +963,22 @@ rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, * their splits. It is not uniform distribution but it seems to be better * to include words from different text parts */ - step_len = ucs_tokens->len / nwords; - remainder = ucs_tokens->len % nwords; + step_len = kv_size(*ucs_tokens) / nwords; + remainder = kv_size(*ucs_tokens) % nwords; out_idx = 0; coin = rspamd_random_uint64_fast_seed(seed); sel = coin % (step_len + remainder); offsets_out[out_idx] = sel; - for (i = step_len + remainder; i < ucs_tokens->len; + for (i = step_len + remainder; i < kv_size(*ucs_tokens); i += step_len, out_idx++) { unsigned int ntries = 0; coin = rspamd_random_uint64_fast_seed(seed); sel = (coin % step_len) + i; for (;;) { - tok = &g_array_index(ucs_tokens, rspamd_stat_token_t, sel); + tok = &kv_A(*ucs_tokens, sel); /* Filter bad tokens */ if (tok->unicode.len >= 2 && @@ -995,8 +995,8 @@ rspamd_language_detector_random_select(GArray *ucs_tokens, unsigned int nwords, if (ntries < step_len) { sel = (coin % step_len) + i; } - else if (ntries < ucs_tokens->len) { - sel = coin % ucs_tokens->len; + else if (ntries < kv_size(*ucs_tokens)) { + sel = coin % kv_size(*ucs_tokens); } else { offsets_out[out_idx] = sel; @@ -1223,12 +1223,12 @@ static void rspamd_language_detector_detect_type(struct rspamd_task *task, unsigned int nwords, struct rspamd_lang_detector *d, - GArray *words, + rspamd_words_t *words, enum rspamd_language_category cat, khash_t(rspamd_candidates_hash) * candidates, struct rspamd_mime_text_part *part) { - unsigned int nparts = MIN(words->len, nwords); + unsigned int nparts = MIN(kv_size(*words), nwords); goffset *selected_words; rspamd_stat_token_t *tok; unsigned int i; @@ -1241,8 +1241,7 @@ rspamd_language_detector_detect_type(struct rspamd_task *task, msg_debug_lang_det("randomly selected %d words", nparts); for (i = 0; i < nparts; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, - selected_words[i]); + tok = &kv_A(*words, selected_words[i]); if (tok->unicode.len >= 3) { rspamd_language_detector_detect_word(task, d, tok, candidates, @@ -1282,7 +1281,7 @@ static enum rspamd_language_detected_type rspamd_language_detector_try_ngramm(struct rspamd_task *task, unsigned int nwords, struct rspamd_lang_detector *d, - GArray *ucs_tokens, + rspamd_words_t *ucs_tokens, enum rspamd_language_category cat, khash_t(rspamd_candidates_hash) * candidates, struct rspamd_mime_text_part *part) @@ -1863,7 +1862,7 @@ rspamd_language_detector_detect(struct rspamd_task *task, if (rspamd_lang_detection_fasttext_is_enabled(d->fasttext_detector)) { rspamd_fasttext_predict_result_t fasttext_predict_result = rspamd_lang_detection_fasttext_detect(d->fasttext_detector, task, - part->utf_words, 4); + &part->utf_words, 4); ndetected = rspamd_lang_detection_fasttext_get_nlangs(fasttext_predict_result); @@ -1930,11 +1929,11 @@ rspamd_language_detector_detect(struct rspamd_task *task, if (!ret) { /* Apply trigramms detection */ candidates = kh_init(rspamd_candidates_hash); - if (part->utf_words->len < default_short_text_limit) { + if (kv_size(part->utf_words) < default_short_text_limit) { r = rs_detect_none; msg_debug_lang_det("text is too short for trigrams detection: " "%d words; at least %d words required", - (int) part->utf_words->len, + (int) kv_size(part->utf_words), (int) default_short_text_limit); switch (cat) { case RSPAMD_LANGUAGE_CYRILLIC: @@ -1960,7 +1959,7 @@ rspamd_language_detector_detect(struct rspamd_task *task, r = rspamd_language_detector_try_ngramm(task, default_words, d, - part->utf_words, + &part->utf_words, cat, candidates, part); @@ -2123,4 +2122,4 @@ int rspamd_language_detector_elt_flags(const struct rspamd_language_elt *elt) } return 0; -}
\ No newline at end of file +} diff --git a/src/libmime/lang_detection_fasttext.cxx b/src/libmime/lang_detection_fasttext.cxx index 8ea2706e6..983ff78de 100644 --- a/src/libmime/lang_detection_fasttext.cxx +++ b/src/libmime/lang_detection_fasttext.cxx @@ -22,6 +22,7 @@ #include "libserver/logger.h" #include "contrib/fmt/include/fmt/base.h" #include "stat_api.h" +#include "libserver/word.h" #include <exception> #include <string_view> #include <vector> @@ -180,26 +181,32 @@ bool rspamd_lang_detection_fasttext_is_enabled(void *ud) rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, struct rspamd_task *task, - GArray *utf_words, + rspamd_words_t *utf_words, int k) { #ifndef WITH_FASTTEXT return nullptr; #else /* Avoid too long inputs */ - static const unsigned int max_fasttext_input_len = 1024 * 1024; + static const size_t max_fasttext_input_len = 1024 * 1024; auto *real_model = FASTTEXT_MODEL_TO_C_API(ud); std::vector<std::int32_t> words_vec; - words_vec.reserve(utf_words->len); - for (auto i = 0; i < std::min(utf_words->len, max_fasttext_input_len); i++) { - const auto *w = &g_array_index(utf_words, rspamd_stat_token_t, i); + if (!utf_words || !utf_words->a) { + return nullptr; + } + + auto words_count = kv_size(*utf_words); + words_vec.reserve(words_count); + + for (auto i = 0; i < std::min(words_count, max_fasttext_input_len); i++) { + const auto *w = &kv_A(*utf_words, i); if (w->original.len > 0) { real_model->word2vec(w->original.begin, w->original.len, words_vec); } } - msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), utf_words->len); + msg_debug_lang_det("fasttext: got %z word tokens from %ud words", words_vec.size(), words_count); auto *res = real_model->detect_language(words_vec, k); @@ -266,4 +273,4 @@ void rspamd_fasttext_predict_result_destroy(rspamd_fasttext_predict_result_t res #endif } -G_END_DECLS
\ No newline at end of file +G_END_DECLS diff --git a/src/libmime/lang_detection_fasttext.h b/src/libmime/lang_detection_fasttext.h index 2a2756968..e2b67181a 100644 --- a/src/libmime/lang_detection_fasttext.h +++ b/src/libmime/lang_detection_fasttext.h @@ -17,6 +17,7 @@ #define RSPAMD_LANG_DETECTION_FASTTEXT_H #include "config.h" +#include "libserver/word.h" G_BEGIN_DECLS struct rspamd_config; @@ -53,7 +54,7 @@ typedef void *rspamd_fasttext_predict_result_t; * @return TRUE if language is detected */ rspamd_fasttext_predict_result_t rspamd_lang_detection_fasttext_detect(void *ud, - struct rspamd_task *task, GArray *utf_words, int k); + struct rspamd_task *task, rspamd_words_t *utf_words, int k); /** * Get number of languages detected diff --git a/src/libmime/message.c b/src/libmime/message.c index f2cabf399..8442c80ac 100644 --- a/src/libmime/message.c +++ b/src/libmime/message.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,8 @@ #include "contrib/uthash/utlist.h" #include "contrib/t1ha/t1ha.h" #include "received.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "libstat/tokenizers/custom_tokenizer.h" #define GTUBE_SYMBOL "GTUBE" @@ -71,14 +73,14 @@ rspamd_mime_part_extract_words(struct rspamd_task *task, rspamd_stat_token_t *w; unsigned int i, total_len = 0, short_len = 0; - if (part->utf_words) { - rspamd_stem_words(part->utf_words, task->task_pool, part->language, + if (part->utf_words.a) { + rspamd_stem_words(&part->utf_words, task->task_pool, part->language, task->lang_det); - for (i = 0; i < part->utf_words->len; i++) { + for (i = 0; i < kv_size(part->utf_words); i++) { uint64_t h; - w = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + w = &kv_A(part->utf_words, i); if (w->stemmed.len > 0) { /* @@ -108,7 +110,7 @@ rspamd_mime_part_extract_words(struct rspamd_task *task, } } - if (part->utf_words->len) { + if (kv_size(part->utf_words)) { double *avg_len_p, *short_len_p; avg_len_p = rspamd_mempool_get_variable(task->task_pool, @@ -185,21 +187,24 @@ rspamd_mime_part_create_words(struct rspamd_task *task, tok_type = RSPAMD_TOKENIZE_RAW; } - part->utf_words = rspamd_tokenize_text( + /* Initialize kvec for words */ + kv_init(part->utf_words); + + rspamd_tokenize_text( part->utf_stripped_content->data, part->utf_stripped_content->len, &part->utf_stripped_text, tok_type, task->cfg, part->exceptions, NULL, - NULL, + &part->utf_words, task->task_pool); - if (part->utf_words) { + if (part->utf_words.a) { part->normalized_hashes = g_array_sized_new(FALSE, FALSE, - sizeof(uint64_t), part->utf_words->len); - rspamd_normalize_words(part->utf_words, task->task_pool); + sizeof(uint64_t), kv_size(part->utf_words)); + rspamd_normalize_words(&part->utf_words, task->task_pool); } } @@ -209,7 +214,7 @@ rspamd_mime_part_detect_language(struct rspamd_task *task, { struct rspamd_lang_detector_res *lang; - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words && part->utf_words->len > 0 && + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a && kv_size(part->utf_words) > 0 && task->lang_det) { if (rspamd_language_detector_detect(task, task->lang_det, part)) { lang = g_ptr_array_index(part->languages, 0); @@ -1106,8 +1111,8 @@ rspamd_message_dtor(struct rspamd_message *msg) PTR_ARRAY_FOREACH(msg->text_parts, i, tp) { - if (tp->utf_words) { - g_array_free(tp->utf_words, TRUE); + if (tp->utf_words.a) { + kv_destroy(tp->utf_words); } if (tp->normalized_hashes) { g_array_free(tp->normalized_hashes, TRUE); @@ -1583,7 +1588,7 @@ void rspamd_message_process(struct rspamd_task *task) rspamd_mime_part_extract_words(task, text_part); - if (text_part->utf_words) { + if (text_part->utf_words.a) { total_words += text_part->nwords; } } diff --git a/src/libmime/message.h b/src/libmime/message.h index cb695773e..e6b454362 100644 --- a/src/libmime/message.h +++ b/src/libmime/message.h @@ -16,6 +16,7 @@ #include "libserver/url.h" #include "libutil/ref.h" #include "libutil/str_util.h" +#include "libserver/word.h" #include <unicode/uchar.h> #include <unicode/utext.h> @@ -139,7 +140,7 @@ struct rspamd_mime_text_part { GByteArray *utf_raw_content; /* utf raw content */ GByteArray *utf_stripped_content; /* utf content with no newlines */ GArray *normalized_hashes; /* Array of uint64_t */ - GArray *utf_words; /* Array of rspamd_stat_token_t */ + rspamd_words_t utf_words; /* kvec of rspamd_word_t */ UText utf_stripped_text; /* Used by libicu to represent the utf8 content */ GPtrArray *newlines; /**< positions of newlines in text, relative to content*/ diff --git a/src/libmime/mime_string.hxx b/src/libmime/mime_string.hxx index b181576d3..d6c11d018 100644 --- a/src/libmime/mime_string.hxx +++ b/src/libmime/mime_string.hxx @@ -497,19 +497,19 @@ public: } /* Comparison */ - auto operator==(const basic_mime_string &other) + auto operator==(const basic_mime_string &other) const { return other.storage == storage; } - auto operator==(const storage_type &other) + auto operator==(const storage_type &other) const { return other == storage; } - auto operator==(const view_type &other) + auto operator==(const view_type &other) const { return other == storage; } - auto operator==(const CharT *other) + auto operator==(const CharT *other) const { if (other == NULL) { return false; diff --git a/src/libserver/cfg_file.h b/src/libserver/cfg_file.h index f59c6ff89..36941da7a 100644 --- a/src/libserver/cfg_file.h +++ b/src/libserver/cfg_file.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ struct worker_s; struct rspamd_external_libs_ctx; struct rspamd_cryptobox_pubkey; struct rspamd_dns_resolver; +struct rspamd_tokenizer_manager; /** * Logging type @@ -395,6 +396,8 @@ struct rspamd_config { unsigned int log_error_elts; /**< number of elements in error logbuf */ unsigned int log_error_elt_maxlen; /**< maximum size of error log element */ unsigned int log_task_max_elts; /**< maximum number of elements in task logging */ + unsigned int log_max_tag_len; /**< maximum length of log tag */ + char *log_tag_strip_policy_str; /**< log tag strip policy string */ struct rspamd_worker_log_pipe *log_pipes; gboolean compat_messages; /**< use old messages in the protocol (array) */ @@ -495,9 +498,10 @@ struct rspamd_config { char *zstd_output_dictionary; /**< path to zstd output dictionary */ ucl_object_t *neighbours; /**< other servers in the cluster */ - struct rspamd_config_settings_elt *setting_ids; /**< preprocessed settings ids */ - struct rspamd_lang_detector *lang_det; /**< language detector */ - struct rspamd_worker *cur_worker; /**< set dynamically by each worker */ + struct rspamd_config_settings_elt *setting_ids; /**< preprocessed settings ids */ + struct rspamd_lang_detector *lang_det; /**< language detector */ + struct rspamd_tokenizer_manager *tokenizer_manager; /**< custom tokenizer manager */ + struct rspamd_worker *cur_worker; /**< set dynamically by each worker */ ref_entry_t ref; /**< reference counter */ }; diff --git a/src/libserver/cfg_rcl.cxx b/src/libserver/cfg_rcl.cxx index f38366908..0a48e8a4f 100644 --- a/src/libserver/cfg_rcl.cxx +++ b/src/libserver/cfg_rcl.cxx @@ -299,6 +299,14 @@ rspamd_rcl_logging_handler(rspamd_mempool_t *pool, const ucl_object_t *obj, cfg->log_flags |= RSPAMD_LOG_FLAG_USEC; } + /* Set default values for new log tag options */ + if (cfg->log_max_tag_len == 0) { + cfg->log_max_tag_len = RSPAMD_LOG_ID_LEN; /* Default to new max size */ + } + if (cfg->log_tag_strip_policy_str == NULL) { + cfg->log_tag_strip_policy_str = rspamd_mempool_strdup(cfg->cfg_pool, "right"); + } + return rspamd_rcl_section_parse_defaults(cfg, *section, cfg->cfg_pool, obj, (void *) cfg, err); } @@ -1700,6 +1708,18 @@ rspamd_rcl_config_init(struct rspamd_config *cfg, GHashTable *skip_sections) G_STRUCT_OFFSET(struct rspamd_config, log_task_max_elts), RSPAMD_CL_FLAG_UINT, "Maximum number of elements in task log entry (7 by default)"); + rspamd_rcl_add_default_handler(sub, + "max_tag_len", + rspamd_rcl_parse_struct_integer, + G_STRUCT_OFFSET(struct rspamd_config, log_max_tag_len), + RSPAMD_CL_FLAG_UINT, + "Maximum length of log tag cannot exceed 32 (" G_STRINGIFY(RSPAMD_LOG_ID_LEN) ") by default)"); + rspamd_rcl_add_default_handler(sub, + "tag_strip_policy", + rspamd_rcl_parse_struct_string, + G_STRUCT_OFFSET(struct rspamd_config, log_tag_strip_policy_str), + 0, + "Log tag strip policy when tag exceeds max length: 'right', 'left', 'middle' (right by default)"); /* Documentation only options, handled in log_handler to map flags */ rspamd_rcl_add_doc_by_path(cfg, @@ -3640,7 +3660,7 @@ rspamd_config_parse_ucl(struct rspamd_config *cfg, /* Try to load keyfile if available */ auto keyfile_name = fmt::format("{}.key", filename); rspamd::util::raii_file::open(keyfile_name, O_RDONLY).map([&](const auto &keyfile) { - auto *kp_parser = ucl_parser_new(0); + auto *kp_parser = ucl_parser_new(UCL_PARSER_DEFAULT); if (ucl_parser_add_fd(kp_parser, keyfile.get_fd())) { auto *kp_obj = ucl_parser_get_object(kp_parser); diff --git a/src/libserver/cfg_utils.cxx b/src/libserver/cfg_utils.cxx index dfbdc6bee..c7bb20210 100644 --- a/src/libserver/cfg_utils.cxx +++ b/src/libserver/cfg_utils.cxx @@ -72,6 +72,11 @@ #include "contrib/expected/expected.hpp" #include "contrib/ankerl/unordered_dense.h" +#include "libserver/task.h" +#include "libserver/url.h" +#define RSPAMD_TOKENIZER_INTERNAL// We need to use internal tokenizer API +#include "libstat/tokenizers/custom_tokenizer.h" + #define DEFAULT_SCORE 10.0 #define DEFAULT_RLIMIT_NOFILE 2048 @@ -821,6 +826,65 @@ rspamd_adjust_clocks_resolution(struct rspamd_config *cfg) #endif } +extern "C" { + +gboolean +rspamd_config_load_custom_tokenizers(struct rspamd_config *cfg, GError **err) +{ + /* Load custom tokenizers */ + const ucl_object_t *custom_tokenizers = ucl_object_lookup_path(cfg->cfg_ucl_obj, + "options.custom_tokenizers"); + if (custom_tokenizers != NULL) { + msg_info_config("loading custom tokenizers"); + + if (!cfg->tokenizer_manager) { + cfg->tokenizer_manager = rspamd_tokenizer_manager_new(cfg->cfg_pool); + } + + ucl_object_iter_t it = ucl_object_iterate_new(custom_tokenizers); + const ucl_object_t *tok_obj; + const char *tok_name; + + while ((tok_obj = ucl_object_iterate_safe(it, true)) != NULL) { + tok_name = ucl_object_key(tok_obj); + GError *local_err = NULL; + + if (!rspamd_tokenizer_manager_load_tokenizer(cfg->tokenizer_manager, + tok_name, tok_obj, &local_err)) { + msg_err_config("failed to load custom tokenizer '%s': %s", + tok_name, local_err ? local_err->message : "unknown error"); + + if (err && !*err) { + *err = g_error_copy(local_err); + } + + if (local_err) { + g_error_free(local_err); + } + + ucl_object_iterate_free(it); + return FALSE; + } + } + ucl_object_iterate_free(it); + + msg_info_config("loaded custom tokenizers successfully"); + } + + return TRUE; +} + +void rspamd_config_unload_custom_tokenizers(struct rspamd_config *cfg) +{ + if (cfg->tokenizer_manager) { + msg_info_config("unloading custom tokenizers"); + rspamd_tokenizer_manager_destroy(cfg->tokenizer_manager); + cfg->tokenizer_manager = NULL; + } +} + +}// extern "C" + /* * Perform post load actions */ @@ -940,6 +1004,20 @@ rspamd_config_post_load(struct rspamd_config *cfg, msg_err_config("cannot configure libraries, fatal error"); return FALSE; } + + /* Load custom tokenizers using the new function */ + GError *tokenizer_err = NULL; + if (!rspamd_config_load_custom_tokenizers(cfg, &tokenizer_err)) { + msg_err_config("failed to load custom tokenizers: %s", + tokenizer_err ? tokenizer_err->message : "unknown error"); + if (tokenizer_err) { + g_error_free(tokenizer_err); + } + + if (opts & RSPAMD_CONFIG_INIT_VALIDATE) { + ret = tl::make_unexpected(std::string{"failed to load custom tokenizers"}); + } + } } /* Validate cache */ @@ -1363,7 +1441,7 @@ rspamd_ucl_fin_cb(struct map_cb_data *data, void **target) } /* New data available */ - auto *parser = ucl_parser_new(0); + auto *parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, (unsigned char *) cbdata->buf.data(), cbdata->buf.size())) { msg_err_config("cannot parse map %s: %s", diff --git a/src/libserver/dynamic_cfg.c b/src/libserver/dynamic_cfg.c index 984517074..6d648d745 100644 --- a/src/libserver/dynamic_cfg.c +++ b/src/libserver/dynamic_cfg.c @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -195,7 +195,7 @@ json_config_fin_cb(struct map_cb_data *data, void **target) return; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, jb->buf->str, jb->buf->len)) { msg_err("cannot load json data: parse error %s", diff --git a/src/libserver/fuzzy_backend/fuzzy_backend_redis.c b/src/libserver/fuzzy_backend/fuzzy_backend_redis.c index 27c663070..f150d48be 100644 --- a/src/libserver/fuzzy_backend/fuzzy_backend_redis.c +++ b/src/libserver/fuzzy_backend/fuzzy_backend_redis.c @@ -116,11 +116,9 @@ rspamd_redis_get_servers(struct rspamd_fuzzy_backend_redis *ctx, res = *((struct upstream_list **) lua_touserdata(L, -1)); } else { - struct lua_logger_trace tr; char outbuf[8192]; - memset(&tr, 0, sizeof(tr)); - lua_logger_out_type(L, -2, outbuf, sizeof(outbuf) - 1, &tr, + lua_logger_out(L, -2, outbuf, sizeof(outbuf), LUA_ESCAPE_UNPRINTABLE); msg_err("cannot get %s upstreams for Redis fuzzy storage %s; table content: %s", diff --git a/src/libserver/http/http_connection.c b/src/libserver/http/http_connection.c index baf37a385..b5d70fc1c 100644 --- a/src/libserver/http/http_connection.c +++ b/src/libserver/http/http_connection.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1670,7 +1670,22 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, { char datebuf[64]; int meth_len = 0; - const char *conn_type = "close"; + const char *server_conn_header, *client_conn_header; + + /* Set up connection header strings based on flags and connection type */ + if (msg->flags & RSPAMD_HTTP_FLAG_HAS_CONNECTION_HEADER) { + server_conn_header = ""; + client_conn_header = ""; + } + else { + server_conn_header = "Connection: close\r\n"; + if (conn->opts & RSPAMD_HTTP_CLIENT_KEEP_ALIVE) { + client_conn_header = "Connection: keep-alive\r\n"; + } + else { + client_conn_header = "Connection: close\r\n"; + } + } if (conn->type == RSPAMD_HTTP_SERVER) { /* Format reply */ @@ -1712,12 +1727,14 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_snprintf(repbuf, replen, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n" "Content-Type: %s", /* NO \r\n at the end ! */ - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen, mime_type); } @@ -1725,11 +1742,13 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_snprintf(repbuf, replen, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z", /* NO \r\n at the end ! */ - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen); } @@ -1737,11 +1756,12 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, /* External reply */ rspamd_printf_fstring(buf, "HTTP/1.1 200 OK\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n", + server_conn_header, priv->ctx->config.server_hdr, datebuf, enclen); } @@ -1750,12 +1770,14 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_printf_fstring(buf, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n" "Content-Type: %s\r\n", - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen, mime_type); } @@ -1763,11 +1785,13 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, meth_len = rspamd_printf_fstring(buf, "HTTP/1.1 %d %T\r\n" - "Connection: close\r\n" + "%s" "Server: %s\r\n" "Date: %s\r\n" "Content-Length: %z\r\n", - msg->code, &status, priv->ctx->config.server_hdr, + msg->code, &status, + server_conn_header, + priv->ctx->config.server_hdr, datebuf, bodylen); } @@ -1804,10 +1828,6 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, else { /* Client request */ - if (conn->opts & RSPAMD_HTTP_CLIENT_KEEP_ALIVE) { - conn_type = "keep-alive"; - } - /* Format request */ enclen += RSPAMD_FSTRING_LEN(msg->url) + strlen(http_method_str(msg->method)) + 1; @@ -1819,21 +1839,21 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, "%s %s HTTP/1.0\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n" - "Connection: %s\r\n", + "%s", "POST", "/post", enclen, - conn_type); + client_conn_header); } else { rspamd_printf_fstring(buf, "%s %V HTTP/1.0\r\n" "Content-Length: %z\r\n" - "Connection: %s\r\n", + "%s", http_method_str(msg->method), msg->url, bodylen, - conn_type); + client_conn_header); if (bodylen > 0) { if (mime_type == NULL) { @@ -1857,26 +1877,26 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, if (rspamd_http_message_is_standard_port(msg)) { rspamd_printf_fstring(buf, "%s %s HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n", "POST", "/post", - conn_type, + client_conn_header, host, enclen); } else { rspamd_printf_fstring(buf, "%s %s HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s:%d\r\n" "Content-Length: %z\r\n" "Content-Type: application/octet-stream\r\n", "POST", "/post", - conn_type, + client_conn_header, host, msg->port, enclen); @@ -1888,21 +1908,21 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, if ((msg->flags & RSPAMD_HTTP_FLAG_HAS_HOST_HEADER)) { rspamd_printf_fstring(buf, "%s %s://%s:%d/%V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Content-Length: %z\r\n", http_method_str(msg->method), (conn->opts & RSPAMD_HTTP_CLIENT_SSL) ? "https" : "http", host, msg->port, msg->url, - conn_type, + client_conn_header, bodylen); } else { if (rspamd_http_message_is_standard_port(msg)) { rspamd_printf_fstring(buf, "%s %s://%s:%d/%V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), @@ -1910,14 +1930,14 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, host, msg->port, msg->url, - conn_type, + client_conn_header, host, bodylen); } else { rspamd_printf_fstring(buf, "%s %s://%s:%d/%V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s:%d\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), @@ -1925,7 +1945,7 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, host, msg->port, msg->url, - conn_type, + client_conn_header, host, msg->port, bodylen); @@ -1937,35 +1957,35 @@ int rspamd_http_message_write_header(const char *mime_type, gboolean encrypted, if ((msg->flags & RSPAMD_HTTP_FLAG_HAS_HOST_HEADER)) { rspamd_printf_fstring(buf, "%s %V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Content-Length: %z\r\n", http_method_str(msg->method), msg->url, - conn_type, + client_conn_header, bodylen); } else { if (rspamd_http_message_is_standard_port(msg)) { rspamd_printf_fstring(buf, "%s %V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), msg->url, - conn_type, + client_conn_header, host, bodylen); } else { rspamd_printf_fstring(buf, "%s %V HTTP/1.1\r\n" - "Connection: %s\r\n" + "%s" "Host: %s:%d\r\n" "Content-Length: %z\r\n", http_method_str(msg->method), msg->url, - conn_type, + client_conn_header, host, msg->port, bodylen); @@ -2633,4 +2653,4 @@ void rspamd_http_connection_disable_encryption(struct rspamd_http_connection *co priv->peer_key = NULL; priv->flags &= ~RSPAMD_HTTP_CONN_FLAG_ENCRYPTED; } -}
\ No newline at end of file +} diff --git a/src/libserver/http/http_connection.h b/src/libserver/http/http_connection.h index f6ec03d95..466a3edd9 100644 --- a/src/libserver/http/http_connection.h +++ b/src/libserver/http/http_connection.h @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -80,9 +80,13 @@ struct rspamd_storage_shmem { */ #define RSPAMD_HTTP_FLAG_HAS_HOST_HEADER (1 << 7) /** + * Connection header has been set for a message + */ +#define RSPAMD_HTTP_FLAG_HAS_CONNECTION_HEADER (1 << 8) +/** * Message is intended for SSL connection */ -#define RSPAMD_HTTP_FLAG_WANT_SSL (1 << 8) +#define RSPAMD_HTTP_FLAG_WANT_SSL (1 << 9) /** * Options for HTTP connection */ diff --git a/src/libserver/http/http_message.c b/src/libserver/http/http_message.c index 0c9708450..e5e4a0469 100644 --- a/src/libserver/http/http_message.c +++ b/src/libserver/http/http_message.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -539,6 +539,9 @@ void rspamd_http_message_add_header_len(struct rspamd_http_message *msg, if (g_ascii_strcasecmp(name, "host") == 0) { msg->flags |= RSPAMD_HTTP_FLAG_HAS_HOST_HEADER; } + else if (g_ascii_strcasecmp(name, "connection") == 0) { + msg->flags |= RSPAMD_HTTP_FLAG_HAS_CONNECTION_HEADER; + } hdr->combined = rspamd_fstring_sized_new(nlen + vlen + 4); rspamd_printf_fstring(&hdr->combined, "%s: %*s\r\n", name, (int) vlen, @@ -746,4 +749,4 @@ const char *rspamd_http_message_get_url(struct rspamd_http_message *msg, gsize * } return NULL; -}
\ No newline at end of file +} diff --git a/src/libserver/logger/logger.c b/src/libserver/logger/logger.c index dc0a85a05..600b7f1e1 100644 --- a/src/libserver/logger/logger.c +++ b/src/libserver/logger/logger.c @@ -22,7 +22,6 @@ #include "unix-std.h" #include "logger_private.h" - static rspamd_logger_t *default_logger = NULL; static rspamd_logger_t *emergency_logger = NULL; static struct rspamd_log_modules *log_modules = NULL; @@ -30,6 +29,61 @@ static struct rspamd_log_modules *log_modules = NULL; static const char lf_chr = '\n'; unsigned int rspamd_task_log_id = (unsigned int) -1; + +/** + * Strip log tag according to the configured policy + * @param original_tag original log tag + * @param original_len length of original tag + * @param dest destination buffer + * @param max_len maximum length allowed + * @param policy stripping policy + * @return actual length of stripped tag + */ +static gsize +rspamd_log_strip_tag(const char *original_tag, gsize original_len, + char *dest, gsize max_len, + enum rspamd_log_tag_strip_policy policy) +{ + if (original_len <= max_len) { + /* No stripping needed */ + memcpy(dest, original_tag, original_len); + return original_len; + } + + switch (policy) { + case RSPAMD_LOG_TAG_STRIP_RIGHT: + /* Cut right part (current behavior) */ + memcpy(dest, original_tag, max_len); + return max_len; + + case RSPAMD_LOG_TAG_STRIP_LEFT: + /* Cut left part (take last elements) */ + memcpy(dest, original_tag + (original_len - max_len), max_len); + return max_len; + + case RSPAMD_LOG_TAG_STRIP_MIDDLE: + /* Half from start and half from end */ + if (max_len >= 2) { + gsize first_half = max_len / 2; + gsize second_half = max_len - first_half; + + memcpy(dest, original_tag, first_half); + memcpy(dest + first_half, + original_tag + (original_len - second_half), + second_half); + } + else if (max_len == 1) { + /* Just take first character */ + dest[0] = original_tag[0]; + } + return max_len; + + default: + /* Fallback to right stripping */ + memcpy(dest, original_tag, max_len); + return max_len; + } +} RSPAMD_CONSTRUCTOR(rspamd_task_log_init) { rspamd_task_log_id = rspamd_logger_add_debug_module("task"); @@ -160,6 +214,10 @@ rspamd_log_open_emergency(rspamd_mempool_t *pool, int flags) logger->process_type = "main"; logger->pid = getpid(); + /* Initialize log tag configuration with defaults */ + logger->max_log_tag_len = RSPAMD_LOG_ID_LEN; /* Keep backward compatibility default */ + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_RIGHT; + const struct rspamd_logger_funcs *funcs = &console_log_funcs; memcpy(&logger->ops, funcs, sizeof(*funcs)); @@ -258,6 +316,28 @@ rspamd_log_open_specific(rspamd_mempool_t *pool, logger->process_type = ptype; logger->enabled = TRUE; + /* Initialize log tag configuration with defaults */ + if (cfg && cfg->log_max_tag_len > 0) { + logger->max_log_tag_len = MIN(MEMPOOL_UID_LEN, cfg->log_max_tag_len); + } + else { + logger->max_log_tag_len = RSPAMD_LOG_ID_LEN; /* Keep backward compatibility default */ + } + + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_RIGHT; + + if (cfg && cfg->log_tag_strip_policy_str) { + if (g_ascii_strcasecmp(cfg->log_tag_strip_policy_str, "left") == 0) { + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_LEFT; + } + else if (g_ascii_strcasecmp(cfg->log_tag_strip_policy_str, "middle") == 0) { + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_MIDDLE; + } + else { + logger->log_tag_strip_policy = RSPAMD_LOG_TAG_STRIP_RIGHT; /* Default */ + } + } + /* Set up conditional logging */ if (cfg) { if (cfg->debug_ip_map != NULL) { @@ -1026,16 +1106,34 @@ log_time(double now, rspamd_logger_t *rspamd_log, char *timebuf, } } +/** + * Process log ID with stripping policy and return the effective length + * @param logger logger instance with configuration + * @param id original log ID + * @param processed_id buffer to store processed ID (should be at least max_log_tag_len + 1) + * @return effective length of processed ID + */ static inline int -rspamd_log_id_strlen(const char *id) +rspamd_log_process_id(rspamd_logger_t *logger, const char *id, char *processed_id) { - for (int i = 0; i < RSPAMD_LOG_ID_LEN; i++) { - if (G_UNLIKELY(id[i] == '\0')) { - return i; - } + if (id == NULL) { + return 0; + } + + gsize original_len = strlen(id); + gsize max_len = MIN(MEMPOOL_UID_LEN, logger->max_log_tag_len); + + if (original_len <= max_len) { + /* No processing needed */ + memcpy(processed_id, id, original_len); + return original_len; } - return RSPAMD_LOG_ID_LEN; + /* Apply stripping policy */ + gsize processed_len = rspamd_log_strip_tag(id, original_len, processed_id, max_len, + logger->log_tag_strip_policy); + + return processed_len; } void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, @@ -1071,8 +1169,17 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, if (G_UNLIKELY(log_json)) { /* Perform JSON logging */ - unsigned int slen = id ? strlen(id) : strlen("(NULL)"); - slen = MIN(RSPAMD_LOG_ID_LEN, slen); + char processed_id[MEMPOOL_UID_LEN]; + int processed_len = 0; + + if (id) { + processed_len = rspamd_log_process_id(logger, id, processed_id); + } + else { + strcpy(processed_id, "(NULL)"); + processed_len = strlen(processed_id); + } + r = rspamd_snprintf(tmpbuf, sizeof(tmpbuf), "{\"ts\": %f, " "\"pid\": %P, " "\"severity\": \"%s\", " @@ -1085,7 +1192,7 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, logger->pid, rspamd_get_log_severity_string(level_flags), logger->process_type, - slen, id, + processed_len, processed_id, module, function); iov_ctx->iov[0].iov_base = tmpbuf; @@ -1241,14 +1348,17 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, glong mremain, mr; char *m; + char processed_id[MEMPOOL_UID_LEN]; + int processed_len = 0; modulebuf[0] = '\0'; mremain = sizeof(modulebuf); m = modulebuf; if (id != NULL) { - mr = rspamd_snprintf(m, mremain, "<%*.s>; ", rspamd_log_id_strlen(id), - id); + processed_len = rspamd_log_process_id(logger, id, processed_id); + mr = rspamd_snprintf(m, mremain, "<%*.s>; ", processed_len, + processed_id); m += mr; mremain -= mr; } @@ -1300,10 +1410,13 @@ void rspamd_log_fill_iov(struct rspamd_logger_iov_ctx *iov_ctx, iov_ctx->iov[niov].iov_base = (void *) timebuf; iov_ctx->iov[niov++].iov_len = strlen(timebuf); if (id != NULL) { + char processed_id[MEMPOOL_UID_LEN]; + int processed_len = rspamd_log_process_id(logger, id, processed_id); + iov_ctx->iov[niov].iov_base = (void *) "; "; iov_ctx->iov[niov++].iov_len = 2; - iov_ctx->iov[niov].iov_base = (void *) id; - iov_ctx->iov[niov++].iov_len = rspamd_log_id_strlen(id); + iov_ctx->iov[niov].iov_base = (void *) processed_id; + iov_ctx->iov[niov++].iov_len = processed_len; iov_ctx->iov[niov].iov_base = (void *) ";"; iov_ctx->iov[niov++].iov_len = 1; } diff --git a/src/libserver/logger/logger_private.h b/src/libserver/logger/logger_private.h index 80178ad32..387d8639b 100644 --- a/src/libserver/logger/logger_private.h +++ b/src/libserver/logger/logger_private.h @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,12 @@ #define REPEATS_MAX 300 #define LOGBUF_LEN 8192 +enum rspamd_log_tag_strip_policy { + RSPAMD_LOG_TAG_STRIP_RIGHT = 0, /* Cut right part (current behavior) */ + RSPAMD_LOG_TAG_STRIP_LEFT, /* Cut left part (take last elements) */ + RSPAMD_LOG_TAG_STRIP_MIDDLE, /* Half from start and half from end */ +}; + struct rspamd_log_module { char *mname; unsigned int id; @@ -73,6 +79,10 @@ struct rspamd_logger_s { gboolean is_debug; gboolean no_lock; + /* Log tag configuration */ + unsigned int max_log_tag_len; + enum rspamd_log_tag_strip_policy log_tag_strip_policy; + pid_t pid; const char *process_type; struct rspamd_radix_map_helper *debug_ip; diff --git a/src/libserver/maps/map.c b/src/libserver/maps/map.c index 97130ad7c..ac82d39bb 100644 --- a/src/libserver/maps/map.c +++ b/src/libserver/maps/map.c @@ -84,7 +84,8 @@ RSPAMD_CONSTRUCTOR(rspamd_map_log_init) } /** - * Write HTTP request + * Write HTTP request with proper cache validation headers + * Uses ETags (If-None-Match) and Last-Modified (If-Modified-Since) for conditional requests */ static void write_http_request(struct http_callback_data *cbd) @@ -109,7 +110,8 @@ write_http_request(struct http_callback_data *cbd) } if (cbd->data->etag) { rspamd_http_message_add_header_len(msg, "If-None-Match", - cbd->data->etag->str, cbd->data->etag->len); + cbd->data->etag->str, + cbd->data->etag->len); } } @@ -295,23 +297,101 @@ rspamd_map_cache_cb(struct ev_loop *loop, ev_timer *w, int revents) } } +/** + * Calculate next check time with proper priority for different cache validation mechanisms + * Priority: ETags > Last-Modified > Cache expiration headers + * @param now current time + * @param expires time from cache expiration header + * @param map_check_interval base polling interval + * @param has_etag whether we have ETag for conditional requests + * @param has_last_modified whether we have Last-Modified for conditional requests + * @return next check time + */ static inline time_t -rspamd_http_map_process_next_check(time_t now, time_t expires, time_t map_check_interval) +rspamd_http_map_process_next_check(struct rspamd_map *map, + struct rspamd_map_backend *bk, + time_t now, + time_t expires, + time_t map_check_interval, + gboolean has_etag, + gboolean has_last_modified) { - static const time_t interval_mult = 16; - /* By default use expires header */ - time_t next_check = expires; + static const time_t interval_mult = 4; /* Reduced from 16 to be more responsive */ + static const time_t min_respectful_interval = 5; + time_t next_check; + time_t effective_interval = map_check_interval; - if (expires < now) { - return now; + /* + * Priority order for cache validation: + * 1. ETags (most reliable) + * 2. Last-Modified dates + * 3. Cache expiration headers (least reliable) + */ + + if (has_etag || has_last_modified) { + /* + * If we have ETags or Last-Modified, we can use conditional requests + * to avoid unnecessary downloads. However, we still need to be respectful + * to servers and not DoS them with overly aggressive polling. + */ + if (map_check_interval < min_respectful_interval) { + /* + * User configured very aggressive polling, but server provides cache validation. + * Enforce minimum respectful interval to avoid DoS'ing the server. + */ + effective_interval = min_respectful_interval * interval_mult; + msg_info_map("map polling interval %d too aggressive with server cache support for %s, " + "using %d seconds minimum", + (int) map_check_interval, bk->uri, (int) effective_interval); + } + + if (expires > now && (expires - now) <= effective_interval * interval_mult) { + /* Use expires header if it's reasonable (within interval_mult x poll interval) */ + next_check = expires; + } + else { + /* Use effective interval, don't extend too much */ + next_check = now + effective_interval; + } } - else if (expires - now > map_check_interval * interval_mult) { - next_check = now + map_check_interval * interval_mult; + else if (expires > now) { + /* + * No ETags or Last-Modified available, rely on cache expiration. + * But still cap the interval to avoid too long delays. + * No need for respectful interval protection here since no conditional requests. + */ + if (expires - now > map_check_interval * interval_mult) { + next_check = now + map_check_interval * interval_mult; + } + else { + next_check = expires; + } + } + else { + /* No valid cache information, check immediately */ + next_check = now; } return next_check; } +/** + * Calculate respectful polling interval to avoid DoS'ing servers with cache validation + * @param map_check_interval user configured interval + * @return effective interval that respects server resources + */ +static inline time_t +rspamd_map_get_respectful_interval(time_t map_check_interval) +{ + static const time_t min_respectful_interval = 5; /* Minimum 5 seconds to be respectful */ + static const time_t interval_mult = 4; /* Multiplier for respectful minimum */ + + if (map_check_interval < min_respectful_interval) { + return min_respectful_interval * interval_mult; + } + return map_check_interval; +} + static int http_map_finish(struct rspamd_http_connection *conn, struct rspamd_http_message *msg) @@ -333,12 +413,15 @@ http_map_finish(struct rspamd_http_connection *conn, if (msg->code == 200) { if (cbd->check) { - msg_info_map("need to reread map from %s", cbd->bk->uri); + msg_info_map("need to reread map from %s (reply code 200); " + "date timestamp: %z, last modified: %z", + cbd->bk->uri, (size_t) msg->date, (size_t) msg->last_modified); cbd->periodic->need_modify = TRUE; /* Reset the whole chain */ cbd->periodic->cur_backend = 0; /* Reset cache, old cached data will be cleaned on timeout */ g_atomic_int_set(&data->cache->available, 0); + g_atomic_int_set(&map->shared->loaded, 0); data->cur_cache_cbd = NULL; rspamd_map_process_periodic(cbd->periodic); @@ -347,6 +430,7 @@ http_map_finish(struct rspamd_http_connection *conn, return 0; } + /* This code is executed when we are actually reading a map */ cbd->data->last_checked = msg->date; if (msg->last_modified) { @@ -377,10 +461,11 @@ http_map_finish(struct rspamd_http_connection *conn, goto err; } - /* Check for expires */ + /* Check for expires + etag */ double cached_timeout = map->poll_timeout * 2; expires_hdr = rspamd_http_message_find_header(msg, "Expires"); + etag_hdr = rspamd_http_message_find_header(msg, "ETag"); if (expires_hdr) { time_t hdate; @@ -388,8 +473,10 @@ http_map_finish(struct rspamd_http_connection *conn, hdate = rspamd_http_parse_date(expires_hdr->begin, expires_hdr->len); if (hdate != (time_t) -1 && hdate > msg->date) { - map->next_check = rspamd_http_map_process_next_check(msg->date, hdate, - (time_t) map->poll_timeout); + map->next_check = rspamd_http_map_process_next_check(map, bk, msg->date, hdate, + (time_t) map->poll_timeout, + etag_hdr != NULL, + msg->last_modified != 0); cached_timeout = map->next_check - msg->date; } else { @@ -397,9 +484,16 @@ http_map_finish(struct rspamd_http_connection *conn, map->next_check = 0; } } - - /* Check for etag */ - etag_hdr = rspamd_http_message_find_header(msg, "ETag"); + else if (etag_hdr != NULL || msg->last_modified != 0) { + /* No expires header, but we have ETag or Last-Modified - use respectful interval */ + time_t effective_interval = rspamd_map_get_respectful_interval(map->poll_timeout); + if (effective_interval != map->poll_timeout) { + msg_info_map("map polling interval %d too aggressive with server cache support, " + "using %d seconds minimum", + (int) map->poll_timeout, (int) effective_interval); + } + map->next_check = msg->date + effective_interval; + } if (etag_hdr) { if (cbd->data->etag) { @@ -420,10 +514,7 @@ http_map_finish(struct rspamd_http_connection *conn, MAP_RETAIN(cbd->shmem_data, "shmem_data"); cbd->data->gen++; - /* - * We know that a map is in the locked state - */ - g_atomic_int_set(&data->cache->available, 1); + /* Store cached data */ rspamd_strlcpy(data->cache->shmem_name, cbd->shmem_data->shm_name, sizeof(data->cache->shmem_name)); @@ -525,6 +616,12 @@ http_map_finish(struct rspamd_http_connection *conn, cbd->periodic->cur_backend++; munmap(in, dlen); + + /* Announce for other processes */ + g_atomic_int_set(&data->cache->available, 1); + g_atomic_int_set(&map->shared->loaded, 1); + g_atomic_int_set(&map->shared->cached, 1); + rspamd_map_process_periodic(cbd->periodic); } else if (msg->code == 304 && cbd->check) { @@ -538,20 +635,34 @@ http_map_finish(struct rspamd_http_connection *conn, } expires_hdr = rspamd_http_message_find_header(msg, "Expires"); + bool has_expires = (expires_hdr != NULL); if (expires_hdr) { time_t hdate; hdate = rspamd_http_parse_date(expires_hdr->begin, expires_hdr->len); if (hdate != (time_t) -1 && hdate > msg->date) { - map->next_check = rspamd_http_map_process_next_check(msg->date, hdate, - (time_t) map->poll_timeout); + map->next_check = rspamd_http_map_process_next_check(map, bk, msg->date, hdate, + (time_t) map->poll_timeout, + cbd->data->etag != NULL, + msg->last_modified != 0); } else { msg_info_map("invalid expires header: %T, ignore it", expires_hdr); map->next_check = 0; + has_expires = false; } } + else if (cbd->data->etag != NULL || msg->last_modified != 0) { + /* No expires header, but we have ETag or Last-Modified - use respectful interval */ + time_t effective_interval = rspamd_map_get_respectful_interval(map->poll_timeout); + if (effective_interval != map->poll_timeout) { + msg_info_map("map polling interval %d too aggressive with server cache support, " + "using %d seconds minimum", + (int) map->poll_timeout, (int) effective_interval); + } + map->next_check = msg->date + effective_interval; + } etag_hdr = rspamd_http_message_find_header(msg, "ETag"); @@ -564,19 +675,24 @@ http_map_finish(struct rspamd_http_connection *conn, } } - if (map->next_check) { + if (has_expires) { rspamd_http_date_format(next_check_date, sizeof(next_check_date), map->next_check); - msg_info_map("data is not modified for server %s, next check at %s " + msg_info_map("data is not modified for server %s (%s), next check at %s " "(http cache based: %T)", - cbd->data->host, next_check_date, expires_hdr); + cbd->data->host, + bk->uri, + next_check_date, + expires_hdr); } else { rspamd_http_date_format(next_check_date, sizeof(next_check_date), - rspamd_get_calendar_ticks() + map->poll_timeout); - msg_info_map("data is not modified for server %s, next check at %s " + map->next_check); + msg_info_map("data is not modified for server %s (%s), next check at %s " "(timer based)", - cbd->data->host, next_check_date); + cbd->data->host, + bk->uri, + next_check_date); } rspamd_map_update_http_cached_file(map, bk, cbd->data); @@ -919,6 +1035,8 @@ read_map_file(struct rspamd_map *map, struct file_map_data *data, map->read_callback(NULL, 0, &periodic->cbdata, TRUE); } + g_atomic_int_set(&map->shared->loaded, 1); + return TRUE; } @@ -1003,6 +1121,7 @@ read_map_static(struct rspamd_map *map, struct static_map_data *data, } data->processed = TRUE; + g_atomic_int_set(&map->shared->loaded, 1); return TRUE; } @@ -1010,9 +1129,7 @@ read_map_static(struct rspamd_map *map, struct static_map_data *data, static void rspamd_map_periodic_dtor(struct map_periodic_cbdata *periodic) { - struct rspamd_map *map; - - map = periodic->map; + struct rspamd_map *map = periodic->map; msg_debug_map("periodic dtor %p; need_modify=%d", periodic, periodic->need_modify); if (periodic->need_modify || periodic->cbdata.errored) { @@ -1027,18 +1144,13 @@ rspamd_map_periodic_dtor(struct map_periodic_cbdata *periodic) /* Not modified */ } - if (periodic->locked) { - g_atomic_int_set(periodic->map->locked, 0); - msg_debug_map("unlocked map %s", periodic->map->name); - - if (periodic->map->wrk->state == rspamd_worker_state_running) { - rspamd_map_schedule_periodic(periodic->map, - RSPAMD_SYMBOL_RESULT_NORMAL); - } - else { - msg_debug_map("stop scheduling periodics for %s; terminating state", - periodic->map->name); - } + if (periodic->map->wrk->state == rspamd_worker_state_running) { + rspamd_map_schedule_periodic(periodic->map, + RSPAMD_MAP_SCHEDULE_NORMAL); + } + else { + msg_debug_map("stop scheduling periodics for %s; terminating state", + periodic->map->name); } g_free(periodic); @@ -1475,7 +1587,7 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, const unsigned char *data, gsize len) { - char path[PATH_MAX]; + char path[PATH_MAX], temp_path[PATH_MAX]; unsigned char digest[rspamd_cryptobox_HASHBYTES]; struct rspamd_config *cfg = map->cfg; int fd; @@ -1488,8 +1600,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, rspamd_cryptobox_hash(digest, bk->uri, strlen(bk->uri), NULL, 0); rspamd_snprintf(path, sizeof(path), "%s%c%*xs.map", cfg->maps_cache_dir, G_DIR_SEPARATOR, 20, digest); + rspamd_snprintf(temp_path, sizeof(temp_path), "%s.tmp.%d.%d", path, + (int) getpid(), (int) rspamd_get_calendar_ticks()); - fd = rspamd_file_xopen(path, O_WRONLY | O_TRUNC | O_CREAT, + fd = rspamd_file_xopen(temp_path, O_WRONLY | O_TRUNC | O_CREAT, 00600, FALSE); if (fd == -1) { @@ -1497,8 +1611,9 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, } if (!rspamd_file_lock(fd, FALSE)) { - msg_err_map("cannot lock file %s: %s", path, strerror(errno)); + msg_err_map("cannot lock file %s: %s", temp_path, strerror(errno)); close(fd); + unlink(temp_path); return FALSE; } @@ -1517,9 +1632,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, } if (write(fd, &header, sizeof(header)) != sizeof(header)) { - msg_err_map("cannot write file %s (header stage): %s", path, strerror(errno)); + msg_err_map("cannot write file %s (header stage): %s", temp_path, strerror(errno)); rspamd_file_unlock(fd, FALSE); close(fd); + unlink(temp_path); return FALSE; } @@ -1527,9 +1643,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, if (header.etag_len > 0) { if (write(fd, RSPAMD_FSTRING_DATA(htdata->etag), header.etag_len) != header.etag_len) { - msg_err_map("cannot write file %s (etag stage): %s", path, strerror(errno)); + msg_err_map("cannot write file %s (etag stage): %s", temp_path, strerror(errno)); rspamd_file_unlock(fd, FALSE); close(fd); + unlink(temp_path); return FALSE; } @@ -1537,9 +1654,10 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, /* Now write the rest */ if (write(fd, data, len) != len) { - msg_err_map("cannot write file %s (data stage): %s", path, strerror(errno)); + msg_err_map("cannot write file %s (data stage): %s", temp_path, strerror(errno)); rspamd_file_unlock(fd, FALSE); close(fd); + unlink(temp_path); return FALSE; } @@ -1547,6 +1665,13 @@ rspamd_map_save_http_cached_file(struct rspamd_map *map, rspamd_file_unlock(fd, FALSE); close(fd); + /* Atomically move temp file to final location */ + if (rename(temp_path, path) != 0) { + msg_err_map("cannot rename %s to %s: %s", temp_path, path, strerror(errno)); + unlink(temp_path); + return FALSE; + } + msg_info_map("saved data from %s in %s, %uz bytes", bk->uri, path, len + sizeof(header) + header.etag_len); return TRUE; @@ -1680,7 +1805,11 @@ rspamd_map_read_http_cached_file(struct rspamd_map *map, double now = rspamd_get_calendar_ticks(); if (header.next_check > now) { - map->next_check = rspamd_http_map_process_next_check(now, header.next_check, map->poll_timeout); + /* We assume that we have this data inside the cached file */ + map->next_check = rspamd_http_map_process_next_check(map, bk, now, header.next_check, + map->poll_timeout, + header.etag_len > 0, + true); } else { map->next_check = now; @@ -1727,6 +1856,8 @@ rspamd_map_read_http_cached_file(struct rspamd_map *map, struct tm tm; char ncheck_buf[32], lm_buf[32]; + g_atomic_int_set(&map->shared->loaded, 1); + g_atomic_int_set(&map->shared->cached, 1); rspamd_localtime(map->next_check, &tm); strftime(ncheck_buf, sizeof(ncheck_buf) - 1, "%Y-%m-%d %H:%M:%S", &tm); rspamd_localtime(htdata->last_modified, &tm); @@ -1769,7 +1900,6 @@ rspamd_map_common_http_callback(struct rspamd_map *map, (int) data->last_modified, (int) data->cache->last_modified); periodic->need_modify = TRUE; - /* Reset the whole chain */ periodic->cur_backend = 0; rspamd_map_process_periodic(periodic); } @@ -2027,33 +2157,22 @@ rspamd_map_process_periodic(struct map_periodic_cbdata *cbd) map = cbd->map; map->scheduled_check = NULL; - if (!map->file_only && !cbd->locked) { - if (!g_atomic_int_compare_and_exchange(cbd->map->locked, - 0, 1)) { - msg_debug_map( - "don't try to reread map %s as it is locked by other process, " - "will reread it later", - cbd->map->name); - rspamd_map_schedule_periodic(map, RSPAMD_MAP_SCHEDULE_LOCKED); - MAP_RELEASE(cbd, "periodic"); + /* For each backend we need to check for modifications */ + if (cbd->cur_backend >= cbd->map->backends->len) { + /* Last backend */ + msg_debug_map("finished map: %d of %d", cbd->cur_backend, + cbd->map->backends->len); + MAP_RELEASE(cbd, "periodic"); - return; - } - else { - msg_debug_map("locked map %s", cbd->map->name); - cbd->locked = TRUE; - } + return; } + bk = g_ptr_array_index(map->backends, cbd->cur_backend); + if (cbd->errored) { /* We should not check other backends if some backend has failed*/ rspamd_map_schedule_periodic(cbd->map, RSPAMD_MAP_SCHEDULE_ERROR); - if (cbd->locked) { - g_atomic_int_set(cbd->map->locked, 0); - cbd->locked = FALSE; - } - /* Also set error flag for the map consumer */ cbd->cbdata.errored = true; @@ -2064,19 +2183,7 @@ rspamd_map_process_periodic(struct map_periodic_cbdata *cbd) return; } - /* For each backend we need to check for modifications */ - if (cbd->cur_backend >= cbd->map->backends->len) { - /* Last backend */ - msg_debug_map("finished map: %d of %d", cbd->cur_backend, - cbd->map->backends->len); - MAP_RELEASE(cbd, "periodic"); - - return; - } - if (cbd->map->wrk && cbd->map->wrk->state == rspamd_worker_state_running) { - bk = g_ptr_array_index(cbd->map->backends, cbd->cur_backend); - g_assert(bk != NULL); if (cbd->need_modify) { /* Load data from the next backend */ @@ -2781,10 +2888,6 @@ rspamd_map_parse_backend(struct rspamd_config *cfg, const char *map_line) bk->data.sd = sdata; } - bk->id = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_T1HA, - bk->uri, strlen(bk->uri), - 0xdeadbabe); - return bk; err: @@ -2815,6 +2918,13 @@ rspamd_map_calculate_hash(struct rspamd_map *map) rspamd_cryptobox_hash_init(&st, NULL, 0); + if (map->name) { + rspamd_cryptobox_hash_update(&st, map->name, strlen(map->name)); + } + if (map->description) { + rspamd_cryptobox_hash_update(&st, map->description, strlen(map->description)); + } + for (i = 0; i < map->backends->len; i++) { bk = g_ptr_array_index(map->backends, i); rspamd_cryptobox_hash_update(&st, bk->uri, strlen(bk->uri)); @@ -2823,6 +2933,26 @@ rspamd_map_calculate_hash(struct rspamd_map *map) rspamd_cryptobox_hash_final(&st, cksum); cksum_encoded = rspamd_encode_base32(cksum, sizeof(cksum), RSPAMD_BASE32_DEFAULT); rspamd_strlcpy(map->tag, cksum_encoded, sizeof(map->tag)); + + for (i = 0; i < map->backends->len; i++) { + bk = g_ptr_array_index(map->backends, i); + + /* Also update each backend */ + rspamd_cryptobox_fast_hash_state_t hst; + rspamd_cryptobox_fast_hash_init(&hst, 0); + rspamd_cryptobox_fast_hash_update(&hst, bk->uri, strlen(bk->uri)); + rspamd_cryptobox_fast_hash_update(&hst, map->tag, sizeof(map->tag)); + + if (bk->protocol == MAP_PROTO_STATIC) { + /* Static maps content is pre-defined */ + rspamd_cryptobox_fast_hash_update(&hst, bk->data.sd->data, + bk->data.sd->len); + } + + /* We use only 52 bits to be compatible with other numbers representation */ + bk->id = rspamd_cryptobox_fast_hash_final(&hst) & ~(0xFFFULL << 52); + } + g_free(cksum_encoded); } @@ -2888,8 +3018,8 @@ rspamd_map_add(struct rspamd_config *cfg, map->user_data = user_data; map->cfg = cfg; map->id = rspamd_random_uint64_fast(); - map->locked = - rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(int)); + map->shared = + rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(struct rspamd_map_shared_data)); map->backends = g_ptr_array_sized_new(1); map->wrk = worker; rspamd_mempool_add_destructor(cfg->cfg_pool, rspamd_ptr_array_free_hard, @@ -2988,8 +3118,8 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, map->user_data = user_data; map->cfg = cfg; map->id = rspamd_random_uint64_fast(); - map->locked = - rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(int)); + map->shared = + rspamd_mempool_alloc0_shared(cfg->cfg_pool, sizeof(struct rspamd_map_shared_data)); map->backends = g_ptr_array_new(); map->wrk = worker; map->no_file_read = (flags & RSPAMD_MAP_FILE_NO_READ); @@ -3108,7 +3238,7 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, goto err; } - gboolean all_local = TRUE; + gboolean all_local = TRUE, all_loaded = TRUE; PTR_ARRAY_FOREACH(map->backends, i, bk) { @@ -3127,9 +3257,8 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, map_data = g_string_sized_new(32); if (rspamd_map_add_static_string(cfg, elt, map_data)) { - bk->data.sd->data = map_data->str; bk->data.sd->len = map_data->len; - g_string_free(map_data, FALSE); + bk->data.sd->data = (unsigned char *) g_string_free(map_data, FALSE); } else { g_string_free(map_data, TRUE); @@ -3152,13 +3281,16 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, } ucl_object_iterate_free(it); - bk->data.sd->data = map_data->str; bk->data.sd->len = map_data->len; - g_string_free(map_data, FALSE); + bk->data.sd->data = (unsigned char *) g_string_free(map_data, FALSE); } } else if (bk->protocol != MAP_PROTO_FILE) { all_local = FALSE; + all_loaded = FALSE; /* Will be loaded later */ + } + else { + all_loaded = FALSE; /* Will be loaded later (even for files) */ } } @@ -3167,6 +3299,11 @@ rspamd_map_add_from_ucl(struct rspamd_config *cfg, cfg->map_file_watch_multiplier); } + if (all_loaded) { + /* Static map */ + g_atomic_int_set(&map->shared->loaded, 1); + } + rspamd_map_calculate_hash(map); msg_debug_map("added map from ucl"); diff --git a/src/libserver/maps/map_private.h b/src/libserver/maps/map_private.h index d0b22fe36..65df8d7f5 100644 --- a/src/libserver/maps/map_private.h +++ b/src/libserver/maps/map_private.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,23 @@ enum fetch_proto { MAP_PROTO_STATIC }; +static const char * +rspamd_map_fetch_protocol_name(enum fetch_proto proto) +{ + switch (proto) { + case MAP_PROTO_FILE: + return "file"; + case MAP_PROTO_HTTP: + return "http"; + case MAP_PROTO_HTTPS: + return "https"; + case MAP_PROTO_STATIC: + return "static"; + default: + return "unknown"; + } +} + /** * Data specific to file maps */ @@ -76,7 +93,7 @@ struct rspamd_http_map_cached_cbdata { time_t last_checked; }; -struct rspamd_map_cachepoint { +struct rspamd_http_map_cache { int available; gsize len; time_t last_modified; @@ -88,7 +105,7 @@ struct rspamd_map_cachepoint { */ struct http_map_data { /* Shared cache data */ - struct rspamd_map_cachepoint *cache; + struct rspamd_http_map_cache *cache; /* Non-shared for cache owner, used to cleanup cache */ struct rspamd_http_map_cached_cbdata *cur_cache_cbd; char *userinfo; @@ -117,6 +134,7 @@ union rspamd_map_backend_data { struct rspamd_map; + struct rspamd_map_backend { enum fetch_proto protocol; gboolean is_signed; @@ -124,7 +142,7 @@ struct rspamd_map_backend { gboolean is_fallback; struct rspamd_map *map; struct ev_loop *event_loop; - uint32_t id; + uint64_t id; struct rspamd_cryptobox_pubkey *trusted_pubkey; union rspamd_map_backend_data data; char *uri; @@ -133,6 +151,14 @@ struct rspamd_map_backend { struct map_periodic_cbdata; +/* + * Shared between workers + */ +struct rspamd_map_shared_data { + int loaded; + int cached; +}; + struct rspamd_map { struct rspamd_dns_resolver *r; struct rspamd_config *cfg; @@ -168,7 +194,7 @@ struct rspamd_map { bool no_file_read; /* Do not read files */ bool seen; /* This map has already been watched or pre-loaded */ /* Shared lock for temporary disabling of map reading (e.g. when this map is written by UI) */ - int *locked; + struct rspamd_map_shared_data *shared; char tag[MEMPOOL_UID_LEN]; }; @@ -185,7 +211,6 @@ struct map_periodic_cbdata { ev_timer ev; gboolean need_modify; gboolean errored; - gboolean locked; unsigned int cur_backend; ref_entry_t ref; }; diff --git a/src/libserver/milter.c b/src/libserver/milter.c index 94b0d6cc1..09ddddaba 100644 --- a/src/libserver/milter.c +++ b/src/libserver/milter.c @@ -1473,8 +1473,6 @@ rspamd_milter_macro_http(struct rspamd_milter_session *session, { rspamd_http_message_add_header_len(msg, QUEUE_ID_HEADER, found->begin, found->len); - rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, - found->begin, found->len); } else { @@ -1482,8 +1480,6 @@ rspamd_milter_macro_http(struct rspamd_milter_session *session, { rspamd_http_message_add_header_len(msg, QUEUE_ID_HEADER, found->begin, found->len); - rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, - found->begin, found->len); } } diff --git a/src/libserver/re_cache.c b/src/libserver/re_cache.c index 06e9f3328..50b155ae0 100644 --- a/src/libserver/re_cache.c +++ b/src/libserver/re_cache.c @@ -998,20 +998,21 @@ rspamd_re_cache_process_selector(struct rspamd_task *task, return result; } + static inline unsigned int -rspamd_process_words_vector(GArray *words, - const unsigned char **scvec, - unsigned int *lenvec, - struct rspamd_re_class *re_class, - unsigned int cnt, - gboolean *raw) +rspamd_process_words_vector_kvec(rspamd_words_t *words, + const unsigned char **scvec, + unsigned int *lenvec, + struct rspamd_re_class *re_class, + unsigned int cnt, + gboolean *raw) { unsigned int j; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; - if (words) { - for (j = 0; j < words->len; j++) { - tok = &g_array_index(words, rspamd_stat_token_t, j); + if (words && words->a) { + for (j = 0; j < kv_size(*words); j++) { + tok = &kv_A(*words, j); if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { if (!(tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF)) { @@ -1432,13 +1433,13 @@ rspamd_re_cache_exec_re(struct rspamd_task *task, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, text_part) { - if (text_part->utf_words) { - cnt += text_part->utf_words->len; + if (text_part->utf_words.a) { + cnt += kv_size(text_part->utf_words); } } - if (task->meta_words && task->meta_words->len > 0) { - cnt += task->meta_words->len; + if (task->meta_words.a && kv_size(task->meta_words) > 0) { + cnt += kv_size(task->meta_words); } if (cnt > 0) { @@ -1449,15 +1450,15 @@ rspamd_re_cache_exec_re(struct rspamd_task *task, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, text_part) { - if (text_part->utf_words) { - cnt = rspamd_process_words_vector(text_part->utf_words, - scvec, lenvec, re_class, cnt, &raw); + if (text_part->utf_words.a) { + cnt = rspamd_process_words_vector_kvec(&text_part->utf_words, + scvec, lenvec, re_class, cnt, &raw); } } - if (task->meta_words) { - cnt = rspamd_process_words_vector(task->meta_words, - scvec, lenvec, re_class, cnt, &raw); + if (task->meta_words.a) { + cnt = rspamd_process_words_vector_kvec(&task->meta_words, + scvec, lenvec, re_class, cnt, &raw); } ret = rspamd_re_cache_process_regexp_data(rt, re, diff --git a/src/libserver/roll_history.c b/src/libserver/roll_history.c index 66a53a597..d0f145d8f 100644 --- a/src/libserver/roll_history.c +++ b/src/libserver/roll_history.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -231,7 +231,7 @@ rspamd_roll_history_load(struct roll_history *history, const char *filename) return FALSE; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_fd(parser, fd)) { msg_warn("cannot parse history file %s: %s", filename, diff --git a/src/libserver/rspamd_control.c b/src/libserver/rspamd_control.c index 1bff2ff12..9e35cb575 100644 --- a/src/libserver/rspamd_control.c +++ b/src/libserver/rspamd_control.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -214,7 +214,7 @@ rspamd_control_write_reply(struct rspamd_control_session *session) case RSPAMD_CONTROL_FUZZY_STAT: if (elt->attached_fd != -1) { /* We have some data to parse */ - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); ucl_object_insert_key(cur, ucl_object_fromint( elt->reply.reply.fuzzy_stat.status), diff --git a/src/libserver/symcache/symcache_impl.cxx b/src/libserver/symcache/symcache_impl.cxx index c0278cfc1..c1ca2a6ed 100644 --- a/src/libserver/symcache/symcache_impl.cxx +++ b/src/libserver/symcache/symcache_impl.cxx @@ -274,7 +274,7 @@ auto symcache::load_items() -> bool return false; } - auto *parser = ucl_parser_new(0); + auto *parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); const auto *p = (const std::uint8_t *) (hdr + 1); if (!ucl_parser_add_chunk(parser, p, cached_map->get_size() - sizeof(*hdr))) { diff --git a/src/libserver/task.c b/src/libserver/task.c index bd1e07549..9f5b1f00a 100644 --- a/src/libserver/task.c +++ b/src/libserver/task.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -196,8 +196,8 @@ void rspamd_task_free(struct rspamd_task *task) rspamd_email_address_free(task->from_envelope_orig); } - if (task->meta_words) { - g_array_free(task->meta_words, TRUE); + if (task->meta_words.a) { + kv_destroy(task->meta_words); } ucl_object_unref(task->messages); diff --git a/src/libserver/task.h b/src/libserver/task.h index 6be350098..1c1778fee 100644 --- a/src/libserver/task.h +++ b/src/libserver/task.h @@ -24,6 +24,7 @@ #include "dns.h" #include "re_cache.h" #include "khash.h" +#include "libserver/word.h" #ifdef __cplusplus extern "C" { @@ -187,7 +188,7 @@ struct rspamd_task { struct rspamd_scan_result *result; /**< Metric result */ khash_t(rspamd_task_lua_cache) lua_cache; /**< cache of lua objects */ GPtrArray *tokens; /**< statistics tokens */ - GArray *meta_words; /**< rspamd_stat_token_t produced from meta headers + rspamd_words_t meta_words; /**< rspamd_word_t produced from meta headers (e.g. Subject) */ GPtrArray *rcpt_envelope; /**< array of rspamd_email_address */ diff --git a/src/libserver/word.h b/src/libserver/word.h new file mode 100644 index 000000000..7698bf327 --- /dev/null +++ b/src/libserver/word.h @@ -0,0 +1,88 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_WORD_H +#define RSPAMD_WORD_H + +#include "config.h" +#include "fstring.h" +#include "contrib/libucl/kvec.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @file word.h + * Word processing structures and definitions + */ + +/* Word flags */ +#define RSPAMD_WORD_FLAG_TEXT (1u << 0) +#define RSPAMD_WORD_FLAG_META (1u << 1) +#define RSPAMD_WORD_FLAG_LUA_META (1u << 2) +#define RSPAMD_WORD_FLAG_EXCEPTION (1u << 3) +#define RSPAMD_WORD_FLAG_HEADER (1u << 4) +#define RSPAMD_WORD_FLAG_UNIGRAM (1u << 5) +#define RSPAMD_WORD_FLAG_UTF (1u << 6) +#define RSPAMD_WORD_FLAG_NORMALISED (1u << 7) +#define RSPAMD_WORD_FLAG_STEMMED (1u << 8) +#define RSPAMD_WORD_FLAG_BROKEN_UNICODE (1u << 9) +#define RSPAMD_WORD_FLAG_STOP_WORD (1u << 10) +#define RSPAMD_WORD_FLAG_SKIPPED (1u << 11) +#define RSPAMD_WORD_FLAG_INVISIBLE_SPACES (1u << 12) +#define RSPAMD_WORD_FLAG_EMOJI (1u << 13) + +/** + * Word structure representing tokenized text + */ +typedef struct rspamd_word_s { + rspamd_ftok_t original; /* utf8 raw */ + rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ + rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ + rspamd_ftok_t stemmed; /* stemmed utf8 */ + unsigned int flags; +} rspamd_word_t; + +/** + * Vector of words using kvec + */ +typedef kvec_t(rspamd_word_t) rspamd_words_t; + +/* Legacy typedefs for backward compatibility */ +typedef rspamd_word_t rspamd_stat_token_t; + +/* Legacy flag aliases for backward compatibility */ +#define RSPAMD_STAT_TOKEN_FLAG_TEXT RSPAMD_WORD_FLAG_TEXT +#define RSPAMD_STAT_TOKEN_FLAG_META RSPAMD_WORD_FLAG_META +#define RSPAMD_STAT_TOKEN_FLAG_LUA_META RSPAMD_WORD_FLAG_LUA_META +#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION RSPAMD_WORD_FLAG_EXCEPTION +#define RSPAMD_STAT_TOKEN_FLAG_HEADER RSPAMD_WORD_FLAG_HEADER +#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM RSPAMD_WORD_FLAG_UNIGRAM +#define RSPAMD_STAT_TOKEN_FLAG_UTF RSPAMD_WORD_FLAG_UTF +#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED RSPAMD_WORD_FLAG_NORMALISED +#define RSPAMD_STAT_TOKEN_FLAG_STEMMED RSPAMD_WORD_FLAG_STEMMED +#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE RSPAMD_WORD_FLAG_BROKEN_UNICODE +#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD RSPAMD_WORD_FLAG_STOP_WORD +#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED RSPAMD_WORD_FLAG_SKIPPED +#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES RSPAMD_WORD_FLAG_INVISIBLE_SPACES +#define RSPAMD_STAT_TOKEN_FLAG_EMOJI RSPAMD_WORD_FLAG_EMOJI + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_WORD_H */ diff --git a/src/libserver/worker_util.c b/src/libserver/worker_util.c index d0ac8d8d3..685ee9cd2 100644 --- a/src/libserver/worker_util.c +++ b/src/libserver/worker_util.c @@ -2138,7 +2138,7 @@ rspamd_controller_load_saved_stats(struct rspamd_main *rspamd_main, return; } - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_file(parser, cfg->stats_file)) { msg_err_config("cannot parse controller stats from %s: %s", diff --git a/src/libstat/CMakeLists.txt b/src/libstat/CMakeLists.txt index 64d572a57..eddf64e49 100644 --- a/src/libstat/CMakeLists.txt +++ b/src/libstat/CMakeLists.txt @@ -1,25 +1,26 @@ # Librspamdserver -SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c - ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) +SET(LIBSTATSRC ${CMAKE_CURRENT_SOURCE_DIR}/stat_config.c + ${CMAKE_CURRENT_SOURCE_DIR}/stat_process.c) -SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c - ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) +SET(TOKENIZERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizers.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/tokenizer_manager.c + ${CMAKE_CURRENT_SOURCE_DIR}/tokenizers/osb.c) -SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c - ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) +SET(CLASSIFIERSSRC ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/bayes.c + ${CMAKE_CURRENT_SOURCE_DIR}/classifiers/lua_classifier.c) -SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c - ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c - ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx - ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx - ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) +SET(BACKENDSSRC ${CMAKE_CURRENT_SOURCE_DIR}/backends/mmaped_file.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/sqlite3_backend.c + ${CMAKE_CURRENT_SOURCE_DIR}/backends/cdb_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/http_backend.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/backends/redis_backend.cxx) -SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c +SET(CACHESSRC ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/sqlite3_cache.c ${CMAKE_CURRENT_SOURCE_DIR}/learn_cache/redis_cache.cxx) SET(RSPAMD_STAT ${LIBSTATSRC} - ${TOKENIZERSSRC} - ${CLASSIFIERSSRC} - ${BACKENDSSRC} - ${CACHESSRC} PARENT_SCOPE) + ${TOKENIZERSSRC} + ${CLASSIFIERSSRC} + ${BACKENDSSRC} + ${CACHESSRC} PARENT_SCOPE) diff --git a/src/libstat/stat_api.h b/src/libstat/stat_api.h index f28922588..811566ad3 100644 --- a/src/libstat/stat_api.h +++ b/src/libstat/stat_api.h @@ -20,6 +20,7 @@ #include "task.h" #include "lua/lua_common.h" #include "contrib/libev/ev.h" +#include "libserver/word.h" #ifdef __cplusplus extern "C" { @@ -30,36 +31,14 @@ extern "C" { * High level statistics API */ -#define RSPAMD_STAT_TOKEN_FLAG_TEXT (1u << 0) -#define RSPAMD_STAT_TOKEN_FLAG_META (1u << 1) -#define RSPAMD_STAT_TOKEN_FLAG_LUA_META (1u << 2) -#define RSPAMD_STAT_TOKEN_FLAG_EXCEPTION (1u << 3) -#define RSPAMD_STAT_TOKEN_FLAG_HEADER (1u << 4) -#define RSPAMD_STAT_TOKEN_FLAG_UNIGRAM (1u << 5) -#define RSPAMD_STAT_TOKEN_FLAG_UTF (1u << 6) -#define RSPAMD_STAT_TOKEN_FLAG_NORMALISED (1u << 7) -#define RSPAMD_STAT_TOKEN_FLAG_STEMMED (1u << 8) -#define RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE (1u << 9) -#define RSPAMD_STAT_TOKEN_FLAG_STOP_WORD (1u << 10) -#define RSPAMD_STAT_TOKEN_FLAG_SKIPPED (1u << 11) -#define RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES (1u << 12) -#define RSPAMD_STAT_TOKEN_FLAG_EMOJI (1u << 13) - -typedef struct rspamd_stat_token_s { - rspamd_ftok_t original; /* utf8 raw */ - rspamd_ftok_unicode_t unicode; /* array of unicode characters, normalized, lowercased */ - rspamd_ftok_t normalized; /* normalized and lowercased utf8 */ - rspamd_ftok_t stemmed; /* stemmed utf8 */ - unsigned int flags; -} rspamd_stat_token_t; #define RSPAMD_TOKEN_VALUE_TYPE float typedef struct token_node_s { uint64_t data; unsigned int window_idx; unsigned int flags; - rspamd_stat_token_t *t1; - rspamd_stat_token_t *t2; + rspamd_word_t *t1; + rspamd_word_t *t2; RSPAMD_TOKEN_VALUE_TYPE values[0]; } rspamd_token_t; diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c index 17caf4cc6..176064087 100644 --- a/src/libstat/stat_process.c +++ b/src/libstat/stat_process.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,12 +36,13 @@ static void rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, struct rspamd_task *task) { - GArray *ar; - rspamd_stat_token_t elt; + rspamd_words_t *words; + rspamd_word_t elt; unsigned int i; lua_State *L = task->cfg->lua_state; - ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16); + words = rspamd_mempool_alloc(task->task_pool, sizeof(*words)); + kv_init(*words); memset(&elt, 0, sizeof(elt)); elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; @@ -87,7 +88,7 @@ rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, elt.normalized.begin = elt.original.begin; elt.normalized.len = elt.original.len; - g_array_append_val(ar, elt); + kv_push_safe(rspamd_word_t, *words, elt, meta_words_error); } lua_pop(L, 1); @@ -99,17 +100,20 @@ rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, } - if (ar->len > 0) { + if (kv_size(*words) > 0) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - ar, + words, TRUE, "M", task->tokens); } - rspamd_mempool_add_destructor(task->task_pool, - rspamd_array_free_hard, ar); + return; +meta_words_error: + + msg_err("cannot process meta words for task" + "memory allocation error, skipping the remaining"); } /* @@ -134,8 +138,8 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) { - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { - reserved_len += part->utf_words->len; + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a) { + reserved_len += kv_size(part->utf_words); } /* XXX: normal window size */ reserved_len += 5; @@ -149,9 +153,9 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) { - if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words.a) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - part->utf_words, IS_TEXT_PART_UTF(part), + &part->utf_words, IS_TEXT_PART_UTF(part), NULL, task->tokens); } @@ -163,10 +167,10 @@ void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, } } - if (task->meta_words != NULL) { + if (task->meta_words.a) { st_ctx->tokenizer->tokenize_func(st_ctx, task, - task->meta_words, + &task->meta_words, TRUE, "SUBJECT", task->tokens); diff --git a/src/libstat/tokenizers/custom_tokenizer.h b/src/libstat/tokenizers/custom_tokenizer.h new file mode 100644 index 000000000..bc173a1da --- /dev/null +++ b/src/libstat/tokenizers/custom_tokenizer.h @@ -0,0 +1,177 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_CUSTOM_TOKENIZER_H +#define RSPAMD_CUSTOM_TOKENIZER_H + +/* Check if we're being included by internal Rspamd code or external plugins */ +#ifdef RSPAMD_TOKENIZER_INTERNAL +/* Internal Rspamd usage - use the full headers */ +#include "config.h" +#include "ucl.h" +#include "libserver/word.h" +#else +/* External plugin usage - use standalone types */ +#include "rspamd_tokenizer_types.h" +/* Forward declaration for UCL object - plugins should include ucl.h if needed */ +typedef struct ucl_object_s ucl_object_t; +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define RSPAMD_CUSTOM_TOKENIZER_API_VERSION 1 + +/** + * Tokenization result - compatible with both internal and external usage + */ +typedef rspamd_words_t rspamd_tokenizer_result_t; + +/** + * Custom tokenizer API that must be implemented by language-specific tokenizer plugins + * All functions use only plain C types to ensure clean boundaries + */ +typedef struct rspamd_custom_tokenizer_api { + /* API version for compatibility checking */ + unsigned int api_version; + + /* Name of the tokenizer (e.g., "japanese_mecab") */ + const char *name; + + /** + * Global initialization function called once when the tokenizer is loaded + * @param config UCL configuration object for this tokenizer (may be NULL) + * @param error_buf Buffer for error message (at least 256 bytes) + * @return 0 on success, non-zero on failure + */ + int (*init)(const ucl_object_t *config, char *error_buf, size_t error_buf_size); + + /** + * Global cleanup function called when the tokenizer is unloaded + */ + void (*deinit)(void); + + /** + * Quick language detection to check if this tokenizer can handle the text + * @param text UTF-8 text to analyze + * @param len Length of the text in bytes + * @return Confidence score 0.0-1.0, or -1.0 if cannot handle + */ + double (*detect_language)(const char *text, size_t len); + + /** + * Main tokenization function + * @param text UTF-8 text to tokenize + * @param len Length of the text in bytes + * @param result Output kvec to fill with rspamd_word_t elements + * @return 0 on success, non-zero on failure + * + * The tokenizer should allocate result->a using its own allocator + * Rspamd will call cleanup_result() to free it after processing + */ + int (*tokenize)(const char *text, size_t len, + rspamd_tokenizer_result_t *result); + + /** + * Cleanup the result from tokenize() + * @param result Result kvec returned by tokenize() + * + * This function should free result->a using the same allocator + * that was used in tokenize() and reset the kvec fields. + * This ensures proper memory management across DLL boundaries. + * Note: This does NOT free the result structure itself, only its contents. + */ + void (*cleanup_result)(rspamd_tokenizer_result_t *result); + + /** + * Optional: Get language hint for better language detection + * @return Language code (e.g., "ja", "zh") or NULL + */ + const char *(*get_language_hint)(void); + + /** + * Optional: Get minimum confidence threshold for this tokenizer + * @return Minimum confidence (0.0-1.0) or -1.0 to use default + */ + double (*get_min_confidence)(void); + +} rspamd_custom_tokenizer_api_t; + +/** + * Entry point function that plugins must export + * Must be named "rspamd_tokenizer_get_api" + */ +typedef const rspamd_custom_tokenizer_api_t *(*rspamd_tokenizer_get_api_func)(void); + +/* Internal Rspamd structures - not exposed to plugins */ +#ifdef RSPAMD_TOKENIZER_INTERNAL + +/** + * Custom tokenizer instance + */ +struct rspamd_custom_tokenizer { + char *name; /* Tokenizer name from config */ + char *path; /* Path to .so file */ + void *handle; /* dlopen handle */ + const rspamd_custom_tokenizer_api_t *api; /* API functions */ + double priority; /* Detection priority */ + double min_confidence; /* Minimum confidence threshold */ + gboolean enabled; /* Is tokenizer enabled */ + ucl_object_t *config; /* Tokenizer-specific config */ +}; + +/** + * Tokenizer manager structure + */ +struct rspamd_tokenizer_manager { + GHashTable *tokenizers; /* name -> rspamd_custom_tokenizer */ + GArray *detection_order; /* Ordered by priority */ + rspamd_mempool_t *pool; + double default_threshold; /* Default confidence threshold */ +}; + +/* Manager functions */ +struct rspamd_tokenizer_manager *rspamd_tokenizer_manager_new(rspamd_mempool_t *pool); +void rspamd_tokenizer_manager_destroy(struct rspamd_tokenizer_manager *mgr); + +gboolean rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr, + const char *name, + const ucl_object_t *config, + GError **err); + +struct rspamd_custom_tokenizer *rspamd_tokenizer_manager_detect( + struct rspamd_tokenizer_manager *mgr, + const char *text, size_t len, + double *confidence, + const char *lang_hint, + const char **detected_lang_hint); + +/* Helper function to tokenize with exceptions handling */ +rspamd_tokenizer_result_t *rspamd_custom_tokenizer_tokenize_with_exceptions( + struct rspamd_custom_tokenizer *tokenizer, + const char *text, + gsize len, + GList *exceptions, + rspamd_mempool_t *pool); + +#endif /* RSPAMD_TOKENIZER_INTERNAL */ + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_CUSTOM_TOKENIZER_H */ diff --git a/src/libstat/tokenizers/osb.c b/src/libstat/tokenizers/osb.c index 0bc3414a5..360c71d36 100644 --- a/src/libstat/tokenizers/osb.c +++ b/src/libstat/tokenizers/osb.c @@ -21,6 +21,7 @@ #include "tokenizers.h" #include "stat_internal.h" #include "libmime/lang_detection.h" +#include "libserver/word.h" /* Size for features pipe */ #define DEFAULT_FEATURE_WINDOW_SIZE 2 @@ -268,7 +269,7 @@ struct token_pipe_entry { int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result) @@ -282,7 +283,7 @@ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, gsize token_size; unsigned int processed = 0, i, w, window_size, token_flags = 0; - if (words == NULL) { + if (words == NULL || !words->a) { return FALSE; } @@ -306,8 +307,8 @@ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, sizeof(RSPAMD_TOKEN_VALUE_TYPE) * ctx->statfiles->len; g_assert(token_size > 0); - for (w = 0; w < words->len; w++) { - token = &g_array_index(words, rspamd_stat_token_t, w); + for (w = 0; w < kv_size(*words); w++) { + token = &kv_A(*words, w); token_flags = token->flags; const char *begin; gsize len; diff --git a/src/libstat/tokenizers/rspamd_tokenizer_types.h b/src/libstat/tokenizers/rspamd_tokenizer_types.h new file mode 100644 index 000000000..eb8518290 --- /dev/null +++ b/src/libstat/tokenizers/rspamd_tokenizer_types.h @@ -0,0 +1,89 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_TOKENIZER_TYPES_H +#define RSPAMD_TOKENIZER_TYPES_H + +/* + * Standalone type definitions for custom tokenizers + * This header is completely self-contained and does not depend on any external libraries. + * Custom tokenizers should include only this header to get access to all necessary types. + */ + +#include <stdint.h> +#include <stddef.h> + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Basic string token structure + */ +typedef struct rspamd_ftok { + size_t len; + const char *begin; +} rspamd_ftok_t; + +/** + * Unicode string token structure + */ +typedef struct rspamd_ftok_unicode { + size_t len; + const uint32_t *begin; +} rspamd_ftok_unicode_t; + +/* Word flags */ +#define RSPAMD_WORD_FLAG_TEXT (1u << 0u) +#define RSPAMD_WORD_FLAG_META (1u << 1u) +#define RSPAMD_WORD_FLAG_LUA_META (1u << 2u) +#define RSPAMD_WORD_FLAG_EXCEPTION (1u << 3u) +#define RSPAMD_WORD_FLAG_HEADER (1u << 4u) +#define RSPAMD_WORD_FLAG_UNIGRAM (1u << 5u) +#define RSPAMD_WORD_FLAG_UTF (1u << 6u) +#define RSPAMD_WORD_FLAG_NORMALISED (1u << 7u) +#define RSPAMD_WORD_FLAG_STEMMED (1u << 8u) +#define RSPAMD_WORD_FLAG_BROKEN_UNICODE (1u << 9u) +#define RSPAMD_WORD_FLAG_STOP_WORD (1u << 10u) +#define RSPAMD_WORD_FLAG_SKIPPED (1u << 11u) +#define RSPAMD_WORD_FLAG_INVISIBLE_SPACES (1u << 12u) +#define RSPAMD_WORD_FLAG_EMOJI (1u << 13u) + +/** + * Word structure + */ +typedef struct rspamd_word { + rspamd_ftok_t original; + rspamd_ftok_unicode_t unicode; + rspamd_ftok_t normalized; + rspamd_ftok_t stemmed; + unsigned int flags; +} rspamd_word_t; + +/** + * Array of words + */ +typedef struct rspamd_words { + rspamd_word_t *a; + size_t n; + size_t m; +} rspamd_words_t; + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_TOKENIZER_TYPES_H */ diff --git a/src/libstat/tokenizers/tokenizer_manager.c b/src/libstat/tokenizers/tokenizer_manager.c new file mode 100644 index 000000000..e6fb5e8d8 --- /dev/null +++ b/src/libstat/tokenizers/tokenizer_manager.c @@ -0,0 +1,500 @@ +/* + * Copyright 2025 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" +#include "tokenizers.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "custom_tokenizer.h" +#include "libutil/util.h" +#include "libserver/logger.h" +#include <dlfcn.h> + +#define msg_err_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_CRITICAL, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_warn_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_WARNING, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_info_tokenizer(...) rspamd_default_log_function(G_LOG_LEVEL_INFO, \ + "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) +#define msg_debug_tokenizer(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_tokenizer_log_id, "tokenizer", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(tokenizer) + +static void +rspamd_custom_tokenizer_dtor(gpointer p) +{ + struct rspamd_custom_tokenizer *tok = p; + + if (tok) { + if (tok->api && tok->api->deinit) { + tok->api->deinit(); + } + + if (tok->handle) { + dlclose(tok->handle); + } + + if (tok->config) { + ucl_object_unref(tok->config); + } + + g_free(tok->name); + g_free(tok->path); + g_free(tok); + } +} + +static int +rspamd_custom_tokenizer_priority_cmp(gconstpointer a, gconstpointer b) +{ + const struct rspamd_custom_tokenizer *t1 = *(const struct rspamd_custom_tokenizer **) a; + const struct rspamd_custom_tokenizer *t2 = *(const struct rspamd_custom_tokenizer **) b; + + /* Higher priority first */ + if (t1->priority > t2->priority) { + return -1; + } + else if (t1->priority < t2->priority) { + return 1; + } + + return 0; +} + +struct rspamd_tokenizer_manager * +rspamd_tokenizer_manager_new(rspamd_mempool_t *pool) +{ + struct rspamd_tokenizer_manager *mgr; + + mgr = rspamd_mempool_alloc0(pool, sizeof(*mgr)); + mgr->pool = pool; + mgr->tokenizers = g_hash_table_new_full(rspamd_strcase_hash, + rspamd_strcase_equal, + NULL, + rspamd_custom_tokenizer_dtor); + mgr->detection_order = g_array_new(FALSE, FALSE, sizeof(struct rspamd_custom_tokenizer *)); + mgr->default_threshold = 0.7; /* Default confidence threshold */ + + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) g_hash_table_unref, + mgr->tokenizers); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) rspamd_array_free_hard, + mgr->detection_order); + + msg_info_tokenizer("created custom tokenizer manager with default confidence threshold %.3f", + mgr->default_threshold); + + return mgr; +} + +void rspamd_tokenizer_manager_destroy(struct rspamd_tokenizer_manager *mgr) +{ + /* Cleanup is handled by memory pool destructors */ +} + +gboolean +rspamd_tokenizer_manager_load_tokenizer(struct rspamd_tokenizer_manager *mgr, + const char *name, + const ucl_object_t *config, + GError **err) +{ + struct rspamd_custom_tokenizer *tok; + const ucl_object_t *elt; + rspamd_tokenizer_get_api_func get_api; + const rspamd_custom_tokenizer_api_t *api; + void *handle; + const char *path; + gboolean enabled = TRUE; + double priority = 50.0; + char error_buf[256]; + + g_assert(mgr != NULL); + g_assert(name != NULL); + g_assert(config != NULL); + + msg_info_tokenizer("starting to load custom tokenizer '%s'", name); + + /* Check if enabled */ + elt = ucl_object_lookup(config, "enabled"); + if (elt && ucl_object_type(elt) == UCL_BOOLEAN) { + enabled = ucl_object_toboolean(elt); + } + + if (!enabled) { + msg_info_tokenizer("custom tokenizer '%s' is disabled", name); + return TRUE; + } + + /* Get path */ + elt = ucl_object_lookup(config, "path"); + if (!elt || ucl_object_type(elt) != UCL_STRING) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "missing 'path' for tokenizer %s", name); + return FALSE; + } + path = ucl_object_tostring(elt); + msg_info_tokenizer("custom tokenizer '%s' will be loaded from path: %s", name, path); + + /* Get priority */ + elt = ucl_object_lookup(config, "priority"); + if (elt) { + priority = ucl_object_todouble(elt); + } + msg_info_tokenizer("custom tokenizer '%s' priority set to %.1f", name, priority); + + /* Load the shared library */ + msg_info_tokenizer("loading shared library for custom tokenizer '%s'", name); + handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); + if (!handle) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "cannot load tokenizer %s from %s: %s", + name, path, dlerror()); + return FALSE; + } + msg_info_tokenizer("successfully loaded shared library for custom tokenizer '%s'", name); + + /* Get the API entry point */ + msg_info_tokenizer("looking up API entry point for custom tokenizer '%s'", name); + get_api = (rspamd_tokenizer_get_api_func) dlsym(handle, "rspamd_tokenizer_get_api"); + if (!get_api) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "cannot find entry point in %s: %s", + path, dlerror()); + return FALSE; + } + + /* Get the API */ + msg_info_tokenizer("calling API entry point for custom tokenizer '%s'", name); + api = get_api(); + if (!api) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "tokenizer %s returned NULL API", name); + return FALSE; + } + msg_info_tokenizer("successfully obtained API from custom tokenizer '%s'", name); + + /* Check API version */ + msg_info_tokenizer("checking API version for custom tokenizer '%s' (got %u, expected %u)", + name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION); + if (api->api_version != RSPAMD_CUSTOM_TOKENIZER_API_VERSION) { + dlclose(handle); + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "tokenizer %s has incompatible API version %u (expected %u)", + name, api->api_version, RSPAMD_CUSTOM_TOKENIZER_API_VERSION); + return FALSE; + } + + /* Create tokenizer instance */ + tok = g_malloc0(sizeof(*tok)); + tok->name = g_strdup(name); + tok->path = g_strdup(path); + tok->handle = handle; + tok->api = api; + tok->priority = priority; + tok->enabled = enabled; + + /* Get tokenizer config */ + elt = ucl_object_lookup(config, "config"); + if (elt) { + tok->config = ucl_object_ref(elt); + } + + /* Get minimum confidence */ + if (api->get_min_confidence) { + tok->min_confidence = api->get_min_confidence(); + msg_info_tokenizer("custom tokenizer '%s' provides minimum confidence threshold: %.3f", + name, tok->min_confidence); + } + else { + tok->min_confidence = mgr->default_threshold; + msg_info_tokenizer("custom tokenizer '%s' using default confidence threshold: %.3f", + name, tok->min_confidence); + } + + /* Initialize the tokenizer */ + if (api->init) { + msg_info_tokenizer("initializing custom tokenizer '%s'", name); + error_buf[0] = '\0'; + if (api->init(tok->config, error_buf, sizeof(error_buf)) != 0) { + g_set_error(err, g_quark_from_static_string("tokenizer"), + EINVAL, "failed to initialize tokenizer %s: %s", + name, error_buf[0] ? error_buf : "unknown error"); + rspamd_custom_tokenizer_dtor(tok); + return FALSE; + } + msg_info_tokenizer("successfully initialized custom tokenizer '%s'", name); + } + else { + msg_info_tokenizer("custom tokenizer '%s' does not require initialization", name); + } + + /* Add to manager */ + g_hash_table_insert(mgr->tokenizers, tok->name, tok); + g_array_append_val(mgr->detection_order, tok); + + /* Re-sort by priority */ + g_array_sort(mgr->detection_order, rspamd_custom_tokenizer_priority_cmp); + msg_info_tokenizer("custom tokenizer '%s' registered and sorted by priority (total tokenizers: %u)", + name, mgr->detection_order->len); + + msg_info_tokenizer("successfully loaded custom tokenizer '%s' (priority %.1f) from %s", + name, priority, path); + + return TRUE; +} + +struct rspamd_custom_tokenizer * +rspamd_tokenizer_manager_detect(struct rspamd_tokenizer_manager *mgr, + const char *text, size_t len, + double *confidence, + const char *lang_hint, + const char **detected_lang_hint) +{ + struct rspamd_custom_tokenizer *tok, *best_tok = NULL; + double conf, best_conf = 0.0; + unsigned int i; + + g_assert(mgr != NULL); + g_assert(text != NULL); + + msg_debug_tokenizer("starting tokenizer detection for text of length %zu", len); + + if (confidence) { + *confidence = 0.0; + } + + if (detected_lang_hint) { + *detected_lang_hint = NULL; + } + + /* If we have a language hint, try to find a tokenizer for that language first */ + if (lang_hint) { + msg_info_tokenizer("trying to find tokenizer for language hint: %s", lang_hint); + for (i = 0; i < mgr->detection_order->len; i++) { + tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i); + + if (!tok->enabled || !tok->api->get_language_hint) { + continue; + } + + /* Check if this tokenizer handles the hinted language */ + const char *tok_lang = tok->api->get_language_hint(); + if (tok_lang && g_ascii_strcasecmp(tok_lang, lang_hint) == 0) { + msg_info_tokenizer("found tokenizer '%s' for language hint '%s'", tok->name, lang_hint); + /* Found a tokenizer for this language, check if it actually detects it */ + if (tok->api->detect_language) { + conf = tok->api->detect_language(text, len); + msg_info_tokenizer("tokenizer '%s' confidence for hinted language: %.3f (threshold: %.3f)", + tok->name, conf, tok->min_confidence); + if (conf >= tok->min_confidence) { + /* Use this tokenizer */ + msg_info_tokenizer("using tokenizer '%s' for language hint '%s' with confidence %.3f", + tok->name, lang_hint, conf); + if (confidence) { + *confidence = conf; + } + if (detected_lang_hint) { + *detected_lang_hint = tok_lang; + } + return tok; + } + } + } + } + msg_info_tokenizer("no suitable tokenizer found for language hint '%s', falling back to general detection", lang_hint); + } + + /* Try each tokenizer in priority order */ + msg_info_tokenizer("trying %u tokenizers for general detection", mgr->detection_order->len); + for (i = 0; i < mgr->detection_order->len; i++) { + tok = g_array_index(mgr->detection_order, struct rspamd_custom_tokenizer *, i); + + if (!tok->enabled || !tok->api->detect_language) { + msg_debug_tokenizer("skipping tokenizer '%s' (enabled: %s, has detect_language: %s)", + tok->name, tok->enabled ? "yes" : "no", + tok->api->detect_language ? "yes" : "no"); + continue; + } + + conf = tok->api->detect_language(text, len); + msg_info_tokenizer("tokenizer '%s' detection confidence: %.3f (threshold: %.3f, current best: %.3f)", + tok->name, conf, tok->min_confidence, best_conf); + + if (conf > best_conf && conf >= tok->min_confidence) { + best_conf = conf; + best_tok = tok; + msg_info_tokenizer("tokenizer '%s' is new best with confidence %.3f", tok->name, best_conf); + + /* Early exit if very confident */ + if (conf >= 0.95) { + msg_info_tokenizer("very high confidence (%.3f >= 0.95), using tokenizer '%s' immediately", + conf, tok->name); + break; + } + } + } + + if (best_tok) { + msg_info_tokenizer("selected tokenizer '%s' with confidence %.3f", best_tok->name, best_conf); + if (confidence) { + *confidence = best_conf; + } + + if (detected_lang_hint && best_tok->api->get_language_hint) { + *detected_lang_hint = best_tok->api->get_language_hint(); + msg_info_tokenizer("detected language hint: %s", *detected_lang_hint); + } + } + else { + msg_info_tokenizer("no suitable tokenizer found during detection"); + } + + return best_tok; +} + +/* Helper function to tokenize with a custom tokenizer handling exceptions */ +rspamd_tokenizer_result_t * +rspamd_custom_tokenizer_tokenize_with_exceptions( + struct rspamd_custom_tokenizer *tokenizer, + const char *text, + gsize len, + GList *exceptions, + rspamd_mempool_t *pool) +{ + rspamd_tokenizer_result_t *words; + rspamd_tokenizer_result_t result; + struct rspamd_process_exception *ex; + GList *cur_ex = exceptions; + gsize pos = 0; + unsigned int i; + int ret; + + /* Allocate result kvec in pool */ + words = rspamd_mempool_alloc(pool, sizeof(*words)); + kv_init(*words); + + /* If no exceptions, tokenize the whole text */ + if (!exceptions) { + kv_init(result); + + ret = tokenizer->api->tokenize(text, len, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result to output */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + kv_push(rspamd_word_t, *words, tok); + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + + return words; + } + + /* Process text with exceptions */ + while (pos < len && cur_ex) { + ex = (struct rspamd_process_exception *) cur_ex->data; + + /* Tokenize text before exception */ + if (ex->pos > pos) { + gsize segment_len = ex->pos - pos; + kv_init(result); + + ret = tokenizer->api->tokenize(text + pos, segment_len, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result, adjusting positions for segment offset */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + + /* Adjust pointers to point to the original text */ + gsize offset_in_segment = tok.original.begin - (text + pos); + if (offset_in_segment < segment_len) { + tok.original.begin = text + pos + offset_in_segment; + /* Ensure we don't go past the exception boundary */ + if (tok.original.begin + tok.original.len <= text + ex->pos) { + kv_push(rspamd_word_t, *words, tok); + } + } + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + } + + /* Add exception as a special token */ + rspamd_word_t ex_tok; + memset(&ex_tok, 0, sizeof(ex_tok)); + + if (ex->type == RSPAMD_EXCEPTION_URL) { + ex_tok.original.begin = "!!EX!!"; + ex_tok.original.len = 6; + } + else { + ex_tok.original.begin = text + ex->pos; + ex_tok.original.len = ex->len; + } + ex_tok.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; + kv_push(rspamd_word_t, *words, ex_tok); + + /* Move past exception */ + pos = ex->pos + ex->len; + cur_ex = g_list_next(cur_ex); + } + + /* Process remaining text after last exception */ + if (pos < len) { + kv_init(result); + + ret = tokenizer->api->tokenize(text + pos, len - pos, &result); + if (ret == 0 && result.a) { + /* Copy tokens from result, adjusting positions for segment offset */ + for (i = 0; i < kv_size(result); i++) { + rspamd_word_t tok = kv_A(result, i); + + /* Adjust pointers to point to the original text */ + gsize offset_in_segment = tok.original.begin - (text + pos); + if (offset_in_segment < (len - pos)) { + tok.original.begin = text + pos + offset_in_segment; + kv_push(rspamd_word_t, *words, tok); + } + } + + /* Use tokenizer's cleanup function */ + if (tokenizer->api->cleanup_result) { + tokenizer->api->cleanup_result(&result); + } + } + } + + return words; +} diff --git a/src/libstat/tokenizers/tokenizers.c b/src/libstat/tokenizers/tokenizers.c index 0ea1bcfc6..8a9f42992 100644 --- a/src/libstat/tokenizers/tokenizers.c +++ b/src/libstat/tokenizers/tokenizers.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ #include "contrib/mumhash/mum.h" #include "libmime/lang_detection.h" #include "libstemmer.h" +#define RSPAMD_TOKENIZER_INTERNAL +#include "custom_tokenizer.h" #include <unicode/utf8.h> #include <unicode/uchar.h> @@ -35,8 +37,8 @@ #include <math.h> -typedef gboolean (*token_get_function)(rspamd_stat_token_t *buf, char const **pos, - rspamd_stat_token_t *token, +typedef gboolean (*token_get_function)(rspamd_word_t *buf, char const **pos, + rspamd_word_t *token, GList **exceptions, gsize *rl, gboolean check_signature); const char t_delimiters[256] = { @@ -69,8 +71,8 @@ const char t_delimiters[256] = { /* Get next word from specified f_str_t buf */ static gboolean -rspamd_tokenizer_get_word_raw(rspamd_stat_token_t *buf, - char const **cur, rspamd_stat_token_t *token, +rspamd_tokenizer_get_word_raw(rspamd_word_t *buf, + char const **cur, rspamd_word_t *token, GList **exceptions, gsize *rl, gboolean unused) { gsize remain, pos; @@ -164,7 +166,7 @@ rspamd_tokenize_check_limit(gboolean decay, unsigned int nwords, uint64_t *hv, uint64_t *prob, - const rspamd_stat_token_t *token, + const rspamd_word_t *token, gssize remain, gssize total) { @@ -242,9 +244,9 @@ rspamd_utf_word_valid(const unsigned char *text, const unsigned char *end, } while (0) static inline void -rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) +rspamd_tokenize_exception(struct rspamd_process_exception *ex, rspamd_words_t *res) { - rspamd_stat_token_t token; + rspamd_word_t token; memset(&token, 0, sizeof(token)); @@ -253,7 +255,7 @@ rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) token.original.len = sizeof("!!EX!!") - 1; token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, exception_error); token.flags = 0; } else if (ex->type == RSPAMD_EXCEPTION_URL) { @@ -271,28 +273,33 @@ rspamd_tokenize_exception(struct rspamd_process_exception *ex, GArray *res) } token.flags = RSPAMD_STAT_TOKEN_FLAG_EXCEPTION; - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, exception_error); token.flags = 0; } + return; + +exception_error: + /* On error, just skip this exception token */ + return; } -GArray * +rspamd_words_t * rspamd_tokenize_text(const char *text, gsize len, const UText *utxt, enum rspamd_tokenize_type how, struct rspamd_config *cfg, GList *exceptions, uint64_t *hash, - GArray *cur_words, + rspamd_words_t *output_kvec, rspamd_mempool_t *pool) { - rspamd_stat_token_t token, buf; + rspamd_word_t token, buf; const char *pos = NULL; gsize l = 0; - GArray *res; + rspamd_words_t *res; GList *cur = exceptions; - unsigned int min_len = 0, max_len = 0, word_decay = 0, initial_size = 128; + unsigned int min_len = 0, max_len = 0, word_decay = 0; uint64_t hv = 0; gboolean decay = FALSE, long_text_mode = FALSE; uint64_t prob = 0; @@ -300,9 +307,12 @@ rspamd_tokenize_text(const char *text, gsize len, static const gsize long_text_limit = 1 * 1024 * 1024; static const ev_tstamp max_exec_time = 0.2; /* 200 ms */ ev_tstamp start; + struct rspamd_custom_tokenizer *custom_tok = NULL; + double custom_confidence = 0.0; + const char *detected_lang = NULL; if (text == NULL) { - return cur_words; + return output_kvec; } if (len > long_text_limit) { @@ -323,15 +333,59 @@ rspamd_tokenize_text(const char *text, gsize len, min_len = cfg->min_word_len; max_len = cfg->max_word_len; word_decay = cfg->words_decay; - initial_size = word_decay * 2; } - if (!cur_words) { - res = g_array_sized_new(FALSE, FALSE, sizeof(rspamd_stat_token_t), - initial_size); + if (!output_kvec) { + res = pool ? rspamd_mempool_alloc0(pool, sizeof(*res)) : g_malloc0(sizeof(*res)); + ; } else { - res = cur_words; + res = output_kvec; + } + + /* Try custom tokenizers first if we're in UTF mode */ + if (cfg && cfg->tokenizer_manager && how == RSPAMD_TOKENIZE_UTF && utxt != NULL) { + custom_tok = rspamd_tokenizer_manager_detect( + cfg->tokenizer_manager, + text, len, + &custom_confidence, + NULL, /* no input language hint */ + &detected_lang); + + if (custom_tok && custom_confidence >= custom_tok->min_confidence) { + /* Use custom tokenizer with exception handling */ + rspamd_tokenizer_result_t *custom_res = rspamd_custom_tokenizer_tokenize_with_exceptions( + custom_tok, text, len, exceptions, pool); + + if (custom_res) { + msg_debug_pool("using custom tokenizer %s (confidence: %.2f) for text tokenization", + custom_tok->name, custom_confidence); + + /* Copy custom tokenizer results to output kvec */ + for (unsigned int i = 0; i < kv_size(*custom_res); i++) { + kv_push_safe(rspamd_word_t, *res, kv_A(*custom_res, i), custom_tokenizer_error); + } + + /* Calculate hash if needed */ + if (hash && kv_size(*res) > 0) { + for (unsigned int i = 0; i < kv_size(*res); i++) { + rspamd_word_t *t = &kv_A(*res, i); + if (t->original.len >= sizeof(uint64_t)) { + uint64_t tmp; + memcpy(&tmp, t->original.begin, sizeof(tmp)); + hv = mum_hash_step(hv, tmp); + } + } + *hash = mum_hash_finish(hv); + } + + return res; + } + else { + msg_warn_pool("custom tokenizer %s failed to tokenize text, falling back to default", + custom_tok->name); + } + } } if (G_UNLIKELY(how == RSPAMD_TOKENIZE_RAW || utxt == NULL)) { @@ -343,7 +397,7 @@ rspamd_tokenize_text(const char *text, gsize len, } if (token.original.len > 0 && - rspamd_tokenize_check_limit(decay, word_decay, res->len, + rspamd_tokenize_check_limit(decay, word_decay, kv_size(*res), &hv, &prob, &token, pos - text, len)) { if (!decay) { decay = TRUE; @@ -355,28 +409,28 @@ rspamd_tokenize_text(const char *text, gsize len, } if (long_text_mode) { - if ((res->len + 1) % 16 == 0) { + if ((kv_size(*res) + 1) % 16 == 0) { ev_tstamp now = ev_time(); if (now - start > max_exec_time) { msg_warn_pool_check( "too long time has been spent on tokenization:" - " %.1f ms, limit is %.1f ms; %d words added so far", + " %.1f ms, limit is %.1f ms; %z words added so far", (now - start) * 1e3, max_exec_time * 1e3, - res->len); + kv_size(*res)); goto end; } } } - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, tokenize_error); - if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + if (kv_size(*res) * sizeof(token) > (0x1ull << 30u)) { /* Due to bug in glib ! */ msg_err_pool_check( - "too many words found: %d, stop tokenization to avoid DoS", - res->len); + "too many words found: %z, stop tokenization to avoid DoS", + kv_size(*res)); goto end; } @@ -523,7 +577,7 @@ rspamd_tokenize_text(const char *text, gsize len, } if (token.original.len > 0 && - rspamd_tokenize_check_limit(decay, word_decay, res->len, + rspamd_tokenize_check_limit(decay, word_decay, kv_size(*res), &hv, &prob, &token, p, len)) { if (!decay) { decay = TRUE; @@ -536,15 +590,15 @@ rspamd_tokenize_text(const char *text, gsize len, if (token.original.len > 0) { /* Additional check for number of words */ - if (((gsize) res->len) * sizeof(token) > (0x1ull << 30u)) { + if (kv_size(*res) * sizeof(token) > (0x1ull << 30u)) { /* Due to bug in glib ! */ - msg_err("too many words found: %d, stop tokenization to avoid DoS", - res->len); + msg_err("too many words found: %z, stop tokenization to avoid DoS", + kv_size(*res)); goto end; } - g_array_append_val(res, token); + kv_push_safe(rspamd_word_t, *res, token, tokenize_error); } /* Also check for long text mode */ @@ -552,15 +606,15 @@ rspamd_tokenize_text(const char *text, gsize len, /* Check time each 128 words added */ const int words_check_mask = 0x7F; - if ((res->len & words_check_mask) == words_check_mask) { + if ((kv_size(*res) & words_check_mask) == words_check_mask) { ev_tstamp now = ev_time(); if (now - start > max_exec_time) { msg_warn_pool_check( "too long time has been spent on tokenization:" - " %.1f ms, limit is %.1f ms; %d words added so far", + " %.1f ms, limit is %.1f ms; %z words added so far", (now - start) * 1e3, max_exec_time * 1e3, - res->len); + kv_size(*res)); goto end; } @@ -590,8 +644,14 @@ end: } return res; + +tokenize_error: +custom_tokenizer_error: + msg_err_pool("failed to allocate memory for tokenization"); + return res; } + #undef SHIFT_EX static void @@ -625,32 +685,38 @@ rspamd_add_metawords_from_str(const char *beg, gsize len, #endif } + /* Initialize meta_words kvec if not already done */ + if (!task->meta_words.a) { + kv_init(task->meta_words); + } + if (valid_utf) { utext_openUTF8(&utxt, beg, len, &uc_err); - task->meta_words = rspamd_tokenize_text(beg, len, - &utxt, RSPAMD_TOKENIZE_UTF, - task->cfg, NULL, NULL, - task->meta_words, - task->task_pool); + rspamd_tokenize_text(beg, len, + &utxt, RSPAMD_TOKENIZE_UTF, + task->cfg, NULL, NULL, + &task->meta_words, + task->task_pool); utext_close(&utxt); } else { - task->meta_words = rspamd_tokenize_text(beg, len, - NULL, RSPAMD_TOKENIZE_RAW, - task->cfg, NULL, NULL, task->meta_words, - task->task_pool); + rspamd_tokenize_text(beg, len, + NULL, RSPAMD_TOKENIZE_RAW, + task->cfg, NULL, NULL, + &task->meta_words, + task->task_pool); } } void rspamd_tokenize_meta_words(struct rspamd_task *task) { unsigned int i = 0; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; if (MESSAGE_FIELD(task, subject)) { rspamd_add_metawords_from_str(MESSAGE_FIELD(task, subject), @@ -667,7 +733,7 @@ void rspamd_tokenize_meta_words(struct rspamd_task *task) } } - if (task->meta_words != NULL) { + if (task->meta_words.a) { const char *language = NULL; if (MESSAGE_FIELD(task, text_parts) && @@ -680,12 +746,12 @@ void rspamd_tokenize_meta_words(struct rspamd_task *task) } } - rspamd_normalize_words(task->meta_words, task->task_pool); - rspamd_stem_words(task->meta_words, task->task_pool, language, + rspamd_normalize_words(&task->meta_words, task->task_pool); + rspamd_stem_words(&task->meta_words, task->task_pool, language, task->lang_det); - for (i = 0; i < task->meta_words->len; i++) { - tok = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(task->meta_words); i++) { + tok = &kv_A(task->meta_words, i); tok->flags |= RSPAMD_STAT_TOKEN_FLAG_HEADER; } } @@ -759,7 +825,7 @@ rspamd_ucs32_to_normalised(rspamd_stat_token_t *tok, tok->normalized.begin = dest; } -void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool) +void rspamd_normalize_single_word(rspamd_word_t *tok, rspamd_mempool_t *pool) { UErrorCode uc_err = U_ZERO_ERROR; UConverter *utf8_converter; @@ -858,25 +924,27 @@ void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *po } } -void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool) + +void rspamd_normalize_words(rspamd_words_t *words, rspamd_mempool_t *pool) { - rspamd_stat_token_t *tok; + rspamd_word_t *tok; unsigned int i; - for (i = 0; i < words->len; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); rspamd_normalize_single_word(tok, pool); } } -void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, + +void rspamd_stem_words(rspamd_words_t *words, rspamd_mempool_t *pool, const char *language, struct rspamd_lang_detector *lang_detector) { static GHashTable *stemmers = NULL; struct sb_stemmer *stem = NULL; unsigned int i; - rspamd_stat_token_t *tok; + rspamd_word_t *tok; char *dest; gsize dlen; @@ -909,8 +977,18 @@ void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, stem = NULL; } } - for (i = 0; i < words->len; i++) { - tok = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); + + /* Skip stemming if token has already been stemmed by custom tokenizer */ + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_STEMMED) { + /* Already stemmed, just check for stop words */ + if (tok->stemmed.len > 0 && lang_detector != NULL && + rspamd_language_detector_is_stop_word(lang_detector, tok->stemmed.begin, tok->stemmed.len)) { + tok->flags |= RSPAMD_STAT_TOKEN_FLAG_STOP_WORD; + } + continue; + } if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { if (stem) { @@ -952,4 +1030,4 @@ void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, } } } -}
\ No newline at end of file +} diff --git a/src/libstat/tokenizers/tokenizers.h b/src/libstat/tokenizers/tokenizers.h index d4a8824a8..bb0bb54e2 100644 --- a/src/libstat/tokenizers/tokenizers.h +++ b/src/libstat/tokenizers/tokenizers.h @@ -1,5 +1,5 @@ /* - * Copyright 2023 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include "fstring.h" #include "rspamd.h" #include "stat_api.h" +#include "libserver/word.h" #include <unicode/utext.h> @@ -43,7 +44,7 @@ struct rspamd_stat_tokenizer { int (*tokenize_func)(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result); @@ -59,20 +60,20 @@ enum rspamd_tokenize_type { int token_node_compare_func(gconstpointer a, gconstpointer b); -/* Tokenize text into array of words (rspamd_stat_token_t type) */ -GArray *rspamd_tokenize_text(const char *text, gsize len, - const UText *utxt, - enum rspamd_tokenize_type how, - struct rspamd_config *cfg, - GList *exceptions, - uint64_t *hash, - GArray *cur_words, - rspamd_mempool_t *pool); +/* Tokenize text into kvec of words (rspamd_word_t type) */ +rspamd_words_t *rspamd_tokenize_text(const char *text, gsize len, + const UText *utxt, + enum rspamd_tokenize_type how, + struct rspamd_config *cfg, + GList *exceptions, + uint64_t *hash, + rspamd_words_t *output_kvec, + rspamd_mempool_t *pool); /* OSB tokenize function */ int rspamd_tokenizer_osb(struct rspamd_stat_ctx *ctx, struct rspamd_task *task, - GArray *words, + rspamd_words_t *words, gboolean is_utf, const char *prefix, GPtrArray *result); @@ -83,11 +84,11 @@ gpointer rspamd_tokenizer_osb_get_config(rspamd_mempool_t *pool, struct rspamd_lang_detector; -void rspamd_normalize_single_word(rspamd_stat_token_t *tok, rspamd_mempool_t *pool); +void rspamd_normalize_single_word(rspamd_word_t *tok, rspamd_mempool_t *pool); -void rspamd_normalize_words(GArray *words, rspamd_mempool_t *pool); - -void rspamd_stem_words(GArray *words, rspamd_mempool_t *pool, +/* Word processing functions */ +void rspamd_normalize_words(rspamd_words_t *words, rspamd_mempool_t *pool); +void rspamd_stem_words(rspamd_words_t *words, rspamd_mempool_t *pool, const char *language, struct rspamd_lang_detector *lang_detector); diff --git a/src/libutil/fstring.h b/src/libutil/fstring.h index 0792ab9fa..ca9f689c8 100644 --- a/src/libutil/fstring.h +++ b/src/libutil/fstring.h @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -30,8 +30,8 @@ extern "C" { */ typedef struct f_str_s { - gsize len; - gsize allocated; + size_t len; + size_t allocated; char str[]; } rspamd_fstring_t; @@ -40,12 +40,12 @@ typedef struct f_str_s { #define RSPAMD_FSTRING_LIT(lit) rspamd_fstring_new_init((lit), sizeof(lit) - 1) typedef struct f_str_tok { - gsize len; + size_t len; const char *begin; } rspamd_ftok_t; typedef struct f_str_unicode_tok { - gsize len; /* in UChar32 */ + size_t len; /* in UChar32 */ const UChar32 *begin; } rspamd_ftok_unicode_t; diff --git a/src/libutil/mem_pool.c b/src/libutil/mem_pool.c index 3dc67bc5f..575b4e497 100644 --- a/src/libutil/mem_pool.c +++ b/src/libutil/mem_pool.c @@ -403,9 +403,10 @@ rspamd_mempool_new_(gsize size, const char *tag, int flags, const char *loc) /* Generate new uid */ uint64_t uid = rspamd_random_uint64_fast(); - rspamd_encode_hex_buf((unsigned char *) &uid, sizeof(uid), - new_pool->tag.uid, sizeof(new_pool->tag.uid) - 1); - new_pool->tag.uid[sizeof(new_pool->tag.uid) - 1] = '\0'; + G_STATIC_ASSERT(sizeof(new_pool->tag.uid) >= sizeof(uid) * 2 + 1); + int enc_len = rspamd_encode_hex_buf((unsigned char *) &uid, sizeof(uid), + new_pool->tag.uid, sizeof(new_pool->tag.uid) - 1); + new_pool->tag.uid[enc_len] = '\0'; mem_pool_stat->pools_allocated++; diff --git a/src/libutil/mem_pool.h b/src/libutil/mem_pool.h index 651b44661..00d1a2067 100644 --- a/src/libutil/mem_pool.h +++ b/src/libutil/mem_pool.h @@ -71,7 +71,7 @@ struct f_str_s; #endif #define MEMPOOL_TAG_LEN 16 -#define MEMPOOL_UID_LEN 16 +#define MEMPOOL_UID_LEN 32 /* All pointers are aligned as this variable */ #define MIN_MEM_ALIGNMENT G_MEM_ALIGN diff --git a/src/libutil/radix.c b/src/libutil/radix.c index 2cae8e34a..bdd722b49 100644 --- a/src/libutil/radix.c +++ b/src/libutil/radix.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,7 +66,7 @@ radix_find_compressed(radix_compressed_t *tree, const uint8_t *key, gsize keylen uintptr_t radix_insert_compressed(radix_compressed_t *tree, - uint8_t *key, gsize keylen, + const uint8_t *key, gsize keylen, gsize masklen, uintptr_t value) { @@ -128,6 +128,39 @@ radix_insert_compressed(radix_compressed_t *tree, return old; } +uintptr_t +radix_insert_compressed_addr(radix_compressed_t *tree, + const rspamd_inet_addr_t *addr, + uintptr_t value) +{ + const unsigned char *key; + unsigned int klen = 0; + unsigned char buf[16]; + + if (addr == NULL) { + return RADIX_NO_VALUE; + } + + key = rspamd_inet_address_get_hash_key(addr, &klen); + + if (key && klen) { + if (klen == 4) { + /* Map to ipv6 */ + memset(buf, 0, 10); + buf[10] = 0xffu; + buf[11] = 0xffu; + memcpy(buf + 12, key, klen); + + key = buf; + klen = sizeof(buf); + } + + return radix_insert_compressed(tree, key, klen, 0, value); + } + + return RADIX_NO_VALUE; +} + radix_compressed_t * radix_create_compressed(const char *tree_name) diff --git a/src/libutil/radix.h b/src/libutil/radix.h index c4fe96441..8c1224707 100644 --- a/src/libutil/radix.h +++ b/src/libutil/radix.h @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "mem_pool.h" #include "util.h" -#define RADIX_NO_VALUE (uintptr_t) - 1 +#define RADIX_NO_VALUE (uintptr_t) -1 #ifdef __cplusplus extern "C" { @@ -39,11 +39,23 @@ typedef struct radix_tree_compressed radix_compressed_t; */ uintptr_t radix_insert_compressed(radix_compressed_t *tree, - uint8_t *key, gsize keylen, + const uint8_t *key, gsize keylen, gsize masklen, uintptr_t value); /** + * Insert new address to the radix trie (works for IPv4 or IPv6 addresses) + * @param tree radix trie + * @param addr address to insert + * @param value opaque value pointer + * @return previous value of the key or `RADIX_NO_VALUE` + */ +uintptr_t +radix_insert_compressed_addr(radix_compressed_t *tree, + const rspamd_inet_addr_t *addr, + uintptr_t value); + +/** * Find a key in a radix trie * @param tree radix trie * @param key key to find (bitstring) diff --git a/src/libutil/shingles.c b/src/libutil/shingles.c index 5fe110eb8..c69c42292 100644 --- a/src/libutil/shingles.c +++ b/src/libutil/shingles.c @@ -18,6 +18,7 @@ #include "cryptobox.h" #include "images.h" #include "libstat/stat_api.h" +#include "libserver/word.h" #define SHINGLES_WINDOW 3 #define SHINGLES_KEY_SIZE rspamd_cryptobox_SIPKEYBYTES @@ -112,7 +113,7 @@ rspamd_shingles_get_keys_cached(const unsigned char key[SHINGLES_KEY_SIZE]) } struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") - rspamd_shingles_from_text(GArray *input, + rspamd_shingles_from_text(rspamd_words_t *input, const unsigned char key[16], rspamd_mempool_t *pool, rspamd_shingles_filter filter, @@ -123,12 +124,16 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") uint64_t **hashes; unsigned char **keys; rspamd_fstring_t *row; - rspamd_stat_token_t *word; + rspamd_word_t *word; uint64_t val; int i, j, k; gsize hlen, ilen = 0, beg = 0, widx = 0; enum rspamd_cryptobox_fast_hash_type ht; + if (!input || !input->a) { + return NULL; + } + if (pool != NULL) { res = rspamd_mempool_alloc(pool, sizeof(*res)); } @@ -138,10 +143,10 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") row = rspamd_fstring_sized_new(256); - for (i = 0; i < input->len; i++) { - word = &g_array_index(input, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*input); i++) { + word = &kv_A(*input, i); - if (!((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0)) { + if (!((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0)) { ilen++; } } @@ -162,10 +167,10 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") for (j = beg; j < i; j++) { word = NULL; - while (widx < input->len) { - word = &g_array_index(input, rspamd_stat_token_t, widx); + while (widx < kv_size(*input)) { + word = &kv_A(*input, widx); - if ((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0) { + if ((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0) { widx++; } else { @@ -237,10 +242,10 @@ struct rspamd_shingle *RSPAMD_OPTIMIZE("unroll-loops") word = NULL; - while (widx < input->len) { - word = &g_array_index(input, rspamd_stat_token_t, widx); + while (widx < kv_size(*input)) { + word = &kv_A(*input, widx); - if ((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0) { + if ((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0) { widx++; } else { diff --git a/src/libutil/shingles.h b/src/libutil/shingles.h index fe6f16cf8..1ab2c6842 100644 --- a/src/libutil/shingles.h +++ b/src/libutil/shingles.h @@ -18,6 +18,7 @@ #include "config.h" #include "mem_pool.h" +#include "libserver/word.h" #define RSPAMD_SHINGLE_SIZE 32 @@ -48,14 +49,14 @@ typedef uint64_t (*rspamd_shingles_filter)(uint64_t *input, gsize count, /** * Generate shingles from the input of fixed size strings using lemmatizer * if needed - * @param input array of `rspamd_fstring_t` + * @param input kvec of `rspamd_word_t` * @param key secret key used to generate shingles * @param pool pool to allocate shingles array * @param filter hashes filtering function * @param filterd opaque data for filtering function * @return shingles array */ -struct rspamd_shingle *rspamd_shingles_from_text(GArray *input, +struct rspamd_shingle *rspamd_shingles_from_text(rspamd_words_t *input, const unsigned char key[16], rspamd_mempool_t *pool, rspamd_shingles_filter filter, diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c index 3a0f1a06c..f36228680 100644 --- a/src/lua/lua_common.c +++ b/src/lua/lua_common.c @@ -2401,7 +2401,7 @@ rspamd_lua_try_load_redis(lua_State *L, const ucl_object_t *obj, return FALSE; } -void rspamd_lua_push_full_word(lua_State *L, rspamd_stat_token_t *w) +void rspamd_lua_push_full_word(lua_State *L, rspamd_word_t *w) { int fl_cnt; @@ -2521,6 +2521,54 @@ int rspamd_lua_push_words(lua_State *L, GArray *words, return 1; } +int rspamd_lua_push_words_kvec(lua_State *L, rspamd_words_t *words, + enum rspamd_lua_words_type how) +{ + rspamd_word_t *w; + unsigned int i, cnt; + + if (!words || !words->a) { + lua_createtable(L, 0, 0); + return 1; + } + + lua_createtable(L, kv_size(*words), 0); + + for (i = 0, cnt = 1; i < kv_size(*words); i++) { + w = &kv_A(*words, i); + + switch (how) { + case RSPAMD_LUA_WORDS_STEM: + if (w->stemmed.len > 0) { + lua_pushlstring(L, w->stemmed.begin, w->stemmed.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_NORM: + if (w->normalized.len > 0) { + lua_pushlstring(L, w->normalized.begin, w->normalized.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_RAW: + if (w->original.len > 0) { + lua_pushlstring(L, w->original.begin, w->original.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_FULL: + rspamd_lua_push_full_word(L, w); + /* Push to the resulting vector */ + lua_rawseti(L, -2, cnt++); + break; + default: + break; + } + } + + return 1; +} + char * rspamd_lua_get_module_name(lua_State *L) { @@ -2658,4 +2706,4 @@ int rspamd_lua_geti(lua_State *L, int pos, int i) return lua_type(L, -1); } -#endif
\ No newline at end of file +#endif diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h index a29444394..d494f0923 100644 --- a/src/lua/lua_common.h +++ b/src/lua/lua_common.h @@ -538,9 +538,8 @@ enum lua_logger_escape_type { * @param len * @return */ -gsize lua_logger_out_type(lua_State *L, int pos, char *outbuf, - gsize len, struct lua_logger_trace *trace, - enum lua_logger_escape_type esc_type); +gsize lua_logger_out(lua_State *L, int pos, char *outbuf, gsize len, + enum lua_logger_escape_type esc_type); /** * Safely checks userdata to match specified class @@ -633,7 +632,7 @@ struct rspamd_stat_token_s; * @param L * @param word */ -void rspamd_lua_push_full_word(lua_State *L, struct rspamd_stat_token_s *word); +void rspamd_lua_push_full_word(lua_State *L, rspamd_word_t *word); enum rspamd_lua_words_type { RSPAMD_LUA_WORDS_STEM = 0, @@ -652,6 +651,9 @@ enum rspamd_lua_words_type { int rspamd_lua_push_words(lua_State *L, GArray *words, enum rspamd_lua_words_type how); +int rspamd_lua_push_words_kvec(lua_State *L, rspamd_words_t *words, + enum rspamd_lua_words_type how); + /** * Returns newly allocated name for caller module name * @param L diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c index 07ed58ad5..7b3a156cd 100644 --- a/src/lua/lua_config.c +++ b/src/lua/lua_config.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,10 @@ #include "utlist.h" #include <math.h> +/* Forward declarations for custom tokenizer functions */ +gboolean rspamd_config_load_custom_tokenizers(struct rspamd_config *cfg, GError **err); +void rspamd_config_unload_custom_tokenizers(struct rspamd_config *cfg); + /*** * This module is used to configure rspamd and is normally available as global * variable named `rspamd_config`. Unlike other modules, it is not necessary to @@ -118,7 +122,7 @@ local function foo(task) end */ /*** -* @method rspamd_config:radix_from_ucl(obj) +* @method rspamd_config:radix_from_ucl(obj, description) * Creates new embedded map of IP/mask addresses from object. * @param {ucl} obj object * @return {map} radix tree object @@ -862,6 +866,19 @@ LUA_FUNCTION_DEF(config, get_dns_max_requests); */ LUA_FUNCTION_DEF(config, get_dns_timeout); +/*** + * @method rspamd_config:load_custom_tokenizers() + * Loads custom tokenizers from configuration + * @return {boolean} true if successful + */ +LUA_FUNCTION_DEF(config, load_custom_tokenizers); + +/*** + * @method rspamd_config:unload_custom_tokenizers() + * Unloads custom tokenizers and frees memory + */ +LUA_FUNCTION_DEF(config, unload_custom_tokenizers); + static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF(config, get_module_opt), LUA_INTERFACE_DEF(config, get_mempool), @@ -937,6 +954,8 @@ static const struct luaL_reg configlib_m[] = { LUA_INTERFACE_DEF(config, get_tld_path), LUA_INTERFACE_DEF(config, get_dns_max_requests), LUA_INTERFACE_DEF(config, get_dns_timeout), + LUA_INTERFACE_DEF(config, load_custom_tokenizers), + LUA_INTERFACE_DEF(config, unload_custom_tokenizers), {"__tostring", rspamd_lua_class_tostring}, {"__newindex", lua_config_newindex}, {NULL, NULL}}; @@ -4485,11 +4504,14 @@ lua_config_init_subsystem(lua_State *L) nparts = g_strv_length(parts); for (i = 0; i < nparts; i++) { - if (strcmp(parts[i], "filters") == 0) { + const char *str = parts[i]; + + /* TODO: total shit, rework some day */ + if (strcmp(str, "filters") == 0) { rspamd_lua_post_load_config(cfg); rspamd_init_filters(cfg, false, false); } - else if (strcmp(parts[i], "langdet") == 0) { + else if (strcmp(str, "langdet") == 0) { if (!cfg->lang_det) { cfg->lang_det = rspamd_language_detector_init(cfg); rspamd_mempool_add_destructor(cfg->cfg_pool, @@ -4497,10 +4519,10 @@ lua_config_init_subsystem(lua_State *L) cfg->lang_det); } } - else if (strcmp(parts[i], "stat") == 0) { + else if (strcmp(str, "stat") == 0) { rspamd_stat_init(cfg, NULL); } - else if (strcmp(parts[i], "dns") == 0) { + else if (strcmp(str, "dns") == 0) { struct ev_loop *ev_base = lua_check_ev_base(L, 3); if (ev_base) { @@ -4514,11 +4536,25 @@ lua_config_init_subsystem(lua_State *L) return luaL_error(L, "no event base specified"); } } - else if (strcmp(parts[i], "symcache") == 0) { + else if (strcmp(str, "symcache") == 0) { rspamd_symcache_init(cfg->cache); } + else if (strcmp(str, "tokenizers") == 0 || strcmp(str, "custom_tokenizers") == 0) { + GError *err = NULL; + if (!rspamd_config_load_custom_tokenizers(cfg, &err)) { + g_strfreev(parts); + if (err) { + int ret = luaL_error(L, "failed to load custom tokenizers: %s", err->message); + g_error_free(err); + return ret; + } + else { + return luaL_error(L, "failed to load custom tokenizers"); + } + } + } else { - int ret = luaL_error(L, "invalid param: %s", parts[i]); + int ret = luaL_error(L, "invalid param: %s", str); g_strfreev(parts); return ret; @@ -4772,3 +4808,43 @@ void lua_call_finish_script(struct rspamd_config_cfg_lua_script *sc, lua_thread_call(thread, 1); } + +static int +lua_config_load_custom_tokenizers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + GError *err = NULL; + gboolean ret = rspamd_config_load_custom_tokenizers(cfg, &err); + + if (!ret && err) { + lua_pushboolean(L, FALSE); + lua_pushstring(L, err->message); + g_error_free(err); + return 2; + } + + lua_pushboolean(L, ret); + return 1; + } + else { + return luaL_error(L, "invalid arguments"); + } +} + +static int +lua_config_unload_custom_tokenizers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + rspamd_config_unload_custom_tokenizers(cfg); + return 0; + } + else { + return luaL_error(L, "invalid arguments"); + } +} diff --git a/src/lua/lua_cryptobox.c b/src/lua/lua_cryptobox.c index 721d71256..2c2254920 100644 --- a/src/lua/lua_cryptobox.c +++ b/src/lua/lua_cryptobox.c @@ -404,7 +404,7 @@ lua_cryptobox_keypair_load(lua_State *L) if (lua_type(L, 1) == LUA_TSTRING) { buf = luaL_checklstring(L, 1, &len); if (buf != NULL) { - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, buf, len)) { msg_err("cannot open keypair from data: %s", diff --git a/src/lua/lua_http.c b/src/lua/lua_http.c index 7e9e7b1df..731b8b057 100644 --- a/src/lua/lua_http.c +++ b/src/lua/lua_http.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,22 +29,123 @@ * This module hides all complexity: DNS resolving, sessions management, zero-copy * text transfers and so on under the hood. * @example +-- Basic GET request with callback local rspamd_http = require "rspamd_http" local function symbol_callback(task) local function http_callback(err_message, code, body, headers) task:insert_result('SYMBOL', 1) -- task is available via closure + + if err_message then + -- Handle error + return + end + + -- Process response + if code == 200 then + -- Process body and headers + for name, value in pairs(headers) do + -- Headers are lowercase + end + end end - rspamd_http.request({ - task=task, - url='http://example.com/data', - body=task:get_content(), - callback=http_callback, - headers={Header='Value', OtherHeader='Value'}, - mime_type='text/plain', - }) - end + rspamd_http.request({ + task=task, + url='http://example.com/data', + body=task:get_content(), + callback=http_callback, + headers={Header='Value', OtherHeader='Value', DuplicatedHeader={'Multiple', 'Values'}}, + mime_type='text/plain', + }) +end + +-- POST request with JSON body +local function post_json_example(task) + local ucl = require "ucl" + local data = { + id = task:get_queue_id(), + sender = task:get_from()[1].addr + } + + local json_data = ucl.to_json(data) + + rspamd_http.request({ + task = task, + url = "http://example.com/api/submit", + method = "POST", + body = json_data, + headers = {['Content-Type'] = 'application/json'}, + callback = function(err, code, body, headers) + if not err and code == 200 then + -- Success + end + end + }) +end + +-- Synchronous HTTP request (using coroutines) +local function sync_http_example(task) + -- No callback makes this a synchronous call + local err, response = rspamd_http.request({ + task = task, + url = "http://example.com/api/data", + method = "GET", + timeout = 10.0 + }) + + if not err then + -- Response is a table with code, content, and headers + if response.code == 200 then + -- Process response.content + return true + end + end + return false +end + +-- Using authentication +local function auth_example(task) + rspamd_http.request({ + task = task, + url = "https://example.com/api/protected", + method = "GET", + user = "username", + password = "secret", + callback = function(err, code, body, headers) + -- Process authenticated response + end + }) +end + +-- Using HTTPS with SSL options +local function https_example(task) + rspamd_http.request({ + task = task, + url = "https://example.com/api/secure", + method = "GET", + no_ssl_verify = false, -- Verify SSL (default) + callback = function(err, code, body, headers) + -- Process secure response + end + }) +end + +-- Using keep-alive and gzip +local function advanced_example(task) + rspamd_http.request({ + task = task, + url = "http://example.com/api/data", + method = "POST", + body = task:get_content(), + gzip = true, -- Compress request body + keepalive = true, -- Use keep-alive connection + max_size = 1024 * 1024, -- Limit response to 1MB + callback = function(err, code, body, headers) + -- Process response + end + }) +end */ #define MAX_HEADERS_SIZE 8192 @@ -602,7 +703,7 @@ lua_http_push_headers(lua_State *L, struct rspamd_http_message *msg) * @param {string} url specifies URL for a request in the standard URI form (e.g. 'http://example.com/path') * @param {function} callback specifies callback function in format `function (err_message, code, body, headers)` that is called on HTTP request completion. if this parameter is missing, the function performs "pseudo-synchronous" call (see [Synchronous and Asynchronous API overview](/doc/developers/sync_async.html#API-example-http-module) * @param {task} task if called from symbol handler it is generally a good idea to use the common task objects: event base, DNS resolver and events session - * @param {table} headers optional headers in form `[name='value', name='value']` + * @param {table} headers optional headers in form `[name='value']` or `[name=['value1', 'value2']]` to duplicate a header with multiple values * @param {string} mime_type MIME type of the HTTP content (for example, `text/html`) * @param {string/text} body full body content, can be opaque `rspamd{text}` to avoid data copying * @param {number} timeout floating point request timeout value in seconds (default is 5.0 seconds) diff --git a/src/lua/lua_logger.c b/src/lua/lua_logger.c index 004b82e72..04ff81b6d 100644 --- a/src/lua/lua_logger.c +++ b/src/lua/lua_logger.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,6 +174,11 @@ static const struct luaL_reg loggerlib_f[] = { {"__tostring", rspamd_lua_class_tostring}, {NULL, NULL}}; +static gsize +lua_logger_out_type(lua_State *L, int pos, char *outbuf, + gsize len, struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type); + static void lua_common_log_line(GLogLevelFlags level, lua_State *L, @@ -203,23 +208,19 @@ lua_common_log_line(GLogLevelFlags level, d.currentline); } - rspamd_common_log_function(NULL, - level, - module, - uid, - func_buf, - "%s", - msg); + p = func_buf; } else { - rspamd_common_log_function(NULL, - level, - module, - uid, - G_STRFUNC, - "%s", - msg); + p = (char *) G_STRFUNC; } + + rspamd_common_log_function(NULL, + level, + module, + uid, + p, + "%s", + msg); } /*** Logger interface ***/ @@ -279,105 +280,161 @@ lua_logger_char_safe(int t, unsigned int esc_type) return true; } -static gsize -lua_logger_out_str(lua_State *L, int pos, - char *outbuf, gsize len, - struct lua_logger_trace *trace, - enum lua_logger_escape_type esc_type) +#define LUA_MAX_ARGS 32 +/* Gracefully handles argument mismatches by substituting missing args and noting extra args */ +static glong +lua_logger_log_format_str(lua_State *L, int offset, char *logbuf, gsize remain, + const char *fmt, + enum lua_logger_escape_type esc_type) { - gsize slen, flen; - const char *str = lua_tolstring(L, pos, &slen); - static const char hexdigests[16] = "0123456789abcdef"; - gsize r = 0, s; - - if (str) { - gboolean normal = TRUE; - flen = MIN(slen, len - 1); + const char *c; + gsize r; + int digit; + char *d = logbuf; + unsigned int arg_num, cur_arg = 0, arg_max = lua_gettop(L) - offset; + gboolean args_used[LUA_MAX_ARGS]; + unsigned int used_args_count = 0; + + memset(args_used, 0, sizeof(args_used)); + while (remain > 1 && *fmt) { + if (*fmt == '%') { + ++fmt; + c = fmt; + if (*fmt == 's') { + ++fmt; + ++cur_arg; + } + else { + arg_num = 0; + while ((digit = g_ascii_digit_value(*fmt)) >= 0) { + ++fmt; + arg_num = arg_num * 10 + digit; + if (arg_num >= LUA_MAX_ARGS) { + /* Avoid ridiculously large numbers */ + fmt = c; + break; + } + } - for (r = 0; r < flen; r++) { - if (!lua_logger_char_safe(str[r], esc_type)) { - normal = FALSE; - break; + if (fmt > c) { + /* Update the current argument */ + cur_arg = arg_num; + } } - } - if (normal) { - r = rspamd_strlcpy(outbuf, str, flen + 1); - } - else { - /* Need to escape non-printed characters */ - r = 0; - s = 0; - - while (slen > 0 && len > 1) { - if (!lua_logger_char_safe(str[s], esc_type)) { - if (len >= 3) { - outbuf[r++] = '\\'; - outbuf[r++] = hexdigests[((str[s] >> 4) & 0xF)]; - outbuf[r++] = hexdigests[((str[s]) & 0xF)]; - - len -= 2; - } - else { - outbuf[r++] = '?'; - } + if (fmt > c) { + if (cur_arg < 1 || cur_arg > arg_max) { + /* Missing argument - substitute placeholder */ + r = rspamd_snprintf(d, remain, "<MISSING ARGUMENT>"); } else { - outbuf[r++] = str[s]; + /* Valid argument - output it */ + r = lua_logger_out(L, offset + cur_arg, d, remain, esc_type); + /* Track which arguments are used */ + if (cur_arg <= LUA_MAX_ARGS && !args_used[cur_arg - 1]) { + args_used[cur_arg - 1] = TRUE; + used_args_count++; + } } - s++; - slen--; - len--; + g_assert(r < remain); + remain -= r; + d += r; + continue; } - outbuf[r] = '\0'; + /* Copy % */ + --fmt; } + + *d++ = *fmt++; + --remain; } - return r; + /* Check for extra arguments and append warning if any */ + if (used_args_count > 0 && used_args_count < arg_max && remain > 1) { + unsigned int extra_args = arg_max - used_args_count; + r = rspamd_snprintf(d, remain, " <EXTRA %d ARGUMENTS>", (int) extra_args); + remain -= r; + d += r; + } + + *d = 0; + + return d - logbuf; } +#undef LUA_MAX_ARGS + static gsize -lua_logger_out_num(lua_State *L, int pos, char *outbuf, gsize len, - struct lua_logger_trace *trace) +lua_logger_out_str(lua_State *L, int pos, + char *outbuf, gsize len, + enum lua_logger_escape_type esc_type) { - double num = lua_tonumber(L, pos); - glong inum; - gsize r = 0; + static const char hexdigests[16] = "0123456789abcdef"; + gsize slen; + const unsigned char *str = lua_tolstring(L, pos, &slen); + unsigned char c; + char *out = outbuf; - if ((double) (glong) num == num) { - inum = num; - r = rspamd_snprintf(outbuf, len + 1, "%l", inum); + if (str) { + while (slen > 0 && len > 1) { + c = *str++; + if (lua_logger_char_safe(c, esc_type)) { + *out++ = c; + } + else if (len > 3) { + /* Need to escape non-printed characters */ + *out++ = '\\'; + *out++ = hexdigests[c >> 4]; + *out++ = hexdigests[c & 0xF]; + len -= 2; + } + else { + *out++ = '?'; + } + --slen; + --len; + } } - else { - r = rspamd_snprintf(outbuf, len + 1, "%f", num); + *out = 0; + + return out - outbuf; +} + +static gsize +lua_logger_out_num(lua_State *L, int pos, char *outbuf, gsize len) +{ + double num = lua_tonumber(L, pos); + glong inum = (glong) num; + + if ((double) inum == num) { + return rspamd_snprintf(outbuf, len, "%l", inum); } - return r; + return rspamd_snprintf(outbuf, len, "%f", num); } static gsize -lua_logger_out_boolean(lua_State *L, int pos, char *outbuf, gsize len, - struct lua_logger_trace *trace) +lua_logger_out_boolean(lua_State *L, int pos, char *outbuf, gsize len) { gboolean val = lua_toboolean(L, pos); - gsize r = 0; - r = rspamd_strlcpy(outbuf, val ? "true" : "false", len + 1); - - return r; + return rspamd_snprintf(outbuf, len, val ? "true" : "false"); } static gsize -lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len, - struct lua_logger_trace *trace) +lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len) { - int r = 0, top; + gsize r = 0; + int top; const char *str = NULL; gboolean converted_to_str = FALSE; top = lua_gettop(L); + if (pos < 0) { + pos += top + 1; /* Convert to absolute */ + } if (!lua_getmetatable(L, pos)) { return 0; @@ -396,26 +453,17 @@ lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len, if (lua_isfunction(L, -1)) { lua_pushvalue(L, pos); - if (lua_pcall(L, 1, 1, 0) != 0) { - lua_settop(L, top); - - return 0; - } - - str = lua_tostring(L, -1); - - if (str) { - r = rspamd_snprintf(outbuf, len, "%s", str); + if (lua_pcall(L, 1, 1, 0) == 0) { + str = lua_tostring(L, -1); + if (str) { + r = rspamd_snprintf(outbuf, len, "%s", str); + } } - - lua_settop(L, top); - - return r; } } lua_settop(L, top); - return 0; + return r; } lua_pushstring(L, "__tostring"); @@ -460,12 +508,12 @@ lua_logger_out_userdata(lua_State *L, int pos, char *outbuf, gsize len, return r; } -#define MOVE_BUF(d, remain, r) \ - (d) += (r); \ - (remain) -= (r); \ - if ((remain) == 0) { \ - lua_settop(L, old_top); \ - break; \ +#define MOVE_BUF(d, remain, r) \ + (d) += (r); \ + (remain) -= (r); \ + if ((remain) <= 1) { \ + lua_settop(L, top); \ + goto table_oob; \ } static gsize @@ -473,169 +521,153 @@ lua_logger_out_table(lua_State *L, int pos, char *outbuf, gsize len, struct lua_logger_trace *trace, enum lua_logger_escape_type esc_type) { - char *d = outbuf; - gsize remain = len, r; + char *d = outbuf, *str; + gsize remain = len; + glong r; gboolean first = TRUE; gconstpointer self = NULL; - int i, tpos, last_seq = -1, old_top; + int i, last_seq = 0, top; + double num; + glong inum; - if (!lua_istable(L, pos) || remain == 0) { - return 0; - } + /* Type and length checks are done in logger_out_type() */ - old_top = lua_gettop(L); self = lua_topointer(L, pos); /* Check if we have seen this pointer */ for (i = 0; i < TRACE_POINTS; i++) { if (trace->traces[i] == self) { - r = rspamd_snprintf(d, remain + 1, "ref(%p)", self); - - d += r; - - return (d - outbuf); + if ((trace->cur_level + TRACE_POINTS - 1) % TRACE_POINTS == i) { + return rspamd_snprintf(d, remain, "__self"); + } + return rspamd_snprintf(d, remain, "ref(%p)", self); } } trace->traces[trace->cur_level % TRACE_POINTS] = self; + ++trace->cur_level; - lua_pushvalue(L, pos); - r = rspamd_snprintf(d, remain + 1, "{"); - remain -= r; - d += r; + top = lua_gettop(L); + if (pos < 0) { + pos += top + 1; /* Convert to absolute */ + } + + r = rspamd_snprintf(d, remain, "{"); + MOVE_BUF(d, remain, r); /* Get numeric keys (ipairs) */ for (i = 1;; i++) { - lua_rawgeti(L, -1, i); + lua_rawgeti(L, pos, i); if (lua_isnil(L, -1)) { lua_pop(L, 1); + last_seq = i; break; } - last_seq = i; - - if (!first) { - r = rspamd_snprintf(d, remain + 1, ", "); - MOVE_BUF(d, remain, r); - } - - r = rspamd_snprintf(d, remain + 1, "[%d] = ", i); - MOVE_BUF(d, remain, r); - tpos = lua_gettop(L); - - if (lua_topointer(L, tpos) == self) { - r = rspamd_snprintf(d, remain + 1, "__self"); + if (first) { + first = FALSE; + str = "[%d] = "; } else { - r = lua_logger_out_type(L, tpos, d, remain, trace, esc_type); + str = ", [%d] = "; } + r = rspamd_snprintf(d, remain, str, i); + MOVE_BUF(d, remain, r); + + r = lua_logger_out_type(L, -1, d, remain, trace, esc_type); MOVE_BUF(d, remain, r); - first = FALSE; lua_pop(L, 1); } /* Get string keys (pairs) */ - for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + for (lua_pushnil(L); lua_next(L, pos); lua_pop(L, 1)) { /* 'key' is at index -2 and 'value' is at index -1 */ - if (lua_type(L, -2) == LUA_TNUMBER) { - if (last_seq > 0) { - lua_pushvalue(L, -2); - if (lua_tonumber(L, -1) <= last_seq + 1) { - lua_pop(L, 1); + /* Preserve key */ + lua_pushvalue(L, -2); + if (last_seq > 0) { + if (lua_type(L, -1) == LUA_TNUMBER) { + num = lua_tonumber(L, -1); /* no conversion here */ + inum = (glong) num; + if ((double) inum == num && inum > 0 && inum < last_seq) { /* Already seen */ + lua_pop(L, 1); continue; } - - lua_pop(L, 1); } } - if (!first) { - r = rspamd_snprintf(d, remain + 1, ", "); - MOVE_BUF(d, remain, r); - } - - /* Preserve key */ - lua_pushvalue(L, -2); - r = rspamd_snprintf(d, remain + 1, "[%s] = ", - lua_tostring(L, -1)); - lua_pop(L, 1); /* Remove key */ - MOVE_BUF(d, remain, r); - tpos = lua_gettop(L); - - if (lua_topointer(L, tpos) == self) { - r = rspamd_snprintf(d, remain + 1, "__self"); + if (first) { + first = FALSE; + str = "[%2] = %1"; } else { - r = lua_logger_out_type(L, tpos, d, remain, trace, esc_type); + str = ", [%2] = %1"; } + r = lua_logger_log_format_str(L, top + 1, d, remain, str, esc_type); + /* lua_logger_log_format_str now handles errors gracefully */ MOVE_BUF(d, remain, r); - first = FALSE; + /* Remove key */ + lua_pop(L, 1); } - lua_settop(L, old_top); - - r = rspamd_snprintf(d, remain + 1, "}"); + r = rspamd_snprintf(d, remain, "}"); d += r; +table_oob: + --trace->cur_level; + return (d - outbuf); } #undef MOVE_BUF -gsize lua_logger_out_type(lua_State *L, int pos, - char *outbuf, gsize len, - struct lua_logger_trace *trace, - enum lua_logger_escape_type esc_type) +static gsize +lua_logger_out_type(lua_State *L, int pos, + char *outbuf, gsize len, + struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type) { - int type; - gsize r = 0; - if (len == 0) { return 0; } - type = lua_type(L, pos); - trace->cur_level++; + int type = lua_type(L, pos); switch (type) { case LUA_TNUMBER: - r = lua_logger_out_num(L, pos, outbuf, len, trace); - break; + return lua_logger_out_num(L, pos, outbuf, len); case LUA_TBOOLEAN: - r = lua_logger_out_boolean(L, pos, outbuf, len, trace); - break; + return lua_logger_out_boolean(L, pos, outbuf, len); case LUA_TTABLE: - r = lua_logger_out_table(L, pos, outbuf, len, trace, esc_type); - break; + return lua_logger_out_table(L, pos, outbuf, len, trace, esc_type); case LUA_TUSERDATA: - r = lua_logger_out_userdata(L, pos, outbuf, len, trace); - break; + return lua_logger_out_userdata(L, pos, outbuf, len); case LUA_TFUNCTION: - r = rspamd_snprintf(outbuf, len + 1, "function"); - break; + return rspamd_snprintf(outbuf, len, "function"); case LUA_TLIGHTUSERDATA: - r = rspamd_snprintf(outbuf, len + 1, "0x%p", lua_topointer(L, pos)); - break; + return rspamd_snprintf(outbuf, len, "0x%p", lua_topointer(L, pos)); case LUA_TNIL: - r = rspamd_snprintf(outbuf, len + 1, "nil"); - break; + return rspamd_snprintf(outbuf, len, "nil"); case LUA_TNONE: - r = rspamd_snprintf(outbuf, len + 1, "no value"); - break; - default: - /* Try to push everything as string using tostring magic */ - r = lua_logger_out_str(L, pos, outbuf, len, trace, esc_type); - break; + return rspamd_snprintf(outbuf, len, "no value"); } - trace->cur_level--; + /* Try to push everything as string using tostring magic */ + return lua_logger_out_str(L, pos, outbuf, len, esc_type); +} - return r; +gsize lua_logger_out(lua_State *L, int pos, + char *outbuf, gsize len, + enum lua_logger_escape_type esc_type) +{ + struct lua_logger_trace tr; + memset(&tr, 0, sizeof(tr)); + + return lua_logger_out_type(L, pos, outbuf, len, &tr, esc_type); } static const char * @@ -731,72 +763,13 @@ static gboolean lua_logger_log_format(lua_State *L, int fmt_pos, gboolean is_string, char *logbuf, gsize remain) { - char *d; - const char *s, *c; - gsize r; - unsigned int arg_num, arg_max, cur_arg; - struct lua_logger_trace tr; - int digit; - - s = lua_tostring(L, fmt_pos); - if (s == NULL) { + const char *fmt = lua_tostring(L, fmt_pos); + if (fmt == NULL) { return FALSE; } - arg_max = (unsigned int) lua_gettop(L) - fmt_pos; - d = logbuf; - cur_arg = 0; - - while (remain > 0 && *s) { - if (*s == '%') { - ++s; - c = s; - if (*s == 's') { - ++s; - ++cur_arg; - } else { - arg_num = 0; - while ((digit = g_ascii_digit_value(*s)) >= 0) { - ++s; - arg_num = arg_num * 10 + digit; - if (arg_num >= 100) { - /* Avoid ridiculously large numbers */ - s = c; - break; - } - } - - if (s > c) { - /* Update the current argument */ - cur_arg = arg_num; - } - } - - if (s > c) { - if (cur_arg < 1 || cur_arg > arg_max) { - msg_err("wrong argument number: %ud", cur_arg); - return FALSE; - } - - memset(&tr, 0, sizeof(tr)); - r = lua_logger_out_type(L, fmt_pos + cur_arg, d, remain, &tr, - is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); - g_assert(r <= remain); - remain -= r; - d += r; - continue; - } - - /* Copy % */ - --s; - } - - *d++ = *s++; - --remain; - } - - *d = '\0'; - + /* lua_logger_log_format_str now handles argument mismatches gracefully */ + lua_logger_log_format_str(L, fmt_pos, logbuf, remain, fmt, is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); return TRUE; } @@ -808,15 +781,10 @@ lua_logger_do_log(lua_State *L, { char logbuf[RSPAMD_LOGBUF_SIZE - 128]; const char *uid = NULL; - int fmt_pos = start_pos; int ret; - GError *err = NULL; - if (lua_type(L, start_pos) == LUA_TSTRING) { - fmt_pos = start_pos; - } - else if (lua_type(L, start_pos) == LUA_TUSERDATA) { - fmt_pos = start_pos + 1; + if (lua_type(L, start_pos) == LUA_TUSERDATA) { + GError *err = NULL; uid = lua_logger_get_id(L, start_pos, &err); @@ -830,15 +798,17 @@ lua_logger_do_log(lua_State *L, return ret; } + + ++start_pos; } - else { + + if (lua_type(L, start_pos) != LUA_TSTRING) { /* Bad argument type */ return luaL_error(L, "bad format string type: %s", lua_typename(L, lua_type(L, start_pos))); } - ret = lua_logger_log_format(L, fmt_pos, is_string, - logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, start_pos, is_string, logbuf, sizeof(logbuf)); if (ret) { if (is_string) { @@ -849,12 +819,9 @@ lua_logger_do_log(lua_State *L, lua_common_log_line(level, L, logbuf, uid, "lua", 1); } } - else { - if (is_string) { - lua_pushnil(L); - - return 1; - } + else if (is_string) { + lua_pushnil(L); + return 1; } return 0; @@ -917,11 +884,11 @@ lua_logger_logx(lua_State *L) if (uid && modname) { if (lua_type(L, 4) == LUA_TSTRING) { - ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf)); } else if (lua_type(L, 4) == LUA_TNUMBER) { stack_pos = lua_tonumber(L, 4); - ret = lua_logger_log_format(L, 5, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 5, FALSE, logbuf, sizeof(logbuf)); } else { return luaL_error(L, "invalid argument on pos 4"); @@ -959,11 +926,11 @@ lua_logger_debugm(lua_State *L) if (uid && module) { if (lua_type(L, 3) == LUA_TSTRING) { - ret = lua_logger_log_format(L, 3, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 3, FALSE, logbuf, sizeof(logbuf)); } else if (lua_type(L, 3) == LUA_TNUMBER) { stack_pos = lua_tonumber(L, 3); - ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf) - 1); + ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf)); } else { return luaL_error(L, "invalid argument on pos 3"); diff --git a/src/lua/lua_map.c b/src/lua/lua_map.c index 062613bd7..5f55ece06 100644 --- a/src/lua/lua_map.c +++ b/src/lua/lua_map.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -319,6 +319,11 @@ int lua_config_radix_from_ucl(lua_State *L) ucl_object_insert_key(fake_obj, ucl_object_fromstring("static"), "url", 0, false); + if (lua_type(L, 3) == LUA_TSTRING) { + ucl_object_insert_key(fake_obj, ucl_object_fromstring(lua_tostring(L, 3)), + "description", 0, false); + } + if ((m = rspamd_map_add_from_ucl(cfg, fake_obj, "static radix map", rspamd_radix_read, rspamd_radix_fin, diff --git a/src/lua/lua_mimepart.c b/src/lua/lua_mimepart.c index 07dba9c93..982b10d90 100644 --- a/src/lua/lua_mimepart.c +++ b/src/lua/lua_mimepart.c @@ -901,7 +901,7 @@ lua_textpart_get_words_count(lua_State *L) return 1; } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_pushinteger(L, 0); } else { @@ -943,7 +943,7 @@ lua_textpart_get_words(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_createtable(L, 0, 0); } else { @@ -957,7 +957,7 @@ lua_textpart_get_words(lua_State *L) } } - return rspamd_lua_push_words(L, part->utf_words, how); + return rspamd_lua_push_words_kvec(L, &part->utf_words, how); } return 1; @@ -976,7 +976,7 @@ lua_textpart_filter_words(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_createtable(L, 0, 0); } else { @@ -998,9 +998,8 @@ lua_textpart_filter_words(lua_State *L) lua_createtable(L, 8, 0); - for (i = 0, cnt = 1; i < part->utf_words->len; i++) { - rspamd_stat_token_t *w = &g_array_index(part->utf_words, - rspamd_stat_token_t, i); + for (i = 0, cnt = 1; i < kv_size(part->utf_words); i++) { + rspamd_word_t *w = &kv_A(part->utf_words, i); switch (how) { case RSPAMD_LUA_WORDS_STEM: @@ -1194,13 +1193,13 @@ struct lua_shingle_filter_cbdata { rspamd_mempool_t *pool; }; -#define STORE_TOKEN(i, t) \ - do { \ - if ((i) < part->utf_words->len) { \ - word = &g_array_index(part->utf_words, rspamd_stat_token_t, (i)); \ - sd->t.begin = word->stemmed.begin; \ - sd->t.len = word->stemmed.len; \ - } \ +#define STORE_TOKEN(i, t) \ + do { \ + if ((i) < kv_size(part->utf_words)) { \ + word = &kv_A(part->utf_words, (i)); \ + sd->t.begin = word->stemmed.begin; \ + sd->t.len = word->stemmed.len; \ + } \ } while (0) static uint64_t @@ -1210,7 +1209,7 @@ lua_shingles_filter(uint64_t *input, gsize count, uint64_t minimal = G_MAXUINT64; gsize i, min_idx = 0; struct lua_shingle_data *sd; - rspamd_stat_token_t *word; + rspamd_word_t *word; struct lua_shingle_filter_cbdata *cbd = (struct lua_shingle_filter_cbdata *) ud; struct rspamd_mime_text_part *part; @@ -1248,7 +1247,7 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) unsigned int i; struct lua_shingle_data *sd; rspamd_cryptobox_hash_state_t st; - rspamd_stat_token_t *word; + rspamd_word_t *word; struct lua_shingle_filter_cbdata cbd; @@ -1256,7 +1255,7 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + if (IS_TEXT_PART_EMPTY(part) || !part->utf_words.a) { lua_pushnil(L); lua_pushnil(L); } @@ -1269,8 +1268,8 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) /* Calculate direct hash */ rspamd_cryptobox_hash_init(&st, key, rspamd_cryptobox_HASHKEYBYTES); - for (i = 0; i < part->utf_words->len; i++) { - word = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(part->utf_words); i++) { + word = &kv_A(part->utf_words, i); rspamd_cryptobox_hash_update(&st, word->stemmed.begin, word->stemmed.len); } @@ -1283,7 +1282,7 @@ lua_textpart_get_fuzzy_hashes(lua_State *L) cbd.pool = pool; cbd.part = part; - sgl = rspamd_shingles_from_text(part->utf_words, key, + sgl = rspamd_shingles_from_text(&part->utf_words, key, pool, lua_shingles_filter, &cbd, RSPAMD_SHINGLES_MUMHASH); if (sgl == NULL) { diff --git a/src/lua/lua_parsers.c b/src/lua/lua_parsers.c index f77b36952..39e1b0317 100644 --- a/src/lua/lua_parsers.c +++ b/src/lua/lua_parsers.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2020 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -108,8 +108,8 @@ int lua_parsers_tokenize_text(lua_State *L) struct rspamd_lua_text *t; struct rspamd_process_exception *ex; UText utxt = UTEXT_INITIALIZER; - GArray *res; - rspamd_stat_token_t *w; + rspamd_words_t *res; + rspamd_word_t *w; if (lua_type(L, 1) == LUA_TSTRING) { in = luaL_checklstring(L, 1, &len); @@ -175,13 +175,15 @@ int lua_parsers_tokenize_text(lua_State *L) lua_pushnil(L); } else { - lua_createtable(L, res->len, 0); + lua_createtable(L, kv_size(*res), 0); - for (i = 0; i < res->len; i++) { - w = &g_array_index(res, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*res); i++) { + w = &kv_A(*res, i); lua_pushlstring(L, w->original.begin, w->original.len); lua_rawseti(L, -2, i + 1); } + kv_destroy(*res); + g_free(res); } cur = exceptions; diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c index 97f9c496e..0b1473b61 100644 --- a/src/lua/lua_task.c +++ b/src/lua/lua_task.c @@ -6943,7 +6943,7 @@ lua_task_get_meta_words(lua_State *L) return luaL_error(L, "invalid arguments"); } - if (task->meta_words == NULL) { + if (!task->meta_words.a) { lua_createtable(L, 0, 0); } else { @@ -6967,7 +6967,7 @@ lua_task_get_meta_words(lua_State *L) } } - return rspamd_lua_push_words(L, task->meta_words, how); + return rspamd_lua_push_words_kvec(L, &task->meta_words, how); } return 1; @@ -7039,6 +7039,76 @@ lua_lookup_words_array(lua_State *L, return nmatched; } +static unsigned int +lua_lookup_words_kvec(lua_State *L, + int cbpos, + struct rspamd_task *task, + struct rspamd_lua_map *map, + rspamd_words_t *words) +{ + rspamd_word_t *tok; + unsigned int i, nmatched = 0; + int err_idx; + gboolean matched; + const char *key; + gsize keylen; + + if (!words || !words->a) { + return 0; + } + + for (i = 0; i < kv_size(*words); i++) { + tok = &kv_A(*words, i); + + matched = FALSE; + + if (tok->normalized.len == 0) { + continue; + } + + key = tok->normalized.begin; + keylen = tok->normalized.len; + + switch (map->type) { + case RSPAMD_LUA_MAP_SET: + case RSPAMD_LUA_MAP_HASH: + /* We know that tok->normalized is zero terminated in fact */ + if (rspamd_match_hash_map(map->data.hash, key, keylen)) { + matched = TRUE; + } + break; + case RSPAMD_LUA_MAP_REGEXP: + case RSPAMD_LUA_MAP_REGEXP_MULTIPLE: + if (rspamd_match_regexp_map_single(map->data.re_map, key, + keylen)) { + matched = TRUE; + } + break; + default: + g_assert_not_reached(); + break; + } + + if (matched) { + nmatched++; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_pushvalue(L, cbpos); /* Function */ + rspamd_lua_push_full_word(L, tok); + + if (lua_pcall(L, 1, 0, err_idx) != 0) { + msg_err_task("cannot call callback function for lookup words: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); + } + } + + return nmatched; +} + static int lua_task_lookup_words(lua_State *L) { @@ -7062,13 +7132,13 @@ lua_task_lookup_words(lua_State *L) PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) { - if (tp->utf_words) { - matches += lua_lookup_words_array(L, 3, task, map, tp->utf_words); + if (tp->utf_words.a) { + matches += lua_lookup_words_kvec(L, 3, task, map, &tp->utf_words); } } - if (task->meta_words) { - matches += lua_lookup_words_array(L, 3, task, map, task->meta_words); + if (task->meta_words.a) { + matches += lua_lookup_words_kvec(L, 3, task, map, &task->meta_words); } lua_pushinteger(L, matches); diff --git a/src/lua/lua_util.c b/src/lua/lua_util.c index 9fe862757..f2e9b8fa9 100644 --- a/src/lua/lua_util.c +++ b/src/lua/lua_util.c @@ -23,12 +23,21 @@ #include "lua_parsers.h" -#ifdef WITH_LUA_REPL -#include "replxx.h" -#endif +#include "replxx.h" #include <math.h> #include <glob.h> +#include <sys/types.h> +#include <sys/time.h> +#if defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) +#include <sys/sysctl.h> +#ifdef __FreeBSD__ +#include <sys/user.h> +#endif +#endif +#ifdef __APPLE__ +#include <mach/mach.h> +#endif #include "unicode/uspoof.h" #include "unicode/uscript.h" @@ -629,6 +638,27 @@ LUA_FUNCTION_DEF(util, caseless_hash_fast); LUA_FUNCTION_DEF(util, get_hostname); /*** + * @function util.get_uptime() + * Returns system uptime in seconds + * @return {number} uptime in seconds + */ +LUA_FUNCTION_DEF(util, get_uptime); + +/*** + * @function util.get_pid() + * Returns current process PID + * @return {number} process ID + */ +LUA_FUNCTION_DEF(util, get_pid); + +/*** + * @function util.get_memory_usage() + * Returns memory usage information for current process + * @return {table} memory usage info with 'rss' and 'vsize' fields in bytes + */ +LUA_FUNCTION_DEF(util, get_memory_usage); + +/*** * @function util.parse_content_type(ct_string, mempool) * Parses content-type string to a table: * - `type` @@ -730,6 +760,9 @@ static const struct luaL_reg utillib_f[] = { LUA_INTERFACE_DEF(util, umask), LUA_INTERFACE_DEF(util, isatty), LUA_INTERFACE_DEF(util, get_hostname), + LUA_INTERFACE_DEF(util, get_uptime), + LUA_INTERFACE_DEF(util, get_pid), + LUA_INTERFACE_DEF(util, get_memory_usage), LUA_INTERFACE_DEF(util, parse_content_type), LUA_INTERFACE_DEF(util, mime_header_encode), LUA_INTERFACE_DEF(util, pack), @@ -2416,6 +2449,107 @@ lua_util_get_hostname(lua_State *L) } static int +lua_util_get_uptime(lua_State *L) +{ + LUA_TRACE_POINT; + double uptime = 0.0; + +#ifdef __linux__ + FILE *f = fopen("/proc/uptime", "r"); + if (f) { + if (fscanf(f, "%lf", &uptime) != 1) { + uptime = 0.0; + } + fclose(f); + } +#elif defined(__APPLE__) || defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) + struct timeval boottime; + size_t len = sizeof(boottime); + int mib[2] = {CTL_KERN, KERN_BOOTTIME}; + + if (sysctl(mib, 2, &boottime, &len, NULL, 0) == 0) { + struct timeval now; + gettimeofday(&now, NULL); + uptime = (now.tv_sec - boottime.tv_sec) + + (now.tv_usec - boottime.tv_usec) / 1000000.0; + } +#endif + + lua_pushnumber(L, uptime); + return 1; +} + +static int +lua_util_get_pid(lua_State *L) +{ + LUA_TRACE_POINT; + lua_pushinteger(L, getpid()); + return 1; +} + +static int +lua_util_get_memory_usage(lua_State *L) +{ + LUA_TRACE_POINT; + lua_createtable(L, 0, 2); + +#ifdef __linux__ + FILE *f = fopen("/proc/self/status", "r"); + if (f) { + char line[256]; + long rss = 0, vsize = 0; + + while (fgets(line, sizeof(line), f)) { + if (sscanf(line, "VmRSS: %ld kB", &rss) == 1) { + rss *= 1024; /* Convert to bytes */ + } + else if (sscanf(line, "VmSize: %ld kB", &vsize) == 1) { + vsize *= 1024; /* Convert to bytes */ + } + } + fclose(f); + + lua_pushstring(L, "rss"); + lua_pushinteger(L, rss); + lua_settable(L, -3); + + lua_pushstring(L, "vsize"); + lua_pushinteger(L, vsize); + lua_settable(L, -3); + } +#elif defined(__APPLE__) + struct task_basic_info info; + mach_msg_type_number_t count = TASK_BASIC_INFO_COUNT; + + if (task_info(mach_task_self(), TASK_BASIC_INFO, (task_info_t) &info, &count) == KERN_SUCCESS) { + lua_pushstring(L, "rss"); + lua_pushinteger(L, info.resident_size); + lua_settable(L, -3); + + lua_pushstring(L, "vsize"); + lua_pushinteger(L, info.virtual_size); + lua_settable(L, -3); + } +#elif defined(__FreeBSD__) || defined(__OpenBSD__) || defined(__NetBSD__) + struct kinfo_proc kp; + size_t len = sizeof(kp); + int mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid()}; + + if (sysctl(mib, 4, &kp, &len, NULL, 0) == 0) { + lua_pushstring(L, "rss"); + lua_pushinteger(L, kp.ki_rssize * getpagesize()); + lua_settable(L, -3); + + lua_pushstring(L, "vsize"); + lua_pushinteger(L, kp.ki_size); + lua_settable(L, -3); + } +#endif + + return 1; +} + +static int lua_util_parse_content_type(lua_State *L) { return lua_parsers_parse_content_type(L); @@ -2510,7 +2644,7 @@ lua_util_readline(lua_State *L) if (lua_type(L, 1) == LUA_TSTRING) { prompt = lua_tostring(L, 1); } -#ifdef WITH_LUA_REPL + static Replxx *rx_instance = NULL; if (rx_instance == NULL) { @@ -2527,26 +2661,6 @@ lua_util_readline(lua_State *L) else { lua_pushnil(L); } -#else - size_t linecap = 0; - ssize_t linelen; - - fprintf(stdout, "%s ", prompt); - - linelen = getline(&input, &linecap, stdin); - - if (linelen > 0) { - if (input[linelen - 1] == '\n') { - linelen--; - } - - lua_pushlstring(L, input, linelen); - free(input); - } - else { - lua_pushnil(L); - } -#endif return 1; } @@ -3721,4 +3835,4 @@ lua_ev_base_add_timer(lua_State *L) ev_timer_start(ev_base, &cbdata->ev); return 0; -}
\ No newline at end of file +} diff --git a/src/plugins/chartable.cxx b/src/plugins/chartable.cxx index a5c7cb899..c82748862 100644 --- a/src/plugins/chartable.cxx +++ b/src/plugins/chartable.cxx @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1696,7 +1696,7 @@ rspamd_can_alias_latin(int ch) static double rspamd_chartable_process_word_utf(struct rspamd_task *task, - rspamd_stat_token_t *w, + rspamd_word_t *w, gboolean is_url, unsigned int *ncap, struct chartable_ctx *chartable_module_ctx, @@ -1842,7 +1842,7 @@ rspamd_chartable_process_word_utf(struct rspamd_task *task, static double rspamd_chartable_process_word_ascii(struct rspamd_task *task, - rspamd_stat_token_t *w, + rspamd_word_t *w, gboolean is_url, struct chartable_ctx *chartable_module_ctx) { @@ -1931,17 +1931,17 @@ rspamd_chartable_process_part(struct rspamd_task *task, struct chartable_ctx *chartable_module_ctx, gboolean ignore_diacritics) { - rspamd_stat_token_t *w; + rspamd_word_t *w; unsigned int i, ncap = 0; double cur_score = 0.0; - if (part == nullptr || part->utf_words == nullptr || - part->utf_words->len == 0 || part->nwords == 0) { + if (part == nullptr || part->utf_words.a == nullptr || + kv_size(part->utf_words) == 0 || part->nwords == 0) { return FALSE; } - for (i = 0; i < part->utf_words->len; i++) { - w = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(part->utf_words); i++) { + w = &kv_A(part->utf_words, i); if ((w->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT)) { @@ -2015,13 +2015,13 @@ chartable_symbol_callback(struct rspamd_task *task, ignore_diacritics = TRUE; } - if (task->meta_words != nullptr && task->meta_words->len > 0) { - rspamd_stat_token_t *w; + if (task->meta_words.a && kv_size(task->meta_words) > 0) { + rspamd_word_t *w; double cur_score = 0; - gsize arlen = task->meta_words->len; + gsize arlen = kv_size(task->meta_words); for (i = 0; i < arlen; i++) { - w = &g_array_index(task->meta_words, rspamd_stat_token_t, i); + w = &kv_A(task->meta_words, i); cur_score += rspamd_chartable_process_word_utf(task, w, FALSE, nullptr, chartable_module_ctx, ignore_diacritics); } diff --git a/src/plugins/fuzzy_check.c b/src/plugins/fuzzy_check.c index ece9a91e0..7dd5162ac 100644 --- a/src/plugins/fuzzy_check.c +++ b/src/plugins/fuzzy_check.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,7 +78,8 @@ enum fuzzy_rule_mode { }; struct fuzzy_rule { - struct upstream_list *servers; + struct upstream_list *read_servers; /* Servers for read operations */ + struct upstream_list *write_servers; /* Servers for write operations */ const char *symbol; const char *algorithm_str; const char *name; @@ -543,22 +544,68 @@ fuzzy_parse_rule(struct rspamd_config *cfg, const ucl_object_t *obj, } if ((value = ucl_object_lookup(obj, "servers")) != NULL) { - rule->servers = rspamd_upstreams_create(cfg->ups_ctx); - /* pass max_error and revive_time configuration in upstream for fuzzy storage - * it allows to configure error_rate threshold and upstream dead timer - */ - rspamd_upstreams_set_limits(rule->servers, + rule->read_servers = rspamd_upstreams_create(cfg->ups_ctx); + rspamd_upstreams_set_limits(rule->read_servers, (double) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, (unsigned int) fuzzy_module_ctx->max_errors, 0); rspamd_mempool_add_destructor(cfg->cfg_pool, (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, - rule->servers); - if (!rspamd_upstreams_from_ucl(rule->servers, value, DEFAULT_PORT, NULL)) { + rule->read_servers); + if (!rspamd_upstreams_from_ucl(rule->read_servers, value, DEFAULT_PORT, NULL)) { msg_err_config("cannot read servers definition"); return -1; } + + rule->write_servers = rule->read_servers; + } + else { + /* Check for read_servers and write_servers */ + gboolean has_read = FALSE, has_write = FALSE; + + if ((value = ucl_object_lookup(obj, "read_servers")) != NULL) { + rule->read_servers = rspamd_upstreams_create(cfg->ups_ctx); + rspamd_upstreams_set_limits(rule->read_servers, + (double) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, + (unsigned int) fuzzy_module_ctx->max_errors, 0); + + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, + rule->read_servers); + if (!rspamd_upstreams_from_ucl(rule->read_servers, value, DEFAULT_PORT, NULL)) { + msg_err_config("cannot read read_servers definition"); + return -1; + } + has_read = TRUE; + } + + if ((value = ucl_object_lookup(obj, "write_servers")) != NULL) { + rule->write_servers = rspamd_upstreams_create(cfg->ups_ctx); + rspamd_upstreams_set_limits(rule->write_servers, + (double) fuzzy_module_ctx->revive_time, NAN, NAN, NAN, + (unsigned int) fuzzy_module_ctx->max_errors, 0); + + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_upstreams_destroy, + rule->write_servers); + if (!rspamd_upstreams_from_ucl(rule->write_servers, value, DEFAULT_PORT, NULL)) { + msg_err_config("cannot read write_servers definition"); + return -1; + } + has_write = TRUE; + } + + /* If we have both read and write servers, we don't need the common servers list */ + if (has_read && !has_write) { + /* Use read_servers for all operations */ + rule->write_servers = rule->read_servers; + } + else if (has_write && !has_read) { + /* Use write_servers for all operations */ + rule->read_servers = rule->write_servers; + } } + if ((value = ucl_object_lookup(obj, "fuzzy_map")) != NULL) { it = NULL; while ((cur = ucl_object_iterate(value, &it, true)) != NULL) { @@ -636,7 +683,7 @@ fuzzy_parse_rule(struct rspamd_config *cfg, const ucl_object_t *obj, strlen(shingles_key_str), NULL, 0); rule->shingles_key->len = 16; - if (rspamd_upstreams_count(rule->servers) == 0) { + if (rspamd_upstreams_count(rule->read_servers) == 0) { msg_err_config("no servers defined for fuzzy rule with name: %s", rule->name); return -1; @@ -898,6 +945,24 @@ int fuzzy_check_module_init(struct rspamd_config *cfg, struct module_ctx **ctx) 0); rspamd_rcl_add_doc_by_path(cfg, "fuzzy_check.rule", + "List of servers to check (read-only operations)", + "read_servers", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", + "List of servers to learn (write operations)", + "write_servers", + UCL_STRING, + NULL, + 0, + NULL, + 0); + rspamd_rcl_add_doc_by_path(cfg, + "fuzzy_check.rule", "If true then never try to learn this fuzzy storage", "read_only", UCL_BOOLEAN, @@ -1249,7 +1314,7 @@ int fuzzy_check_module_config(struct rspamd_config *cfg, bool validate) LL_FOREACH(value, cur) { - if (ucl_object_lookup(cur, "servers")) { + if (ucl_object_lookup_any(cur, "servers", "read_servers", "write_servers", NULL) != NULL) { /* Unnamed rule */ fuzzy_parse_rule(cfg, cur, NULL, cb_id); nrules++; @@ -1366,10 +1431,10 @@ fuzzy_io_fin(void *ud) close(session->fd); } -static GArray * +static rspamd_words_t * fuzzy_preprocess_words(struct rspamd_mime_text_part *part, rspamd_mempool_t *pool) { - return part->utf_words; + return &part->utf_words; } static void @@ -1715,26 +1780,30 @@ fuzzy_cmd_write_extensions(struct rspamd_task *task, struct rspamd_email_address *addr = g_ptr_array_index(MESSAGE_FIELD(task, from_mime), 0); - unsigned int to_write = MIN(MAX_FUZZY_DOMAIN, addr->domain_len) + 2; - if (to_write > 0 && to_write <= available) { - *dest++ = RSPAMD_FUZZY_EXT_SOURCE_DOMAIN; - *dest++ = to_write - 2; + if (addr->domain_len > 0) { + /* Filter invalid domains */ + unsigned int to_write = MIN(MAX_FUZZY_DOMAIN, addr->domain_len) + 2; - if (addr->domain_len < MAX_FUZZY_DOMAIN) { - memcpy(dest, addr->domain, addr->domain_len); - dest += addr->domain_len; - } - else { - /* Trim from left */ - memcpy(dest, - addr->domain + (addr->domain_len - MAX_FUZZY_DOMAIN), - MAX_FUZZY_DOMAIN); - dest += MAX_FUZZY_DOMAIN; - } + if (to_write > 0 && to_write <= available) { + *dest++ = RSPAMD_FUZZY_EXT_SOURCE_DOMAIN; + *dest++ = to_write - 2; + + if (addr->domain_len < MAX_FUZZY_DOMAIN) { + memcpy(dest, addr->domain, addr->domain_len); + dest += addr->domain_len; + } + else { + /* Trim from left */ + memcpy(dest, + addr->domain + (addr->domain_len - MAX_FUZZY_DOMAIN), + MAX_FUZZY_DOMAIN); + dest += MAX_FUZZY_DOMAIN; + } - available -= to_write; - written += to_write; + available -= to_write; + written += to_write; + } } } @@ -1792,7 +1861,7 @@ fuzzy_cmd_from_text_part(struct rspamd_task *task, unsigned int i; rspamd_cryptobox_hash_state_t st; rspamd_stat_token_t *word; - GArray *words; + rspamd_words_t *words; struct fuzzy_cmd_io *io; unsigned int additional_length; unsigned char *additional_data; @@ -1901,10 +1970,10 @@ fuzzy_cmd_from_text_part(struct rspamd_task *task, rspamd_cryptobox_hash_init(&st, rule->hash_key->str, rule->hash_key->len); words = fuzzy_preprocess_words(part, task->task_pool); - for (i = 0; i < words->len; i++) { - word = &g_array_index(words, rspamd_stat_token_t, i); + for (i = 0; i < kv_size(*words); i++) { + word = &kv_A(*words, i); - if (!((word->flags & RSPAMD_STAT_TOKEN_FLAG_SKIPPED) || word->stemmed.len == 0)) { + if (!((word->flags & RSPAMD_WORD_FLAG_SKIPPED) || word->stemmed.len == 0)) { rspamd_cryptobox_hash_update(&st, word->stemmed.begin, word->stemmed.len); } @@ -2615,7 +2684,7 @@ fuzzy_insert_metric_results(struct rspamd_task *task, struct fuzzy_rule *rule, if (task->message) { PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) { - if (!IS_TEXT_PART_EMPTY(tp) && tp->utf_words != NULL && tp->utf_words->len > 0) { + if (!IS_TEXT_PART_EMPTY(tp) && kv_size(tp->utf_words) > 0) { seen_text_part = TRUE; if (tp->utf_stripped_text.magic == UTEXT_MAGIC) { @@ -3394,8 +3463,8 @@ register_fuzzy_client_call(struct rspamd_task *task, int sock; if (!rspamd_session_blocked(task->s)) { - /* Get upstream */ - selected = rspamd_upstream_get(rule->servers, RSPAMD_UPSTREAM_ROUND_ROBIN, + /* Get upstream - use read_servers for check operations */ + selected = rspamd_upstream_get(rule->read_servers, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); if (selected) { addr = rspamd_upstream_addr_next(selected); @@ -3522,9 +3591,8 @@ register_fuzzy_controller_call(struct rspamd_http_connection_entry *entry, int sock; int ret = -1; - /* Get upstream */ - - while ((selected = rspamd_upstream_get_forced(rule->servers, + /* Get upstream - use write_servers for learn/unlearn operations */ + while ((selected = rspamd_upstream_get_forced(rule->write_servers, RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) { /* Create UDP socket */ addr = rspamd_upstream_addr_next(selected); @@ -3538,6 +3606,9 @@ register_fuzzy_controller_call(struct rspamd_http_connection_entry *entry, rspamd_upstream_fail(selected, TRUE, strerror(errno)); } else { + msg_info_task("fuzzy storage %s (%s rule) is used for write", + rspamd_inet_address_to_string_pretty(addr), + rule->name); s = rspamd_mempool_alloc0(session->pool, sizeof(struct fuzzy_learn_session)); @@ -3620,6 +3691,7 @@ fuzzy_modify_handler(struct rspamd_http_connection_entry *conn_ent, PTR_ARRAY_FOREACH(fuzzy_module_ctx->fuzzy_rules, i, rule) { if (rule->mode == fuzzy_rule_read_only) { + msg_debug_task("skip rule %s as it is read-only", rule->name); continue; } @@ -3729,6 +3801,8 @@ fuzzy_modify_handler(struct rspamd_http_connection_entry *conn_ent, else { commands = fuzzy_generate_commands(task, rule, cmd, flag, value, flags); + msg_debug_task("fuzzy command %d for rule %s, flag %d, value %d", + cmd, rule->name, flag, value); if (commands != NULL) { res = register_fuzzy_controller_call(conn_ent, rule, @@ -3894,7 +3968,7 @@ fuzzy_check_send_lua_learn(struct fuzzy_rule *rule, /* Get upstream */ if (!rspamd_session_blocked(task->s)) { - while ((selected = rspamd_upstream_get(rule->servers, + while ((selected = rspamd_upstream_get(rule->write_servers, RSPAMD_UPSTREAM_SEQUENTIAL, NULL, 0))) { /* Create UDP socket */ addr = rspamd_upstream_addr_next(selected); @@ -4491,9 +4565,21 @@ fuzzy_lua_list_storages(lua_State *L) lua_setfield(L, -2, "read_only"); /* Push servers */ - lua_createtable(L, rspamd_upstreams_count(rule->servers), 0); - rspamd_upstreams_foreach(rule->servers, lua_upstream_str_inserter, L); - lua_setfield(L, -2, "servers"); + if (rule->read_servers == rule->write_servers) { + /* Same servers for both operations */ + lua_createtable(L, rspamd_upstreams_count(rule->read_servers), 0); + rspamd_upstreams_foreach(rule->read_servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "servers"); + } + else { + /* Different servers for read and write */ + lua_createtable(L, rspamd_upstreams_count(rule->read_servers), 0); + rspamd_upstreams_foreach(rule->read_servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "read_servers"); + lua_createtable(L, rspamd_upstreams_count(rule->write_servers), 0); + rspamd_upstreams_foreach(rule->write_servers, lua_upstream_str_inserter, L); + lua_setfield(L, -2, "write_servers"); + } /* Push flags */ GHashTableIter it; @@ -4780,7 +4866,7 @@ fuzzy_lua_ping_storage(lua_State *L) rspamd_ptr_array_free_hard, addrs); } else { - struct upstream *selected = rspamd_upstream_get(rule_found->servers, + struct upstream *selected = rspamd_upstream_get(rule_found->read_servers, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); addr = rspamd_upstream_addr_next(selected); } @@ -4824,4 +4910,4 @@ fuzzy_lua_ping_storage(lua_State *L) lua_pushboolean(L, TRUE); return 1; -}
\ No newline at end of file +} diff --git a/src/plugins/lua/arc.lua b/src/plugins/lua/arc.lua index fb5dd93e6..45da1f5a2 100644 --- a/src/plugins/lua/arc.lua +++ b/src/plugins/lua/arc.lua @@ -147,7 +147,7 @@ local function parse_arc_header(hdr, target, is_aar) -- sort by i= attribute table.sort(target, function(a, b) - return (a.i or 0) < (b.i or 0) + return (tonumber(a.i) or 0) < (tonumber(b.i) or 0) end) end @@ -695,11 +695,11 @@ local function do_sign(task, sign_params) sign_params.pubkey = results[1] sign_params.strict_pubkey_check = not settings.allow_pubkey_mismatch elseif not settings.allow_pubkey_mismatch then - rspamd_logger.errx('public key for domain %s/%s is not found: %s, skip signing', + rspamd_logger.errx(task, 'public key for domain %s/%s is not found: %s, skip signing', sign_params.domain, sign_params.selector, err) return else - rspamd_logger.infox('public key for domain %s/%s is not found: %s', + rspamd_logger.infox(task, 'public key for domain %s/%s is not found: %s', sign_params.domain, sign_params.selector, err) end diff --git a/src/plugins/lua/contextal.lua b/src/plugins/lua/contextal.lua new file mode 100644 index 000000000..e29c21645 --- /dev/null +++ b/src/plugins/lua/contextal.lua @@ -0,0 +1,338 @@ +--[[ +Copyright (c) 2025, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local E = {} +local N = 'contextal' + +if confighelp then + return +end + +local opts = rspamd_config:get_all_opt(N) +if not opts then + return +end + +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local redis_cache = require "lua_cache" +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types +local ucl = require "ucl" + +local cache_context, redis_params + +local contextal_actions = { + ['ALERT'] = true, + ['ALLOW'] = true, + ['BLOCK'] = true, + ['QUARANTINE'] = true, + ['SPAM'] = true, +} + +local config_schema = lua_redis.enrich_schema { + action_symbol_prefix = ts.string:is_optional(), + base_url = ts.string:is_optional(), + cache_prefix = ts.string:is_optional(), + cache_timeout = ts.number:is_optional(), + cache_ttl = ts.number:is_optional(), + custom_actions = ts.array_of(ts.string):is_optional(), + defer_if_no_result = ts.boolean:is_optional(), + defer_message = ts.string:is_optional(), + enabled = ts.boolean:is_optional(), + http_timeout = ts.number:is_optional(), + request_ttl = ts.number:is_optional(), + submission_symbol = ts.string:is_optional(), +} + +local settings = { + action_symbol_prefix = 'CONTEXTAL_ACTION', + base_url = 'http://localhost:8080', + cache_prefix = 'CXAL', + cache_timeout = 5, + cache_ttl = 3600, + custom_actions = {}, + defer_if_no_result = false, + defer_message = 'Awaiting deep scan - try again later', + http_timeout = 2, + request_ttl = 4, + submission_symbol = 'CONTEXTAL_SUBMIT', +} + +local static_boundary = rspamd_util.random_hex(32) +local wait_request_ttl = true + +local function maybe_defer(task, obj) + if settings.defer_if_no_result and not ((obj or E)[1] or E).actions then + task:set_pre_result('soft reject', settings.defer_message, N) + end +end + +local function process_actions(task, obj, is_cached) + lua_util.debugm(N, task, 'got result: %s (%s)', obj, is_cached and 'cached' or 'fresh') + for _, match in ipairs((obj[1] or E).actions or E) do + local act = match.action + local scenario = match.scenario + if not (act and scenario) then + rspamd_logger.err(task, 'bad result: %s', match) + elseif contextal_actions[act] then + task:insert_result(settings.action_symbol_prefix .. '_' .. act, 1.0, scenario) + else + rspamd_logger.err(task, 'unknown action: %s', act) + end + end + + if not cache_context or is_cached then + maybe_defer(task, obj) + return + end + + local cache_obj + if (obj[1] or E).actions then + cache_obj = {[1] = {["actions"] = obj[1].actions}} + else + local work_id = task:get_mempool():get_variable('contextal_work_id', 'string') + if work_id then + cache_obj = {[1] = {["work_id"] = work_id}} + else + rspamd_logger.err(task, 'no work id found in mempool') + return + end + end + + redis_cache.cache_set(task, + task:get_digest(), + cache_obj, + cache_context) + + maybe_defer(task, obj) +end + +local function process_cached(task, obj) + if (obj[1] or E).actions then + lua_util.debugm(N, task, 'using cached actions: %s', obj[1].actions) + task:disable_symbol(settings.action_symbol_prefix) + return process_actions(task, obj, true) + elseif (obj[1] or E).work_id then + lua_util.debugm(N, task, 'using old work ID: %s', obj[1].work_id) + task:get_mempool():set_variable('contextal_work_id', obj[1].work_id) + else + rspamd_logger.err(task, 'bad result (cached): %s', obj) + end +end + +local function action_cb(task) + local work_id = task:get_mempool():get_variable('contextal_work_id', 'string') + if not work_id then + rspamd_logger.err(task, 'no work id found in mempool') + return + end + lua_util.debugm(N, task, 'polling for result for work id: %s', work_id) + + local function http_callback(err, code, body, hdrs) + if err then + rspamd_logger.err(task, 'http error: %s', err) + maybe_defer(task) + return + end + if code ~= 200 then + rspamd_logger.err(task, 'bad http code: %s', code) + maybe_defer(task) + return + end + local parser = ucl.parser() + local _, parse_err = parser:parse_string(body) + if parse_err then + rspamd_logger.err(task, 'cannot parse JSON: %s', err) + maybe_defer(task) + return + end + local obj = parser:get_object() + return process_actions(task, obj, false) + end + + rspamd_http.request({ + task = task, + url = settings.actions_url .. work_id, + callback = http_callback, + timeout = settings.http_timeout, + gzip = settings.gzip, + keepalive = settings.keepalive, + no_ssl_verify = settings.no_ssl_verify, + }) +end + +local function submit(task) + + local function http_callback(err, code, body, hdrs) + if err then + rspamd_logger.err(task, 'http error: %s', err) + maybe_defer(task) + return + end + if code ~= 201 then + rspamd_logger.err(task, 'bad http code: %s', code) + maybe_defer(task) + return + end + local parser = ucl.parser() + local _, parse_err = parser:parse_string(body) + if parse_err then + rspamd_logger.err(task, 'cannot parse JSON: %s', err) + maybe_defer(task) + return + end + local obj = parser:get_object() + local work_id = obj.work_id + if work_id then + task:get_mempool():set_variable('contextal_work_id', work_id) + end + task:insert_result(settings.submission_symbol, 1.0, + string.format('work_id=%s', work_id or 'nil')) + if wait_request_ttl then + task:add_timer(settings.request_ttl, action_cb) + end + end + + local req = { + object_data = {['data'] = task:get_content()}, + } + if settings.request_ttl then + req.ttl = {['data'] = tostring(settings.request_ttl)} + end + if settings.max_recursion then + req.maxrec = {['data'] = tostring(settings.max_recursion)} + end + rspamd_http.request({ + task = task, + url = settings.submit_url, + body = lua_util.table_to_multipart_body(req, static_boundary), + callback = http_callback, + headers = { + ['Content-Type'] = string.format('multipart/form-data; boundary="%s"', static_boundary) + }, + timeout = settings.http_timeout, + gzip = settings.gzip, + keepalive = settings.keepalive, + no_ssl_verify = settings.no_ssl_verify, + }) +end + +local function cache_hit(task, err, data) + if err then + rspamd_logger.err(task, 'error getting cache: %s', err) + else + process_cached(task, data) + end +end + +local function submit_cb(task) + if cache_context then + redis_cache.cache_get(task, + task:get_digest(), + cache_context, + settings.cache_timeout, + submit, + cache_hit + ) + else + submit(task) + end +end + +local function set_url_path(base, path) + local slash = base:sub(#base) == '/' and '' or '/' + return base .. slash .. path +end + +settings = lua_util.override_defaults(settings, opts) + +local res, err = config_schema:transform(settings) +if not res then + rspamd_logger.warnx(rspamd_config, 'plugin %s is misconfigured: %s', N, err) + local err_msg = string.format("schema error: %s", res) + lua_util.config_utils.push_config_error(N, err_msg) + lua_util.disable_module(N, "failed", err_msg) + return +end + +for _, k in ipairs(settings.custom_actions) do + contextal_actions[k] = true +end + +if not settings.base_url then + if not (settings.submit_url and settings.actions_url) then + rspamd_logger.err(rspamd_config, 'no URL configured for contextal') + lua_util.disable_module(N, 'config') + return + end +else + if not settings.submit_url then + settings.submit_url = set_url_path(settings.base_url, 'api/v1/submit') + end + if not settings.actions_url then + settings.actions_url = set_url_path(settings.base_url, 'api/v1/actions/') + end +end + +redis_params = lua_redis.parse_redis_server(N) +if redis_params then + cache_context = redis_cache.create_cache_context(redis_params, { + cache_prefix = settings.cache_prefix, + cache_ttl = settings.cache_ttl, + cache_format = 'json', + cache_use_hashing = false + }) +end + +local submission_id = rspamd_config:register_symbol({ + name = settings.submission_symbol, + type = 'normal', + group = N, + callback = submit_cb +}) + +local top_options = rspamd_config:get_all_opt('options') +if settings.request_ttl and settings.request_ttl >= (top_options.task_timeout * 0.8) then + rspamd_logger.info(rspamd_config, [[request ttl is >= 80% of task timeout, won't wait on processing]]) + wait_request_ttl = false +elseif not settings.request_ttl then + wait_request_ttl = false +end + +local parent_id +if wait_request_ttl then + parent_id = submission_id +else + parent_id = rspamd_config:register_symbol({ + name = settings.action_symbol_prefix, + type = 'postfilter', + priority = lua_util.symbols_priorities.high - 1, + group = N, + callback = action_cb + }) +end + +for k in pairs(contextal_actions) do + rspamd_config:register_symbol({ + name = settings.action_symbol_prefix .. '_' .. k, + parent = parent_id, + type = 'virtual', + group = N, + }) +end diff --git a/src/plugins/lua/fuzzy_collect.lua b/src/plugins/lua/fuzzy_collect.lua index 132ace90c..060cc2fc2 100644 --- a/src/plugins/lua/fuzzy_collect.lua +++ b/src/plugins/lua/fuzzy_collect.lua @@ -34,7 +34,7 @@ local settings = { local function send_data_mirror(m, cfg, ev_base, body) local function store_callback(err, _, _, _) if err then - rspamd_logger.errx(cfg, 'cannot save data on %(%s): %s', m.server, m.name, err) + rspamd_logger.errx(cfg, 'cannot save data on %s(%s): %s', m.server, m.name, err) else rspamd_logger.infox(cfg, 'saved data on %s(%s)', m.server, m.name) end diff --git a/src/plugins/lua/gpt.lua b/src/plugins/lua/gpt.lua index 98a3e38ee..5776791a1 100644 --- a/src/plugins/lua/gpt.lua +++ b/src/plugins/lua/gpt.lua @@ -20,9 +20,9 @@ local E = {} if confighelp then rspamd_config:add_example(nil, 'gpt', - "Performs postfiltering using GPT model", - [[ -gpt { + "Performs postfiltering using GPT model", + [[ + gpt { # Supported types: openai, ollama type = "openai"; # Your key to access the API @@ -53,7 +53,7 @@ gpt { reason_header = "X-GPT-Reason"; # Use JSON format for response json = false; -} + } ]]) return end @@ -162,7 +162,7 @@ local function default_condition(task) end end lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, - sym.weight, required_weight) + sym.weight, required_weight) else return false, 'skip as "' .. s .. '" is found' end @@ -182,7 +182,7 @@ local function default_condition(task) end end lua_util.debugm(N, task, 'symbol %s has weight %s, but required %s', s, - sym.weight, required_weight) + sym.weight, required_weight) end else return false, 'skip as "' .. s .. '" is not found' @@ -301,7 +301,7 @@ local function default_openai_json_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -355,14 +355,27 @@ local function default_openai_plain_conversion(task, input) local reason = clean_reply_line(lines[2]) local categories = lua_util.str_split(clean_reply_line(lines[3]), ',') + if type(reply.usage) == 'table' then + rspamd_logger.infox(task, 'usage: %s tokens', reply.usage.total_tokens) + end + if spam_score then return spam_score, reason, categories end - rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1]) + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) return end +-- Helper function to remove <think>...</think> and trim leading newlines +local function clean_gpt_response(text) + -- Remove <think>...</think> including multiline + text = text:gsub("<think>.-</think>", "") + -- Trim leading whitespace and newlines + text = text:gsub("^%s*\n*", "") + return text +end + local function default_ollama_plain_conversion(task, input) local parser = ucl.parser() local res, err = parser:parse_string(input) @@ -387,6 +400,10 @@ local function default_ollama_plain_conversion(task, input) rspamd_logger.errx(task, 'no content in the first message') return end + + -- Clean message + first_message = clean_gpt_response(first_message) + local lines = lua_util.str_split(first_message, '\n') local first_line = clean_reply_line(lines[1]) local spam_score = tonumber(first_line) @@ -397,7 +414,7 @@ local function default_ollama_plain_conversion(task, input) return spam_score, reason, categories end - rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s', lines[1]) + rspamd_logger.errx(task, 'cannot parse plain gpt reply: %s (all: %s)', lines[1], first_message) return end @@ -449,7 +466,7 @@ local function default_ollama_json_conversion(task, input) elseif reply.probability == "low" then spam_score = 0.1 else - rspamd_logger.infox("cannot convert to spam probability: %s", reply.probability) + rspamd_logger.infox(task, "cannot convert to spam probability: %s", reply.probability) end end @@ -477,7 +494,7 @@ local function redis_cache_key(sel_part) env_digest = digest:hex():sub(1, 4) end return string.format('%s_%s', env_digest, - sel_part:get_mimepart():get_digest():sub(1, 24)) + sel_part:get_mimepart():get_digest():sub(1, 24)) end local function process_categories(task, categories) @@ -494,6 +511,7 @@ local function insert_results(task, result, sel_part) rspamd_logger.errx(task, 'no probability in result') return end + if result.probability > 0.5 then task:insert_result('GPT_SPAM', (result.probability - 0.5) * 2, tostring(result.probability)) if settings.autolearn then @@ -504,10 +522,6 @@ local function insert_results(task, result, sel_part) process_categories(task, result.categories) end else - if result.reason and settings.reason_header then - lua_mime.modify_headers(task, - { add = { [settings.reason_header] = { value = 'value', order = 1 } } }) - end task:insert_result('GPT_HAM', (0.5 - result.probability) * 2, tostring(result.probability)) if settings.autolearn then task:set_flag("learn_ham") @@ -516,6 +530,10 @@ local function insert_results(task, result, sel_part) process_categories(task, result.categories) end end + if result.reason and settings.reason_header then + lua_mime.modify_headers(task, + { add = { [settings.reason_header] = { value = tostring(result.reason), order = 1 } } }) + end if cache_context then lua_cache.cache_set(task, redis_cache_key(sel_part), result, cache_context) @@ -539,12 +557,12 @@ local function check_consensus_and_insert_results(task, results, sel_part) nspam = nspam + 1 max_spam_prob = math.max(max_spam_prob, result.probability) lua_util.debugm(N, task, "model: %s; spam: %s; reason: '%s'", - result.model, result.probability, result.reason) + result.model, result.probability, result.reason) else nham = nham + 1 max_ham_prob = math.min(max_ham_prob, result.probability) lua_util.debugm(N, task, "model: %s; ham: %s; reason: '%s'", - result.model, result.probability, result.reason) + result.model, result.probability, result.reason) end if result.reason then @@ -558,23 +576,22 @@ local function check_consensus_and_insert_results(task, results, sel_part) if nspam > nham and max_spam_prob > 0.75 then insert_results(task, { - probability = max_spam_prob, - reason = reason.reason, - categories = reason.categories, - }, - sel_part) + probability = max_spam_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) elseif nham > nspam and max_ham_prob < 0.25 then insert_results(task, { - probability = max_ham_prob, - reason = reason.reason, - categories = reason.categories, - }, - sel_part) + probability = max_ham_prob, + reason = reason.reason, + categories = reason.categories, + }, + sel_part) else -- No consensus lua_util.debugm(N, task, "no consensus") end - end local function get_meta_llm_content(task) @@ -673,7 +690,7 @@ local function openai_check(task, content, sel_part) }, { role = 'user', - content = 'Subject: ' .. task:get_subject() or '', + content = 'Subject: ' .. (task:get_subject() or ''), }, { role = 'user', @@ -725,7 +742,6 @@ local function openai_check(task, content, sel_part) if not rspamd_http.request(http_params) then results[idx].checked = true end - end end @@ -958,14 +974,14 @@ if opts then "FROM and url domains. Evaluate spam probability (0-1). " .. "Output ONLY 3 lines:\n" .. "1. Numeric score (0.00-1.00)\n" .. - "2. One-sentence reason citing strongest red flag\n" .. + "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" .. "3. Primary concern category if found from the list: " .. table.concat(lua_util.keys(categories_map), ', ') else settings.prompt = "Analyze this email strictly as a spam detector given the email message, subject, " .. "FROM and url domains. Evaluate spam probability (0-1). " .. "Output ONLY 2 lines:\n" .. "1. Numeric score (0.00-1.00)\n" .. - "2. One-sentence reason citing strongest red flag\n" + "2. One-sentence reason citing whether it is spam, the strongest red flag, or why it is ham\n" end end end diff --git a/src/plugins/lua/greylist.lua b/src/plugins/lua/greylist.lua index e4a633233..934e17bce 100644 --- a/src/plugins/lua/greylist.lua +++ b/src/plugins/lua/greylist.lua @@ -122,6 +122,29 @@ local function data_key(task) local h = hash.create() h:update(body, len) + local subject = task:get_subject() or '' + h:update(subject) + + -- Take recipients into account + local rcpt = task:get_recipients('smtp') + if rcpt then + table.sort(rcpt, function(r1, r2) + return r1['addr'] < r2['addr'] + end) + + fun.each(function(r) + h:update(r['addr']) + end, rcpt) + end + + -- Use from as well, but mime one + local from = task:get_from('mime') + + local addr = '<>' + if from and from[1] then + addr = from[1]['addr'] + end + h:update(addr) local b32 = settings['key_prefix'] .. 'b' .. h:base32():sub(1, 20) task:get_mempool():set_variable("grey_bodyhash", b32) diff --git a/src/plugins/lua/hfilter.lua b/src/plugins/lua/hfilter.lua index 6bc011b83..32102e4f8 100644 --- a/src/plugins/lua/hfilter.lua +++ b/src/plugins/lua/hfilter.lua @@ -131,6 +131,7 @@ local checks_hellohost = [[ /modem[.-][0-9]/i 5 /[0-9][.-]?dhcp/i 5 /wifi[.-][0-9]/i 5 +/[.-]vps[.-]/i 1 ]] local checks_hellohost_map @@ -199,9 +200,10 @@ local function check_regexp(str, regexp_text) return re:match(str) end -local function add_static_map(data) +local function add_static_map(data, description) return rspamd_config:add_map { type = 'regexp_multi', + description = description, url = { upstreams = 'static', data = data, @@ -568,16 +570,16 @@ local function append_t(t, a) end end if config['helo_enabled'] then - checks_hello_bareip_map = add_static_map(checks_hello_bareip) - checks_hello_badip_map = add_static_map(checks_hello_badip) - checks_hellohost_map = add_static_map(checks_hellohost) - checks_hello_map = add_static_map(checks_hello) + checks_hello_bareip_map = add_static_map(checks_hello_bareip, 'Hfilter: HELO bare ip') + checks_hello_badip_map = add_static_map(checks_hello_badip, 'Hfilter: HELO bad ip') + checks_hellohost_map = add_static_map(checks_hellohost, 'Hfilter: HELO host') + checks_hello_map = add_static_map(checks_hello, 'Hfilter: HELO') append_t(symbols_enabled, symbols_helo) timeout = math.max(timeout, rspamd_config:get_dns_timeout() * 3) end if config['hostname_enabled'] then if not checks_hellohost_map then - checks_hellohost_map = add_static_map(checks_hellohost) + checks_hellohost_map = add_static_map(checks_hellohost, 'Hfilter: HOSTNAME') end append_t(symbols_enabled, symbols_hostname) timeout = math.max(timeout, rspamd_config:get_dns_timeout()) diff --git a/src/plugins/lua/history_redis.lua b/src/plugins/lua/history_redis.lua index a3fdb0ec4..44eb40ad9 100644 --- a/src/plugins/lua/history_redis.lua +++ b/src/plugins/lua/history_redis.lua @@ -138,7 +138,7 @@ end local function history_save(task) local function redis_llen_cb(err, _) if err then - rspamd_logger.errx(task, 'got error %s when writing history row: %s', + rspamd_logger.errx(task, 'got error %s when writing history row', err) end end @@ -188,7 +188,7 @@ local function handle_history_request(task, conn, from, to, reset) if reset then local function redis_ltrim_cb(err, _) if err then - rspamd_logger.errx(task, 'got error %s when resetting history: %s', + rspamd_logger.errx(task, 'got error %s when resetting history', err) conn:send_error(504, '{"error": "' .. err .. '"}') else @@ -258,7 +258,7 @@ local function handle_history_request(task, conn, from, to, reset) (rspamd_util:get_ticks() - t1) * 1000.0) collectgarbage() else - rspamd_logger.errx(task, 'got error %s when getting history: %s', + rspamd_logger.errx(task, 'got error %s when getting history', err) conn:send_error(504, '{"error": "' .. err .. '"}') end diff --git a/src/plugins/lua/known_senders.lua b/src/plugins/lua/known_senders.lua index 5cb2ddcf5..0cbf3cdcf 100644 --- a/src/plugins/lua/known_senders.lua +++ b/src/plugins/lua/known_senders.lua @@ -106,21 +106,26 @@ local function configure_scripts(_, _, _) -- script checks if given recipients are in the local replies set of the sender local redis_zscore_script = [[ local replies_recipients_addrs = ARGV - if replies_recipients_addrs then + if replies_recipients_addrs and #replies_recipients_addrs > 0 then + local found = false for _, rcpt in ipairs(replies_recipients_addrs) do local score = redis.call('ZSCORE', KEYS[1], rcpt) - -- check if score is nil (for some reason redis script does not see if score is a nil value) - if type(score) == 'boolean' then - score = nil - -- 0 is stand for failure code - return 0 + if score then + -- If we found at least one recipient, consider it a match + found = true + break end end - -- first number in return statement is stands for the success/failure code - -- where success code is 1 and failure code is 0 - return 1 + + if found then + -- Success code is 1 + return 1 + else + -- Failure code is 0 + return 0 + end else - -- 0 is a failure code + -- No recipients to check, failure code is 0 return 0 end ]] @@ -259,7 +264,13 @@ local function verify_local_replies_set(task) return nil end - local replies_recipients = task:get_recipients('mime') or E + local replies_recipients = task:get_recipients('smtp') or E + + -- If no recipients, don't proceed + if #replies_recipients == 0 then + lua_util.debugm(N, task, 'No recipients to verify') + return nil + end local replies_sender_string = lua_util.maybe_obfuscate_string(tostring(replies_sender), settings, settings.sender_prefix) @@ -268,13 +279,16 @@ local function verify_local_replies_set(task) local function redis_zscore_script_cb(err, data) if err ~= nil then rspamd_logger.errx(task, 'Could not verify %s local replies set %s', replies_sender_key, err) - end - if data ~= 1 then - lua_util.debugm(N, task, 'Recipients were not verified') return end - lua_util.debugm(N, task, 'Recipients were verified') - task:insert_result(settings.symbol_check_mail_local, 1.0, replies_sender_key) + + -- We need to ensure we're properly checking the result + if data == 1 then + lua_util.debugm(N, task, 'Recipients were verified') + task:insert_result(settings.symbol_check_mail_local, 1.0, replies_sender_key) + else + lua_util.debugm(N, task, 'Recipients were not verified, data=%s', data) + end end local replies_recipients_addrs = {} @@ -284,12 +298,24 @@ local function verify_local_replies_set(task) table.insert(replies_recipients_addrs, replies_recipients[i].addr) end - lua_util.debugm(N, task, 'Making redis request to local replies set') - lua_redis.exec_redis_script(zscore_script_id, + -- Only proceed if we have recipients to check + if #replies_recipients_addrs == 0 then + lua_util.debugm(N, task, 'No recipient addresses to verify') + return nil + end + + lua_util.debugm(N, task, 'Making redis request to local replies set with key %s and recipients %s', + replies_sender_key, table.concat(replies_recipients_addrs, ", ")) + + local ret = lua_redis.exec_redis_script(zscore_script_id, { task = task, is_write = true }, redis_zscore_script_cb, { replies_sender_key }, replies_recipients_addrs) + + if not ret then + rspamd_logger.errx(task, "redis script request wasn't scheduled") + end end local function check_known_incoming_mail_callback(task) diff --git a/src/plugins/lua/milter_headers.lua b/src/plugins/lua/milter_headers.lua index 2daeeed78..17fc90562 100644 --- a/src/plugins/lua/milter_headers.lua +++ b/src/plugins/lua/milter_headers.lua @@ -138,7 +138,7 @@ local function milter_headers(task) local function skip_wanted(hdr) if settings_override then - return true + return false end -- Normal checks local function match_extended_headers_rcpt() diff --git a/src/plugins/lua/mime_types.lua b/src/plugins/lua/mime_types.lua index c69fa1e7b..73cd63c6a 100644 --- a/src/plugins/lua/mime_types.lua +++ b/src/plugins/lua/mime_types.lua @@ -128,6 +128,7 @@ local settings = { inf = 4, its = 4, jnlp = 4, + ['library-ms'] = 4, lnk = 4, ksh = 4, mad = 4, @@ -179,6 +180,7 @@ local settings = { reg = 4, scf = 4, scr = 4, + ['search-ms'] = 4, shs = 4, theme = 4, url = 4, @@ -406,9 +408,9 @@ local function check_mime_type(task) local score2 = check_tables(ext2) -- Check if detected extension match real extension if detected_ext and detected_ext == ext then - check_extension(score1, nil) + check_extension(score1, nil) else - check_extension(score1, score2) + check_extension(score1, score2) end -- Check for archive cloaking like .zip.gz if settings['archive_extensions'][ext2] diff --git a/src/plugins/lua/multimap.lua b/src/plugins/lua/multimap.lua index b96c105b1..0c82b167e 100644 --- a/src/plugins/lua/multimap.lua +++ b/src/plugins/lua/multimap.lua @@ -1282,7 +1282,7 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end elseif newrule['type'] == 'received' then @@ -1303,7 +1303,7 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end else @@ -1312,7 +1312,7 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end end @@ -1328,11 +1328,14 @@ local function add_multimap_rule(key, newrule) if newrule.map_obj then ret = true else - rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %1', + rspamd_logger.warnx(rspamd_config, 'Cannot add rule: map doesn\'t exists: %s', newrule['map']) end elseif newrule['type'] == 'dnsbl' then ret = true + else + rspamd_logger.errx(rspamd_config, 'cannot add rule %s: invalid type %s', + key, newrule['type']) end end diff --git a/src/plugins/lua/ratelimit.lua b/src/plugins/lua/ratelimit.lua index c20e61b17..d463658fa 100644 --- a/src/plugins/lua/ratelimit.lua +++ b/src/plugins/lua/ratelimit.lua @@ -373,7 +373,7 @@ local function ratelimit_cb(task) local function gen_check_cb(prefix, bucket, lim_name, lim_key) return function(err, data) if err then - rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data) + rspamd_logger.errx('cannot check limit %s: %s', prefix, err) elseif type(data) == 'table' and data[1] then lua_util.debugm(N, task, "got reply for limit %s (%s / %s); %s burst, %s:%s dyn, %s leaked", @@ -416,7 +416,7 @@ local function ratelimit_cb(task) task:set_pre_result('soft reject', message_func(task, lim_name, prefix, bucket, lim_key), N) else - task:set_pre_result('soft reject', bucket.message) + task:set_pre_result('soft reject', bucket.message, N) end end end @@ -476,7 +476,7 @@ local function maybe_cleanup_pending(task) local bucket = v.bucket local function cleanup_cb(err, data) if err then - rspamd_logger.errx('cannot cleanup limit %s: %s %s', k, err, data) + rspamd_logger.errx('cannot cleanup limit %s: %s', k, err) else lua_util.debugm(N, task, 'cleaned pending bucked for %s: %s', k, data) end diff --git a/src/plugins/lua/rbl.lua b/src/plugins/lua/rbl.lua index af4a4cd15..b5b904b00 100644 --- a/src/plugins/lua/rbl.lua +++ b/src/plugins/lua/rbl.lua @@ -1077,7 +1077,7 @@ local function add_rbl(key, rbl, global_opts) rbl.selector_flatten) if not sel then - rspamd_logger.errx('invalid selector for rbl rule %s: %s', key, selector) + rspamd_logger.errx(rspamd_config, 'invalid selector for rbl rule %s: %s', key, selector) return false end diff --git a/src/plugins/lua/replies.lua b/src/plugins/lua/replies.lua index 08fb68bc7..2f0153d00 100644 --- a/src/plugins/lua/replies.lua +++ b/src/plugins/lua/replies.lua @@ -79,8 +79,8 @@ local function configure_redis_scripts(_, _) end ]] local set_script_zadd_global = lua_util.jinja_template(redis_script_zadd_global, - { max_global_size = settings.max_global_size }) - global_replies_set_script = lua_redis.add_redis_script(set_script_zadd_global, redis_params) + { max_global_size = settings.max_global_size }) + global_replies_set_script = lua_redis.add_redis_script(set_script_zadd_global, redis_params) local redis_script_zadd_local = [[ redis.call('ZREMRANGEBYRANK', KEYS[1], 0, -({= max_local_size =} + 1)) -- keeping size of local replies set @@ -102,7 +102,7 @@ local function configure_redis_scripts(_, _) end ]] local set_script_zadd_local = lua_util.jinja_template(redis_script_zadd_local, - { expire_time = settings.expire, max_local_size = settings.max_local_size }) + { expire_time = settings.expire, max_local_size = settings.max_local_size }) local_replies_set_script = lua_redis.add_redis_script(set_script_zadd_local, redis_params) end @@ -110,7 +110,7 @@ local function replies_check(task) local in_reply_to local function check_recipient(stored_rcpt) - local rcpts = task:get_recipients('mime') + local rcpts = task:get_recipients('smtp') lua_util.debugm(N, task, 'recipients: %s', rcpts) if rcpts then local filter_predicate = function(input_rcpt) @@ -119,7 +119,7 @@ local function replies_check(task) return real_rcpt_h == stored_rcpt end - if fun.any(filter_predicate, fun.map(function(rcpt) + if fun.all(filter_predicate, fun.map(function(rcpt) return rcpt.addr or '' end, rcpts)) then lua_util.debugm(N, task, 'reply to %s validated', in_reply_to) @@ -155,9 +155,9 @@ local function replies_check(task) end lua_redis.exec_redis_script(global_replies_set_script, - { task = task, is_write = true }, - zadd_global_set_cb, - { global_key }, params) + { task = task, is_write = true }, + zadd_global_set_cb, + { global_key }, params) end local function add_to_replies_set(recipients) @@ -173,7 +173,7 @@ local function replies_check(task) local params = recipients lua_util.debugm(N, task, - 'Adding recipients %s to sender %s local replies set', recipients, sender_key) + 'Adding recipients %s to sender %s local replies set', recipients, sender_key) local function zadd_cb(err, _) if err ~= nil then @@ -189,9 +189,9 @@ local function replies_check(task) table.insert(params, 1, task_time_str) lua_redis.exec_redis_script(local_replies_set_script, - { task = task, is_write = true }, - zadd_cb, - { sender_key }, params) + { task = task, is_write = true }, + zadd_cb, + { sender_key }, params) end local function redis_get_cb(err, data, addr) @@ -387,7 +387,7 @@ if opts then end lua_redis.register_prefix(settings.sender_prefix, N, - 'Prefix to identify replies sets') + 'Prefix to identify replies sets') local id = rspamd_config:register_symbol({ name = 'REPLIES_CHECK', diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua index bd7d91932..eacaee064 100644 --- a/src/plugins/lua/reputation.lua +++ b/src/plugins/lua/reputation.lua @@ -200,7 +200,9 @@ local function dkim_reputation_filter(task, rule) end end - if sel_tld and requests[sel_tld] then + if rule.selector.config.exclusion_map and sel_tld and rule.selector.config.exclusion_map:get_key(sel_tld) then + lua_util.debugm(N, task, 'DKIM domain %s is excluded from reputation scoring', sel_tld) + elseif sel_tld and requests[sel_tld] then if requests[sel_tld] == 'a' then rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task) end @@ -243,9 +245,13 @@ local function dkim_reputation_idempotent(task, rule) if sc then for dom, res in pairs(requests) do - -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs - local query = string.format('%s.%s', dom, res) - rule.backend.set_token(task, rule, nil, query, sc) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(dom) then + lua_util.debugm(N, task, 'DKIM domain %s is excluded from reputation update', dom) + else + -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs + local query = string.format('%s.%s', dom, res) + rule.backend.set_token(task, rule, nil, query, sc) + end end end end @@ -277,6 +283,7 @@ local dkim_selector = { outbound = true, inbound = true, max_accept_adjustment = 2.0, -- How to adjust accepted DKIM score + exclusion_map = nil }, dependencies = { "DKIM_TRACE" }, filter = dkim_reputation_filter, -- used to get scores @@ -356,10 +363,14 @@ local function url_reputation_filter(task, rule) for i, res in pairs(results) do local req = requests[i] if req then - local url_score = generic_reputation_calc(res, rule, - req[2] / mhits, task) - lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score) - score = score + url_score + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(req[1]) then + lua_util.debugm(N, task, 'URL domain %s is excluded from reputation scoring', req[1]) + else + local url_score = generic_reputation_calc(res, rule, + req[2] / mhits, task) + lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score) + score = score + url_score + end end end @@ -386,7 +397,11 @@ local function url_reputation_idempotent(task, rule) if sc then for _, tld in ipairs(requests) do - rule.backend.set_token(task, rule, nil, tld[1], sc) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(tld[1]) then + lua_util.debugm(N, task, 'URL domain %s is excluded from reputation update', tld[1]) + else + rule.backend.set_token(task, rule, nil, tld[1], sc) + end end end end @@ -401,6 +416,7 @@ local url_selector = { check_from = true, outbound = true, inbound = true, + exclusion_map = nil }, filter = url_reputation_filter, -- used to get scores idempotent = url_reputation_idempotent -- used to set scores @@ -439,6 +455,11 @@ local function ip_reputation_filter(task, rule) ip = ip:apply_mask(cfg.ipv6_mask) end + if cfg.exclusion_map and cfg.exclusion_map:get_key(ip) then + lua_util.debugm(N, task, 'IP %s is excluded from reputation scoring', tostring(ip)) + return + end + local pool = task:get_mempool() local asn = pool:get_variable("asn") local country = pool:get_variable("country") @@ -554,6 +575,11 @@ local function ip_reputation_idempotent(task, rule) ip = ip:apply_mask(cfg.ipv6_mask) end + if cfg.exclusion_map and cfg.exclusion_map:get_key(ip) then + lua_util.debugm(N, task, 'IP %s is excluded from reputation update', tostring(ip)) + return + end + local pool = task:get_mempool() local asn = pool:get_variable("asn") local country = pool:get_variable("country") @@ -600,6 +626,7 @@ local ip_selector = { inbound = true, ipv4_mask = 32, -- Mask bits for ipv4 ipv6_mask = 64, -- Mask bits for ipv6 + exclusion_map = nil }, --dependencies = {"ASN"}, -- ASN is a prefilter now... init = ip_reputation_init, @@ -621,6 +648,11 @@ local function spf_reputation_filter(task, rule) local cr = require "rspamd_cryptobox_hash" local hkey = cr.create(spf_record):base32():sub(1, 32) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(hkey) then + lua_util.debugm(N, task, 'SPF record %s is excluded from reputation scoring', hkey) + return + end + lua_util.debugm(N, task, 'check spf record %s -> %s', spf_record, hkey) local function tokens_cb(err, token, values) @@ -649,6 +681,11 @@ local function spf_reputation_idempotent(task, rule) local cr = require "rspamd_cryptobox_hash" local hkey = cr.create(spf_record):base32():sub(1, 32) + if rule.selector.config.exclusion_map and rule.selector.config.exclusion_map:get_key(hkey) then + lua_util.debugm(N, task, 'SPF record %s is excluded from reputation update', hkey) + return + end + lua_util.debugm(N, task, 'set spf record %s -> %s = %s', spf_record, hkey, sc) rule.backend.set_token(task, rule, nil, hkey, sc) @@ -663,6 +700,7 @@ local spf_selector = { max_score = nil, outbound = true, inbound = true, + exclusion_map = nil }, dependencies = { "R_SPF_ALLOW" }, filter = spf_reputation_filter, -- used to get scores @@ -697,6 +735,13 @@ local function generic_reputation_init(rule) 'Whitelisted selectors') end + if cfg.exclusion_map then + cfg.exclusion_map = lua_maps.map_add('reputation', + 'generic_exclusion', + 'set', + 'Excluded selectors') + end + return true end @@ -706,6 +751,10 @@ local function generic_reputation_filter(task, rule) local function tokens_cb(err, token, values) if values then + if cfg.exclusion_map and cfg.exclusion_map:get_key(token) then + lua_util.debugm(N, task, 'Generic selector token %s is excluded from reputation scoring', token) + return + end local score = generic_reputation_calc(values, rule, 1.0, task) if math.abs(score) > 1e-3 then @@ -742,14 +791,22 @@ local function generic_reputation_idempotent(task, rule) if sc then if type(selector_res) == 'table' then fun.each(function(e) - lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', - rule['symbol'], e, sc) - rule.backend.set_token(task, rule, nil, e, sc) + if cfg.exclusion_map and cfg.exclusion_map:get_key(e) then + lua_util.debugm(N, task, 'Generic selector token %s is excluded from reputation update', e) + else + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], e, sc) + rule.backend.set_token(task, rule, nil, e, sc) + end end, selector_res) else - lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', - rule['symbol'], selector_res, sc) - rule.backend.set_token(task, rule, nil, selector_res, sc) + if cfg.exclusion_map and cfg.exclusion_map:get_key(selector_res) then + lua_util.debugm(N, task, 'Generic selector token %s is excluded from reputation update', selector_res) + else + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], selector_res, sc) + rule.backend.set_token(task, rule, nil, selector_res, sc) + end end end end @@ -764,6 +821,7 @@ local generic_selector = { selector = ts.string, delimiter = ts.string, whitelist = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional(), + exclusion_map = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional() }, config = { lower_bound = 10, -- minimum number of messages to be scored @@ -773,7 +831,8 @@ local generic_selector = { inbound = true, selector = nil, delimiter = ':', - whitelist = nil + whitelist = nil, + exclusion_map = nil }, init = generic_reputation_init, filter = generic_reputation_filter, -- used to get scores @@ -1107,7 +1166,7 @@ local backends = { name = '1m', mult = 1.0, } - }, -- What buckets should be used, default 1h and 1month + }, -- What buckets should be used, default 1month }, init = reputation_redis_init, get_token = reputation_redis_get_token, @@ -1267,6 +1326,24 @@ local function parse_rule(name, tbl) end end + -- Parse exclusion_map for reputation exclusion lists + if rule.config.exclusion_map then + local map_type = 'set' -- Default to set for string-based selectors (dkim, url, spf, generic) + if sel_type == 'ip' or sel_type == 'sender' then + map_type = 'radix' -- Use radix for IP-based selectors + end + local map = lua_maps.map_add_from_ucl(rule.config.exclusion_map, + map_type, + sel_type .. ' reputation exclusion map') + if not map then + rspamd_logger.errx(rspamd_config, "cannot parse exclusion map config for %s: (%s)", + sel_type, + rule.config.exclusion_map) + return false + end + rule.config.exclusion_map = map + end + local symbol = rule.selector.config.symbol or name if tbl.symbol then symbol = tbl.symbol @@ -1387,4 +1464,4 @@ if opts['rules'] then end else lua_util.disable_module(N, "config") -end +end
\ No newline at end of file diff --git a/src/plugins/lua/settings.lua b/src/plugins/lua/settings.lua index 0f8e00723..c576e1325 100644 --- a/src/plugins/lua/settings.lua +++ b/src/plugins/lua/settings.lua @@ -1275,7 +1275,7 @@ local function gen_redis_callback(handler, id) ucl_err) else local obj = parser:get_object() - rspamd_logger.infox(task, "<%1> apply settings according to redis rule %2", + rspamd_logger.infox(task, "<%s> apply settings according to redis rule %s", task:get_message_id(), id) apply_settings(task, obj, nil, 'redis') break @@ -1283,7 +1283,7 @@ local function gen_redis_callback(handler, id) end end elseif err then - rspamd_logger.errx(task, 'Redis error: %1', err) + rspamd_logger.errx(task, 'Redis error: %s', err) end end @@ -1371,7 +1371,7 @@ if set_section and set_section[1] and type(set_section[1]) == "string" then opaque_data = true } if not rspamd_config:add_map(map_attrs) then - rspamd_logger.errx(rspamd_config, 'cannot load settings from %1', set_section) + rspamd_logger.errx(rspamd_config, 'cannot load settings from %s', set_section) end elseif set_section and type(set_section) == "table" then settings_map_pool = rspamd_mempool.create() diff --git a/src/plugins/lua/spamassassin.lua b/src/plugins/lua/spamassassin.lua index 3ea794495..c03481de2 100644 --- a/src/plugins/lua/spamassassin.lua +++ b/src/plugins/lua/spamassassin.lua @@ -221,7 +221,7 @@ local function handle_header_def(hline, cur_rule) }) cur_rule['function'] = function(task) if not re then - rspamd_logger.errx(task, 're is missing for rule %1', h) + rspamd_logger.errx(task, 're is missing for rule %s', h) return 0 end @@ -272,7 +272,7 @@ local function handle_header_def(hline, cur_rule) elseif func == 'case' then cur_param['strong'] = true else - rspamd_logger.warnx(rspamd_config, 'Function %1 is not supported in %2', + rspamd_logger.warnx(rspamd_config, 'Function %s is not supported in %s', func, cur_rule['symbol']) end end, fun.tail(args)) @@ -314,7 +314,7 @@ end local function freemail_search(input) local res = 0 local function trie_callback(number, pos) - lua_util.debugm(N, rspamd_config, 'Matched pattern %1 at pos %2', freemail_domains[number], pos) + lua_util.debugm(N, rspamd_config, 'Matched pattern %s at pos %s', freemail_domains[number], pos) res = res + 1 end @@ -369,7 +369,7 @@ local function gen_eval_rule(arg) end return 0 else - rspamd_logger.infox(rspamd_config, 'cannot create regexp %1', re) + rspamd_logger.infox(rspamd_config, 'cannot create regexp %s', re) return 0 end end @@ -461,7 +461,7 @@ local function gen_eval_rule(arg) end end else - rspamd_logger.infox(task, 'unimplemented mime check %1', arg) + rspamd_logger.infox(task, 'unimplemented mime check %s', arg) end end @@ -576,7 +576,7 @@ local function maybe_parse_sa_function(line) local elts = split(line, '[^:]+') arg = elts[2] - lua_util.debugm(N, rspamd_config, 'trying to parse SA function %1 with args %2', + lua_util.debugm(N, rspamd_config, 'trying to parse SA function %s with args %s', elts[1], elts[2]) local substitutions = { { '^exists:', @@ -612,7 +612,7 @@ local function maybe_parse_sa_function(line) end if not func then - rspamd_logger.errx(task, 'cannot find appropriate eval rule for function %1', + rspamd_logger.errx(task, 'cannot find appropriate eval rule for function %s', arg) else return func(task) @@ -685,7 +685,7 @@ local function process_sa_conf(f) end -- We have previous rule valid if not cur_rule['symbol'] then - rspamd_logger.errx(rspamd_config, 'bad rule definition: %1', cur_rule) + rspamd_logger.errx(rspamd_config, 'bad rule definition: %s', cur_rule) end rules[cur_rule['symbol']] = cur_rule cur_rule = {} @@ -695,15 +695,15 @@ local function process_sa_conf(f) local function parse_score(words) if #words == 3 then -- score rule <x> - lua_util.debugm(N, rspamd_config, 'found score for %1: %2', words[2], words[3]) + lua_util.debugm(N, rspamd_config, 'found score for %s: %s', words[2], words[3]) return tonumber(words[3]) elseif #words == 6 then -- score rule <x1> <x2> <x3> <x4> -- we assume here that bayes and network are enabled and select <x4> - lua_util.debugm(N, rspamd_config, 'found score for %1: %2', words[2], words[6]) + lua_util.debugm(N, rspamd_config, 'found score for %s: %s', words[2], words[6]) return tonumber(words[6]) else - rspamd_logger.errx(rspamd_config, 'invalid score for %1', words[2]) + rspamd_logger.errx(rspamd_config, 'invalid score for %s', words[2]) end return 0 @@ -812,7 +812,7 @@ local function process_sa_conf(f) cur_rule['re'] = rspamd_regexp.create(cur_rule['re_expr']) if not cur_rule['re'] then - rspamd_logger.warnx(rspamd_config, "Cannot parse regexp '%1' for %2", + rspamd_logger.warnx(rspamd_config, "Cannot parse regexp '%s' for %s", cur_rule['re_expr'], cur_rule['symbol']) else cur_rule['re']:set_max_hits(1) @@ -829,8 +829,8 @@ local function process_sa_conf(f) cur_rule['mime'] = false end - if cur_rule['re'] and cur_rule['symbol'] and - (cur_rule['header'] or cur_rule['function']) then + if cur_rule['re'] and cur_rule['symbol'] + and (cur_rule['header'] or cur_rule['function']) then valid_rule = true cur_rule['re']:set_max_hits(1) if cur_rule['header'] and cur_rule['ordinary'] then @@ -894,7 +894,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "body" then @@ -931,7 +931,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "rawbody" then @@ -968,7 +968,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "full" then @@ -1006,7 +1006,7 @@ local function process_sa_conf(f) cur_rule['function'] = func valid_rule = true else - rspamd_logger.infox(rspamd_config, 'unknown function %1', args) + rspamd_logger.infox(rspamd_config, 'unknown function %s', args) end end elseif words[1] == "uri" then @@ -1265,11 +1265,11 @@ local function post_process() if res then local nre = rspamd_regexp.create(nexpr) if not nre then - rspamd_logger.errx(rspamd_config, 'cannot apply replacement for rule %1', r) + rspamd_logger.errx(rspamd_config, 'cannot apply replacement for rule %s', r) --rule['re'] = nil else local old_max_hits = rule['re']:get_max_hits() - lua_util.debugm(N, rspamd_config, 'replace %1 -> %2', r, nexpr) + lua_util.debugm(N, rspamd_config, 'replace %s -> %s', r, nexpr) rspamd_config:replace_regexp({ old_re = rule['re'], new_re = nre, @@ -1306,8 +1306,7 @@ local function post_process() end if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1 (%2 header)', k, - h['header']) + rspamd_logger.errx(task, 're is missing for rule %s', h) return 0 end @@ -1434,7 +1433,7 @@ local function post_process() fun.each(function(k, r) local f = function(task) if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1', k) + rspamd_logger.errx(task, 're is missing for rule %s', k) return 0 end @@ -1461,7 +1460,7 @@ local function post_process() fun.each(function(k, r) local f = function(task) if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1', k) + rspamd_logger.errx(task, 're is missing for rule %s', k) return 0 end @@ -1486,7 +1485,7 @@ local function post_process() fun.each(function(k, r) local f = function(task) if not r['re'] then - rspamd_logger.errx(task, 're is missing for rule %1', k) + rspamd_logger.errx(task, 're is missing for rule %s', k) return 0 end @@ -1629,8 +1628,8 @@ local function post_process() rspamd_config:register_dependency(k, rspamd_symbol) external_deps[k][rspamd_symbol] = true lua_util.debugm(N, rspamd_config, - 'atom %1 is a direct foreign dependency, ' .. - 'register dependency for %2 on %3', + 'atom %s is a direct foreign dependency, ' .. + 'register dependency for %s on %s', a, k, rspamd_symbol) end end @@ -1659,8 +1658,8 @@ local function post_process() rspamd_config:register_dependency(k, dep) external_deps[k][dep] = true lua_util.debugm(N, rspamd_config, - 'atom %1 is an indirect foreign dependency, ' .. - 'register dependency for %2 on %3', + 'atom %s is an indirect foreign dependency, ' .. + 'register dependency for %s on %s', a, k, dep) nchanges = nchanges + 1 end @@ -1694,10 +1693,10 @@ local function post_process() -- Logging output if freemail_domains then freemail_trie = rspamd_trie.create(freemail_domains) - rspamd_logger.infox(rspamd_config, 'loaded %1 freemail domains definitions', + rspamd_logger.infox(rspamd_config, 'loaded %s freemail domains definitions', #freemail_domains) end - rspamd_logger.infox(rspamd_config, 'loaded %1 blacklist/whitelist elements', + rspamd_logger.infox(rspamd_config, 'loaded %s blacklist/whitelist elements', sa_lists['elts']) end @@ -1739,7 +1738,7 @@ if type(section) == "table" then process_sa_conf(f) has_rules = true else - rspamd_logger.errx(rspamd_config, "cannot open %1", matched) + rspamd_logger.errx(rspamd_config, "cannot open %s", matched) end end end @@ -1758,7 +1757,7 @@ if type(section) == "table" then process_sa_conf(f) has_rules = true else - rspamd_logger.errx(rspamd_config, "cannot open %1", matched) + rspamd_logger.errx(rspamd_config, "cannot open %s", matched) end end end diff --git a/src/plugins/lua/trie.lua b/src/plugins/lua/trie.lua index 7ba455289..7c7214b55 100644 --- a/src/plugins/lua/trie.lua +++ b/src/plugins/lua/trie.lua @@ -107,10 +107,10 @@ local function process_trie_file(symbol, cf) local file = io.open(cf['file']) if not file then - rspamd_logger.errx(rspamd_config, 'Cannot open trie file %1', cf['file']) + rspamd_logger.errx(rspamd_config, 'Cannot open trie file %s', cf['file']) else if cf['binary'] then - rspamd_logger.errx(rspamd_config, 'binary trie patterns are not implemented yet: %1', + rspamd_logger.errx(rspamd_config, 'binary trie patterns are not implemented yet: %s', cf['file']) else for line in file:lines() do @@ -123,7 +123,7 @@ end local function process_trie_conf(symbol, cf) if type(cf) ~= 'table' then - rspamd_logger.errx(rspamd_config, 'invalid value for symbol %1: "%2", expected table', + rspamd_logger.errx(rspamd_config, 'invalid value for symbol %s: "%s", expected table', symbol, cf) return end @@ -145,17 +145,17 @@ if opts then if #raw_patterns > 0 then raw_trie = rspamd_trie.create(raw_patterns) - rspamd_logger.infox(rspamd_config, 'registered raw search trie from %1 patterns', #raw_patterns) + rspamd_logger.infox(rspamd_config, 'registered raw search trie from %s patterns', #raw_patterns) end if #mime_patterns > 0 then mime_trie = rspamd_trie.create(mime_patterns) - rspamd_logger.infox(rspamd_config, 'registered mime search trie from %1 patterns', #mime_patterns) + rspamd_logger.infox(rspamd_config, 'registered mime search trie from %s patterns', #mime_patterns) end if #body_patterns > 0 then body_trie = rspamd_trie.create(body_patterns) - rspamd_logger.infox(rspamd_config, 'registered body search trie from %1 patterns', #body_patterns) + rspamd_logger.infox(rspamd_config, 'registered body search trie from %s patterns', #body_patterns) end local id = -1 diff --git a/src/rspamadm/configdump.c b/src/rspamadm/configdump.c index 456875cf2..d090b66f0 100644 --- a/src/rspamadm/configdump.c +++ b/src/rspamadm/configdump.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,8 @@ static gboolean symbol_groups_only = FALSE; static gboolean symbol_full_details = FALSE; static gboolean skip_template = FALSE; static char *config = NULL; +static gboolean local_conf_only = FALSE; +static gboolean override_conf_only = FALSE; extern struct rspamd_main *rspamd_main; /* Defined in modules.c */ extern module_t *modules[]; @@ -66,6 +68,8 @@ static GOptionEntry entries[] = { "Show full symbol details only", NULL}, {"skip-template", 'T', 0, G_OPTION_ARG_NONE, &skip_template, "Do not apply Jinja templates", NULL}, + {"local", 0, 0, G_OPTION_ARG_NONE, &local_conf_only, "Show only local and override configuration", NULL}, + {"override", 0, 0, G_OPTION_ARG_NONE, &override_conf_only, "Show only override configuration", NULL}, {NULL, 0, 0, G_OPTION_ARG_NONE, NULL, NULL, NULL}}; static const char * @@ -82,6 +86,8 @@ rspamadm_configdump_help(gboolean full_help, const struct rspamadm_command *cmd) "-c: config file to test\n" "-m: show state of modules only\n" "-h: show help for dumped options\n" + "--local: show only local (and override) configuration\n" + "--override: show only override configuration\n" "--help: shows available options and commands"; } else { @@ -96,6 +102,84 @@ config_logger(rspamd_mempool_t *pool, gpointer ud) { } +static ucl_object_t * +filter_non_default(const ucl_object_t *obj, bool override_only) +{ + ucl_object_t *result = NULL; + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + + if (obj == NULL) { + return NULL; + } + + int min_prio = override_only ? 1 : 0; + + if (ucl_object_get_priority(obj) > min_prio) { + + switch (ucl_object_type(obj)) { + case UCL_OBJECT: + result = ucl_object_typed_new(ucl_object_type(obj)); + + while ((cur = ucl_object_iterate(obj, &it, true))) { + ucl_object_t *filtered = filter_non_default(cur, override_conf_only); + if (filtered) { + ucl_object_insert_key(result, filtered, ucl_object_key(cur), cur->keylen, true); + } + } + break; + case UCL_ARRAY: + result = ucl_object_typed_new(ucl_object_type(obj)); + + while ((cur = ucl_object_iterate(obj, &it, true))) { + ucl_object_t *filtered = filter_non_default(cur, override_conf_only); + if (filtered) { + ucl_array_append(result, filtered); + } + } + default: + result = ucl_object_ref(obj); + break; + } + + return result; + } + + if (ucl_object_type(obj) == UCL_OBJECT || ucl_object_type(obj) == UCL_ARRAY) { + bool has_non_default = false; + + result = ucl_object_typed_new(ucl_object_type(obj)); + while ((cur = ucl_object_iterate(obj, &it, true))) { + ucl_object_t *filtered = filter_non_default(cur, override_only); + if (filtered) { + has_non_default = true; + + if (ucl_object_type(obj) == UCL_OBJECT) { + ucl_object_insert_key(result, filtered, + ucl_object_key(cur), cur->keylen, true); + } + else if (ucl_object_type(obj) == UCL_ARRAY) { + ucl_array_append(result, filtered); + } + else { + g_assert_not_reached(); + } + } + } + + /* Avoid empty objects */ + if (!has_non_default) { + ucl_object_unref(result); + result = NULL; + } + + return result; + } + + + return NULL; +} + static void rspamadm_add_doc_elt(const ucl_object_t *obj, const ucl_object_t *doc_obj, ucl_object_t *comment_obj) @@ -524,7 +608,20 @@ rspamadm_configdump(int argc, char **argv, const struct rspamadm_command *cmd) /* Output configuration */ if (argc == 1) { - rspamadm_dump_section_obj(cfg, cfg->cfg_ucl_obj, cfg->doc_strings); + const ucl_object_t *output_obj = cfg->cfg_ucl_obj; + if (local_conf_only || override_conf_only) { + output_obj = filter_non_default(cfg->cfg_ucl_obj, override_conf_only); + if (!output_obj) { + rspamd_printf("No non-default configuration found\n"); + exit(EXIT_SUCCESS); + } + } + + rspamadm_dump_section_obj(cfg, output_obj, cfg->doc_strings); + + if (local_conf_only || override_conf_only) { + ucl_object_unref((ucl_object_t *) output_obj); + } } else { for (i = 1; i < argc; i++) { @@ -537,10 +634,18 @@ rspamadm_configdump(int argc, char **argv, const struct rspamadm_command *cmd) else { LL_FOREACH(obj, cur) { + const ucl_object_t *output_obj = cur; + if (local_conf_only || override_conf_only) { + output_obj = filter_non_default(cur, override_conf_only); + if (!output_obj) { + rspamd_printf("No non-default configuration found for section %s\n", argv[i]); + continue; + } + } if (!json && !compact) { rspamd_printf("*** Section %s ***\n", argv[i]); } - rspamadm_dump_section_obj(cfg, cur, doc_obj); + rspamadm_dump_section_obj(cfg, output_obj, doc_obj); if (!json && !compact) { rspamd_printf("\n*** End of section %s ***\n", argv[i]); @@ -548,6 +653,10 @@ rspamadm_configdump(int argc, char **argv, const struct rspamadm_command *cmd) else { rspamd_printf("\n"); } + + if (local_conf_only || override_conf_only) { + ucl_object_unref((ucl_object_t *) output_obj); + } } } } diff --git a/src/rspamadm/control.c b/src/rspamadm/control.c index 381bdaa7a..cd550c04e 100644 --- a/src/rspamadm/control.c +++ b/src/rspamadm/control.c @@ -1,11 +1,11 @@ -/*- - * Copyright 2016 Vsevolod Stakhov +/* + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -112,7 +112,7 @@ rspamd_control_finish_handler(struct rspamd_http_connection *conn, struct rspamadm_control_cbdata *cbdata = conn->ud; body = rspamd_http_message_get_body(msg, &body_len); - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!body || !ucl_parser_add_chunk(parser, body, body_len)) { rspamd_fprintf(stderr, "cannot parse server's reply: %s\n", diff --git a/src/rspamadm/lua_repl.c b/src/rspamadm/lua_repl.c index 1d6da5aa9..f9099d895 100644 --- a/src/rspamadm/lua_repl.c +++ b/src/rspamadm/lua_repl.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,8 @@ #include "lua/lua_thread_pool.h" #include "message.h" #include "unix-std.h" -#ifdef WITH_LUA_REPL + #include "replxx.h" -#endif #include "worker_util.h" #ifdef WITH_LUAJIT #include <luajit.h> @@ -43,10 +42,7 @@ static int batch = -1; extern struct rspamd_async_session *rspamadm_session; static const char *default_history_file = ".rspamd_repl.hist"; - -#ifdef WITH_LUA_REPL static Replxx *rx_instance = NULL; -#endif #ifdef WITH_LUAJIT #define MAIN_PROMPT LUAJIT_VERSION "> " @@ -232,7 +228,6 @@ rspamadm_exec_input(lua_State *L, const char *input) int i, cbref; int top = 0; char outbuf[8192]; - struct lua_logger_trace tr; struct thread_entry *thread = lua_thread_pool_get_for_config(rspamd_main->cfg); L = thread->lua_state; @@ -272,9 +267,8 @@ rspamadm_exec_input(lua_State *L, const char *input) rspamd_printf("local function: %d\n", cbref); } else { - memset(&tr, 0, sizeof(tr)); - lua_logger_out_type(L, i, outbuf, sizeof(outbuf) - 1, &tr, - LUA_ESCAPE_UNPRINTABLE); + lua_logger_out(L, i, outbuf, sizeof(outbuf), + LUA_ESCAPE_UNPRINTABLE); rspamd_printf("%s\n", outbuf); } } @@ -393,7 +387,6 @@ rspamadm_lua_message_handler(lua_State *L, int argc, char **argv) gpointer map; gsize len; char outbuf[8192]; - struct lua_logger_trace tr; if (argv[1] == NULL) { rspamd_printf("no callback is specified\n"); @@ -455,9 +448,8 @@ rspamadm_lua_message_handler(lua_State *L, int argc, char **argv) rspamd_printf("lua callback for %s returned:\n", argv[i]); for (j = old_top + 1; j <= lua_gettop(L); j++) { - memset(&tr, 0, sizeof(tr)); - lua_logger_out_type(L, j, outbuf, sizeof(outbuf), &tr, - LUA_ESCAPE_UNPRINTABLE); + lua_logger_out(L, j, outbuf, sizeof(outbuf), + LUA_ESCAPE_UNPRINTABLE); rspamd_printf("%s\n", outbuf); } } @@ -503,7 +495,6 @@ rspamadm_lua_try_dot_command(lua_State *L, const char *input) return FALSE; } -#ifdef WITH_LUA_REPL static int lex_ref_idx = -1; static void @@ -599,20 +590,14 @@ lua_syntax_highlighter(const char *str, ReplxxColor *colours, int size, void *ud lua_settop(L, 0); } -#endif static void rspamadm_lua_run_repl(lua_State *L, bool is_batch) { char *input = NULL; -#ifdef WITH_LUA_REPL gboolean is_multiline = FALSE; GString *tb = NULL; gsize i; -#else - /* Always set is_batch */ - is_batch = TRUE; -#endif for (;;) { if (is_batch) { @@ -644,7 +629,6 @@ rspamadm_lua_run_repl(lua_State *L, bool is_batch) lua_settop(L, 0); } else { -#ifdef WITH_LUA_REPL replxx_set_highlighter_callback(rx_instance, lua_syntax_highlighter, L); @@ -706,7 +690,6 @@ rspamadm_lua_run_repl(lua_State *L, bool is_batch) } } } -#endif } } @@ -1009,16 +992,12 @@ rspamadm_lua(int argc, char **argv, const struct rspamadm_command *cmd) } if (!batch) { -#ifdef WITH_LUA_REPL rx_instance = replxx_init(); replxx_set_max_history_size(rx_instance, max_history); replxx_history_load(rx_instance, histfile); -#endif rspamadm_lua_run_repl(L, false); -#ifdef WITH_LUA_REPL replxx_history_save(rx_instance, histfile); replxx_end(rx_instance); -#endif } else { rspamadm_lua_run_repl(L, true); diff --git a/src/rspamadm/signtool.c b/src/rspamadm/signtool.c index 6d60e6700..538767b19 100644 --- a/src/rspamadm/signtool.c +++ b/src/rspamadm/signtool.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -573,7 +573,7 @@ rspamadm_signtool(int argc, char **argv, const struct rspamadm_command *cmd) else { g_assert(keypair_file != NULL); - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_file(parser, keypair_file) || (top = ucl_parser_get_object(parser)) == NULL) { diff --git a/src/rspamd_proxy.c b/src/rspamd_proxy.c index 694e87c12..77d2336b2 100644 --- a/src/rspamd_proxy.c +++ b/src/rspamd_proxy.c @@ -1,5 +1,5 @@ /* - * Copyright 2024 Vsevolod Stakhov + * Copyright 2025 Vsevolod Stakhov * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -85,6 +85,12 @@ worker_t rspamd_proxy_worker = { RSPAMD_WORKER_SOCKET_TCP, /* TCP socket */ RSPAMD_WORKER_VER}; +enum rspamd_proxy_log_tag_type { + RSPAMD_PROXY_LOG_TAG_SESSION = 0, /* Use session mempool tag (default) */ + RSPAMD_PROXY_LOG_TAG_QUEUE_ID, /* Use Queue-ID from client message */ + RSPAMD_PROXY_LOG_TAG_NONE, /* Skip log tag passing */ +}; + struct rspamd_http_upstream { char *name; char *settings_id; @@ -96,6 +102,10 @@ struct rspamd_http_upstream { gboolean local; gboolean self_scan; gboolean compress; + gboolean ssl; + gboolean keepalive; /* Whether to use keepalive for this upstream */ + enum rspamd_proxy_log_tag_type log_tag_type; + ucl_object_t *extra_headers; }; struct rspamd_http_mirror { @@ -109,6 +119,10 @@ struct rspamd_http_mirror { int parser_to_ref; gboolean local; gboolean compress; + gboolean ssl; + gboolean keepalive; /* Whether to use keepalive for this mirror */ + enum rspamd_proxy_log_tag_type log_tag_type; + ucl_object_t *extra_headers; }; static const uint64_t rspamd_rspamd_proxy_magic = 0xcdeb4fd1fc351980ULL; @@ -161,6 +175,8 @@ struct rspamd_proxy_ctx { /* Language detector */ struct rspamd_lang_detector *lang_det; double task_timeout; + /* Default log tag type for worker */ + enum rspamd_proxy_log_tag_type log_tag_type; struct rspamd_main *srv; }; @@ -214,6 +230,7 @@ struct rspamd_proxy_session { enum rspamd_proxy_legacy_support legacy_support; int retries; ref_entry_t ref; + gboolean use_keepalive; /* Whether to use keepalive for this session */ }; static gboolean proxy_send_master_message(struct rspamd_proxy_session *session); @@ -224,6 +241,77 @@ rspamd_proxy_quark(void) return g_quark_from_static_string("rspamd-proxy"); } +static enum rspamd_proxy_log_tag_type +rspamd_proxy_parse_log_tag_type(const char *str) +{ + if (str == NULL) { + return RSPAMD_PROXY_LOG_TAG_SESSION; + } + + if (g_ascii_strcasecmp(str, "session") == 0 || + g_ascii_strcasecmp(str, "session_tag") == 0) { + return RSPAMD_PROXY_LOG_TAG_SESSION; + } + else if (g_ascii_strcasecmp(str, "queue_id") == 0 || + g_ascii_strcasecmp(str, "queue-id") == 0) { + return RSPAMD_PROXY_LOG_TAG_QUEUE_ID; + } + else if (g_ascii_strcasecmp(str, "none") == 0 || + g_ascii_strcasecmp(str, "skip") == 0) { + return RSPAMD_PROXY_LOG_TAG_NONE; + } + + /* Default to session tag for unknown values */ + return RSPAMD_PROXY_LOG_TAG_SESSION; +} + +static void +rspamd_proxy_add_log_tag_header(struct rspamd_http_message *msg, + struct rspamd_proxy_session *session, + enum rspamd_proxy_log_tag_type log_tag_type) +{ + const rspamd_ftok_t *queue_id_hdr; + + switch (log_tag_type) { + case RSPAMD_PROXY_LOG_TAG_SESSION: + /* Use session mempool tag (current behavior) */ + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + break; + + case RSPAMD_PROXY_LOG_TAG_QUEUE_ID: + /* Try to extract Queue-ID from client message */ + if (session->client_message) { + queue_id_hdr = rspamd_http_message_find_header(session->client_message, QUEUE_ID_HEADER); + if (queue_id_hdr) { + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, + queue_id_hdr->begin, queue_id_hdr->len); + } + /* If no Queue-ID found, fall back to session tag */ + else { + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + } + } + else { + /* No client message, fall back to session tag */ + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + } + break; + + case RSPAMD_PROXY_LOG_TAG_NONE: + /* Skip adding log tag header */ + break; + + default: + /* Fall back to session tag for unknown types */ + rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, + strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); + break; + } +} + static gboolean rspamd_proxy_parse_lua_parser(lua_State *L, const ucl_object_t *obj, int *ref_from, int *ref_to, GError **err) @@ -392,6 +480,7 @@ rspamd_proxy_parse_upstream(rspamd_mempool_t *pool, up->parser_from_ref = -1; up->parser_to_ref = -1; up->timeout = ctx->timeout; + up->log_tag_type = ctx->log_tag_type; /* Inherit from worker default */ elt = ucl_object_lookup(obj, "key"); if (elt != NULL) { @@ -420,6 +509,21 @@ rspamd_proxy_parse_upstream(rspamd_mempool_t *pool, up->compress = TRUE; } + elt = ucl_object_lookup(obj, "ssl"); + if (elt && ucl_object_toboolean(elt)) { + up->ssl = TRUE; + } + + elt = ucl_object_lookup_any(obj, "keepalive", "keep_alive", NULL); + if (elt && ucl_object_toboolean(elt)) { + up->keepalive = TRUE; + } + + elt = ucl_object_lookup_any(obj, "keepalive", "keep_alive", NULL); + if (elt && ucl_object_toboolean(elt)) { + up->keepalive = TRUE; + } + elt = ucl_object_lookup(obj, "hosts"); if (elt == NULL && !up->self_scan) { @@ -469,6 +573,27 @@ rspamd_proxy_parse_upstream(rspamd_mempool_t *pool, up->settings_id = rspamd_mempool_strdup(pool, ucl_object_tostring(elt)); } + elt = ucl_object_lookup(obj, "extra_headers"); + if (elt && ucl_object_type(elt) == UCL_OBJECT) { + up->extra_headers = ucl_object_ref(elt); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) ucl_object_unref, + up->extra_headers); + } + + elt = ucl_object_lookup(obj, "extra_headers"); + if (elt && ucl_object_type(elt) == UCL_OBJECT) { + up->extra_headers = ucl_object_ref(elt); + rspamd_mempool_add_destructor(pool, + (rspamd_mempool_destruct_t) ucl_object_unref, + up->extra_headers); + } + + elt = ucl_object_lookup_any(obj, "log_tag", "log_tag_type", NULL); + if (elt && ucl_object_type(elt) == UCL_STRING) { + up->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(elt)); + } + /* * Accept lua function here in form * fun :: String -> UCL @@ -568,6 +693,7 @@ rspamd_proxy_parse_mirror(rspamd_mempool_t *pool, up->parser_to_ref = -1; up->parser_from_ref = -1; up->timeout = ctx->timeout; + up->log_tag_type = ctx->log_tag_type; /* Inherit from worker default */ elt = ucl_object_lookup(obj, "key"); if (elt != NULL) { @@ -648,6 +774,11 @@ rspamd_proxy_parse_mirror(rspamd_mempool_t *pool, up->settings_id = rspamd_mempool_strdup(pool, ucl_object_tostring(elt)); } + elt = ucl_object_lookup_any(obj, "log_tag", "log_tag_type", NULL); + if (elt && ucl_object_type(elt) == UCL_STRING) { + up->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(elt)); + } + g_ptr_array_add(ctx->mirrors, up); return TRUE; @@ -747,6 +878,29 @@ err: return FALSE; } +static gboolean +rspamd_proxy_parse_log_tag_worker_option(rspamd_mempool_t *pool, + const ucl_object_t *obj, + gpointer ud, + struct rspamd_rcl_section *section, + GError **err) +{ + struct rspamd_proxy_ctx *ctx; + struct rspamd_rcl_struct_parser *pd = ud; + + ctx = pd->user_struct; + + if (ucl_object_type(obj) != UCL_STRING) { + g_set_error(err, rspamd_proxy_quark(), 100, + "log_tag_type option must be a string"); + return FALSE; + } + + ctx->log_tag_type = rspamd_proxy_parse_log_tag_type(ucl_object_tostring(obj)); + + return TRUE; +} + gpointer init_rspamd_proxy(struct rspamd_config *cfg) { @@ -772,6 +926,7 @@ init_rspamd_proxy(struct rspamd_config *cfg) (rspamd_mempool_destruct_t) rspamd_array_free_hard, ctx->cmp_refs); ctx->max_retries = DEFAULT_RETRIES; ctx->spam_header = RSPAMD_MILTER_SPAM_HEADER; + ctx->log_tag_type = RSPAMD_PROXY_LOG_TAG_SESSION; /* Default to session tag */ rspamd_rcl_register_worker_option(cfg, type, @@ -895,6 +1050,16 @@ init_rspamd_proxy(struct rspamd_config *cfg) 0, "Use custom tempfail message"); + /* We need a custom parser for log_tag_type as it's an enum */ + rspamd_rcl_register_worker_option(cfg, + type, + "log_tag_type", + rspamd_proxy_parse_log_tag_worker_option, + ctx, + 0, + 0, + "Log tag type: session (default), queue_id, or none"); + return ctx; } @@ -905,7 +1070,11 @@ proxy_backend_close_connection(struct rspamd_proxy_backend_connection *conn) if (conn->backend_conn) { rspamd_http_connection_reset(conn->backend_conn); rspamd_http_connection_unref(conn->backend_conn); - close(conn->backend_sock); + + if (!(conn->s && conn->s->use_keepalive)) { + /* Only close socket if we're not using keepalive */ + close(conn->backend_sock); + } } conn->flags |= RSPAMD_BACKEND_CLOSED; @@ -970,7 +1139,7 @@ proxy_backend_parse_results(struct rspamd_proxy_session *session, RSPAMD_FTOK_ASSIGN(&json_ct, "application/json"); if (ct && rspamd_ftok_casecmp(ct, &json_ct) == 0) { - parser = ucl_parser_new(0); + parser = ucl_parser_new(UCL_PARSER_SAFE_FLAGS); if (!ucl_parser_add_chunk(parser, in, inlen)) { char *encoded; @@ -1384,6 +1553,8 @@ proxy_backend_mirror_finish_handler(struct rspamd_http_connection *conn, struct rspamd_proxy_backend_connection *bk_conn = conn->ud; struct rspamd_proxy_session *session; const rspamd_ftok_t *orig_ct; + const rspamd_ftok_t *conn_hdr; + gboolean is_keepalive = FALSE; session = bk_conn->s; @@ -1403,6 +1574,36 @@ proxy_backend_mirror_finish_handler(struct rspamd_http_connection *conn, bk_conn->name, msg->code); rspamd_upstream_ok(bk_conn->up); + /* Check if we can use keepalive */ + conn_hdr = rspamd_http_message_find_header(msg, "Connection"); + if (conn_hdr) { + if (rspamd_substring_search_caseless(conn_hdr->begin, conn_hdr->len, + "keep-alive", 10) != -1) { + is_keepalive = TRUE; + } + } + + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + /* Store connection in keepalive pool */ + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_prepare_keepalive(session->ctx->http_ctx, + conn, rspamd_upstream_addr_cur(bk_conn->up), + up_name, FALSE); + rspamd_http_context_push_keepalive(session->ctx->http_ctx, + conn, msg, session->ctx->event_loop); + + msg_debug_session("pushed mirror connection to %s to keepalive pool", + bk_conn->name); + + /* Mark connection as closed without actually closing it */ + bk_conn->flags |= RSPAMD_BACKEND_CLOSED; + REF_RELEASE(bk_conn->s); + return 0; + } + } + proxy_backend_close_connection(bk_conn); REF_RELEASE(bk_conn->s); @@ -1418,6 +1619,7 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) struct rspamd_proxy_backend_connection *bk_conn; struct rspamd_http_message *msg; GError *err = NULL; + const rspamd_inet_addr_t *keepalive_addr; coin = rspamd_random_double(); @@ -1429,6 +1631,157 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) continue; } + /* Check if we can use keepalive for this mirror */ + if (m->keepalive && session->ctx->http_ctx) { + const char *up_name = NULL; + unsigned int port = 0; + + /* Try to find a keepalive connection */ + if (m->u) { + struct upstream *up = rspamd_upstream_get(m->u, + RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); + if (up) { + up_name = rspamd_upstream_name(up); + port = rspamd_inet_address_get_port(rspamd_upstream_addr_cur(up)); + } + } + + if (up_name) { + keepalive_addr = rspamd_http_context_has_keepalive( + session->ctx->http_ctx, up_name, port, m->ssl); + + if (keepalive_addr) { + /* We found a keepalive connection, use it */ + struct rspamd_http_connection *conn; + + conn = rspamd_http_context_check_keepalive( + session->ctx->http_ctx, + (rspamd_inet_addr_t *) keepalive_addr, + up_name, + m->ssl); + + if (conn) { + /* We have a keepalive connection, set it up */ + bk_conn = rspamd_mempool_alloc0(session->pool, sizeof(*bk_conn)); + bk_conn->s = session; + bk_conn->name = m->name; + bk_conn->timeout = m->timeout; + bk_conn->parser_from_ref = m->parser_from_ref; + bk_conn->parser_to_ref = m->parser_to_ref; + bk_conn->backend_conn = conn; + bk_conn->backend_sock = conn->fd; + + msg = rspamd_http_connection_copy_msg(session->client_message, &err); + + if (msg == NULL) { + msg_err_session("cannot copy message to send to a mirror %s: %e", + m->name, err); + if (err) { + g_error_free(err); + } + continue; + } + + if (up_name) { + rspamd_http_message_remove_header(msg, "Host"); + rspamd_http_message_add_header(msg, "Host", up_name); + } + rspamd_http_message_remove_header(msg, "Connection"); + rspamd_http_message_add_header(msg, "Connection", "keep-alive"); + + if (msg->url->len == 0) { + msg->url = rspamd_fstring_append(msg->url, "/check", strlen("/check")); + } + + if (m->settings_id != NULL) { + rspamd_http_message_remove_header(msg, "Settings-ID"); + rspamd_http_message_add_header(msg, "Settings-ID", m->settings_id); + } + + /* Add extra headers if specified */ + if (m->extra_headers != NULL) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + const char *key, *value; + + while ((cur = ucl_object_iterate(m->extra_headers, &it, true)) != NULL) { + key = ucl_object_key(cur); + value = ucl_object_tostring(cur); + + if (key != NULL && value != NULL) { + rspamd_http_message_remove_header(msg, key); + rspamd_http_message_add_header(msg, key, value); + } + } + } + + /* Add log tag header based on mirror's configuration */ + rspamd_proxy_add_log_tag_header(msg, session, m->log_tag_type); + + /* Set handlers for the connection */ + conn->error_handler = proxy_backend_mirror_error_handler; + conn->finish_handler = proxy_backend_mirror_finish_handler; + conn->ud = bk_conn; + + if (m->key) { + msg->peer_key = rspamd_pubkey_ref(m->key); + } + + if (m->local || rspamd_inet_address_is_local(keepalive_addr)) { + if (session->fname) { + rspamd_http_message_add_header(msg, "File", session->fname); + } + + msg->method = HTTP_GET; + rspamd_http_connection_write_message_shared(conn, + msg, up_name, + NULL, bk_conn, + bk_conn->timeout); + } + else { + if (session->fname) { + msg->flags &= ~RSPAMD_HTTP_FLAG_SHMEM; + rspamd_http_message_set_body(msg, session->map, session->map_len); + } + + msg->method = HTTP_POST; + + if (m->compress) { + proxy_request_compress(msg); + + if (session->client_milter_conn) { + rspamd_http_message_add_header(msg, "Content-Type", + "application/octet-stream"); + } + } + else { + if (session->client_milter_conn) { + rspamd_http_message_add_header(msg, "Content-Type", + "text/plain"); + } + } + + rspamd_http_connection_write_message(conn, + msg, up_name, NULL, bk_conn, + bk_conn->timeout); + } + + g_ptr_array_add(session->mirror_conns, bk_conn); + REF_RETAIN(session); + msg_info_session("send request to %s (using keepalive)", m->name); + + /* + * We have found the existing keepalive connection, so we can + * process another mirror + */ + continue; + } + } + } + } + + /* Non-keepalive connection */ + bk_conn = rspamd_mempool_alloc0(session->pool, sizeof(*bk_conn)); bk_conn->s = session; @@ -1472,7 +1825,9 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) rspamd_http_message_remove_header(msg, "Host"); rspamd_http_message_add_header(msg, "Host", up_name); } - rspamd_http_message_add_header(msg, "Connection", "close"); + rspamd_http_message_remove_header(msg, "Connection"); + rspamd_http_message_add_header(msg, "Connection", + m->keepalive ? "keep-alive" : "close"); if (msg->url->len == 0) { msg->url = rspamd_fstring_append(msg->url, "/check", strlen("/check")); @@ -1483,12 +1838,38 @@ proxy_open_mirror_connections(struct rspamd_proxy_session *session) rspamd_http_message_add_header(msg, "Settings-ID", m->settings_id); } + /* Add extra headers if specified */ + if (m->extra_headers != NULL) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + const char *key, *value; + + while ((cur = ucl_object_iterate(m->extra_headers, &it, true)) != NULL) { + key = ucl_object_key(cur); + value = ucl_object_tostring(cur); + + if (key != NULL && value != NULL) { + rspamd_http_message_remove_header(msg, key); + rspamd_http_message_add_header(msg, key, value); + } + } + } + + /* Add log tag header based on mirror's configuration */ + rspamd_proxy_add_log_tag_header(msg, session, m->log_tag_type); + + unsigned int http_opts = RSPAMD_HTTP_CLIENT_SIMPLE; + + if (m->ssl) { + http_opts |= RSPAMD_HTTP_CLIENT_SSL; + } + bk_conn->backend_conn = rspamd_http_connection_new_client_socket( session->ctx->http_ctx, NULL, proxy_backend_mirror_error_handler, proxy_backend_mirror_finish_handler, - RSPAMD_HTTP_CLIENT_SIMPLE, + http_opts, bk_conn->backend_sock); if (m->key) { @@ -1600,8 +1981,9 @@ proxy_backend_master_error_handler(struct rspamd_http_connection *conn, GError * session->retries++; msg_info_session("abnormally closing connection from backend: %s, error: %e," " retries left: %d", - rspamd_inet_address_to_string_pretty( - rspamd_upstream_addr_cur(session->master_conn->up)), + session->master_conn->up ? rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->master_conn->up)) + : "self-scan", err, session->ctx->max_retries - session->retries); rspamd_upstream_fail(bk_conn->up, FALSE, err ? err->message : "unknown"); @@ -1632,8 +2014,9 @@ proxy_backend_master_error_handler(struct rspamd_http_connection *conn, GError * else { msg_info_session("retry connection to: %s" " retries left: %d", - rspamd_inet_address_to_string( - rspamd_upstream_addr_cur(session->master_conn->up)), + session->master_conn->up ? rspamd_inet_address_to_string( + rspamd_upstream_addr_cur(session->master_conn->up)) + : "self-scan", session->ctx->max_retries - session->retries); } } @@ -1647,7 +2030,9 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, struct rspamd_proxy_session *session, *nsession; rspamd_fstring_t *reply; const rspamd_ftok_t *orig_ct; + const rspamd_ftok_t *conn_hdr; goffset body_offset = -1; + gboolean is_keepalive = FALSE; session = bk_conn->s; rspamd_http_connection_steal_msg(session->master_conn->backend_conn); @@ -1663,6 +2048,16 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_http_message_remove_header(msg, "Server"); rspamd_http_message_remove_header(msg, "Key"); orig_ct = rspamd_http_message_find_header(msg, "Content-Type"); + + /* Check if we can use keepalive */ + conn_hdr = rspamd_http_message_find_header(msg, "Connection"); + if (conn_hdr) { + if (rspamd_substring_search_caseless(conn_hdr->begin, conn_hdr->len, + "keep-alive", 10) != -1) { + is_keepalive = TRUE; + } + } + rspamd_http_connection_reset(session->master_conn->backend_conn); if (!proxy_backend_parse_results(session, bk_conn, session->ctx->lua_state, @@ -1695,6 +2090,22 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_upstream_ok(bk_conn->up); + /* Handle keepalive for master connection */ + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + /* Store connection in keepalive pool */ + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_prepare_keepalive(session->ctx->http_ctx, + conn, rspamd_upstream_addr_cur(bk_conn->up), + up_name, FALSE); + + /* We'll push to keepalive pool after we're done with the response */ + msg_debug_session("will push master connection to %s to keepalive pool", + up_name); + } + } + if (session->client_milter_conn) { nsession = proxy_session_refresh(session); @@ -1708,6 +2119,20 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_milter_send_task_results(nsession->client_milter_conn, session->master_conn->results, NULL, 0); } + + /* Push to keepalive if needed */ + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_push_keepalive(session->ctx->http_ctx, + conn, msg, session->ctx->event_loop); + + /* Mark connection as closed without actually closing it */ + bk_conn->flags |= RSPAMD_BACKEND_CLOSED; + } + } + REF_RELEASE(session); rspamd_http_message_free(msg); } @@ -1723,6 +2148,19 @@ proxy_backend_master_finish_handler(struct rspamd_http_connection *conn, rspamd_http_connection_write_message(session->client_conn, msg, NULL, passed_ct, session, bk_conn->timeout); + + /* Push to keepalive if needed */ + if (is_keepalive && session->use_keepalive && + bk_conn->up && session->ctx->http_ctx) { + const char *up_name = rspamd_upstream_name(bk_conn->up); + if (up_name) { + rspamd_http_context_push_keepalive(session->ctx->http_ctx, + conn, msg, session->ctx->event_loop); + + /* Mark connection as closed without actually closing it */ + bk_conn->flags |= RSPAMD_BACKEND_CLOSED; + } + } } return 0; @@ -1982,6 +2420,9 @@ proxy_send_master_message(struct rspamd_proxy_session *session) /* Remove the original `Connection` header */ rspamd_http_message_remove_header(session->client_message, "Connection"); + /* Set keepalive flag based on backend configuration */ + session->use_keepalive = backend ? backend->keepalive : FALSE; + if (backend == NULL) { /* No backend */ msg_err_session("cannot find upstream for %s", host ? hostbuf : "default"); @@ -2063,14 +2504,21 @@ proxy_send_master_message(struct rspamd_proxy_session *session) if (up_name) { rspamd_http_message_add_header(msg, "Host", up_name); } - rspamd_http_message_add_header(msg, "Connection", "close"); + rspamd_http_message_add_header(msg, "Connection", + backend->keepalive ? "keep-alive" : "close"); + + unsigned int http_opts = RSPAMD_HTTP_CLIENT_SIMPLE; + + if (backend->ssl) { + http_opts |= RSPAMD_HTTP_CLIENT_SSL; + } session->master_conn->backend_conn = rspamd_http_connection_new_client_socket( session->ctx->http_ctx, NULL, proxy_backend_master_error_handler, proxy_backend_master_finish_handler, - RSPAMD_HTTP_CLIENT_SIMPLE, + http_opts, session->master_conn->backend_sock); session->master_conn->flags &= ~RSPAMD_BACKEND_CLOSED; session->master_conn->parser_from_ref = backend->parser_from_ref; @@ -2086,6 +2534,26 @@ proxy_send_master_message(struct rspamd_proxy_session *session) backend->settings_id); } + /* Add extra headers if specified */ + if (backend->extra_headers != NULL) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + const char *key, *value; + + while ((cur = ucl_object_iterate(backend->extra_headers, &it, true)) != NULL) { + key = ucl_object_key(cur); + value = ucl_object_tostring(cur); + + if (key != NULL && value != NULL) { + rspamd_http_message_remove_header(msg, key); + rspamd_http_message_add_header(msg, key, value); + } + } + } + + /* Add log tag header based on backend's configuration */ + rspamd_proxy_add_log_tag_header(msg, session, backend->log_tag_type); + if (backend->local || rspamd_inet_address_is_local( rspamd_upstream_addr_cur( @@ -2206,8 +2674,6 @@ proxy_client_finish_handler(struct rspamd_http_connection *conn, rspamd_http_message_remove_header(msg, "Keep-Alive"); rspamd_http_message_remove_header(msg, "Connection"); rspamd_http_message_remove_header(msg, "Key"); - rspamd_http_message_add_header_len(msg, LOG_TAG_HEADER, session->pool->tag.uid, - strnlen(session->pool->tag.uid, sizeof(session->pool->tag.uid))); proxy_open_mirror_connections(session); rspamd_http_connection_reset(session->client_conn); @@ -2216,8 +2682,9 @@ proxy_client_finish_handler(struct rspamd_http_connection *conn, } else { msg_info_session("finished master connection to %s; HTTP code: %d", - rspamd_inet_address_to_string_pretty( - rspamd_upstream_addr_cur(session->master_conn->up)), + session->master_conn->up ? rspamd_inet_address_to_string_pretty( + rspamd_upstream_addr_cur(session->master_conn->up)) + : "self-scan", msg->code); proxy_backend_close_connection(session->master_conn); REF_RELEASE(session); |