aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree/utils.h
blob: 8a0196a5891caeed4cd5dd8f8bb5df539e1104a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include "error.h"

#define check_tensors(L, a, b) \
   do { \
      if ((a)->nDimension != (b)->nDimension) \
         return LUA_HANDLE_ERROR_STR((L), "different tensor dimensions"); \
      for (int __local__var = 0; __local__var < (a)->nDimension; __local__var++) \
         if ((a)->size[__local__var] != (b)->size[__local__var]) \
            return LUA_HANDLE_ERROR_STR((L), "different tensor sizes"); \
   } while (0)

#define check_tensor(L, t, type) \
   do { \
      if (!type##_isContiguous(t)) \
         return LUA_HANDLE_ERROR_STR((L), "tensor should be contiguous"); \
   } while (0)

#define get_tensor_size(t, type)                \
    (TH##type##Tensor_nElement(t))

#define get_tensor(L, idx, type)                                        \
    (TH##type##Tensor *)luaT_checkudata(L, idx, "torch." #type "Tensor")

static int push_table_contents(lua_State *L, int arg)
{
    int size = 0;
    while(1) {
        lua_checkstack(L, 1);
        lua_rawgeti(L, arg, size + 1);
        if (lua_isnil(L, -1)) {
            lua_pop(L, 1);
            break;
        }
        size++;
    }
    return size;
}

#define verify_push_table_contents(L, idx, count)  do {             \
        int __tmp_count = push_table_contents(L, idx);              \
        if (__tmp_count != count) {                                 \
            lua_pop(L, __tmp_count);                                \
            LUA_HANDLE_ERROR_STR(L, "Table sizes do not match");    \
        }                                                           \
    } while(0)