Curried Lua |
|
You can implement curried functions in all languages that support functions as first-class objects. For example, there's a little [tutorial about curried JavaScript].
Here is a small Lua example of a curried function:
function sum(number) return function(anothernumber) return number + anothernumber end end local f = sum(5) print(f(3)) --> 8
-- WalterCruz
Here is another, contributed by [GavinWraith], which takes a variable number of arguments terminated with a "()
":
function addup(x) local sum = 0 local function f(n) if type(n) == "number" then sum = sum + n return f else return sum end end return f(x) end print(addup (1) (2) (3) ()) --> 6 print(addup (4) (5) (6) ()) --> 15
Although these pre-curried functions are useful, what we would really like to do is make a general-purpose function that can perform the curry operation on any other function. To do this, we need to realize that functions can be operated upon by a "Higher-order function" -- a function that takes functions as arguments. The following curry function is an example of this, and curries a 2-argument function:
function curry(f) return function (x) return function (y) return f(x,y) end end end powcurry = curry(math.pow) powcurry (2) (4) --> 16 pow2 = powcurry(2) pow2(3) --> 8 pow2(4) --> 16 pow2(8) --> 256
To go from currying 2 arguments to currying 'n' arguments is a bit more complicated. We need to store an indeterminate number of partial applications, and unfortunately there is no way for Lua to know how many arguments a function requires; Lua functions can successfully receive any number of arguments, whether too many or too few. So, it's necessary to tell the curry function how many single-argument calls to accept before applying those collected arguments to the original function.
(this code is freely available from http://tinylittlelife.org/?p=249 and includes a full discussion of how to tackle this problem.)
-- curry(func, num_args) : take a function requiring a tuple for num_args arguments -- and turn it into a series of 1-argument functions -- e.g.: you have a function dosomething(a, b, c) -- curried_dosomething = curry(dosomething, 3) -- we want to curry 3 arguments -- curried_dosomething (a1) (b1) (c1) -- returns the result of dosomething(a1, b1, c1) -- partial_dosomething1 = curried_dosomething (a_value) -- returns a function -- partial_dosomething2 = partial_dosomething1 (b_value) -- returns a function -- partial_dosomething2 (c_value) -- returns the result of dosomething(a_value, b_value, c_value) function curry(func, num_args) -- currying 2-argument functions seems to be the most popular application num_args = num_args or 2 -- no sense currying for 1 arg or less if num_args <= 1 then return func end -- helper takes an argtrace function, and number of arguments remaining to be applied local function curry_h(argtrace, n) if 0 == n then -- kick off argtrace, reverse argument list, and call the original function return func(reverse(argtrace())) else -- "push" argument (by building a wrapper function) and decrement n return function (onearg) return curry_h(function () return onearg, argtrace() end, n - 1) end end end -- push the terminal case of argtrace into the function first return curry_h(function () return end, num_args) end -- reverse(...) : take some tuple and return a tuple of elements in reverse order -- -- e.g. "reverse(1,2,3)" returns 3,2,1 function reverse(...) --reverse args by building a function to do it, similar to the unpack() example local function reverse_h(acc, v, ...) if 0 == select('#', ...) then return v, acc() else return reverse_h(function () return v, acc() end, ...) end end -- initial acc is the end of the list return reverse_h(function () return end, ...) end
The above code is Lua 5.1 compatible.
Since Lua 5.2 (or LuaJIT 2.0) provides an advanced debug.getinfo function that let us know how many arguments a function desires, we can make a practical function which mixes currying and partial application techniques. Here's the code:
function curry(func, num_args) num_args = num_args or debug.getinfo(func, "u").nparams if num_args < 2 then return func end local function helper(argtrace, n) if n < 1 then return func(unpack(flatten(argtrace))) else return function (...) return helper({argtrace, ...}, n - select("#", ...)) end end end return helper({}, num_args) end function flatten(t) local ret = {} for _, v in ipairs(t) do if type(v) == 'table' then for _, fv in ipairs(flatten(v)) do ret[#ret + 1] = fv end else ret[#ret + 1] = v end end return ret end function multiplyAndAdd (a, b, c) return a * b + c end curried_multiplyAndAdd = curry(multiplyAndAdd) multiplyBySevenAndAdd = curried_multiplyAndAdd(7) multiplySevenByEightAndAdd_v1 = multiplyBySevenAndAdd(8) multiplySevenByEightAndAdd_v2 = curried_multiplyAndAdd(7, 8) assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v1(9)) assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v2(9)) assert(multiplyAndAdd(7, 8, 9) == multiplyBySevenAndAdd(8, 9)) assert(multiplyAndAdd(7, 8, 9) == curried_multiplyAndAdd(7, 8, 9))