#include "design.h"
#include "contrast.h"

#include <rumba/arghandler.h>
#include <rumba/manifoldmatrix.h>
#include <rumba/matrixio.h>
#include <rumba/parse.h>


using RUMBA::ManifoldMatrix;
using RUMBA::Manifold;
using RUMBA::manifold_generator;
using RUMBA::makeMatrix;
using RUMBA::writeManifoldMatrix;
using RUMBA::Argument;
using RUMBA::intPoint;

using std::vector;
using std::string;

std::string usage()
{
	return "design-matrix "
	   " [--contrast-file|-f filename --contrast cont1 cont2 cont3 ... ]\n"
		"-i infile [--header-row n]\n"
		"	[--acquisitions-per-trial|-a n]\n"
	   "	[--covariate|-c cov1 cov2 ... ]\n"
	   "    [--filter <filter_string> [--ffile <filter output file>] "
		"--factors factor1 factor2 factor3 ..... \n";
}

Argument myArgs[] =
{
	Argument("header-row", RUMBA::NUMERIC, '\0', 1),
	Argument("covariate", RUMBA::ALPHA, '\0', RUMBA::Splodge(), false, true ),
	Argument("contrast", RUMBA::ALPHA, 'c', "", true, true ),
	Argument("contrast-file", RUMBA::ALPHA, '\0', "" ),
	Argument("acquisitions-per-trial", RUMBA::NUMERIC, 'a', 1 ),
	Argument("factors", RUMBA::ALPHA, '\0', "", true, true ),
	Argument("filter", RUMBA::ALPHA, '\0' ),
	Argument("ffile", RUMBA::ALPHA, '\0' ),
	Argument("filterfactors", RUMBA::ALPHA, '\0', "", false, true ),
	Argument()
};

bool find_helper ( const std::map<Treatment, double>& m, const Treatment& t )
{
	std::map<Treatment, double>::const_iterator it = m.find(t);
	return (it!=m.end() && fabs(it->second) > 1e-6 );
}

// move to the row below the header, and return the header row.
std::string cue(std::ifstream& fin, int header_row)
{
	string s;
	int count = 1;
	fin.clear();
	fin.seekg(0,std::ios::beg);
	do 
	{ 
		getline(fin,s);
		count++;
	} while (count <= header_row);
	return s;
}

//foobar(factors, contrasts, count_treatments(factors)+1+cov.size(), cfile );

void generate_contrasts
( 
 const vector<Factor>& factors, 
 const vector<RUMBA::Splodge>& contrasts, 
 int size, string cfile 
 )
{
	if ( !cfile.empty() )	
	{
		ManifoldMatrix C = makeMatrix( contrasts.size(), size 	);
		std::fill(C.begin(),C.end(),0);
		for ( unsigned int i = 0; i < contrasts.size(); ++i )
		{
			Parser p(factors,contrasts[i].asString());
			C.put( i,0, (generateContrast(p.parse(),factors)));
		}
		manifoldMatrixWriteHack (C, cfile );
	}

}

struct approx_equal
{
	bool operator()(double left, double right)
	{
		return fabs(left-right)<1e-6;
	}
};

ManifoldMatrix squash_matrix( ManifoldMatrix& M, const ManifoldMatrix & mask )
{
	static const double epsilon=1e-6;
	int rows = count (mask.begin(),mask.end(), 1.0);
	assert(mask.rows()==M.rows());

	ManifoldMatrix N = makeMatrix(rows, M.cols());
	for (unsigned int i = 0, k = 0; i < M.rows(); ++i )
		if (approx_equal()(mask.element(i,0),1))
			N.put(k++,0,M.subMatrix(i,1,0,M.cols()));

	return N;
}


ManifoldMatrix mask_matrix( ManifoldMatrix& M, const ManifoldMatrix & mask )
{
	static const double epsilon=1e-6;
//	int rows = count (mask.begin(),mask.end(), 1.0);
	assert(mask.rows()==M.rows());

	ManifoldMatrix N = makeMatrix(M.rows(), M.cols());
	for (unsigned int i = 0; i < M.rows(); ++i )
		for (unsigned int j = 0; j < M.cols(); ++j )
			N.element(i,j) = mask.element(i,0)*M.element(i,j);

	return N;
}

ManifoldMatrix generate_filter
(
 const vector<Treatment>& treatments, 
 const vector<Factor>& factors, 
 string filter_string, 
 string ffile, 
 bool verbose = false
 )
{
	ManifoldMatrix F = 
		makeMatrix( Manifold<double>(intPoint(1,1,1,treatments.size())));

	F = F.transpose();
	std::fill(F.begin(),F.end(),0);
	Parser p(factors,filter_string);
	std::map<Treatment,double> m = p.parse(); // what does this do ???

	if (verbose)
	{
		cout << "--" << endl;
		for ( map<Treatment,double>::iterator it = m.begin();
				it != m.end(); ++it )
			cout << it->first << " " << it->second << endl;
		cout << "--" << endl;

		copy (treatments.begin(), treatments.end(),
			ostream_iterator<Treatment>(cout,"\n"));
		cout << "--" << endl;
	}

	for ( std::vector<Treatment>::size_type i = 0;
		i != treatments.size(); ++i )
		if (find_helper(m,treatments[i]))
			F.element(i,0) = 1;
		else
			F.element(i,0) = 0;
	if (ffile.empty())
		writeManifoldMatrix(F);
	else 
		manifoldMatrixWriteHack(F,ffile.c_str());

	return F;
}

class SplodgeComp
{
public: 
	inline bool operator() 
		(const RUMBA::Splodge& left, const RUMBA::Splodge& right)
	{
		return left.asString()==right.asString();
	}
};

RUMBA::ManifoldMatrix getCovariateRow
(
	const std::vector<int>& v, 
	const std::string& line
)
{
	std::list<string> tokens_ = RUMBA::tokenizeCsvLine(line,'\t');
	vector<string> tokens(tokens_.begin(),tokens_.end());

	std::for_each (tokens.begin(),tokens.end(),RUMBA::strip );
	intPoint dims (1,1,1,v.size());
	RUMBA::ManifoldMatrix M = makeMatrix(RUMBA::Manifold<double>(dims));

	assert (v.size()<=tokens.size());

	for ( vector<string>::size_type i = 0; i < v.size(); ++i  )	
	{
		if ( v[i] > static_cast<int>(tokens.size()) )
			throw RUMBA::Exception ("index out of range in getCovariateRow");
		M.element(i,0)=RUMBA::stream_cast<double>( tokens[v[i]] );
	}
	return M;
}

void getCovariateIndices
( 
 vector<int>& indices, 
 vector<RUMBA::Splodge>& cov, 
 const std::string& s
 )
{
	indices.clear();
	std::list<string> tokens_ = RUMBA::tokenizeCsvLine(s,'\t');
	vector<string> tokens(tokens_.begin(),tokens_.end());
	std::for_each (tokens.begin(),tokens.end(),RUMBA::strip );
	vector<string>::iterator it2;

	for ( vector<RUMBA::Splodge>::const_iterator it = cov.begin();
			it != cov.end(); ++it )
	{
		it2 = std::find(tokens.begin(),tokens.end(),it->asString());
		if (it2 == tokens.end()) 
			throw RUMBA::Exception (string("Not found: ") + it->asString());
		indices.push_back (it2-tokens.begin());
	}
}




int main(int argc, char** argv)
{

	string infile,outfile,cfile;
	string s;

	int header_row = 1;
	vector<string> headers;
	vector<string> filter_headers;
	std::ifstream fin;
	vector<Factor> factors;
	vector<Factor> filter_factors;
	vector<RUMBA::Splodge> cov;
	vector<RUMBA::Splodge> contrasts;
	vector<int> cov_indices;
	vector<Treatment> treatments;
	vector<Treatment> filter_treatments;
	ManifoldMatrix M;
	ManifoldMatrix N;
	int acquisitions_per_trial;
	std::string ffile,filter_string;

	bool verbose = false;

try{
		RUMBA::ArgHandler argh(argc,argv, myArgs);
//		argh.print();
		if (argh.arg("help"))
		{
			std::cerr<<usage()<<std::endl;
			return 0;
		}
		
		argh.arg("infile",infile);
		argh.arg("outfile",outfile);
		argh.arg("contrast-file",cfile);
		argh.arg("header-row",header_row);
		argh.arg("acquisitions-per-trial",acquisitions_per_trial);
		if (argh.arg("ffile"))
			argh.arg("ffile",ffile);

		std::transform ( 
				argh.multiarg("factors").begin(), 
				argh.multiarg("factors").end(), 
				std::back_inserter(headers),
				std::mem_fun_ref(&RUMBA::Splodge::asString) 
				);

		std::transform ( 
				argh.multiarg("filterfactors").begin(), 
				argh.multiarg("filterfactors").end(), 
				std::back_inserter(filter_headers),
				std::mem_fun_ref(&RUMBA::Splodge::asString) 
				);

		cov=argh.multiarg("covariate");
		contrasts=argh.multiarg("contrast");
		fin.open(infile.c_str());
		if (!fin)
		{
			std::cerr << "Couldn't open file " << infile << std::endl;
			exit(1);
		}
	
		s = cue(fin,header_row);
	
		if (!fin)
		{
			std::cerr << "Couldn't skip to header row " << header_row << std::endl;
			exit(1);
		}

		if (headers.empty())
			throw RUMBA::ArgHandlerException(
					"must supply the names of some factors"
					);
		process_header(factors, s, headers);
		process_header(filter_factors, s, filter_headers);
		while ( getline(fin,s))
		{
			process_row(factors,s);
			process_row(filter_factors,s);
		}

		s = cue (fin,header_row);
		treatments = readfile(fin,factors);
		s = cue (fin,header_row);
		filter_treatments = readfile(fin,filter_factors); // ?????
		M = design_matrix2(treatments,factors);
		N = makeMatrix(M.rows(), M.cols() + cov.size());
		N.put(0,0,M);
		if (!cov.empty())
		{
			s = cue (fin,header_row);
			getCovariateIndices(cov_indices,cov,s);
			for (int i = 0; getline(fin,s); ++i )
				N.put(i,M.cols(),getCovariateRow(cov_indices,s));
		}

		generate_contrasts
			(factors,contrasts,count_treatments(factors)+1+cov.size(),cfile);

		if (argh.arg("filter"))
			argh.arg("filter",filter_string);

		if (!filter_string.empty() && ! ffile.empty() )
		{
			ManifoldMatrix tmp = generate_filter 
				(filter_treatments, filter_factors, filter_string, ffile);
//			N = squash_matrix(N,tmp);
			N = mask_matrix(N,tmp);

		}

		N = repeatRows(N,acquisitions_per_trial);
		manifoldMatrixWriteHack (N,outfile);
	
		return 0;
}
catch ( RUMBA::InvalidArgumentException& s)
{
	std::cerr << "Invalid argument: " << s.error() << std::endl;
	exit(1);
}
catch (RUMBA::DuplicateArgumentException& s)
{
	std::cerr << "Duplicate argument: " << s.error() << std::endl;
	exit(1);
}
catch (RUMBA::ArgHandlerException& s)
{
	std::cerr << "Error: " << s.error() << std::endl;
	exit(1);
}
 catch (RUMBA::Exception& e)
{
	std::cerr << e.error() << std::endl; exit(1);
}
catch ( std::exception& e)
{
	std::cerr << e.what() << std::endl; exit(1);
}


}
