#include <iostream>
#include <rumba/matrixio.h>

#include <rumba/manifoldmatrix.h>
#include <rumba/arghandler.h>

#include "../matrixutils/normalize.h"
#include "../matrixutils/stats.h"


using namespace RUMBA;
using std::string;
using std::cerr;
using std::endl;
// get residual forming matrix. X is a col vector
ManifoldMatrix residual(ManifoldMatrix& X)
{
	return identityMatrix(X.rows()) - X*invert(X.transpose()*X)*X.transpose();
}

double sumsquare(const ManifoldMatrix& v)
{
	double result = 0;
	// make v a unit vector
	for ( const double* it = v.begin();  it != v.end(); ++it )
		result += *it * *it;
	return result;
}
	



void do_glm(ManifoldFile* in, ManifoldFile* out)
{
	in->setCacheStrategy(RUMBA::CACHE_TIME);
	out->setCacheStrategy(RUMBA::CACHE_TIME);
	intPoint ex = in->extent(); ex.z() = 1;

	int px = in->pixels();
	int tp = in->timepoints();

	ManifoldMatrix v(makeMatrix(tp,1));

	ManifoldMatrix X(makeMatrix(tp,2));
	for (int i = 0; i < X.rows(); ++i )
	{
		X.element(i,0) = i;
		X.element(i,1) = 1;
	}

	ManifoldMatrix R = residual(X);

	for ( int i = 0; i < px; ++i )
	{
		for ( int j = 0; j < tp; ++j ) v.element(j,0) = (*in)[i+j*px];
		v = R*v;
		for ( int j = 0; j < tp; ++j ) (*out)[i+j*px] = v.element(j,0);
	}
}

void do_moving_glm(ManifoldFile* in, ManifoldFile* out, int period)
{
	in->setCacheStrategy(RUMBA::CACHE_TIME);
	out->setCacheStrategy(RUMBA::CACHE_TIME);
	intPoint ex = in->extent(); ex.z() = 1;

	int px = in->pixels();
	int tp = in->timepoints();

	ManifoldMatrix v(makeMatrix(period,1));

	ManifoldMatrix X(makeMatrix(period,2));
	for (int i = 0; i < X.rows(); ++i )
	{
		X.element(i,0) = i;
		X.element(i,1) = 1;
	}

	ManifoldMatrix R = residual(X);

	for ( int i = 0; i < px; ++i )
	{
		// preliminary: compute leading line:
		for ( int j = 0; j < period; ++j ) 
			v.element(j,0) = (*in)[i+j*px];
		v = R*v;
		// the above line is good for the first period/2 points
		for ( int j = 0; j<= period/2; ++j ) 
			(*out)[i+j*px] = v.element(j,0);

		// now compute moving averages ... 
		for (int j = 1; j < tp - period; ++j )
		{
			for ( int k = j; k < j + period; ++k )
			{
				v.element(k-j,0) = (*in)[i+k*px];
			}


			v = R*v;
			
			(*out)[i+(j+period/2)*px] = v.element(period/2,0);

		}

		// now compute trailing line
		for (int j = tp - period; j < tp; ++j ) 
			v.element(j - (tp - period),0) = (*in)[i+j*px];
		v = R*v;
		for (int j = tp - period/2; j < tp; ++j ) 
			(*out)[i+j*px] = v.element(j-(tp-period/2),0);
	}
}

void do_mean ( ManifoldFile* in, ManifoldFile* out )
{
	BaseManifold* N;
	intPoint ex = in->extent(); ex.t()=1;
	for ( int i = 0; i < in->timepoints(); ++i )
	{
		N = in->get(intPoint(0,0,0,i),ex);
		volume_normalize(N, RUMBA::MEAN_NORMALIZE);
		out->put(N,intPoint(0,0,0,i));
		delete N;
	}
}

void do_decorrelate( ManifoldFile* in, ManifoldFile* out )
{
	in->setCacheStrategy(RUMBA::CACHE_TIME);
	out->setCacheStrategy(RUMBA::CACHE_TIME);
	const double epsilon = 1e-8;
	const int px = in->pixels();
	const int tp = in->timepoints();
	ManifoldMatrix v(makeMatrix(in->timepoints(),1));
	ManifoldMatrix w(makeMatrix(in->timepoints(),1));

	std::fill(v.begin(),v.end(),0);

	for ( int i = 0; i < px; ++i )
		for ( int j = 0; j < tp; ++j  )
			v.element(j,0) += static_cast<double>((*in)[ i + j*px ]);
	writeManifoldMatrix(v); 
	cout << sumsquare(v) << endl;

	if (sumsquare(v) > epsilon)
		v = v*(1/std::sqrt(sumsquare(v)));
	else
		v = v*0;

	writeManifoldMatrix(v); 

	for ( int i = 0; i < px; ++i )
	{
		for ( int j = 0; j < tp; ++j )
			w.element(j,0) = (*in)[ j*px + i ];

		w = w - v * (w.transpose()*v).element(0,0);
		
		for ( int j = 0; j < tp; ++j )
			(*out)[ j*px + i ] = w.element(j,0);
	}
			

}


Argument myArgs[] = 
{
	Argument("type", RUMBA::ALPHA, 't', "l" ),
	Argument("period", RUMBA::NUMERIC, 'p' ),
	Argument()
};

void help_message()
{
	std::cerr << 
	"Usage: detrend_transform -i infile -o outfile [[--type|-t] (l|m|d)] \n\n"
	"Types of detrending:\n"
	"----------------------\n"
	"d: decorrelate; subtracts from each time series its orthogonal\n"
	"projection onto\n 	the mean time series\n\n"
	"l: linear; subtract the line of best fit from each time series.\n"
	"This is the default. \n\n"
	"m: mean; subtract the volume mean from each volume\n"
	"n: nonlinear: use moving regression lines to adjust each point. To use"
	"this method, it's necessary to use [-p|--period] number_of_images, to \n"
	"specify the number of points in a time series used to fit a moving line\n"
	<< std::endl;
}

int main(int argc, char** argv)
{
	ManifoldFile *inManifold, *outManifold;
	std::string infile, outfile, type;	
	double period = -1;

	try 
	{
		ArgHandler argh(argc,argv,myArgs);	

		if ( argh.arg("help"))
		{	
			help_message();
			exit(0);
		}
		argh.arg("infile",infile);
		argh.arg("outfile",outfile);
		argh.arg("type",type);
		if (argh.arg("period"))
			argh.arg("period",period);

		inManifold = ManifoldFile::construct(infile.c_str());
		if (!inManifold) throw RUMBA::Exception ("Couldn't open input file");
		outManifold = ManifoldFile::construct
			(outfile.c_str(),std::ios::out,inManifold);

		if (period < 0 && type == "n")
		{
			std::cerr << "Must supply -p to use nonlinear method\n";
			delete inManifold; delete outManifold;
			exit(1);
		}

		if ( type == "n" && period > inManifold->timepoints())
		{
			std::cerr << "Number of samples per line, " << period 
				<< " is greater than the number of timepoints" << std::endl;
			delete inManifold; delete outManifold;
			exit(1);
		}
	
		if ( type == "l" )	
		{
			do_glm(inManifold,outManifold);
		}
		else if ( type == "m" )
		{
			do_mean(inManifold,outManifold);
		}
		else if ( type == "d" )
		{
			do_decorrelate(inManifold,outManifold);
		}
		else if ( type == "n" )
		{
			do_moving_glm(inManifold,outManifold, period);
		}
		else { throw RUMBA::Exception( string("Unknown type " + type ) ); }
		
		delete inManifold;
		delete outManifold;

	}
	catch ( RUMBA::InvalidArgumentException& s)
	{
		std::cerr << "Invalid argument: " << s.error() << std::endl;
	}
    catch (RUMBA::DuplicateArgumentException& s)
    {
		std::cerr << "Duplicate argument: " << s.error() << std::endl;
	}
	catch (RUMBA::ArgHandlerException& s)
	{
		std::cerr << "Error: " << s.error() << std::endl;
	}
	
	catch (RUMBA::Exception& e)
	{
		cerr << e.error() << endl;
	}

}
