選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

GBDT.c 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. #ifndef TH_GENERIC_FILE
  2. #define TH_GENERIC_FILE "generic/GBDT.c"
  3. #else
  4. #include "GBDT_internal.h"
  5. #include "GBDT_internal.c"
  6. // note that each one of the functions to find the best split is a subset of the next.
  7. // first we have one that can only evaluate a single feature, using the logic in lua to control the
  8. // features
  9. // then we have one that can go over a shard of faetures, following the feature parallelism
  10. // introduced by the lua logic
  11. // and finally we have one that performans the feature parallelism itself in the special case of
  12. // dense tensors
  13. // these functions are provided for completeness and to test in case the logic is to be changed
  14. // finds the best split for a given node and feature
  15. static int nn_(gb_findBestFeatureSplit)(lua_State *L) {
  16. THLongTensor *exampleIds = luaT_checkudata(L, 1, "torch.LongTensor");
  17. const int dataset_index = 2;
  18. if (!lua_isnumber(L, 3))
  19. return LUA_HANDLE_ERROR_STR(L, "third argument should be an integer");
  20. long feature_id = lua_tointeger(L, 3);
  21. if (!lua_isnumber(L, 4))
  22. return LUA_HANDLE_ERROR_STR(L, "fourth argument should be an integer");
  23. long minLeafSize = lua_tointeger(L, 4);
  24. // Since minLeafSize == 1 corresponds to each sample in its own leaf, any value below it doesn't
  25. // make sense
  26. if (minLeafSize < 1)
  27. minLeafSize = 1;
  28. THTensor *grad = luaT_checkudata(L, 5, torch_Tensor);
  29. THTensor *hess = luaT_checkudata(L, 6, torch_Tensor);
  30. if (!THLongTensor_isContiguous(exampleIds))
  31. return LUA_HANDLE_ERROR_STR(L, "exampleIds has to be contiguous");
  32. if (!THTensor_(isContiguous)(grad))
  33. return LUA_HANDLE_ERROR_STR(L, "grad has to be contiguous");
  34. if (!THTensor_(isContiguous)(hess))
  35. return LUA_HANDLE_ERROR_STR(L, "hessian has to be contiguous");
  36. // initializes the static data
  37. nn_(GBInitialization) initialization_data;
  38. nn_(gb_initialize)(L, &initialization_data, exampleIds, grad, hess, dataset_index);
  39. // initializes the dynamic data
  40. GBRunData run_data;
  41. gb_create_run_data(&run_data, minLeafSize);
  42. // finds the best state possible for the split
  43. nn_(GBBestState) bs;
  44. nn_(gb_find_best_feature_split)(L, &initialization_data, &bs, feature_id, &run_data);
  45. lua_pop(L, lua_gettop(L) - initialization_data.splitInfo_index);
  46. // fills the table we the best split found and the lua logic above will do everything else
  47. // if no state was found, returns nil
  48. if (bs.valid_state == 0) {
  49. lua_pop(L, 1);
  50. lua_pushnil(L);
  51. }
  52. else {
  53. nn_(gb_internal_split_info)(L, &bs, initialization_data.splitInfo_index);
  54. }
  55. gb_destroy_run_data(&run_data);
  56. return 1;
  57. }
  58. // finds the best split for a given node and shard of features
  59. // this is more efficient than calling the previous one multiple times
  60. static int nn_(gb_findBestSplit)(lua_State *L) {
  61. THLongTensor *exampleIds = luaT_checkudata(L, 1, "torch.LongTensor");
  62. const int dataset_index = 2;
  63. THLongTensor *feature_ids = luaT_checkudata(L, 3, "torch.LongTensor");
  64. if (!lua_isnumber(L, 4))
  65. return LUA_HANDLE_ERROR_STR(L, "fourth argument should be an integer");
  66. long minLeafSize = lua_tointeger(L, 4);
  67. // Since minLeafSize == 1 corresponds to each sample in its own leaf, any value below it doesn't
  68. // make sense
  69. if (minLeafSize < 1)
  70. minLeafSize = 1;
  71. if (!lua_isnumber(L, 5))
  72. return LUA_HANDLE_ERROR_STR(L, "fifth argument should be an integer");
  73. long shardId = lua_tointeger(L, 5);
  74. if (!lua_isnumber(L, 6))
  75. return LUA_HANDLE_ERROR_STR(L, "sixth argument should be an integer");
  76. long nShard = lua_tointeger(L, 6);
  77. THTensor *grad = luaT_checkudata(L, 7, torch_Tensor);
  78. THTensor *hess = luaT_checkudata(L, 8, torch_Tensor);
  79. if (!THLongTensor_isContiguous(exampleIds))
  80. return LUA_HANDLE_ERROR_STR(L, "exampleIds has to be contiguous");
  81. if (!THTensor_(isContiguous)(grad))
  82. return LUA_HANDLE_ERROR_STR(L, "grad has to be contiguous");
  83. if (!THTensor_(isContiguous)(hess))
  84. return LUA_HANDLE_ERROR_STR(L, "hessian has to be contiguous");
  85. // initializes the static data
  86. nn_(GBInitialization) initialization_data;
  87. nn_(gb_initialize)(L, &initialization_data, exampleIds, grad, hess, dataset_index);
  88. // initializes the dynamic data
  89. GBRunData run_data;
  90. gb_create_run_data(&run_data, minLeafSize);
  91. // initializes to evaluate all the features in this shard
  92. nn_(GBBestState) global_bs;
  93. global_bs.valid_state = 0;
  94. long n_features = THLongTensor_size(feature_ids, 0);
  95. if (!THLongTensor_isContiguous(feature_ids))
  96. return LUA_HANDLE_ERROR_STR(L, "feature_ids must be contiguous");
  97. long *feature_ids_data = THLongTensor_data(feature_ids);
  98. // for every feature
  99. for (long i = 0; i < n_features; i++) {
  100. long feature_id = feature_ids_data[i];
  101. // if we are responsible for it
  102. if (nShard <= 1 || (feature_id % nShard) + 1 == shardId) {
  103. // finds the best state possible for the split
  104. nn_(GBBestState) bs;
  105. nn_(gb_find_best_feature_split)(L, &initialization_data, &bs, feature_id, &run_data);
  106. // if it's valid and better than one we found before, saves it
  107. if (bs.valid_state) {
  108. if (global_bs.valid_state == 0 || bs.gain < global_bs.gain) {
  109. global_bs = bs;
  110. }
  111. }
  112. }
  113. }
  114. lua_pop(L, lua_gettop(L) - initialization_data.splitInfo_index);
  115. // fills the table we the best split found and the lua logic above will do everything else
  116. // if no state was found, returns nil
  117. if (global_bs.valid_state == 0) {
  118. lua_pop(L, 1);
  119. lua_pushnil(L);
  120. }
  121. else {
  122. nn_(gb_internal_split_info)(L, &global_bs, initialization_data.splitInfo_index);
  123. }
  124. gb_destroy_run_data(&run_data);
  125. return 1;
  126. }
  127. // all the info we have to apss to the slave threads so that they can do their jobs
  128. // note that we do not pass the lua state since it isn't required. we perform direct C parallelism
  129. // instead of using lua's parallelism like with the previous version
  130. typedef struct {
  131. nn_(GBInitialization) *initialization_data;
  132. GBRunData *run_data;
  133. long *index;
  134. nn_(GBBestState) *global_bs;
  135. long n_features;
  136. long *feature_ids_data;
  137. pthread_mutex_t *mutex;
  138. THLongTensor *exampleIds;
  139. THTensor *input;
  140. THLongTensor **sorted_ids_per_feature;
  141. } nn_(ThreadInfo);
  142. // loops over all the features in parallel and finds the best global split
  143. static void* nn_(thread_worker)(void *arg) {
  144. nn_(ThreadInfo) *info = (nn_(ThreadInfo) *)arg;
  145. while (1) {
  146. pthread_mutex_lock(info->mutex);
  147. long index = (*info->index);
  148. (*info->index)++;
  149. pthread_mutex_unlock(info->mutex);
  150. if (index >= info->n_features)
  151. break;
  152. // performs part of steps (1) and (2) of gb_find_best_feature_split without having to access the
  153. // lua state using pre-loaded data
  154. long feature_id = info->feature_ids_data[index];
  155. THLongTensor *exampleIdsWithFeature_ret = info->exampleIds;
  156. THLongTensor *featureExampleIds = info->sorted_ids_per_feature[index];
  157. nn_(GBInitialization) *initialization_data = info->initialization_data;
  158. GBRunData *run_data = info->run_data;
  159. // performs steps (3) and (4) of gb_find_best_feature_split since (1) and (2) were already
  160. // performed before
  161. nn_(GBBestState) bs;
  162. nn_(gb_internal_create)(initialization_data->grad, initialization_data->hess,
  163. exampleIdsWithFeature_ret, &bs.state);
  164. nn_(gb_internal_get_best_split_special)(&bs, featureExampleIds, run_data->exampleMap,
  165. info->input, run_data->minLeafSize, feature_id);
  166. // saves to the global state if it's better
  167. if (bs.valid_state) {
  168. pthread_mutex_lock(info->mutex);
  169. if (info->global_bs->valid_state == 0 || bs.gain < info->global_bs->gain) {
  170. (*info->global_bs) = bs;
  171. }
  172. pthread_mutex_unlock(info->mutex);
  173. }
  174. }
  175. return NULL;
  176. }
  177. // finds the global best split by doing feature parallelism directly in C
  178. static int nn_(gb_findBestSplitFP)(lua_State *L) {
  179. THLongTensor *exampleIds = luaT_checkudata(L, 1, "torch.LongTensor");
  180. const int dataset_index = 2;
  181. THLongTensor *feature_ids = luaT_checkudata(L, 3, "torch.LongTensor");
  182. if (!lua_isnumber(L, 4))
  183. return LUA_HANDLE_ERROR_STR(L, "fourth argument should be an integer");
  184. long minLeafSize = lua_tointeger(L, 4);
  185. THTensor *grad = luaT_checkudata(L, 5, torch_Tensor);
  186. THTensor *hess = luaT_checkudata(L, 6, torch_Tensor);
  187. if (!lua_isnumber(L, 7))
  188. return LUA_HANDLE_ERROR_STR(L, "seventh argument should be an integer");
  189. long nThread = lua_tointeger(L, 7);
  190. if (!THLongTensor_isContiguous(exampleIds))
  191. return LUA_HANDLE_ERROR_STR(L, "exampleIds has to be contiguous");
  192. if (!THTensor_(isContiguous)(grad))
  193. return LUA_HANDLE_ERROR_STR(L, "grad has to be contiguous");
  194. if (!THTensor_(isContiguous)(hess))
  195. return LUA_HANDLE_ERROR_STR(L, "hessian has to be contiguous");
  196. pthread_mutex_t mutex;
  197. pthread_mutex_init(&mutex, NULL);
  198. // initializes the static data
  199. nn_(GBInitialization) initialization_data;
  200. nn_(gb_initialize)(L, &initialization_data, exampleIds, grad, hess, dataset_index);
  201. // initializes the dynamic data
  202. GBRunData run_data;
  203. gb_create_run_data(&run_data, minLeafSize);
  204. // initializes to evaluate all the features
  205. nn_(GBBestState) global_bs;
  206. global_bs.valid_state = 0;
  207. long n_features = THLongTensor_size(feature_ids, 0);
  208. if (!THLongTensor_isContiguous(feature_ids))
  209. return LUA_HANDLE_ERROR_STR(L, "feature_ids must be contiguous");
  210. long *feature_ids_data = THLongTensor_data(feature_ids);
  211. THTensor *input = luaT_checkudata(L, initialization_data.input_index, torch_Tensor);
  212. // performs step (1) of gb_find_best_feature_split so that we don't have to pass the lua state
  213. THLongTensor *sorted_ids_per_feature[n_features];
  214. for (long i = 0; i < n_features; i++) {
  215. long feature_id = feature_ids_data[i];
  216. lua_pushvalue(L, initialization_data.getSortedFeature_index);
  217. lua_pushvalue(L, initialization_data.dataset_index);
  218. lua_pushinteger(L, feature_id);
  219. lua_call(L, 2, 1);
  220. THLongTensor *featureExampleIds = luaT_checkudata(L, -1, "torch.LongTensor");
  221. sorted_ids_per_feature[i] = featureExampleIds;
  222. }
  223. // performas step (2) of gb_find_best_feature_split since it's the same for all features when the
  224. // data is dense
  225. long exampleIds_size = THLongTensor_size(initialization_data.exampleIds, 0);
  226. long *exampleIds_data = THLongTensor_data(initialization_data.exampleIds);
  227. int ret;
  228. kh_resize(long, run_data.exampleMap, exampleIds_size*8);
  229. for (long i = 0; i < exampleIds_size; i++)
  230. kh_put(long, run_data.exampleMap, exampleIds_data[i], &ret);
  231. // saves the info for the threads
  232. long index = 0;
  233. nn_(ThreadInfo) info;
  234. info.initialization_data = &initialization_data;
  235. info.run_data = &run_data;
  236. info.index = &index;
  237. info.global_bs = &global_bs;
  238. info.n_features = n_features;
  239. info.feature_ids_data = feature_ids_data;
  240. info.mutex = &mutex;
  241. info.exampleIds = exampleIds;
  242. info.input = input;
  243. info.sorted_ids_per_feature = sorted_ids_per_feature;
  244. pthread_t threads[nThread];
  245. // let the threads run like crazy over the features to find the minimum
  246. for (long i = 0; i < nThread; i++) {
  247. int ret = pthread_create(&threads[i], NULL, nn_(thread_worker), &info);
  248. if (ret)
  249. return LUA_HANDLE_ERROR_STR(L, "falied to create thread");
  250. }
  251. for (long i = 0; i < nThread; i++) {
  252. int ret = pthread_join(threads[i], NULL);
  253. if (ret)
  254. return LUA_HANDLE_ERROR_STR(L, "failed to join thread");
  255. }
  256. lua_pop(L, lua_gettop(L) - initialization_data.splitInfo_index);
  257. // fills the table we the best split found and the lua logic above will do everything else
  258. // if no state was found, returns nil
  259. if (global_bs.valid_state == 0) {
  260. lua_pop(L, 1);
  261. lua_pushnil(L);
  262. }
  263. else {
  264. nn_(gb_internal_split_info)(L, &global_bs, initialization_data.splitInfo_index);
  265. }
  266. gb_destroy_run_data(&run_data);
  267. pthread_mutex_destroy(&mutex);
  268. return 1;
  269. }
  270. // performs an efficient branch of the current examples based on a split info provided
  271. static int nn_(gb_branch)(lua_State *L) {
  272. if (!lua_istable(L, 1))
  273. return LUA_HANDLE_ERROR_STR(L, "first argument must be a table");
  274. THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
  275. THLongTensor *exampleIds = luaT_checkudata(L, 3, "torch.LongTensor");
  276. // gets direct access to the dataset
  277. long n_exampleIds = THLongTensor_size(exampleIds, 0);
  278. long *exampleIds_data = THLongTensor_data(exampleIds);
  279. long n_features = THTensor_(size)(input, 1);
  280. real *input_data = THTensor_(data)(input);
  281. // creates the tensors to be returned
  282. luaT_pushudata(L, THLongTensor_new(), "torch.LongTensor");
  283. luaT_pushudata(L, THLongTensor_new(), "torch.LongTensor");
  284. THLongTensor *leftExampleIds = luaT_checkudata(L, 4, "torch.LongTensor");
  285. THLongTensor *rightExampleIds = luaT_checkudata(L, 5, "torch.LongTensor");
  286. THLongTensor_resize1d(leftExampleIds, n_exampleIds);
  287. // gets direct access to the examples
  288. THLongTensor *splitExampleIds = leftExampleIds;
  289. long *splitExampleIds_data = THLongTensor_data(splitExampleIds);
  290. // gets the split info
  291. lua_pushstring(L, "splitId");
  292. lua_rawget(L, 1);
  293. const long splitId = lua_tointeger(L, -1);
  294. lua_pushstring(L, "splitValue");
  295. lua_rawget(L, 1);
  296. const real splitValue = lua_tonumber(L, -1);
  297. lua_pop(L, 2);
  298. long leftIdx = 0, rightIdx = 0;
  299. // goes over all the samples dividing them into the two sides
  300. for (long i = 0; i < n_exampleIds; i++) {
  301. long exampleId = exampleIds_data[i];
  302. real val = input_data[(exampleId-1) * n_features + (splitId - 1)];
  303. if (val <= splitValue) {
  304. leftIdx++;
  305. splitExampleIds_data[leftIdx-1] = exampleId;
  306. }
  307. else {
  308. rightIdx++;
  309. splitExampleIds_data[n_exampleIds - rightIdx + 1 - 1] = exampleId;
  310. }
  311. }
  312. // once done, the resulting tensors are just splits of the sample base. this is more efficient
  313. // than having 2 tensors since we didn't know where the split would happen (how much to each
  314. // side), but we knew that the sum would be constant
  315. THLongTensor_narrow(rightExampleIds, splitExampleIds, 0, n_exampleIds-rightIdx+1-1, rightIdx);
  316. THLongTensor_narrow(leftExampleIds, splitExampleIds, 0, 0, leftIdx);
  317. return 2;
  318. }
  319. static const struct luaL_Reg nn_(GBDT__) [] = {
  320. {"GBDT_findBestFeatureSplit", nn_(gb_findBestFeatureSplit)},
  321. {"GBDT_findBestSplit", nn_(gb_findBestSplit)},
  322. {"GBDT_findBestSplitFP", nn_(gb_findBestSplitFP)},
  323. {"GBDT_branch", nn_(gb_branch)},
  324. {NULL, NULL}
  325. };
  326. static void nn_(GBDT_init)(lua_State *L)
  327. {
  328. luaT_pushmetatable(L, torch_Tensor);
  329. luaT_registeratname(L, nn_(GBDT__), "nn");
  330. lua_pop(L,1);
  331. }
  332. #endif