You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

init.c 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #include "TH.h"
  2. #include "luaT.h"
  3. #ifdef _OPENMP
  4. #include "omp.h"
  5. #endif
  6. #include "error.h"
  7. #include "hash_map.h"
  8. #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
  9. #define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)
  10. #define nn_(NAME) TH_CONCAT_3(nn_, Real, NAME)
  11. #include "generic/LogitBoostCriterion.c"
  12. #include "THGenerateFloatTypes.h"
  13. #include "generic/DFD.c"
  14. #include "THGenerateFloatTypes.h"
  15. #include "generic/S2D.c"
  16. #include "THGenerateFloatTypes.h"
  17. #include "generic/CartTree.c"
  18. #include "THGenerateFloatTypes.h"
  19. #include "GBDT_common.h"
  20. #include "generic/GBDT.c"
  21. #include "THGenerateFloatTypes.h"
  22. static const struct luaL_Reg decisiontree_hash_map_routines[] = {
  23. {"__gc", hash_map_gc_lua},
  24. {"retain", hash_map_retain_lua},
  25. {"metatablename", hash_map_metatablename_lua},
  26. {"clear", hash_map_clear_lua},
  27. {"put", hash_map_put_lua},
  28. {"fill", hash_map_fill_lua},
  29. {"adjustCounter", hash_map_adjust_counter_lua},
  30. {"getCounter", hash_map_get_counter_lua},
  31. {"setCounter", hash_map_set_counter_lua},
  32. {"get", hash_map_get_lua},
  33. {"getInplace", hash_map_get_inplace_lua},
  34. {"del", hash_map_del_lua},
  35. {"size", hash_map_size_lua},
  36. {"safe", hash_map_autolock_on_lua},
  37. {"unsafe", hash_map_autolock_off_lua},
  38. {"toTensors", hash_map_to_tensor_lua},
  39. {"new", hash_map_init_lua},
  40. {NULL, NULL}
  41. };
  42. DLL_EXPORT int luaopen_libdecisiontree(lua_State *L)
  43. {
  44. // HashMap
  45. luaL_newmetatable(L, "dt.HashMap");
  46. lua_pushstring(L, "__index");
  47. lua_pushvalue(L, -2);
  48. lua_settable(L, -3);
  49. luaT_setfuncs(L, decisiontree_hash_map_routines, 0);
  50. nn_FloatLogitBoostCriterion_init(L);
  51. nn_DoubleLogitBoostCriterion_init(L);
  52. nn_FloatDFD_init(L);
  53. nn_DoubleDFD_init(L);
  54. nn_FloatS2D_init(L);
  55. nn_DoubleS2D_init(L);
  56. nn_FloatCT_init(L);
  57. nn_DoubleCT_init(L);
  58. nn_FloatGBDT_init(L);
  59. nn_DoubleGBDT_init(L);
  60. return 1;
  61. }