// Copyright (C) 2002 Zbigniew Leyk (zbigniew.leyk@anu.edu.au)
//                and David E. Stewart (david.stewart@anu.edu.au)
//                and Ronan Collobert (collober@iro.umontreal.ca)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#include "mx_solve.h"
#include "mx_low_level.h"

namespace Torch {

/* Most matrix factorisation routines are in-situ unless otherwise specified */

/* Usolve -- back substitution with optional over-riding diagonal
		-- can be in-situ but doesn't need to be */
void mxUSolve(Mat * matrix, Vec * b, Vec * out, real diag)
{
  int dim;
  int i, i_lim;
  real **mat_ptr, *mat_row, *b_ptr, *out_ptr, *out_col, sum, tiny;

  dim = min(matrix->m, matrix->n);
  mat_ptr = matrix->ptr;
  b_ptr = b->ptr;
  out_ptr = out->ptr;

  tiny = 10.0 / INF;

  for (i = dim - 1; i >= 0; i--)
    if (b_ptr[i] != 0.0)
      break;
    else
      out_ptr[i] = 0.0;
  i_lim = i;

  for (; i >= 0; i--)
  {
    sum = b_ptr[i];
    mat_row = &mat_ptr[i][i + 1];
    out_col = &out_ptr[i + 1];
    sum -= mxIp__(mat_row, out_col, i_lim - i);
    if (diag == 0.0)
    {
      if (fabs(mat_ptr[i][i]) <= tiny * fabs(sum))
	error("USolve: sorry, singular problem.");
      else
	out_ptr[i] = sum / mat_ptr[i][i];
    }
    else
      out_ptr[i] = sum / diag;
  }
}

/* Lsolve -- forward elimination with (optional) default diagonal value */
void mxLSolve(Mat * matrix, Vec * b, Vec * out, real diag)
{
  int dim, i, i_lim;
  real **mat_ptr, *mat_row, *b_ptr, *out_ptr, *out_col, sum, tiny;

  dim = min(matrix->m, matrix->n);
  mat_ptr = matrix->ptr;
  b_ptr = b->ptr;
  out_ptr = out->ptr;

  for (i = 0; i < dim; i++)
    if (b_ptr[i] != 0.0)
      break;
    else
      out_ptr[i] = 0.0;
  i_lim = i;

  tiny = 10.0 / INF;

  for (; i < dim; i++)
  {
    sum = b_ptr[i];
    mat_row = &mat_ptr[i][i_lim];
    out_col = &out_ptr[i_lim];
    sum -= mxIp__(mat_row, out_col, i - i_lim);
    if (diag == 0.0)
    {
      if (fabs(mat_ptr[i][i]) <= tiny * fabs(sum))
	error("LSolve: sorry, singular problem.");
      else
	out_ptr[i] = sum / mat_ptr[i][i];
    }
    else
      out_ptr[i] = sum / diag;
  }
}


/* UTsolve -- forward elimination with (optional) default diagonal value
		using UPPER triangular part of matrix */
void mxUTSolve(Mat * mat, Vec * b, Vec * out, real diag)
{
  int dim, i, i_lim;
  real **mat_ptr, *b_ptr, *out_ptr, tmp, invdiag, tiny;

  dim = min(mat->m, mat->n);
  mat_ptr = mat->ptr;
  b_ptr = b->ptr;
  out_ptr = out->ptr;

  tiny = 10.0 / INF;

  for (i = 0; i < dim; i++)
  {
    if (b_ptr[i] != 0.0)
      break;
    else
      out_ptr[i] = 0.0;
  }
  i_lim = i;
  if (b != out)
  {
    mxZero__(out_ptr, out->n);
    real *ptr_r = &b_ptr[i_lim];
    real *ptr_w = &out_ptr[i_lim];
    for (int j = 0; j < dim - i_lim; j++)
      *ptr_w++ = *ptr_r++;
  }

  if (diag == 0.0)
  {
    for (; i < dim; i++)
    {
      tmp = mat_ptr[i][i];
      if (fabs(tmp) <= tiny * fabs(out_ptr[i]))
	error("UTSolve: sorry, singular problem.");
      out_ptr[i] /= tmp;
      mxRealMulAdd__(&out_ptr[i + 1], &mat_ptr[i][i + 1], -out_ptr[i],
		     dim - i - 1);
    }
  }
  else
  {
    invdiag = 1.0 / diag;
    for (; i < dim; i++)
    {
      out_ptr[i] *= invdiag;
      mxRealMulAdd__(&out_ptr[i + 1], &mat_ptr[i][i + 1], -out_ptr[i],
		     dim - i - 1);
    }
  }
}

/* Dsolve -- solves Dx=b where D is the diagonal of A -- may be in-situ */
void mxDSolve(Mat * mat, Vec * b, Vec * x)
{
  int dim, i;
  real tiny;

  dim = min(mat->m, mat->n);

  tiny = 10.0 / INF;

  dim = b->n;
  for (i = 0; i < dim; i++)
  {
    if (fabs(mat->ptr[i][i]) <= tiny * fabs(b->ptr[i]))
      error("DSolve: sorry, singular problem.");
    else
      x->ptr[i] = b->ptr[i] / mat->ptr[i][i];
  }
}

/* LTsolve -- back substitution with optional over-riding diagonal
		using the LOWER triangular part of matrix
		-- can be in-situ but doesn't need to be */
void mxLTSolve(Mat * mat, Vec * b, Vec * out, real diag)
{
  int dim;
  int i, i_lim;
  real **mat_ptr, *b_ptr, *out_ptr, tmp, invdiag, tiny;

  dim = min(mat->m, mat->n);
  mat_ptr = mat->ptr;
  b_ptr = b->ptr;
  out_ptr = out->ptr;

  tiny = 10.0 / INF;

  for (i = dim - 1; i >= 0; i--)
  {
    if (b_ptr[i] != 0.0)
      break;
  }
  i_lim = i;

  if (b != out)
  {
    mxZero__(out_ptr, out->n);
    real *ptr_r = b_ptr;
    real *ptr_w = out_ptr;
    for (int j = 0; j < i_lim + 1; j++)
      *ptr_w++ = *ptr_r++;
  }

  if (diag == 0.0)
  {
    for (; i >= 0; i--)
    {
      tmp = mat_ptr[i][i];
      if (fabs(tmp) <= tiny * fabs(out_ptr[i]))
	error("LTSolve: sorry, singular problem.");
      out_ptr[i] /= tmp;
      mxRealMulAdd__(out_ptr, mat_ptr[i], -out_ptr[i], i);
    }
  }
  else
  {
    invdiag = 1.0 / diag;
    for (; i >= 0; i--)
    {
      out_ptr[i] *= invdiag;
      mxRealMulAdd__(out_ptr, mat_ptr[i], -out_ptr[i], i);
    }
  }
}

}

