#include "ClusterModel.hpp"

#include "ClusterModelObserver.hpp"

#include <cassert>
#include <ctime>
#include <limits>

namespace ublas = boost::numeric::ublas;

using namespace indii::tint;
using namespace indii::cluster;
using namespace std;

ClusterModel::ClusterModel(ImageResource* res) : minClusterer(NULL),
    locked(false) {
  this->res = res;
  this->saturationThreshold = 16;
  this->maxPixels = 10000;
  this->reps = 8;
  this->k = 4;
  this->hard = true;
  this->saturationDecay = 16.0f / 255.0f;
  this->centroidDecay = 0.0f;
  this->saturationSoftness = 32.0f / 255.0f;
  this->centroidSoftness = 0.0f;
  this->greyR = 76.0f / 255.0f;
  this->greyG = 150.0f / 255.0f;
  this->greyB = 29.0f / 255.0f;
  this->visible.resize(k, false);
  this->colours.resize(k, *wxWHITE);
  this->hues.resize(k, 0.0f);
  this->sats.resize(k, 0.0f);
  this->lights.resize(k, 0.0f);
  this->minClusterer = new KMeansClusterer<>(k, time(NULL));
}

ClusterModel::~ClusterModel() {
  delete minClusterer;
}

void ClusterModel::setDefaults() {
  lock();
  setHard(true);
  setSaturationThreshold(16);
  setMaxPixels(10000);
  setNumRepetitions(8);
  setNumClusters(4);
  setSaturationDecay(16.0f / 255.0f);
  setCentroidDecay(0.0f);
  setSaturationSoftness(32.0f / 255.0f);
  setCentroidSoftness(0.0f);
  setGreyscale(76.0f, 150.0f, 29.0f);
  setForDialog();
  unlock();
  prepare();
  cluster();

  notifyAll();
  registry_t::iterator iter;
  for (iter = os.begin(); iter != os.end(); iter++) {
    static_cast<ClusterModelObserver*>(*iter)->notifyNumClustersChange();
  }  
}

void ClusterModel::setForDialog() {
  fill(visible.begin(), visible.end(), false);
  fill(hues.begin(), hues.end(), 0.0f);
  fill(sats.begin(), sats.end(), 0.0f);
  fill(lights.begin(), lights.end(), 0.0f);  
}

void ClusterModel::setNumClusters(const unsigned k) {
  if (this->k != k) {
    this->k = k;
    visible.resize(k, false);
    colours.resize(k, *wxWHITE);
    hues.resize(k, 0.0f);
    sats.resize(k, 0.0f);
    lights.resize(k, 0.0f);

    fill(visible.begin(), visible.end(), false);
    fill(sats.begin(), sats.end(), 0.0f);
    fill(hues.begin(), hues.end(), 0.0f);
    fill(lights.begin(), lights.end(), 0.0f);

    delete minClusterer;
    minClusterer = new KMeansClusterer<>(k, time(NULL));
    cluster();
  
    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyNumClustersChange();
    }
  }
}

void ClusterModel::setHard(const bool hard) {
  if (this->hard != hard) {
    this->hard = hard;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyIsHardChange();
    }
  }
}

void ClusterModel::setNumRepetitions(const unsigned reps) {
  if (this->reps != reps) {
    this->reps = reps;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyNumRepetitionsChange();
    }
  }
}

void ClusterModel::setSaturationThreshold(const unsigned char x) {
  if (this->saturationThreshold != x) {
    this->saturationThreshold = x;
    prepare();

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifySaturationThresholdChange();
    }
  }
}

void ClusterModel::setMaxPixels(const unsigned n) {
  if (this->maxPixels != n) {
    this->maxPixels = n;
    prepare();

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyMaxPixelsChange();
    }
  }
}

void ClusterModel::setSaturationDecay(const float x) {
  /* pre-condition */
  assert (x >= 0.0 && x <= 255.0f);

  if (this->saturationDecay != x) {
    this->saturationDecay = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifySaturationDecayChange();
    }
  }
}

void ClusterModel::setCentroidDecay(const float x) {
  /* pre-condition */
  assert (x >= 0.0 && x <= 255.0f);

  if (this->centroidDecay != x) {
    this->centroidDecay = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyCentroidDecayChange();
    }
  }
}

void ClusterModel::setSaturationSoftness(const float x) {
  /* pre-condition */
  assert (x >= 0.0 && x <= 255.0f);

  if (this->saturationSoftness != x) {
    this->saturationSoftness = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifySaturationSoftnessChange();
    }
  }
}

void ClusterModel::setCentroidSoftness(const float x) {
  /* pre-condition */
  assert (x >= 0.0 && x <= 255.0f);

  if (this->centroidSoftness != x) {
    this->centroidSoftness = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyCentroidSoftnessChange();
    }
  }
}

void ClusterModel::setGreyscale(const float r, const float g,
    const float b) {
  float total = r + g + b;
  greyR = r / total;
  greyG = g / total;
  greyB = b / total;
  cs.setLightness(greyR, greyG, greyB);

  /* notify observers */
  notifyAll();
  registry_t::iterator iter;
  for (iter = os.begin(); iter != os.end(); iter++) {
    static_cast<ClusterModelObserver*>(*iter)->notifyGreyscaleChange();
  }
}

void ClusterModel::setColour(const unsigned i, const wxColour& col) {
  /* pre-condition */
  assert (i < k);

  ClusterVector<>::type x(ClusterVector<>::N);
  x(0) = (float)col.Red();
  x(1) = (float)col.Green();
  x(2) = (float)col.Blue();
  PearsonDistance<>::prepare(x);

  colours[i] = col;
  minClusterer->setCentroid(i, x);
}

void ClusterModel::setHue(const unsigned i, const float x) {
  /* pre-condition */
  assert (i < k);
  assert (x >= 0.0f && x <= 6.0f);

  if (hues[i] != x) {
    hues[i] = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyHueChange(i);
    }
  }
}

void ClusterModel::setSat(const unsigned i, const float x) {
  /* pre-condition */
  assert (i < k);
  assert (x >= 0.0f && x <= 1.0f);

  if (sats[i] != x) {
    sats[i] = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifySatChange(i);
    }
  }
}

void ClusterModel::setLight(const unsigned i, const float x) {
  /* pre-condition */
  assert (i < k);
  assert (x >= 0.0f && x <= 1.0f);

  if (lights[i] != x) {
    lights[i] = x;

    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyLightChange(i);
    }
  }
}

void ClusterModel::show(const unsigned i, const bool on) {
  /* pre-condition */
  assert (i < k);

  if (visible[i] != on) {
    visible[i] = on;
    if (on) {
      sats[i] = 1.0f;
    } else {
      sats[i] = 0.0f;
    }
    
    /* notify observers */
    notifyAll();
    registry_t::iterator iter;
    for (iter = os.begin(); iter != os.end(); iter++) {
      static_cast<ClusterModelObserver*>(*iter)->notifyClusterChange(i);
    }
  }
}

void ClusterModel::showAll(const bool on) {
  unsigned i;
  for (i = 0; i < k; i++) {
    show(i, on);
  }
}

void ClusterModel::calcFg(const unsigned i, const wxRect& rect,
    const unsigned width, const unsigned height, wxImage& o) {

}

void ClusterModel::calcFg(const wxRect& rect,
    const unsigned width, const unsigned height, wxImage& o) {
  /* pre-conditions */
  assert (o.GetWidth() >= rect.width);
  assert (o.GetHeight() >= rect.height);

  wxImage img(res->calc(rect, width, height));

  #pragma omp parallel default(shared)
  {
    ColourSpace::rgb_t rgb(3);
    ColourSpace::hsl_t hsl(3);
    ClusterVector<>::type pixel(ClusterVector<>::N);
    //std::vector<float> ds(k);

    int x, y;
    int cluster;
    unsigned char r, g, b;
    float s, d, distance;

    #pragma omp for schedule(dynamic)
    for (y = 0; y < rect.height; y++) {
      for (x = 0; x < rect.width; x++) {
        r = img.GetRed(x, y);
        g = img.GetGreen(x, y);
        b = img.GetBlue(x, y);

        rgb(0) = r;
        rgb(1) = g;
        rgb(2) = b;
        cs.rgb2hsl(rgb, hsl);

        pixel(0) = r;
        pixel(1) = g;
        pixel(2) = b;
        
        /* cluster adjustments */
        PearsonDistance<>::prepare(pixel);
        //if (isHard()) {
          cluster = minClusterer->assign(pixel, &distance);

          //hsl(0) = fmod(6.0f + hsl(0) + 6.0f*hues[cluster], 6.0f);
          hsl(1) *= sats[cluster];
          //hsl(2) = min(255.0f, hsl(2) + 255.0f*lights[cluster]);
        //} else {
          //
        //}

        /* hue and saturation decay */
        d = std::pow(1.0f-0.5f*distance, 16);
        d = std::max(0.0f, std::min(d, 1.0f));
        d = threshold(d, centroidDecay, centroidSoftness);
        s = threshold(hsl(1), saturationDecay, saturationSoftness);
        hsl(1) *= s*d;

        /* convert back */
        cs.hsl2rgb(hsl, rgb);
        o.SetRGB(x, y, rgb(0), rgb(1), rgb(2));
      }
    }
  }
}

void ClusterModel::calcFg(const unsigned width, const unsigned height,
    wxImage& o) {

}

void ClusterModel::calcAlpha(const unsigned i,
    const wxRect& rect, const unsigned width, const unsigned height,
    channel& c) {
  /* pre-conditions */
  assert (i < k);
  assert ((int)c.size1() == rect.height && (int)c.size2() == rect.width);

  wxImage img(res->calc(rect, width, height));

  #pragma omp parallel default(shared)
  {
    ColourSpace::rgb_t rgb(3);
    ColourSpace::hsl_t hsl(3);
    ClusterVector<>::type pixel(ClusterVector<>::N);

    int x, y;
    unsigned char r, g, b;
    float s, d, distance;

    #pragma omp for schedule(dynamic)
    for (y = 0; y < rect.height; y++) {
      for (x = 0; x < rect.width; x++) {
        r = img.GetRed(x, y);
        g = img.GetGreen(x, y);
        b = img.GetBlue(x, y);

        rgb(0) = r;
        rgb(1) = g;
        rgb(2) = b;
        cs.rgb2hsl(rgb, hsl);

        pixel(0) = r;
        pixel(1) = g;
        pixel(2) = b;
        
        /* cluster adjustments */
        PearsonDistance<>::prepare(pixel);
        if (minClusterer->assign(pixel, &distance) == i) {
          /* hue and saturation decay */
          d = std::pow(1.0f - 0.5f*distance, 16);
          d = threshold(d, centroidDecay, centroidSoftness);
          s = threshold(hsl(1), saturationDecay, saturationSoftness);
          c(y,x) = ColourSpace::uround(255.0f*s*d);
        } else {
          c(y,x) = 0;
        }
      }
    }
  }
}

void ClusterModel::calcAlpha(const wxRect& rect,
    const unsigned width, const unsigned height, channel& c) {
  /* pre-condition */
  assert ((int)c.size1() == rect.height && (int)c.size2() == rect.width);

  wxImage img(res->calc(rect, width, height));

  #pragma omp parallel default(shared)
  {
    ColourSpace::rgb_t rgb(3);
    ColourSpace::hsl_t hsl(3);
    ClusterVector<>::type pixel(ClusterVector<>::N);

    int x, y;
    unsigned char r, g, b;
    float s, d, distance;

    #pragma omp for schedule(dynamic)
    for (y = 0; y < rect.height; y++) {
      for (x = 0; x < rect.width; x++) {
        r = img.GetRed(x, y);
        g = img.GetGreen(x, y);
        b = img.GetBlue(x, y);

        rgb(0) = r;
        rgb(1) = g;
        rgb(2) = b;
        cs.rgb2hsl(rgb, hsl);

        pixel(0) = r;
        pixel(1) = g;
        pixel(2) = b;
        
        PearsonDistance<>::prepare(pixel);

        if (isShown(minClusterer->assign(pixel, &distance))) {
          /* hue and saturation decay */
          d = std::pow(1.0f - 0.5f*distance, 16);
          d = threshold(d, centroidDecay, centroidSoftness);
          s = threshold(hsl(1), saturationDecay, saturationSoftness);
          c(y,x) = ColourSpace::uround(255.0f*s*d);
        } else {
          c(y,x) = 0;
        }
      }
    }
  }
}

void ClusterModel::calcAlpha(const unsigned width, const unsigned height,
    channel& c) {
  wxRect rect;
  rect.x = 0;
  rect.y = 0;
  rect.width = width;
  rect.height = height;

  calcAlpha(rect, width, height, c);
}

sparse_mask ClusterModel::calcMask(const unsigned i,
    const wxRect& rect, const unsigned width, const unsigned height) {
  /* pre-condition */
  assert (i < k);

  sparse_mask m(rect.height, rect.width);
  wxImage img(res->calc(rect, width, height));

  #pragma omp parallel default(shared)
  {
    ClusterVector<>::type pixel(ClusterVector<>::N);
    unsigned char r, g, b;
    int x, y;

    #pragma omp for schedule(dynamic)
    for (y = 0; y < rect.height; y++) {
      for (x = 0; x < rect.width; x++) {
        r = img.GetRed(x,y);
        g = img.GetGreen(x,y);
        b = img.GetBlue(x,y);

        pixel(0) = static_cast<float>(r);
        pixel(1) = static_cast<float>(g);
        pixel(2) = static_cast<float>(b);

        PearsonDistance<>::prepare(pixel);
        if (minClusterer->assign(pixel) == i) {
          #pragma omp critical
          {
            /* sparse matrix, so need to control access, elements may be
             * rearranged */
            m(y,x) = true;
          }
        }
      }
    }
  }
  return m;
}

void ClusterModel::prepare() {
  if (locked) {
    return;
  }

  /* prepare data set for clustering */
  wxImage* original = res->get();
  wxImage* working;
  int originalWidth = original->GetWidth();
  int originalHeight = original->GetHeight();
  int pixels = originalWidth*originalHeight;
  int workingWidth, workingHeight;
  float factor = maxPixels / static_cast<float>(pixels);
    
  if (factor > 1.0) {
    /* image size doesn't exceed limit, cluster it directly */
    workingWidth = originalWidth;
    workingHeight = originalHeight;
    working = original;
  } else {
    /* image size exceeds limit, cluster scaled version */
    workingWidth = static_cast<int>(originalWidth*sqrt(factor));
    workingHeight = static_cast<int>(originalHeight*sqrt(factor));
    working = res->get(workingWidth, workingHeight);
  }

  data.clear();
  
  /* prepare data set for clustering */
  #pragma omp parallel default(shared)
  {
    ClusterVector<>::type pixel(ClusterVector<>::N);
    int x, y;
    unsigned char r, g, b, s;
    
    #pragma omp for schedule(dynamic)
    for (x = 0; x < workingWidth; x++) {
      for (y = 0; y < workingHeight; y++) {
        r = working->GetRed(x,y);
        g = working->GetGreen(x,y);
        b = working->GetBlue(x,y);
        
        /* saturation threshold */
        s = saturation(r,g,b);
        if (s >= saturationThreshold) {
          pixel(0) = r;
          pixel(1) = g;
          pixel(2) = b;
          PearsonDistance<>::prepare(pixel);
          #pragma omp critical
          {
            data.add(pixel);
          }
        }
      }
    }
  }
}

void ClusterModel::cluster() {
  if (locked) {
    return;
  }
  const unsigned MAX_ITERS = 100;
  const unsigned seed = time(NULL);
  float minError = std::numeric_limits<float>::max();

  #pragma omp parallel shared(minError)
  {
    /* cluster */
    KMeansClusterer<>* clusterer;
    float error;
    int i;
    ClusterVector<>::type pixel(ClusterVector<>::N);
    
    #pragma omp for schedule(dynamic)
    for (i = 0; i < (int)reps; i++) {
      clusterer = new KMeansClusterer<>(k, seed+i);
      clusterer->cluster(data, MAX_ITERS);
      error = clusterer->getError();
      #pragma omp critical
      {
        if (error < minError) {
          delete minClusterer;
          minError = error;
          minClusterer = clusterer;
        } else {
          delete clusterer;
        }
      }
    }
    assert (minClusterer != NULL);

    /* cluster visibility */
    #pragma omp for
    for (i = 0; i < (int)k; i++) {
      visible[i] = false;
    }
  
    /* cluster colours */
    #pragma omp for
    for (i = 0; i < (int)k; i++) {
      pixel = 127.0*minClusterer->getCentroid(i) +
          ublas::scalar_vector<float>(3,128.0);
      colours[i].Set(pixel(0), pixel(1), pixel(2));
    }
  }
}

