123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445 |
- #include "utils.h"
- #include "hash_map.h"
- #include "internal_hash_map.h"
- #include <pthread.h>
-
- hash_map_t hash_map_init(void) {
- return kh_init(long);
- }
-
- void hash_map_destroy(hash_map_t h_) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- kh_destroy(long, h);
- }
-
- void hash_map_clear(hash_map_t h_) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- kh_clear(long, h);
- }
-
- int hash_map_put(hash_map_t h_, long key, long val) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- int ret;
- khiter_t k = kh_put(long, h, key, &ret);
- ret = (ret >= 0);
- if (ret)
- kh_value(h, k) = val;
- return ret;
- }
-
- int hash_map_put_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) {
- long *keys = THLongTensor_data(keys_);
- long *vals = THLongTensor_data(vals_);
- long size = get_tensor_size(keys_, Long);
- for (long i = 0; i < size; i++)
- if (!hash_map_put(h_, keys[i], vals[i]))
- return 0;
- return 1;
- }
-
- int hash_map_fill(hash_map_t h_, long key, long *counter) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- khiter_t k = kh_get(long, h, key);
- if (k == kh_end(h))
- return hash_map_put(h_, key, ++(*counter));
- return 1;
- }
-
- int hash_map_fill_tensor(hash_map_t h_, THLongTensor *keys_, long *counter) {
- long *keys = THLongTensor_data(keys_);
- long size = get_tensor_size(keys_, Long);
- for (long i = 0; i < size; i++)
- if (!hash_map_fill(h_, keys[i], counter))
- return 0;
- return 1;
- }
-
- int hash_map_get(hash_map_t h_, long key, long* val) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- khiter_t k = kh_get(long, h, key);
- if (k == kh_end(h))
- return 0;
- *val = kh_value(h, k);
- return 1;
- }
-
- void hash_map_get_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_, THByteTensor *mask_) {
- long *keys = THLongTensor_data(keys_);
- long *vals = THLongTensor_data(vals_);;
- unsigned char *mask = THByteTensor_data(mask_);
- long size = get_tensor_size(keys_, Long);
- for (long i = 0; i < size; i++)
- mask[i] = hash_map_get(h_, keys[i], &vals[i]);
- }
-
- void hash_map_del(hash_map_t h_, long key) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- khiter_t k = kh_get(long, h, key);
- if (k != kh_end(h))
- kh_del(long, h, k);
- }
-
- void hash_map_del_tensor(hash_map_t h_, THLongTensor *keys_) {
- long *keys = THLongTensor_data(keys_);
- long size = get_tensor_size(keys_, Long);
- for (long i = 0; i < size; i++)
- hash_map_del(h_, keys[i]);
- }
-
- size_t hash_map_size(hash_map_t h_) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- return kh_size(h);
- }
-
- void hash_map_to_tensor(hash_map_t h_, THLongTensor *keys_, THLongTensor *vals_) {
- internal_hash_map_t h = (internal_hash_map_t) h_;
- long *keys = THLongTensor_data(keys_);
- long *vals = THLongTensor_data(vals_);
- long key, val, i = 0;
- kh_foreach(h, key, val, {
- keys[i] = key;
- vals[i] = val;
- i++;
- });
- }
-
- static void autolock(hash_map_lua_t *h) {
- if (h->autolock) {
- pthread_mutex_lock(&h->mutex);
- }
- }
-
- static void autounlock(hash_map_lua_t *h) {
- if (h->autolock) {
- pthread_mutex_unlock(&h->mutex);
- }
- }
-
- int hash_map_autolock_on_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- h->autolock = 1;
- return 0;
- }
-
- int hash_map_autolock_off_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- h->autolock = 0;
- return 0;
- }
-
- int hash_map_init_lua(lua_State *L) {
- hash_map_lua_t **hp = (hash_map_lua_t**)lua_newuserdata(L, sizeof(hash_map_lua_t*));
- *hp = (hash_map_lua_t*)malloc(sizeof(hash_map_lua_t));
- hash_map_lua_t *h = *hp;
- h->refcount = 1;
- h->counter = 0;
- h->autolock = 0;
- h->h = hash_map_init();
-
- pthread_mutexattr_t mutex_attr;
- pthread_mutexattr_init(&mutex_attr);
- pthread_mutexattr_settype(&mutex_attr, PTHREAD_MUTEX_RECURSIVE);
- pthread_mutex_init(&h->mutex, &mutex_attr);
-
- luaL_getmetatable(L, "dt.HashMap");
- lua_setmetatable(L, -2);
- return 1;
- }
-
- int hash_map_gc_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- if (THAtomicDecrementRef(&h->refcount)) {
- pthread_mutex_destroy(&h->mutex);
- hash_map_destroy(h->h);
- free(h);
- }
- return 0;
- }
-
- int hash_map_retain_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- THAtomicIncrementRef(&h->refcount);
- return 0;
- }
-
- int hash_map_metatablename_lua(lua_State *L) {
- lua_pushstring(L, "dt.HashMap");
- return 1;
- }
-
- int hash_map_clear_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- autolock(h);
- hash_map_clear(h->h);
- autounlock(h);
- return 0;
- }
-
- int hash_map_put_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- int ret;
- #if LUA_VERSION_NUM <= 501
- #define lua_isinteger lua_isnumber
- #endif
- if (lua_isinteger(L, 2)) {
- if (!lua_isinteger(L, 3))
- return LUA_HANDLE_ERROR_STR(L, "second parameter is not a number");
- long key = lua_tointeger(L, 2);
- long val = lua_tointeger(L, 3);
- autolock(h);
- ret = hash_map_put(h->h, key, val);
- autounlock(h);
- }
- else {
- THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
- THLongTensor *vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
- check_tensor(L, keys, THLongTensor);
- check_tensor(L, vals, THLongTensor);
- check_tensors(L, keys, vals);
- autolock(h);
- ret = hash_map_put_tensor(h->h, keys, vals);
- autounlock(h);
- }
- if (!ret)
- return LUA_HANDLE_ERROR_STR(L, "failed to put into hash map");
- return 0;
- }
-
- int hash_map_fill_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- int ret;
- if (lua_isinteger(L, 2)) {
- long key = lua_tointeger(L, 2);
- autolock(h);
- ret = hash_map_fill(h->h, key, &h->counter);
- autounlock(h);
- }
- else {
- THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
- check_tensor(L, keys, THLongTensor);
- autolock(h);
- ret = hash_map_fill_tensor(h->h, keys, &h->counter);
- autounlock(h);
- }
- if (!ret)
- return LUA_HANDLE_ERROR_STR(L, "failed to fill into hash map");
- return 0;
- }
-
- int hash_map_adjust_counter_lua(lua_State *L) {
- hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
- internal_hash_map_t h = (internal_hash_map_t) h_->h;
-
- long val;
- kh_foreach_value(h, val, {
- if (val >= h_->counter)
- h_->counter = val;
- });
- return 0;
- }
-
- int hash_map_set_counter_lua(lua_State *L) {
- hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
- h_->counter = lua_tointeger(L, 2);
- return 0;
- }
-
- int hash_map_get_counter_lua(lua_State *L) {
- hash_map_lua_t *h_ = *(hash_map_lua_t**)lua_touserdata(L, 1);
- lua_pushinteger(L, h_->counter);
- return 1;
- }
-
- static int hash_map_get_tensor_lua(lua_State *L, hash_map_lua_t *h, int inplace) {
- THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
- check_tensor(L, keys, THLongTensor);
- THLongTensor *vals = inplace ? keys : NULL;
- THByteTensor *mask = NULL;
-
- int maskIdx = inplace ? 3 : 4;
-
- if (!inplace) {
- if (lua_gettop(L) < 3) {
- vals = THLongTensor_new();
- } else {
- vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
- check_tensor(L, vals, THLongTensor);
- }
- }
-
- if (lua_gettop(L) < maskIdx) {
- mask = THByteTensor_new();
- } else {
- mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor");
- check_tensor(L, mask, THByteTensor);
- }
-
- int n_dim = THLongTensor_nDimension(keys);
- THLongStorage *st = THLongStorage_newWithSize1(n_dim);
- for (int i = 0; i < n_dim; i++) {
- THLongStorage_set(st, i, THLongTensor_size(keys, i));
- }
- THByteTensor_resize(mask, st, NULL);
- if (!inplace) THLongTensor_resize(vals, st, NULL);
- THLongStorage_free(st);
-
- autolock(h);
- hash_map_get_tensor(h->h, keys, vals, mask);
- autounlock(h);
-
- if (!inplace && lua_gettop(L) < 3)
- luaT_pushudata(L, vals, "torch.LongTensor");
- if (lua_gettop(L) < maskIdx)
- luaT_pushudata(L, mask, "torch.ByteTensor");
-
- return 2;
- }
-
- static int hash_map_get_table_lua(lua_State *L, hash_map_lua_t *h, int inplace) {
- const int kidx = 2;
- const int vidx = inplace ? 2 : 3;
- const int midx = inplace ? 3 : 4;
- const int narg = lua_gettop(L);
-
- if (inplace) {
- if (narg < 3) {
- LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace requires two arguments.");
- }
- } else {
- if (narg < 4) {
- LUA_HANDLE_ERROR_STR(L, "HashMap.get requires three arguments.");
- }
- }
-
- int count = push_table_contents(L, kidx);
- verify_push_table_contents(L, vidx, count);
- verify_push_table_contents(L, midx, count);
-
- THLongTensor *keys;
- THLongTensor *vals;
- THByteTensor *mask;
- for (int i = count - 1; i >= 0; i--) {
- int maskIdx = i - count;
- int valIdx = maskIdx - count;
- int keyIdx = inplace ? valIdx : (valIdx - count);
-
- keys = (THLongTensor *)luaT_checkudata(L, keyIdx, "torch.LongTensor");
- check_tensor(L, keys, THLongTensor);
- if (inplace) {
- vals = keys;
- } else {
- vals = (THLongTensor *)luaT_checkudata(L, valIdx, "torch.LongTensor");
- }
- mask = (THByteTensor *)luaT_checkudata(L, maskIdx, "torch.ByteTensor");
-
- int n_dim = THLongTensor_nDimension(keys);
- THLongStorage *st = THLongStorage_newWithSize1(n_dim);
- for (int i = 0; i < n_dim; i++) {
- THLongStorage_set(st, i, THLongTensor_size(keys, i));
- }
- THByteTensor_resize(mask, st, NULL);
- THLongTensor_resize(vals, st, NULL);
- THLongStorage_free(st);
-
- autolock(h);
- hash_map_get_tensor(h->h, keys, vals, mask);
- autounlock(h);
- }
- lua_pop(L, (narg - 1) * count);
- return 2;
- }
-
- int hash_map_get_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- if (lua_isinteger(L, 2)) {
- long key = lua_tointeger(L, 2);
- long val;
- autolock(h);
- int ret = hash_map_get(h->h, key, &val);
- autounlock(h);
- if (ret) {
- lua_pushinteger(L, val);
- lua_pushinteger(L, 1);
- }
- else {
- lua_pushinteger(L, 0);
- lua_pushinteger(L, 0);
- }
- } else if (lua_istable(L, 2)) {
- return hash_map_get_table_lua(L, h, 0);
- } else {
- return hash_map_get_tensor_lua(L, h, 0);
- }
- return 2;
- }
-
- int hash_map_get_inplace_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- if (lua_isinteger(L, 2)) {
- LUA_HANDLE_ERROR_STR(L, "HashMap.getInplace does not support integer arguments.");
- } else if (lua_istable(L, 2)) {
- return hash_map_get_table_lua(L, h, 1);
- } else {
- return hash_map_get_tensor_lua(L, h, 1);
- }
- return 2;
- }
-
- int hash_map_del_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- if (lua_isinteger(L, 2)) {
- long key = lua_tointeger(L, 2);
- autolock(h);
- hash_map_del(h->h, key);
- autounlock(h);
- }
- else {
- THLongTensor *keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
- autolock(h);
- hash_map_del_tensor(h->h, keys);
- autounlock(h);
- }
- return 0;
- }
-
- int hash_map_size_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- long size = hash_map_size(h->h);
- lua_pushinteger(L, size);
- return 1;
- }
-
- int hash_map_to_tensor_lua(lua_State *L) {
- hash_map_lua_t *h = *(hash_map_lua_t**)lua_touserdata(L, 1);
- THLongTensor *keys, *vals;
-
- if (lua_gettop(L) < 2) {
- keys = THLongTensor_new();
- }
- else {
- keys = (THLongTensor *)luaT_checkudata(L, 2, "torch.LongTensor");
- check_tensor(L, keys, THLongTensor);
- }
-
- if (lua_gettop(L) < 3) {
- vals = THLongTensor_new();
- }
- else {
- vals = (THLongTensor *)luaT_checkudata(L, 3, "torch.LongTensor");
- check_tensor(L, vals, THLongTensor);
- }
-
- size_t size = hash_map_size(h->h);
- THLongTensor_resize1d(keys, size);
- THLongTensor_resize1d(vals, size);
-
- autolock(h);
- hash_map_to_tensor(h->h, keys, vals);
- autounlock(h);
-
- if (lua_gettop(L) < 2)
- luaT_pushudata(L, keys, "torch.LongTensor");
- if (lua_gettop(L) < 3)
- luaT_pushudata(L, vals, "torch.LongTensor");
- return 2;
- }
|