#ifndef TH_GENERIC_FILE #define TH_GENERIC_FILE "generic/Tensor.c" #else #include "luaG.h" static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride, THStorage **storage_, ptrdiff_t *storageOffset_, THLongStorage **size_, THLongStorage **stride_); static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_); static int torch_Tensor_(size)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); if(lua_isnumber(L,2)) { int dim = luaL_checkint(L, 2)-1; THArgCheck(dim >= 0 && dim < tensor->nDimension, 2, "dimension %d out of range of %dD tensor", dim+1, THTensor_(nDimension)(tensor)); luaT_pushlong(L, tensor->size[dim]); } else { THLongStorage *size = THTensor_(newSizeOf)(tensor); luaT_pushudata(L, size, "torch.LongStorage"); } return 1; } static int torch_Tensor_(elementSize)(lua_State *L) { luaT_pushinteger(L, THStorage_(elementSize)()); return 1; } static int torch_Tensor_(stride)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); if(lua_isnumber(L,2)) { int dim = luaL_checkint(L, 2)-1; THArgCheck(dim >= 0 && dim < tensor->nDimension, 2, "dimension %d out of range of %dD tensor", dim+1, THTensor_(nDimension)(tensor)); luaT_pushlong(L, tensor->stride[dim]); } else { THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension); luaT_pushudata(L, storage, "torch.LongStorage"); } return 1; } static int torch_Tensor_(nDimension)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); luaT_pushinteger(L, tensor->nDimension); return 1; } static int torch_Tensor_(storage)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); if(tensor->storage) { THStorage_(retain)(tensor->storage); luaT_pushudata(L, tensor->storage, torch_Storage); } else lua_pushnil(L); return 1; } static int torch_Tensor_(storageOffset)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); luaT_pushinteger(L, tensor->storageOffset+1); return 1; } static int torch_Tensor_(new)(lua_State *L) { THTensor *tensor; ptrdiff_t storageOffset; THLongStorage *size, *stride; if(lua_type(L, 1) == LUA_TTABLE) { ptrdiff_t i, j; THLongStorage *counter; ptrdiff_t si = 0; int dimension = 0; int is_finished = 0; lua_settop(L, 1); size = THLongStorage_new(); while( (lua_type(L, -1) == LUA_TTABLE) && (lua_objlen(L, -1) > 0) ) { THLongStorage_resize(size, dimension+1); size->data[dimension] = lua_objlen(L, -1); dimension++; lua_rawgeti(L, -1, 1); } lua_pop(L, 1); counter = THLongStorage_newWithSize(size->size); THLongStorage_fill(counter, 0); tensor = THTensor_(newWithSize)(size, NULL); if(size->size == 0) is_finished = 1; while(!is_finished) { if(!lua_istable(L, -1)) { THLongStorage_free(size); THLongStorage_free(counter); THTensor_(free)(tensor); THError("invalid tensor definition"); } if(lua_objlen(L, -1) != size->data[size->size-1]) { THLongStorage_free(size); THLongStorage_free(counter); THTensor_(free)(tensor); THError("invalid tensor sizes"); } for(i = 0; i < size->data[size->size-1]; i++) { lua_rawgeti(L, -1, i+1); if(!lua_isnumber(L, -1)) { THLongStorage_free(size); THLongStorage_free(counter); THTensor_(free)(tensor); THError("invalid element (not a number)"); } THStorage_(set)(THTensor_(storage)(tensor), si++, luaG_(checkreal)(L, -1)); lua_pop(L, 1); } if(size->size == 1) break; for(i = size->size-2; i >= 0; i--) { if(++counter->data[i] == size->data[i]) { if(i == 0) { is_finished = 1; break; } else { counter->data[i] = 0; lua_pop(L, 1); } } else { lua_pop(L, 1); for(j = i; j < size->size-1; j++) { if(!lua_istable(L, -1)) { THLongStorage_free(size); THLongStorage_free(counter); THTensor_(free)(tensor); THError("invalid tensor definition"); } if(lua_objlen(L, -1) != size->data[j]) { THLongStorage_free(size); THLongStorage_free(counter); THTensor_(free)(tensor); THError("invalid tensor sizes"); } lua_rawgeti(L, -1, counter->data[j]+1); } break; } } } THLongStorage_free(size); THLongStorage_free(counter); } else { THStorage *storage; torch_Tensor_(c_readTensorStorageSizeStride)(L, 1, 1, 1, 1, 1, &storage, &storageOffset, &size, &stride); tensor = THTensor_(newWithStorage)(storage, storageOffset, size, stride); THLongStorage_free(size); THLongStorage_free(stride); } luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(set)(lua_State *L) { THTensor *self = luaT_checkudata(L, 1, torch_Tensor); THStorage *storage; ptrdiff_t storageOffset; THLongStorage *size, *stride; torch_Tensor_(c_readTensorStorageSizeStride)(L, 2, 1, 1, 1, 1, &storage, &storageOffset, &size, &stride); THTensor_(setStorage)(self, storage, storageOffset, size, stride); THLongStorage_free(size); THLongStorage_free(stride); lua_settop(L, 1); return 1; } static int torch_Tensor_(clone)(lua_State *L) { THTensor *self = luaT_checkudata(L, 1, torch_Tensor); self = THTensor_(newClone)(self); luaT_pushudata(L, self, torch_Tensor); return 1; } static int torch_Tensor_(contiguous)(lua_State *L) { THTensor *self = luaT_checkudata(L, 1, torch_Tensor); self = THTensor_(newContiguous)(self); luaT_pushudata(L, self, torch_Tensor); return 1; } /* Resize */ static int torch_Tensor_(resizeAs)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor *src = luaT_checkudata(L, 2, torch_Tensor); THTensor_(resizeAs)(tensor, src); lua_settop(L, 1); return 1; } static int torch_Tensor_(resize)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *size, *stride; torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride); THTensor_(resize)(tensor, size, stride); THLongStorage_free(size); THLongStorage_free(stride); lua_settop(L, 1); return 1; } static int torch_Tensor_(narrow)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension = luaL_checkint(L, 2)-1; long firstIndex = luaL_checklong(L, 3)-1; long size = luaL_checklong(L, 4); /* THArgCheck( (dimension >= 0) && (dimension < tensor->nDimension), 2, "out of range"); THArgCheck( (firstIndex >= 0) && (firstIndex < tensor->size[dimension]), 3, "out of range"); THArgCheck( (size > 0) && (firstIndex+size <= tensor->size[dimension]), 4, "out of range"); */ tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size); luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(sub)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1; d0s = luaL_checklong(L, 2)-1; d0e = luaL_checklong(L, 3)-1; if(d0s < 0) d0s += tensor->size[0]+1; if(d0e < 0) d0e += tensor->size[0]+1; THArgCheck(tensor->nDimension > 0, 2, "invalid dimension"); THArgCheck(d0s >= 0 && d0s < tensor->size[0], 2, "out of range"); THArgCheck(d0e >= 0 && d0e < tensor->size[0], 3, "out of range"); THArgCheck(d0e >= d0s, 3, "end smaller than beginning"); if(!lua_isnone(L, 4)) { d1s = luaL_checklong(L, 4)-1; d1e = luaL_checklong(L, 5)-1; if(d1s < 0) d1s += tensor->size[1]+1; if(d1e < 0) d1e += tensor->size[1]+1; THArgCheck(tensor->nDimension > 1, 4, "invalid dimension"); THArgCheck(d1s >= 0 && d1s < tensor->size[1], 4, "out of range"); THArgCheck(d1e >= 0 && d1e < tensor->size[1], 5, "out of range"); THArgCheck(d1e >= d1s, 5, "end smaller than beginning"); if(!lua_isnone(L, 6)) { d2s = luaL_checklong(L, 6)-1; d2e = luaL_checklong(L, 7)-1; if(d2s < 0) d2s += tensor->size[2]+1; if(d2e < 0) d2e += tensor->size[2]+1; THArgCheck(tensor->nDimension > 2, 6, "invalid dimension"); THArgCheck(d2s >= 0 && d2s < tensor->size[2], 6, "out of range"); THArgCheck(d2e >= 0 && d2e < tensor->size[2], 7, "out of range"); THArgCheck(d2e >= d2s, 7, "end smaller than beginning"); if(!lua_isnone(L, 8)) { d3s = luaL_checklong(L, 8)-1; d3e = luaL_checklong(L, 9)-1; if(d3s < 0) d3s += tensor->size[3]+1; if(d3e < 0) d3e += tensor->size[3]+1; THArgCheck(tensor->nDimension > 3, 8, "invalid dimension"); THArgCheck(d3s >= 0 && d3s < tensor->size[3], 8, "out of range"); THArgCheck(d3e >= 0 && d3e < tensor->size[3], 9, "out of range"); THArgCheck(d3e >= d3s, 9, "end smaller than beginning"); } } } tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, d0s, d0e-d0s+1); if(d1s >= 0) THTensor_(narrow)(tensor, NULL, 1, d1s, d1e-d1s+1); if(d2s >= 0) THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1); if(d3s >= 0) THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1); luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(select)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension = luaL_checkint(L, 2)-1; long sliceIndex = luaL_checklong(L, 3)-1; /* THArgCheck(src->nDimension > 1, 1, "cannot select on a vector"); THArgCheck((dimension >= 0) && (dimension < src->nDimension), 2, "out of range"); THArgCheck((sliceIndex >= 0) && (sliceIndex < src->size[dimension]), 3, "out of range"); */ if(tensor->nDimension > 1) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(select)(tensor, NULL, dimension, sliceIndex); luaT_pushudata(L, tensor, torch_Tensor); } else { THArgCheck(tensor->nDimension == 1, 1, "empty Tensor"); luaG_(pushreal)(L, THTensor_(get1d)(tensor, sliceIndex)); } return 1; } #ifndef TH_REAL_IS_HALF static int torch_Tensor_(indexSelect)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor, *src; THLongTensor *index; int dim; if (narg == 3) { tensor = THTensor_(new)(); src = luaT_checkudata(L, 1, torch_Tensor); dim = luaL_checkint(L, 2) - 1; index = luaT_checkudata(L, 3, "torch.LongTensor"); luaT_pushudata(L,tensor,torch_Tensor); } else if(narg == 4) { src = luaT_checkudata(L, 2, torch_Tensor); dim = luaL_checkint(L, 3) - 1; index = luaT_checkudata(L, 4, "torch.LongTensor"); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError(torch_Tensor ", number, torch.LongTensor | " torch_Tensor ", " torch_Tensor ", number, torch.LongTensor expected"); return 0; } THTensor_(indexSelect)(tensor,src,dim,index); return 1; } static int torch_Tensor_(indexCopy)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor, *src; THLongTensor *index; int dim; if(narg == 4) { dim = luaL_checkint(L, 2) - 1; index = luaT_checkudata(L, 3, "torch.LongTensor"); src = luaT_checkudata(L, 4, torch_Tensor); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError( torch_Tensor ", number, torch.LongTensor, " torch_Tensor " expected"); return 0; } THTensor_(indexCopy)(tensor,dim,index,src); return 1; } static int torch_Tensor_(indexAdd)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor, *src; THLongTensor *index; int dim; if(narg == 4) { dim = luaL_checkint(L, 2) - 1; index = luaT_checkudata(L, 3, "torch.LongTensor"); src = luaT_checkudata(L, 4, torch_Tensor); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError( torch_Tensor ", number, torch.LongTensor, " torch_Tensor " expected"); return 0; } THTensor_(indexAdd)(tensor,dim,index,src); return 1; } static int torch_Tensor_(indexFill)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor; THLongTensor *index; real val; int dim; if(narg == 4) { dim = luaL_checkint(L, 2) - 1; index = luaT_checkudata(L, 3, "torch.LongTensor"); val = luaG_(checkreal)(L, 4); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError( torch_Tensor ", number, torch.LongTensor, number expected"); return 0; } THTensor_(indexFill)(tensor,dim,index,val); return 1; } static int torch_Tensor_(maskedSelect)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor, *src; THByteTensor *mask; if (narg == 2) { tensor = THTensor_(new)(); src = luaT_checkudata(L, 1, torch_Tensor); mask = luaT_checkudata(L, 2, "torch.ByteTensor"); luaT_pushudata(L,tensor,torch_Tensor); } else if(narg == 3) { src = luaT_checkudata(L, 2, torch_Tensor); mask = luaT_checkudata(L, 3, "torch.ByteTensor"); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError( torch_Tensor ", torch.ByteTensor | " torch_Tensor ", " torch_Tensor ", torch.ByteTensor expected"); return 0; } THTensor_(maskedSelect)(tensor,src,mask); return 1; } static int torch_Tensor_(maskedCopy)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor, *src; THByteTensor *mask; if(narg == 3) { mask = luaT_checkudata(L, 2, "torch.ByteTensor"); src = luaT_checkudata(L, 3, torch_Tensor); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError( torch_Tensor ", torch.ByteTensor, " torch_Tensor " expected"); return 0; } THTensor_(maskedCopy)(tensor,mask,src); /* return destination */ lua_pop(L, 2); return 1; } static int torch_Tensor_(maskedFill)(lua_State *L) { int narg = lua_gettop(L); THTensor *tensor; THByteTensor *mask; real val; if(narg == 3) { mask = luaT_checkudata(L, 2, "torch.ByteTensor"); val = luaG_(checkreal)(L, 3); tensor = luaT_checkudata(L,1,torch_Tensor); } else { THError( torch_Tensor ", torch.ByteTensor, number expected"); return 0; } THTensor_(maskedFill)(tensor,mask,val); return 1; } #endif static int torch_Tensor_(transpose)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension1 = luaL_checkint(L, 2)-1; int dimension2 = luaL_checkint(L, 3)-1; /* THArgCheck( (dimension1 >= 0) && (dimension1 < src->nDimension), 2, "out of range"); THArgCheck( (dimension2 >= 0) && (dimension2 < src->nDimension), 3, "out of range"); */ tensor = THTensor_(newWithTensor)(tensor); THTensor_(transpose)(tensor, NULL, dimension1, dimension2); luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(t)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THArgCheck(tensor->nDimension == 2, 1, "Tensor must have 2 dimensions"); tensor = THTensor_(newWithTensor)(tensor); THTensor_(transpose)(tensor, NULL, 0, 1); luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(unfold)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension = luaL_checkint(L, 2)-1; long size = luaL_checklong(L, 3); long step = luaL_checklong(L, 4); /* THArgCheck( (src->nDimension > 0), 1, "cannot unfold an empty tensor"); THArgCheck(dimension < src->nDimension, 2, "out of range"); THArgCheck(size <= src->size[dimension], 3, "out of range"); */ tensor = THTensor_(newWithTensor)(tensor); THTensor_(unfold)(tensor, NULL, dimension, size, step); luaT_pushudata(L, tensor, torch_Tensor); return 1; } /* is contiguous? [a bit like in TnXIterator] */ static int torch_Tensor_(isContiguous)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); lua_pushboolean(L, THTensor_(isContiguous)(tensor)); return 1; } static int torch_Tensor_(isSize)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *size = luaT_checkudata(L, 2, "torch.LongStorage"); lua_pushboolean(L, THTensor_(isSize)(tensor, size)); return 1; } static int torch_Tensor_(isSameSizeAs)(lua_State *L) { THTensor *tensor1 = luaT_checkudata(L, 1, torch_Tensor); THTensor *tensor2 = luaT_checkudata(L, 2, torch_Tensor); lua_pushboolean(L, THTensor_(isSameSizeAs)(tensor1, tensor2)); return 1; } static int torch_Tensor_(isSetTo)(lua_State *L) { THTensor *tensor1 = luaT_checkudata(L, 1, torch_Tensor); THTensor *tensor2 = luaT_checkudata(L, 2, torch_Tensor); lua_pushboolean(L, THTensor_(isSetTo)(tensor1, tensor2)); return 1; } static int torch_Tensor_(nElement)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); luaT_pushinteger(L, THTensor_(nElement)(tensor)); return 1; } static int torch_Tensor_(copy)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); void *src; if( (src = luaT_toudata(L, 2, torch_Tensor)) ) THTensor_(copy)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) THTensor_(copyByte)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) THTensor_(copyChar)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) THTensor_(copyShort)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) THTensor_(copyInt)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) THTensor_(copyLong)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) THTensor_(copyFloat)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) THTensor_(copyDouble)(tensor, src); else if( (src = luaT_toudata(L, 2, "torch.HalfTensor")) ) THTensor_(copyHalf)(tensor, src); else luaL_typerror(L, 2, "torch.*Tensor"); lua_settop(L, 1); return 1; } static int torch_Tensor_(__newindex__)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *idx = NULL; THByteTensor *mask; if(lua_isnumber(L, 2)) { void *src; long index = luaL_checklong(L,2)-1; THArgCheck(tensor->nDimension > 0, 1, "empty tensor"); if (index < 0) index = tensor->size[0] + index + 1; if (lua_isnumber(L,3)) { real value = luaG_(checkreal)(L,3); if (tensor->nDimension == 1) { THArgCheck(index >= 0 && index < tensor->size[0], 2, "out of range"); THStorage_(set)(tensor->storage, tensor->storageOffset+index*tensor->stride[0], value); } else { #ifndef TH_REAL_IS_HALF tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(fill)(tensor, value); THTensor_(free)(tensor); #else THError("fill on torch.HalfTensor not yet supported"); #endif } } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copy)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyByte)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.CharTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyChar)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.ShortTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyShort)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.IntTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyInt)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.LongTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyLong)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.FloatTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyFloat)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyDouble)(tensor, src); THTensor_(free)(tensor); } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyHalf)(tensor, src); THTensor_(free)(tensor); } else { luaL_typerror(L, 3, "torch.*Tensor"); } lua_pushboolean(L, 1); } else if((idx = luaT_toudata(L, 2, "torch.LongStorage"))) { ptrdiff_t index = THTensor_(storageOffset)(tensor); real value = luaG_(checkreal)(L,3); int dim; THArgCheck(idx->size == tensor->nDimension, 2, "invalid size"); for(dim = 0; dim < idx->size; dim++) { long z = idx->data[dim]-1; if (z < 0) z = tensor->size[dim] + z + 1; THArgCheck((z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); index += z*tensor->stride[dim]; } THStorage_(set)(tensor->storage, index, value); lua_pushboolean(L, 1); } else if(lua_istable(L, 2)) { int dim; int cdim = 0; int ndims; int done = 0; ndims = tensor->nDimension; THArgCheck(lua_objlen(L, 2) <= ndims, 2, "too many indices provided"); tensor = THTensor_(newWithTensor)(tensor); for(dim = 0; dim < ndims; dim++) { lua_rawgeti(L, 2, dim+1); if(lua_isnumber(L, -1)) { long z = lua_tonumber(L, -1)-1; lua_pop(L, 1); if (z < 0) z = tensor->size[cdim] + z + 1; THArgCheck((z >= 0) && (z < tensor->size[cdim]), 2, "index out of bound"); if(tensor->nDimension == 1) { real value = luaG_(checkreal)(L,3); done = 1; THStorage_(set)(tensor->storage, tensor->storageOffset+z*tensor->stride[0], value); } else { THTensor_(select)(tensor, NULL, cdim, z); } } else if (lua_istable(L, -1)) { long start = 0; long end = tensor->size[cdim]-1; lua_rawgeti(L, -1, 1); if(lua_isnumber(L, -1)) { start = lua_tonumber(L, -1)-1; end = start; } lua_pop(L, 1); if (start < 0) start = tensor->size[cdim] + start + 1; THArgCheck((start >= 0) && (start < tensor->size[cdim]), 2, "start index out of bound"); lua_rawgeti(L, -1, 2); if(lua_isnumber(L, -1)) { end = lua_tonumber(L, -1)-1; } lua_pop(L, 2); if (end < 0) end = tensor->size[cdim] + end + 1; THArgCheck((end >= 0) && (end < tensor->size[cdim]), 2, "end index out of bound"); THArgCheck((end >= start), 2, "end index must be greater or equal to start index"); THTensor_(narrow)(tensor, NULL, cdim++, start, end-start+1); } else { break; } } if(!done) { /* doing a copy */ void *src; if (lua_isnumber(L,3)) { #ifndef TH_REAL_IS_HALF THTensor_(fill)(tensor, LUA_NUMBER_TO_REAL(lua_tonumber(L,3))); #else THError("fill on torch.HalfTensor not yet supported"); #endif } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) { THTensor_(copy)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) { THTensor_(copyByte)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.CharTensor")) ) { THTensor_(copyChar)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.ShortTensor")) ) { THTensor_(copyShort)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.IntTensor")) ) { THTensor_(copyInt)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.LongTensor")) ) { THTensor_(copyLong)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.FloatTensor")) ) { THTensor_(copyFloat)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) { THTensor_(copyDouble)(tensor, src); } else if( (src = luaT_toudata(L, 3, "torch.HalfTensor")) ) { THTensor_(copyHalf)(tensor, src); } else { luaL_typerror(L, 3, "torch.*Tensor"); } } THTensor_(free)(tensor); lua_pushboolean(L, 1); } else if((mask = luaT_toudata(L, 2, "torch.ByteTensor"))) { #ifndef TH_REAL_IS_HALF THTensor *vals; if (lua_isnumber(L, 3)) { THTensor_(maskedFill)(tensor, mask, luaG_(checkreal)(L,3)); } else if((vals = luaT_toudata(L, 3, torch_Tensor))) { THTensor_(maskedCopy)(tensor, mask, vals); } else { THError("number or " torch_Tensor " expected"); } #else THError("ByteTensor indexing not yet supported with half types"); #endif } else lua_pushboolean(L, 0); return 1; } static int torch_Tensor_(__index__)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *idx = NULL; THByteTensor *mask; if(lua_isnumber(L, 2)) { long index = luaL_checklong(L,2)-1; THArgCheck(tensor->nDimension > 0, 1, "empty tensor"); if (index < 0) index = tensor->size[0] + index + 1; THArgCheck(index >= 0 && index < tensor->size[0], 2, "out of range"); if(tensor->nDimension == 1) { luaG_(pushreal)(L, THStorage_(get)(tensor->storage, tensor->storageOffset+index*tensor->stride[0])); } else { tensor = THTensor_(newWithTensor)(tensor); THTensor_(select)(tensor, NULL, 0, index); luaT_pushudata(L, tensor, torch_Tensor); } lua_pushboolean(L, 1); return 2; } else if((idx = luaT_toudata(L, 2, "torch.LongStorage"))) { ptrdiff_t index = THTensor_(storageOffset)(tensor); int dim; THArgCheck(idx->size == tensor->nDimension, 2, "invalid size"); for(dim = 0; dim < idx->size; dim++) { long z = idx->data[dim]-1; if (z < 0) z = tensor->size[dim] + z + 1; THArgCheck((z >= 0) && (z < tensor->size[dim]), 2, "index out of bound"); index += z*tensor->stride[dim]; } luaG_(pushreal)(L, THStorage_(get)(THTensor_(storage)(tensor), index)); lua_pushboolean(L, 1); return 2; } else if(lua_istable(L, 2)) { int dim; int cdim = 0; int ndims; int done = 0; ndims = tensor->nDimension; THArgCheck(lua_objlen(L, 2) <= ndims, 2, "too many indices provided"); tensor = THTensor_(newWithTensor)(tensor); for(dim = 0; dim < ndims; dim++) { lua_rawgeti(L, 2, dim+1); if(lua_isnumber(L, -1)) { long z = lua_tonumber(L, -1)-1; lua_pop(L, 1); if (z < 0) z = tensor->size[cdim] + z + 1; THArgCheck((z >= 0) && (z < tensor->size[cdim]), 2, "index out of bound"); if(tensor->nDimension == 1) { done = 1; luaG_(pushreal)(L, THStorage_(get)(tensor->storage, tensor->storageOffset+z*tensor->stride[0])); } else { THTensor_(select)(tensor, NULL, cdim, z); } } else if (lua_istable(L, -1)) { long start = 0; long end = tensor->size[cdim]-1; lua_rawgeti(L, -1, 1); if(lua_isnumber(L, -1)) { start = lua_tonumber(L, -1)-1; end = start; } lua_pop(L, 1); if (start < 0) start = tensor->size[cdim] + start + 1; THArgCheck((start >= 0) && (start < tensor->size[cdim]), 2, "start index out of bound"); lua_rawgeti(L, -1, 2); if(lua_isnumber(L, -1)) { end = lua_tonumber(L, -1)-1; } lua_pop(L, 2); if (end < 0) end = tensor->size[cdim] + end + 1; THArgCheck((end >= 0) && (end < tensor->size[cdim]), 2, "end index out of bound"); THArgCheck((end >= start), 2, "end index must be greater or equal to start index"); THTensor_(narrow)(tensor, NULL, cdim++, start, end-start+1); } else { break; } } if(!done) { luaT_pushudata(L, tensor, torch_Tensor); } else { THTensor_(free)(tensor); } lua_pushboolean(L, 1); return 2; } else if((mask = luaT_toudata(L, 2, "torch.ByteTensor"))) { #ifndef TH_REAL_IS_HALF THTensor *vals = THTensor_(new)(); THTensor_(maskedSelect)(vals, tensor, mask); luaT_pushudata(L, vals, torch_Tensor); lua_pushboolean(L, 1); return 2; #else THError("ByteTensor based indexing not yetsupported with half type"); return 0; #endif } else { lua_pushboolean(L, 0); return 1; } } static int torch_Tensor_(retain)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor_(retain)(tensor); return 0; } static int torch_Tensor_(free)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor_(free)(tensor); return 0; } /* helpful functions */ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowStride, THLongStorage **size_, THLongStorage **stride_) { THLongStorage *size = NULL; THLongStorage *stride = NULL; if( (size = luaT_toudata(L, index, "torch.LongStorage")) ) { if(!lua_isnoneornil(L, index+1)) { if( (stride = luaT_toudata(L, index+1, "torch.LongStorage")) ) THArgCheck(stride->size == size->size, index+1, "provided stride and size are inconsistent"); else THArgCheck(0, index+1, "torch.LongStorage expected"); } THLongStorage_retain(size); if(stride) THLongStorage_retain(stride); } else { int i; size = THLongStorage_newWithSize(8); stride = THLongStorage_newWithSize(8); THLongStorage_fill(size, -1); THLongStorage_fill(stride, -1); if(allowStride) { for(i = 0; i < 8; i++) { if(lua_isnone(L, index+2*i)) break; size->data[i] = luaL_checklong(L, index+2*i); if(lua_isnone(L, index+2*i+1)) break; stride->data[i] = luaL_checklong(L, index+2*i+1); } } else { for(i = 0; i < 8; i++) { if(lua_isnone(L, index+i)) break; size->data[i] = luaL_checklong(L, index+i); } } } *size_ = size; *stride_ = stride; } static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index, int allowNone, int allowTensor, int allowStorage, int allowStride, THStorage **storage_, ptrdiff_t *storageOffset_, THLongStorage **size_, THLongStorage **stride_) { THTensor *src = NULL; THStorage *storage = NULL; int arg1Type = lua_type(L, index); if( allowNone && (arg1Type == LUA_TNONE) ) { *storage_ = NULL; *storageOffset_ = 0; *size_ = NULL; *stride_ = NULL; return; } else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor)) ) { *storage_ = src->storage; *storageOffset_ = src->storageOffset; *size_ = THTensor_(newSizeOf)(src); *stride_ = THTensor_(newStrideOf)(src); return; } else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage)) ) { *storage_ = storage; if(lua_isnone(L, index+1)) { *storageOffset_ = 0; *size_ = THLongStorage_newWithSize1(storage->size); *stride_ = THLongStorage_newWithSize1(1); } else { *storageOffset_ = luaL_checkinteger(L, index+1)-1; torch_Tensor_(c_readSizeStride)(L, index+2, allowStride, size_, stride_); } return; } else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, "torch.LongStorage")) ) { *storage_ = NULL; *storageOffset_ = 0; torch_Tensor_(c_readSizeStride)(L, index, 0, size_, stride_); return; } *storage_ = NULL; *storageOffset_ = 0; if(allowTensor && allowStorage) THArgCheck(0, index, "expecting number or " torch_Tensor " or " torch_Storage ); else if(allowTensor) THArgCheck(0, index, "expecting number or " torch_Tensor ); else if(allowStorage) THArgCheck(0, index, "expecting number or " torch_Storage ); else THArgCheck(0, index, "expecting number"); } #ifndef TH_REAL_IS_HALF static int torch_Tensor_(apply)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); luaL_checktype(L, 2, LUA_TFUNCTION); lua_settop(L, 2); TH_TENSOR_APPLY(real, tensor, lua_pushvalue(L, 2); luaG_(pushreal)(L, *tensor_data); lua_call(L, 1, 1); if(lua_isnumber(L, 3)) { *tensor_data = luaG_(checkreal)(L, 3); lua_pop(L, 1); } else if(lua_isnil(L, 3)) lua_pop(L, 1); else THError("given function should return a number or nil");); lua_settop(L, 1); return 1; } static int torch_Tensor_(map)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor *src = luaT_checkudata(L, 2, torch_Tensor); luaL_checktype(L, 3, LUA_TFUNCTION); lua_settop(L, 3); TH_TENSOR_APPLY2(real, tensor, real, src, lua_pushvalue(L, 3); luaG_(pushreal)(L, *tensor_data); luaG_(pushreal)(L, *src_data); lua_call(L, 2, 1); if(lua_isnumber(L, 4)) { *tensor_data = luaG_(checkreal)(L, 4); lua_pop(L, 1); } else if(lua_isnil(L, 4)) lua_pop(L, 1); else THError("given function should return a number or nil");); lua_settop(L, 1); return 1; } static int torch_Tensor_(map2)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor); THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor); luaL_checktype(L, 4, LUA_TFUNCTION); lua_settop(L, 4); TH_TENSOR_APPLY3(real, tensor, real, src1, real, src2, lua_pushvalue(L, 4); luaG_(pushreal)(L, *tensor_data); luaG_(pushreal)(L, *src1_data); luaG_(pushreal)(L, *src2_data); lua_call(L, 3, 1); if(lua_isnumber(L, 5)) { *tensor_data = luaG_(checkreal)(L, 5); lua_pop(L, 1); } else if(lua_isnil(L, 5)) lua_pop(L, 1); else THError("given function should return a number or nil");); lua_settop(L, 1); return 1; } #endif static int torch_Tensor_(factory)(lua_State *L) { THTensor *tensor = THTensor_(new)(); luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(write)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THFile *file = luaT_checkudata(L, 2, "torch.File"); THFile_writeIntScalar(file, tensor->nDimension); THFile_writeLongRaw(file, tensor->size, tensor->nDimension); THFile_writeLongRaw(file, tensor->stride, tensor->nDimension); THFile_writeLongScalar(file, tensor->storageOffset+1); /* to respect Lua convention */ lua_getfield(L, 2, "writeObject"); /* the method */ lua_pushvalue(L, 2); /* the file */ /* the storage */ if(tensor->storage) { THStorage_(retain)(tensor->storage); luaT_pushudata(L, tensor->storage, torch_Storage); } else lua_pushnil(L); lua_call(L, 2, 0); /* call the method */ return 0; } static int torch_Tensor_(read)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THFile *file = luaT_checkudata(L, 2, "torch.File"); tensor->nDimension = THFile_readIntScalar(file); tensor->size = THAlloc(sizeof(long)*tensor->nDimension); tensor->stride = THAlloc(sizeof(long)*tensor->nDimension); THFile_readLongRaw(file, tensor->size, tensor->nDimension); THFile_readLongRaw(file, tensor->stride, tensor->nDimension); tensor->storageOffset = THFile_readLongScalar(file); tensor->storageOffset--; /* to respect Lua convention */ lua_getfield(L, 2, "readObject"); /* the method */ lua_pushvalue(L, 2); /* the file */ lua_call(L, 1, 1); /* call the method */ tensor->storage = luaT_toudata(L, -1, torch_Storage); if(tensor->storage) THStorage_(retain)(tensor->storage); return 0; } static const struct luaL_Reg torch_Tensor_(_) [] = { {"retain", torch_Tensor_(retain)}, {"free", torch_Tensor_(free)}, {"contiguous", torch_Tensor_(contiguous)}, {"size", torch_Tensor_(size)}, {"elementSize", torch_Tensor_(elementSize)}, {"__len__", torch_Tensor_(size)}, {"stride", torch_Tensor_(stride)}, {"dim", torch_Tensor_(nDimension)}, {"nDimension", torch_Tensor_(nDimension)}, {"set", torch_Tensor_(set)}, {"storage", torch_Tensor_(storage)}, {"storageOffset", torch_Tensor_(storageOffset)}, {"clone", torch_Tensor_(clone)}, {"contiguous", torch_Tensor_(contiguous)}, {"resizeAs", torch_Tensor_(resizeAs)}, {"resize", torch_Tensor_(resize)}, {"narrow", torch_Tensor_(narrow)}, {"sub", torch_Tensor_(sub)}, {"select", torch_Tensor_(select)}, #ifndef TH_REAL_IS_HALF {"index", torch_Tensor_(indexSelect)}, {"indexCopy", torch_Tensor_(indexCopy)}, {"indexAdd", torch_Tensor_(indexAdd)}, {"indexFill", torch_Tensor_(indexFill)}, {"maskedSelect", torch_Tensor_(maskedSelect)}, {"maskedCopy", torch_Tensor_(maskedCopy)}, {"maskedFill", torch_Tensor_(maskedFill)}, #endif {"transpose", torch_Tensor_(transpose)}, {"t", torch_Tensor_(t)}, {"unfold", torch_Tensor_(unfold)}, {"isContiguous", torch_Tensor_(isContiguous)}, {"isSameSizeAs", torch_Tensor_(isSameSizeAs)}, {"isSetTo", torch_Tensor_(isSetTo)}, {"isSize", torch_Tensor_(isSize)}, {"nElement", torch_Tensor_(nElement)}, {"copy", torch_Tensor_(copy)}, #ifndef TH_REAL_IS_HALF {"apply", torch_Tensor_(apply)}, {"map", torch_Tensor_(map)}, {"map2", torch_Tensor_(map2)}, #endif {"read", torch_Tensor_(read)}, {"write", torch_Tensor_(write)}, {"__index__", torch_Tensor_(__index__)}, {"__newindex__", torch_Tensor_(__newindex__)}, {NULL, NULL} }; void torch_Tensor_(init)(lua_State *L) { luaT_newmetatable(L, torch_Tensor, NULL, torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); luaT_setfuncs(L, torch_Tensor_(_), 0); lua_pop(L, 1); #ifndef TH_REAL_IS_HALF THVector_(vectorDispatchInit)(); #endif } #endif