Method Chaining Wrapper |
|
(" test "):trim():repeatchars(2):upper() --> "TTEESSTT" (function(x,y) return x*y end):curry(2) --> (function(y) return 2*y end)
We can do this with the debug library (debug.setmetatable
) [5]. The downside is that each built-in type has a single, common metatable. Modifying this metatable causes a global side-effect, which is a potential source of conflict between independently maintained modules in a program. Functions in the debug library are often discouraged in regular code for good reason. Many people avoid injecting into these global metatables, while others find it too convenient to avoid [3][6][ExtensionProposal]. Some have even asked why objects of built-in types don't have their own metatables [7].
... debug.setmetatable("", string_mt) debug.setmetatable(function()end, function_mt)
We could instead use just standalone functions:
(repeatchars(trim("test"), 2)):upper() curry(function(x,y) return x*y end, 2)
This is the simplest solution. Simple solutions are often good ones. Nevertheless, there can be a certain level of discordance with some operations being method calls and some being standalone global functions, along with the reordering that results.
One solution to avoid touching the global metatables is to wrap the object inside our own class, perform operations on the wrapper in a method call chain, and unwrap the objects.
Examples would look like this:
S" test ":trim():repeatchars(2):upper()() --> TTEESSTT S" TEST ":trim():lower():find('e')() --> 2 2
The S
function wraps the given object into a wrapper object. A chain of method calls on the wrapper object operate on the wrapped object in-place. Finally, the wrapper object is unwrapped with a function call ()
.
For functions that return a single value, an alternative way to unpack is to the use the unary minus:
-S" test ":trim():repeatchars(2):upper() --> TTEESSTT
To define S
in terms of a table of string functions stringx
, we can use this code:
local stringx = {} for k,v in pairs(string) do stringx[k] = v end function stringx.trim(self) return self:match('^%s*(%S*)%s*$') end function stringx.repeatchars(self, n) local ts = {} for i=1,#self do local c = self:sub(i,i) for i=1,n do ts[#ts+1] = c end end return table.concat(ts) end local S = buildchainwrapbuilder(stringx)
The buildchainwrapbuilder
function is general and implements our design pattern:
-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license). -- version 20090430 local select = select local setmetatable = setmetatable local unpack = unpack local rawget = rawget -- http://lua-users.org/wiki/CodeGeneration local function memoize(func) return setmetatable({}, { __index = function(self, k) local v = func(k); self[k] = v; return v end, __call = function(self, k) return self[k] end }) end -- unique IDs (avoid name clashes with wrapped object) local N = {} local VALS = memoize(function() return {} end) local VAL = VALS[1] local PREV = {} local function mypack(ow, ...) local n = select('#', ...) for i=1,n do ow[VALS[i]] = select(i, ...) end for i=n+1,ow[N] do ow[VALS[i]] = nil end ow[N] = n end local function myunpack(ow, i) i = i or 1 if i <= ow[N] then return rawget(ow, VALS[i]), myunpack(ow, i+1) end end local function buildchainwrapbuilder(t) local mt = {} function mt:__index(k) local val = rawget(self, VAL) self[PREV] = val -- store in case of method call mypack(self, t[k]) return self end function mt:__call(...) if (...) == self then -- method call local val = rawget(self, VAL) local prev = rawget(self, PREV) self[PREV] = nil mypack(self, val(prev, select(2,...))) return self else return myunpack(self, 1, self[N]) end end function mt:__unm() return rawget(self, VAL) end local function build(o) return setmetatable({[VAL]=o,[N]=1}, mt) end return build end local function chainwrap(o, t) return buildchainwrapbuilder(t)(o) end
Test suite:
-- simple examples assert(-S"AA":lower() == "aa") assert(-S"AB":lower():reverse() == "ba") assert(-S" test ":trim():repeatchars(2):upper() == "TTEESSTT") assert(S" test ":trim():repeatchars(2):upper()() == "TTEESSTT") -- basics assert(S""() == "") assert(S"a"() == "a") assert(-S"a" == "a") assert(S(nil)() == nil) assert(S"a":byte()() == 97) local a,b,c = S"TEST":lower():find('e')() assert(a==2 and b==2 and c==nil) assert(-S"TEST":lower():find('e') == 2) -- potentially tricky cases assert(S"".__index() == nil) assert(S"".__call() == nil) assert(S""[1]() == nil) stringx[1] = 'c' assert(S"a"[1]() == 'c') assert(S"a"[1]:upper()() == 'C') stringx[1] = 'd' assert(S"a"[1]() == 'd') -- uncached assert(S"a".lower() == string.lower) -- improve error messages? --assert(S(nil):z() == nil) print 'DONE'
The above implementation has these qualities and assumptions:
__call
and __index
operators, which also form the two halves of the method call. Operators like __len
are not definable in 5.1 tables. True LuaVirtualization is not possible.
__call
operator ()
to unpack, which is the only operator allowing multiple return values. The code also supports unary minus as an alternative, which has the limitation of returning a single value (usual case) but perhaps has nicer syntax qualities (S and -
together).
There are alternative ways we could have expressed the chaining:
S{" test ", "trim", {"repeatchars",2}, "upper"} S(" test ", "trim | repeatchars(2) | upper")
but this looks less conventional. (Note: the second argument in the last line is point-free [4].)
We could instead express the call chain like this:
chain(stringx):trim():repeatchars(5):upper()(' test ')
where the object operated on is placed at the very end. This reduces the chance of forgetting to unpack, and it allows separation and reuse:
f = chain(stringx):trim():repeatchars(5):upper() print ( f(' test ') ) print ( f(' again ') )
There's various ways to implement this (functional, CodeGeneration, and VM). Here we take the latter approach.
-- method call chaining, take #2 -- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license). -- version 20090501 -- unique IDs to avoid name conflict local OPS = {} local INDEX = {} local METHOD = {} -- table insert, allowing trailing nils local function myinsert(t, v) local n = t.n + 1; t.n = n t[n] = v end local function eval(ops, x) --print('DEBUG:', unpack(ops,1,ops.n)) local t = ops.t local self = x local prev local n = ops.n local i=1; while i <= n do if ops[i] == INDEX then local k = ops[i+1] prev = x -- save in case of method call x = t[k] i = i + 2 elseif ops[i] == METHOD then local narg = ops[i+1] x = x(prev, unpack(ops, i+2, i+1+narg)) i = i + 2 + narg else assert(false) end end return x end local mt = {} function mt:__index(k) local ops = self[OPS] myinsert(ops, INDEX) myinsert(ops, k) return self end function mt:__call(x, ...) local ops = self[OPS] if x == self then -- method call myinsert(ops, METHOD) local n = select('#', ...) myinsert(ops, n) for i=1,n do myinsert(ops, (select(i, ...))) end return self else return eval(ops, x) end end local function chain(t) return setmetatable({[OPS]={n=0,t=t}}, mt) end
Rudimentary test code:
local stringx = {} for k,v in pairs(string) do stringx[k] = v end function stringx.trim(self) return self:match('^%s*(%S*)%s*$') end function stringx.repeatchars(self, n) local ts = {} for i=1,#self do local c = self:sub(i,i) for i=1,n do ts[#ts+1] = c end end return table.concat(ts) end local C = chain assert(C(stringx):trim():repeatchars(2):upper()(" test ") == 'TTEESSTT') local f = C(stringx):trim():repeatchars(2):upper() assert(f" test " == 'TTEESSTT') assert(f" again " == 'AAGGAAIINN') print 'DONE'
An alternate idea is to modify the string metatable so that the extensions to the string methods are only visible within a lexical scope. The following is not perfect (e.g. nested functions), but it is a start. Example:
-- test example libraries local stringx = {} function stringx.trim(self) return self:match('^%s*(%S*)%s*$') end local stringxx = {} function stringxx.trim(self) return self:match('^%s?(.-)%s?$') end -- test example function test2(s) assert(s.trim == nil) scoped_string_methods(stringxx) assert(s:trim() == ' 123 ') end function test(s) scoped_string_methods(stringx) assert(s:trim() == '123') test2(s) assert(s:trim() == '123') end local s = ' 123 ' assert(s.trim == nil) test(s) assert(s.trim == nil) print 'DONE'
The function scoped_string_methods
assigns the given function table to the scope of the currently executing function. All string indexing within the scope goes through that given table.
The above uses this framework code:
-- framework local mt = debug.getmetatable('') local scope = {} function mt.__index(s, k) local f = debug.getinfo(2, 'f').func return scope[f] and scope[f][k] or string[k] end local function scoped_string_methods(t) local f = debug.getinfo(2, 'f').func scope[f] = t end
We can do something similar to the above more robustly via MetaLua. An example is below.
-{extension "lexicalindex"} -- test example libraries local stringx = {} function stringx.trim(self) return self:match('^%s*(%S*)%s*$') end local function f(o,k) if type(o) == 'string' then return stringx[k] or string[k] end return o[k] end local function test(s) assert(s.trim == nil) lexicalindex f assert(s.trim ~= nil) assert(s:trim():upper() == 'TEST') end local s = ' test ' assert(s.trim == nil) test(s) assert(s.trim == nil) print 'DONE'
The syntax extension introduces a new keyword lexicalindex
that specifies a function to be called whenever a value is to be indexed inside the current scope.
Here is what the corresponding plain Lua source looks like:
--- $ ./build/bin/metalua -S vs.lua --- Source From "@vs.lua": --- local function __li_invoke (__li_index, o, name, ...) return __li_index (o, name) (o, ...) end local stringx = { } function stringx:trim () return self:match "^%s*(%S*)%s*$" end local function f (o, k) if type (o) == "string" then return stringx[k] or string[k] end return o[k] end local function test (s) assert (s.trim == nil) local __li_index = f assert (__li_index (s, "trim") ~= nil) assert (__li_invoke (__li_index, __li_invoke (__li_index, s, "trim"), "upper" ) == "TEST") end local s = " test " assert (s.trim == nil) test (s) assert (s.trim == nil) print "DONE"
The lexicalindex
Metalua extension is implemented as
-- lexical index in scope iff depth > 0 local depth = 0 -- transform indexing expressions mlp.expr.transformers:add(function(ast) if depth > 0 then if ast.tag == 'Index' then return +{__li_index(-{ast[1]}, -{ast[2]})} elseif ast.tag == 'Invoke' then return `Call{`Id'__li_invoke', `Id'__li_index', unpack(ast)} end end end) -- monitor scoping depth mlp.block.transformers:add(function(ast) for _,ast2 in ipairs(ast) do if ast2.is_lexicalindex then depth = depth - 1; break end end end) -- handle new "lexicalindex" statement mlp.lexer:add'lexicalindex' mlp.stat:add{'lexicalindex', mlp.expr, builder=function(x) local e = unpack(x) local ast_out = +{stat: local __li_index = -{e}} ast_out.is_lexicalindex = true depth = depth + 1 return ast_out end} -- utility function -- (note: o must be indexed exactly once to preserve behavior return +{block: local function __li_invoke(__li_index, o, name, ...) return __li_index(o, name)(o, ...) end }