#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
#include "Matrix.h"
#include "Graph.h"

double round_check( double x )
{
    double y = round( x );
    assert( fabs( y - x ) < 0.1 );
    return y;
}

double gcd( double x, double y ) 
{
    assert( x==round(x) && y==round(y) );
    while ( y != 0. ) {
        double temp = x - y*round(x/y);
        x = y;
        y = temp;
    }
    return fabs(x);
}


Matrix::Matrix() {
    data = 0;
}


Matrix::Matrix( u32 n ) {
    data = (Data*)malloc( sizeof( Data ) );
    size = n;
    mstate = INIT;
    for( u32 i = 0; i < n; ++i  )
        for( u32 j = 0; j <= i ; ++j )
            (*data)[i][j] = 0.;
}

Matrix::~Matrix()
{
    free( data );
    data = 0;
}

Matrix& Matrix::operator=( Matrix& m ) {
    size = m.size;
    mstate = m.mstate;
    det = m.det;
    if ( data == 0 )
        data = (Data*)malloc( sizeof( Data ) );
    for( u32 i = 0; i < size; ++i  )
        for( u32 j = 0; j <= i ; ++j )
            (*data)[i][j] = (*m.data)[i][j];
    return *this;
}

Matrix::Matrix( Matrix& m ) 
{
    data = (Data*)malloc( sizeof( Data ) );
    *this = m;
}

Matrix::Matrix( Graph& g )
{
    data = (Data*)malloc( sizeof( Data ) );
    size = g.num_nodes - 1;
    mstate = INIT;
    for( u32 i = 0; i < size; ++i  )
        for( u32 j = 0; j <= i ; ++j )
            (*data)[i][j] = 0.;
    for (Edge* e = g.edge_list; e; e=e->next ) {
        u32 n0 = e->dedge_0.node_num;
        u32 n1 = e->dedge_1.node_num;
        select( n0, n0 ) += 1.;
        select( n1, n1 ) += 1.;
        select( n0, n1 ) = -1.;
    }
}

double Matrix::transfer_impedance( u32 n0, u32 n1, u32 n2, u32 n3 ) {
    return select(n0, n2) + select(n1, n3) - select( n0, n3) - select(n1, n2);
}

// If "this" represents the resistance matrix of a graph, and a new edge is added to
// the graph connecting nodes n0 and n1, this function updates the resistance matrix
// to the new metrix of the new graph including the added edge.
//
// This is a reference implementation of "add_edge()".  This is a simple implementation
// to illustrate the mathematics.  The subsequent implementation has
// been hand optimized for speed.
void Matrix::add_edge_ref( u32 n0, u32 n1 )
{
    double old_det = det;
    double new_det = det + transfer_impedance(n0, n1, n0, n1);
    Data   new_data;
    assert( mstate == INVERSE );
    for( u32 i = 0; i < size; ++i  ) {
        for( u32 j = 0; j <= i ; ++j ) {
            new_data[i][j] = round_check( new_det/old_det*(*data)[i][j]  -  transfer_impedance(n0,n1,i+1,0)*transfer_impedance(n0,n1,j+1,0)/old_det );
        }
    }
    for( u32 i = 0; i < size; ++i  ) {
        for( u32 j = 0; j <= i ; ++j ) {
            (*data)[i][j] = new_data[i][j];
        }
    }
    det = new_det;
}



#define M(a,b) ( a>b ? (*data)[a][b] : (*data)[b][a] )
void Matrix::add_edge( u32 n0, u32 n1 )
{
    double old_det = det;
    double new_det = det + transfer_impedance(n0, n1, n0, n1);
    double new_det_over_old_det = new_det / old_det;
    double one_over_old_det = 1. / old_det;
    double n0_column[32];
    double n1_column[32];

    assert( mstate == INVERSE );
    if ( n0 == 0 ) {
        for ( u32 i = 0; i < size; ++i ) {
            n1_column[i] = M(n1-1, i );
        }
        for( u32 i = 0; i < size; ++i  ) {
            for( u32 j = 0; j <= i ; ++j ) {
                (*data)[i][j] = round_check( new_det_over_old_det * (*data)[i][j] - n1_column[i] * n1_column[j] * one_over_old_det );
            }
        }
    } else if ( n1 == 0 ) {
        for ( u32 i = 0; i < size; ++i ) {
            n0_column[i] = M(n0-1, i );
        }
        for( u32 i = 0; i < size; ++i  ) {
            for( u32 j = 0; j <= i ; ++j ) {
                (*data)[i][j] = round_check( new_det_over_old_det * (*data)[i][j] - n0_column[i]* n0_column[j] * one_over_old_det );
            }
        }
    } else {
        for ( u32 i = 0; i < size; ++i ) {
            n0_column[i] = M(n0-1, i );
            n1_column[i] = M(n1-1, i );
        }
        for( u32 i = 0; i < size; ++i  ) {
            for( u32 j = 0; j <= i ; ++j ) {
                (*data)[i][j] = round_check( new_det_over_old_det * (*data)[i][j] - (n0_column[i] - n1_column[i]) * (n0_column[j] - n1_column[j]) * one_over_old_det );
            }
        }
    }
    det = new_det;
}


void Matrix::print()
{
    for( u32 i = 0; i < size; ++i  ) {
        for( u32 j = 0; j < size ; ++j ) {
            printf("%10g  ", select( i+1, j+1 ) );
        }
        printf("\n");
    }
    printf( "det = %g\n\n", det );
}




double& Matrix::select( u32 i, u32 j ) {
    static double zero;
    zero = 0.;
    if ( i==0 || j==0 ) 
        return zero;
    else if ( j > i )
        return (*data)[j-1][i-1];
    else
        return (*data)[i-1][j-1];
};

void Matrix::chol( )
{
    // compute in place cholesky decompostion of a.
    // L is lower triangular and A = L * tr(L)
    det = 1.;
    for ( u32 j = 0; j< size; ++j ) {

        // compute r[j][j]
        double s = 0.;
        for ( u32 k = 0; k<j; ++k ) {
            double x = (*data)[j][k];
            s = s + x*x;
        }
        s = (*data)[j][j] - s;
        det *= s;
        (*data)[j][j] = sqrt( s );

        for ( u32 i = j+1; i < size; ++i ) {
            // compute r[i][j]
            double s = 0.;
            for ( u32 k = 0; k < j; ++k ) {
                s += (*data)[i][k] * (*data)[j][k];
            }
            (*data)[i][j] = ( (*data)[i][j] - s ) / (*data)[j][j];
        }
    }
    det = round_check( det );
}


void Matrix::inverse_chol( )
{
    // compute the inverse of the cholesky matrix
    for ( u32 i = 0; i<size; ++i ) {
        (*data)[i][i] = 1./(*data)[i][i];
        for ( u32 j = 0; j < i; ++j ) {
            double s = 0.;
            for ( u32 k = j; k < i; ++k ) {
                s += (*data)[i][k]*(*data)[k][j];
            }
            (*data)[i][j] = -s*(*data)[i][i];
        }
    }
}

void Matrix::chol2inverse( )
{
    // convert a cholesky decomposed matrix to
    // the inverse of the original matrix multiplied by the det.

    inverse_chol();
    // now compute Ct*C
    for ( u32 j = 0; j < size; ++j ) {
        for ( u32 i = j; i < size; ++i ) {
            double sum = 0.;
            for ( u32 k = i; k < size; ++ k ) {
                sum += (*data)[k][i] * (*data)[k][j];
            }
            (*data)[i][j] = sum;
        }
    }
    for ( u32 j = 0; j < size; ++j ) {
        for ( u32 i = j; i < size; ++i ) {
            (*data)[i][j] = round_check( det*(*data)[i][j] );
        }
    }
}


double Matrix::determinant()
{
    if ( mstate == INIT ) {
        chol();
        mstate = CHOL;
        }
    return det;
}

double Matrix::invert()
{
    if ( mstate == INIT ) {
        chol();
        chol2inverse();
    } else if ( mstate == CHOL ) {
        chol2inverse();
    }
    mstate = INVERSE;
    return det;
}






