You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.h 1.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #include "error.h"
  2. #define check_tensors(L, a, b) \
  3. do { \
  4. if ((a)->nDimension != (b)->nDimension) \
  5. return LUA_HANDLE_ERROR_STR((L), "different tensor dimensions"); \
  6. for (int __local__var = 0; __local__var < (a)->nDimension; __local__var++) \
  7. if ((a)->size[__local__var] != (b)->size[__local__var]) \
  8. return LUA_HANDLE_ERROR_STR((L), "different tensor sizes"); \
  9. } while (0)
  10. #define check_tensor(L, t, type) \
  11. do { \
  12. if (!type##_isContiguous(t)) \
  13. return LUA_HANDLE_ERROR_STR((L), "tensor should be contiguous"); \
  14. } while (0)
  15. #define get_tensor_size(t, type) \
  16. (TH##type##Tensor_nElement(t))
  17. #define get_tensor(L, idx, type) \
  18. (TH##type##Tensor *)luaT_checkudata(L, idx, "torch." #type "Tensor")
  19. static int push_table_contents(lua_State *L, int arg)
  20. {
  21. int size = 0;
  22. while(1) {
  23. lua_checkstack(L, 1);
  24. lua_rawgeti(L, arg, size + 1);
  25. if (lua_isnil(L, -1)) {
  26. lua_pop(L, 1);
  27. break;
  28. }
  29. size++;
  30. }
  31. return size;
  32. }
  33. #define verify_push_table_contents(L, idx, count) do { \
  34. int __tmp_count = push_table_contents(L, idx); \
  35. if (__tmp_count != count) { \
  36. lua_pop(L, __tmp_count); \
  37. LUA_HANDLE_ERROR_STR(L, "Table sizes do not match"); \
  38. } \
  39. } while(0)