diff options
Diffstat (limited to 'contrib/torch/decisiontree/hash_map.c')
-rw-r--r-- | contrib/torch/decisiontree/hash_map.c | 445 |
1 files changed, 445 insertions, 0 deletions
diff --git a/contrib/torch/decisiontree/hash_map.c b/contrib/torch/decisiontree/hash_map.c new file mode 100644 index 000000000..2c679e207 --- /dev/null +++ b/contrib/torch/decisiontree/hash_map.c @@ -0,0 +1,445 @@ +#include "utils.h" +#include "hash_map.h" +#include "internal_hash_map.h" +#include <pthread.h> + +hash_map_t hash_map_init(void) { + return kh_init(long); +} + +void hash_map_destroy(hash_map_t h_) { + internal_hash_map_t h = (internal_hash_map_t) h_; + kh_destroy(long, h); +} + +void hash_map_clear(hash_map_t h_) { + internal_hash_map_t h = (internal_hash_map_t) h_; + kh_clear(long, h); +} + +int hash_map_put(hash_map_t h_, long key, long val) { + internal_hash_map_t h = (internal_hash_map_t) h_; + int ret; + khiter_t k = kh_put(long, h, key, &ret); + ret = (ret >= 0); + if (ret) + kh_value(h, k) = val; + return ret; +} + +int hash_map_put_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) { + long *keys = THLongTensor_data(keys_); + long *vals = THLongTensor_data(vals_); + long size = get_tensor_size(keys_, Long); + for (long i = 0; i < size; i++) + if (!hash_map_put(h_, keys[i], vals[i])) + return 0; + return 1; +} + +int hash_map_fill(hash_map_t h_, long key, long *counter) { + internal_hash_map_t h = (internal_hash_map_t) h_; + khiter_t k = kh_get(long, h, key); + if (k == kh_end(h)) + return hash_map_put(h_, key, ++(*counter)); + return 1; +} + +int hash_map_fill_tensor(hash_map_t h_, THLongTensor *keys_, long *counter) { + long *keys = THLongTensor_data(keys_); + long size = get_tensor_size(keys_, Long); + for (long i = 0; i < size; i++) + if (!hash_map_fill(h_, keys[i], counter)) + return 0; + return 1; +} + +int hash_map_get(hash_map_t h_, long key, long* val) { + internal_hash_map_t h = (internal_hash_map_t) h_; + khiter_t k = kh_get(long, h, key); + if (k == kh_end(h)) + return 0; + *val = kh_value(h, k); + return 1; +} + +void hash_map_get_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_) { + long *keys = THLongTensor_data(keys_); + long *vals = THLongTensor_data(vals_);; + unsigned char *mask = THByteTensor_data(mask_); + long size = get_tensor_size(keys_, Long); + for (long i = 0; i < size; i++) + mask[i] = hash_map_get(h_, keys[i], &vals[i]); +} + +void hash_map_del(hash_map_t h_, long key) { + internal_hash_map_t h = (internal_hash_map_t) h_; + khiter_t k = kh_get(long, h, key); + if (k != kh_end(h)) + kh_del(long, h, k); +} + +void hash_map_del_tensor(hash_map_t h_, THLongTensor *keys_) { + long *keys = THLongTensor_data(keys_); + long size = get_tensor_size(keys_, Long); + for (long i = 0; i < size; i++) + hash_map_del(h_, keys[i]); +} + +size_t hash_map_size(hash_map_t h_) { + internal_hash_map_t h = (internal_hash_map_t) h_; + return kh_size(h); +} + +void hash_map_to_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) { + internal_hash_map_t h = (internal_hash_map_t) h_; + long *keys = THLongTensor_data(keys_); + long *vals = THLongTensor_data(vals_); + long key, val, i = 0; + kh_foreach(h, key, val, { + keys[i] = key; + vals[i] = val; + i++; + }); +} + +static void autolock(hash_map_lua_t *h) { + if (h->autolock) { + pthread_mutex_lock(&h->mutex); + } +} + +static void autounlock(hash_map_lua_t *h) { + if (h->autolock) { + pthread_mutex_unlock(&h->mutex); + } +} + +int hash_map_autolock_on_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + h->autolock = 1; + return 0; +} + +int hash_map_autolock_off_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + h->autolock = 0; + return 0; +} + +int hash_map_init_lua(lua_State *L) { + hash_map_lua_t **hp = (hash_map_lua_t**)lua_newuserdata(L, sizeof(hash_map_lua_t*)); + *hp = (hash_map_lua_t*)malloc(sizeof(hash_map_lua_t)); + hash_map_lua_t *h = *hp; + h->refcount = 1; + h->counter = 0; + h->autolock = 0; + h->h = hash_map_init(); + + pthread_mutexattr_t mutex_attr; + pthread_mutexattr_init(&mutex_attr); + pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE); + pthread_mutex_init(&h->mutex, &mutex_attr); + + luaL_getmetatable(L, "dt.HashMap"); + lua_setmetatable(L, -2); + return 1; +} + +int hash_map_gc_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + if (THAtomicDecrementRef(&h->refcount)) { + pthread_mutex_destroy(&h->mutex); + hash_map_destroy(h->h); + free(h); + } + return 0; +} + +int hash_map_retain_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + THAtomicIncrementRef(&h->refcount); + return 0; +} + +int hash_map_metatablename_lua(lua_State *L) { + lua_pushstring(L, "dt.HashMap"); + return 1; +} + +int hash_map_clear_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + autolock(h); + hash_map_clear(h->h); + autounlock(h); + return 0; +} + +int hash_map_put_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + int ret; +#if LUA_VERSION_NUM <= 501 +#define lua_isinteger lua_isnumber +#endif + if (lua_isinteger(L, 2)) { + if (!lua_isinteger(L, 3)) + return LUA_HANDLE_ERROR_STR(L, "second parameter is not a number"); + long key = lua_tointeger(L, 2); + long val = lua_tointeger(L, 3); + autolock(h); + ret = hash_map_put(h->h, key, val); + autounlock(h); + } + else { + THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); + THLongTensor *vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); + check_tensor(L, keys, THLongTensor); + check_tensor(L, vals, THLongTensor); + check_tensors(L, keys, vals); + autolock(h); + ret = hash_map_put_tensor(h->h, keys, vals); + autounlock(h); + } + if (!ret) + return LUA_HANDLE_ERROR_STR(L, "failed to put into hash map"); + return 0; +} + +int hash_map_fill_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + int ret; + if (lua_isinteger(L, 2)) { + long key = lua_tointeger(L, 2); + autolock(h); + ret = hash_map_fill(h->h, key, &h->counter); + autounlock(h); + } + else { + THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); + check_tensor(L, keys, THLongTensor); + autolock(h); + ret = hash_map_fill_tensor(h->h, keys, &h->counter); + autounlock(h); + } + if (!ret) + return LUA_HANDLE_ERROR_STR(L, "failed to fill into hash map"); + return 0; +} + +int hash_map_adjust_counter_lua(lua_State *L) { + hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); + internal_hash_map_t h = (internal_hash_map_t) h_->h; + + long val; + kh_foreach_value(h, val, { + if (val >= h_->counter) + h_->counter = val; + }); + return 0; +} + +int hash_map_set_counter_lua(lua_State *L) { + hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); + h_->counter = lua_tointeger(L, 2); + return 0; +} + +int hash_map_get_counter_lua(lua_State *L) { + hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1); + lua_pushinteger(L, h_->counter); + return 1; +} + +static int hash_map_get_tensor_lua(lua_State *L, hash_map_lua_t *h, int inplace) { + THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); + check_tensor(L, keys, THLongTensor); + THLongTensor *vals = inplace ? keys : NULL; + THByteTensor *mask = NULL; + + int maskIdx = inplace ? 3 : 4; + + if (!inplace) { + if (lua_gettop(L) < 3) { + vals = THLongTensor_new(); + } else { + vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); + check_tensor(L, vals, THLongTensor); + } + } + + if (lua_gettop(L) < maskIdx) { + mask = THByteTensor_new(); + } else { + mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor"); + check_tensor(L, mask, THByteTensor); + } + + int n_dim = THLongTensor_nDimension(keys); + THLongStorage *st = THLongStorage_newWithSize1(n_dim); + for (int i = 0; i < n_dim; i++) { + THLongStorage_set(st, i, THLongTensor_size(keys, i)); + } + THByteTensor_resize(mask, st, NULL); + if (!inplace) THLongTensor_resize(vals, st, NULL); + THLongStorage_free(st); + + autolock(h); + hash_map_get_tensor(h->h, keys, vals, mask); + autounlock(h); + + if (!inplace && lua_gettop(L) < 3) + luaT_pushudata(L, vals, "torch.LongTensor"); + if (lua_gettop(L) < maskIdx) + luaT_pushudata(L, mask, "torch.ByteTensor"); + + return 2; +} + +static int hash_map_get_table_lua(lua_State *L, hash_map_lua_t *h, int inplace) { + const int kidx = 2; + const int vidx = inplace ? 2 : 3; + const int midx = inplace ? 3 : 4; + const int narg = lua_gettop(L); + + if (inplace) { + if (narg < 3) { + LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace requires two arguments."); + } + } else { + if (narg < 4) { + LUA_HANDLE_ERROR_STR(L, "HashMap.get requires three arguments."); + } + } + + int count = push_table_contents(L, kidx); + verify_push_table_contents(L, vidx, count); + verify_push_table_contents(L, midx, count); + + THLongTensor *keys; + THLongTensor *vals; + THByteTensor *mask; + for (int i = count - 1; i >= 0; i--) { + int maskIdx = i - count; + int valIdx = maskIdx - count; + int keyIdx = inplace ? valIdx : (valIdx - count); + + keys = (THLongTensor *)luaT_checkudata(L, keyIdx, "torch.LongTensor"); + check_tensor(L, keys, THLongTensor); + if (inplace) { + vals = keys; + } else { + vals = (THLongTensor *)luaT_checkudata(L, valIdx, "torch.LongTensor"); + } + mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor"); + + int n_dim = THLongTensor_nDimension(keys); + THLongStorage *st = THLongStorage_newWithSize1(n_dim); + for (int i = 0; i < n_dim; i++) { + THLongStorage_set(st, i, THLongTensor_size(keys, i)); + } + THByteTensor_resize(mask, st, NULL); + THLongTensor_resize(vals, st, NULL); + THLongStorage_free(st); + + autolock(h); + hash_map_get_tensor(h->h, keys, vals, mask); + autounlock(h); + } + lua_pop(L, (narg - 1) * count); + return 2; +} + +int hash_map_get_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + if (lua_isinteger(L, 2)) { + long key = lua_tointeger(L, 2); + long val; + autolock(h); + int ret = hash_map_get(h->h, key, &val); + autounlock(h); + if (ret) { + lua_pushinteger(L, val); + lua_pushinteger(L, 1); + } + else { + lua_pushinteger(L, 0); + lua_pushinteger(L, 0); + } + } else if (lua_istable(L, 2)) { + return hash_map_get_table_lua(L, h, 0); + } else { + return hash_map_get_tensor_lua(L, h, 0); + } + return 2; +} + +int hash_map_get_inplace_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + if (lua_isinteger(L, 2)) { + LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace does not support integer arguments."); + } else if (lua_istable(L, 2)) { + return hash_map_get_table_lua(L, h, 1); + } else { + return hash_map_get_tensor_lua(L, h, 1); + } + return 2; +} + +int hash_map_del_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + if (lua_isinteger(L, 2)) { + long key = lua_tointeger(L, 2); + autolock(h); + hash_map_del(h->h, key); + autounlock(h); + } + else { + THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); + autolock(h); + hash_map_del_tensor(h->h, keys); + autounlock(h); + } + return 0; +} + +int hash_map_size_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + long size = hash_map_size(h->h); + lua_pushinteger(L, size); + return 1; +} + +int hash_map_to_tensor_lua(lua_State *L) { + hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1); + THLongTensor *keys, *vals; + + if (lua_gettop(L) < 2) { + keys = THLongTensor_new(); + } + else { + keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor"); + check_tensor(L, keys, THLongTensor); + } + + if (lua_gettop(L) < 3) { + vals = THLongTensor_new(); + } + else { + vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor"); + check_tensor(L, vals, THLongTensor); + } + + size_t size = hash_map_size(h->h); + THLongTensor_resize1d(keys, size); + THLongTensor_resize1d(vals, size); + + autolock(h); + hash_map_to_tensor(h->h, keys, vals); + autounlock(h); + + if (lua_gettop(L) < 2) + luaT_pushudata(L, keys, "torch.LongTensor"); + if (lua_gettop(L) < 3) + luaT_pushudata(L, vals, "torch.LongTensor"); + return 2; +} |