Symbolic Differentiation |
|
The first step in symbolic algebra is defining a representation. Getting expressions into a suitable form is actually straightforward; no parsing of expressions is needed, since we have Lua to do that for us. Using the pl.func
library does all the hard work; it redefines the arithmetic operations to work on placeholder expressions (PEs), which are Lua expressions involving dummy variables called placeholders. pl.func
defines standard placeholders for arguments called _1
,_2
, etc but the Var
function will create new ones of our chosing:
utils.import 'pl.func' a,b,c,d = Var 'a,b,c,d' print(a+b+c+d)
Which will indeed print out the expression in a readable form. PE operator expressions are stored as combinations of tables looking like {op='+',x,y}
, which have an associated metatable which defines the metamethods like __add
and so forth. As a tree, with the usual associativity of Lua operators we get:
It is irritating to draw these diagrams, so a better notation is Lisp-style S-expressions:
1: (+ (+ (+ a b) c) d)
However, with the various manipulations we will perform, this canonical form is not the only possible representation of a+b+c+d
:
2: (+ a (+ b (+ c d))) 3: (+ (+ a b) (+ c d))
Now, experience shows that this leads to madness. Instead, it's easier to go for the canonical Lisp representation:
4: (+ a b c d)
Many operations become straightforward once this is in place, for instance comparing with (+ a c b d)
is just a matter of doing a 'compare with no order' on the arguments. Displaying PEs in this form is straightforward. isPE
simply checks the expression to see if it is a placeholder expression, by looking at the metatable. PEs with op=='X'
are placeholder variables, so the rest must be expression nodes.
function sexpr (e) if isPE(e) then if e.op ~= 'X' then local args = tablex.imap(sexpr,e) return '('..e.op..' '..table.concat(args,' ')..')' else return e.repr end else return tostring(e) end end
The first task is to balance the expressions, which converts representations 1-3 into 4.
function balance (e) if isPE(e) and e.op ~= 'X' then local op,args = e.op if op == '+' or op == '*' then args = rcollect(e) else args = imap(balance,e) end for i = 1,#args do e[i] = args[i] end end return e end
For the non-commutative operators, the idea is just to balance all the subexpressions by mapping balance
over the array part of the PE, which is the argument list. These are then copied back in-place. The non-trivial part is dealing with + and *, where it is necessary to collect all the arguments from expression trees looking like 1,2 or 3 and convert them into the fourth form.
function tcollect (op,e,ls) if isPE(e) and e.op == op then for i = 1,#e do tcollect(op,e[i],ls) end else ls:append(e) return end end function rcollect (e) local res = List() tcollect(e.op,e,res) return res end
This recursively goes down same-operator chains (the (+ (+ ...)
mentioned earlier) and collects the arguments, flattening them into n-ary + or * expressions.
Here is a useful function, which follows the same recursive pattern:
-- does this PE contain a reference to x? function references (e,x) if isPE(e) then if e.op == 'X' then return x.repr == e.repr else return find_if(e,references,x) end else return false end end
Here are functions to create n-ary products and sums:
function muli (args) return PE{op='*',unpack(args)} end function addi (args) return PE{op='+',unpack(args)} end
With this in place, the basic differentiation rules are not difficult. Firstly, only consider subexpressions which do contain the variable:
function diff (e,x) if isPE(e) and references(e,x) then local op = e.op if op == 'X' then return 1 else local a,b = e[1],e[2] if op == '+' then -- differentiation is linear local args = imap(diff,e,x) return balance(addi(args)) elseif op == '*' then -- product rule local res,d,ee = {} for i = 1,#e do d = fold(diff(e[i],x)) if d ~= 0 then ee = {unpack(e)} -- make a copy ee[i] = d append(res,balance(muli(ee))) end end if #res > 1 then return addi(res) else return res[1] end elseif op == '^' and isnumber(b) then -- power rule return b*x^(b-1) end end else return 0 end end
The derivative of a sum of expressions is the sum of the derivatives. Again, imap
does the job of applying the function recursively over the subexpressions. After constructing the result, we re-balance for luck.
The product rule is given here in its general form, with an explicit check for terms which work out to zero - that is the job of fold
, which will be discussed next.
(uvw..)' = u'vw.. + uv'w... + uvw'... + ...
And finally, the simple power rule. Note how the result can be expressed in a straightforward fashion, since all these operators are acting on PEs.
In fact, all these rules are certainly clearer if you use form 1, binary + and *! But then simplification becomes unbearable. And simplification ('folding') is the tricky one to get right. fold
is a longish function, so I will deal with it in sections:
local op = e.op local addmul = op == '*' or op == '+' -- first fold all arguments local args = imap(fold,e) if not addmul and not find_if(args,isPE) then -- no placeholders in these args, we can fold the expression. local opfn = optable[op] if opfn then return opfn(unpack(args)) else return '?' end elseif addmul then
The first if
is looking for a case where a subexpression has no symbols, i.e. it is something like 2*5
or 10^2
; in this case, the constant can be completely folded. optable
(defined in pl.operator
) gives a mapping between the operator names and the actual function implementing them.
elseif op == '^' then if args[2] == 1 then return args[1] end -- identity if args[2] == 0 then return 1 end end return PE{op=op,unpack(args)}
This clause is clearing up expressions like x^1
and y^0
which naturally arise from the power rule in diff
. Once args
has been processed, the expression can be put together again.
The bulk of this routine handles the awkward twins, + and *.
-- split the args into two classes, PE args and non-PE args. local classes = List.partition(args,isPE) local pe,npe = classes[true],classes[false]
List.partition
takes a list and a function, which takes one argument and returns a single value. The result is table where the keys are the returned values, and the values are lists of those elements where the function returned that value. So:
List{1,2,3,4}:partition(function(x) return x > 2 end) --> {false={1,2},true={3,4}} List{'one',math.sin,10,20,{1,2}}:partition(type) --> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }
(Mathematically, these are referred to as equivalence classes and partition
would be called the quotient set)
In this case, we want to separate the non-symbolic arguments from the symbolic arguments; order does not matter. The non-symbolic arguments npe
can folded into a constant. At this point, operator identity rules can kick in, so that we can drop (* 0 x)
and simplify (+ 0 x)
to be just x
.
The final simplification is replacing repeated values, so that (* x x)
should become (^ x 2)
and (+ x x x)
should become (* x 3)
. count_map
from pl.tablex
will do the job. It is given a list-like table and a function which defines equivalence, and returns a map from the values to the number of their occurrences, so that count_map{'a','b','a'}
is {a=2,b=1}
.
Given this test function:
function testdiff (e) balance(e) e = diff(e,x) balance(e) print('+ ',e) e = fold(e) print('- ',e) end
and these cases:
testdiff(x^2+1) testdiff(3*x^2) testdiff(x^2 + 2*x^3) testdiff(x^2 + 2*a*x^3 + x^4) testdiff(2*a*x^3) testdiff(x*x*x)
we get these results, showing why something like fold
is so necessary to process the result of diff
.
+ 2 * x ^ 1 + 0 - 2 * x + 3 * 2 * x - 6 * x + 2 * x ^ 1 + 2 * 3 * x ^ 2 - 2 * x + 6 * x ^ 2 + 2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3 - 6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x + 2 * a * 3 * x ^ 2 - 6 * a * x ^ 2 + 1 * x * x + x * 1 * x + x * x * 1 - x ^ 2 * 3
https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua
https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua