1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML//EN">
<html> <head>
<title>How to include additional info in day cells</title>
<script type="text/javascript" src="calendar.js"></script>
<script type="text/javascript" src="lang/calendar-en.js"></script>
<script type="text/javascript" src="calendar-setup.js"></script>
<script type="text/javascript">
// define info for dates in this table:
var dateInfo = {
"20050308" : "Mishoo's birthday",
"20050310" : "foo",
"20050315" : "bar",
"20050318" : "25$",
"20050324" : "60$"
};
</script>
<style type="text/css">
@import url(calendar-win2k-1.css);
.calendar .inf { font-size: 80%; color: #444; }
.calendar .wn { font-weight: bold; vertical-align: top; }
</style>
</head>
<body>
<h1>How to include additional info in day cells</h1>
<div id="flatcal" style="float: right"></div>
<script type="text/javascript">
function getDateText(date, d) {
var inf = dateInfo[date.print("%Y%m%d")];
if (!inf) {
return d + "<div class='inf'> </div>";
} else {
return d + "<div class='inf'>" + inf + "</div>";
}
};
function flatCallback(cal) {
if (cal.dateClicked) {
// do something here
window.status = "Selected: " + cal.date;
var inf = dateInfo[cal.date.print("%Y%m%d")];
if (inf) {
window.status += ". Additional info: " + inf;
}
}
};
Calendar.setup({
flat: "flatcal",
dateText: getDateText,
flatCallback: flatCallback
});
</script>
<p>The idea is simple:</p>
<ol>
<li>
<p>Define a callback that takes two parameters like this:</p>
<pre>function getDateText(date, d)</pre>
<p>
This function will receive the date object as the first
parameter and the current date number (1..31) as the second (you
can get it as well by calling date.getDate() but since it's very
probably useful I thought I'd pass it too so that we can avoid a
function call).
</p>
<p>
This function <em>must</em> return the text to be inserted in
the cell of the passed date. That is, one should at least
"return d;".
</p>
</li>
<li>
Pass the above function as the "dateText" parameter to
Calendar.setup.
</li>
</ol>
<p>
The function could simply look like:
</p>
<pre
> function getDateText(date, d) {
if (d == 12) {
return "12th";
} else if (d == 13) {
return "bad luck";
} /* ... etc ... */
}</pre>
<p>
but it's easy to imagine that this approach sucks. For a better
way, see the source of this page and note the usage of an externally
defined "dateText" object which maps "date" to "date info", also
taking into account the year and month. This object can be easily
generated from a database, and the getDateText function becomes
extremely simple (and static).
</p>
<p>
Cheers!
</p>
<hr />
<address><a href="http://dynarch.com/mishoo/">mishoo</a></address>
<!-- hhmts start --> Last modified: Sat Mar 5 17:18:06 EET 2005 <!-- hhmts end -->
</pre { line-height: 125%; }
td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
.highlight .hll { background-color: #ffffcc }
.highlight .c { color: #888888 } /* Comment */
.highlight .err { color: #a61717; background-color: #e3d2d2 } /* Error */
.highlight .k { color: #008800; font-weight: bold } /* Keyword */
.highlight .ch { color: #888888 } /* Comment.Hashbang */
.highlight .cm { color: #888888 } /* Comment.Multiline */
.highlight .cp { color: #cc0000; font-weight: bold } /* Comment.Preproc */
.highlight .cpf { color: #888888 } /* Comment.PreprocFile */
.highlight .c1 { color: #888888 } /* Comment.Single */
.highlight .cs { color: #cc0000; font-weight: bold; background-color: #fff0f0 } /* Comment.Special */
.highlight .gd { color: #000000; background-color: #ffdddd } /* Generic.Deleted */
.highlight .ge { font-style: italic } /* Generic.Emph */
.highlight .gr { color: #aa0000 } /* Generic.Error */
.highlight .gh { color: #333333 } /* Generic.Heading */
.highlight .gi { color: #000000; background-color: #ddffdd } /* Generic.Inserted */
.highlight .go { color: #888888 } /* Generic.Output */
.highlight .gp { color: #555555 } /* Generic.Prompt */
.highlight .gs { font-weight: bold } /* Generic.Strong */
.highlight .gu { color: #666666 } /* Generic.Subheading */
.highlight .gt { color: #aa0000 } /* Generic.Traceback */
.highlight .kc { color: #008800; font-weight: bold } /* Keyword.Constant */
.highlight .kd { color: #008800; font-weight: bold } /* Keyword.Declaration */
.highlight .kn { color: #008800; font-weight: bold } /* Keyword.Namespace */
.highlight .kp { color: #008800 } /* Keyword.Pseudo */
.highlight .kr { color: #008800; font-weight: bold } /* Keyword.Reserved */
.highlight .kt { color: #888888; font-weight: bold } /* Keyword.Type */
.highlight .m { color: #0000DD; font-weight: bold } /* Literal.Number */
.highlight .s { color: #dd2200; background-color: #fff0f0 } /* Literal.String */
.highlight .na { color: #336699 } /* Name.Attribute */
.highlight .nb { color: #003388 } /* Name.Builtin */
.highlight .nc { color: #bb0066; font-weight: bold } /* Name.Class */
.highlight .no { color: #003366; font-weight: bold } /* Name.Constant */
.highlight .nd { color: #555555 } /* Name.Decorator */
.highlight .ne { color: #bb0066; font-weight: bold } /* Name.Exception */
.highlight .nf { color: #0066bb; font-weight: bold } /* Name.Function */
.highlight .nl { color: #336699; font-style: italic } /* Name.Label */
.highlight .nn { color: #bb0066; font-weight: bold } /* Name.Namespace */
.highlight .py { color: #336699; font-weight: bold } /* Name.Property */
.highlight .nt { color: #bb0066; font-weight: bold } /* Name.Tag */
.highlight .nv { color: #336699 } /* Name.Variable */
.highlight .ow { color: #008800 } /* Operator.Word */
.highlight .w { color: #bbbbbb } /* Text.Whitespace */
.highlight .mb { color: #0000DD; font-weight: bold } /* Literal.Number.Bin */
.highlight .mf { color: #0000DD; font-weight: bold } /* Literal.Number.Float */
.highlight .mh { color: #0000DD; font-weight: bold } /* Literal.Number.Hex */
.highlight .mi { color: #0000DD; font-weight: bold } /* Literal.Number.Integer */
.highlight .mo { color: #0000DD; font-weight: bold } /* Literal.Number.Oct */
.highlight .sa { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Affix */
.highlight .sb { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Backtick */
.highlight .sc { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Char */
.highlight .dl { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Delimiter */
.highlight .sd { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Doc */
.highlight .s2 { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Double */
.highlight .se { color: #0044dd; background-color: #fff0f0 } /* Literal.String.Escape */
.highlight .sh { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Heredoc */
.highlight .si { color: #3333bb; background-color: #fff0f0 } /* Literal.String.Interpol */
.highlight .sx { color: #22bb22; background-color: #f0fff0 } /* Literal.String.Other */
.highlight .sr { color: #008800; background-color: #fff0ff } /* Literal.String.Regex */
.highlight .s1 { color: #dd2200; background-color: #fff0f0 } /* Literal.String.Single */
.highlight .ss { color: #aa6600; background-color: #fff0f0 } /* Literal.String.Symbol */
.highlight .bp { color: #003388 } /* Name.Builtin.Pseudo */
.highlight .fm { color: #0066bb; font-weight: bold } /* Name.Function.Magic */
.highlight .vc { color: #336699 } /* Name.Variable.Class */
.highlight .vg { color: #dd7700 } /* Name.Variable.Global */
.highlight .vi { color: #3333bb } /* Name.Variable.Instance */
.highlight .vm { color: #336699 } /* Name.Variable.Magic */
.highlight .il { color: #0000DD; font-weight: bold } /* Literal.Number.Integer.Long *//*-
* Copyright 2020 Vsevolod Stakhov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "lua_common.h"
#include "lua_tensor.h"
#include "contrib/kann/kautodiff.h"
/***
* @module rspamd_tensor
* `rspamd_tensor` is a simple Lua library to abstract matrices and vectors
* Internally, they are represented as arrays of float variables
* So far, merely 1D and 2D tensors are supported
*/
LUA_FUNCTION_DEF (tensor, load);
LUA_FUNCTION_DEF (tensor, save);
LUA_FUNCTION_DEF (tensor, new);
LUA_FUNCTION_DEF (tensor, fromtable);
LUA_FUNCTION_DEF (tensor, destroy);
LUA_FUNCTION_DEF (tensor, mul);
LUA_FUNCTION_DEF (tensor, tostring);
LUA_FUNCTION_DEF (tensor, index);
LUA_FUNCTION_DEF (tensor, newindex);
LUA_FUNCTION_DEF (tensor, len);
LUA_FUNCTION_DEF (tensor, eugen);
LUA_FUNCTION_DEF (tensor, mean);
LUA_FUNCTION_DEF (tensor, transpose);
static luaL_reg rspamd_tensor_f[] = {
LUA_INTERFACE_DEF (tensor, load),
LUA_INTERFACE_DEF (tensor, new),
LUA_INTERFACE_DEF (tensor, fromtable),
{NULL, NULL},
};
static luaL_reg rspamd_tensor_m[] = {
LUA_INTERFACE_DEF (tensor, save),
{"__gc", lua_tensor_destroy},
{"__mul", lua_tensor_mul},
{"mul", lua_tensor_mul},
{"tostring", lua_tensor_tostring},
{"__tostring", lua_tensor_tostring},
{"__index", lua_tensor_index},
{"__newindex", lua_tensor_newindex},
{"__len", lua_tensor_len},
LUA_INTERFACE_DEF (tensor, eugen),
LUA_INTERFACE_DEF (tensor, mean),
LUA_INTERFACE_DEF (tensor, transpose),
{NULL, NULL},
};
struct rspamd_lua_tensor *
lua_newtensor (lua_State *L, int ndims, const int *dim, bool zero_fill, bool own)
{
struct rspamd_lua_tensor *res;
res = lua_newuserdata (L, sizeof (struct rspamd_lua_tensor));
memset (res, 0, sizeof (*res));
res->ndims = ndims;
res->size = 1;
for (guint i = 0; i < ndims; i ++) {
res->size *= dim[i];
res->dim[i] = dim[i];
}
/* To avoid allocating large stuff in Lua */
if (own) {
res->data = g_malloc (sizeof (rspamd_tensor_num_t) * res->size);
if (zero_fill) {
memset (res->data, 0, sizeof (rspamd_tensor_num_t) * res->size);
}
}
else {
/* Mark size negative to distinguish */
res->size = -(res->size);
}
rspamd_lua_setclass (L, TENSOR_CLASS, -1);
return res;
}
/***
* @function tensor.new(ndims, [dim1, ... dimN])
* Creates a new zero filled tensor with the specific number of dimensions
* @return
*/
static gint
lua_tensor_new (lua_State *L)
{
gint ndims = luaL_checkinteger (L, 1);
if (ndims > 0 && ndims <= 2) {
gint *dims = g_alloca (sizeof (gint) * ndims);
for (guint i = 0; i < ndims; i ++) {
dims[i] = lua_tointeger (L, i + 2);
}
(void)lua_newtensor (L, ndims, dims, true, true);
}
else {
return luaL_error (L, "incorrect dimensions number: %d", ndims);
}
return 1;
}
/***
* @function tensor.fromtable(tbl)
* Creates a new zero filled tensor with the specific number of dimensions
* @return
*/
static gint
lua_tensor_fromtable (lua_State *L)
{
if (lua_istable (L, 1)) {
lua_rawgeti (L, 1, 1);
if (lua_isnumber (L, -1)) {
lua_pop (L, 1);
/* Input vector */
gint dims[2];
dims[0] = 1;
dims[1] = rspamd_lua_table_size (L, 1);
struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
dims, false, true);
for (guint i = 0; i < dims[1]; i ++) {
lua_rawgeti (L, 1, i + 1);
res->data[i] = lua_tonumber (L, -1);
lua_pop (L, 1);
}
}
else if (lua_istable (L, -1)) {
/* Input matrix */
lua_pop (L, 1);
/* Calculate the overall size */
gint nrows = rspamd_lua_table_size (L, 1), ncols = 0;
gint err;
for (gint i = 0; i < nrows; i ++) {
lua_rawgeti (L, 1, i + 1);
if (ncols == 0) {
ncols = rspamd_lua_table_size (L, -1);
if (ncols == 0) {
lua_pop (L, 1);
err = luaL_error (L, "invalid params at pos %d: "
"bad input dimension %d",
i,
(int)ncols);
return err;
}
}
else {
if (ncols != rspamd_lua_table_size (L, -1)) {
gint t = rspamd_lua_table_size (L, -1);
lua_pop (L, 1);
err = luaL_error (L, "invalid params at pos %d: "
"bad input dimension %d; %d expected",
i,
t,
ncols);
return err;
}
}
lua_pop (L, 1);
}
gint dims[2];
dims[0] = nrows;
dims[1] = ncols;
struct rspamd_lua_tensor *res = lua_newtensor (L, 2,
dims, false, true);
for (gint i = 0; i < nrows; i ++) {
lua_rawgeti (L, 1, i + 1);
for (gint j = 0; j < ncols; j++) {
lua_rawgeti (L, -1, j + 1);
res->data[i * ncols + j] = lua_tonumber (L, -1);
lua_pop (L, 1);
}
lua_pop (L, 1);
}
}
else {
lua_pop (L, 1);
return luaL_error (L, "incorrect table");
}
}
else {
return luaL_error (L, "incorrect input");
}
return 1;
}
/***
* @method tensor:destroy()
* Tensor destructor
* @return
*/
static gint
lua_tensor_destroy (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
if (t) {
if (t->size > 0) {
g_free (t->data);
}
}
return 0;
}
/***
* @method tensor:save()
* Tensor serialisation function
* @return
*/
static gint
lua_tensor_save (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
gint size;
if (t) {
if (t->size > 0) {
size = t->size;
}
else {
size = -(t->size);
}
gsize sz = sizeof (gint) * 4 + size * sizeof (rspamd_tensor_num_t);
guchar *data;
struct rspamd_lua_text *out = lua_new_text (L, NULL, 0, TRUE);
data = g_malloc (sz);
memcpy (data, &t->ndims, sizeof (int));
memcpy (data + sizeof (int), &size, sizeof (int));
memcpy (data + 2 * sizeof (int), t->dim, sizeof (int) * 2);
memcpy (data + 4 * sizeof (int), t->data,
size * sizeof (rspamd_tensor_num_t));
out->start = (const gchar *)data;
out->len = sz;
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
static gint
lua_tensor_tostring (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
if (t) {
GString *out = g_string_sized_new (128);
if (t->ndims == 1) {
/* Print as a vector */
for (gint i = 0; i < t->dim[0]; i ++) {
rspamd_printf_gstring (out, "%.4f ", t->data[i]);
}
/* Trim last space */
out->len --;
}
else {
for (gint i = 0; i < t->dim[0]; i ++) {
for (gint j = 0; j < t->dim[1]; j ++) {
rspamd_printf_gstring (out, "%.4f ",
t->data[i * t->dim[1] + j]);
}
/* Trim last space */
out->len --;
rspamd_printf_gstring (out, "\n");
}
/* Trim last ; */
out->len --;
}
lua_pushlstring (L, out->str, out->len);
g_string_free (out, TRUE);
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
static gint
lua_tensor_index (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
gint idx;
if (t) {
if (lua_isnumber (L, 2)) {
idx = lua_tointeger (L, 2);
if (t->ndims == 1) {
/* Individual element */
if (idx <= t->dim[0]) {
lua_pushnumber (L, t->data[idx - 1]);
}
else {
lua_pushnil (L);
}
}
else {
/* Push row */
gint dim = t->dim[1];
if (idx <= t->dim[0]) {
/* Non-owning tensor */
struct rspamd_lua_tensor *res =
lua_newtensor (L, 1, &dim, false, false);
res->data = &t->data[(idx - 1) * t->dim[1]];
}
else {
lua_pushnil (L);
}
}
}
else if (lua_isstring (L, 2)) {
/* Access to methods */
lua_getmetatable (L, 1);
lua_pushvalue (L, 2);
lua_rawget (L, -2);
}
}
return 1;
}
static gint
lua_tensor_newindex (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
gint idx;
if (t) {
if (lua_isnumber (L, 2)) {
idx = lua_tointeger (L, 2);
if (t->ndims == 1) {
/* Individual element */
if (idx <= t->dim[0] && idx > 0) {
rspamd_tensor_num_t value = lua_tonumber (L, 3), old;
old = t->data[idx - 1];
t->data[idx - 1] = value;
lua_pushnumber (L, old);
}
else {
return luaL_error (L, "invalid index: %d", idx);
}
}
else {
if (lua_isnumber (L, 3)) {
return luaL_error (L, "cannot assign number to a row");
}
else if (lua_isuserdata (L, 3)) {
/* Tensor assignment */
struct rspamd_lua_tensor *row = lua_check_tensor (L, 3);
if (row) {
if (row->ndims == 1) {
if (row->dim[0] == t->dim[1]) {
if (idx > 0 && idx <= t->dim[0]) {
idx --; /* Zero based index */
memcpy (&t->data[idx * t->dim[1]],
row->data,
t->dim[1] * sizeof (rspamd_tensor_num_t));
return 0;
}
else {
return luaL_error (L, "invalid index: %d", idx);
}
}
}
else {
return luaL_error (L, "cannot assign matrix to row");
}
}
else {
return luaL_error (L, "cannot assign row, invalid tensor");
}
}
else {
/* TODO: add table assignment */
return luaL_error (L, "cannot assign row, not a tensor");
}
}
}
else {
/* Access to methods? NYI */
return luaL_error (L, "cannot assign method of a tensor");
}
}
return 1;
}
/***
* @method tensor:mul(other, [transA, [transB]])
* Multiply two tensors (optionally transposed) and return a new tensor
* @return
*/
static gint
lua_tensor_mul (lua_State *L)
{
struct rspamd_lua_tensor *t1 = lua_check_tensor (L, 1),
*t2 = lua_check_tensor (L, 2), *res;
int transA = 0, transB = 0;
if (lua_isboolean (L, 3)) {
transA = lua_toboolean (L, 3);
}
if (lua_isboolean (L, 4)) {
transB = lua_toboolean (L, 4);
}
if (t1 && t2) {
gint dims[2], shadow_dims[2];
dims[0] = abs (transA ? t1->dim[1] : t1->dim[0]);
shadow_dims[0] = abs (transB ? t2->dim[1] : t2->dim[0]);
dims[1] = abs (transB ? t2->dim[0] : t2->dim[1]);
shadow_dims[1] = abs (transA ? t1->dim[0] : t1->dim[1]);
if (shadow_dims[0] != shadow_dims[1]) {
return luaL_error (L, "incompatible dimensions %d x %d * %d x %d",
dims[0], shadow_dims[1], shadow_dims[0], dims[1]);
}
else if (shadow_dims[0] == 0) {
/* Row * Column -> matrix */
shadow_dims[0] = 1;
shadow_dims[1] = 1;
}
if (dims[0] == 0) {
/* Column */
dims[0] = 1;
if (dims[1] == 0) {
/* Column * row -> number */
dims[1] = 1;
}
res = lua_newtensor (L, 2, dims, true, true);
}
else if (dims[1] == 0) {
/* Row */
res = lua_newtensor (L, 1, dims, true, true);
dims[1] = 1;
}
else {
res = lua_newtensor (L, 2, dims, true, true);
}
kad_sgemm_simple (transA, transB, dims[0], dims[1], shadow_dims[0],
t1->data, t2->data, res->data);
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
/***
* @function tensor.load(rspamd_text)
* Deserialize tensor
* @return
*/
static gint
lua_tensor_load (lua_State *L)
{
const guchar *data;
gsize sz;
if (lua_type (L, 1) == LUA_TUSERDATA) {
struct rspamd_lua_text *t = lua_check_text (L, 1);
if (!t) {
return luaL_error (L, "invalid argument");
}
data = (const guchar *)t->start;
sz = t->len;
}
else {
data = (const guchar *)lua_tolstring (L, 1, &sz);
}
if (sz >= sizeof (gint) * 4) {
int ndims, nelts, dims[2];
memcpy (&ndims, data, sizeof (int));
memcpy (&nelts, data + sizeof (int), sizeof (int));
memcpy (dims, data + sizeof (int) * 2, sizeof (int) * 2);
if (sz == nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4) {
if (ndims == 1) {
if (nelts == dims[0]) {
struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
else {
return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
dims[0], 1, nelts);
}
}
else if (ndims == 2) {
if (nelts == dims[0] * dims[1]) {
struct rspamd_lua_tensor *t = lua_newtensor (L, ndims, dims, false, true);
memcpy (t->data, data + sizeof (int) * 4, nelts *
sizeof (rspamd_tensor_num_t));
}
else {
return luaL_error (L, "invalid argument: bad dims: %d x %d != %d",
dims[0], dims[1], nelts);
}
}
else {
return luaL_error (L, "invalid argument: bad ndims: %d", ndims);
}
}
else {
return luaL_error (L, "invalid size: %d, %d required, %d elts", (int)sz,
(int)(nelts * sizeof (rspamd_tensor_num_t) + sizeof (int) * 4),
nelts);
}
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
static gint
lua_tensor_len (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
gint nret = 1;
if (t) {
/* Return the main dimension first */
if (t->ndims == 1) {
lua_pushinteger (L, t->dim[0]);
}
else {
lua_pushinteger (L, t->dim[0]);
lua_pushinteger (L, t->dim[1]);
nret = 2;
}
}
else {
return luaL_error (L, "invalid arguments");
}
return nret;
}
static gint
lua_tensor_eugen (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *eugen;
if (t) {
if (t->ndims != 2 || t->dim[0] != t->dim[1]) {
return luaL_error (L, "expected square matrix NxN but got %dx%d",
t->dim[0], t->dim[1]);
}
eugen = lua_newtensor (L, 1, &t->dim[0], true, true);
if (!kad_ssyev_simple (t->dim[0], t->data, eugen->data)) {
lua_pop (L, 1);
return luaL_error (L, "kad_ssyev_simple failed (no blas?)");
}
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
static inline rspamd_tensor_num_t
mean_vec (rspamd_tensor_num_t *x, int n)
{
rspamd_tensor_num_t s = 0;
rspamd_tensor_num_t c = 0;
for (int i = 0; i < n; i ++) {
rspamd_tensor_num_t v = x[i];
rspamd_tensor_num_t y = v - c;
rspamd_tensor_num_t t = s + y;
c = (t - s) - y;
s = t;
}
return s / (rspamd_tensor_num_t)n;
}
static gint
lua_tensor_mean (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1);
if (t) {
if (t->ndims == 1) {
/* Mean of all elements in a vector */
lua_pushnumber (L, mean_vec (t->data, t->dim[0]));
}
else {
/* Row-wise mean vector output */
struct rspamd_lua_tensor *res;
res = lua_newtensor (L, 1, &t->dim[0], false, true);
for (int i = 0; i < t->dim[0]; i ++) {
res->data[i] = mean_vec (&t->data[i * t->dim[1]], t->dim[1]);
}
}
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
static gint
lua_tensor_transpose (lua_State *L)
{
struct rspamd_lua_tensor *t = lua_check_tensor (L, 1), *res;
int dims[2];
if (t) {
if (t->ndims == 1) {
/* Row to column */
dims[0] = 1;
dims[1] = t->dim[0];
res = lua_newtensor (L, 2, dims, false, true);
memcpy (res->data, t->data, t->dim[0] * sizeof (rspamd_tensor_num_t));
}
else {
/* Cache friendly algorithm */
struct rspamd_lua_tensor *res;
dims[0] = t->dim[1];
dims[1] = t->dim[0];
res = lua_newtensor (L, 2, dims, false, true);
static const int block = 32;
for (int i = 0; i < t->dim[0]; i += block) {
for(int j = 0; j < t->dim[1]; ++j) {
for(int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) {
res->data[j * t->dim[0] + i + boff] =
t->data[(i + boff) * t->dim[1] + j];
}
}
}
}
}
else {
return luaL_error (L, "invalid arguments");
}
return 1;
}
static gint
lua_load_tensor (lua_State * L)
{
lua_newtable (L);
luaL_register (L, NULL, rspamd_tensor_f);
return 1;
}
void luaopen_tensor (lua_State *L)
{
/* Metatables */
rspamd_lua_new_class (L, TENSOR_CLASS, rspamd_tensor_m);
lua_pop (L, 1); /* No need in metatable... */
rspamd_lua_add_preload (L, "rspamd_tensor", lua_load_tensor);
lua_settop (L, 0);
}
struct rspamd_lua_tensor *
lua_check_tensor (lua_State *L, int pos)
{
void *ud = rspamd_lua_check_udata (L, pos, TENSOR_CLASS);
luaL_argcheck (L, ud != NULL, pos, "'tensor' expected");
return ud ? ((struct rspamd_lua_tensor *)ud) : NULL;
}
|