#include <stdio.h>
#include "matfuncs.h"

void
m_gj_invert(mat **inv, mat *a)
{
	mat *AI, *dummy, *LI, *result;
	unsigned long i, j;
	unsigned long from, to, col;
	unsigned long siz;
	unsigned long targ;
	double c;

	if (a->rows != a->cols) {
		fprintf(stderr,
			"gauss-jordan inversion works for square matrices, "
			"not [%lux%lu] matrices\n",a->rows,a->cols);
		*inv = NULL;
	}
	siz = a->rows;

	AI = m_alloc(siz, siz * 2);
	if (!AI) {
		*inv = NULL;
		return;
	}
	for (j = 0; j < AI->cols; j++)
		for (i = 0; i < AI->rows; i++) {
			/* Left of halfway is orig matrix */
			/* Right of halfway is identity [nxn] */
			if (j < a->cols)
				AI->e[i][j] = a->e[i][j];
			else
				AI->e[i][j] = (j - a->cols == i) ? 1 : 0;
		}
	
	m_gausselim(AI, &LI, &dummy);
	m_dealloc(AI);
	m_dealloc(dummy);
	/* m_print(LI); */

	for (targ = siz - 1; targ != -1; targ--) {
		/* printf("DIAGONAL ELEMENT [%lu,%lu]: ----\n",targ,targ); */
		for (i = targ - 1; i != -1; i--) {
			/* printf("comparing LI[%lu][%lu] and LI[%lu][%lu]...\n",targ,targ,i,targ); */
			c = LI->e[i][targ]/LI->e[targ][targ];
			/* printf("adding -%.4f * row %lu to row %lu\n",c,targ,i); */
			m_addrow(LI, -c, targ, i);
		}
	}
	for (targ = 0; targ < siz; targ++) {
		m_rowmult(LI, targ, 1.0/LI->e[targ][targ]);
	}
	/* m_print(LI); */
	result = m_alloc(siz,siz);
	if (!result) {
		*inv = NULL;
		return;
	}
	/* m_print(result); */
	for (i = 0; i < siz; i++) {
		for (j = 0; j < siz; j++) {
			result->e[i][j] = LI->e[i][j + siz];
		}
	}
	m_dealloc(LI);
	/* m_print(result); */
	*inv = result;
}

int
main(int argc, char **argv)
{
	mat *a = NULL, *ai = NULL;

	if (strcmp(argv[1],"-")==0) {
		a = m_read_from_stream(stdin);
	} else {
		a = m_read_from_file(argv[1]);
	}
	
	m_gj_invert(&ai, a);
	m_print(ai);
	exit(0);
}
