You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cdb_backend.cxx 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. /*-
  2. * Copyright 2021 Vsevolod Stakhov
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*
  17. * CDB read only statistics backend
  18. */
  19. #include "config.h"
  20. #include "stat_internal.h"
  21. #include "contrib/cdb/cdb.h"
  22. #include <utility>
  23. #include <memory>
  24. #include <string>
  25. #include <optional>
  26. #include "contrib/expected/expected.hpp"
  27. #include "contrib/robin-hood/robin_hood.h"
  28. #include "fmt/core.h"
  29. namespace rspamd::stat::cdb {
  30. /*
  31. * Utility class to share cdb instances over statfiles instances, as each
  32. * cdb has tokens for both ham and spam classes
  33. */
  34. class cdb_shared_storage {
  35. public:
  36. using cdb_element_t = std::shared_ptr<struct cdb>;
  37. cdb_shared_storage() noexcept = default;
  38. auto get_cdb(const char *path) const -> std::optional<cdb_element_t> {
  39. auto found = elts.find(path);
  40. if (found != elts.end()) {
  41. if (!found->second.expired()) {
  42. return found->second.lock();
  43. }
  44. }
  45. return std::nullopt;
  46. }
  47. /* Create a new smart pointer over POD cdb structure */
  48. static auto new_cdb() -> cdb_element_t {
  49. auto ret = cdb_element_t(new struct cdb, cdb_deleter());
  50. memset(ret.get(), 0, sizeof(struct cdb));
  51. return ret;
  52. }
  53. /* Enclose cdb into storage */
  54. auto push_cdb(const char *path, cdb_element_t cdbp) -> cdb_element_t {
  55. auto found = elts.find(path);
  56. if (found != elts.end()) {
  57. if (found->second.expired()) {
  58. /* OK, move in lieu of the expired weak pointer */
  59. found->second = cdbp;
  60. return cdbp;
  61. }
  62. else {
  63. /*
  64. * Existing and not expired, return the existing one
  65. */
  66. return found->second.lock();
  67. }
  68. }
  69. else {
  70. /* Not existing, make a weak ptr and return the original */
  71. elts.emplace(path,std::weak_ptr<struct cdb>(cdbp));
  72. return cdbp;
  73. }
  74. }
  75. private:
  76. /*
  77. * We store weak pointers here to allow owning cdb statfiles to free
  78. * expensive cdb before this cache is terminated (e.g. on dynamic cdb reload)
  79. */
  80. robin_hood::unordered_flat_map<std::string, std::weak_ptr<struct cdb>> elts;
  81. struct cdb_deleter {
  82. void operator()(struct cdb *c) const {
  83. cdb_free(c);
  84. }
  85. };
  86. };
  87. static cdb_shared_storage cdb_shared_storage;
  88. class ro_backend final {
  89. public:
  90. explicit ro_backend(struct rspamd_statfile *_st, cdb_shared_storage::cdb_element_t _db)
  91. : st(_st), db(_db) {}
  92. ro_backend() = delete;
  93. ro_backend(const ro_backend &) = delete;
  94. ro_backend(ro_backend &&other) noexcept {
  95. *this = std::move(other);
  96. }
  97. ro_backend& operator=(ro_backend &&other) noexcept
  98. {
  99. std::swap(st, other.st);
  100. std::swap(db, other.db);
  101. return *this;
  102. }
  103. ~ro_backend() {}
  104. auto load_cdb() -> tl::expected<bool, std::string>;
  105. auto process_token(const rspamd_token_t *tok) const -> std::optional<float>;
  106. constexpr auto is_spam() const -> bool {
  107. return st->stcf->is_spam;
  108. }
  109. auto get_learns() const -> std::uint64_t {
  110. if (is_spam()) {
  111. return learns_spam;
  112. }
  113. else {
  114. return learns_ham;
  115. }
  116. }
  117. auto get_total_learns() const -> std::uint64_t {
  118. return learns_spam + learns_ham;
  119. }
  120. private:
  121. struct rspamd_statfile *st;
  122. cdb_shared_storage::cdb_element_t db;
  123. bool loaded = false;
  124. std::uint64_t learns_spam = 0;
  125. std::uint64_t learns_ham = 0;
  126. };
  127. template<typename T>
  128. static inline auto
  129. cdb_get_key_as_double(struct cdb *cdb, T key) -> std::optional<double>
  130. {
  131. auto pos = cdb_find(cdb, (void *)&key, sizeof(key));
  132. if (pos > 0) {
  133. auto vpos = cdb_datapos(cdb);
  134. auto vlen = cdb_datalen(cdb);
  135. if (vlen == sizeof(double)) {
  136. double ret;
  137. cdb_read(cdb, (void *)&ret, vlen, vpos);
  138. return ret;
  139. }
  140. }
  141. return std::nullopt;
  142. }
  143. template<typename T>
  144. static inline auto
  145. cdb_get_key_as_float_pair(struct cdb *cdb, T key) -> std::optional<std::pair<float, float>>
  146. {
  147. auto pos = cdb_find(cdb, (void *)&key, sizeof(key));
  148. if (pos > 0) {
  149. auto vpos = cdb_datapos(cdb);
  150. auto vlen = cdb_datalen(cdb);
  151. if (vlen == sizeof(float) * 2) {
  152. union {
  153. struct {
  154. float v1;
  155. float v2;
  156. } d;
  157. char c[sizeof(float) * 2];
  158. } u;
  159. cdb_read(cdb, (void *)u.c, vlen, vpos);
  160. return std::make_pair(u.d.v1, u.d.v2);
  161. }
  162. }
  163. return std::nullopt;
  164. }
  165. auto
  166. ro_backend::load_cdb() -> tl::expected<bool, std::string>
  167. {
  168. if (!db) {
  169. return tl::make_unexpected("no database loaded");
  170. }
  171. /* Now get number of learns */
  172. std::int64_t cdb_key;
  173. static const char learn_spam_key[9] = "_lrnspam", learn_ham_key[9] = "_lrnham_";
  174. auto check_key = [&](const char *key, std::uint64_t &target) -> tl::expected<bool, std::string> {
  175. memcpy((void *)&cdb_key, key, sizeof(cdb_key));
  176. auto maybe_value = cdb_get_key_as_double(db.get(), cdb_key);
  177. if (!maybe_value) {
  178. return tl::make_unexpected(fmt::format("missing {} key", key));
  179. }
  180. // Convert from double to int
  181. target = (std::uint64_t)maybe_value.value();
  182. return true;
  183. };
  184. auto res = check_key(learn_spam_key, learns_spam);
  185. if (!res) {
  186. return res;
  187. }
  188. res = check_key(learn_ham_key, learns_ham);
  189. if (!res) {
  190. return res;
  191. }
  192. loaded = true;
  193. return true; // expected
  194. }
  195. auto
  196. ro_backend::process_token(const rspamd_token_t *tok) const -> std::optional<float>
  197. {
  198. if (!loaded) {
  199. return std::nullopt;
  200. }
  201. auto maybe_value = cdb_get_key_as_float_pair(db.get(), tok->data);
  202. if (maybe_value) {
  203. auto [spam_count, ham_count] = maybe_value.value();
  204. if (is_spam()) {
  205. return spam_count;
  206. }
  207. else {
  208. return ham_count;
  209. }
  210. }
  211. return std::nullopt;
  212. }
  213. auto
  214. open_cdb(struct rspamd_statfile *st) -> tl::expected<ro_backend, std::string>
  215. {
  216. const char *path = nullptr;
  217. const auto *stf = st->stcf;
  218. auto get_filename = [](const ucl_object_t *obj) -> const char * {
  219. const auto *filename = ucl_object_lookup_any(obj,
  220. "filename", "path", "cdb", nullptr);
  221. if (filename && ucl_object_type(filename) == UCL_STRING) {
  222. return ucl_object_tostring(filename);
  223. }
  224. return nullptr;
  225. };
  226. /* First search in backend configuration */
  227. const auto *obj = ucl_object_lookup (st->classifier->cfg->opts, "backend");
  228. if (obj != NULL && ucl_object_type (obj) == UCL_OBJECT) {
  229. path = get_filename(obj);
  230. }
  231. /* Now try statfiles config */
  232. if (!path && stf->opts) {
  233. path = get_filename(stf->opts);
  234. }
  235. /* Now try classifier config */
  236. if (!path && st->classifier->cfg->opts) {
  237. path = get_filename(st->classifier->cfg->opts);
  238. }
  239. if (!path) {
  240. return tl::make_unexpected("missing/malformed filename attribute");
  241. }
  242. auto cached_cdb_maybe = cdb_shared_storage.get_cdb(path);
  243. cdb_shared_storage::cdb_element_t cdbp;
  244. if (!cached_cdb_maybe) {
  245. auto fd = rspamd_file_xopen(path, O_RDONLY, 0, true);
  246. if (fd == -1) {
  247. return tl::make_unexpected(fmt::format("cannot open {}: {}",
  248. path, strerror(errno)));
  249. }
  250. cdbp = cdb_shared_storage::new_cdb();
  251. if (cdb_init(cdbp.get(), fd) == -1) {
  252. close(fd);
  253. return tl::make_unexpected(fmt::format("cannot init cdb in {}: {}",
  254. path, strerror(errno)));
  255. }
  256. cdbp = cdb_shared_storage.push_cdb(path, cdbp);
  257. close(fd);
  258. }
  259. else {
  260. cdbp = cached_cdb_maybe.value();
  261. }
  262. if (!cdbp) {
  263. return tl::make_unexpected(fmt::format("cannot init cdb in {}: internal error",
  264. path));
  265. }
  266. ro_backend bk{st, cdbp};
  267. auto res = bk.load_cdb();
  268. if (!res) {
  269. return tl::make_unexpected(res.error());
  270. }
  271. return bk;
  272. }
  273. }
  274. #define CDB_FROM_RAW(p) (reinterpret_cast<rspamd::stat::cdb::ro_backend *>(p))
  275. /* C exports */
  276. gpointer
  277. rspamd_cdb_init(struct rspamd_stat_ctx* ctx,
  278. struct rspamd_config* cfg,
  279. struct rspamd_statfile* st)
  280. {
  281. auto maybe_backend = rspamd::stat::cdb::open_cdb(st);
  282. if (maybe_backend) {
  283. /* Move into a new pointer */
  284. auto *result = new rspamd::stat::cdb::ro_backend(std::move(maybe_backend.value()));
  285. return result;
  286. }
  287. else {
  288. msg_err_config("cannot load cdb backend: %s", maybe_backend.error().c_str());
  289. }
  290. return nullptr;
  291. }
  292. gpointer
  293. rspamd_cdb_runtime(struct rspamd_task* task,
  294. struct rspamd_statfile_config* stcf,
  295. gboolean learn,
  296. gpointer ctx)
  297. {
  298. /* In CDB we don't have any dynamic stuff */
  299. return ctx;
  300. }
  301. gboolean
  302. rspamd_cdb_process_tokens(struct rspamd_task* task,
  303. GPtrArray* tokens,
  304. gint id,
  305. gpointer ctx)
  306. {
  307. auto *cdbp = CDB_FROM_RAW(ctx);
  308. bool seen_values = false;
  309. for (auto i = 0u; i < tokens->len; i++) {
  310. rspamd_token_t *tok;
  311. tok = reinterpret_cast<rspamd_token_t *>(g_ptr_array_index(tokens, i));
  312. auto res = cdbp->process_token(tok);
  313. if (res) {
  314. tok->values[id] = res.value();
  315. seen_values = true;
  316. }
  317. else {
  318. tok->values[id] = 0;
  319. }
  320. }
  321. if (seen_values) {
  322. if (cdbp->is_spam()) {
  323. task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS;
  324. }
  325. else {
  326. task->flags |= RSPAMD_TASK_FLAG_HAS_HAM_TOKENS;
  327. }
  328. }
  329. return true;
  330. }
  331. gboolean
  332. rspamd_cdb_finalize_process(struct rspamd_task* task,
  333. gpointer runtime,
  334. gpointer ctx)
  335. {
  336. return true;
  337. }
  338. gboolean
  339. rspamd_cdb_learn_tokens(struct rspamd_task* task,
  340. GPtrArray* tokens,
  341. gint id,
  342. gpointer ctx)
  343. {
  344. return false;
  345. }
  346. gboolean
  347. rspamd_cdb_finalize_learn(struct rspamd_task* task,
  348. gpointer runtime,
  349. gpointer ctx,
  350. GError** err)
  351. {
  352. return false;
  353. }
  354. gulong rspamd_cdb_total_learns(struct rspamd_task* task,
  355. gpointer runtime,
  356. gpointer ctx)
  357. {
  358. auto *cdbp = CDB_FROM_RAW(ctx);
  359. return cdbp->get_total_learns();
  360. }
  361. gulong
  362. rspamd_cdb_inc_learns(struct rspamd_task* task,
  363. gpointer runtime,
  364. gpointer ctx)
  365. {
  366. return (gulong)-1;
  367. }
  368. gulong
  369. rspamd_cdb_dec_learns(struct rspamd_task* task,
  370. gpointer runtime,
  371. gpointer ctx)
  372. {
  373. return (gulong)-1;
  374. }
  375. gulong
  376. rspamd_cdb_learns(struct rspamd_task* task,
  377. gpointer runtime,
  378. gpointer ctx)
  379. {
  380. auto *cdbp = CDB_FROM_RAW(ctx);
  381. return cdbp->get_learns();
  382. }
  383. ucl_object_t*
  384. rspamd_cdb_get_stat(gpointer runtime, gpointer ctx)
  385. {
  386. return nullptr;
  387. }
  388. gpointer
  389. rspamd_cdb_load_tokenizer_config(gpointer runtime, gsize* len)
  390. {
  391. return nullptr;
  392. }
  393. void
  394. rspamd_cdb_close(gpointer ctx)
  395. {
  396. auto *cdbp = CDB_FROM_RAW(ctx);
  397. delete cdbp;
  398. }