You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

hash_map.c 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. #include "utils.h"
  2. #include "hash_map.h"
  3. #include "internal_hash_map.h"
  4. #include <pthread.h>
  5. hash_map_t hash_map_init(void) {
  6. return kh_init(long);
  7. }
  8. void hash_map_destroy(hash_map_t h_) {
  9. internal_hash_map_t h = (internal_hash_map_t) h_;
  10. kh_destroy(long, h);
  11. }
  12. void hash_map_clear(hash_map_t h_) {
  13. internal_hash_map_t h = (internal_hash_map_t) h_;
  14. kh_clear(long, h);
  15. }
  16. int hash_map_put(hash_map_t h_, long key, long val) {
  17. internal_hash_map_t h = (internal_hash_map_t) h_;
  18. int ret;
  19. khiter_t k = kh_put(long, h, key, &ret);
  20. ret = (ret >= 0);
  21. if (ret)
  22. kh_value(h, k) = val;
  23. return ret;
  24. }
  25. int hash_map_put_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) {
  26. long *keys = THLongTensor_data(keys_);
  27. long *vals = THLongTensor_data(vals_);
  28. long size = get_tensor_size(keys_, Long);
  29. for (long i = 0; i < size; i++)
  30. if (!hash_map_put(h_, keys[i], vals[i]))
  31. return 0;
  32. return 1;
  33. }
  34. int hash_map_fill(hash_map_t h_, long key, long *counter) {
  35. internal_hash_map_t h = (internal_hash_map_t) h_;
  36. khiter_t k = kh_get(long, h, key);
  37. if (k == kh_end(h))
  38. return hash_map_put(h_, key, ++(*counter));
  39. return 1;
  40. }
  41. int hash_map_fill_tensor(hash_map_t h_, THLongTensor *keys_, long *counter) {
  42. long *keys = THLongTensor_data(keys_);
  43. long size = get_tensor_size(keys_, Long);
  44. for (long i = 0; i < size; i++)
  45. if (!hash_map_fill(h_, keys[i], counter))
  46. return 0;
  47. return 1;
  48. }
  49. int hash_map_get(hash_map_t h_, long key, long* val) {
  50. internal_hash_map_t h = (internal_hash_map_t) h_;
  51. khiter_t k = kh_get(long, h, key);
  52. if (k == kh_end(h))
  53. return 0;
  54. *val = kh_value(h, k);
  55. return 1;
  56. }
  57. void hash_map_get_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_) {
  58. long *keys = THLongTensor_data(keys_);
  59. long *vals = THLongTensor_data(vals_);;
  60. unsigned char *mask = THByteTensor_data(mask_);
  61. long size = get_tensor_size(keys_, Long);
  62. for (long i = 0; i < size; i++)
  63. mask[i] = hash_map_get(h_, keys[i], &vals[i]);
  64. }
  65. void hash_map_del(hash_map_t h_, long key) {
  66. internal_hash_map_t h = (internal_hash_map_t) h_;
  67. khiter_t k = kh_get(long, h, key);
  68. if (k != kh_end(h))
  69. kh_del(long, h, k);
  70. }
  71. void hash_map_del_tensor(hash_map_t h_, THLongTensor *keys_) {
  72. long *keys = THLongTensor_data(keys_);
  73. long size = get_tensor_size(keys_, Long);
  74. for (long i = 0; i < size; i++)
  75. hash_map_del(h_, keys[i]);
  76. }
  77. size_t hash_map_size(hash_map_t h_) {
  78. internal_hash_map_t h = (internal_hash_map_t) h_;
  79. return kh_size(h);
  80. }
  81. void hash_map_to_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) {
  82. internal_hash_map_t h = (internal_hash_map_t) h_;
  83. long *keys = THLongTensor_data(keys_);
  84. long *vals = THLongTensor_data(vals_);
  85. long key, val, i = 0;
  86. kh_foreach(h, key, val, {
  87. keys[i] = key;
  88. vals[i] = val;
  89. i++;
  90. });
  91. }
  92. static void autolock(hash_map_lua_t *h) {
  93. if (h->autolock) {
  94. pthread_mutex_lock(&h->mutex);
  95. }
  96. }
  97. static void autounlock(hash_map_lua_t *h) {
  98. if (h->autolock) {
  99. pthread_mutex_unlock(&h->mutex);
  100. }
  101. }
  102. int hash_map_autolock_on_lua(lua_State *L) {
  103. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  104. h->autolock = 1;
  105. return 0;
  106. }
  107. int hash_map_autolock_off_lua(lua_State *L) {
  108. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  109. h->autolock = 0;
  110. return 0;
  111. }
  112. int hash_map_init_lua(lua_State *L) {
  113. hash_map_lua_t **hp = (hash_map_lua_t**)lua_newuserdata(L, sizeof(hash_map_lua_t*));
  114. *hp = (hash_map_lua_t*)malloc(sizeof(hash_map_lua_t));
  115. hash_map_lua_t *h = *hp;
  116. h->refcount = 1;
  117. h->counter = 0;
  118. h->autolock = 0;
  119. h->h = hash_map_init();
  120. pthread_mutexattr_t mutex_attr;
  121. pthread_mutexattr_init(&mutex_attr);
  122. pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE);
  123. pthread_mutex_init(&h->mutex, &mutex_attr);
  124. luaL_getmetatable(L, "dt.HashMap");
  125. lua_setmetatable(L, -2);
  126. return 1;
  127. }
  128. int hash_map_gc_lua(lua_State *L) {
  129. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  130. if (THAtomicDecrementRef(&h->refcount)) {
  131. pthread_mutex_destroy(&h->mutex);
  132. hash_map_destroy(h->h);
  133. free(h);
  134. }
  135. return 0;
  136. }
  137. int hash_map_retain_lua(lua_State *L) {
  138. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  139. THAtomicIncrementRef(&h->refcount);
  140. return 0;
  141. }
  142. int hash_map_metatablename_lua(lua_State *L) {
  143. lua_pushstring(L, "dt.HashMap");
  144. return 1;
  145. }
  146. int hash_map_clear_lua(lua_State *L) {
  147. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  148. autolock(h);
  149. hash_map_clear(h->h);
  150. autounlock(h);
  151. return 0;
  152. }
  153. int hash_map_put_lua(lua_State *L) {
  154. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  155. int ret;
  156. #if LUA_VERSION_NUM <= 501
  157. #define lua_isinteger lua_isnumber
  158. #endif
  159. if (lua_isinteger(L, 2)) {
  160. if (!lua_isinteger(L, 3))
  161. return LUA_HANDLE_ERROR_STR(L, "second parameter is not a number");
  162. long key = lua_tointeger(L, 2);
  163. long val = lua_tointeger(L, 3);
  164. autolock(h);
  165. ret = hash_map_put(h->h, key, val);
  166. autounlock(h);
  167. }
  168. else {
  169. THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
  170. THLongTensor *vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
  171. check_tensor(L, keys, THLongTensor);
  172. check_tensor(L, vals, THLongTensor);
  173. check_tensors(L, keys, vals);
  174. autolock(h);
  175. ret = hash_map_put_tensor(h->h, keys, vals);
  176. autounlock(h);
  177. }
  178. if (!ret)
  179. return LUA_HANDLE_ERROR_STR(L, "failed to put into hash map");
  180. return 0;
  181. }
  182. int hash_map_fill_lua(lua_State *L) {
  183. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  184. int ret;
  185. if (lua_isinteger(L, 2)) {
  186. long key = lua_tointeger(L, 2);
  187. autolock(h);
  188. ret = hash_map_fill(h->h, key, &h->counter);
  189. autounlock(h);
  190. }
  191. else {
  192. THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
  193. check_tensor(L, keys, THLongTensor);
  194. autolock(h);
  195. ret = hash_map_fill_tensor(h->h, keys, &h->counter);
  196. autounlock(h);
  197. }
  198. if (!ret)
  199. return LUA_HANDLE_ERROR_STR(L, "failed to fill into hash map");
  200. return 0;
  201. }
  202. int hash_map_adjust_counter_lua(lua_State *L) {
  203. hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
  204. internal_hash_map_t h = (internal_hash_map_t) h_->h;
  205. long val;
  206. kh_foreach_value(h, val, {
  207. if (val >= h_->counter)
  208. h_->counter = val;
  209. });
  210. return 0;
  211. }
  212. int hash_map_set_counter_lua(lua_State *L) {
  213. hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
  214. h_->counter = lua_tointeger(L, 2);
  215. return 0;
  216. }
  217. int hash_map_get_counter_lua(lua_State *L) {
  218. hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
  219. lua_pushinteger(L, h_->counter);
  220. return 1;
  221. }
  222. static int hash_map_get_tensor_lua(lua_State *L, hash_map_lua_t *h, int inplace) {
  223. THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
  224. check_tensor(L, keys, THLongTensor);
  225. THLongTensor *vals = inplace ? keys : NULL;
  226. THByteTensor *mask = NULL;
  227. int maskIdx = inplace ? 3 : 4;
  228. if (!inplace) {
  229. if (lua_gettop(L) < 3) {
  230. vals = THLongTensor_new();
  231. } else {
  232. vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
  233. check_tensor(L, vals, THLongTensor);
  234. }
  235. }
  236. if (lua_gettop(L) < maskIdx) {
  237. mask = THByteTensor_new();
  238. } else {
  239. mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor");
  240. check_tensor(L, mask, THByteTensor);
  241. }
  242. int n_dim = THLongTensor_nDimension(keys);
  243. THLongStorage *st = THLongStorage_newWithSize1(n_dim);
  244. for (int i = 0; i < n_dim; i++) {
  245. THLongStorage_set(st, i, THLongTensor_size(keys, i));
  246. }
  247. THByteTensor_resize(mask, st, NULL);
  248. if (!inplace) THLongTensor_resize(vals, st, NULL);
  249. THLongStorage_free(st);
  250. autolock(h);
  251. hash_map_get_tensor(h->h, keys, vals, mask);
  252. autounlock(h);
  253. if (!inplace && lua_gettop(L) < 3)
  254. luaT_pushudata(L, vals, "torch.LongTensor");
  255. if (lua_gettop(L) < maskIdx)
  256. luaT_pushudata(L, mask, "torch.ByteTensor");
  257. return 2;
  258. }
  259. static int hash_map_get_table_lua(lua_State *L, hash_map_lua_t *h, int inplace) {
  260. const int kidx = 2;
  261. const int vidx = inplace ? 2 : 3;
  262. const int midx = inplace ? 3 : 4;
  263. const int narg = lua_gettop(L);
  264. if (inplace) {
  265. if (narg < 3) {
  266. LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace requires two arguments.");
  267. }
  268. } else {
  269. if (narg < 4) {
  270. LUA_HANDLE_ERROR_STR(L, "HashMap.get requires three arguments.");
  271. }
  272. }
  273. int count = push_table_contents(L, kidx);
  274. verify_push_table_contents(L, vidx, count);
  275. verify_push_table_contents(L, midx, count);
  276. THLongTensor *keys;
  277. THLongTensor *vals;
  278. THByteTensor *mask;
  279. for (int i = count - 1; i >= 0; i--) {
  280. int maskIdx = i - count;
  281. int valIdx = maskIdx - count;
  282. int keyIdx = inplace ? valIdx : (valIdx - count);
  283. keys = (THLongTensor *)luaT_checkudata(L, keyIdx, "torch.LongTensor");
  284. check_tensor(L, keys, THLongTensor);
  285. if (inplace) {
  286. vals = keys;
  287. } else {
  288. vals = (THLongTensor *)luaT_checkudata(L, valIdx, "torch.LongTensor");
  289. }
  290. mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor");
  291. int n_dim = THLongTensor_nDimension(keys);
  292. THLongStorage *st = THLongStorage_newWithSize1(n_dim);
  293. for (int i = 0; i < n_dim; i++) {
  294. THLongStorage_set(st, i, THLongTensor_size(keys, i));
  295. }
  296. THByteTensor_resize(mask, st, NULL);
  297. THLongTensor_resize(vals, st, NULL);
  298. THLongStorage_free(st);
  299. autolock(h);
  300. hash_map_get_tensor(h->h, keys, vals, mask);
  301. autounlock(h);
  302. }
  303. lua_pop(L, (narg - 1) * count);
  304. return 2;
  305. }
  306. int hash_map_get_lua(lua_State *L) {
  307. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  308. if (lua_isinteger(L, 2)) {
  309. long key = lua_tointeger(L, 2);
  310. long val;
  311. autolock(h);
  312. int ret = hash_map_get(h->h, key, &val);
  313. autounlock(h);
  314. if (ret) {
  315. lua_pushinteger(L, val);
  316. lua_pushinteger(L, 1);
  317. }
  318. else {
  319. lua_pushinteger(L, 0);
  320. lua_pushinteger(L, 0);
  321. }
  322. } else if (lua_istable(L, 2)) {
  323. return hash_map_get_table_lua(L, h, 0);
  324. } else {
  325. return hash_map_get_tensor_lua(L, h, 0);
  326. }
  327. return 2;
  328. }
  329. int hash_map_get_inplace_lua(lua_State *L) {
  330. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  331. if (lua_isinteger(L, 2)) {
  332. LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace does not support integer arguments.");
  333. } else if (lua_istable(L, 2)) {
  334. return hash_map_get_table_lua(L, h, 1);
  335. } else {
  336. return hash_map_get_tensor_lua(L, h, 1);
  337. }
  338. return 2;
  339. }
  340. int hash_map_del_lua(lua_State *L) {
  341. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  342. if (lua_isinteger(L, 2)) {
  343. long key = lua_tointeger(L, 2);
  344. autolock(h);
  345. hash_map_del(h->h, key);
  346. autounlock(h);
  347. }
  348. else {
  349. THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
  350. autolock(h);
  351. hash_map_del_tensor(h->h, keys);
  352. autounlock(h);
  353. }
  354. return 0;
  355. }
  356. int hash_map_size_lua(lua_State *L) {
  357. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  358. long size = hash_map_size(h->h);
  359. lua_pushinteger(L, size);
  360. return 1;
  361. }
  362. int hash_map_to_tensor_lua(lua_State *L) {
  363. hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
  364. THLongTensor *keys, *vals;
  365. if (lua_gettop(L) < 2) {
  366. keys = THLongTensor_new();
  367. }
  368. else {
  369. keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
  370. check_tensor(L, keys, THLongTensor);
  371. }
  372. if (lua_gettop(L) < 3) {
  373. vals = THLongTensor_new();
  374. }
  375. else {
  376. vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
  377. check_tensor(L, vals, THLongTensor);
  378. }
  379. size_t size = hash_map_size(h->h);
  380. THLongTensor_resize1d(keys, size);
  381. THLongTensor_resize1d(vals, size);
  382. autolock(h);
  383. hash_map_to_tensor(h->h, keys, vals);
  384. autounlock(h);
  385. if (lua_gettop(L) < 2)
  386. luaT_pushudata(L, keys, "torch.LongTensor");
  387. if (lua_gettop(L) < 3)
  388. luaT_pushudata(L, vals, "torch.LongTensor");
  389. return 2;
  390. }