const char *help = "\
MLPTorch (c) Trebolloc & Co 2001\n\
\n\
This program will train a MLP with tanh outputs for\n\
classification and linear outputs for regression\n";

#include "MLP.h"
#include "FileDataSet.h"
#include "ClassLLCriterion.h"
#include "MseCriterion.h"
#include "MseMeasurer.h"
#include "ClassMeasurer.h"
#include "TwoClassFormat.h"
#include "OneHotClassFormat.h"
#include "StochasticGradient.h"
#include "GMTrainer.h"
#include "CmdLine.h"
#include "string_utils.h"

using namespace Torch;

int main(int argc, char **argv)
{
  char *model_file, *test_model_file;
  char *valid_file;
  char *file;

  int n_inputs;
  int n_targets;
  int n_hu;

  int max_load;
  int max_load_valid;
  real accuracy;
  real learning_rate;
  real decay;
  int max_iter;
  bool regression;
  int k_fold;
  int the_seed;
  bool sigmoid_outputs;
  bool inputs_to_outputs;
  real weight_decay;
  bool one_hot;
  bool train_prob;

  char *dir_name;
  char *buffer;

  //=================== The command-line ==========================

  // Construct the command line
  CmdLine cmd;

  // Put the help line at the beginning
  cmd.info(help);

  // Ask for arguments
  cmd.addText("\nArguments:");
  cmd.addSCmdArg("file", &file, "the train or test file");
  cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");
  cmd.addICmdArg("n_targets", &n_targets, "output dimension of the data");

  // Propose some options
  cmd.addText("\nModel Options:");
  cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units");
  cmd.addBCmdOption("-rm", &regression, false, "regression mode");
  cmd.addBCmdOption("-sigmoid", &sigmoid_outputs, false, "sigmoid outputs instead of tanh (for classification)");
  cmd.addBCmdOption("-i_to_o", &inputs_to_outputs, false, "add a linear connection from inputs to outputs");

  cmd.addText("\nLearning Options:");
  cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");
  cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");
  cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");
  cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");
  cmd.addRCmdOption("-wd", &weight_decay, 0, "weight decay");
  cmd.addBCmdOption("-prob", &train_prob, false, "in classification, train by maximizing the likelihood");

  cmd.addText("\nMisc Options:");
  cmd.addBCmdOption("-one_hot", &one_hot, false, "one-hot encoding for two classes classification");
  cmd.addICmdOption("-seed", &the_seed, -1, "the random seed");
  cmd.addICmdOption("-Kfold", &k_fold, -1, "number of subsets for K-fold cross-validation");
  cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load for train");
  cmd.addICmdOption("-load_valid", &max_load_valid, -1, "max number of examples to load for valid");
  cmd.addSCmdOption("-valid", &valid_file, "", "validation file, if you want it");
  cmd.addSCmdOption("-sm", &model_file, "", "file to save the model");
  cmd.addSCmdOption("-test", &test_model_file, "", "model file to test");
  cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");

  // Read the command line
  cmd.read(argc, argv);

  // If the user didn't give any random seed,
  // generate a random random seed...
  if(the_seed == -1)
    seed();
  else
    manual_seed((long)the_seed);

  //=================== Create the MLP... =========================
  MLP mlp(n_inputs, n_hu, n_targets);
  if(!regression)
  {
    if(train_prob && !regression)
      mlp.setBOption("log-softmax outputs", true);
    else
    {
      if(sigmoid_outputs)
        mlp.setBOption("sigmoid outputs", true);
      else
        mlp.setBOption("tanh outputs", true);
    }
  }
  mlp.setROption("weight decay", weight_decay);
  mlp.setBOption("inputs to outputs", inputs_to_outputs);

  // Initialize the MLP
  mlp.init();


  //=================== DataSets & Measurers... ===================

  // Create the training dataset (normalize inputs)
  FileDataSet data(file, n_inputs, n_targets, false, max_load);
  data.setBOption("normalize inputs", true);
  data.init();

  // The list of measurers...
  List *measurers = NULL;

  // The class format
  ClassFormat *class_format = NULL;
  if(!regression)
  {
    if( (n_targets == 1) && (!train_prob) && (!one_hot) )
      class_format = new TwoClassFormat(&data);
    else
      class_format = new OneHotClassFormat(&data);
  }

  // The validation set...
  FileDataSet *valid_data = NULL;
  MseMeasurer *valid_mse_meas = NULL;
  ClassMeasurer *valid_class_meas = NULL;

  // Create a validation set, if any
  if(strcmp(valid_file, ""))
  {
    // Load the validation set and normalize it with the
    // values in the train dataset
    valid_data = new FileDataSet(valid_file, n_inputs, n_targets, false, max_load_valid);
    valid_data->init();
    valid_data->normalizeUsingDataSet(&data);

    // Create a MSE measurer and an error class measurer
    // on the validation dataset (if we are not in regression)
    buffer = strConcat(2, dir_name, "/the_valid_mse");
    valid_mse_meas = new MseMeasurer(mlp.outputs, valid_data, buffer);
    valid_mse_meas->init();
    addToList(&measurers, 1, valid_mse_meas);
    free(buffer);

    if(!regression)
    {
      buffer = strConcat(2, dir_name, "/the_valid_class_err");
      valid_class_meas = new ClassMeasurer(mlp.outputs, valid_data, class_format, buffer);
      valid_class_meas->init();
      addToList(&measurers, 1, valid_class_meas);
      free(buffer);
    }
  }

  // Measurers on the training dataset
  buffer = strConcat(2, dir_name, "/the_mse");
  MseMeasurer *mse_meas = new MseMeasurer(mlp.outputs, &data, buffer);
  mse_meas->init();
  addToList(&measurers, 1, mse_meas);
  free(buffer);

  ClassMeasurer *class_meas = NULL;
  if(!regression)
  {
    buffer = strConcat(2, dir_name, "/the_class_err");
    class_meas = new ClassMeasurer(mlp.outputs, &data, class_format, buffer);
    class_meas->init();
    addToList(&measurers, 1, class_meas);
    free(buffer);
  }

  //=================== The Trainer ===============================
  
  // The criterion for the GMTrainer (MSE criterion or LL criterion)
  MseCriterion *mse = NULL;
  ClassLLCriterion *cllc = NULL;

  Criterion *the_criterion;
  if(train_prob && !regression)
  {
    cllc = new ClassLLCriterion(class_format);
    cllc->init();
    the_criterion = cllc;
  }
  else
  {
    mse = new MseCriterion(n_targets);
    mse->init();
    the_criterion = mse;
  }

  // The optimizer for the GMTrainer
  StochasticGradient opt;
  opt.setIOption("max iter", max_iter);
  opt.setROption("end accuracy", accuracy);
  opt.setROption("learning rate", learning_rate);
  opt.setROption("learning rate decay", decay);

  // The Gradient Machine Trainer
  GMTrainer trainer(&mlp, &data, the_criterion, &opt);

  //=================== Let's go... ===============================

  // Print the number of parameter of the MLP (just for fun)
  message("Number of parameters: %d", mlp.n_params);

  // If the user provides a previously trained model,
  // test it...
  if( strcmp(test_model_file, "") )
  {
    trainer.load(test_model_file);
    trainer.test(measurers);
  }

  // ...else...
  else
  {
    // If the user provides a number for the K-fold validation,
    // do a K-fold validation
    if(k_fold > 0)
      trainer.crossValidate(k_fold, NULL, measurers);

    // Else, train the model
    else
      trainer.train(measurers);

    // Save the model if the user provides a name for that
    if( strcmp(model_file, "") )
      trainer.save(model_file);
  }

  //=================== Quit... ===================================
  if(strcmp(valid_file, ""))
  {
    delete valid_data;
    delete valid_mse_meas;
    if(!regression)
      delete valid_class_meas;
  }

  delete mse_meas;
  if(!regression)
  {
    delete class_meas;
    delete class_format;
  }

  if(train_prob && !regression)
    delete cllc;
  else
    delete mse;

  freeList(&measurers);

  return(0);
}
