#include "utils.h" #include "hash_map.h" #include "internal_hash_map.h" #include hash_map_t hash_map_init(void) { return kh_init(long); } void hash_map_destroy(hash_map_t h_) { internal_hash_map_t h = (internal_hash_map_t) h_; kh_destroy(long, h); } void hash_map_clear(hash_map_t h_) { internal_hash_map_t h = (internal_hash_map_t) h_; kh_clear(long, h); } int hash_map_put(hash_map_t h_, long key, long val) { internal_hash_map_t h = (internal_hash_map_t) h_; int ret; khiter_t k = kh_put(long, h, key, &ret); ret = (ret >= 0); if (ret) kh_value(h, k) = val; return ret; } int hash_map_put_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) { long *keys = THLongTensor_data(keys_); long *vals = THLongTensor_data(vals_); long size = get_tensor_size(keys_, Long); for (long i = 0; i < size; i++) if (!hash_map_put(h_, keys[i], vals[i])) return 0; return 1; } int hash_map_fill(hash_map_t h_, long key, long *counter) { internal_hash_map_t h = (internal_hash_map_t) h_; khiter_t k = kh_get(long, h, key); if (k == kh_end(h)) return hash_map_put(h_, key, ++(*counter)); return 1; } int hash_map_fill_tensor(hash_map_t h_, THLongTensor *keys_, long *counter) { long *keys = THLongTensor_data(keys_); long size = get_tensor_size(keys_, Long); for (long i = 0; i < size; i++) if (!hash_map_fill(h_, keys[i], counter)) return 0; return 1; } int hash_map_get(hash_map_t h_, long key, long* val) { internal_hash_map_t h = (internal_hash_map_t) h_; khiter_t k = kh_get(long, h, key); if (k == kh_end(h)) return 0; *val = kh_value(h, k); return 1; } void hash_map_get_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_) { long *keys = THLongTensor_data(keys_); long *vals = THLongTensor_data(vals_);; unsigned char *mask = THByteTensor_data(mask_); long size = get_tensor_size(keys_, Long); for (long i = 0; i < size; i++) mask[i] = hash_map_get(h_, keys[i], &vals[i]); } void hash_map_del(hash_map_t h_, long key) { internal_hash_map_t h = (internal_hash_map_t) h_; khiter_t k = kh_get(long, h, key); if (k != kh_end(h)) kh_del(long, h, k); } void hash_map_del_tensor(hash_map_t h_, THLongTensor *keys_) { long *keys = THLongTensor_data(keys_); long size = get_tensor_size(keys_, Long); for (long i = 0; i < size; i++) hash_map_del(h_, keys[i]); } size_t hash_map_size(hash_map_t h_) { internal_hash_map_t h = (internal_hash_map_t) h_; return kh_size(h); } void hash_map_to_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) { internal_hash_map_t h = (internal_hash_map_t) h_; long *keys = THLongTensor_data(keys_); long *vals = THLongTensor_data(vals_); long key, val, i = 0; kh_foreach(h, key, val, { keys[i] = key; vals[i] = val; i++; }); } static void autolock(hash_map_lua_t *h) { if (h->autolock) { pthread_mutex_lock(&h->mutex); } } static void autounlock(hash_map_lua_t *h) { if (h->autolock) { pthread_mutex_unlock(&h->mutex); } } int hash_map_autolock_on_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); h->autolock = 1; return 0; } int hash_map_autolock_off_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); h->autolock = 0; return 0; } int hash_map_init_lua(lua_State *L) { hash_map_lua_t **hp = (hash_map_lua_t**)lua_newuserdata(L, sizeof(hash_map_lua_t*)); *hp = (hash_map_lua_t*)malloc(sizeof(hash_map_lua_t)); hash_map_lua_t *h = *hp; h->refcount = 1; h->counter = 0; h->autolock = 0; h->h = hash_map_init(); pthread_mutexattr_t mutex_attr; pthread_mutexattr_init(&mutex_attr); pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); pthread_mutex_init(&h->mutex, &mutex_attr); luaL_getmetatable(L, "dt.HashMap"); lua_setmetatable(L, -2); return 1; } int hash_map_gc_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); if (THAtomicDecrementRef(&h->refcount)) { pthread_mutex_destroy(&h->mutex); hash_map_destroy(h->h); free(h); } return 0; } int hash_map_retain_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); THAtomicIncrementRef(&h->refcount); return 0; } int hash_map_metatablename_lua(lua_State *L) { lua_pushstring(L, "dt.HashMap"); return 1; } int hash_map_clear_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); autolock(h); hash_map_clear(h->h); autounlock(h); return 0; } int hash_map_put_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); int ret; #if LUA_VERSION_NUM <= 501 #define lua_isinteger lua_isnumber #endif if (lua_isinteger(L, 2)) { if (!lua_isinteger(L, 3)) return LUA_HANDLE_ERROR_STR(L, "second parameter is not a number"); long key = lua_tointeger(L, 2); long val = lua_tointeger(L, 3); autolock(h); ret = hash_map_put(h->h, key, val); autounlock(h); } else { THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); THLongTensor *vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); check_tensor(L, keys, THLongTensor); check_tensor(L, vals, THLongTensor); check_tensors(L, keys, vals); autolock(h); ret = hash_map_put_tensor(h->h, keys, vals); autounlock(h); } if (!ret) return LUA_HANDLE_ERROR_STR(L, "failed to put into hash map"); return 0; } int hash_map_fill_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); int ret; if (lua_isinteger(L, 2)) { long key = lua_tointeger(L, 2); autolock(h); ret = hash_map_fill(h->h, key, &h->counter); autounlock(h); } else { THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); check_tensor(L, keys, THLongTensor); autolock(h); ret = hash_map_fill_tensor(h->h, keys, &h->counter); autounlock(h); } if (!ret) return LUA_HANDLE_ERROR_STR(L, "failed to fill into hash map"); return 0; } int hash_map_adjust_counter_lua(lua_State *L) { hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); internal_hash_map_t h = (internal_hash_map_t) h_->h; long val; kh_foreach_value(h, val, { if (val >= h_->counter) h_->counter = val; }); return 0; } int hash_map_set_counter_lua(lua_State *L) { hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); h_->counter = lua_tointeger(L, 2); return 0; } int hash_map_get_counter_lua(lua_State *L) { hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); lua_pushinteger(L, h_->counter); return 1; } static int hash_map_get_tensor_lua(lua_State *L, hash_map_lua_t *h, int inplace) { THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); check_tensor(L, keys, THLongTensor); THLongTensor *vals = inplace ? keys : NULL; THByteTensor *mask = NULL; int maskIdx = inplace ? 3 : 4; if (!inplace) { if (lua_gettop(L) < 3) { vals = THLongTensor_new(); } else { vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); check_tensor(L, vals, THLongTensor); } } if (lua_gettop(L) < maskIdx) { mask = THByteTensor_new(); } else { mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor"); check_tensor(L, mask, THByteTensor); } int n_dim = THLongTensor_nDimension(keys); THLongStorage *st = THLongStorage_newWithSize1(n_dim); for (int i = 0; i < n_dim; i++) { THLongStorage_set(st, i, THLongTensor_size(keys, i)); } THByteTensor_resize(mask, st, NULL); if (!inplace) THLongTensor_resize(vals, st, NULL); THLongStorage_free(st); autolock(h); hash_map_get_tensor(h->h, keys, vals, mask); autounlock(h); if (!inplace && lua_gettop(L) < 3) luaT_pushudata(L, vals, "torch.LongTensor"); if (lua_gettop(L) < maskIdx) luaT_pushudata(L, mask, "torch.ByteTensor"); return 2; } static int hash_map_get_table_lua(lua_State *L, hash_map_lua_t *h, int inplace) { const int kidx = 2; const int vidx = inplace ? 2 : 3; const int midx = inplace ? 3 : 4; const int narg = lua_gettop(L); if (inplace) { if (narg < 3) { LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace requires two arguments."); } } else { if (narg < 4) { LUA_HANDLE_ERROR_STR(L, "HashMap.get requires three arguments."); } } int count = push_table_contents(L, kidx); verify_push_table_contents(L, vidx, count); verify_push_table_contents(L, midx, count); THLongTensor *keys; THLongTensor *vals; THByteTensor *mask; for (int i = count - 1; i >= 0; i--) { int maskIdx = i - count; int valIdx = maskIdx - count; int keyIdx = inplace ? valIdx : (valIdx - count); keys = (THLongTensor *)luaT_checkudata(L, keyIdx, "torch.LongTensor"); check_tensor(L, keys, THLongTensor); if (inplace) { vals = keys; } else { vals = (THLongTensor *)luaT_checkudata(L, valIdx, "torch.LongTensor"); } mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor"); int n_dim = THLongTensor_nDimension(keys); THLongStorage *st = THLongStorage_newWithSize1(n_dim); for (int i = 0; i < n_dim; i++) { THLongStorage_set(st, i, THLongTensor_size(keys, i)); } THByteTensor_resize(mask, st, NULL); THLongTensor_resize(vals, st, NULL); THLongStorage_free(st); autolock(h); hash_map_get_tensor(h->h, keys, vals, mask); autounlock(h); } lua_pop(L, (narg - 1) * count); return 2; } int hash_map_get_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); if (lua_isinteger(L, 2)) { long key = lua_tointeger(L, 2); long val; autolock(h); int ret = hash_map_get(h->h, key, &val); autounlock(h); if (ret) { lua_pushinteger(L, val); lua_pushinteger(L, 1); } else { lua_pushinteger(L, 0); lua_pushinteger(L, 0); } } else if (lua_istable(L, 2)) { return hash_map_get_table_lua(L, h, 0); } else { return hash_map_get_tensor_lua(L, h, 0); } return 2; } int hash_map_get_inplace_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); if (lua_isinteger(L, 2)) { LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace does not support integer arguments."); } else if (lua_istable(L, 2)) { return hash_map_get_table_lua(L, h, 1); } else { return hash_map_get_tensor_lua(L, h, 1); } return 2; } int hash_map_del_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); if (lua_isinteger(L, 2)) { long key = lua_tointeger(L, 2); autolock(h); hash_map_del(h->h, key); autounlock(h); } else { THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); autolock(h); hash_map_del_tensor(h->h, keys); autounlock(h); } return 0; } int hash_map_size_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); long size = hash_map_size(h->h); lua_pushinteger(L, size); return 1; } int hash_map_to_tensor_lua(lua_State *L) { hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); THLongTensor *keys, *vals; if (lua_gettop(L) < 2) { keys = THLongTensor_new(); } else { keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); check_tensor(L, keys, THLongTensor); } if (lua_gettop(L) < 3) { vals = THLongTensor_new(); } else { vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); check_tensor(L, vals, THLongTensor); } size_t size = hash_map_size(h->h); THLongTensor_resize1d(keys, size); THLongTensor_resize1d(vals, size); autolock(h); hash_map_to_tensor(h->h, keys, vals); autounlock(h); if (lua_gettop(L) < 2) luaT_pushudata(L, keys, "torch.LongTensor"); if (lua_gettop(L) < 3) luaT_pushudata(L, vals, "torch.LongTensor"); return 2; }