lua_kann_train1 (lua_State *L)
{
kann_t *k = lua_check_kann (L, 1);
+ struct rspamd_lua_tensor *pca = NULL;
/* Default train params */
double lr = 0.001;
if (!rspamd_lua_parse_table_arguments (L, 4, &err,
RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
- "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F",
- &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref)) {
+ "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
+ &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
n = luaL_error (L, "invalid params: %s",
err ? err->message : "unknown error");
g_error_free (err);
}
}
- float **x, **y;
+ if (pca) {
+ /* Check pca matrix validity */
+ if (pca->ndims != 2) {
+ return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
+ }
+
+ if (pca->dim[0] != n_in) {
+ return luaL_error (L, "invalid pca tensor: "
+ "matrix must have %d rows and it has %d rows instead",
+ n_in, pca->dim[0]);
+ }
+ }
- /* Fill vectors */
+ float **x, **y, *tmp_row = NULL;
+
+ /* Fill vectors row by row */
x = (float **)g_malloc0 (sizeof (float *) * n);
y = (float **)g_malloc0 (sizeof (float *) * n);
+ if (pca) {
+ tmp_row = g_malloc (sizeof (float) * pca->dim[1]);
+ }
+
for (int s = 0; s < n; s ++) {
/* Inputs */
lua_rawgeti (L, 2, s + 1);
x[s] = (float *)g_malloc (sizeof (float) * n_in);
- if (rspamd_lua_table_size (L, -1) != n_in) {
- FREE_VEC (x, n);
- FREE_VEC (y, n);
+ if (pca == NULL) {
+ if (rspamd_lua_table_size (L, -1) != n_in) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
- lua_pop (L, 1);
- n = luaL_error (L, "invalid params at pos %d: "
- "bad input dimension %d; %d expected",
- s + 1,
- (int)rspamd_lua_table_size (L, -1),
- n_in);
+ n = luaL_error (L, "invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int) rspamd_lua_table_size (L, -1),
+ n_in);
+ lua_pop (L, 1);
- return n;
+ return n;
+ }
+
+ for (int i = 0; i < n_in; i++) {
+ lua_rawgeti (L, -1, i + 1);
+ x[s][i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
}
+ else {
+ if (rspamd_lua_table_size (L, -1) != pca->dim[1]) {
+ FREE_VEC (x, n);
+ FREE_VEC (y, n);
+ g_free (tmp_row);
+
+ n = luaL_error (L, "(pca on) invalid params at pos %d: "
+ "bad input dimension %d; %d expected",
+ s + 1,
+ (int) rspamd_lua_table_size (L, -1),
+ pca->dim[1]);
+ lua_pop (L, 1);
- for (int i = 0; i < n_in; i ++) {
- lua_rawgeti (L, -1, i + 1);
- x[s][i] = lua_tonumber (L, -1);
+ return n;
+ }
- lua_pop (L, 1);
+
+ for (int i = 0; i < pca->dim[1]; i++) {
+ lua_rawgeti (L, -1, i + 1);
+ tmp_row[i] = lua_tonumber (L, -1);
+
+ lua_pop (L, 1);
+ }
+
+ kad_sgemm_simple (0, 0, pca->dim[0], 1,
+ pca->dim[1], pca->data,
+ tmp_row, x[s]);
}
lua_pop (L, 1);
lua_rawgeti (L, 3, s + 1);
if (rspamd_lua_table_size (L, -1) != n_out) {
- lua_pop (L, 1);
FREE_VEC (x, n);
FREE_VEC (y, n);
+ g_free (tmp_row);
n = luaL_error (L, "invalid params at pos %d: "
"bad output dimension %d; "
s + 1,
(int)rspamd_lua_table_size (L, -1),
n_out);
+ lua_pop (L, 1);
return n;
}
FREE_VEC (x, n);
FREE_VEC (y, n);
+ g_free (tmp_row);
}
else {
return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
-- This is an utility function for PCA training
local function fill_scatter(inputs)
- local scatter_matrix = rspamd_tensor.new(2, #inputs, #inputs)
- local row_len = #inputs[1]
+ local scatter_matrix = rspamd_tensor.new(2, #inputs[1], #inputs[1])
+ local nsamples = #inputs
- if type(inputs) == 'table' then
- -- Convert to a tensor
- inputs = rspamd_tensor.fromtable(inputs)
- end
+ -- Convert to a tensor where each row is an input dimension
+ inputs = rspamd_tensor.fromtable(inputs):transpose()
local meanv = inputs:mean()
- for i=1,row_len do
+ for i=1,nsamples do
local col = rspamd_tensor.new(1, #inputs)
for j=1,#inputs do
local x = inputs[j][i] - meanv[j]
w[i] = scatter_matrix[#scatter_matrix - i + 1]
end
+ lua_util.debugm(N, 'pca matrix: %s', w)
+
return w
end
local profile_serialized = ucl.to_format(profile, 'json-compact', true)
rspamd_logger.infox(rspamd_config,
- 'trained ANN %s:%s, %s bytes; redis key: %s (old key %s)',
- rule.prefix, set.name, #data, set.ann.redis_key, ann_key)
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ rule.prefix, set.name,
+ #data, #ann_data,
+ #(set.ann.pca or {}), #(pca_data or {}),
+ set.ann.redis_key, ann_key)
lua_redis.exec_redis_script(redis_save_unlock_id,
{ev_base = ev_base, is_write = true},