summaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree/utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/torch/decisiontree/utils.h')
-rw-r--r--contrib/torch/decisiontree/utils.h45
1 files changed, 45 insertions, 0 deletions
diff --git a/contrib/torch/decisiontree/utils.h b/contrib/torch/decisiontree/utils.h
new file mode 100644
index 000000000..8a0196a58
--- /dev/null
+++ b/contrib/torch/decisiontree/utils.h
@@ -0,0 +1,45 @@
+#include "error.h"
+
+#define check_tensors(L, a, b) \
+ do { \
+ if ((a)->nDimension != (b)->nDimension) \
+ return LUA_HANDLE_ERROR_STR((L), "different tensor dimensions"); \
+ for (int __local__var = 0; __local__var < (a)->nDimension; __local__var++) \
+ if ((a)->size[__local__var] != (b)->size[__local__var]) \
+ return LUA_HANDLE_ERROR_STR((L), "different tensor sizes"); \
+ } while (0)
+
+#define check_tensor(L, t, type) \
+ do { \
+ if (!type##_isContiguous(t)) \
+ return LUA_HANDLE_ERROR_STR((L), "tensor should be contiguous"); \
+ } while (0)
+
+#define get_tensor_size(t, type) \
+ (TH##type##Tensor_nElement(t))
+
+#define get_tensor(L, idx, type) \
+ (TH##type##Tensor *)luaT_checkudata(L, idx, "torch." #type "Tensor")
+
+static int push_table_contents(lua_State *L, int arg)
+{
+ int size = 0;
+ while(1) {
+ lua_checkstack(L, 1);
+ lua_rawgeti(L, arg, size + 1);
+ if (lua_isnil(L, -1)) {
+ lua_pop(L, 1);
+ break;
+ }
+ size++;
+ }
+ return size;
+}
+
+#define verify_push_table_contents(L, idx, count) do { \
+ int __tmp_count = push_table_contents(L, idx); \
+ if (__tmp_count != count) { \
+ lua_pop(L, __tmp_count); \
+ LUA_HANDLE_ERROR_STR(L, "Table sizes do not match"); \
+ } \
+ } while(0)