/*
 * canonicallabel.cpp
 *
 *  Created on: Jul 23, 2016
 *      Author: cedric
 */

#include "canonicallabel.h"
#include "functions.h"
#include "undirectedgraph.h"
#include "yamlutils.h"
#include <algorithm>

using namespace std;
using GiNaC::lst;

namespace Reduze {

bool mass_is_less::operator()(const GiNaC::ex& e1, const GiNaC::ex& e2) const {
	using namespace GiNaC;
	if (is_a < numeric > (e1) &&is_a<numeric>(e2))
		return ex_to<numeric>(e1) < ex_to<numeric>(e2);
	if (is_a < numeric > (e1) &&!is_a<numeric>(e2))
		return true;
	if (!is_a < numeric > (e1) &&is_a<numeric>(e2))
		return false;
	ostringstream s1, s2;
	s1 << e1;
	s2 << e2;
	return s1.str() < s2.str();
}

// canonical label

int CanonicalLabel::compare(const CanonicalLabel& other) const {

	if (edge_attributes_ != other.edge_attributes_)
		return ((edge_attributes_ < other.edge_attributes_) ? -1 : 1);

	if (adjacencies_ != other.adjacencies_)
		return ((adjacencies_ < other.adjacencies_) ? -1 : 1);

	if (node_colorings_ != other.node_colorings_)
		return ((lexicographical_compare(node_colorings_.begin(),
				node_colorings_.end(), other.node_colorings_.begin(),
				other.node_colorings_.end())) ? -1 : 1);
	return 0;
}

bool CanonicalLabel::operator<(const CanonicalLabel& other) const {
	return (this->compare(other) < 0);
}
bool CanonicalLabel::operator==(const CanonicalLabel& other) const {
	return (this->compare(other) == 0);
}
bool CanonicalLabel::operator!=(const CanonicalLabel& other) const {
	return !(*this == other);
}

std::ostream& operator<<(std::ostream& os, const CanonicalLabel& label) {
	YAML::Emitter ye;
	ye << label;
	os << ye.c_str();
	return os;
}

YAML::Emitter& operator<<(YAML::Emitter& ye, const CanonicalLabel& label) {
	using namespace YAML;
	ye << BeginMap;
	ye << Key << "edge_attributes" << Value << Flow << label.edge_attributes_;
	ye << Key << "adjacencies" << Value << BeginSeq;
	for (unsigned i = 0; i < label.adjacencies_.size(); ++i)
		ye << Flow << label.adjacencies_[i];
	ye << EndSeq;
	ye << Key << "node_colorings" << Value << Flow << label.node_colorings_;
	ye << EndMap;
	return ye;
}

void CanonicalLabel::read(const YAML::Node& yn, const GiNaC::lst& symbs) {
	using namespace YAML;
	if (yn.Type() != NodeType::Map || yn.size() != 3)
		throw runtime_error("expected a 3-element map " + position_info(yn));

	const Node& edge_node = yn["edge_attributes"];
	if (edge_node.Type() != NodeType::Map)
		throw runtime_error("expected a map " + position_info(edge_node));
	edge_attributes ea;
	ea.read(edge_node, symbs);
	edge_attributes_ = ea;

	const Node& ad_node = yn["adjacencies"];
	if (ad_node.Type() != NodeType::Sequence)
		throw runtime_error("expected a sequence " + position_info(ad_node));
	vector<map<int, colored_multi_edge> > ad;
	ad.reserve(ad_node.size());
	for (Iterator i = ad_node.begin(); i != ad_node.end(); ++i) {
		ad.push_back(map<int, colored_multi_edge>());
		const Node& map_node = *i;
		if (map_node.Type() != NodeType::Map)
			throw runtime_error("expected a map " + position_info(map_node));
		for (Iterator j = map_node.begin(); j != map_node.end(); ++j) {
			int key;
			j.first() >> key;
			colored_multi_edge cme(j.second());
			if (!ad.back().insert(make_pair(key, cme)).second)
				throw runtime_error(
						"node " + position_info(j.first())
								+ " has already been read.");
		}
	}
	adjacencies_.swap(ad);

	const Node& col_node = yn["node_colorings"];
	if (col_node.Type() != NodeType::Sequence
			|| (col_node.size() != 0 && col_node.size() != adjacencies_.size()))
		throw runtime_error(
				"expected an empty or " + to_string(adjacencies_.size())
						+ "-element sequence " + position_info(col_node));
	vector<int> col;
	col_node >> col;
	node_colorings_.swap(col);
}

void CanonicalLabel::swap(CanonicalLabel& other) {
	edge_attributes_.swap(other.edge_attributes_);
	adjacencies_.swap(other.adjacencies_);
	node_colorings_.swap(other.node_colorings_);
}

// CanonicalRelabeling

std::map<int, int> CanonicalRelabeling::find_node_permutation(
		const std::vector<int>& from, const std::vector<int>& to) {
	VERIFY(from.size() == to.size());
	std::map<int, int> res;
	for (size_t n = 0; n < from.size(); ++n)
		res[from[n]] = to[n];
	return res;
}

std::map<int, int> //
CanonicalRelabeling::find_node_permutation(
		const CanonicalRelabeling& to) const {
	return CanonicalRelabeling::find_node_permutation(node_permutation_,
			to.node_permutation_);
}
const std::vector<int>& CanonicalRelabeling::node_permutation() const {
	return node_permutation_;
}

void CanonicalRelabeling::swap(CanonicalRelabeling& other) {
	node_permutation_.swap(other.node_permutation_);
}

YAML::Emitter& operator<<(YAML::Emitter& ye,
		const CanonicalRelabeling& relabel) {
	ye << YAML::Flow << relabel.node_permutation_;
	return ye;
}
void operator>>(const YAML::Node& n, CanonicalRelabeling& relabel) {
	if (n.Type() != YAML::NodeType::Sequence)
		throw runtime_error("expected a sequence " + position_info(n));
	n >> relabel.node_permutation_;
}
std::ostream& operator<<(std::ostream& os, const CanonicalRelabeling& relabel) {
	YAML::Emitter ye;
	ye << relabel;
	os << ye.c_str();
	return os;
}

IFindCanonicalLabel::IFindCanonicalLabel() {
}

IFindCanonicalLabel::~IFindCanonicalLabel() {
}

int IFindCanonicalLabel::max_num_edge_colors() const {
	list<map<int, int> > coloring = get_edge_coloring().first;
	int num_c = 1;
	list<map<int, int> >::const_iterator l;
	for (l = coloring.begin(); l != coloring.end(); ++l) {
		ASSERT(l->size() == edges().size());
		num_c *= get_map_values(*l).size();
	}
	return num_c;
}

std::map<int, int> combine_color_simple(std::list<map<int, int> > coloring) {

	if (coloring.empty())
		return map<int, int>();

	map<int, list<int> > tmp;
	list<map<int, int> >::const_iterator l;
	for (l = coloring.begin(); l != coloring.end(); ++l) {
		ASSERT(l->size() == coloring.begin()->size());
		map<int, int>::const_iterator m;
		for (m = l->begin(); m != l->end(); ++m)
			tmp[m->first].push_back(m->second);
	}
	ASSERT(tmp.size() == coloring.begin()->size());
	return map_values_to_base_0(tmp);
}

std::pair<CanonicalLabel, CanonicalRelabeling> IFindCanonicalLabel::find_canonical_label(
		bool perm_ext_nodes) const {

	const pair<list<map<int, int> >, edge_attributes> edge_coloring =
			get_edge_coloring();
	const edge_attributes& eas = edge_coloring.second;
	UndirectedGraph udgraph(nodes(), edges(), edge_coloring.first);

	list<map<int, int> > node_coloring = get_node_coloring(perm_ext_nodes);
	map<int, int> color_of_node = combine_color_simple(node_coloring);

	pair<vector<map<int, colored_multi_edge> >, vector<int> > ug_label;
	ug_label = udgraph.find_canonical_label(color_of_node);

	vector<int> colorings(color_of_node.size());
	for (unsigned int i = 0; i < color_of_node.size(); ++i)
		colorings[i] = color_of_node.at(ug_label.second[i]);

	CanonicalLabel label;
	label.adjacencies_.swap(ug_label.first);
	label.node_colorings_.swap(colorings);
	label.edge_attributes_ = eas;
	CanonicalRelabeling relabel;
	relabel.node_permutation_.swap(ug_label.second);
	return make_pair(label, relabel);
}

std::list<std::map<int, int> > IFindCanonicalLabel::find_node_symmetry_group(
		bool suppress_free_ext_node_perms) const {

	pair<list<map<int, int> >, edge_attributes> eces = get_edge_coloring();
	UndirectedGraph udgraph(nodes(), edges(), eces.first);
	map<int, int> color_of_node;
	list<map<int, int> > perms;
	set<int> no_free_permutations;
	if (suppress_free_ext_node_perms)
		no_free_permutations = external_nodes();
	// strict pruning by default set to false to allow for some more permutations of ext. nodes
	udgraph.find_node_symmetry_group(perms, color_of_node, no_free_permutations);
	return perms;
}

namespace {
/** The multi line graph L from the topology T is defined as the topology
 ** where each edge from T becomes a node in L. Two nodes in L are
 ** connected by one or two edges if the corresponding edges in T have
 ** one or two common endpoints. We choose multiple edges in L
 ** because it reduces the number of node permutations which are not
 ** valid in T.
 ** Single disconnected edges become isolated nodes. **/
void get_multi_line_graph(std::set<int>& lg_nodes,
		std::map<int, Edge>& lg_edges, const std::map<int, Edge>& t_edges) {
	int lg_edge_id = 0;
	map<int, Edge>::const_iterator mit1, mit2;
	for (mit1 = t_edges.begin(); mit1 != t_edges.end(); ++mit1) {
		const Edge& e1 = mit1->second;
		// edge becomes a node
		lg_nodes.insert(e1.id);
		if (e1.is_self_loop()) {
			// self-loop becomes a disconnected self-loop
			++lg_edge_id;
			lg_edges.insert(
					make_pair(lg_edge_id, Edge(e1.id, e1.id, lg_edge_id)));
			continue;
		}
		mit2 = mit1;
		for (++mit2; mit2 != t_edges.end(); ++mit2) {
			const Edge& e2 = mit2->second;
			if (e2.is_self_loop())
				continue;
			if (e1.from == e2.from || e1.from == e2.to) {
				++lg_edge_id;
				lg_edges.insert(
						make_pair(lg_edge_id, Edge(e1.id, e2.id, lg_edge_id)));
			}
			if (e1.to == e2.from || e1.to == e2.to) {
				++lg_edge_id;
				lg_edges.insert(
						make_pair(lg_edge_id, Edge(e1.id, e2.id, lg_edge_id)));
			}
		}
	}
}

// turn set1 into the intersection of set1 and set2
void intersect_with(std::set<int>& set1, const std::set<int>& set2) {
	set<int> tmp;
	set<int>::const_iterator s;
	for (s = set1.begin(); s != set1.end(); ++s)
		if (set2.find(*s) != set2.end())
			tmp.insert(*s);
	set1.swap(tmp);
}

bool is_a_permutation(const std::map<int, int>& perm) {
	set<int> test_from, test_to;
	map<int, int>::const_iterator it;
	for (it = perm.begin(); it != perm.end(); ++it) {
		test_from.insert(it->first);
		test_to.insert(it->second);
	}
	return test_from == test_to;
}

bool reconstruct_node_permutation(std::map<int, int>& node_permutation,
		const std::map<int, int>& edge_permutation,
		const std::set<int>& nodes_,
		const std::map<int, Edge>& edges_,
		const std::map<int, std::list<Edge> > edges_of_node) {

	/* An edge permutation is valid if there exists a node permutation P
	 * such that for every node N all adjacent edges of node N are
	 * mapped to adjacent edges of node P(N).
	 *
	 * For every node N: transform all adjacent edges and check the
	 * intersection of their endpoints.
	 * If the size of the intersection is:
	 *   0: No node permutation possible for node N, return false
	 *   1: A unique node permutation for node N has been found
	 *   2:	All edges of N transform to edges between two endpoints ep1, ep2.
	 *		The node N must be connected to a single opposite node M.
	 *		N and M are mapped to ep1, ep2, respecting the number of neighbors they have.
	 *      If all 4 nodes have the same number of neighbors choose
	 *        permutation N -> min(intersection) and M -> max(intersection)
	 */

	if (!is_a_permutation(edge_permutation))  {
		ABORT("The edge transformation is not a permutation");
	}

	map<int, int> node_map;

	set<int>::const_iterator node;
	for (node = nodes_.begin(); node != nodes_.end(); ++node) {
		if (node_map.find(*node) != node_map.end())
			continue;
		const list<Edge>& edges_from = edges_of_node.at(*node);

		// isolated nodes are mapped to itself
		if (edges_from.empty()) {
			node_map[*node] = *node;
			continue;
		}

		set<int> inter_nodes_to; // intersection of the endpoints of the mapped edges
		set<int> oppnodes; // nodes opposite of node
		list<Edge>::const_iterator e_from;
		for (e_from = edges_from.begin(); e_from != edges_from.end(); ++e_from) {
			const Edge& e_to = edges_.at(edge_permutation.at(e_from->id));
			set<int> endpoints;
			endpoints.insert(e_to.from);
			endpoints.insert(e_to.to);
			if (e_from == edges_from.begin())
				inter_nodes_to.swap(endpoints);
			else
				intersect_with(inter_nodes_to, endpoints);
			// test for case 0:
			if (inter_nodes_to.empty())
				return false;
			oppnodes.insert(e_from->opposite(*node));
		}
		ASSERT(!inter_nodes_to.empty() && inter_nodes_to.size() <= 2);

		// test for case 1:
		int ep1 = *inter_nodes_to.begin();
		int num_e_ep1 = edges_of_node.at(ep1).size();
		int num_e_node = edges_from.size();
		if (inter_nodes_to.size() == 1) {
			if (num_e_node != num_e_ep1)
				return false;
			node_map[*node] = ep1;
			continue;
		}

		// test for case 2:
		ASSERT(inter_nodes_to.size() == 2);
		int ep2 = *inter_nodes_to.rbegin();
		int num_e_ep2 = edges_of_node.at(ep2).size();
		if (oppnodes.size() != 1 || *node == *oppnodes.begin())
			return false;
		int opnode = *oppnodes.begin();
		int num_e_opnode = edges_of_node.at(opnode).size();
		map<int, int>::const_iterator m = node_map.find(opnode);
		if (m != node_map.end()) {
			if (m->second != ep1 && m->second != ep2)
				return false;
			node_map[*node] = ((m->second == ep1) ? ep2 : ep1);
		} else if (num_e_node != num_e_opnode) {
			if (num_e_ep1 != num_e_node && num_e_ep1 != num_e_opnode)
				return false;
			if (num_e_ep2 != num_e_node && num_e_ep2 != num_e_opnode)
				return false;
			node_map[*node] = ((num_e_ep1 == num_e_node) ? ep1 : ep2);
			node_map[opnode] = ((num_e_ep1 == num_e_node) ? ep2 : ep1);
		} else { // node and opnode have the same number of neighbours
			if (num_e_ep1 != num_e_node || num_e_ep2 != num_e_node)
				return false;
			node_map[std::min(*node, opnode)] = std::min(ep1, ep2);
			node_map[std::max(*node, opnode)] = std::max(ep1, ep2);
		}
	}
	if (!is_a_permutation(node_map))  {
		ABORT("The node transformation is not a permutation");
	}
	node_map.swap(node_permutation);
	return true;
}

std::map<int, std::list<Edge> > get_edges_of_node(const std::set<int>& nodes,
		const std::map<int, Edge>& edges) {
	map<int, list<Edge> > edges_of_node;
	for (set<int>::const_iterator n = nodes.begin(); n != nodes.end(); ++n)
		edges_of_node[*n];
	for (map<int, Edge>::const_iterator e = edges.begin(); e != edges.end();
			++e) {
		edges_of_node[e->second.from].push_back(e->second);
		if (e->second.to != e->second.from)
			edges_of_node[e->second.to].push_back(e->second);
	}
	return edges_of_node;
}
}

void IFindCanonicalLabel::find_edge_symmetry_group(
		std::list<std::map<int, std::pair<int, int> > >& edge_permutations,
		bool suppress_free_ext_node_perms) const {

	// find the node symmetry of the line graph
	// the nodes of the line graph are the edge ids of *this
	LOGX("  finding edge symmetry group of graph "/* << name()*/);

	const set<int>& t_nodes = nodes();
	const map<int, Edge>& t_edges = edges();
	const map<int, list<Edge> > t_edges_of_node = get_edges_of_node(t_nodes, t_edges);

	set<int> lg_nodes;
	map<int, Edge> lg_edges;
	get_multi_line_graph(lg_nodes, lg_edges, t_edges);
	ASSERT(lg_nodes.size() == t_edges.size());

	list<map<int, int> > dummy_lg_edge_coloring;
	UndirectedGraph udgraph(lg_nodes, lg_edges, dummy_lg_edge_coloring);
	set<int> no_free_permute;
	if (suppress_free_ext_node_perms) {
		// dont allow external edges to be permuted freely....
		list<Edge> el = external_edges();
		for (list<Edge>::const_iterator e = el.begin(); e != el.end(); ++e)
			no_free_permute.insert(e->id);
	}

	// this edge coloring is the line graphs node coloring
	const list<map<int, int> > t_edge_coloring = get_edge_coloring().first; // virtual
	map<int, int> lg_node_color = combine_color_simple(t_edge_coloring);
	ASSERT(lg_node_color.size() == lg_nodes.size());
	list<map<int, int> > edge_perm_list;
	udgraph.find_node_symmetry_group(edge_perm_list, lg_node_color,
			no_free_permute);
	const int num_edge_perms = edge_perm_list.size();

	// determine the sign of the edge permutation
	edge_permutations.clear();
	for (; !edge_perm_list.empty(); edge_perm_list.pop_front()) {
		const map<int, int>& edge_perm = edge_perm_list.front();
		map<int, int> node_perm;
		if (reconstruct_node_permutation(node_perm, edge_perm, t_nodes, t_edges,
				t_edges_of_node)) {
			map<int, pair<int, int> > edge_perm_sign;
			map<int, int>::const_iterator ep;
			for (ep = edge_perm.begin(); ep != edge_perm.end(); ++ep) {
				const int from1 = node_perm.at(t_edges.at(ep->first).from);
				const int from2 = t_edges.at(ep->second).from;
				int sign = ((from1 == from2) ? 1 : -1);
				edge_perm_sign[ep->first] = make_pair(sign, ep->second);
			}
			edge_permutations.push_back(map<int, pair<int, int> >());
			edge_permutations.back().swap(edge_perm_sign);
		}
	}
	const int num_node_perms = edge_permutations.size();
	LOGX("  found " << num_edge_perms << " edge permutation");
	LOGX("  found " << num_node_perms //
			<< " edge permutations with valid node permutation");
	if (num_edge_perms != num_node_perms)
		LOGX("    mismatch might come from star K_1_3 as subgraph");
}

} // namespace Reduze

