aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--contrib/lua-fun/fun.lua513
1 files changed, 293 insertions, 220 deletions
diff --git a/contrib/lua-fun/fun.lua b/contrib/lua-fun/fun.lua
index 5d019d6a1..6a536d759 100644
--- a/contrib/lua-fun/fun.lua
+++ b/contrib/lua-fun/fun.lua
@@ -1,11 +1,14 @@
---
--- Lua Fun - a high-performance functional programming library for LuaJIT
---
---- Copyright (c) 2013 Roman Tsisyk <roman@tsisyk.com>
+--- Copyright (c) 2013-2014 Roman Tsisyk <roman@tsisyk.com>
---
--- Distributed under the MIT/X11 License. See COPYING.md for more details.
---
+local exports = {}
+local methods = {}
+
--------------------------------------------------------------------------------
-- Tools
--------------------------------------------------------------------------------
@@ -38,6 +41,32 @@ local function deepcopy(orig) -- used by cycle()
return copy
end
+local iterator_mt = {
+ -- usually called by for-in loop
+ __call = function(self, param, state)
+ return self.gen(param, state)
+ end;
+ __tostring = function(self)
+ return '<generator>'
+ end;
+ -- add all exported methods
+ __index = methods;
+}
+
+local wrap = function(gen, param, state)
+ return setmetatable({
+ gen = gen,
+ param = param,
+ state = state
+ }, iterator_mt), param, state
+end
+exports.wrap = wrap
+
+local unwrap = function(self)
+ return self.gen, self.param, self.state
+end
+methods.unwrap = unwrap
+
--------------------------------------------------------------------------------
-- Basic Functions
--------------------------------------------------------------------------------
@@ -62,16 +91,28 @@ local map_gen = function(tab, key)
return key, key, value
end
-local iter = function(obj, param, state)
+local rawiter = function(obj, param, state)
assert(obj ~= nil, "invalid iterator")
- if (type(obj) == "function") then
- return obj, param, state
- elseif (type(obj) == "table" or type(obj) == "userdata") then
+ if type(obj) == "table" then
+ local mt = getmetatable(obj);
+ if mt ~= nil then
+ if mt == iterator_mt then
+ return obj.gen, obj.param, obj.state
+ elseif mt.__ipairs ~= nil then
+ return mt.__ipairs(obj)
+ elseif mt.__pairs ~= nil then
+ return mt.__pairs(obj)
+ end
+ end
if #obj > 0 then
+ -- array
return ipairs(obj)
else
+ -- hash
return map_gen, obj, nil
end
+ elseif (type(obj) == "function") then
+ return obj, param, state
elseif (type(obj) == "string") then
if #obj == 0 then
return nil_gen, nil, nil
@@ -82,22 +123,58 @@ local iter = function(obj, param, state)
obj, type(obj)))
end
-local iter_tab = function(obj)
- if type(obj) == "function" then
- return obj, nil, nil
- elseif type(obj) == "table" and type(obj[1]) == "function" then
- return obj[1], obj[2], obj[3]
- else
- return iter(obj)
+local iter = function(obj, param, state)
+ return wrap(rawiter(obj, param, state))
+end
+exports.iter = iter
+
+local method0 = function(fun)
+ return function(self)
+ return fun(self.gen, self.param, self.state)
+ end
+end
+
+local method1 = function(fun)
+ return function(self, arg1)
+ return fun(arg1, self.gen, self.param, self.state)
+ end
+end
+
+local method2 = function(fun)
+ return function(self, arg1, arg2)
+ return fun(arg1, arg2, self.gen, self.param, self.state)
+ end
+end
+
+local export0 = function(fun)
+ return function(gen, param, state)
+ return fun(rawiter(gen, param, state))
+ end
+end
+
+local export1 = function(fun)
+ return function(arg1, gen, param, state)
+ return fun(arg1, rawiter(gen, param, state))
+ end
+end
+
+local export2 = function(fun)
+ return function(arg1, arg2, gen, param, state)
+ return fun(arg1, arg2, rawiter(gen, param, state))
end
end
local each = function(fun, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
repeat
- state_x = call_if_not_empty(fun, gen_x(param_x, state_x))
- until state_x == nil
+ state = call_if_not_empty(fun, gen(param, state))
+ until state == nil
end
+methods.each = method1(each)
+exports.each = export1(each)
+methods.for_each = methods.each
+exports.for_each = exports.each
+methods.foreach = methods.each
+exports.foreach = exports.each
--------------------------------------------------------------------------------
-- Generators
@@ -106,7 +183,7 @@ end
local range_gen = function(param, state)
local stop, step = param[1], param[2]
local state = state + step
- if state >= stop then
+ if state > stop then
return nil
end
return state, state
@@ -115,7 +192,7 @@ end
local range_rev_gen = function(param, state)
local stop, step = param[1], param[2]
local state = state + step
- if state <= stop then
+ if state < stop then
return nil
end
return state, state
@@ -123,11 +200,14 @@ end
local range = function(start, stop, step)
if step == nil then
- step = 1
if stop == nil then
+ if start == 0 then
+ return nil_gen, nil, nil
+ end
stop = start
- start = 0
+ start = stop > 0 and 1 or -1
end
+ step = start <= stop and 1 or -1
end
assert(type(start) == "number", "start must be a number")
@@ -136,11 +216,12 @@ local range = function(start, stop, step)
assert(step ~= 0, "step must not be zero")
if (step > 0) then
- return range_gen, {stop, step}, start - step
+ return wrap(range_gen, {stop, step}, start - step)
elseif (step < 0) then
- return range_rev_gen, {stop, step}, start - step
+ return wrap(range_rev_gen, {stop, step}, start - step)
end
end
+exports.range = range
local duplicate_table_gen = function(param_x, state_x)
return state_x + 1, unpack(param_x)
@@ -156,24 +237,30 @@ end
local duplicate = function(...)
if select('#', ...) <= 1 then
- return duplicate_gen, select(1, ...), 0
+ return wrap(duplicate_gen, select(1, ...), 0)
else
- return duplicate_table_gen, {...}, 0
+ return wrap(duplicate_table_gen, {...}, 0)
end
end
+exports.duplicate = duplicate
+exports.replicate = duplicate
+exports.xrepeat = duplicate
local tabulate = function(fun)
assert(type(fun) == "function")
- return duplicate_fun_gen, fun, 0
+ return wrap(duplicate_fun_gen, fun, 0)
end
+exports.tabulate = tabulate
local zeros = function()
- return duplicate_gen, 0, 0
+ return wrap(duplicate_gen, 0, 0)
end
+exports.zeros = zeros
local ones = function()
- return duplicate_gen, 1, 0
+ return wrap(duplicate_gen, 1, 0)
end
+exports.ones = ones
local rands_gen = function(param_x, _state_x)
return 0, math.random(param_x[1], param_x[2])
@@ -185,7 +272,7 @@ end
local rands = function(n, m)
if n == nil and m == nil then
- return rands_nil_gen, 0, 0
+ return wrap(rands_nil_gen, 0, 0)
end
assert(type(n) == "number", "invalid first arg to rands")
if m == nil then
@@ -195,16 +282,16 @@ local rands = function(n, m)
assert(type(m) == "number", "invalid second arg to rands")
end
assert(n < m, "empty interval")
- return rands_gen, {n, m - 1}, 0
+ return wrap(rands_gen, {n, m - 1}, 0)
end
+exports.rands = rands
--------------------------------------------------------------------------------
-- Slicing
--------------------------------------------------------------------------------
-local nth = function(n, gen, param, state)
+local nth = function(n, gen_x, param_x, state_x)
assert(n > 0, "invalid first argument to nth")
- local gen_x, param_x, state_x = iter(gen, param, state)
-- An optimization for arrays and strings
if gen_x == ipairs then
return param_x[n]
@@ -223,6 +310,8 @@ local nth = function(n, gen, param, state)
end
return return_if_not_empty(gen_x(param_x, state_x))
end
+methods.nth = method1(nth)
+exports.nth = export1(nth)
local head_call = function(state, ...)
if state == nil then
@@ -232,18 +321,24 @@ local head_call = function(state, ...)
end
local head = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return head_call(gen_x(param_x, state_x))
+ return head_call(gen(param, state))
end
+methods.head = method0(head)
+exports.head = export0(head)
+exports.car = exports.head
+methods.car = methods.head
local tail = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- state_x = gen_x(param_x, state_x)
- if state_x == nil then
- return nil_gen, nil, nil
+ state = gen(param, state)
+ if state == nil then
+ return wrap(nil_gen, nil, nil)
end
- return gen_x, param_x, state_x
+ return wrap(gen, param, state)
end
+methods.tail = method0(tail)
+exports.tail = export0(tail)
+exports.cdr = exports.tail
+methods.cdr = methods.tail
local take_n_gen_x = function(i, state_x, ...)
if state_x == nil then
@@ -263,9 +358,10 @@ end
local take_n = function(n, gen, param, state)
assert(n >= 0, "invalid first argument to take_n")
- local gen_x, param_x, state_x = iter(gen, param, state)
- return take_n_gen, {n, gen, param}, {0, state}
+ return wrap(take_n_gen, {n, gen, param}, {0, state})
end
+methods.take_n = method1(take_n)
+exports.take_n = export1(take_n)
local take_while_gen_x = function(fun, state_x, ...)
if state_x == nil or not fun(...) then
@@ -281,29 +377,34 @@ end
local take_while = function(fun, gen, param, state)
assert(type(fun) == "function", "invalid first argument to take_while")
- local gen_x, param_x, state_x = iter(gen, param, state)
- return take_while_gen, {fun, gen, param}, state
+ return wrap(take_while_gen, {fun, gen, param}, state)
end
+methods.take_while = method1(take_while)
+exports.take_while = export1(take_while)
local take = function(n_or_fun, gen, param, state)
if type(n_or_fun) == "number" then
return take_n(n_or_fun, gen, param, state)
- else
+ else
return take_while(n_or_fun, gen, param, state)
end
end
+methods.take = method1(take)
+exports.take = export1(take)
local drop_n = function(n, gen, param, state)
assert(n >= 0, "invalid first argument to drop_n")
- local gen_x, param_x, state_x = iter(gen, param, state)
+ local i
for i=1,n,1 do
- state_x = gen_x(param_x, state_x)
- if state_x == nil then
- return nil_gen, nil, nil
+ state = gen(param, state)
+ if state == nil then
+ return wrap(nil_gen, nil, nil)
end
end
- return gen_x, param_x, state_x
+ return wrap(gen, param, state)
end
+methods.drop_n = method1(drop_n)
+exports.drop_n = export1(drop_n)
local drop_while_x = function(fun, state_x, ...)
if state_x == nil or not fun(...) then
@@ -312,40 +413,49 @@ local drop_while_x = function(fun, state_x, ...)
return state_x, true, ...
end
-local drop_while = function(fun, gen, param, state)
+local drop_while = function(fun, gen_x, param_x, state_x)
assert(type(fun) == "function", "invalid first argument to drop_while")
- local gen_x, param_x, state_x = iter(gen, param, state)
local cont, state_x_prev
repeat
state_x_prev = deepcopy(state_x)
state_x, cont = drop_while_x(fun, gen_x(param_x, state_x))
until not cont
if state_x == nil then
- return nil_gen, nil, nil
+ return wrap(nil_gen, nil, nil)
end
- return gen_x, param_x, state_x_prev
+ return wrap(gen_x, param_x, state_x_prev)
end
+methods.drop_while = method1(drop_while)
+exports.drop_while = export1(drop_while)
-local drop = function(n_or_fun, gen, param, state)
+local drop = function(n_or_fun, gen_x, param_x, state_x)
if type(n_or_fun) == "number" then
- return drop_n(n_or_fun, gen, param, state)
- else
- return drop_while(n_or_fun, gen, param, state)
+ return drop_n(n_or_fun, gen_x, param_x, state_x)
+ else
+ return drop_while(n_or_fun, gen_x, param_x, state_x)
end
end
+methods.drop = method1(drop)
+exports.drop = export1(drop)
-local split = function(n_or_fun, gen, param, state)
- return {take(n_or_fun, gen, param, state)},
- {drop(n_or_fun, gen, param, state)}
+local split = function(n_or_fun, gen_x, param_x, state_x)
+ return take(n_or_fun, gen_x, param_x, state_x),
+ drop(n_or_fun, gen_x, param_x, state_x)
end
+methods.split = method1(split)
+exports.split = export1(split)
+methods.split_at = methods.split
+exports.split_at = exports.split
+methods.span = methods.split
+exports.span = exports.split
--------------------------------------------------------------------------------
-- Indexing
--------------------------------------------------------------------------------
-local index = function(x, gen, param, state)
+local index = function(x, gen, param, state)
local i = 1
- for _k, r in iter(gen, param, state) do
+ for _k, r in gen, param, state do
if r == x then
return i
end
@@ -353,6 +463,12 @@ local index = function(x, gen, param, state)
end
return nil
end
+methods.index = method1(index)
+exports.index = export1(index)
+methods.index_of = methods.index
+exports.index_of = exports.index
+methods.elem_index = methods.index
+exports.elem_index = exports.index
local indexes_gen = function(param, state)
local x, gen_x, param_x = param[1], param[2], param[3]
@@ -371,15 +487,16 @@ local indexes_gen = function(param, state)
end
local indexes = function(x, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return indexes_gen, {x, gen_x, param_x}, {0, state_x}
-end
-
--- TODO: undocumented
-local find = function(fun, gen, param, state)
- local gen_x, param_x, state_x = filter(fun, gen, param, state)
- return return_if_not_empty(gen_x(param_x, state_x))
+ return wrap(indexes_gen, {x, gen, param}, {0, state})
end
+methods.indexes = method1(indexes)
+exports.indexes = export1(indexes)
+methods.elem_indexes = methods.indexes
+exports.elem_indexes = exports.indexes
+methods.indices = methods.indexes
+exports.indices = exports.indexes
+methods.elem_indices = methods.indexes
+exports.elem_indices = exports.indexes
--------------------------------------------------------------------------------
-- Filtering
@@ -423,9 +540,12 @@ local filter_gen = function(param, state_x)
end
local filter = function(fun, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return filter_gen, {fun, gen_x, param_x}, state_x
+ return wrap(filter_gen, {fun, gen, param}, state)
end
+methods.filter = method1(filter)
+exports.filter = export1(filter)
+methods.remove_if = methods.filter
+exports.remove_if = exports.filter
local grep = function(fun_or_regexp, gen, param, state)
local fun = fun_or_regexp
@@ -434,15 +554,18 @@ local grep = function(fun_or_regexp, gen, param, state)
end
return filter(fun, gen, param, state)
end
+methods.grep = method1(grep)
+exports.grep = export1(grep)
local partition = function(fun, gen, param, state)
local neg_fun = function(...)
return not fun(...)
end
- local gen_x, param_x, state_x = iter(gen, param, state)
- return {filter(fun, gen_x, param_x, state_x)},
- {filter(neg_fun, gen_x, param_x, state_x)}
+ return filter(fun, gen, param, state),
+ filter(neg_fun, gen, param, state)
end
+methods.partition = method1(partition)
+exports.partition = export1(partition)
--------------------------------------------------------------------------------
-- Reducing
@@ -455,8 +578,7 @@ local foldl_call = function(fun, start, state, ...)
return state, fun(start, ...)
end
-local foldl = function(fun, start, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local foldl = function(fun, start, gen_x, param_x, state_x)
while true do
state_x, start = foldl_call(fun, start, gen_x(param_x, state_x))
if state_x == nil then
@@ -465,9 +587,12 @@ local foldl = function(fun, start, gen, param, state)
end
return start
end
+methods.foldl = method2(foldl)
+exports.foldl = export2(foldl)
+methods.reduce = methods.foldl
+exports.reduce = exports.foldl
local length = function(gen, param, state)
- local gen, param, state = iter(gen, param, state)
if gen == ipairs or gen == string_gen then
return #param
end
@@ -478,15 +603,18 @@ local length = function(gen, param, state)
until state == nil
return len - 1
end
+methods.length = method0(length)
+exports.length = export0(length)
local is_null = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return gen_x(param_x, deepcopy(state_x)) == nil
+ return gen(param, deepcopy(state)) == nil
end
+methods.is_null = method0(is_null)
+exports.is_null = export0(is_null)
local is_prefix_of = function(iter_x, iter_y)
- local gen_x, param_x, state_x = iter_tab(iter_x)
- local gen_y, param_y, state_y = iter_tab(iter_y)
+ local gen_x, param_x, state_x = iter(iter_x)
+ local gen_y, param_y, state_y = iter(iter_y)
local r_x, r_y
for i=1,10,1 do
@@ -500,27 +628,34 @@ local is_prefix_of = function(iter_x, iter_y)
end
end
end
+methods.is_prefix_of = is_prefix_of
+exports.is_prefix_of = is_prefix_of
-local all = function(fun, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local all = function(fun, gen_x, param_x, state_x)
local r
repeat
state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x))
until state_x == nil or not r
return state_x == nil
end
+methods.all = method1(all)
+exports.all = export1(all)
+methods.every = methods.all
+exports.every = exports.all
-local any = function(fun, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local any = function(fun, gen_x, param_x, state_x)
local r
repeat
state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x))
until state_x == nil or r
return not not r
end
+methods.any = method1(any)
+exports.any = export1(any)
+methods.some = methods.any
+exports.some = exports.any
local sum = function(gen, param, state)
- local gen, param, state = iter(gen, param, state)
local s = 0
local r = 0
repeat
@@ -529,9 +664,10 @@ local sum = function(gen, param, state)
until state == nil
return s
end
+methods.sum = method0(sum)
+exports.sum = export0(sum)
local product = function(gen, param, state)
- local gen, param, state = iter(gen, param, state)
local p = 1
local r = 1
repeat
@@ -540,6 +676,8 @@ local product = function(gen, param, state)
until state == nil
return p
end
+methods.product = method0(product)
+exports.product = export0(product)
local min_cmp = function(m, n)
if n < m then return n else return m end
@@ -550,7 +688,6 @@ local max_cmp = function(m, n)
end
local min = function(gen, param, state)
- local gen, param, state = iter(gen, param, state)
local state, m = gen(param, state)
if state == nil then
error("min: iterator is empty")
@@ -569,9 +706,12 @@ local min = function(gen, param, state)
end
return m
end
+methods.min = method0(min)
+exports.min = export0(min)
+methods.minimum = methods.min
+exports.minimum = exports.min
-local min_by = function(cmp, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local min_by = function(cmp, gen_x, param_x, state_x)
local state_x, m = gen_x(param_x, state_x)
if state_x == nil then
error("min: iterator is empty")
@@ -582,9 +722,12 @@ local min_by = function(cmp, gen, param, state)
end
return m
end
+methods.min_by = method1(min_by)
+exports.min_by = export1(min_by)
+methods.minimum_by = methods.min_by
+exports.minimum_by = exports.min_by
-local max = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local max = function(gen_x, param_x, state_x)
local state_x, m = gen_x(param_x, state_x)
if state_x == nil then
error("max: iterator is empty")
@@ -603,9 +746,12 @@ local max = function(gen, param, state)
end
return m
end
+methods.max = method0(max)
+exports.max = export0(max)
+methods.maximum = methods.max
+exports.maximum = exports.max
-local max_by = function(cmp, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local max_by = function(cmp, gen_x, param_x, state_x)
local state_x, m = gen_x(param_x, state_x)
if state_x == nil then
error("max: iterator is empty")
@@ -616,9 +762,12 @@ local max_by = function(cmp, gen, param, state)
end
return m
end
+methods.max_by = method1(max_by)
+exports.max_by = export1(max_by)
+methods.maximum_by = methods.maximum_by
+exports.maximum_by = exports.maximum_by
-local totable = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local totable = function(gen_x, param_x, state_x)
local tab, key, val = {}
while true do
state_x, val = gen_x(param_x, state_x)
@@ -629,9 +778,10 @@ local totable = function(gen, param, state)
end
return tab
end
+methods.totable = method0(totable)
+exports.totable = export0(totable)
-local tomap = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
+local tomap = function(gen_x, param_x, state_x)
local tab, key, val = {}
while true do
state_x, key, val = gen_x(param_x, state_x)
@@ -642,6 +792,8 @@ local tomap = function(gen, param, state)
end
return tab
end
+methods.tomap = method0(tomap)
+exports.tomap = export0(tomap)
--------------------------------------------------------------------------------
-- Transformations
@@ -653,9 +805,10 @@ local map_gen = function(param, state)
end
local map = function(fun, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return map_gen, {gen_x, param_x, fun}, state_x
+ return wrap(map_gen, {gen, param, fun}, state)
end
+methods.map = method1(map)
+exports.map = export1(map)
local enumerate_gen_call = function(state, i, state_x, ...)
if state_x == nil then
@@ -671,9 +824,10 @@ local enumerate_gen = function(param, state)
end
local enumerate = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return enumerate_gen, {gen_x, param_x}, {0, state_x}
+ return wrap(enumerate_gen, {gen, param}, {1, state})
end
+methods.enumerate = method0(enumerate)
+exports.enumerate = export0(enumerate)
local intersperse_call = function(i, state_x, ...)
if state_x == nil then
@@ -694,9 +848,10 @@ end
-- TODO: interperse must not add x to the tail
local intersperse = function(x, gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return intersperse_gen, {x, gen_x, param_x}, {0, state_x}
+ return wrap(intersperse_gen, {x, gen, param}, {0, state})
end
+methods.intersperse = method1(intersperse)
+exports.intersperse = export1(intersperse)
--------------------------------------------------------------------------------
-- Compositions
@@ -710,7 +865,6 @@ local function zip_gen_r(param, state, state_new, ...)
local i = #state_new + 1
local gen_x, param_x = param[2 * i - 1], param[2 * i]
local state_x, r = gen_x(param_x, state[i])
- -- print('i', i, 'state_x', state_x, 'r', r)
if state_x == nil then
return nil
end
@@ -722,25 +876,42 @@ local zip_gen = function(param, state)
return zip_gen_r(param, state, {})
end
-local zip = function(...)
+-- A special hack for zip/chain to skip last two state, if a wrapped iterator
+-- has been passed
+local numargs = function(...)
local n = select('#', ...)
+ if n >= 3 then
+ -- Fix last argument
+ local it = select(n - 2, ...)
+ if type(it) == 'table' and getmetatable(it) == iterator_mt and
+ it.param == select(n - 1, ...) and it.state == select(n, ...) then
+ return n - 2
+ end
+ end
+ return n
+end
+
+local zip = function(...)
+ local n = numargs(...)
if n == 0 then
- return nil_gen, nil, nil
+ return wrap(nil_gen, nil, nil)
end
local param = { [2 * n] = 0 }
local state = { [n] = 0 }
local i, gen_x, param_x, state_x
for i=1,n,1 do
- local elem = select(n - i + 1, ...)
- gen_x, param_x, state_x = iter_tab(elem)
+ local it = select(n - i + 1, ...)
+ gen_x, param_x, state_x = rawiter(it)
param[2 * i - 1] = gen_x
param[2 * i] = param_x
state[i] = state_x
end
- return zip_gen, param, state
+ return wrap(zip_gen, param, state)
end
+methods.zip = zip
+exports.zip = zip
local cycle_gen_call = function(param, state_x, ...)
if state_x == nil then
@@ -756,9 +927,10 @@ local cycle_gen = function(param, state_x)
end
local cycle = function(gen, param, state)
- local gen_x, param_x, state_x = iter(gen, param, state)
- return cycle_gen, {gen_x, param_x, state_x}, deepcopy(state_x)
+ return wrap(cycle_gen, {gen, param, state}, deepcopy(state))
end
+methods.cycle = method0(cycle)
+exports.cycle = export0(cycle)
-- call each other
local chain_gen_r1
@@ -782,23 +954,25 @@ chain_gen_r1 = function(param, state)
end
local chain = function(...)
- local n = select('#', ...)
+ local n = numargs(...)
if n == 0 then
- return nil_gen, nil, nil
+ return wrap(nil_gen, nil, nil)
end
local param = { [3 * n] = 0 }
local i, gen_x, param_x, state_x
for i=1,n,1 do
local elem = select(i, ...)
- gen_x, param_x, state_x = iter_tab(elem)
+ gen_x, param_x, state_x = iter(elem)
param[3 * i - 2] = gen_x
param[3 * i - 1] = param_x
param[3 * i] = state_x
end
- return chain_gen_r1, param, {1, param[3]}
+ return wrap(chain_gen_r1, param, {1, param[3]})
end
+methods.chain = chain
+exports.chain = chain
--------------------------------------------------------------------------------
-- Operators
@@ -848,116 +1022,15 @@ operator = {
lnot = function(a) return not a end,
truth = function(a) return not not a end,
}
+exports.operator = operator
+methods.operator = operator
+exports.op = operator
+methods.op = operator
--------------------------------------------------------------------------------
-- module definitions
--------------------------------------------------------------------------------
-local exports = {
- ----------------------------------------------------------------------------
- -- Basic
- ----------------------------------------------------------------------------
- iter = iter,
- each = each,
- for_each = each, -- an alias
- foreach = each, -- an alias
-
- ----------------------------------------------------------------------------
- -- Generators
- ----------------------------------------------------------------------------
- range = range,
- duplicate = duplicate,
- xrepeat = duplicate, -- an alias
- replicate = duplicate, -- an alias
- tabulate = tabulate,
- ones = ones,
- zeros = zeros,
- rands = rands,
-
- ----------------------------------------------------------------------------
- -- Slicing
- ----------------------------------------------------------------------------
- nth = nth,
- head = head,
- car = head, -- an alias
- tail = tail,
- cdr = tail, -- an alias
- take_n = take_n,
- take_while = take_while,
- take = take,
- drop_n = drop_n,
- drop_while = drop_while,
- drop = drop,
- split = split,
- split_at = split, -- an alias
- span = split, -- an alias
-
- ----------------------------------------------------------------------------
- -- Indexing
- ----------------------------------------------------------------------------
- index = index,
- index_of = index, -- an alias
- elem_index = index, -- an alias
- indexes = indexes,
- indices = indexes, -- an alias
- elem_indexes = indexes, -- an alias
- elem_indices = indexes, -- an alias
- find = find,
-
- ----------------------------------------------------------------------------
- -- Filtering
- ----------------------------------------------------------------------------
- filter = filter,
- remove_if = filter, -- an alias
- grep = grep,
- partition = partition,
-
- ----------------------------------------------------------------------------
- -- Reducing
- ----------------------------------------------------------------------------
- foldl = foldl,
- reduce = foldl, -- an alias
- length = length,
- is_null = is_null,
- is_prefix_of = is_prefix_of,
- all = all,
- every = all, -- an alias
- any = any,
- some = any, -- an alias
- sum = sum,
- product = product,
- min = min,
- minimum = min, -- an alias
- min_by = min_by,
- minimum_by = min_by, -- an alias
- max = max,
- maximum = max, -- an alias
- max_by = max_by,
- maximum_by = max_by, -- an alias
- totable = totable,
- tomap = tomap,
-
- ----------------------------------------------------------------------------
- -- Transformations
- ----------------------------------------------------------------------------
- map = map,
- enumerate = enumerate,
- intersperse = intersperse,
-
- ----------------------------------------------------------------------------
- -- Compositions
- ----------------------------------------------------------------------------
- zip = zip,
- cycle = cycle,
- chain = chain,
-
- ----------------------------------------------------------------------------
- -- Operators
- ----------------------------------------------------------------------------
- operator = operator,
- op = operator -- an alias
-}
-
-- a special syntax sugar to export all functions to the global table
setmetatable(exports, {
__call = function(t)