#include <iostream>
#include <cmath>
#include <rumba/arghandler.h>
#include <rumba/manifoldmatrix.h>
#include <rumba/matrixio.h>
#include <rumba/log.h>

#include "../matrixutils/multutils.h"
#include "../matrixutils/t_dist.h"

using namespace RUMBA;

void usage()
{
	std::cerr << "Usage: glmfit\n\t-i infile \n\t--design file \n\t[--filter file]\n\t[--contrasts file --resid file[--tmap file][--pmap file]]\n\t[--beta file]\n";
}


// USER-SPECIFIED COMMAND-LINE ARGUMENTS

Argument myArgs [] = {
	Argument ( "design", RUMBA::ALPHA, 'd', "", true ),
	Argument ( "filter", RUMBA::ALPHA, 'f', "", false ),
	Argument ( "contrasts", RUMBA::ALPHA, 'c', "", false ),
	Argument ( "tmap", RUMBA::ALPHA, 't', "", false ),
	Argument ( "pmap", RUMBA::ALPHA, 'p', "", false ),
	Argument ( "beta", RUMBA::ALPHA, 'b', "", false ),
	Argument ( "resid", RUMBA::ALPHA, 'r', "", false ),
	Argument ( "lazy", RUMBA::FLAG, 'z', false ),
	Argument ( "epsilon", RUMBA::ALPHA, 'e', "" ),
	Argument()
};

bool check_exists(const std::string& name)
{
	return (std::ifstream (name.c_str()));
}

//ManifoldMatrix getRow ( const ManifoldMatrix&, int );
//ManifoldMatrix makeMatrix ( Manifold<double>, bool = false );

Manifold<double> 
contrast_sd
(
 const ManifoldMatrix& M, 
 const ManifoldMatrix& varb, 
 const Manifold<double>& eps, 
 int rownum
)
{
	Manifold<double> result (intPoint(eps.width(),eps.height(),eps.depth(),1));
	ManifoldMatrix c = getRow ( M, rownum );
	for ( int i = 0; i < eps.pixels(); ++i )
		result[i] = (sqrt((c*eps[i]*varb*c.transpose()).element(0,0)));
	return result;
}


inline double pixelNorm(BaseManifold* f, int pixel)
{
	int incr=f->pixels(); 
	int sz = f->size();
	double tmp;
	double res = 0;

	for ( int i = pixel; i < sz; i+=incr )
	{
		tmp=f->getElementDouble(i);
		res+=tmp*tmp;
	}
	return res;
}

RUMBA::ManifoldMatrix pinv(const RUMBA::ManifoldMatrix& G)
{
	return invert( G.transpose() * G ) * G.transpose();
}


int main(int argc,char** argv)
{
	RUMBA::Log log ( "MAIN" );
	std::string infile , betafile , designfile , filterfile , contrastsfile , tmapfile , pmapfile , residfile, epsilonfile;
	
	ManifoldMatrix G,V,K,R,pinvG, invGtG, GinvGtG, contrasts, varb;
	Manifold<double> epsilonSquared;
	ManifoldFile* r = 0;
	ManifoldFile* T = 0;
	ManifoldFile* B = 0;
	ManifoldFile* DataFile = 0;
	bool lazy;
// HANDLING COMMAND-LINE ARGUMENTS IN A TRY-CATCH BLOCK
	ArgHandler::setRequiredDefaultArg("infile");

	try {
		ArgHandler argh ( argc , argv, myArgs );
		if ( argh.arg("help") )
		{
		usage();
		exit(0);
		}
		argh.arg ( "infile" , infile );
		argh.arg ( "design" , designfile );
		argh.arg ( "filter" , filterfile );
		argh.arg ( "contrasts" , contrastsfile );
		argh.arg ( "beta" , betafile );
		argh.arg ( "tmap" , tmapfile );
		argh.arg ( "pmap" , pmapfile );
		argh.arg ( "resid" , residfile );
		argh.arg ( "epsilon", epsilonfile );
		lazy = argh.arg("lazy");
	}
	catch ( RUMBA::InvalidArgumentException& s)
	{
		std::cerr << "Invalid argument: " << s.error() << std::endl;
		usage();
	}
    catch (RUMBA::DuplicateArgumentException& s)
    {
		std::cerr << "Duplicate argument: " << s.error() << std::endl;
		usage();
	}
	catch (RUMBA::MissingArgumentException& s)
	{
		std::cerr << "Missing argument: " << s.error() << std::endl;
		usage();
	}
	catch (RUMBA::ArgHandlerException& s)
	{
		std::cerr << "Error: " << s.error() << std::endl; usage();
	}
	catch (Exception& s)
	{
		std::cerr << "Exception:" << s.error() << std::endl;
		usage();
		exit(1);
	}

	try {	

// LOADING THE DATA (X), DESIGN (G), AND POSSIBLE CONVOLUTION FILTER (K)

	DataFile = ManifoldFile::construct ( infile.c_str(), std::ios::in );
	if (!DataFile)
	{	
		throw RUMBA::Exception(
			std::string ( "Couldn't open file ") 
			+ infile + " for reading.\n");
	}

	if (epsilonfile.empty())
	{
		epsilonSquared = Manifold<double> ( intPoint(
			DataFile->width(),DataFile->height(),DataFile->depth(),1)
		);
	}
	else
	{
		epsilonSquared = Manifold<double> ( epsilonfile.c_str() );
	}



	G = manifoldMatrixReadHack(designfile.c_str());	// load design matrix
	if ( G.rows() == 0 || G.cols() == 0 )
		throw RUMBA::Exception(std::string("Couldn't succesfully open")
			+ designfile + "for input"); 
	V = identityMatrix( DataFile->timepoints() );
	K = identityMatrix( DataFile->timepoints() );
	// check design matrix dimensions
	if ( G.rows() != DataFile->timepoints() )
	{
		throw RUMBA::Exception ( 
			"The design matrix has incompatible dimensions with the data\n"
		);
	}

	if ( !filterfile.empty() )
	{
		K = manifoldMatrixReadHack(filterfile.c_str());	// load convolution matrix
		if (K.rows()!=K.cols() || K.cols() != G.rows() )
		{
			throw RUMBA::Exception (
				"The filter file has incompatible dimensions with the data"
			);
		}
		G = K * G;
		V = K * K.transpose();
	}


// ESTIMATING THE REGRESSION WEIGHTS, b


	invGtG = invert( G.transpose() * G );
	pinvG = invGtG * G.transpose();
	if ( !lazy && check_exists (betafile))
		throw RUMBA::Exception("Beta file already exists." );
	if ( !lazy && check_exists (residfile))
		throw RUMBA::Exception("Beta file already exists." );

	if (!lazy)
	{
		B = ManifoldFile::construct ( 
			betafile.c_str(), "float64", 
			intPoint
			( DataFile->width(), 
			  DataFile->height(), 
			  DataFile->depth(), 
			  G.cols() ) 
			);
		multiply ( pinvG * K, DataFile, B );
	}
	else
		B = ManifoldFile::construct ( betafile.c_str());

// ASSESSING THE FIT, IF DESIRED

	if ( ! ( tmapfile.empty() && pmapfile.empty() && residfile.empty() ))
	{
		GinvGtG = G * invGtG;
		R = identityMatrix(DataFile->timepoints()) - GinvGtG * G.transpose(); // residual forming matrix


		if (!residfile.empty())
		{
			if (lazy)
				r = ManifoldFile::construct(residfile.c_str());
			else
				r = ManifoldFile::construct(residfile.c_str(), 
						"float64", DataFile->extent());

			if (!r)
			{
				throw RUMBA::Exception(std::string(
				"Couldn't open residuals file ") + residfile + " for writing"
				);
			}
			if (!lazy)
				multiply ( R * K, DataFile, r );
		}

		if (residfile.empty() && ! pmapfile.empty() )
		{
			throw RUMBA::Exception("Pmap requires residuals");
		}
		
		if ( !contrastsfile.empty() && !tmapfile.empty()   )    // && !pmapfile.empty()
		{
			contrasts = manifoldMatrixReadHack( contrastsfile.c_str() );

			if (r)
			{
				r->setCacheStrategy(RUMBA::CACHE_TIME);
				for ( int i = 0; i < r->pixels(); ++i )
					epsilonSquared[i]=pixelNorm(r,i);
				varb = (1/trace(R*V) * pinvG * V * GinvGtG);
			}


			T = ManifoldFile::construct ( tmapfile.c_str(), "float64", 
				intPoint( DataFile->width(), DataFile->height(), DataFile->depth(), contrasts.rows() ) 
			);

			multiply ( contrasts,  B,  T);

			for ( int i=0 ; i < contrasts.rows(); ++i )
			{
		
				Manifold<double> tmp = contrast_sd ( contrasts, varb, epsilonSquared, i );
				BaseManifold* tmpMF = T->get ( 0,0,0,i, T->width(), T->height(), T->depth(), 1 );
				for ( int j = 0; j < T->pixels(); ++j ) 
				{
					tmpMF->setElementDouble ( j, tmpMF->getElementDouble ( j ) / tmp[j] );
				}
				T->put ( tmpMF, intPoint(0,0,0,i) );
				delete tmpMF;
			}

 			if ( !pmapfile.empty() ) 
			{
				double nu = trace( R*V )*trace(R*V) / trace (R*V*R*V);
				int i = 0;
				Manifold<double> p (T->extent())  ; 
				for ( double * it = p.begin(); it != p.end(); ++it )
					*it = RUMBA::t_to_p(T->getElementDouble(i++), nu);
				p.save( pmapfile.c_str() );

			}

		}
	}


	}
	catch ( RUMBA::Exception& e) 
	{ 
		delete r;
		delete T;
		delete B;
		delete DataFile;

		std::cerr << e.error() << std::endl; exit(1); 
	}

	delete r;
	delete T;
	delete B;
	delete DataFile;
}


