]> source.dussan.org Git - rspamd.git/commitdiff
Import the proper version of lua-functional.
authorVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 26 Feb 2015 14:06:01 +0000 (14:06 +0000)
committerVsevolod Stakhov <vsevolod@highsecure.ru>
Thu, 26 Feb 2015 14:06:01 +0000 (14:06 +0000)
contrib/lua-fun/fun.lua

index 5d019d6a1c8a9f02928db77cf15dfc5c1748fd41..6a536d7594909cea742061a31e0a585f0e5e5033 100644 (file)
@@ -1,11 +1,14 @@
 ---
 --- Lua Fun - a high-performance functional programming library for LuaJIT
 ---
---- Copyright (c) 2013 Roman Tsisyk <roman@tsisyk.com>
+--- Copyright (c) 2013-2014 Roman Tsisyk <roman@tsisyk.com>
 ---
 --- Distributed under the MIT/X11 License. See COPYING.md for more details.
 ---
 
+local exports = {}
+local methods = {}
+
 --------------------------------------------------------------------------------
 -- Tools
 --------------------------------------------------------------------------------
@@ -38,6 +41,32 @@ local function deepcopy(orig) -- used by cycle()
     return copy
 end
 
+local iterator_mt = {
+    -- usually called by for-in loop
+    __call = function(self, param, state)
+        return self.gen(param, state)
+    end;
+    __tostring = function(self)
+        return '<generator>'
+    end;
+    -- add all exported methods
+    __index = methods;
+}
+
+local wrap = function(gen, param, state)
+    return setmetatable({
+        gen = gen,
+        param = param,
+        state = state
+    }, iterator_mt), param, state
+end
+exports.wrap = wrap
+
+local unwrap = function(self)
+    return self.gen, self.param, self.state
+end
+methods.unwrap = unwrap
+
 --------------------------------------------------------------------------------
 -- Basic Functions
 --------------------------------------------------------------------------------
@@ -62,16 +91,28 @@ local map_gen = function(tab, key)
     return key, key, value
 end
 
-local iter = function(obj, param, state)
+local rawiter = function(obj, param, state)
     assert(obj ~= nil, "invalid iterator")
-    if (type(obj) == "function") then
-        return obj, param, state
-    elseif (type(obj) == "table" or type(obj) == "userdata") then
+    if type(obj) == "table" then
+        local mt = getmetatable(obj);
+        if mt ~= nil then
+            if mt == iterator_mt then
+                return obj.gen, obj.param, obj.state
+            elseif mt.__ipairs ~= nil then
+                return mt.__ipairs(obj)
+            elseif mt.__pairs ~= nil then
+                return mt.__pairs(obj)
+            end
+        end
         if #obj > 0 then
+            -- array
             return ipairs(obj)
         else
+            -- hash
             return map_gen, obj, nil
         end
+    elseif (type(obj) == "function") then
+        return obj, param, state
     elseif (type(obj) == "string") then
         if #obj == 0 then
             return nil_gen, nil, nil
@@ -82,22 +123,58 @@ local iter = function(obj, param, state)
           obj, type(obj)))
 end
 
-local iter_tab = function(obj)
-    if type(obj) == "function" then
-       return obj, nil, nil
-    elseif type(obj) == "table" and type(obj[1]) == "function" then
-        return obj[1], obj[2], obj[3]
-    else
-        return iter(obj)
+local iter = function(obj, param, state)
+    return wrap(rawiter(obj, param, state))
+end
+exports.iter = iter
+
+local method0 = function(fun)
+    return function(self)
+        return fun(self.gen, self.param, self.state)
+    end
+end
+
+local method1 = function(fun)
+    return function(self, arg1)
+        return fun(arg1, self.gen, self.param, self.state)
+    end
+end
+
+local method2 = function(fun)
+    return function(self, arg1, arg2)
+        return fun(arg1, arg2, self.gen, self.param, self.state)
+    end
+end
+
+local export0 = function(fun)
+    return function(gen, param, state)
+        return fun(rawiter(gen, param, state))
+    end
+end
+
+local export1 = function(fun)
+    return function(arg1, gen, param, state)
+        return fun(arg1, rawiter(gen, param, state))
+    end
+end
+
+local export2 = function(fun)
+    return function(arg1, arg2, gen, param, state)
+        return fun(arg1, arg2, rawiter(gen, param, state))
     end
 end
 
 local each = function(fun, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
     repeat
-        state_x = call_if_not_empty(fun, gen_x(param_x, state_x))
-    until state_x == nil
+        state = call_if_not_empty(fun, gen(param, state))
+    until state == nil
 end
+methods.each = method1(each)
+exports.each = export1(each)
+methods.for_each = methods.each
+exports.for_each = exports.each
+methods.foreach = methods.each
+exports.foreach = exports.each
 
 --------------------------------------------------------------------------------
 -- Generators
@@ -106,7 +183,7 @@ end
 local range_gen = function(param, state)
     local stop, step = param[1], param[2]
     local state = state + step
-    if state >= stop then
+    if state > stop then
         return nil
     end
     return state, state
@@ -115,7 +192,7 @@ end
 local range_rev_gen = function(param, state)
     local stop, step = param[1], param[2]
     local state = state + step
-    if state <= stop then
+    if state < stop then
         return nil
     end
     return state, state
@@ -123,11 +200,14 @@ end
 
 local range = function(start, stop, step)
     if step == nil then
-        step = 1
         if stop == nil then
+            if start == 0 then
+                return nil_gen, nil, nil
+            end
             stop = start
-            start = 0
+            start = stop > 0 and 1 or -1
         end
+        step = start <= stop and 1 or -1
     end
 
     assert(type(start) == "number", "start must be a number")
@@ -136,11 +216,12 @@ local range = function(start, stop, step)
     assert(step ~= 0, "step must not be zero")
 
     if (step > 0) then
-        return range_gen, {stop, step}, start - step
+        return wrap(range_gen, {stop, step}, start - step)
     elseif (step < 0) then
-        return range_rev_gen, {stop, step}, start - step
+        return wrap(range_rev_gen, {stop, step}, start - step)
     end
 end
+exports.range = range
 
 local duplicate_table_gen = function(param_x, state_x)
     return state_x + 1, unpack(param_x)
@@ -156,24 +237,30 @@ end
 
 local duplicate = function(...)
     if select('#', ...) <= 1 then
-        return duplicate_gen, select(1, ...), 0
+        return wrap(duplicate_gen, select(1, ...), 0)
     else
-        return duplicate_table_gen, {...}, 0 
+        return wrap(duplicate_table_gen, {...}, 0)
     end
 end
+exports.duplicate = duplicate
+exports.replicate = duplicate
+exports.xrepeat = duplicate
 
 local tabulate = function(fun)
     assert(type(fun) == "function")
-    return duplicate_fun_gen, fun, 0
+    return wrap(duplicate_fun_gen, fun, 0)
 end
+exports.tabulate = tabulate
 
 local zeros = function()
-    return duplicate_gen, 0, 0
+    return wrap(duplicate_gen, 0, 0)
 end
+exports.zeros = zeros
 
 local ones = function()
-    return duplicate_gen, 1, 0
+    return wrap(duplicate_gen, 1, 0)
 end
+exports.ones = ones
 
 local rands_gen = function(param_x, _state_x)
     return 0, math.random(param_x[1], param_x[2])
@@ -185,7 +272,7 @@ end
 
 local rands = function(n, m)
     if n == nil and m == nil then
-        return rands_nil_gen, 0, 0
+        return wrap(rands_nil_gen, 0, 0)
     end
     assert(type(n) == "number", "invalid first arg to rands")
     if m == nil then
@@ -195,16 +282,16 @@ local rands = function(n, m)
         assert(type(m) == "number", "invalid second arg to rands")
     end
     assert(n < m, "empty interval")
-    return rands_gen, {n, m - 1}, 0
+    return wrap(rands_gen, {n, m - 1}, 0)
 end
+exports.rands = rands
 
 --------------------------------------------------------------------------------
 -- Slicing
 --------------------------------------------------------------------------------
 
-local nth = function(n, gen, param, state)
+local nth = function(n, gen_x, param_x, state_x)
     assert(n > 0, "invalid first argument to nth")
-    local gen_x, param_x, state_x = iter(gen, param, state)
     -- An optimization for arrays and strings
     if gen_x == ipairs then
         return param_x[n]
@@ -223,6 +310,8 @@ local nth = function(n, gen, param, state)
     end
     return return_if_not_empty(gen_x(param_x, state_x))
 end
+methods.nth = method1(nth)
+exports.nth = export1(nth)
 
 local head_call = function(state, ...)
     if state == nil then
@@ -232,18 +321,24 @@ local head_call = function(state, ...)
 end
 
 local head = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return head_call(gen_x(param_x, state_x))
+    return head_call(gen(param, state))
 end
+methods.head = method0(head)
+exports.head = export0(head)
+exports.car = exports.head
+methods.car = methods.head
 
 local tail = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    state_x = gen_x(param_x, state_x)
-    if state_x == nil then
-        return nil_gen, nil, nil
+    state = gen(param, state)
+    if state == nil then
+        return wrap(nil_gen, nil, nil)
     end
-    return gen_x, param_x, state_x
+    return wrap(gen, param, state)
 end
+methods.tail = method0(tail)
+exports.tail = export0(tail)
+exports.cdr = exports.tail
+methods.cdr = methods.tail
 
 local take_n_gen_x = function(i, state_x, ...)
     if state_x == nil then
@@ -263,9 +358,10 @@ end
 
 local take_n = function(n, gen, param, state)
     assert(n >= 0, "invalid first argument to take_n")
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return take_n_gen, {n, gen, param}, {0, state}
+    return wrap(take_n_gen, {n, gen, param}, {0, state})
 end
+methods.take_n = method1(take_n)
+exports.take_n = export1(take_n)
 
 local take_while_gen_x = function(fun, state_x, ...)
     if state_x == nil or not fun(...) then
@@ -281,29 +377,34 @@ end
 
 local take_while = function(fun, gen, param, state)
     assert(type(fun) == "function", "invalid first argument to take_while")
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return take_while_gen, {fun, gen, param}, state
+    return wrap(take_while_gen, {fun, gen, param}, state)
 end
+methods.take_while = method1(take_while)
+exports.take_while = export1(take_while)
 
 local take = function(n_or_fun, gen, param, state)
     if type(n_or_fun) == "number" then
         return take_n(n_or_fun, gen, param, state)
-    else 
+    else
         return take_while(n_or_fun, gen, param, state)
     end
 end
+methods.take = method1(take)
+exports.take = export1(take)
 
 local drop_n = function(n, gen, param, state)
     assert(n >= 0, "invalid first argument to drop_n")
-    local gen_x, param_x, state_x = iter(gen, param, state)
+    local i
     for i=1,n,1 do
-        state_x = gen_x(param_x, state_x)
-        if state_x == nil then
-            return nil_gen, nil, nil
+        state = gen(param, state)
+        if state == nil then
+            return wrap(nil_gen, nil, nil)
         end
     end
-    return gen_x, param_x, state_x
+    return wrap(gen, param, state)
 end
+methods.drop_n = method1(drop_n)
+exports.drop_n = export1(drop_n)
 
 local drop_while_x = function(fun, state_x, ...)
     if state_x == nil or not fun(...) then
@@ -312,40 +413,49 @@ local drop_while_x = function(fun, state_x, ...)
     return state_x, true, ...
 end
 
-local drop_while = function(fun, gen, param, state)
+local drop_while = function(fun, gen_x, param_x, state_x)
     assert(type(fun) == "function", "invalid first argument to drop_while")
-    local gen_x, param_x, state_x = iter(gen, param, state)
     local cont, state_x_prev
     repeat
         state_x_prev = deepcopy(state_x)
         state_x, cont = drop_while_x(fun, gen_x(param_x, state_x))
     until not cont
     if state_x == nil then
-        return nil_gen, nil, nil
+        return wrap(nil_gen, nil, nil)
     end
-    return gen_x, param_x, state_x_prev
+    return wrap(gen_x, param_x, state_x_prev)
 end
+methods.drop_while = method1(drop_while)
+exports.drop_while = export1(drop_while)
 
-local drop = function(n_or_fun, gen, param, state)
+local drop = function(n_or_fun, gen_x, param_x, state_x)
     if type(n_or_fun) == "number" then
-        return drop_n(n_or_fun, gen, param, state)
-    else 
-        return drop_while(n_or_fun, gen, param, state)
+        return drop_n(n_or_fun, gen_x, param_x, state_x)
+    else
+        return drop_while(n_or_fun, gen_x, param_x, state_x)
     end
 end
+methods.drop = method1(drop)
+exports.drop = export1(drop)
 
-local split = function(n_or_fun, gen, param, state)
-    return {take(n_or_fun, gen, param, state)},
-           {drop(n_or_fun, gen, param, state)}
+local split = function(n_or_fun, gen_x, param_x, state_x)
+    return take(n_or_fun, gen_x, param_x, state_x),
+           drop(n_or_fun, gen_x, param_x, state_x)
 end
+methods.split = method1(split)
+exports.split = export1(split)
+methods.split_at = methods.split
+exports.split_at = exports.split
+methods.span = methods.split
+exports.span = exports.split
 
 --------------------------------------------------------------------------------
 -- Indexing
 --------------------------------------------------------------------------------
 
-local index = function(x, gen, param, state) 
+local index = function(x, gen, param, state)
     local i = 1
-    for _k, r in iter(gen, param, state) do
+    for _k, r in gen, param, state do
         if r == x then
             return i
         end
@@ -353,6 +463,12 @@ local index = function(x, gen, param, state)
     end
     return nil
 end
+methods.index = method1(index)
+exports.index = export1(index)
+methods.index_of = methods.index
+exports.index_of = exports.index
+methods.elem_index = methods.index
+exports.elem_index = exports.index
 
 local indexes_gen = function(param, state)
     local x, gen_x, param_x = param[1], param[2], param[3]
@@ -371,15 +487,16 @@ local indexes_gen = function(param, state)
 end
 
 local indexes = function(x, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return indexes_gen, {x, gen_x, param_x}, {0, state_x}
-end
-
--- TODO: undocumented
-local find = function(fun, gen, param, state)
-    local gen_x, param_x, state_x = filter(fun, gen, param, state)
-    return return_if_not_empty(gen_x(param_x, state_x))
+    return wrap(indexes_gen, {x, gen, param}, {0, state})
 end
+methods.indexes = method1(indexes)
+exports.indexes = export1(indexes)
+methods.elem_indexes = methods.indexes
+exports.elem_indexes = exports.indexes
+methods.indices = methods.indexes
+exports.indices = exports.indexes
+methods.elem_indices = methods.indexes
+exports.elem_indices = exports.indexes
 
 --------------------------------------------------------------------------------
 -- Filtering
@@ -423,9 +540,12 @@ local filter_gen = function(param, state_x)
 end
 
 local filter = function(fun, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return filter_gen, {fun, gen_x, param_x}, state_x
+    return wrap(filter_gen, {fun, gen, param}, state)
 end
+methods.filter = method1(filter)
+exports.filter = export1(filter)
+methods.remove_if = methods.filter
+exports.remove_if = exports.filter
 
 local grep = function(fun_or_regexp, gen, param, state)
     local fun = fun_or_regexp
@@ -434,15 +554,18 @@ local grep = function(fun_or_regexp, gen, param, state)
     end
     return filter(fun, gen, param, state)
 end
+methods.grep = method1(grep)
+exports.grep = export1(grep)
 
 local partition = function(fun, gen, param, state)
     local neg_fun = function(...)
         return not fun(...)
     end
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return {filter(fun, gen_x, param_x, state_x)},
-           {filter(neg_fun, gen_x, param_x, state_x)}
+    return filter(fun, gen, param, state),
+           filter(neg_fun, gen, param, state)
 end
+methods.partition = method1(partition)
+exports.partition = export1(partition)
 
 --------------------------------------------------------------------------------
 -- Reducing
@@ -455,8 +578,7 @@ local foldl_call = function(fun, start, state, ...)
     return state, fun(start, ...)
 end
 
-local foldl = function(fun, start, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local foldl = function(fun, start, gen_x, param_x, state_x)
     while true do
         state_x, start = foldl_call(fun, start, gen_x(param_x, state_x))
         if state_x == nil then
@@ -465,9 +587,12 @@ local foldl = function(fun, start, gen, param, state)
     end
     return start
 end
+methods.foldl = method2(foldl)
+exports.foldl = export2(foldl)
+methods.reduce = methods.foldl
+exports.reduce = exports.foldl
 
 local length = function(gen, param, state)
-    local gen, param, state = iter(gen, param, state)
     if gen == ipairs or gen == string_gen then
         return #param
     end
@@ -478,15 +603,18 @@ local length = function(gen, param, state)
     until state == nil
     return len - 1
 end
+methods.length = method0(length)
+exports.length = export0(length)
 
 local is_null = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return gen_x(param_x, deepcopy(state_x)) == nil
+    return gen(param, deepcopy(state)) == nil
 end
+methods.is_null = method0(is_null)
+exports.is_null = export0(is_null)
 
 local is_prefix_of = function(iter_x, iter_y)
-    local gen_x, param_x, state_x = iter_tab(iter_x)
-    local gen_y, param_y, state_y = iter_tab(iter_y)
+    local gen_x, param_x, state_x = iter(iter_x)
+    local gen_y, param_y, state_y = iter(iter_y)
 
     local r_x, r_y
     for i=1,10,1 do
@@ -500,27 +628,34 @@ local is_prefix_of = function(iter_x, iter_y)
         end
     end
 end
+methods.is_prefix_of = is_prefix_of
+exports.is_prefix_of = is_prefix_of
 
-local all = function(fun, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local all = function(fun, gen_x, param_x, state_x)
     local r
     repeat
         state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x))
     until state_x == nil or not r
     return state_x == nil
 end
+methods.all = method1(all)
+exports.all = export1(all)
+methods.every = methods.all
+exports.every = exports.all
 
-local any = function(fun, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local any = function(fun, gen_x, param_x, state_x)
     local r
     repeat
         state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x))
     until state_x == nil or r
     return not not r
 end
+methods.any = method1(any)
+exports.any = export1(any)
+methods.some = methods.any
+exports.some = exports.any
 
 local sum = function(gen, param, state)
-    local gen, param, state = iter(gen, param, state)
     local s = 0
     local r = 0
     repeat
@@ -529,9 +664,10 @@ local sum = function(gen, param, state)
     until state == nil
     return s
 end
+methods.sum = method0(sum)
+exports.sum = export0(sum)
 
 local product = function(gen, param, state)
-    local gen, param, state = iter(gen, param, state)
     local p = 1
     local r = 1
     repeat
@@ -540,6 +676,8 @@ local product = function(gen, param, state)
     until state == nil
     return p
 end
+methods.product = method0(product)
+exports.product = export0(product)
 
 local min_cmp = function(m, n)
     if n < m then return n else return m end
@@ -550,7 +688,6 @@ local max_cmp = function(m, n)
 end
 
 local min = function(gen, param, state)
-    local gen, param, state = iter(gen, param, state)
     local state, m = gen(param, state)
     if state == nil then
         error("min: iterator is empty")
@@ -569,9 +706,12 @@ local min = function(gen, param, state)
     end
     return m
 end
+methods.min = method0(min)
+exports.min = export0(min)
+methods.minimum = methods.min
+exports.minimum = exports.min
 
-local min_by = function(cmp, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local min_by = function(cmp, gen_x, param_x, state_x)
     local state_x, m = gen_x(param_x, state_x)
     if state_x == nil then
         error("min: iterator is empty")
@@ -582,9 +722,12 @@ local min_by = function(cmp, gen, param, state)
     end
     return m
 end
+methods.min_by = method1(min_by)
+exports.min_by = export1(min_by)
+methods.minimum_by = methods.min_by
+exports.minimum_by = exports.min_by
 
-local max = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local max = function(gen_x, param_x, state_x)
     local state_x, m = gen_x(param_x, state_x)
     if state_x == nil then
         error("max: iterator is empty")
@@ -603,9 +746,12 @@ local max = function(gen, param, state)
     end
     return m
 end
+methods.max = method0(max)
+exports.max = export0(max)
+methods.maximum = methods.max
+exports.maximum = exports.max
 
-local max_by = function(cmp, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local max_by = function(cmp, gen_x, param_x, state_x)
     local state_x, m = gen_x(param_x, state_x)
     if state_x == nil then
         error("max: iterator is empty")
@@ -616,9 +762,12 @@ local max_by = function(cmp, gen, param, state)
     end
     return m
 end
+methods.max_by = method1(max_by)
+exports.max_by = export1(max_by)
+methods.maximum_by = methods.maximum_by
+exports.maximum_by = exports.maximum_by
 
-local totable = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local totable = function(gen_x, param_x, state_x)
     local tab, key, val = {}
     while true do
         state_x, val = gen_x(param_x, state_x)
@@ -629,9 +778,10 @@ local totable = function(gen, param, state)
     end
     return tab
 end
+methods.totable = method0(totable)
+exports.totable = export0(totable)
 
-local tomap = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
+local tomap = function(gen_x, param_x, state_x)
     local tab, key, val = {}
     while true do
         state_x, key, val = gen_x(param_x, state_x)
@@ -642,6 +792,8 @@ local tomap = function(gen, param, state)
     end
     return tab
 end
+methods.tomap = method0(tomap)
+exports.tomap = export0(tomap)
 
 --------------------------------------------------------------------------------
 -- Transformations
@@ -653,9 +805,10 @@ local map_gen = function(param, state)
 end
 
 local map = function(fun, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return map_gen, {gen_x, param_x, fun}, state_x
+    return wrap(map_gen, {gen, param, fun}, state)
 end
+methods.map = method1(map)
+exports.map = export1(map)
 
 local enumerate_gen_call = function(state, i, state_x, ...)
     if state_x == nil then
@@ -671,9 +824,10 @@ local enumerate_gen = function(param, state)
 end
 
 local enumerate = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return enumerate_gen, {gen_x, param_x}, {0, state_x}
+    return wrap(enumerate_gen, {gen, param}, {1, state})
 end
+methods.enumerate = method0(enumerate)
+exports.enumerate = export0(enumerate)
 
 local intersperse_call = function(i, state_x, ...)
     if state_x == nil then
@@ -694,9 +848,10 @@ end
 
 -- TODO: interperse must not add x to the tail
 local intersperse = function(x, gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return intersperse_gen, {x, gen_x, param_x}, {0, state_x}
+    return wrap(intersperse_gen, {x, gen, param}, {0, state})
 end
+methods.intersperse = method1(intersperse)
+exports.intersperse = export1(intersperse)
 
 --------------------------------------------------------------------------------
 -- Compositions
@@ -710,7 +865,6 @@ local function zip_gen_r(param, state, state_new, ...)
     local i = #state_new + 1
     local gen_x, param_x = param[2 * i - 1], param[2 * i]
     local state_x, r = gen_x(param_x, state[i])
-    -- print('i', i, 'state_x', state_x, 'r', r)
     if state_x == nil then
         return nil
     end
@@ -722,25 +876,42 @@ local zip_gen = function(param, state)
     return zip_gen_r(param, state, {})
 end
 
-local zip = function(...)
+-- A special hack for zip/chain to skip last two state, if a wrapped iterator
+-- has been passed
+local numargs = function(...)
     local n = select('#', ...)
+    if n >= 3 then
+        -- Fix last argument
+        local it = select(n - 2, ...)
+        if type(it) == 'table' and getmetatable(it) == iterator_mt and
+           it.param == select(n - 1, ...) and it.state == select(n, ...) then
+            return n - 2
+        end
+    end
+    return n
+end
+
+local zip = function(...)
+    local n = numargs(...)
     if n == 0 then
-        return nil_gen, nil, nil
+        return wrap(nil_gen, nil, nil)
     end
     local param = { [2 * n] = 0 }
     local state = { [n] = 0 }
 
     local i, gen_x, param_x, state_x
     for i=1,n,1 do
-        local elem = select(n - i + 1, ...)
-        gen_x, param_x, state_x = iter_tab(elem)
+        local it = select(n - i + 1, ...)
+        gen_x, param_x, state_x = rawiter(it)
         param[2 * i - 1] = gen_x
         param[2 * i] = param_x
         state[i] = state_x
     end
 
-    return zip_gen, param, state
+    return wrap(zip_gen, param, state)
 end
+methods.zip = zip
+exports.zip = zip
 
 local cycle_gen_call = function(param, state_x, ...)
     if state_x == nil then
@@ -756,9 +927,10 @@ local cycle_gen = function(param, state_x)
 end
 
 local cycle = function(gen, param, state)
-    local gen_x, param_x, state_x = iter(gen, param, state)
-    return cycle_gen, {gen_x, param_x, state_x}, deepcopy(state_x)
+    return wrap(cycle_gen, {gen, param, state}, deepcopy(state))
 end
+methods.cycle = method0(cycle)
+exports.cycle = export0(cycle)
 
 -- call each other
 local chain_gen_r1
@@ -782,23 +954,25 @@ chain_gen_r1 = function(param, state)
 end
 
 local chain = function(...)
-    local n = select('#', ...)
+    local n = numargs(...)
     if n == 0 then
-        return nil_gen, nil, nil
+        return wrap(nil_gen, nil, nil)
     end
 
     local param = { [3 * n] = 0 }
     local i, gen_x, param_x, state_x
     for i=1,n,1 do
         local elem = select(i, ...)
-        gen_x, param_x, state_x = iter_tab(elem)
+        gen_x, param_x, state_x = iter(elem)
         param[3 * i - 2] = gen_x
         param[3 * i - 1] = param_x
         param[3 * i] = state_x
     end
 
-    return chain_gen_r1, param, {1, param[3]}
+    return wrap(chain_gen_r1, param, {1, param[3]})
 end
+methods.chain = chain
+exports.chain = chain
 
 --------------------------------------------------------------------------------
 -- Operators
@@ -848,116 +1022,15 @@ operator = {
     lnot = function(a) return not a end,
     truth = function(a) return not not a end,
 }
+exports.operator = operator
+methods.operator = operator
+exports.op = operator
+methods.op = operator
 
 --------------------------------------------------------------------------------
 -- module definitions
 --------------------------------------------------------------------------------
 
-local exports = {
-    ----------------------------------------------------------------------------
-    -- Basic
-    ----------------------------------------------------------------------------
-    iter = iter,
-    each = each,
-    for_each = each, -- an alias
-    foreach = each, -- an alias
-
-    ----------------------------------------------------------------------------
-    -- Generators
-    ----------------------------------------------------------------------------
-    range = range,
-    duplicate = duplicate,
-    xrepeat = duplicate, -- an alias
-    replicate = duplicate, -- an alias
-    tabulate = tabulate,
-    ones = ones,
-    zeros = zeros,
-    rands = rands,
-
-    ----------------------------------------------------------------------------
-    -- Slicing
-    ----------------------------------------------------------------------------
-    nth = nth,
-    head = head,
-    car = head, -- an alias
-    tail = tail,
-    cdr = tail, -- an alias
-    take_n = take_n,
-    take_while = take_while,
-    take = take,
-    drop_n = drop_n,
-    drop_while = drop_while,
-    drop = drop,
-    split = split,
-    split_at = split, -- an alias
-    span = split, -- an alias
-
-    ----------------------------------------------------------------------------
-    -- Indexing
-    ----------------------------------------------------------------------------
-    index = index,
-    index_of = index, -- an alias
-    elem_index = index, -- an alias
-    indexes = indexes,
-    indices = indexes, -- an alias
-    elem_indexes = indexes, -- an alias
-    elem_indices = indexes, -- an alias
-    find = find,
-
-    ----------------------------------------------------------------------------
-    -- Filtering
-    ----------------------------------------------------------------------------
-    filter = filter,
-    remove_if = filter, -- an alias
-    grep = grep,
-    partition = partition,
-
-    ----------------------------------------------------------------------------
-    -- Reducing
-    ----------------------------------------------------------------------------
-    foldl = foldl,
-    reduce = foldl, -- an alias
-    length = length,
-    is_null = is_null,
-    is_prefix_of = is_prefix_of,
-    all = all,
-    every = all, -- an alias
-    any = any,
-    some = any, -- an alias
-    sum = sum,
-    product = product,
-    min = min,
-    minimum = min, -- an alias
-    min_by = min_by,
-    minimum_by = min_by, -- an alias
-    max = max,
-    maximum = max, -- an alias
-    max_by = max_by,
-    maximum_by = max_by, -- an alias
-    totable = totable,
-    tomap = tomap,
-
-    ----------------------------------------------------------------------------
-    -- Transformations
-    ----------------------------------------------------------------------------
-    map = map,
-    enumerate = enumerate,
-    intersperse = intersperse,
-
-    ----------------------------------------------------------------------------
-    -- Compositions
-    ----------------------------------------------------------------------------
-    zip = zip,
-    cycle = cycle,
-    chain = chain,
-
-    ----------------------------------------------------------------------------
-    -- Operators
-    ----------------------------------------------------------------------------
-    operator = operator,
-    op = operator -- an alias
-}
-
 -- a special syntax sugar to export all functions to the global table
 setmetatable(exports, {
     __call = function(t)