You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

WorkPool.lua 4.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. local dt = require "decisiontree._env"
  2. -- Utility to simplify construction of a pool of daemon threads with which to execute tasks in parallel.
  3. local WorkPool = torch.class("dt.WorkPool", dt)
  4. function WorkPool:__init(nThread)
  5. self.nThread = nThread or 16
  6. assert(torch.type(self.nThread) == 'number')
  7. assert(self.nThread > 0)
  8. self:initialize()
  9. end
  10. function WorkPool:initialize()
  11. local ipc = require 'libipc'
  12. self.queuename = os.tmpname()
  13. self.queue = ipc.workqueue(self.queuename)
  14. self.queues = {}
  15. for i=1,self.nThread do
  16. self.queues[i] = ipc.workqueue(self.queuename.."/"..i)
  17. end
  18. -- spawn thread workers
  19. ipc.map(self.nThread, function(queuename, nThread, myId)
  20. assert(nThread)
  21. assert(myId)
  22. local ipc = require 'libipc'
  23. -- Open the queue by name (the main thread already created it)
  24. local mainqueue = ipc.workqueue(queuename)
  25. local workqueue = ipc.workqueue(queuename.."/"..myId)
  26. local taskname, args
  27. local store = {}
  28. local queue = mainqueue
  29. repeat
  30. local msg = queue:read()
  31. assert(torch.type(msg) == 'table')
  32. taskname, task = unpack(msg)
  33. if taskname == nil then
  34. break
  35. elseif torch.type(taskname) ~= 'string' then
  36. error("Expecting taskname string. Got "..torch.type(taskname))
  37. elseif taskname == 'storeKeyValue' then
  38. assert(torch.type(task) == 'table')
  39. assert(queue == workqueue)
  40. store[task.key] = task.value
  41. queue:write({taskname})
  42. elseif taskname == 'storeKeysValues' then
  43. assert(torch.type(task) == 'table')
  44. assert(queue == workqueue)
  45. for key,value in pairs(task) do
  46. store[key] = value
  47. end
  48. queue:write({taskname})
  49. elseif taskname == 'require' then
  50. assert(torch.type(task) == 'table')
  51. assert(torch.type(task.libname) == 'string')
  52. assert(torch.type(task.varname) == 'string')
  53. _G[task.varname] = require(task.libname)
  54. assert(queue == workqueue)
  55. queue:write({taskname})
  56. elseif taskname == 'storeReset' then
  57. store = {}
  58. mainqueue:write({taskname})
  59. elseif taskname == 'echo' then
  60. mainqueue:write({taskname, task})
  61. elseif taskname == 'readWorkerQueue' then
  62. queue = workqueue
  63. elseif taskname == 'readMainQueue' then
  64. queue = mainqueue
  65. elseif taskname == 'execute' then
  66. if torch.type(task) == 'table' then
  67. assert(task.func and task.args)
  68. queue:write({taskname, task.func(store, task.args, myId)})
  69. else
  70. assert(torch.type(task) == 'function')
  71. queue:write({taskname, task(store, myId)})
  72. end
  73. else
  74. error("Unknown taskname: "..taskname)
  75. end
  76. until taskname == nil
  77. end, self.queuename, self.nThread)
  78. end
  79. -- Terminates all daemon threads.
  80. function WorkPool:terminate()
  81. for i=1,self.nThread do
  82. self.queue:write({})
  83. end
  84. end
  85. -- this function is used to update the store of data in each worker thread
  86. function WorkPool:_update(taskname, task, upval)
  87. assert(torch.type(taskname) == 'string')
  88. local _ = require 'moses'
  89. assert(_.contains({'storeKeyValue','storeKeysValues','require','execute'}, taskname))
  90. assert(torch.type(task) == 'table' or torch.type(task) == 'function')
  91. -- tell the workers to read their individual queue
  92. for i=1,self.nThread do
  93. self.queue:write({'readWorkerQueue'})
  94. end
  95. -- write to individual worker queues
  96. for i=1,self.nThread do
  97. if upval then
  98. self.queues[i]:writeup({taskname, task})
  99. else
  100. self.queues[i]:write({taskname, task})
  101. end
  102. end
  103. -- TODO use ipc.mutex:barrier(nThread+1)
  104. -- barrier: make sure that every worker has completed task by reading their queue
  105. for i=1,self.nThread do
  106. assert(self.queues[i]:read()[1] == taskname)
  107. end
  108. -- finally, tell them to read the main queue
  109. for i=1,self.nThread do
  110. self.queues[i]:write({'readMainQueue'})
  111. end
  112. end
  113. function WorkPool:update(taskname, task)
  114. return self:_update(taskname, task, false)
  115. end
  116. function WorkPool:updateup(taskname, task)
  117. return self:_update(taskname, task, true)
  118. end
  119. function WorkPool:write(taskname, task)
  120. assert(torch.type(taskname) == 'string')
  121. assert(taskname ~= 'storeKeyValue' or taskname ~= 'storeKeysValues')
  122. self.queue:write({taskname, task})
  123. end
  124. function WorkPool:writeup(taskname, task)
  125. assert(torch.type(taskname) == 'string')
  126. assert(taskname ~= 'storeKeyValue' or taskname ~= 'storeKeysValues')
  127. self.queue:writeup({taskname, task})
  128. end
  129. function WorkPool:read()
  130. local res = self.queue:read()
  131. assert(torch.type(res) == 'table')
  132. assert(torch.type(res[1] == 'string'))
  133. return unpack(res)
  134. end