/*
  MeCab -- Yet Another Part-of-Speech and Morphological Analyzer

  $Id: learner.cpp,v 1.13 2006/07/09 11:26:27 taku-ku Exp $;

  Copyright (C) 2001-2006 Taku Kudo <taku@chasen.org>
  Copyright (C) 2004-2006 Nippon Telegraph and Telephone Corporation

*/
#include <vector>
#include <string>
#include <fstream>
#include "param.h"
#include "common.h"
#include "lbfgs.h"
#include "thread.h"
#include "learner_tagger.h"
#include "freelist.h"
#include "feature_index.h"
#include "string_buffer.h"

namespace MeCab {

#ifdef MECAB_USE_THREAD
  class learner_thread: public thread {
  public:
    unsigned short start_i;
    unsigned short thread_num;
    size_t size;
    size_t micro_p;
    size_t micro_r;
    size_t micro_c;
    size_t err;
    double f;
    EncoderLearnerTagger **x;
    std::vector<double> expected;
    void run() {
      micro_p = micro_r = micro_c = err = 0;
      f = 0.0;
      std::fill(expected.begin(), expected.end(), 0.0);
      for (size_t i = start_i; i < size; i += thread_num) {
        f += x[i]->gradient(&expected[0]);
        err += x[i]->eval(micro_c, micro_p, micro_r);
      }
    }
  };
#endif

  class HMMLearner {
  };

  class CRFLearner {
  public:
    static bool run(int argc, char **argv)
    {
      static const MeCab::Option long_options[] =
        {
          { "dicdir",   'd',  ".",     "DIR",    "set DIR as dicdir (default \".\" )" },
          { "cost",     'c',  "1.0",   "FLOAT",  "set FLOAT for cost C for constraints violatoin" },
          { "freq",     'f',  "1",     "INT",    "set the frequency cut-off (default 1)" },
          { "eta",      'e',  "0.001", "DIR",    "set FLOAT for tolerance of termination criterion" },
          { "thread",   'p',  "1",     "INT",    "number of threads (default 1)" },
          { "build",    'b',  0,  0,   "build binary model from text model"},
          { "text-only",'y',  0,  0,        "output text model only" },
          { "version",  'v',  0,   0,  "show the version and exit"  },
          { "help",     'h',  0,   0,  "show this help and exit."      },
          { 0, 0, 0, 0 }
        };

      Param param;

      if (! param.open(argc, argv, long_options)) {
        std::cout << param.what() << "\n\n" <<  COPYRIGHT
                  << "\ntry '--help' for more information." << std::endl;
        return -1;
      }

#define DCONF(file) create_filename(dicdir, std::string(file)).c_str()
      std::string dicdir = param.getProfileString("dicdir");
      CHECK_DIE(param.load(DCONF(DICRC)))
        << "no such file or directory: " << DCONF(DICRC);

#undef DCONF

      std::vector<std::string> files = param.rest_args();
      if (files.size() != 2) {
        std::cout << "Usage: " <<
          param.program_name() << " corpus model" << std::endl;
        return -1;
      }

      std::string ifile = files[0];
      std::string model = files[1];

      double C = param.getProfileFloat("cost");
      double eta = param.getProfileFloat("eta");
      bool text_only = param.getProfileFloat("text-only");
      bool build = param.getProfileFloat("build");
      size_t eval_size = param.getProfileInt("eval-size");
      size_t unk_eval_size = param.getProfileInt("unk-eval-size");
      size_t freq = param.getProfileInt("freq");
      size_t thread_num = param.getProfileInt("thread");

      if (build) {
        EncoderFeatureIndex feature_index_;
        CHECK_DIE(feature_index_.convert(ifile.c_str(), model.c_str()))
          << feature_index_.what();
        return 0;
      }

      EncoderFeatureIndex feature_index_;
      LearnerTokenizer tokenizer_;
      FreeList<LearnerPath> path_freelist_(PATH_FREELIST_SIZE);
      std::vector <double> expected_;
      std::vector <double> observed_;
      std::vector <double> alpha_;
      std::vector <EncoderLearnerTagger *> x_;

      std::cout.setf(std::ios::fixed, std::ios::floatfield);
      std::cout.precision(5);

      std::ifstream ifs(ifile.c_str());
      {
        CHECK_DIE(C > 0) << "cost parameter is out of range: " << C;
        CHECK_DIE(eta > 0) "eta is out of range: " << eta;
        CHECK_DIE(eval_size > 0) << "eval-size is out of range: " << eval_size;
        CHECK_DIE(unk_eval_size > 0) <<
          "unk-eval-size is out of range: " << unk_eval_size;
        CHECK_DIE(freq > 0) <<
          "freq is out of range: " << unk_eval_size;
        CHECK_DIE(thread_num > 0 && thread_num <= 512)
          << "# thread is invalid: " << thread_num;
        CHECK_DIE(tokenizer_.open(param)) << tokenizer_.what();
        CHECK_DIE(feature_index_.open(param)) << feature_index_.what();
        CHECK_DIE(ifs) << "no such file or directory: " << ifile;
      }

      std::cout << "reading corpus ..." << std::flush;

      while (ifs) {

        EncoderLearnerTagger *_x = new EncoderLearnerTagger();

        CHECK_DIE(_x->open(&tokenizer_, &path_freelist_,
                           &feature_index_,
                           eval_size,
                           unk_eval_size))
                             << _x->what();

        CHECK_DIE(_x->read(ifs, observed_)) << _x->what();

        if (! _x->empty()) x_.push_back(_x);
        else delete _x;

        if (x_.size() % 100 == 0)
          std::cout << x_.size() << "... " << std::flush;
      }

      // shrink vector
      feature_index_.shrink(freq, observed_);
      feature_index_.clearcache();

      int converge = 0;
      double old_f = 0.0;
      size_t psize = feature_index_.size();
      observed_.resize(psize);
      LBFGS lbfgs;

      alpha_.resize(psize);
      expected_.resize(psize);
      std::fill(alpha_.begin(), alpha_.end(), 0.0);
      lbfgs.init(static_cast<int>(psize), 5);

      feature_index_.set_alpha(&alpha_[0]);

      std::cout << std::endl;
      std::cout << "Number of sentences: " << x_.size() << std::endl;
      std::cout << "Number of features:  " << psize     << std::endl;
      std::cout << "eta:                 " << eta       << std::endl;
      std::cout << "freq:                " << freq      << std::endl;
#ifdef MECAB_USE_THREAD
      std::cout << "threads:             " << thread_num << std::endl;
#endif
      std::cout << "C(sigma^2):          " << C          << std::endl << std::endl;

#ifdef MECAB_USE_THREAD
      std::vector<learner_thread> thread;
      if (thread_num > 1) {
        thread.resize(thread_num);
        for (size_t i = 0; i < thread_num; ++i) {
          thread[i].start_i = i;
          thread[i].size = x_.size();
          thread[i].thread_num = thread_num;
          thread[i].x = &x_[0];
          thread[i].expected.resize(expected_.size());
        }
      }
#endif

      for (size_t itr = 0; ;  ++itr) {

        std::fill(expected_.begin(), expected_.end(), 0.0);

        double f = 0.0;
        size_t err = 0;
        size_t micro_p = 0;
        size_t micro_r = 0;
        size_t micro_c = 0;

#ifdef MECAB_USE_THREAD
        if (thread_num > 1) {
          for (size_t i = 0; i < thread_num; ++i)
            thread[i].start();

          for (size_t i = 0; i < thread_num; ++i)
            thread[i].join();

          for (size_t i = 0; i < thread_num; ++i) {
            f += thread[i].f;
            err += thread[i].err;
            micro_r += thread[i].micro_r;
            micro_p += thread[i].micro_p;
            micro_c += thread[i].micro_c;
            for (size_t k = 0; k < psize; ++k)
              expected_[k] += thread[i].expected[k];
          }
        } else
#endif
          {
            for (size_t i = 0; i < x_.size(); ++i) {
              f += x_[i]->gradient(&expected_[0]);
              err += x_[i]->eval(micro_c, micro_p, micro_r);
            }
          }

        double p = 1.0 * micro_c / micro_p;
        double r = 1.0 * micro_c / micro_r;
        double micro_f = 2 * p * r / (p + r);

        for (size_t i = 0; i < psize; ++i) {
          f += (alpha_[i] * alpha_[i]/(2.0 * C));
          expected_[i] = expected_[i] - observed_[i] + alpha_[i]/C;
        }

        double diff = (itr == 0 ? 1.0 : std::fabs(1.0 * (old_f - f) )/old_f);
        std::cout << "iter="    << itr
                  << " err="    << 1.0 * err/x_.size()
                  << " F="      << micro_f
                  << " target=" << f
                  << " diff="   << diff << std::endl;
        old_f = f;

        if (diff < eta) converge++; else converge = 0;
        if (converge == 3)  break; // 3 is ad-hoc

        int ret = lbfgs.optimize(&alpha_[0], &f, &expected_[0]);

        CHECK_DIE(ret > 0) << lbfgs.what();
      }

      std::cout << "\nDone! writing model file ... " << std::endl;

      std::string txtfile = model;
      txtfile += ".txt";

      CHECK_DIE(feature_index_.save(txtfile.c_str()))
        << feature_index_.what();

      if (! text_only) {
        CHECK_DIE(feature_index_.convert(txtfile.c_str(), model.c_str()))
          << feature_index_.what();
      }

      return 0;
    }
  };
}

int mecab_cost_train(int argc, char **argv)
{
  return MeCab::CRFLearner::run(argc, argv);
}
