Simple Fit

lua-users home
wiki

Here is a way to fit curves--e.g. straight lines, parabolas, and exponential functions--using LuaMatrix.

Download the whole package from:

http://luaforge.net/projects/luamatrix/

The code and method is very simple.

First you get the x values into a table. Then you concatenate it with the y-values. Use the Gauss-Jordan Method to get the results for your variables.

The only thing one had to think about with exponential functions was how to make them linear. That is easy; for example:

y = a * x^b | ln ==> ln(y) = ln( a ) + b * ln( x )

Then one can use fit.linear() again to get the variables a and b.

--///////////////////--
--// Curve Fitting //--
--///////////////////--

-- v 0.2

-- Lua 5.1 compatible

-- little add-on to the matrix module, to show some curve fitting

-- http://luaforge.net/projects/LuaMatrix
-- http://lua-users.org/wiki/SimpleFit

-- Licensed under the same terms as Lua itself.

-- requires matrix module
local matrix = require "matrix"

-- The Fit Table
local fit = {}

-- Note all these Algos use the Gauss-Jordan Method to caculate equation systems

-- function to get the results
local function getresults( mtx )
   assert( #mtx+1 == #mtx[1], "Cannot calculate Results" )
   mtx:dogauss()
   -- tresults
   local cols = #mtx[1]
   local tres = {}
   for i = 1,#mtx do
      tres[i] = mtx[i][cols]
   end
   return unpack( tres )
end

-- fit.linear ( x_values, y_values )
-- fit a straight line
-- model (  y = a + b * x  )
-- returns a, b
function fit.linear( x_values,y_values )
   -- x_values = { x1,x2,x3,...,xn }
   -- y_values = { y1,y2,y3,...,yn }
   
   -- values for A matrix
   local a_vals = {}
   -- values for Y vector
   local y_vals = {}

   for i,v in ipairs( x_values ) do
      a_vals[i] = { 1, v }
      y_vals[i] = { y_values[i] }
   end

   -- create both Matrixes
   local A = matrix:new( a_vals )
   local Y = matrix:new( y_vals )

   local ATA = matrix.mul( matrix.transpose(A), A )
   local ATY = matrix.mul( matrix.transpose(A), Y )

   local ATAATY = matrix.concath(ATA,ATY)

   return getresults( ATAATY )
end

-- fit.parabola ( x_values, y_values )
-- Fit a parabola
-- model (  y = a + b * x + c * x² )
-- returns a, b, c
function fit.parabola( x_values,y_values )
   -- x_values = { x1,x2,x3,...,xn }
   -- y_values = { y1,y2,y3,...,yn }

   -- values for A matrix
   local a_vals = {}
   -- values for Y vector
   local y_vals = {}

   for i,v in ipairs( x_values ) do
      a_vals[i] = { 1, v, v*v }
      y_vals[i] = { y_values[i] }
   end

   -- create both Matrixes
   local A = matrix:new( a_vals )
   local Y = matrix:new( y_vals )

   local ATA = matrix.mul( matrix.transpose(A), A )
   local ATY = matrix.mul( matrix.transpose(A), Y )

   local ATAATY = matrix.concath(ATA,ATY)

   return getresults( ATAATY )
end

-- fit.exponential ( x_values, y_values )
-- Fit exponential
-- model (  y = a * x^b )
-- returns a, b
function fit.exponential( x_values,y_values )
   -- convert to linear problem
   -- ln(y) = ln(a) + b * ln(x)
   for i,v in ipairs( x_values ) do
      x_values[i] = math.log( v )
      y_values[i] = math.log( y_values[i] )
   end

   local a,b = fit.linear( x_values,y_values )

   return math.exp(a), b
end

return fit

--///////////////--
--// chillcode //--
--///////////////--

Testcode:

-- require fit
-- local fit = require "fit"
local fit = dofile( "fit.lua" )

print( "Fit a straight line " )
-- x(i) = 2  | 3  | 4  | 5
-- y(i) = 5  | 9  | 15 | 21
-- model = y = a +  b * x
-- r(i) = y(i) - ( a + b * x(i) )
local a,b = fit.linear(	{ 2,3, 4, 5 },
			{ 5,9,15,21 } )
print( "=>    y = ( "..a.." )  +  ( "..b.." ) * x")

print( "Fit a parabola " )
local a, b, c = fit.parabola(	{ 0,1,2,4,6 },
				{ 3,1,0,1,4 } )
print( "=>    y = ( "..a.." )  +  ( "..b.." ) * x  +  ( "..c.." ) * x²")

print( "Fit exponential" )
local a, b = fit.exponential( {1,  2,  3,  4,   5},
			{1,3.1,5.6,9.1,12.9} )
print( "=>    y = ( "..a.." )  *  x^( "..b.." )")


RecentChanges · preferences
edit · history
Last edited August 26, 2007 4:40 pm GMT (diff)