/*

Copyright (c) 2014 Randy Gaul http://RandyGaul.net

This software is provided 'as-is', without any express or implied

warranty. In no event will the authors be held liable for any damages

arising from the use of this software.

Permission is granted to anyone to use this software for any purpose,

including commercial applications, and to alter it and redistribute it

freely, subject to the following restrictions:

1. The origin of this software must not be misrepresented; you must not

claim that you wrote the original software. If you use this software

in a product, an acknowledgment in the product documentation would be

appreciated but is not required.

2. Altered source versions must be plainly marked as such, and must not be

misrepresented as being the original software.

3. This notice may not be removed or altered from any source distribution.

Creator(s) : Randy Gaul

Creation Date : Sun Feb 09 15:15:22 2014

File Name : GaussianElimination.cpp

Purpose :

*/

#include <cstdio> // printf

#include <cstring> // memcpy

#include <malloc.h> // alloca

typedef float f32;

typedef int int32;

const f32 kEpsilon = 1.0e-6f;

struct Vec3

{

union

{

struct

{

f32 x;

f32 y;

f32 z;

};

f32 v[3];

};

};

struct Mat3

{

union

{

struct

{

f32 m00, m01, m02;

f32 m10, m11, m12;

f32 m20, m21, m22;

};

f32 v[9];

};

};

f32 Abs( f32 a )

{

if(a < 0.0f)

return -a;

else

return a;

}

void PrintMatrix( f32 *A, int32 n )

{

for(int32 y = 0; y < n; ++y)

{

for(int32 x = 0; x < n; ++x)

printf( "%6.2f", A[y * n + x] );

printf( "\n" );

}

printf( "\n" );

printf( " ----------------\n\n" );

}

void PrintVector( f32 *v, int32 n )

{

for(int32 i = 0; i < n; ++i)

printf( "%6.2f\n", v[i] );

printf( "\n" );

}

// Solve for x in the system Ax = b for an n by n

// matrix and n vector. Will destroy the contents

// of A and return the result of x in b.

// Derived from: Essential Math for Games, 2nd Ed.

// pg. 116.

bool Solve( f32 *A, f32 *b, const int32 n )

{

// One pass per diagonal element

for(int32 pivot = 0; pivot < n; ++pivot)

{

// Find the largest absolute value element, M, in current column

int32 M = pivot;

f32 best = Abs( A[pivot * n + pivot] );

for(int32 row = pivot + 1; row < n; ++row)

{

f32 element = Abs( A[row * n + pivot] );

if(element > best)

{

best = element;

M = row;

}

}

// If M is zero no solution

if(M < kEpsilon)

return false;

// If M is not in the top row perform partial pivoting

if(M != pivot)

{

// Swap the rows within A

int32 size = sizeof( f32 ) * (n - pivot);

f32 *temp = (f32 *)alloca( size );

std::memcpy( temp, A + (M * n), size );

std::memcpy( A + (M * n), A + (n * pivot), size );

std::memcpy( A + (n * pivot), temp, size );

// Swap elements in b

f32 temp2 = b[M];

b[M] = b[pivot];

b[pivot] = temp2;

}

// Multiply the current pivot row by 1/M, set the pivot to 1.0f

f32 inv = 1.0f / A[n * pivot + pivot];

for(int32 col = pivot; col < n; ++col)

A[n * pivot + col] *= inv;

b[pivot] *= inv;

A[n * pivot + pivot] = 1.0f;

// Zero the pivot column in all lower rows

for(int32 row = pivot + 1; row < n; ++row)

{

// Subtract a multiple of the pivot row from the current row

// such that the pivot column element becomes 0

f32 factor = A[n * row + pivot];

for(int32 col = pivot; col < n; ++col)

A[n * row + col] -= factor * A[n * pivot + col];

b[row] -= factor * b[pivot];

}

}

// Perform backwards substitution to solve for all variables

int32 i = n - 1;

do

{

--i;

// Subtract multiples of other known variables

for(int32 col = i + 1; col < n; ++col)

b[i] -= A[i * n + col] * b[col];

} while(i > 0);

return true;

}

int main( int argc, char **argv )

{

Mat3 A;

A.m00 = 1.0f; A.m01 = 1.0f; A.m02 = 1.0f;

A.m10 = 0.0f; A.m11 = 2.0f; A.m12 = 5.0f;

A.m20 = 2.0f; A.m21 = 5.0f; A.m22 = -1.0f;

Vec3 b;

b.x = 6.0f;

b.y = -4.0f;

b.z = 27.0f;

printf( "Input:\n\n" );

PrintVector( b.v, 3 );

PrintMatrix( A.v, 3 );

bool success = Solve( A.v, b.v, 3 );

if(!success)

{

printf( "Singular matrix\n" );

return 0;

}

printf( "Result:\n\n" );

PrintVector( b.v, 3 );

PrintMatrix( A.v, 3 );

return 0;

}