diff options
Diffstat (limited to 'contrib/torch/decisiontree/utils.h')
-rw-r--r-- | contrib/torch/decisiontree/utils.h | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/contrib/torch/decisiontree/utils.h b/contrib/torch/decisiontree/utils.h new file mode 100644 index 000000000..8a0196a58 --- /dev/null +++ b/contrib/torch/decisiontree/utils.h @@ -0,0 +1,45 @@ +#include "error.h" + +#define check_tensors(L, a, b) \ + do { \ + if ((a)->nDimension != (b)->nDimension) \ + return LUA_HANDLE_ERROR_STR((L), "different tensor dimensions"); \ + for (int __local__var = 0; __local__var < (a)->nDimension; __local__var++) \ + if ((a)->size[__local__var] != (b)->size[__local__var]) \ + return LUA_HANDLE_ERROR_STR((L), "different tensor sizes"); \ + } while (0) + +#define check_tensor(L, t, type) \ + do { \ + if (!type##_isContiguous(t)) \ + return LUA_HANDLE_ERROR_STR((L), "tensor should be contiguous"); \ + } while (0) + +#define get_tensor_size(t, type) \ + (TH##type##Tensor_nElement(t)) + +#define get_tensor(L, idx, type) \ + (TH##type##Tensor *)luaT_checkudata(L, idx, "torch." #type "Tensor") + +static int push_table_contents(lua_State *L, int arg) +{ + int size = 0; + while(1) { + lua_checkstack(L, 1); + lua_rawgeti(L, arg, size + 1); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + break; + } + size++; + } + return size; +} + +#define verify_push_table_contents(L, idx, count) do { \ + int __tmp_count = push_table_contents(L, idx); \ + if (__tmp_count != count) { \ + lua_pop(L, __tmp_count); \ + LUA_HANDLE_ERROR_STR(L, "Table sizes do not match"); \ + } \ + } while(0) |