#include "swi.h"
#include "data.h"

#include <odindata/statistics.h>
#include <odindata/linalg.h>


STD_vector<TinyVector<int,3> > neighb_indices(int radius_pixels) {
  STD_vector<TinyVector<int,3> > result;

  int n=radius_pixels;

  for(int k=-n; k<=n; k++) {
    for(int j=-n; j<=n; j++) {
      for(int i=-n; i<=n; i++) {
        float radius=sqrt(float(k*k+j*j+i*i));
        if(radius<=radius_pixels && (k || j || i) ) {
          result.push_back(TinyVector<int,3>( k, j, i));
        }
      }
    }
  }

  return result;
}



///////////////////////////////////////////////////////////////////////


bool RecoSwi::process(RecoData& rd, RecoController& controller) {
  Log<Reco> odinlog(c_label(),"process");

  Range all=Range::all();

  ComplexData<3>& data=rd.data(Rank<3>());
  TinyVector<int,3> shape=data.shape();

  Data<float,3> pha(phase(data));
  Data<float,3> magn(cabs(data));

  int pixelradius=3; // best


  // Create linear mask for valid pixels
  Data<char,1> mask(magn.size()); mask=1;
  for(int i=0; i<magn.size(); i++) {
    TinyVector<int,3> index=magn.create_index(i);
    for(int j=0; j<3; j++) {
      if(index(j)<pixelradius)            mask(i)=0;
      if(index(j)>shape(j)-pixelradius-1) mask(i)=0;
    }
  }


  STD_vector<TinyVector<int,3> > neighboffset=neighb_indices(pixelradius);
  int numof_neigb=neighboffset.size();

  Data<float,1> medensemble(numof_neigb);
  Array<TinyVector<float,3>,1> ensembleind(numof_neigb);

  Data<float,2> matrix(numof_neigb,4);
  Data<float,1> phavec(numof_neigb);
  Data<float,1> xvec(4); // will contain 3 spatial indices and phase offset

  // Evaluate local phase contrast
  Data<float,3> swiarr(shape); swiarr=0.0;
  for(int i=0; i<magn.size(); i++) {
    if(mask(i)) {
      TinyVector<int,3> index=magn.create_index(i);

      float centpha=pha(index);

      for(int j=0; j<numof_neigb; j++) {
        TinyVector<int,3> indexoffset=neighboffset[j];
        TinyVector<int,3> neighbindex=index+indexoffset;

        medensemble(j)=magn(neighbindex);

        // Subtract center phase
        float phadiff=pha(neighbindex)-centpha;

        // Unwrap phase
        while(phadiff>PII)  phadiff-=2.0*PII;
        while(phadiff<-PII) phadiff+=2.0*PII;

        // Set up matrix for pseudo-inversion
        float weight=magn(neighbindex);
        matrix(j,0)=weight*indexoffset(0);
        matrix(j,1)=weight*indexoffset(1);
        matrix(j,2)=weight*indexoffset(2);
        matrix(j,3)=weight;
        phavec(j)=weight*phadiff;

      }

      xvec=solve_linear(matrix, phavec, 0.05); // Linear 3D fit using neighborhood

      float med=median(medensemble);

      swiarr(index)=med*fabs(xvec(3)); // Use phase offset/difference at central spatial position

    }
  }

  data.reference(float2real(swiarr));

  return execute_next_step(rd,controller);
}

