12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361 |
- /*
- * Copyright 2024 Vsevolod Stakhov
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "lua_common.h"
- #include "lua_tensor.h"
- #include "contrib/kann/kann.h"
-
- /***
- * @module rspamd_kann
- * `rspamd_kann` is a Lua interface to kann library
- */
-
- #define KANN_NODE_CLASS rspamd_kann_node_classname
- #define KANN_NETWORK_CLASS rspamd_kann_classname
-
- /* Simple macros to define behaviour */
- #define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L)
- #define KANN_LAYER_INTERFACE(name) \
- { \
- #name, lua_kann_layer_##name \
- }
-
- #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L)
- #define KANN_TRANSFORM_INTERFACE(name) \
- { \
- #name, lua_kann_transform_##name \
- }
-
- #define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L)
- #define KANN_LOSS_INTERFACE(name) \
- { \
- #name, lua_kann_loss_##name \
- }
-
- #define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L)
- #define KANN_NEW_INTERFACE(name) \
- { \
- #name, lua_kann_new_##name \
- }
-
-
- /*
- * Forwarded declarations
- */
- static kad_node_t *lua_check_kann_node(lua_State *L, int pos);
-
- /* Layers */
- KANN_LAYER_DEF(input);
- KANN_LAYER_DEF(dense);
- KANN_LAYER_DEF(layernorm);
- KANN_LAYER_DEF(rnn);
- KANN_LAYER_DEF(lstm);
- KANN_LAYER_DEF(gru);
- KANN_LAYER_DEF(conv2d);
- KANN_LAYER_DEF(conv1d);
- KANN_LAYER_DEF(cost);
-
- static luaL_reg rspamd_kann_layers_f[] = {
- KANN_LAYER_INTERFACE(input),
- KANN_LAYER_INTERFACE(dense),
- KANN_LAYER_INTERFACE(layernorm),
- KANN_LAYER_INTERFACE(rnn),
- KANN_LAYER_INTERFACE(lstm),
- KANN_LAYER_INTERFACE(gru),
- KANN_LAYER_INTERFACE(conv2d),
- KANN_LAYER_INTERFACE(conv1d),
- KANN_LAYER_INTERFACE(cost),
- {NULL, NULL},
- };
-
- /* Transition and composition functions */
-
- /* General transform */
- KANN_TRANSFORM_DEF(add);
- KANN_TRANSFORM_DEF(sub);
- KANN_TRANSFORM_DEF(mul);
- KANN_TRANSFORM_DEF(cmul);
- KANN_TRANSFORM_DEF(matmul);
-
- KANN_TRANSFORM_DEF(square);
- KANN_TRANSFORM_DEF(sigm);
- KANN_TRANSFORM_DEF(tanh);
- KANN_TRANSFORM_DEF(relu);
- KANN_TRANSFORM_DEF(softmax);
- KANN_TRANSFORM_DEF(1minus);
- KANN_TRANSFORM_DEF(exp);
- KANN_TRANSFORM_DEF(log);
- KANN_TRANSFORM_DEF(sin);
- static luaL_reg rspamd_kann_transform_f[] = {
- KANN_TRANSFORM_INTERFACE(add),
- KANN_TRANSFORM_INTERFACE(sub),
- KANN_TRANSFORM_INTERFACE(mul),
- KANN_TRANSFORM_INTERFACE(cmul),
- KANN_TRANSFORM_INTERFACE(matmul),
-
- KANN_TRANSFORM_INTERFACE(square),
- KANN_TRANSFORM_INTERFACE(sigm),
- KANN_TRANSFORM_INTERFACE(tanh),
- KANN_TRANSFORM_INTERFACE(relu),
- KANN_TRANSFORM_INTERFACE(softmax),
- KANN_TRANSFORM_INTERFACE(1minus),
- KANN_TRANSFORM_INTERFACE(exp),
- KANN_TRANSFORM_INTERFACE(log),
- KANN_TRANSFORM_INTERFACE(sin),
- {NULL, NULL},
- };
-
- /* Loss functions */
- KANN_LOSS_DEF(mse);
- KANN_LOSS_DEF(ce_multi);
- KANN_LOSS_DEF(ce_bin);
- KANN_LOSS_DEF(ce_bin_neg);
- KANN_LOSS_DEF(ce_multi_weighted);
- static luaL_reg rspamd_kann_loss_f[] = {
- KANN_LOSS_INTERFACE(mse),
- KANN_LOSS_INTERFACE(ce_multi),
- KANN_LOSS_INTERFACE(ce_bin),
- KANN_LOSS_INTERFACE(ce_bin_neg),
- KANN_LOSS_INTERFACE(ce_multi_weighted),
- {NULL, NULL},
- };
-
- /* Creation functions */
- KANN_NEW_DEF(leaf);
- KANN_NEW_DEF(scalar);
- KANN_NEW_DEF(weight);
- KANN_NEW_DEF(bias);
- KANN_NEW_DEF(weight_conv2d);
- KANN_NEW_DEF(weight_conv1d);
- KANN_NEW_DEF(kann);
-
- static luaL_reg rspamd_kann_new_f[] = {
- KANN_NEW_INTERFACE(leaf),
- KANN_NEW_INTERFACE(scalar),
- KANN_NEW_INTERFACE(weight),
- KANN_NEW_INTERFACE(bias),
- KANN_NEW_INTERFACE(weight_conv2d),
- KANN_NEW_INTERFACE(weight_conv1d),
- KANN_NEW_INTERFACE(kann),
- {NULL, NULL},
- };
-
- LUA_FUNCTION_DEF(kann, load);
- LUA_FUNCTION_DEF(kann, destroy);
- LUA_FUNCTION_DEF(kann, save);
- LUA_FUNCTION_DEF(kann, train1);
- LUA_FUNCTION_DEF(kann, apply1);
-
- static luaL_reg rspamd_kann_m[] = {
- LUA_INTERFACE_DEF(kann, save),
- LUA_INTERFACE_DEF(kann, train1),
- LUA_INTERFACE_DEF(kann, apply1),
- {"__gc", lua_kann_destroy},
- {NULL, NULL},
- };
-
- static int
- rspamd_kann_table_to_flags(lua_State *L, int table_pos)
- {
- int result = 0;
-
- lua_pushvalue(L, table_pos);
-
- for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) {
- int fl = lua_tointeger(L, -1);
-
- result |= fl;
- }
-
- lua_pop(L, 1);
-
- return result;
- }
-
- static int
- lua_load_kann(lua_State *L)
- {
- lua_newtable(L);
-
- /* Flags */
- lua_pushstring(L, "flag");
- lua_newtable(L);
- lua_pushinteger(L, KANN_F_IN);
- lua_setfield(L, -2, "in");
- lua_pushinteger(L, KANN_F_COST);
- lua_setfield(L, -2, "cost");
- lua_pushinteger(L, KANN_F_OUT);
- lua_setfield(L, -2, "out");
- lua_pushinteger(L, KANN_F_TRUTH);
- lua_setfield(L, -2, "truth");
- lua_settable(L, -3);
-
- /* Cost type */
- lua_pushstring(L, "cost");
- lua_newtable(L);
- /* binary cross-entropy cost, used with sigmoid */
- lua_pushinteger(L, KANN_C_CEB);
- lua_setfield(L, -2, "ceb");
- /* multi-class cross-entropy cost, used with softmax */
- lua_pushinteger(L, KANN_C_CEM);
- lua_setfield(L, -2, "cem");
- /* binary cross-entropy-like cost, used with tanh */
- lua_pushinteger(L, KANN_C_CEB_NEG);
- lua_setfield(L, -2, "ceb_neg");
- lua_pushinteger(L, KANN_C_MSE);
- lua_setfield(L, -2, "mse");
- lua_settable(L, -3);
-
- /* RNN flag */
- lua_pushstring(L, "rnn");
- lua_newtable(L);
- /* apply layer normalization */
- lua_pushinteger(L, KANN_RNN_NORM);
- lua_setfield(L, -2, "norm");
- /* take the initial hidden values as variables */
- lua_pushinteger(L, KANN_RNN_VAR_H0);
- lua_setfield(L, -2, "var_h0");
- lua_settable(L, -3);
-
- /* Layers */
- lua_pushstring(L, "layer");
- lua_newtable(L);
- luaL_register(L, NULL, rspamd_kann_layers_f);
- lua_settable(L, -3);
-
- /* Transforms */
- lua_pushstring(L, "transform");
- lua_newtable(L);
- luaL_register(L, NULL, rspamd_kann_transform_f);
- lua_settable(L, -3);
-
- /* Cost */
- lua_pushstring(L, "loss");
- lua_newtable(L);
- luaL_register(L, NULL, rspamd_kann_loss_f);
- lua_settable(L, -3);
-
- /* Create functions */
- lua_pushstring(L, "new");
- lua_newtable(L);
- luaL_register(L, NULL, rspamd_kann_new_f);
- lua_settable(L, -3);
-
- /* Load ann from memory or file */
- lua_pushstring(L, "load");
- lua_pushcfunction(L, lua_kann_load);
- lua_settable(L, -3);
-
- return 1;
- }
-
- static kad_node_t *
- lua_check_kann_node(lua_State *L, int pos)
- {
- void *ud = rspamd_lua_check_udata(L, pos, KANN_NODE_CLASS);
- luaL_argcheck(L, ud != NULL, pos, "'kann_node' expected");
- return ud ? *((kad_node_t **) ud) : NULL;
- }
-
- static kann_t *
- lua_check_kann(lua_State *L, int pos)
- {
- void *ud = rspamd_lua_check_udata(L, pos, KANN_NETWORK_CLASS);
- luaL_argcheck(L, ud != NULL, pos, "'kann' expected");
- return ud ? *((kann_t **) ud) : NULL;
- }
-
- void luaopen_kann(lua_State *L)
- {
- /* Metatables */
- rspamd_lua_new_class(L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
- lua_pop(L, 1); /* No need in metatable... */
- rspamd_lua_new_class(L, KANN_NETWORK_CLASS, rspamd_kann_m);
- lua_pop(L, 1); /* No need in metatable... */
- rspamd_lua_add_preload(L, "rspamd_kann", lua_load_kann);
- lua_settop(L, 0);
- }
-
- /* Layers implementation */
- #define PUSH_KAD_NODE(n) \
- do { \
- kad_node_t **pt; \
- pt = lua_newuserdata(L, sizeof(kad_node_t *)); \
- *pt = (n); \
- rspamd_lua_setclass(L, KANN_NODE_CLASS, -1); \
- } while (0)
-
- #define PUSH_KAN_NETWORK(n) \
- do { \
- kann_t **pn; \
- pn = lua_newuserdata(L, sizeof(kann_t *)); \
- *pn = (n); \
- rspamd_lua_setclass(L, KANN_NETWORK_CLASS, -1); \
- } while (0)
-
- #define PROCESS_KAD_FLAGS(n, pos) \
- do { \
- int fl = 0; \
- if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags(L, (pos)); } \
- else if (lua_type(L, (pos)) == LUA_TNUMBER) { \
- fl = lua_tointeger(L, (pos)); \
- } \
- (n)->ext_flag |= fl; \
- } while (0)
-
- /***
- * @function kann.layer.input(ninputs[, flags])
- * Creates an input layer for ANN
- * @param {int} ninputs number of inputs
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_input(lua_State *L)
- {
- int nnodes = luaL_checkinteger(L, 1);
-
- if (nnodes > 0) {
- kad_node_t *t;
-
- t = kann_layer_input(nnodes);
-
- PROCESS_KAD_FLAGS(t, 2);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, nnodes required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.dense(in, ninputs[, flags])
- * Creates a dense layer (e.g. for hidden layer)
- * @param {kann_node} in kann node
- * @param {int} ninputs number of dense nodes
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_dense(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int nnodes = luaL_checkinteger(L, 2);
-
- if (in != NULL && nnodes > 0) {
- kad_node_t *t;
-
- t = kann_layer_dense(in, nnodes);
-
- PROCESS_KAD_FLAGS(t, 3);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input + nnodes required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.dropout(in, ratio[, flags])
- * Creates a dropout layer
- * @param {kann_node} in kann node
- * @param {float} ratio drop ratio
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_layerdropout(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- double r = luaL_checknumber(L, 2);
-
- if (in != NULL) {
- kad_node_t *t;
-
- t = kann_layer_dropout(in, r);
-
- PROCESS_KAD_FLAGS(t, 3);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input + rate required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.dropout(in [, flags])
- * Creates a normalisation layer
- * @param {kann_node} in kann node
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_layernorm(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
-
- if (in != NULL) {
- kad_node_t *t;
-
- t = kann_layer_layernorm(in);
-
- PROCESS_KAD_FLAGS(t, 2);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
- * Creates a recursive NN layer
- * @param {kann_node} in kann node
- * @param {int} nnodes number of cells
- * @param {int} rnnflags rnn flags
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_rnn(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int nnodes = luaL_checkinteger(L, 2);
- int rnnflags = 0;
-
- if (in != NULL && nnodes > 0) {
- kad_node_t *t;
-
- if (lua_type(L, 3) == LUA_TNUMBER) {
- rnnflags = lua_tointeger(L, 3);
- }
-
- t = kann_layer_rnn(in, nnodes, rnnflags);
-
- PROCESS_KAD_FLAGS(t, 4);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input + nnodes required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]])
- * Creates a recursive NN layer using LSTM cells
- * @param {kann_node} in kann node
- * @param {int} nnodes number of cells
- * @param {int} rnnflags rnn flags
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_lstm(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int nnodes = luaL_checkinteger(L, 2);
- int rnnflags = 0;
-
- if (in != NULL && nnodes > 0) {
- kad_node_t *t;
-
- if (lua_type(L, 3) == LUA_TNUMBER) {
- rnnflags = lua_tointeger(L, 3);
- }
-
- t = kann_layer_lstm(in, nnodes, rnnflags);
-
- PROCESS_KAD_FLAGS(t, 4);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input + nnodes required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
- * Creates a recursive NN layer using GRU cells
- * @param {kann_node} in kann node
- * @param {int} nnodes number of cells
- * @param {int} rnnflags rnn flags
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_gru(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int nnodes = luaL_checkinteger(L, 2);
- int rnnflags = 0;
-
- if (in != NULL && nnodes > 0) {
- kad_node_t *t;
-
- if (lua_type(L, 3) == LUA_TNUMBER) {
- rnnflags = lua_tointeger(L, 3);
- }
-
- t = kann_layer_gru(in, nnodes, rnnflags);
-
- PROCESS_KAD_FLAGS(t, 4);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input + nnodes required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags])
- * Creates a 2D convolution layer
- * @param {kann_node} in kann node
- * @param {int} n_flt number of filters
- * @param {int} k_rows kernel rows
- * @param {int} k_cols kernel columns
- * @param {int} stride_rows stride rows
- * @param {int} stride_cols stride columns
- * @param {int} pad_rows padding rows
- * @param {int} pad_columns padding columns
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_conv2d(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int n_flt = luaL_checkinteger(L, 2);
- int k_rows = luaL_checkinteger(L, 3);
- int k_cols = luaL_checkinteger(L, 4);
- int stride_r = luaL_checkinteger(L, 5);
- int stride_c = luaL_checkinteger(L, 6);
- int pad_r = luaL_checkinteger(L, 7);
- int pad_c = luaL_checkinteger(L, 8);
-
- if (in != NULL) {
- kad_node_t *t;
- t = kann_layer_conv2d(in, n_flt, k_rows, k_cols, stride_r, stride_c,
- pad_r, pad_c);
-
- PROCESS_KAD_FLAGS(t, 9);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags])
- * Creates 1D convolution layer
- * @param {kann_node} in kann node
- * @param {int} n_flt number of filters
- * @param {int} kern_size kernel rows
- * @param {int} stride_size stride rows
- * @param {int} pad_size padding rows
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_conv1d(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int n_flt = luaL_checkinteger(L, 2);
- int k_size = luaL_checkinteger(L, 3);
- int stride = luaL_checkinteger(L, 4);
- int pad = luaL_checkinteger(L, 5);
-
- if (in != NULL) {
- kad_node_t *t;
- t = kann_layer_conv1d(in, n_flt, k_size, stride, pad);
-
- PROCESS_KAD_FLAGS(t, 6);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input, nflt, k, stride, pad required");
- }
-
- return 1;
- }
-
- /***
- * @function kann.layer.cost(in, nout, cost_type[, flags])
- * Creates 1D convolution layer
- * @param {kann_node} in kann node
- * @param {int} nout number of outputs
- * @param {int} cost_type see kann.cost table
- * @param {table|int} flags optional flags
- * @return {kann_node} kann node object (should be used to combine ANN)
- */
- static int
- lua_kann_layer_cost(lua_State *L)
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
- int nout = luaL_checkinteger(L, 2);
- int cost_type = luaL_checkinteger(L, 3);
-
- if (in != NULL && nout > 0) {
- kad_node_t *t;
- t = kann_layer_cost(in, nout, cost_type);
-
- PROCESS_KAD_FLAGS(t, 4);
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments, input, nout and cost_type are required");
- }
-
- return 1;
- }
-
- /* Generic helpers */
- static int
- lua_kann_call_unary_function(lua_State *L, const char *name,
- kad_node_t *(*func)(kad_node_t *) )
- {
- kad_node_t *in = lua_check_kann_node(L, 1);
-
- if (in != NULL) {
- kad_node_t *t;
- t = func(in);
-
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments for %s, input required", name);
- }
-
- return 1;
- }
- static int
- lua_kann_call_binary_function(lua_State *L, const char *name,
- kad_node_t *(*func)(kad_node_t *, kad_node_t *) )
- {
- kad_node_t *x = lua_check_kann_node(L, 1);
- kad_node_t *y = lua_check_kann_node(L, 2);
-
- if (x != NULL && y != NULL) {
- kad_node_t *t;
- t = func(x, y);
-
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments for %s, 2 inputs required", name);
- }
-
- return 1;
- }
-
- #define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \
- static int lua_kann_transform_##name(lua_State *L) \
- { \
- return lua_kann_call_unary_function(L, #name, kad_##name); \
- }
-
- #define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \
- static int lua_kann_transform_##name(lua_State *L) \
- { \
- return lua_kann_call_binary_function(L, #name, kad_##name); \
- }
-
- #define LUA_LOSS_FUNC_IMPL(name) \
- static int lua_kann_loss_##name(lua_State *L) \
- { \
- return lua_kann_call_binary_function(L, #name, kad_##name); \
- }
-
- /* Transform functions registered via macro helpers */
- LUA_BINARY_TRANSFORM_FUNC_IMPL(add)
- LUA_BINARY_TRANSFORM_FUNC_IMPL(sub)
- LUA_BINARY_TRANSFORM_FUNC_IMPL(mul)
- LUA_BINARY_TRANSFORM_FUNC_IMPL(cmul)
- LUA_BINARY_TRANSFORM_FUNC_IMPL(matmul)
-
- LUA_UNARY_TRANSFORM_FUNC_IMPL(square)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(relu)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(exp)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(log)
- LUA_UNARY_TRANSFORM_FUNC_IMPL(sin)
-
- /* Generic cost functions */
- LUA_LOSS_FUNC_IMPL(mse)
- LUA_LOSS_FUNC_IMPL(ce_multi)
- LUA_LOSS_FUNC_IMPL(ce_bin)
- LUA_LOSS_FUNC_IMPL(ce_bin_neg)
-
- /* The only case of ternary weight function */
- static int
- lua_kann_loss_ce_multi_weighted(lua_State *L)
- {
- kad_node_t *pred = lua_check_kann_node(L, 1);
- kad_node_t *truth = lua_check_kann_node(L, 2);
- kad_node_t *weight = lua_check_kann_node(L, 3);
-
- if (pred != NULL && truth != NULL && weight != NULL) {
- kad_node_t *t;
- t = kad_ce_multi_weighted(pred, truth, weight);
-
- PUSH_KAD_NODE(t);
- }
- else {
- return luaL_error(L, "invalid arguments for ce_multi_weighted, 3 inputs required");
- }
-
- return 1;
- }
-
- /* Creation functions */
- static int
- lua_kann_new_scalar(lua_State *L)
- {
- int flag = luaL_checkinteger(L, 1);
- double x = luaL_checknumber(L, 2);
- kad_node_t *t;
-
- t = kann_new_scalar(flag, x);
-
- PROCESS_KAD_FLAGS(t, 3);
- PUSH_KAD_NODE(t);
-
- return 1;
- }
-
- static int
- lua_kann_new_weight(lua_State *L)
- {
- int nrow = luaL_checkinteger(L, 1);
- int ncol = luaL_checkinteger(L, 2);
- kad_node_t *t;
-
- t = kann_new_weight(nrow, ncol);
-
- PROCESS_KAD_FLAGS(t, 3);
- PUSH_KAD_NODE(t);
-
- return 1;
- }
-
- static int
- lua_kann_new_bias(lua_State *L)
- {
- int n = luaL_checkinteger(L, 1);
- kad_node_t *t;
-
- t = kann_new_bias(n);
-
- PROCESS_KAD_FLAGS(t, 2);
- PUSH_KAD_NODE(t);
-
- return 1;
- }
-
- static int
- lua_kann_new_weight_conv2d(lua_State *L)
- {
- int nout = luaL_checkinteger(L, 1);
- int nin = luaL_checkinteger(L, 2);
- int krow = luaL_checkinteger(L, 3);
- int kcol = luaL_checkinteger(L, 4);
- kad_node_t *t;
-
- t = kann_new_weight_conv2d(nout, nin, krow, kcol);
-
- PROCESS_KAD_FLAGS(t, 5);
- PUSH_KAD_NODE(t);
-
- return 1;
- }
-
- static int
- lua_kann_new_weight_conv1d(lua_State *L)
- {
- int nout = luaL_checkinteger(L, 1);
- int nin = luaL_checkinteger(L, 2);
- int klen = luaL_checkinteger(L, 3);
- kad_node_t *t;
-
- t = kann_new_weight_conv1d(nout, nin, klen);
-
- PROCESS_KAD_FLAGS(t, 4);
- PUSH_KAD_NODE(t);
-
- return 1;
- }
-
- static int
- lua_kann_new_leaf(lua_State *L)
- {
- int dim = luaL_checkinteger(L, 1), i, *ar;
- kad_node_t *t;
-
- if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable(L, 2)) {
- ar = g_new0(int, KAD_MAX_DIM);
-
- for (i = 0; i < dim; i++) {
- lua_rawgeti(L, 2, i + 1);
- ar[i] = lua_tointeger(L, -1);
- lua_pop(L, 1);
- }
-
- t = kann_new_leaf_array(NULL, NULL, 0, 0.0, dim, ar);
-
- PROCESS_KAD_FLAGS(t, 3);
- PUSH_KAD_NODE(t);
-
- g_free(ar);
- }
- else {
- return luaL_error(L, "invalid arguments for new.leaf, "
- "dim and vector of elements are required");
- }
-
- return 1;
- }
-
- static int
- lua_kann_new_kann(lua_State *L)
- {
- kad_node_t *cost = lua_check_kann_node(L, 1);
- kann_t *k;
-
- if (cost) {
- k = kann_new(cost, 0);
-
- PUSH_KAN_NETWORK(k);
- }
- else {
- return luaL_error(L, "invalid arguments for new.kann, "
- "cost node is required");
- }
-
- return 1;
- }
-
- static int
- lua_kann_destroy(lua_State *L)
- {
- kann_t *k = lua_check_kann(L, 1);
-
- kann_delete(k);
-
- return 0;
- }
-
- static int
- lua_kann_save(lua_State *L)
- {
- kann_t *k = lua_check_kann(L, 1);
-
- if (k) {
- if (lua_istable(L, 2)) {
- lua_getfield(L, 2, "filename");
-
- if (lua_isstring(L, -1)) {
- const char *fname = lua_tostring(L, -1);
- FILE *f;
-
- f = fopen(fname, "w");
-
- if (!f) {
- lua_pop(L, 1);
-
- return luaL_error(L, "cannot open %s for writing: %s",
- fname, strerror(errno));
- }
-
- kann_save_fp(f, k);
- fclose(f);
-
- lua_pushboolean(L, true);
- }
- else {
- lua_pop(L, 1);
-
- return luaL_error(L, "invalid arguments: missing filename");
- }
-
- lua_pop(L, 1);
- }
- else {
- /* Save to Rspamd text */
- #ifndef HAVE_OPENMEMSTREAM
- return luaL_error(L, "no support of saving to memory on your system");
- #endif
- FILE *f;
- char *buf = NULL;
- size_t buflen;
- struct rspamd_lua_text *t;
-
- f = open_memstream(&buf, &buflen);
- g_assert(f != NULL);
-
- kann_save_fp(f, k);
- fclose(f);
-
- t = lua_newuserdata(L, sizeof(*t));
- rspamd_lua_setclass(L, rspamd_text_classname, -1);
- t->flags = RSPAMD_TEXT_FLAG_OWN;
- t->start = (const char *) buf;
- t->len = buflen;
- }
- }
- else {
- return luaL_error(L, "invalid arguments");
- }
-
- return 1;
- }
-
- static int
- lua_kann_load(lua_State *L)
- {
- kann_t *k;
- FILE *f = NULL;
-
- if (lua_istable(L, 1)) {
- lua_getfield(L, 2, "filename");
-
- if (lua_isstring(L, -1)) {
- const char *fname = lua_tostring(L, -1);
-
- f = fopen(fname, "rb");
- }
- else {
- lua_pop(L, 1);
-
- return luaL_error(L, "invalid arguments: missing filename");
- }
-
- lua_pop(L, 1);
- }
- else if (lua_isstring(L, 1)) {
- gsize dlen;
- const char *data;
-
- data = lua_tolstring(L, 1, &dlen);
-
- #ifndef HAVE_FMEMOPEN
- return luaL_error(L, "no support of loading from memory on your system");
- #endif
- f = fmemopen((void *) data, dlen, "rb");
- }
- else if (lua_isuserdata(L, 1)) {
- struct rspamd_lua_text *t;
-
- t = lua_check_text(L, 1);
-
- if (!t) {
- return luaL_error(L, "invalid arguments");
- }
-
- #ifndef HAVE_FMEMOPEN
- return luaL_error(L, "no support of loading from memory on your system");
- #endif
- f = fmemopen((void *) t->start, t->len, "rb");
- }
-
- if (f == NULL) {
- return luaL_error(L, "invalid arguments or cannot open file");
- }
-
- k = kann_load_fp(f);
- fclose(f);
-
- if (k == NULL) {
- lua_pushnil(L);
- }
- else {
- PUSH_KAN_NETWORK(k);
- }
-
- return 1;
- }
-
- struct rspamd_kann_train_cbdata {
- lua_State *L;
- kann_t *k;
- int cbref;
- };
-
- static void
- lua_kann_train_cb(int iter, float train_cost, float val_cost, void *ud)
- {
- struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *) ud;
-
- if (cbd->cbref != -1) {
- int err_idx;
- lua_State *L = cbd->L;
-
- lua_pushcfunction(L, &rspamd_lua_traceback);
- err_idx = lua_gettop(L);
-
- lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref);
- lua_pushinteger(L, iter);
- lua_pushnumber(L, train_cost);
- lua_pushnumber(L, val_cost);
-
- if (lua_pcall(L, 3, 0, err_idx) != 0) {
- msg_err("cannot run lua train callback: %s",
- lua_tostring(L, -1));
- }
-
- lua_settop(L, err_idx - 1);
- }
- }
-
- #define FREE_VEC(a, n) \
- do { \
- for (int i = 0; i < (n); i++) g_free((a)[i]); \
- g_free(a); \
- } while (0)
-
- static int
- lua_kann_train1(lua_State *L)
- {
- kann_t *k = lua_check_kann(L, 1);
- struct rspamd_lua_tensor *pca = NULL;
-
- /* Default train params */
- double lr = 0.001;
- int64_t mini_size = 64;
- int64_t max_epoch = 25;
- int64_t max_drop_streak = 10;
- double frac_val = 0.1;
- int cbref = -1;
-
- if (k && lua_istable(L, 2) && lua_istable(L, 3)) {
- int n = rspamd_lua_table_size(L, 2);
- int n_in = kann_dim_in(k);
- int n_out = kann_dim_out(k);
-
- if (n_in <= 0) {
- return luaL_error(L, "invalid inputs count: %d", n_in);
- }
-
- if (n_out <= 0) {
- return luaL_error(L, "invalid outputs count: %d", n_out);
- }
-
- if (n != rspamd_lua_table_size(L, 3) || n == 0) {
- return luaL_error(L, "invalid dimensions: outputs size must be "
- "equal to inputs and non zero");
- }
-
- if (lua_istable(L, 4)) {
- GError *err = NULL;
-
- if (!rspamd_lua_parse_table_arguments(L, 4, &err,
- RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
- "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
- &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
- n = luaL_error(L, "invalid params: %s",
- err ? err->message : "unknown error");
- g_error_free(err);
-
- return n;
- }
- }
-
- if (pca) {
- /* Check pca matrix validity */
- if (pca->ndims != 2) {
- return luaL_error(L, "invalid pca tensor: matrix expected, got a row");
- }
-
- if (pca->dim[0] != n_in) {
- return luaL_error(L, "invalid pca tensor: "
- "matrix must have %d rows and it has %d rows instead",
- n_in, pca->dim[0]);
- }
- }
-
- float **x, **y, *tmp_row = NULL;
-
- /* Fill vectors row by row */
- x = (float **) g_malloc0(sizeof(float *) * n);
- y = (float **) g_malloc0(sizeof(float *) * n);
-
- if (pca) {
- tmp_row = g_malloc(sizeof(float) * pca->dim[1]);
- }
-
- for (int s = 0; s < n; s++) {
- /* Inputs */
- lua_rawgeti(L, 2, s + 1);
- x[s] = (float *) g_malloc(sizeof(float) * n_in);
-
- if (pca == NULL) {
- if (rspamd_lua_table_size(L, -1) != n_in) {
- FREE_VEC(x, n);
- FREE_VEC(y, n);
-
- n = luaL_error(L, "invalid params at pos %d: "
- "bad input dimension %d; %d expected",
- s + 1,
- (int) rspamd_lua_table_size(L, -1),
- n_in);
- lua_pop(L, 1);
-
- return n;
- }
-
- for (int i = 0; i < n_in; i++) {
- lua_rawgeti(L, -1, i + 1);
- x[s][i] = lua_tonumber(L, -1);
-
- lua_pop(L, 1);
- }
- }
- else {
- if (rspamd_lua_table_size(L, -1) != pca->dim[1]) {
- FREE_VEC(x, n);
- FREE_VEC(y, n);
- g_free(tmp_row);
-
- n = luaL_error(L, "(pca on) invalid params at pos %d: "
- "bad input dimension %d; %d expected",
- s + 1,
- (int) rspamd_lua_table_size(L, -1),
- pca->dim[1]);
- lua_pop(L, 1);
-
- return n;
- }
-
-
- for (int i = 0; i < pca->dim[1]; i++) {
- lua_rawgeti(L, -1, i + 1);
- tmp_row[i] = lua_tonumber(L, -1);
-
- lua_pop(L, 1);
- }
-
- kad_sgemm_simple(0, 1, 1, n_in,
- pca->dim[1], tmp_row, pca->data,
- x[s]);
- }
-
- lua_pop(L, 1);
-
- /* Outputs */
- y[s] = (float *) g_malloc(sizeof(float) * n_out);
- lua_rawgeti(L, 3, s + 1);
-
- if (rspamd_lua_table_size(L, -1) != n_out) {
- FREE_VEC(x, n);
- FREE_VEC(y, n);
- g_free(tmp_row);
-
- n = luaL_error(L, "invalid params at pos %d: "
- "bad output dimension %d; "
- "%d expected",
- s + 1,
- (int) rspamd_lua_table_size(L, -1),
- n_out);
- lua_pop(L, 1);
-
- return n;
- }
-
- for (int i = 0; i < n_out; i++) {
- lua_rawgeti(L, -1, i + 1);
- y[s][i] = lua_tonumber(L, -1);
-
- lua_pop(L, 1);
- }
-
- lua_pop(L, 1);
- }
-
- struct rspamd_kann_train_cbdata cbd;
-
- cbd.cbref = cbref;
- cbd.k = k;
- cbd.L = L;
-
- int niters = kann_train_fnn1(k, lr,
- mini_size, max_epoch, max_drop_streak,
- frac_val, n, x, y, lua_kann_train_cb, &cbd);
-
- lua_pushinteger(L, niters);
-
- FREE_VEC(x, n);
- FREE_VEC(y, n);
- g_free(tmp_row);
- }
- else {
- return luaL_error(L, "invalid arguments: kann, inputs, outputs and"
- " optional params are expected");
- }
-
- return 1;
- }
-
- static int
- lua_kann_apply1(lua_State *L)
- {
- kann_t *k = lua_check_kann(L, 1);
- struct rspamd_lua_tensor *pca = NULL;
-
- if (k) {
- if (lua_istable(L, 2)) {
- gsize vec_len = rspamd_lua_table_size(L, 2);
- float *vec = (float *) g_malloc(sizeof(float) * vec_len),
- *pca_out = NULL;
- int i_out;
- int n_in = kann_dim_in(k);
-
- if (n_in <= 0) {
- g_free(vec);
- return luaL_error(L, "invalid inputs count: %d", n_in);
- }
-
- if (lua_isuserdata(L, 3)) {
- pca = lua_check_tensor(L, 3);
-
- if (pca) {
- if (pca->ndims != 2) {
- g_free(vec);
- return luaL_error(L, "invalid pca tensor: matrix expected, got a row");
- }
-
- if (pca->dim[0] != n_in) {
- g_free(vec);
- return luaL_error(L, "invalid pca tensor: "
- "matrix must have %d rows and it has %d rows instead",
- n_in, pca->dim[0]);
- }
- }
- else {
- g_free(vec);
- return luaL_error(L, "invalid params: pca matrix expected");
- }
- }
- else {
- if (n_in != vec_len) {
- g_free(vec);
- return luaL_error(L, "invalid params: bad input dimension %d; %d expected",
- (int) vec_len, n_in);
- }
- }
-
- for (gsize i = 0; i < vec_len; i++) {
- lua_rawgeti(L, 2, i + 1);
- vec[i] = lua_tonumber(L, -1);
- lua_pop(L, 1);
- }
-
- i_out = kann_find(k, KANN_F_OUT, 0);
-
- if (i_out <= 0) {
- g_free(vec);
- return luaL_error(L, "invalid ANN: output layer is missing or is "
- "at the input pos");
- }
-
- kann_set_batch_size(k, 1);
- if (pca) {
- pca_out = g_malloc(sizeof(float) * n_in);
-
- kad_sgemm_simple(0, 1, 1, n_in,
- vec_len, vec, pca->data,
- pca_out);
-
- kann_feed_bind(k, KANN_F_IN, 0, &pca_out);
- }
- else {
- kann_feed_bind(k, KANN_F_IN, 0, &vec);
- }
-
- kad_eval_at(k->n, k->v, i_out);
-
- gsize outlen = kad_len(k->v[i_out]);
- lua_createtable(L, outlen, 0);
-
- for (gsize i = 0; i < outlen; i++) {
- lua_pushnumber(L, k->v[i_out]->x[i]);
- lua_rawseti(L, -2, i + 1);
- }
-
- g_free(vec);
- g_free(pca_out);
- }
- else if (lua_isuserdata(L, 2)) {
- struct rspamd_lua_tensor *t = lua_check_tensor(L, 2);
-
- if (t && t->ndims == 1) {
- int i_out;
- int n_in = kann_dim_in(k);
-
- if (n_in != t->dim[0]) {
- return luaL_error(L, "invalid params: bad input dimension %d; %d expected",
- (int) t->dim[0], n_in);
- }
-
- i_out = kann_find(k, KANN_F_OUT, 0);
-
- if (i_out <= 0) {
- return luaL_error(L, "invalid ANN: output layer is missing or is "
- "at the input pos");
- }
-
- kann_set_batch_size(k, 1);
- kann_feed_bind(k, KANN_F_IN, 0, &t->data);
- kad_eval_at(k->n, k->v, i_out);
-
- int outlen = kad_len(k->v[i_out]);
- struct rspamd_lua_tensor *out;
- out = lua_newtensor(L, 1, &outlen, false, false);
- /* Ensure that kann and tensor have the same understanding of floats */
- G_STATIC_ASSERT(sizeof(float) == sizeof(rspamd_tensor_num_t));
- memcpy(out->data, k->v[i_out]->x, outlen * sizeof(float));
- }
- else {
- return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected");
- }
- }
- else {
- return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected");
- }
- }
- else {
- return luaL_error(L, "invalid arguments: rspamd{kann} expected");
- }
-
- return 1;
- }
|