aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/torch/decisiontree/hash_map.c
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/torch/decisiontree/hash_map.c')
-rw-r--r--contrib/torch/decisiontree/hash_map.c445
1 files changed, 445 insertions, 0 deletions
diff --git a/contrib/torch/decisiontree/hash_map.c b/contrib/torch/decisiontree/hash_map.c
new file mode 100644
index 000000000..2c679e207
--- /dev/null
+++ b/contrib/torch/decisiontree/hash_map.c
@@ -0,0 +1,445 @@
+#include "utils.h"
+#include "hash_map.h"
+#include "internal_hash_map.h"
+#include <pthread.h>
+
+hash_map_t hash_map_init(void) {
+ return kh_init(long);
+}
+
+void hash_map_destroy(hash_map_t h_) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ kh_destroy(long, h);
+}
+
+void hash_map_clear(hash_map_t h_) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ kh_clear(long, h);
+}
+
+int hash_map_put(hash_map_t h_, long key, long val) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ int ret;
+ khiter_t k = kh_put(long, h, key, &ret);
+ ret = (ret >= 0);
+ if (ret)
+ kh_value(h, k) = val;
+ return ret;
+}
+
+int hash_map_put_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) {
+ long *keys = THLongTensor_data(keys_);
+ long *vals = THLongTensor_data(vals_);
+ long size = get_tensor_size(keys_, Long);
+ for (long i = 0; i < size; i++)
+ if (!hash_map_put(h_, keys[i], vals[i]))
+ return 0;
+ return 1;
+}
+
+int hash_map_fill(hash_map_t h_, long key, long *counter) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ khiter_t k = kh_get(long, h, key);
+ if (k == kh_end(h))
+ return hash_map_put(h_, key, ++(*counter));
+ return 1;
+}
+
+int hash_map_fill_tensor(hash_map_t h_, THLongTensor *keys_, long *counter) {
+ long *keys = THLongTensor_data(keys_);
+ long size = get_tensor_size(keys_, Long);
+ for (long i = 0; i < size; i++)
+ if (!hash_map_fill(h_, keys[i], counter))
+ return 0;
+ return 1;
+}
+
+int hash_map_get(hash_map_t h_, long key, long* val) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ khiter_t k = kh_get(long, h, key);
+ if (k == kh_end(h))
+ return 0;
+ *val = kh_value(h, k);
+ return 1;
+}
+
+void hash_map_get_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_) {
+ long *keys = THLongTensor_data(keys_);
+ long *vals = THLongTensor_data(vals_);;
+ unsigned char *mask = THByteTensor_data(mask_);
+ long size = get_tensor_size(keys_, Long);
+ for (long i = 0; i < size; i++)
+ mask[i] = hash_map_get(h_, keys[i], &vals[i]);
+}
+
+void hash_map_del(hash_map_t h_, long key) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ khiter_t k = kh_get(long, h, key);
+ if (k != kh_end(h))
+ kh_del(long, h, k);
+}
+
+void hash_map_del_tensor(hash_map_t h_, THLongTensor *keys_) {
+ long *keys = THLongTensor_data(keys_);
+ long size = get_tensor_size(keys_, Long);
+ for (long i = 0; i < size; i++)
+ hash_map_del(h_, keys[i]);
+}
+
+size_t hash_map_size(hash_map_t h_) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ return kh_size(h);
+}
+
+void hash_map_to_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) {
+ internal_hash_map_t h = (internal_hash_map_t) h_;
+ long *keys = THLongTensor_data(keys_);
+ long *vals = THLongTensor_data(vals_);
+ long key, val, i = 0;
+ kh_foreach(h, key, val, {
+ keys[i] = key;
+ vals[i] = val;
+ i++;
+ });
+}
+
+static void autolock(hash_map_lua_t *h) {
+ if (h->autolock) {
+ pthread_mutex_lock(&h->mutex);
+ }
+}
+
+static void autounlock(hash_map_lua_t *h) {
+ if (h->autolock) {
+ pthread_mutex_unlock(&h->mutex);
+ }
+}
+
+int hash_map_autolock_on_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ h->autolock = 1;
+ return 0;
+}
+
+int hash_map_autolock_off_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ h->autolock = 0;
+ return 0;
+}
+
+int hash_map_init_lua(lua_State *L) {
+ hash_map_lua_t **hp = (hash_map_lua_t**)lua_newuserdata(L, sizeof(hash_map_lua_t*));
+ *hp = (hash_map_lua_t*)malloc(sizeof(hash_map_lua_t));
+ hash_map_lua_t *h = *hp;
+ h->refcount = 1;
+ h->counter = 0;
+ h->autolock = 0;
+ h->h = hash_map_init();
+
+ pthread_mutexattr_t mutex_attr;
+ pthread_mutexattr_init(&mutex_attr);
+ pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE);
+ pthread_mutex_init(&h->mutex, &mutex_attr);
+
+ luaL_getmetatable(L, "dt.HashMap");
+ lua_setmetatable(L, -2);
+ return 1;
+}
+
+int hash_map_gc_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ if (THAtomicDecrementRef(&h->refcount)) {
+ pthread_mutex_destroy(&h->mutex);
+ hash_map_destroy(h->h);
+ free(h);
+ }
+ return 0;
+}
+
+int hash_map_retain_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ THAtomicIncrementRef(&h->refcount);
+ return 0;
+}
+
+int hash_map_metatablename_lua(lua_State *L) {
+ lua_pushstring(L, "dt.HashMap");
+ return 1;
+}
+
+int hash_map_clear_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ autolock(h);
+ hash_map_clear(h->h);
+ autounlock(h);
+ return 0;
+}
+
+int hash_map_put_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ int ret;
+#if LUA_VERSION_NUM <= 501
+#define lua_isinteger lua_isnumber
+#endif
+ if (lua_isinteger(L, 2)) {
+ if (!lua_isinteger(L, 3))
+ return LUA_HANDLE_ERROR_STR(L, "second parameter is not a number");
+ long key = lua_tointeger(L, 2);
+ long val = lua_tointeger(L, 3);
+ autolock(h);
+ ret = hash_map_put(h->h, key, val);
+ autounlock(h);
+ }
+ else {
+ THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
+ THLongTensor *vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
+ check_tensor(L, keys, THLongTensor);
+ check_tensor(L, vals, THLongTensor);
+ check_tensors(L, keys, vals);
+ autolock(h);
+ ret = hash_map_put_tensor(h->h, keys, vals);
+ autounlock(h);
+ }
+ if (!ret)
+ return LUA_HANDLE_ERROR_STR(L, "failed to put into hash map");
+ return 0;
+}
+
+int hash_map_fill_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ int ret;
+ if (lua_isinteger(L, 2)) {
+ long key = lua_tointeger(L, 2);
+ autolock(h);
+ ret = hash_map_fill(h->h, key, &h->counter);
+ autounlock(h);
+ }
+ else {
+ THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
+ check_tensor(L, keys, THLongTensor);
+ autolock(h);
+ ret = hash_map_fill_tensor(h->h, keys, &h->counter);
+ autounlock(h);
+ }
+ if (!ret)
+ return LUA_HANDLE_ERROR_STR(L, "failed to fill into hash map");
+ return 0;
+}
+
+int hash_map_adjust_counter_lua(lua_State *L) {
+ hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ internal_hash_map_t h = (internal_hash_map_t) h_->h;
+
+ long val;
+ kh_foreach_value(h, val, {
+ if (val >= h_->counter)
+ h_->counter = val;
+ });
+ return 0;
+}
+
+int hash_map_set_counter_lua(lua_State *L) {
+ hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ h_->counter = lua_tointeger(L, 2);
+ return 0;
+}
+
+int hash_map_get_counter_lua(lua_State *L) {
+ hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ lua_pushinteger(L, h_->counter);
+ return 1;
+}
+
+static int hash_map_get_tensor_lua(lua_State *L, hash_map_lua_t *h, int inplace) {
+ THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
+ check_tensor(L, keys, THLongTensor);
+ THLongTensor *vals = inplace ? keys : NULL;
+ THByteTensor *mask = NULL;
+
+ int maskIdx = inplace ? 3 : 4;
+
+ if (!inplace) {
+ if (lua_gettop(L) < 3) {
+ vals = THLongTensor_new();
+ } else {
+ vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
+ check_tensor(L, vals, THLongTensor);
+ }
+ }
+
+ if (lua_gettop(L) < maskIdx) {
+ mask = THByteTensor_new();
+ } else {
+ mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor");
+ check_tensor(L, mask, THByteTensor);
+ }
+
+ int n_dim = THLongTensor_nDimension(keys);
+ THLongStorage *st = THLongStorage_newWithSize1(n_dim);
+ for (int i = 0; i < n_dim; i++) {
+ THLongStorage_set(st, i, THLongTensor_size(keys, i));
+ }
+ THByteTensor_resize(mask, st, NULL);
+ if (!inplace) THLongTensor_resize(vals, st, NULL);
+ THLongStorage_free(st);
+
+ autolock(h);
+ hash_map_get_tensor(h->h, keys, vals, mask);
+ autounlock(h);
+
+ if (!inplace && lua_gettop(L) < 3)
+ luaT_pushudata(L, vals, "torch.LongTensor");
+ if (lua_gettop(L) < maskIdx)
+ luaT_pushudata(L, mask, "torch.ByteTensor");
+
+ return 2;
+}
+
+static int hash_map_get_table_lua(lua_State *L, hash_map_lua_t *h, int inplace) {
+ const int kidx = 2;
+ const int vidx = inplace ? 2 : 3;
+ const int midx = inplace ? 3 : 4;
+ const int narg = lua_gettop(L);
+
+ if (inplace) {
+ if (narg < 3) {
+ LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace requires two arguments.");
+ }
+ } else {
+ if (narg < 4) {
+ LUA_HANDLE_ERROR_STR(L, "HashMap.get requires three arguments.");
+ }
+ }
+
+ int count = push_table_contents(L, kidx);
+ verify_push_table_contents(L, vidx, count);
+ verify_push_table_contents(L, midx, count);
+
+ THLongTensor *keys;
+ THLongTensor *vals;
+ THByteTensor *mask;
+ for (int i = count - 1; i >= 0; i--) {
+ int maskIdx = i - count;
+ int valIdx = maskIdx - count;
+ int keyIdx = inplace ? valIdx : (valIdx - count);
+
+ keys = (THLongTensor *)luaT_checkudata(L, keyIdx, "torch.LongTensor");
+ check_tensor(L, keys, THLongTensor);
+ if (inplace) {
+ vals = keys;
+ } else {
+ vals = (THLongTensor *)luaT_checkudata(L, valIdx, "torch.LongTensor");
+ }
+ mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor");
+
+ int n_dim = THLongTensor_nDimension(keys);
+ THLongStorage *st = THLongStorage_newWithSize1(n_dim);
+ for (int i = 0; i < n_dim; i++) {
+ THLongStorage_set(st, i, THLongTensor_size(keys, i));
+ }
+ THByteTensor_resize(mask, st, NULL);
+ THLongTensor_resize(vals, st, NULL);
+ THLongStorage_free(st);
+
+ autolock(h);
+ hash_map_get_tensor(h->h, keys, vals, mask);
+ autounlock(h);
+ }
+ lua_pop(L, (narg - 1) * count);
+ return 2;
+}
+
+int hash_map_get_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ if (lua_isinteger(L, 2)) {
+ long key = lua_tointeger(L, 2);
+ long val;
+ autolock(h);
+ int ret = hash_map_get(h->h, key, &val);
+ autounlock(h);
+ if (ret) {
+ lua_pushinteger(L, val);
+ lua_pushinteger(L, 1);
+ }
+ else {
+ lua_pushinteger(L, 0);
+ lua_pushinteger(L, 0);
+ }
+ } else if (lua_istable(L, 2)) {
+ return hash_map_get_table_lua(L, h, 0);
+ } else {
+ return hash_map_get_tensor_lua(L, h, 0);
+ }
+ return 2;
+}
+
+int hash_map_get_inplace_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ if (lua_isinteger(L, 2)) {
+ LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace does not support integer arguments.");
+ } else if (lua_istable(L, 2)) {
+ return hash_map_get_table_lua(L, h, 1);
+ } else {
+ return hash_map_get_tensor_lua(L, h, 1);
+ }
+ return 2;
+}
+
+int hash_map_del_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ if (lua_isinteger(L, 2)) {
+ long key = lua_tointeger(L, 2);
+ autolock(h);
+ hash_map_del(h->h, key);
+ autounlock(h);
+ }
+ else {
+ THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
+ autolock(h);
+ hash_map_del_tensor(h->h, keys);
+ autounlock(h);
+ }
+ return 0;
+}
+
+int hash_map_size_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ long size = hash_map_size(h->h);
+ lua_pushinteger(L, size);
+ return 1;
+}
+
+int hash_map_to_tensor_lua(lua_State *L) {
+ hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
+ THLongTensor *keys, *vals;
+
+ if (lua_gettop(L) < 2) {
+ keys = THLongTensor_new();
+ }
+ else {
+ keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
+ check_tensor(L, keys, THLongTensor);
+ }
+
+ if (lua_gettop(L) < 3) {
+ vals = THLongTensor_new();
+ }
+ else {
+ vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
+ check_tensor(L, vals, THLongTensor);
+ }
+
+ size_t size = hash_map_size(h->h);
+ THLongTensor_resize1d(keys, size);
+ THLongTensor_resize1d(vals, size);
+
+ autolock(h);
+ hash_map_to_tensor(h->h, keys, vals);
+ autounlock(h);
+
+ if (lua_gettop(L) < 2)
+ luaT_pushudata(L, keys, "torch.LongTensor");
+ if (lua_gettop(L) < 3)
+ luaT_pushudata(L, vals, "torch.LongTensor");
+ return 2;
+}