You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

fun.lua 28KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. ---
  2. --- Lua Fun - a high-performance functional programming library for LuaJIT
  3. ---
  4. --- Copyright (c) 2013-2017 Roman Tsisyk <roman@tsisyk.com>
  5. ---
  6. --- Distributed under the MIT/X11 License. See COPYING.md for more details.
  7. ---
  8. local exports = {}
  9. local methods = {}
  10. -- compatibility with Lua 5.1/5.2
  11. local unpack = rawget(table, "unpack") or unpack
  12. --------------------------------------------------------------------------------
  13. -- Tools
  14. --------------------------------------------------------------------------------
  15. local return_if_not_empty = function(state_x, ...)
  16. if state_x == nil then
  17. return nil
  18. end
  19. return ...
  20. end
  21. local call_if_not_empty = function(fun, state_x, ...)
  22. if state_x == nil then
  23. return nil
  24. end
  25. return state_x, fun(...)
  26. end
  27. local function deepcopy(orig) -- used by cycle()
  28. local orig_type = type(orig)
  29. local copy
  30. if orig_type == 'table' then
  31. copy = {}
  32. for orig_key, orig_value in next, orig, nil do
  33. copy[deepcopy(orig_key)] = deepcopy(orig_value)
  34. end
  35. else
  36. copy = orig
  37. end
  38. return copy
  39. end
  40. local iterator_mt = {
  41. -- usually called by for-in loop
  42. __call = function(self, param, state)
  43. return self.gen(param, state)
  44. end;
  45. __tostring = function(self)
  46. return '<generator>'
  47. end;
  48. -- add all exported methods
  49. __index = methods;
  50. }
  51. local wrap = function(gen, param, state)
  52. return setmetatable({
  53. gen = gen,
  54. param = param,
  55. state = state
  56. }, iterator_mt), param, state
  57. end
  58. exports.wrap = wrap
  59. local unwrap = function(self)
  60. return self.gen, self.param, self.state
  61. end
  62. methods.unwrap = unwrap
  63. --------------------------------------------------------------------------------
  64. -- Basic Functions
  65. --------------------------------------------------------------------------------
  66. local nil_gen = function(_param, _state)
  67. return nil
  68. end
  69. local string_gen = function(param, state)
  70. local state = state + 1
  71. if state > #param then
  72. return nil
  73. end
  74. local r = string.sub(param, state, state)
  75. return state, r
  76. end
  77. local ipairs_gen = ipairs({}) -- get the generating function from ipairs
  78. local pairs_gen = pairs({ a = 0 }) -- get the generating function from pairs
  79. local map_gen = function(tab, key)
  80. local value
  81. local key, value = pairs_gen(tab, key)
  82. return key, key, value
  83. end
  84. local rawiter = function(obj, param, state)
  85. assert(obj ~= nil, "invalid iterator")
  86. if type(obj) == "table" then
  87. local mt = getmetatable(obj);
  88. if mt ~= nil then
  89. if mt == iterator_mt then
  90. return obj.gen, obj.param, obj.state
  91. elseif mt.__ipairs ~= nil then
  92. return mt.__ipairs(obj)
  93. elseif mt.__pairs ~= nil then
  94. return mt.__pairs(obj)
  95. end
  96. end
  97. if #obj > 0 then
  98. -- array
  99. return ipairs(obj)
  100. else
  101. -- hash
  102. return map_gen, obj, nil
  103. end
  104. elseif (type(obj) == "function") then
  105. return obj, param, state
  106. elseif (type(obj) == "string") then
  107. if #obj == 0 then
  108. return nil_gen, nil, nil
  109. end
  110. return string_gen, obj, 0
  111. end
  112. error(string.format('object %s of type "%s" is not iterable',
  113. obj, type(obj)))
  114. end
  115. local iter = function(obj, param, state)
  116. return wrap(rawiter(obj, param, state))
  117. end
  118. exports.iter = iter
  119. local method0 = function(fun)
  120. return function(self)
  121. return fun(self.gen, self.param, self.state)
  122. end
  123. end
  124. local method1 = function(fun)
  125. return function(self, arg1)
  126. return fun(arg1, self.gen, self.param, self.state)
  127. end
  128. end
  129. local method2 = function(fun)
  130. return function(self, arg1, arg2)
  131. return fun(arg1, arg2, self.gen, self.param, self.state)
  132. end
  133. end
  134. local export0 = function(fun)
  135. return function(gen, param, state)
  136. return fun(rawiter(gen, param, state))
  137. end
  138. end
  139. local export1 = function(fun)
  140. return function(arg1, gen, param, state)
  141. return fun(arg1, rawiter(gen, param, state))
  142. end
  143. end
  144. local export2 = function(fun)
  145. return function(arg1, arg2, gen, param, state)
  146. return fun(arg1, arg2, rawiter(gen, param, state))
  147. end
  148. end
  149. local each = function(fun, gen, param, state)
  150. repeat
  151. state = call_if_not_empty(fun, gen(param, state))
  152. until state == nil
  153. end
  154. methods.each = method1(each)
  155. exports.each = export1(each)
  156. methods.for_each = methods.each
  157. exports.for_each = exports.each
  158. methods.foreach = methods.each
  159. exports.foreach = exports.each
  160. --------------------------------------------------------------------------------
  161. -- Generators
  162. --------------------------------------------------------------------------------
  163. local range_gen = function(param, state)
  164. local stop, step = param[1], param[2]
  165. local state = state + step
  166. if state > stop then
  167. return nil
  168. end
  169. return state, state
  170. end
  171. local range_rev_gen = function(param, state)
  172. local stop, step = param[1], param[2]
  173. local state = state + step
  174. if state < stop then
  175. return nil
  176. end
  177. return state, state
  178. end
  179. local range = function(start, stop, step)
  180. if step == nil then
  181. if stop == nil then
  182. if start == 0 then
  183. return nil_gen, nil, nil
  184. end
  185. stop = start
  186. start = stop > 0 and 1 or -1
  187. end
  188. step = start <= stop and 1 or -1
  189. end
  190. assert(type(start) == "number", "start must be a number")
  191. assert(type(stop) == "number", "stop must be a number")
  192. assert(type(step) == "number", "step must be a number")
  193. assert(step ~= 0, "step must not be zero")
  194. if (step > 0) then
  195. return wrap(range_gen, {stop, step}, start - step)
  196. elseif (step < 0) then
  197. return wrap(range_rev_gen, {stop, step}, start - step)
  198. end
  199. end
  200. exports.range = range
  201. local duplicate_table_gen = function(param_x, state_x)
  202. return state_x + 1, unpack(param_x)
  203. end
  204. local duplicate_fun_gen = function(param_x, state_x)
  205. return state_x + 1, param_x(state_x)
  206. end
  207. local duplicate_gen = function(param_x, state_x)
  208. return state_x + 1, param_x
  209. end
  210. local duplicate = function(...)
  211. if select('#', ...) <= 1 then
  212. return wrap(duplicate_gen, select(1, ...), 0)
  213. else
  214. return wrap(duplicate_table_gen, {...}, 0)
  215. end
  216. end
  217. exports.duplicate = duplicate
  218. exports.replicate = duplicate
  219. exports.xrepeat = duplicate
  220. local tabulate = function(fun)
  221. assert(type(fun) == "function")
  222. return wrap(duplicate_fun_gen, fun, 0)
  223. end
  224. exports.tabulate = tabulate
  225. local zeros = function()
  226. return wrap(duplicate_gen, 0, 0)
  227. end
  228. exports.zeros = zeros
  229. local ones = function()
  230. return wrap(duplicate_gen, 1, 0)
  231. end
  232. exports.ones = ones
  233. local rands_gen = function(param_x, _state_x)
  234. return 0, math.random(param_x[1], param_x[2])
  235. end
  236. local rands_nil_gen = function(_param_x, _state_x)
  237. return 0, math.random()
  238. end
  239. local rands = function(n, m)
  240. if n == nil and m == nil then
  241. return wrap(rands_nil_gen, 0, 0)
  242. end
  243. assert(type(n) == "number", "invalid first arg to rands")
  244. if m == nil then
  245. m = n
  246. n = 0
  247. else
  248. assert(type(m) == "number", "invalid second arg to rands")
  249. end
  250. assert(n < m, "empty interval")
  251. return wrap(rands_gen, {n, m - 1}, 0)
  252. end
  253. exports.rands = rands
  254. --------------------------------------------------------------------------------
  255. -- Slicing
  256. --------------------------------------------------------------------------------
  257. local nth = function(n, gen_x, param_x, state_x)
  258. assert(n > 0, "invalid first argument to nth")
  259. -- An optimization for arrays and strings
  260. if gen_x == ipairs_gen then
  261. return param_x[n]
  262. elseif gen_x == string_gen then
  263. if n <= #param_x then
  264. return string.sub(param_x, n, n)
  265. else
  266. return nil
  267. end
  268. end
  269. for i=1,n-1,1 do
  270. state_x = gen_x(param_x, state_x)
  271. if state_x == nil then
  272. return nil
  273. end
  274. end
  275. return return_if_not_empty(gen_x(param_x, state_x))
  276. end
  277. methods.nth = method1(nth)
  278. exports.nth = export1(nth)
  279. local head_call = function(state, ...)
  280. if state == nil then
  281. error("head: iterator is empty")
  282. end
  283. return ...
  284. end
  285. local head = function(gen, param, state)
  286. return head_call(gen(param, state))
  287. end
  288. methods.head = method0(head)
  289. exports.head = export0(head)
  290. exports.car = exports.head
  291. methods.car = methods.head
  292. local tail = function(gen, param, state)
  293. state = gen(param, state)
  294. if state == nil then
  295. return wrap(nil_gen, nil, nil)
  296. end
  297. return wrap(gen, param, state)
  298. end
  299. methods.tail = method0(tail)
  300. exports.tail = export0(tail)
  301. exports.cdr = exports.tail
  302. methods.cdr = methods.tail
  303. local take_n_gen_x = function(i, state_x, ...)
  304. if state_x == nil then
  305. return nil
  306. end
  307. return {i, state_x}, ...
  308. end
  309. local take_n_gen = function(param, state)
  310. local n, gen_x, param_x = param[1], param[2], param[3]
  311. local i, state_x = state[1], state[2]
  312. if i >= n then
  313. return nil
  314. end
  315. return take_n_gen_x(i + 1, gen_x(param_x, state_x))
  316. end
  317. local take_n = function(n, gen, param, state)
  318. assert(n >= 0, "invalid first argument to take_n")
  319. return wrap(take_n_gen, {n, gen, param}, {0, state})
  320. end
  321. methods.take_n = method1(take_n)
  322. exports.take_n = export1(take_n)
  323. local take_while_gen_x = function(fun, state_x, ...)
  324. if state_x == nil or not fun(...) then
  325. return nil
  326. end
  327. return state_x, ...
  328. end
  329. local take_while_gen = function(param, state_x)
  330. local fun, gen_x, param_x = param[1], param[2], param[3]
  331. return take_while_gen_x(fun, gen_x(param_x, state_x))
  332. end
  333. local take_while = function(fun, gen, param, state)
  334. assert(type(fun) == "function", "invalid first argument to take_while")
  335. return wrap(take_while_gen, {fun, gen, param}, state)
  336. end
  337. methods.take_while = method1(take_while)
  338. exports.take_while = export1(take_while)
  339. local take = function(n_or_fun, gen, param, state)
  340. if type(n_or_fun) == "number" then
  341. return take_n(n_or_fun, gen, param, state)
  342. else
  343. return take_while(n_or_fun, gen, param, state)
  344. end
  345. end
  346. methods.take = method1(take)
  347. exports.take = export1(take)
  348. local drop_n = function(n, gen, param, state)
  349. assert(n >= 0, "invalid first argument to drop_n")
  350. local i
  351. for i=1,n,1 do
  352. state = gen(param, state)
  353. if state == nil then
  354. return wrap(nil_gen, nil, nil)
  355. end
  356. end
  357. return wrap(gen, param, state)
  358. end
  359. methods.drop_n = method1(drop_n)
  360. exports.drop_n = export1(drop_n)
  361. local drop_while_x = function(fun, state_x, ...)
  362. if state_x == nil or not fun(...) then
  363. return state_x, false
  364. end
  365. return state_x, true, ...
  366. end
  367. local drop_while = function(fun, gen_x, param_x, state_x)
  368. assert(type(fun) == "function", "invalid first argument to drop_while")
  369. local cont, state_x_prev
  370. repeat
  371. state_x_prev = deepcopy(state_x)
  372. state_x, cont = drop_while_x(fun, gen_x(param_x, state_x))
  373. until not cont
  374. if state_x == nil then
  375. return wrap(nil_gen, nil, nil)
  376. end
  377. return wrap(gen_x, param_x, state_x_prev)
  378. end
  379. methods.drop_while = method1(drop_while)
  380. exports.drop_while = export1(drop_while)
  381. local drop = function(n_or_fun, gen_x, param_x, state_x)
  382. if type(n_or_fun) == "number" then
  383. return drop_n(n_or_fun, gen_x, param_x, state_x)
  384. else
  385. return drop_while(n_or_fun, gen_x, param_x, state_x)
  386. end
  387. end
  388. methods.drop = method1(drop)
  389. exports.drop = export1(drop)
  390. local split = function(n_or_fun, gen_x, param_x, state_x)
  391. return take(n_or_fun, gen_x, param_x, state_x),
  392. drop(n_or_fun, gen_x, param_x, state_x)
  393. end
  394. methods.split = method1(split)
  395. exports.split = export1(split)
  396. methods.split_at = methods.split
  397. exports.split_at = exports.split
  398. methods.span = methods.split
  399. exports.span = exports.split
  400. --------------------------------------------------------------------------------
  401. -- Indexing
  402. --------------------------------------------------------------------------------
  403. local index = function(x, gen, param, state)
  404. local i = 1
  405. for _k, r in gen, param, state do
  406. if r == x then
  407. return i
  408. end
  409. i = i + 1
  410. end
  411. return nil
  412. end
  413. methods.index = method1(index)
  414. exports.index = export1(index)
  415. methods.index_of = methods.index
  416. exports.index_of = exports.index
  417. methods.elem_index = methods.index
  418. exports.elem_index = exports.index
  419. local indexes_gen = function(param, state)
  420. local x, gen_x, param_x = param[1], param[2], param[3]
  421. local i, state_x = state[1], state[2]
  422. local r
  423. while true do
  424. state_x, r = gen_x(param_x, state_x)
  425. if state_x == nil then
  426. return nil
  427. end
  428. i = i + 1
  429. if r == x then
  430. return {i, state_x}, i
  431. end
  432. end
  433. end
  434. local indexes = function(x, gen, param, state)
  435. return wrap(indexes_gen, {x, gen, param}, {0, state})
  436. end
  437. methods.indexes = method1(indexes)
  438. exports.indexes = export1(indexes)
  439. methods.elem_indexes = methods.indexes
  440. exports.elem_indexes = exports.indexes
  441. methods.indices = methods.indexes
  442. exports.indices = exports.indexes
  443. methods.elem_indices = methods.indexes
  444. exports.elem_indices = exports.indexes
  445. --------------------------------------------------------------------------------
  446. -- Filtering
  447. --------------------------------------------------------------------------------
  448. local filter1_gen = function(fun, gen_x, param_x, state_x, a)
  449. while true do
  450. if state_x == nil or fun(a) then break; end
  451. state_x, a = gen_x(param_x, state_x)
  452. end
  453. return state_x, a
  454. end
  455. -- call each other
  456. local filterm_gen
  457. local filterm_gen_shrink = function(fun, gen_x, param_x, state_x)
  458. return filterm_gen(fun, gen_x, param_x, gen_x(param_x, state_x))
  459. end
  460. filterm_gen = function(fun, gen_x, param_x, state_x, ...)
  461. if state_x == nil then
  462. return nil
  463. end
  464. if fun(...) then
  465. return state_x, ...
  466. end
  467. return filterm_gen_shrink(fun, gen_x, param_x, state_x)
  468. end
  469. local filter_detect = function(fun, gen_x, param_x, state_x, ...)
  470. if select('#', ...) < 2 then
  471. return filter1_gen(fun, gen_x, param_x, state_x, ...)
  472. else
  473. return filterm_gen(fun, gen_x, param_x, state_x, ...)
  474. end
  475. end
  476. local filter_gen = function(param, state_x)
  477. local fun, gen_x, param_x = param[1], param[2], param[3]
  478. return filter_detect(fun, gen_x, param_x, gen_x(param_x, state_x))
  479. end
  480. local filter = function(fun, gen, param, state)
  481. return wrap(filter_gen, {fun, gen, param}, state)
  482. end
  483. methods.filter = method1(filter)
  484. exports.filter = export1(filter)
  485. methods.remove_if = methods.filter
  486. exports.remove_if = exports.filter
  487. local grep = function(fun_or_regexp, gen, param, state)
  488. local fun = fun_or_regexp
  489. if type(fun_or_regexp) == "string" then
  490. fun = function(x) return string.find(x, fun_or_regexp) ~= nil end
  491. end
  492. return filter(fun, gen, param, state)
  493. end
  494. methods.grep = method1(grep)
  495. exports.grep = export1(grep)
  496. local partition = function(fun, gen, param, state)
  497. local neg_fun = function(...)
  498. return not fun(...)
  499. end
  500. return filter(fun, gen, param, state),
  501. filter(neg_fun, gen, param, state)
  502. end
  503. methods.partition = method1(partition)
  504. exports.partition = export1(partition)
  505. --------------------------------------------------------------------------------
  506. -- Reducing
  507. --------------------------------------------------------------------------------
  508. local foldl_call = function(fun, start, state, ...)
  509. if state == nil then
  510. return nil, start
  511. end
  512. return state, fun(start, ...)
  513. end
  514. local foldl = function(fun, start, gen_x, param_x, state_x)
  515. while true do
  516. state_x, start = foldl_call(fun, start, gen_x(param_x, state_x))
  517. if state_x == nil then
  518. break;
  519. end
  520. end
  521. return start
  522. end
  523. methods.foldl = method2(foldl)
  524. exports.foldl = export2(foldl)
  525. methods.reduce = methods.foldl
  526. exports.reduce = exports.foldl
  527. local length = function(gen, param, state)
  528. if gen == ipairs_gen or gen == string_gen then
  529. return #param
  530. end
  531. local len = 0
  532. repeat
  533. state = gen(param, state)
  534. len = len + 1
  535. until state == nil
  536. return len - 1
  537. end
  538. methods.length = method0(length)
  539. exports.length = export0(length)
  540. local is_null = function(gen, param, state)
  541. return gen(param, deepcopy(state)) == nil
  542. end
  543. methods.is_null = method0(is_null)
  544. exports.is_null = export0(is_null)
  545. local is_prefix_of = function(iter_x, iter_y)
  546. local gen_x, param_x, state_x = iter(iter_x)
  547. local gen_y, param_y, state_y = iter(iter_y)
  548. local r_x, r_y
  549. for i=1,10,1 do
  550. state_x, r_x = gen_x(param_x, state_x)
  551. state_y, r_y = gen_y(param_y, state_y)
  552. if state_x == nil then
  553. return true
  554. end
  555. if state_y == nil or r_x ~= r_y then
  556. return false
  557. end
  558. end
  559. end
  560. methods.is_prefix_of = is_prefix_of
  561. exports.is_prefix_of = is_prefix_of
  562. local all = function(fun, gen_x, param_x, state_x)
  563. local r
  564. repeat
  565. state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x))
  566. until state_x == nil or not r
  567. return state_x == nil
  568. end
  569. methods.all = method1(all)
  570. exports.all = export1(all)
  571. methods.every = methods.all
  572. exports.every = exports.all
  573. local any = function(fun, gen_x, param_x, state_x)
  574. local r
  575. repeat
  576. state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x))
  577. until state_x == nil or r
  578. return not not r
  579. end
  580. methods.any = method1(any)
  581. exports.any = export1(any)
  582. methods.some = methods.any
  583. exports.some = exports.any
  584. local sum = function(gen, param, state)
  585. local s = 0
  586. local r = 0
  587. repeat
  588. s = s + r
  589. state, r = gen(param, state)
  590. until state == nil
  591. return s
  592. end
  593. methods.sum = method0(sum)
  594. exports.sum = export0(sum)
  595. local product = function(gen, param, state)
  596. local p = 1
  597. local r = 1
  598. repeat
  599. p = p * r
  600. state, r = gen(param, state)
  601. until state == nil
  602. return p
  603. end
  604. methods.product = method0(product)
  605. exports.product = export0(product)
  606. local min_cmp = function(m, n)
  607. if n < m then return n else return m end
  608. end
  609. local max_cmp = function(m, n)
  610. if n > m then return n else return m end
  611. end
  612. local min = function(gen, param, state)
  613. local state, m = gen(param, state)
  614. if state == nil then
  615. error("min: iterator is empty")
  616. end
  617. local cmp
  618. if type(m) == "number" then
  619. -- An optimization: use math.min for numbers
  620. cmp = math.min
  621. else
  622. cmp = min_cmp
  623. end
  624. for _, r in gen, param, state do
  625. m = cmp(m, r)
  626. end
  627. return m
  628. end
  629. methods.min = method0(min)
  630. exports.min = export0(min)
  631. methods.minimum = methods.min
  632. exports.minimum = exports.min
  633. local min_by = function(cmp, gen_x, param_x, state_x)
  634. local state_x, m = gen_x(param_x, state_x)
  635. if state_x == nil then
  636. error("min: iterator is empty")
  637. end
  638. for _, r in gen_x, param_x, state_x do
  639. m = cmp(m, r)
  640. end
  641. return m
  642. end
  643. methods.min_by = method1(min_by)
  644. exports.min_by = export1(min_by)
  645. methods.minimum_by = methods.min_by
  646. exports.minimum_by = exports.min_by
  647. local max = function(gen_x, param_x, state_x)
  648. local state_x, m = gen_x(param_x, state_x)
  649. if state_x == nil then
  650. error("max: iterator is empty")
  651. end
  652. local cmp
  653. if type(m) == "number" then
  654. -- An optimization: use math.max for numbers
  655. cmp = math.max
  656. else
  657. cmp = max_cmp
  658. end
  659. for _, r in gen_x, param_x, state_x do
  660. m = cmp(m, r)
  661. end
  662. return m
  663. end
  664. methods.max = method0(max)
  665. exports.max = export0(max)
  666. methods.maximum = methods.max
  667. exports.maximum = exports.max
  668. local max_by = function(cmp, gen_x, param_x, state_x)
  669. local state_x, m = gen_x(param_x, state_x)
  670. if state_x == nil then
  671. error("max: iterator is empty")
  672. end
  673. for _, r in gen_x, param_x, state_x do
  674. m = cmp(m, r)
  675. end
  676. return m
  677. end
  678. methods.max_by = method1(max_by)
  679. exports.max_by = export1(max_by)
  680. methods.maximum_by = methods.maximum_by
  681. exports.maximum_by = exports.maximum_by
  682. local totable = function(gen_x, param_x, state_x)
  683. local tab, key, val = {}
  684. while true do
  685. state_x, val = gen_x(param_x, state_x)
  686. if state_x == nil then
  687. break
  688. end
  689. table.insert(tab, val)
  690. end
  691. return tab
  692. end
  693. methods.totable = method0(totable)
  694. exports.totable = export0(totable)
  695. local tomap = function(gen_x, param_x, state_x)
  696. local tab, key, val = {}
  697. while true do
  698. state_x, key, val = gen_x(param_x, state_x)
  699. if state_x == nil then
  700. break
  701. end
  702. tab[key] = val
  703. end
  704. return tab
  705. end
  706. methods.tomap = method0(tomap)
  707. exports.tomap = export0(tomap)
  708. --------------------------------------------------------------------------------
  709. -- Transformations
  710. --------------------------------------------------------------------------------
  711. local map_gen = function(param, state)
  712. local gen_x, param_x, fun = param[1], param[2], param[3]
  713. return call_if_not_empty(fun, gen_x(param_x, state))
  714. end
  715. local map = function(fun, gen, param, state)
  716. return wrap(map_gen, {gen, param, fun}, state)
  717. end
  718. methods.map = method1(map)
  719. exports.map = export1(map)
  720. local enumerate_gen_call = function(state, i, state_x, ...)
  721. if state_x == nil then
  722. return nil
  723. end
  724. return {i + 1, state_x}, i, ...
  725. end
  726. local enumerate_gen = function(param, state)
  727. local gen_x, param_x = param[1], param[2]
  728. local i, state_x = state[1], state[2]
  729. return enumerate_gen_call(state, i, gen_x(param_x, state_x))
  730. end
  731. local enumerate = function(gen, param, state)
  732. return wrap(enumerate_gen, {gen, param}, {1, state})
  733. end
  734. methods.enumerate = method0(enumerate)
  735. exports.enumerate = export0(enumerate)
  736. local intersperse_call = function(i, state_x, ...)
  737. if state_x == nil then
  738. return nil
  739. end
  740. return {i + 1, state_x}, ...
  741. end
  742. local intersperse_gen = function(param, state)
  743. local x, gen_x, param_x = param[1], param[2], param[3]
  744. local i, state_x = state[1], state[2]
  745. if i % 2 == 1 then
  746. return {i + 1, state_x}, x
  747. else
  748. return intersperse_call(i, gen_x(param_x, state_x))
  749. end
  750. end
  751. -- TODO: interperse must not add x to the tail
  752. local intersperse = function(x, gen, param, state)
  753. return wrap(intersperse_gen, {x, gen, param}, {0, state})
  754. end
  755. methods.intersperse = method1(intersperse)
  756. exports.intersperse = export1(intersperse)
  757. --------------------------------------------------------------------------------
  758. -- Compositions
  759. --------------------------------------------------------------------------------
  760. local function zip_gen_r(param, state, state_new, ...)
  761. if #state_new == #param / 2 then
  762. return state_new, ...
  763. end
  764. local i = #state_new + 1
  765. local gen_x, param_x = param[2 * i - 1], param[2 * i]
  766. local state_x, r = gen_x(param_x, state[i])
  767. if state_x == nil then
  768. return nil
  769. end
  770. table.insert(state_new, state_x)
  771. return zip_gen_r(param, state, state_new, r, ...)
  772. end
  773. local zip_gen = function(param, state)
  774. return zip_gen_r(param, state, {})
  775. end
  776. -- A special hack for zip/chain to skip last two state, if a wrapped iterator
  777. -- has been passed
  778. local numargs = function(...)
  779. local n = select('#', ...)
  780. if n >= 3 then
  781. -- Fix last argument
  782. local it = select(n - 2, ...)
  783. if type(it) == 'table' and getmetatable(it) == iterator_mt and
  784. it.param == select(n - 1, ...) and it.state == select(n, ...) then
  785. return n - 2
  786. end
  787. end
  788. return n
  789. end
  790. local zip = function(...)
  791. local n = numargs(...)
  792. if n == 0 then
  793. return wrap(nil_gen, nil, nil)
  794. end
  795. local param = { [2 * n] = 0 }
  796. local state = { [n] = 0 }
  797. local i, gen_x, param_x, state_x
  798. for i=1,n,1 do
  799. local it = select(n - i + 1, ...)
  800. gen_x, param_x, state_x = rawiter(it)
  801. param[2 * i - 1] = gen_x
  802. param[2 * i] = param_x
  803. state[i] = state_x
  804. end
  805. return wrap(zip_gen, param, state)
  806. end
  807. methods.zip = zip
  808. exports.zip = zip
  809. local cycle_gen_call = function(param, state_x, ...)
  810. if state_x == nil then
  811. local gen_x, param_x, state_x0 = param[1], param[2], param[3]
  812. return gen_x(param_x, deepcopy(state_x0))
  813. end
  814. return state_x, ...
  815. end
  816. local cycle_gen = function(param, state_x)
  817. local gen_x, param_x, state_x0 = param[1], param[2], param[3]
  818. return cycle_gen_call(param, gen_x(param_x, state_x))
  819. end
  820. local cycle = function(gen, param, state)
  821. return wrap(cycle_gen, {gen, param, state}, deepcopy(state))
  822. end
  823. methods.cycle = method0(cycle)
  824. exports.cycle = export0(cycle)
  825. -- call each other
  826. local chain_gen_r1
  827. local chain_gen_r2 = function(param, state, state_x, ...)
  828. if state_x == nil then
  829. local i = state[1]
  830. i = i + 1
  831. if param[3 * i - 1] == nil then
  832. return nil
  833. end
  834. local state_x = param[3 * i]
  835. return chain_gen_r1(param, {i, state_x})
  836. end
  837. return {state[1], state_x}, ...
  838. end
  839. chain_gen_r1 = function(param, state)
  840. local i, state_x = state[1], state[2]
  841. local gen_x, param_x = param[3 * i - 2], param[3 * i - 1]
  842. return chain_gen_r2(param, state, gen_x(param_x, state[2]))
  843. end
  844. local chain = function(...)
  845. local n = numargs(...)
  846. if n == 0 then
  847. return wrap(nil_gen, nil, nil)
  848. end
  849. local param = { [3 * n] = 0 }
  850. local i, gen_x, param_x, state_x
  851. for i=1,n,1 do
  852. local elem = select(i, ...)
  853. gen_x, param_x, state_x = iter(elem)
  854. param[3 * i - 2] = gen_x
  855. param[3 * i - 1] = param_x
  856. param[3 * i] = state_x
  857. end
  858. return wrap(chain_gen_r1, param, {1, param[3]})
  859. end
  860. methods.chain = chain
  861. exports.chain = chain
  862. --------------------------------------------------------------------------------
  863. -- Operators
  864. --------------------------------------------------------------------------------
  865. local operator = {
  866. ----------------------------------------------------------------------------
  867. -- Comparison operators
  868. ----------------------------------------------------------------------------
  869. lt = function(a, b) return a < b end,
  870. le = function(a, b) return a <= b end,
  871. eq = function(a, b) return a == b end,
  872. ne = function(a, b) return a ~= b end,
  873. ge = function(a, b) return a >= b end,
  874. gt = function(a, b) return a > b end,
  875. ----------------------------------------------------------------------------
  876. -- Arithmetic operators
  877. ----------------------------------------------------------------------------
  878. add = function(a, b) return a + b end,
  879. div = function(a, b) return a / b end,
  880. floordiv = function(a, b) return math.floor(a/b) end,
  881. intdiv = function(a, b)
  882. local q = a / b
  883. if a >= 0 then return math.floor(q) else return math.ceil(q) end
  884. end,
  885. mod = function(a, b) return a % b end,
  886. mul = function(a, b) return a * b end,
  887. neq = function(a) return -a end,
  888. unm = function(a) return -a end, -- an alias
  889. pow = function(a, b) return a ^ b end,
  890. sub = function(a, b) return a - b end,
  891. truediv = function(a, b) return a / b end,
  892. ----------------------------------------------------------------------------
  893. -- String operators
  894. ----------------------------------------------------------------------------
  895. concat = function(a, b) return a..b end,
  896. len = function(a) return #a end,
  897. length = function(a) return #a end, -- an alias
  898. ----------------------------------------------------------------------------
  899. -- Logical operators
  900. ----------------------------------------------------------------------------
  901. land = function(a, b) return a and b end,
  902. lor = function(a, b) return a or b end,
  903. lnot = function(a) return not a end,
  904. truth = function(a) return not not a end,
  905. }
  906. exports.operator = operator
  907. methods.operator = operator
  908. exports.op = operator
  909. methods.op = operator
  910. --------------------------------------------------------------------------------
  911. -- module definitions
  912. --------------------------------------------------------------------------------
  913. -- a special syntax sugar to export all functions to the global table
  914. setmetatable(exports, {
  915. __call = function(t, override)
  916. for k, v in pairs(t) do
  917. if _G[k] ~= nil then
  918. local msg = 'function ' .. k .. ' already exists in global scope.'
  919. if override then
  920. _G[k] = v
  921. print('WARNING: ' .. msg .. ' Overwritten.')
  922. else
  923. print('NOTICE: ' .. msg .. ' Skipped.')
  924. end
  925. else
  926. _G[k] = v
  927. end
  928. end
  929. end,
  930. })
  931. return exports