Curried Memoization |
|
For example:
local mtest = weak_memoize_m_to_n( function(...) print 'exec' return ... end ) print( mtest(nil,2) ) --> "exec", "nil,2" print( mtest(nil,2) ) --> "nil,2" collectgarbage() print( mtest ( nil,2 ) ) --> "exec", "nil,2" print( mtest ( ) ) --> "exec", "" print( mtest ( nil ) ) --> "exec", "nil"
The design is equivalent to an argument tree technique like [memoize.lua]. However, the tree is built implicitly, rather than explicitly, by recursively reducing M>1 cases to M-1 cases.
local _ENV = setmetatable({},{__index=_G}) -- this code can be made more memory and speed efficient by -- defining catch() in C. but, a table.unpack approach will also -- work. function catch(...) local rvals = {...} local n = select('#',...) return function() return table.unpack(rvals,1,n) end end local weak_mt= {__mode='kv'} local function weak_table() return setmetatable({},weak_mt) end local function strong_table() return {} end local null = {} local function arg2key(arg) return (arg == nil and null) or arg end -- build a memoization function that can handle the 1-argument -- to n rvals case. local function new_memoizer_1_to_n(newtable) return function(fun) local lookup = newtable() return function (arg) local k = arg2key(arg) local r=lookup[k] if r then return r() end r=catch( fun(arg) ) lookup[k] = r return r() end end end local function new_memoizer_m_to_n( newtable, memoize_1_to_n ) -- return a memoization of f that assumes m arguments. local function memoize_m_to_n(m,f) -- handle the m==0 case if m==0 then local memoized return function() if memoized then return memoized() end memoized = catch(f()) return memoized() end end if m==1 then return memoize_1_to_n(f) end local lookup = newtable() -- handle the general m-to-n case, for m>=2. return function(arg, ...) local k = arg2key(arg) local r = lookup[k] if r then return r(...) end -- create a new (m-1) argument memoizer that will handle -- this arg value in the future. r = memoize_m_to_n(m-1, function(...) return f(arg,...) end) lookup[k]=r return r(...) end end -- return a memoizer that dispatches between the different m-argument cases. return function(f) local m_to_memoized = newtable() return function(...) local m = select('#',...) local memoized = m_to_memoized[m] if memoized then return memoized(...) end memoized = memoize_m_to_n(m,f) m_to_memoized[m]=memoized return memoized(...) end end end weak_memoize_1_to_n = new_memoizer_1_to_n(weak_table) strong_memoize_1_to_n = new_memoizer_1_to_n(strong_table) weak_memoize_m_to_n = new_memoizer_m_to_n(weak_table,weak_memoize_1_to_n) strong_memoize_m_to_n = new_memoizer_m_to_n(strong_table,strong_memoize_1_to_n) return _ENV
Memory use and performance can both be improved by implementing catch() inside the C-API. While their storage capacity is limited to 255 values; C-closures are lighter, faster datastructures than Lua's generic tables.
static int throw_upvalues(lua_State *L) { int n1=lua_tointeger(L,lua_upvalueindex(1)); luaL_checkstack(L,n1-1,"too many upvalues"); for(int i=2; i<=n1; i++) { lua_pushvalue(L,lua_upvalueindex(i)); } return n1-1; } static int catch_args(lua_State *L) { int n1 = lua_gettop(L)+1; if(n1>MAXUPVAL) { return luaL_error(L,"can't catch more than %d args. (catch() called with %d arguments).",MAXUPVAL-1, n1-1); } lua_pushinteger(L,n1); lua_insert(L,1); lua_pushcclosure(L,throw_upvalues,n1); return 1; }
Many other Lua memoization implementations are scattered around the web. The FuncTables page appears to be the de-facto wiki link hub. But, the topic also comes up frequently on the lua users list; and there are a couple nice implementations of string-serialization based approaches posted in the archives [1], [2].