/*  job_computedifferentialequations.cpp
 *
 *  Copyright (C) 2010-2012 Andreas von Manteuffel
 *  Copyright (C) 2010-2012 Cedric Studerus
 *
 *  This file is part of the package Reduze 2.
 *  It is distributed under the GNU General Public License version 3
 *  (see the file GPL-3.0.txt or http://www.gnu.org/licenses/gpl-3.0.txt).
 */

#include "job_computedifferentialequations.h"
#include "functions.h"
#include "files.h"
#include "filedata.h"
#include "ginacutils.h"
#include "streamutils.h"
#include "kinematics.h"
#include "integralfamily.h"
#include "equation.h"
#include "yamlutils.h"
#include "sector.h"
#include "identitygenerator.h"

using namespace std;
using namespace GiNaC;

namespace Reduze {

// register job type at JobFactory
namespace {
JobProxy<ComputeDifferentialEquations> dummy;
}

bool ComputeDifferentialEquations::find_dependencies(
		const set<string>& outothers,//
		list<string>& in, list<string>& out, list<Job*>& auxjobs) {
	find_dependencies_reductions(outothers, in);
	out.push_back(output_filename_);
	return true;
}

void ComputeDifferentialEquations::print_manual_options(YAML::Emitter& os) const {
	using namespace YAML;
}

void ComputeDifferentialEquations::read_manual_options(const YAML::Node& node) {
	if (output_filename_.empty())
		throw runtime_error("output file undefined");
}

std::string ComputeDifferentialEquations::get_description() const {
	return string("compute differential equations");
}

void remove_other_sectors(LinearCombination& in, const Sector& sec) {
	LinearCombination out;
	for (LinearCombination::const_iterator t = in.begin(); t != in.end(); ++t)
		if (t->first.get_sector() == sec)
			out.insert_end(t->first, t->second);
	in.swap_terms(out);
}

// basic idea of the algorithm:
//   An integral I is a function of the kinematic invariants x_i
//      I = I(x_1,..,x_n) where x_i = s,t,...,m_1^2,...
//   Here we want to compute
//      (d I)/(d x_i)
//   for each x_i.
//
//   For this purpose we consider the "off-shell" integral J
//      J = J(s_1,..,s_m,x_1,..,x_n)
//   where the x_1,...,x_n denote _explicit_ occurances of the
//   parameters (i.e. not implicitely via the momenta) and
//   s_l = p_1^2, p_1*p_2, ... originating from momenta.
//   J is equal to I for "on-shell" values of s_l. We compute
//      (d I)/(d x_i) =
//           ((d J)/(d s_l) (d s_l)/(d x_i) + (d J)/(d x_i))_ONS
//   To get (d J)/(d s_l) we compute
//      p_j (d I)/(d p_k) = p_j (d s_l)/(d p_k) (d I)/(d s_l),
//   and solve for (d I)/(d s_l) (over-determined system).
//
//   Subsequently the derivatives will typically be reduced with
//   available reductions, e.g. IBPs

void ComputeDifferentialEquations::run_serial() {
	LOG("Generating differential equations for integrals in " << integrals_filename_);
	OutFileLinearCombinations out(output_filename_.c_str());

	LOG("Reading integrals to compute DEQ for from " + integrals_filename_);
	std::list<INT> integrals;
	InFileINTs in(integrals_filename_);
	in.get_all(integrals);
	LOG("  found " << integrals.size() << " integrals");

	list<INT>::const_iterator integ;
	for (integ = integrals.begin(); integ != integrals.end(); ++integ) {
		LOG("\nProcessing " << *integ << " from sector " << integ->get_sector());

		const IntegralFamily* fam = integ->integralfamily();
		const Kinematics* kin = fam->kinematics();
		GenericIdentityGenerator idgen(fam);
		const exmap& sp2inv = kin->rules_sp_to_invariants();
		const lst& p = kin->independent_external_momenta();
		const lst& x = kin->kinematic_invariants();

		// prepare dsl/dxi = d(pa * pb |_ONS)/dxi  and  pj dsl/dpk
		LOG("Differentiating scalar products w.r.t. momenta...");
		size_t ns = (p.nops()*(p.nops() + 1))/2; // size of s={sl}
		exvector s(ns); // scalar products sl = pa*pb with a<=b
		matrix dsldxi(ns, x.nops()); // row: l, col: i
		matrix pjdsldpk(p.nops() * p.nops(), ns); // row: j*(#p) + k, col: l
		size_t l = -1;
		for (size_t a = 0; a < p.nops(); ++a) {
			for (size_t b = a; b < p.nops(); ++b) {
				++l;
				s[l] = ScalarProduct(p[a], p[b]).eval();
				ex slons = s[l].eval().subs(sp2inv);
				VERIFY(!slons.has(ScalarProduct(wild(),wild())));
				for (size_t i = 0 ; i < x.nops(); ++i) {
					ASSERT(is_a<symbol>(x[i]));
					dsldxi(l, i) = normal_form(slons.diff(ex_to<symbol>(x[i])));
				}
				for (size_t j = 0; j < p.nops(); ++j) {
					pjdsldpk(j * p.nops() + a, l) +=
							ScalarProduct(p[j], p[b]).eval().subs(sp2inv);
					pjdsldpk(j * p.nops() + b, l) +=
							ScalarProduct(p[j], p[a]).eval().subs(sp2inv);
				}
			}
		}

		// solve system for dJ/dsl
		LOG("Solving for derivatives w.r.t. invariants...");
		lst dJdsl_syms, pjdJdpk_syms; // symbols for system solving
		for (size_t l = 0; l < ns; ++l) {
			string var_str = to_safe_variable_name(to_string(s[l]));
			dJdsl_syms.append(symbol("dJd" + var_str));
		}
		lst eqns;
		for (size_t j = 0; j < p.nops(); ++j) {
			for (size_t k = 0; k < p.nops(); ++k) {
				symbol sym(to_string(p[j]) + "dJd" + to_string(p[k]));
				pjdJdpk_syms.append(sym);
				ex expr = 0;
				for (size_t l = 0; l < ns; ++l)
					expr += pjdsldpk(j * p.nops() + k, l) * dJdsl_syms[l];
				eqns.append(pjdJdpk_syms[j * p.nops() + k] == expr);
			}
		}
		unsigned opts = solve_algo::gauss; // avoid division problems with sqrt
		LOGX(" system to solve: " << eqns);
		ex sollst;
		exmap sol;
		try {
			sollst = lsolve_overdetermined(eqns, dJdsl_syms, opts);
			sol = lsolve_result_to_rules(sollst, dJdsl_syms);
		} catch (exception& e) {
			ERROR(string("can't solve for wanted derivatives:\n") + e.what());
		}

		// actually compute pj dJ/dpk and dJ/dxi
		LOG("Differentiating integral w.r.t. momenta");
		if (set_subsectors_to_zero_)
			LOG("NOTE: setting subsectors to zero !");
		vector<LinearCombination> pjdJdpk(p.nops() * p.nops());
		vector<LinearCombination> dJdxi(x.nops());
		for (size_t j = 0; j < p.nops(); ++j)
			for (size_t k = 0; k < p.nops(); ++k) {
				LinearCombination deriv = idgen.get_derivative_wrt_momentum(
						*integ, j, k);
				if (set_subsectors_to_zero_)
					remove_other_sectors(deriv, integ->get_sector());
				pjdJdpk[j * p.nops() + k] = deriv;
				stringstream ss;
				deriv.to_mma_stream(ss);
				LOGX(" " << p[j] << " dJ/d" << p[k] << " " << ss.str());
			}
		for (size_t i = 0; i < x.nops(); ++i) {
			LinearCombination deriv = idgen.get_derivative_wrt_parameter(*integ,
					ex_to < symbol > (x[i]));
			if (set_subsectors_to_zero_)
				remove_other_sectors(deriv, integ->get_sector());
			dJdxi[i] = deriv;
			stringstream ss;
			deriv.to_mma_stream(ss);
			LOGX(" dJ/d" << x[i] << " " << ss.str());
		}

		// inserting linear combinations of integrals into dJdsl
		LOG("Inserting linear combinations of integrals");
		vector<LinearCombination> dJdsl(ns);
		for (size_t l = 0 ; l < ns ; ++l) {
			VERIFY(sol.find(dJdsl_syms[l]) != sol.end());
			ex todo = sol[dJdsl_syms[l]].expand();
			ex simple = 0;
			LinearCombination deriv;
			for (size_t j = 0; j < p.nops(); ++j) {
				for (size_t k = 0; k < p.nops(); ++k) {
					ex sym = pjdJdpk_syms[j * p.nops() + k];
					ex coeff = todo.coeff(sym, 1);
					ex nd = coeff.numer_denom();
					coeff = factor(nd[0])/factor(nd[1]);
					const LinearCombination& d = pjdJdpk[j * p.nops() + k];
					LinearCombination::const_iterator t;
					for (t = d.begin(); t != d.end(); ++t)
						deriv.insert(t->first, t->second * coeff, true /*norm*/);
					todo -= coeff * sym;
					simple += coeff*sym;
				}
			}
			todo = todo.normal().expand();
			if (!todo.is_zero())
				ERROR("encountered non-linear or inhomogenous terms w.r.t. pjdJdpk:\n" << todo);
			LOGX("  " << dJdsl_syms[l] << " = " << simple);
			dJdsl[l] = deriv;
		}

		// assemble derivatives: dI/dxi = ( (dJ/dsl) (dsl/dxi) + dJ/dxi )_ONS
		LOG("Assembling derivatives w.r.t. kinematic invariants");
		stringstream ss;
		integ->to_mma_stream(ss);
		string integ_str = to_safe_variable_name(ss.str());
		LinearCombination chk; // sum_i  xi*massdim(xi)*dI/dxi as scaling check
		chk.set_name("scaling" + integ_str);
		for (size_t i = 0 ; i < x.nops(); ++i) {
			ex scalxi = x[i] * kin->find_mass_dimension(x[i]);
			LinearCombination dIdxi;
			string var_str = to_safe_variable_name(to_string(x[i]));
            dIdxi.set_name("d" + integ_str + "d" + var_str);
        	LinearCombination::const_iterator t;
            for (size_t l = 0 ; l < ns ; ++l)
            	for (t = dJdsl[l].begin(); t != dJdsl[l].end(); ++t) {
            		dIdxi.insert(t->first, t->second*dsldxi(l,i), true);
            		chk.insert(t->first, scalxi * t->second*dsldxi(l,i), true);
            	}
        	for (t = dJdxi[i].begin(); t != dJdxi[i].end(); ++t) {
        		dIdxi.insert(t->first, t->second, true);
        		chk.insert(t->first,  scalxi * t->second, true);
        	}
            out << dIdxi;
		}
		out << chk;
	}
	out.finalize();
}

/// returns (q * d/dk f) for two momenta q and k
/** k must enter only via ScalarProducts, not Propagators **/
/** tested code, currently not needed any more:
static ex derivative_contracted(const ex& f, const symbol& q, const symbol& k) {
	ASSERT(!f.has(Propagator(wild(1), wild(2))));
	ASSERT(!f.has(Propagator(wild(1), wild(2), wild(3))));
	// search ScalarProducts in f and introduce abbreviations for them
	exset found;
	f.find(ScalarProduct(wild(1), wild(2)), found);
	exmap sp2tmp, tmp2sp;
	for (exset::const_iterator s = found.begin(); s != found.end(); ++s) {
		symbol tmpvar;
		sp2tmp[*s] = tmpvar;
		tmp2sp[tmpvar] = *s;
	}
	ex foftmp = f.subs(sp2tmp);
	if (foftmp.has(k))
		ERROR("can't take derivative of " << foftmp << " w.r.t. " << k);
	// take (d f)/(d pk) = (d f)/(d scalarproduct_i) (d scalarproduct_i)/(d pk)
	ex deriv = 0;
	exmap::const_iterator sv;
	for (sv = sp2tmp.begin(); sv != sp2tmp.end(); ++sv) {
		if (!is_a<ScalarProduct> (sv->first) || //
				!is_a<symbol> (sv->second))
			ERROR("failed to substitute scalar products");
		ScalarProduct sprod = ex_to<ScalarProduct> (sv->first);
		ex lhs = sprod.op(0);
		ex rhs = sprod.op(1);
		symbol var = ex_to<symbol> (sv->second);
		deriv += foftmp.diff(var) * (//
				          ScalarProduct(q * lhs.diff(k), rhs)//
						+ ScalarProduct(lhs, q * rhs.diff(k))//
				);
	}
	deriv = deriv.subs(tmp2sp);
	return deriv;
}
*/

/// compute pj dI/dpk where masses are functions of the momenta
/* tested, currently unused:
static LinearCombination compute_pj_dI_dpk(const INT& i, const symbol& pj,
		const symbol& pk) {
	const IntegralFamily* ic = i.integralfamily();
	const exmap inv2sp = ic->kinematics()->find_rules_invariants_to_sp();
	ex f = i.get_integrand();
	exset found;
	f.find(Propagator(wild(1), wild(2)), found);
	f.find(Propagator(wild(1), wild(2), wild(3)), found);
	ex fexpl = f;
	for (exset::const_iterator s = found.begin(); s != found.end(); ++s) {
		ASSERT(is_a<Propagator>(*s));
		const Propagator& prop = ex_to<Propagator> (*s);
		fexpl = fexpl.subs(prop == prop.to_explicite_form().subs(inv2sp));
	}
	ex deriv = derivative_contracted(fexpl, pj, pk);
	const exmap& sp2inv = ic->kinematics()->rules_sp_to_invariants();
	const exmap& sp2prop = ic->rules_sp_to_prop();
	deriv = deriv.expand().subs(sp2inv).subs(sp2prop).expand();
	LinearCombination lc;
	propagator_products_to_integrals(deriv, lc, ic);
	return lc;
}
*/


} // namespace Reduze

