Vararg The Second Class Citizen |
|
...
" [1] are not [first class] objects in Lua 5.1, and this leads to some limitations in expression. Some of these issues and their workarounds are given here.
Lua 5.1 vararg (...
) handling is a bit limited. For example, it doesn't permit things like this:
function tuple(...) return function() return ... end end --Gives error "cannot use '...' outside a vararg function near '...'"
(Some comments on that are in LuaList:2007-03/msg00249.html .)
You might want to use such a function to temporarily store away the return values of a function call, do some other work, and then retrieve those stored return values again. The following function would use this hypothetical tuple
function to add trace statements around a given function:
--Wraps a function with trace statements. function trace(f) return function(...) print("begin", f) local result = tuple(f(...)) print("end", f) return result() end end test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end) print("returns:", test(2,3,nil)) -- Desired Output: -- begin function: 0x687350 -- calc 2 3 nil -- end function: 0x687350 -- returns: 5 nil
Still, there are ways to achieve this in Lua.
{...}
and unpack
You could instead implement trace
with the table construct {...}
and unpack
:
--Wraps a function with trace statements. function trace(f) return function(...) print("begin", f) local result = {f(...)} print("end", f) return unpack(result) end end test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end) print("returns:", test(2,3,nil)) -- Output: -- begin function: 0x6869d0 -- calc 2 3 nil -- end function: 0x6869d0 -- returns: 5
Unfortunately, it misses a nil
return value since nil
are not explicitly storable in tables , and particularly {...}
does not preserve information about trailing nil
s (this is further discussed in StoringNilsInTables).
{...}
and unpack
with n
The following improvement on the previous solution properly handles nil
s:
function pack2(...) return {n=select('#', ...), ...} end function unpack2(t) return unpack(t, 1, t.n) end --Wraps a function with trace statements. function trace(f) return function(...) print("begin", f) local result = pack2(f(...)) print("end", f) return unpack2(result); end end test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end) print("returns:", test(2,3,nil)) -- Output: -- begin function: 0x6869d0 -- calc 2 3 nil -- end function: 0x6869d0 -- returns: 5 nil
A variant noted by Shirik is
local function tuple(...) local n = select('#', ...) local t = {...} return function() return unpack(t, 1, n) end end
nil
Placeholders
The following approach swaps the nil
's with placeholders that can be stored in tables. It's probably less optimal here, but the approach might be usable elsewhere.
local NIL = {} -- placeholder value for nil, storable in table. function pack2(...) local n = select('#', ...) local t = {} for i = 1,n do local v = select(i, ...) t[i] = (v == nil) and NIL or v end return t end function unpack2(t) --caution: modifies t if #t == 0 then return else local v = table.remove(t, 1) if v == NIL then v = nil end return v, unpack2(t) end end --Wraps a function with trace statements. function trace(f) return function(...) print("begin", f) local result = pack2(f(...)) print("end", f) return unpack2(result) end end test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end) print("returns:", test(2,3,nil)) -- Output: -- begin function: 0x687350 -- calc 2 3 nil -- end function: 0x687350 -- returns: 5 nil
Here are more optimal implementations of pack2 and unpack2:
local NIL = {} -- placeholder value for nil, storable in table. function pack2(...) local n = select('#', ...) local t = {...} for i = 1,n do if t[i] == nil then t[i] = NIL end end return t end function unpack2(t, k, n) k = k or 1 n = n or #t if k > n then return end local v = t[k] if v == NIL then v = nil end return v, unpack2(t, k + 1, n) end
See also StoringNilsInTables.
Tables can be avoided if we use the Continuation passing style (CPS) ([Wikipedia]) as below. We could expect this to be a bit more efficient.
function trace(f) local helper = function(...) print("end", f) return ... end return function(...) print("begin", f) return helper(f(...)) end end test = trace(function(x,y,z) print("calc", x,y,z); return x+y, z end) print("returns:", test(2,3,nil)) -- Output: -- begin function: 0x686b10 -- calc 2 3 nil -- end function: 0x686b10 -- returns: 5 nil
The CPS approach was also used in the RiciLake's string split function (LuaList:2006-12/msg00414.html).
Another approach is code generation, which compiles a separate constructor for each tuple size. There is some initial overhead building the constructors, but the constructors themselves can be optimally implemented. The tuple
function used previously can be implemented as such:
local function build_constructor(n) local t = {}; for i = 1,n do t[i] = "a" .. i end local arglist = table.concat(t, ',') local src = "return function(" .. arglist .. ") return function() return " .. arglist .. " end end" return assert(loadstring(src))() end function tuple(...) local construct = build_constructor(select('#', ...)) return construct(...) end
To avoid the overhead of code generation on each store, we can memoize the make_storeimpl
function (for background see [Wikipedia:Memoization] and FuncTables).
function Memoize(fn) return setmetatable({}, { __index = function(t, k) local val = fn(k); t[k] = val; return val end, __call = function(t, k) return t[k] end }) end build_constructor = Memoize(build_constructor)
A more complete example of tuples implemented via code generation is in FunctionalTuples.
The code building/memoization technique and the above Memoize
function are based on some previous related examples by RiciLake such as [Re: The Curry Challenge].
Note also that if your wrapped function raises exceptions, you would want to pcall
as well (LuaList:2007-02/msg00165.html).
The following approach is purely functional (no tables) and avoids code generation. It's not necessarily the most efficient as it creates a function per tuple element.
function helper(n, first, ...) if n == 1 then return function() return first end else local rest = helper(n-1, ...) return function() return first, rest() end end end function tuple(...) local n = select('#', ...) return (n == 0) and function() end or helper(n, ...) end -- TEST local function join(...) local t = {n=select('#', ...), ...} for i=1,t.n do t[i] = tostring(t[i]) end return table.concat(t, ",") end local t = tuple() assert(join(t()) == "") t = tuple(2,3,nil,4,nil) assert(join(t()) == "2,3,nil,4,nil") print "done"
Another idea is with coroutines:
do local function helper(...) coroutine.yield() return ... end function pack2(...) local o = coroutine.create(helper) coroutine.resume(o, ...) return o end function unpack2(o) return select(2, coroutine.resume(o)) end end
A similar suggestion was posted in LuaList:2007-02/msg00142.html . That can be inefficient though (RiciLake notes that a minimal coroutine occupies slightly more than 1k plus malloc overhead, on freebsd it totals close to 2k, and the largest part is the stack, which defaults to 45 slots @ 12 or 16 bytes each).
It is not necessary to create a new coroutine on each call. The following approach is rather efficient, and the recursion uses a tail call:
local yield = coroutine.yield local resume = coroutine.resume local function helper(...) yield(); return helper(yield(...)) end local function make_stack() return coroutine.create(helper) end -- Example local stack = make_stack() local function trace(f) return function(...) print("begin", f) resume(stack, f(...)) print("end", f) return select(2, resume(stack)) end end
Tuples can be implemented in C as a closure containing the tuple elements as upvalues. This is demonstrated in Section 27.3 of Programming In Lua, 2nd Ed [2].
The speeds of the above solutions are compared in the following benchmark.
-- Avoid global table accesses in benchmark. local time = os.time local unpack = unpack local select = select -- Benchmarks function f using chunks of nbase iterations for duration -- seconds in ntrials trials. local function bench(duration, nbase, ntrials, func, ...) assert(nbase % 10 == 0) local nloops = nbase/10 local ts = {} for k=1,ntrials do local t1, t2 = time() local nloops2 = 0 repeat for j=1,nloops do func(...) func(...) func(...) func(...) func(...) func(...) func(...) func(...) func(...) func(...) end t2 = time() nloops2 = nloops2 + 1 until t2 - t1 >= duration local t = (t2-t1) / (nbase * nloops2) ts[k] = t end return unpack(ts) end local function print_bench(name, duration, nbase, ntrials, func, ...) local fmt = (" %0.1e"):rep(ntrials) print(string.format("%25s:" .. fmt, name, bench(duration, nbase, ntrials, func, ...) )) end -- Test all methods. local function test_suite(duration, nbase, ntrials) print("name" .. (", t"):rep(ntrials) .. " (times in sec)") do -- This is a base-line. local function trace(f) return function(...) return f(...) end end local f = trace(function() return 11,12,13,14,15 end) print_bench("(control)", duration, nbase, ntrials, f, 1,2,3,4,5) end do local function trace(f) local function helper(...) return ... end return function(...) return helper(f(...)) end end local f = trace(function() return 11,12,13,14,15 end) print_bench("CPS", duration, nbase, ntrials, f, 1,2,3,4,5) end do local yield = coroutine.yield local resume = coroutine.resume local function helper(...) yield(); return helper(yield(...)) end local function make_stack() return coroutine.create(helper) end local stack = make_stack() local function trace(f) return function(...) resume(stack, f(...)) return select(2, resume(stack)) end end local f = trace(function() return 11,12,13,14,15 end) print_bench("Coroutine", duration, nbase, ntrials, f, 1,2,3,4,5) end do local function trace(f) return function(...) local t = {f(...)} return unpack(t) end end local f = trace(function() return 11,12,13,14,15 end) print_bench("{...} and unpack", duration, nbase, ntrials, f, 1,2,3,4,5) end do local function trace(f) return function(...) local n = select('#', ...) local t = {f(...)} return unpack(t, 1, n) end end local f = trace(function() return 11,12,13,14,15 end) print_bench("{...} and unpack with n", duration, nbase, ntrials, f, 1,2,3,4,5) end do local NIL = {} local function pack2(...) local n = select('#', ...) local t = {...} for i=1,n do local v = t[i] if t[i] == nil then t[i] = NIL end end return t end local function unpack2(t) local n = #t for i=1,n do local v = t[i] if t[i] == NIL then t[i] = nil end end return unpack(t, 1, n) end local function trace(f) return function(...) local t = pack2(f(...)) return unpack2(t) end end local f = trace(function() return 11,12,13,14,15 end) print_bench("nil Placeholder", duration, nbase, ntrials, f, 1,2,3,4,5) end do -- This is a simplified version of Code Generation for comparison. local function tuple(a1,a2,a3,a4,a5) return function() return a1,a2,a3,a4,a5 end end local function trace(f) return function(...) local t = tuple(f(...)) return t() end end local f = trace(function() return 11,12,13,14,15 end) print_bench("Closure", duration, nbase, ntrials, f, 1,2,3,4,5) end do local function build_constructor(n) local t = {}; for i = 1,n do t[i] = "a" .. i end local arglist = table.concat(t, ',') local src = "return function(" .. arglist .. ") return function() return " .. arglist .. " end end" return assert(loadstring(src))() end local cache = {} local function tuple(...) local n = select('#', ...) local construct = cache[n] if not construct then construct = build_constructor(n) cache[n] = construct end return construct(...) end local function trace(f) return function(...) local t = tuple(f(...)) return t() end end local f = trace(function() return 11,12,13,14,15 end) print_bench("Code Generation", duration, nbase, ntrials, f, 1,2,3,4,5) end do local function helper(n, first, ...) if n == 1 then return function() return first end else local rest = helper(n-1, ...) return function() return first, rest() end end end local function tuple(...) local n = select('#', ...) return (n == 0) and function() end or helper(n, ...) end local function trace(f) return function(...) local t = tuple(f(...)) return t() end end local f = trace(function() return 11,12,13,14,15 end) print_bench("Functional, Recursive", duration, nbase, ntrials, f, 1,2,3,4,5) end -- NOTE: Upvalues in C Closure not benchmarked here. print "done" end test_suite(10, 1000000, 3) test_suite(10, 1000000, 1) -- recheck
Results:
(Pentium4/3GHz) name, t, t, t (times in sec) (control): 3.8e-007 3.8e-007 4.0e-007 CPS: 5.6e-007 6.3e-007 5.9e-007 Coroutine: 1.7e-006 1.7e-006 1.7e-006 {...} and unpack: 2.2e-006 2.2e-006 2.4e-006 {...} and unpack with n: 2.5e-006 2.5e-006 2.5e-006 nil Placeholder: 5.0e-006 4.7e-006 4.7e-006 Closure: 5.0e-006 5.0e-006 5.0e-006 Code Generation: 5.5e-006 5.5e-006 5.5e-006 Functional, Recursive: 1.3e-005 1.3e-005 1.3e-005 done
The CPS is the fastest, followed by coroutines (both operated on the stack). Tables take a bit more time than the coroutine approach, though coroutines could be even faster if we didn't have the the select
on the resume
. Use of closures are a few times slower still (including when generalized with code generation) to an order of magnitude slower (if generalized with Functional, Recursive).
For a tuple size of 1, we get
name, t, t, t (times in sec) (control): 2.9e-007 2.8e-007 2.7e-007 CPS: 4.3e-007 4.3e-007 4.3e-007 Coroutine: 1.4e-006 1.4e-006 1.4e-006 {...} and unpack: 2.0e-006 2.2e-006 2.2e-006 {...} and unpack with n: 2.4e-006 2.5e-006 2.4e-006 nil Placeholder: 3.3e-006 3.3e-006 3.3e-006 Closure: 2.0e-006 2.0e-006 2.0e-006 Code Generation: 2.2e-006 2.5e-006 2.2e-006 Functional, Recursive: 2.5e-006 2.4e-006 2.2e-006 done
For a tuple size of 20, we get
name, t, t, t (times in sec) (control): 8.3e-007 9.1e-007 9.1e-007 CPS: 1.3e-006 1.3e-006 1.1e-006 Coroutine: 2.7e-006 2.7e-006 2.7e-006 {...} and unpack: 3.0e-006 3.2e-006 3.0e-006 {...} and unpack with n: 3.7e-006 3.3e-006 3.7e-006 nil Placeholder: 1.0e-005 1.0e-005 1.0e-005 Closure: 1.8e-005 1.8e-005 1.8e-005 Code Generation: 1.9e-005 1.8e-005 1.9e-005 Functional, Recursive: 5.7e-005 5.7e-005 5.8e-005 done
Notice that the times for table construction methods differ relatively little with respect to tuple size (due to the initial overhead of constructing a table). In contrast, use of closures entails run times that vary more significantly with tuple size.
Problem: given two variable length lists (e.g. the return values of two functions, f
and g
, that each return multiple values), combine them into a single list.
This can be a problem because of the behavior of Lua to discard all but the first return value of a function unless it is the last item in a list:
local function f() return 1,2,3 end local function g() return 4,5,6 end print(f(), g()) -- prints 1 4 5 6
Besides the obvious solutions of converting the lists into objects such as tables (via the methods in Issue #1 above), there are ways to do this with only function calls.
The following combines lists recursively by prepending only one element at a time and delaying evaluation of one of the lists.
local function helper(f, n, a, ...) if n == 0 then return f() end return a, helper(f, n-1, ...) end local function combine(f, ...) local n = select('#', ...) return helper(f, n, ...) end -- TEST local function join(...) local t = {n=select('#', ...), ...} for i=1,t.n do t[i] = tostring(t[i]) end return table.concat(t, ",") end local function f0() return end local function f1() return 1 end local function g1() return 2 end local function f3() return 1,2,3 end local function g3() return 4,5,6 end assert(join(combine(f0, f0())) == "") assert(join(combine(f0, f1())) == "1") assert(join(combine(f1, f0())) == "1") assert(join(combine(g1, f1())) == "1,2") assert(join(combine(g3, f3())) == "1,2,3,4,5,6") print "done"
Problem: Return a list consisting of the first N elements in another list.
The select
function allows selecting the last N elements in a list, but there is no built-in function for selecting the first N elements.
local function helper(n, a, ...) if n == 0 then return end return a, helper(n-1, ...) end local function first(k, ...) local n = select('#', ...) return helper(k, ...) end -- TEST local function join(...) local t = {n=select('#', ...), ...} for i=1,t.n do t[i] = tostring(t[i]) end return table.concat(t, ",") end local function f0() return end local function f1() return 1 end local function f8() return 1,2,3,4,5,6,7,8 end assert(join(first(0, f0())) == "") assert(join(first(0, f1())) == "") assert(join(first(1, f1())) == "1") assert(join(first(0, f8())) == "") assert(join(first(1, f8())) == "1") assert(join(first(2, f8())) == "1,2") assert(join(first(8, f8())) == "1,2,3,4,5,6,7,8") print "done"
Note: if the number of elements is fixed, the solution is easier:
local function firstthree(a,b,c) return a,b,c end assert(join(firstthree(f8())) == "1,2,3") -- TEST
Code generation approaches can be based on this.
Problem: Append one element to a list.
Note that prepending one element to a list is simple: {a, ...}
local function helper(a, n, b, ...) if n == 0 then return a else return b, helper(a, n-1, ...) end end local function append(a, ...) return helper(a, select('#', ...), ...) end
Note: if the number of elements is fixed, the solution is easier:
local function append3(e, a, b, c) return a, b, c, e end
Problem: Reverse a list.
local function helper(n, a, ...) if n > 0 then return append(a, helper(n-1, ...)) end end local function reverse(...) return helper(select('#', ...), ...) end
Note: if the number of elements is fixed, the solution is easier:
local function reverse3(a,b,c) return c,b,a end
Problem: Implement the map [3] function over a list.
local function helper(f, n, a, ...) if n > 0 then return f(a), helper(f, n-1, ...) end end local function map(f, ...) return helper(f, select('#', ...), ...) end
Problem: Implement the filter [4] function over a list.
local function helper(f, n, a, ...) if n > 0 then if f(a) then return a, helper(f, n-1, ...) else return helper(f, n-1, ...) end end end local function grep(f, ...) return helper(f, select('#', ...), ...) end
Problem: Iterate over all elements in the vararg.
for n=1,select('#',...) do local e = select(n,...) end
If you do not need nil elements, you can also use the following:
for _, e in ipairs({...}) do -- something with e end
If you wish to use an iterator function without creating a garbage table every time, you can use the following:
do local i, t, l = 0, {} local function iter(...) i = i + 1 if i > l then return end return i, t[i] end function vararg(...) i = 0 l = select("#", ...) for n = 1, l do t[n] = select(n, ...) end for n = l+1, #t do t[n] = nil end return iter end end for i, v in vararg(1, "a", false, nil) do print(i, v) end -- test -- Output: -- 1 1 -- 2 "a" -- 3 false -- 4 nil
(none)
nil
s and also efficiency.