const char *help = "\
hmm_speech (c) Trebolloc & Co 2001\n\
\n\
This program will train a HMM for a connected word speech recognition experiment\n";

#include "EditDistanceMeasurer.h"
#include "WordSegMeasurer.h"
#include "EMTrainer.h"
#include "ViterbiTrainer.h"
#include "DiagonalGMM.h"
#include "Kmeans.h"
#include "SpeechHMM.h"
#include "Dictionary.h"
#include "MatSeqDataSet.h"
#include "HtkSeqDataSet.h"
#include "CmdLine.h"
#include "NllMeasurer.h"

using namespace Torch;

// load a SpeechHMM that was saved in the HTK format
// assumes that it uses DiagonalGMMs
void load_htk_model(char* filename, SpeechHMM* shmm)
{
  char line[1000];
  char* values[1000];
  int n;
  FILE* f=fopen(filename,"r");
  if (!f)
    error("file %s cannot be opened",filename);
  // initialization
  for (int i=0;i<shmm->n_models;i++) {
    HMM* hmm = shmm->models[i];
    for (int j=1;j<hmm->n_states-1;j++) {
      DiagonalGMM* gmm = (DiagonalGMM*)hmm->states[j];
      for (int k=0;k<gmm->n_gaussians;k++) {
        gmm->log_weights[k] = LOG_ZERO;
        for (int l=0;l<gmm->n_observations;l++) {
          gmm->means[k][l] = 0;
          gmm->var[k][l] = 1;
        }
      }
    }
    for (int j=0;j<hmm->n_states;j++) {
      for (int k=0;k<hmm->n_states;k++) {
        hmm->log_transitions[k][j] = LOG_ZERO;
      }
    }
  }
  // reading the model
  int model = -1;
  int state = 0;
  int mixture = 0;
  real w;
  HMM* hmm = NULL;
  DiagonalGMM* gmm = NULL;
  fgets(line,1000,f);
  while (!feof(f)) {
    if (strstr(line,"~h")) {
      model++;
      hmm = shmm->models[model];
    } else if (strstr(line,"<STATE>")) {
      sscanf(line,"%*s %d",&state);
      gmm = (DiagonalGMM*)hmm->states[state-1];
    } else if (strstr(line,"<MIXTURE>")) {
      sscanf(line,"%*s %d %f",&mixture,&w);
      gmm->log_weights[mixture-1] = log(w);
    } else if (strstr(line,"<MEAN>")) {
      fgets(line,1000,f);
      values[0] = strtok(line," ");
      for (n=1;(values[n]=strtok(NULL," "));n++);
      for (int l=0;l<gmm->n_observations;l++) {
        gmm->means[mixture-1][l] = (real)atof(values[l]);;
      }
    } else if (strstr(line,"<VARIANCE>")) {
      fgets(line,1000,f);
      values[0] = strtok(line," ");
      for (n=1;(values[n]=strtok(NULL," "));n++);
      for (int l=0;l<gmm->n_observations;l++) {
        gmm->var[mixture-1][l] = (real)atof(values[l]);;
      }
    } else if (strstr(line,"<TRANSP>")) {
      for (int j=0;j<hmm->n_states;j++) {
        fgets(line,1000,f);
        values[0] = strtok(line," ");
        for (n=1;(values[n]=strtok(NULL," "));n++);
        for (int k=0;k<hmm->n_states;k++) {
          w = atof(values[k]);
          hmm->log_transitions[k][j] = w == 0 ? LOG_ZERO : log(w);
        }
      }
    }
    fgets(line,1000,f);
  }
  fclose(f);
}

// this function saves in HTK format a given SpeechHMM
void save_htk_model(char* filename, SpeechHMM* shmm, char**phonemes)
{
  FILE* f=fopen(filename,"w");
  if (!f)
    error("file %s cannot be opened",filename);
  fprintf(f,"~o\n");
  fprintf(f,"<STREAMINFO> 1 %d\n",shmm->n_observations);
  fprintf(f,"<VECSIZE> %d<NULLD><MFCC_D_A_O>\n",shmm->n_observations);
  for (int i=0;i<shmm->n_models;i++) {
    HMM* hmm = shmm->models[i];
    fprintf(f,"~h \"%s\"\n",phonemes[i]);
    fprintf(f,"<BEGINHMM>\n");
    fprintf(f,"<NUMSTATES> %d\n",hmm->n_states);
    for (int j=1;j<hmm->n_states-1;j++) {
      DiagonalGMM* gmm = (DiagonalGMM*)hmm->states[j];
      fprintf(f,"<STATE> %d\n",j+1);
      fprintf(f,"<NUMMIXES> %d\n",gmm->n_gaussians);
      for (int k=0;k<gmm->n_gaussians;k++) {
        fprintf(f,"<MIXTURE> %d %12.10e\n",k+1,exp(gmm->log_weights[k]));
        fprintf(f,"<MEAN> %d\n",gmm->n_observations);
        for (int l=0;l<gmm->n_observations;l++) {
          fprintf(f,"%12.10e ",gmm->means[k][l]);
        }
        fprintf(f,"\n");
        fprintf(f,"<VARIANCE> %d\n",gmm->n_observations);
        for (int l=0;l<gmm->n_observations;l++) {
          fprintf(f,"%12.10e ",gmm->var[k][l]);
        }
        fprintf(f,"\n");
      }
    }
    fprintf(f,"<TRANSP> %d\n",hmm->n_states);
    for (int j=0;j<hmm->n_states;j++) {
      for (int k=0;k<hmm->n_states;k++) {
        fprintf(f,"%12.10e ",hmm->log_transitions[k][j] != LOG_ZERO ? exp(hmm->log_transitions[k][j]) : 0);
      }
      fprintf(f,"\n");
    }
    fprintf(f,"<ENDHMM>\n");
  }
  fclose(f);
}

// this function can be used to add silences at the beginning 
// of each target sequence
void add_sil_to_targets(SeqDataSet* data, int sil_word)
{
  for (int i=0;i<data->n_examples;i++) {
    data->setExample(i);
    SeqExample* ex = (SeqExample*)data->inputs->ptr;
    if (ex->n_seqtargets>0) {
      ex->seqtargets = (real**)xrealloc(ex->seqtargets,sizeof(real*)*(ex->n_seqtargets+1));
      for (int j=ex->n_seqtargets;j>0;j--)
        ex->seqtargets[j] = ex->seqtargets[j-1];
      ex->seqtargets[0] = (real*)xalloc(sizeof(real));
      ex->seqtargets[0][0] = sil_word;
      ex->n_seqtargets ++;
    }
  }
}

// this function reads from a filename the list of legal phonemes
// and return it
char** read_phonemes(char* filename, int* n_phonemes)
{
  FILE *f=fopen(filename,"r");
  if (!f)
    error("file %s cannot be open",filename);
  fscanf(f,"%d",n_phonemes);
  char** phonemes = (char**)xalloc(sizeof(char*)* *n_phonemes);
  char word[100];
  for (int i=0;i<*n_phonemes;i++) {
    fscanf(f,"%s",word);
    phonemes[i] = (char*)xalloc(sizeof(char)*(strlen(word)+1));
    strcpy(phonemes[i],word);
  }
  fclose(f);
  return phonemes;
}

// this function read a file which contains a list of filenames, 
// transform them to add the path and the extension, and returns the
// given list, which should then contain the list of speech sequences
char** read_data(char* filename, int* n_data,char* data_dir, char* extension)
{
  // first find number of data files
  char command[300];
  sprintf(command,"wc -l %s",filename);
  FILE *f=popen(command,"r");
  if (!f)
    error("file %s cannot be open",filename);
  int n;
  fscanf(f,"%d",&n);
  *n_data = n;
  fclose(f);
  
  char** data = (char**)xalloc(sizeof(char*)* n);

  f=fopen(filename,"r");
  if (!f)
    error("file %s cannot be open",filename);
  for (int i=0;i<n;i++) {
    char word[300];
    fscanf(f,"%s",word);
    word[strlen(word)-1]='\0'; // strip last \"
    data[i] = (char*)xalloc(sizeof(char)*(strlen(word)+2+strlen(extension)+strlen(data_dir)));
    sprintf(data[i],"%s/%s.%s",data_dir,&word[1],extension);
  }
  fclose(f);
  return data;
}

int main(int argc, char **argv)
{
  char *train_file;
  char *cv_file;
  char *test_file;
  char *target_train_file;
  char *target_cv_file;
  char *target_test_file;
  char *data_dir;
  char *train_alignment;
  int max_load_train;
  int max_load_test;
  int max_load_cv;
  int seed_value;
  real accuracy;
  real threshold;
  int max_iter_kmeans;
  int max_iter_hmm;
  char *dir_name;
  char *load_model;
  char *save_model;
  int n_gaussians;
  int n_states;
  real prior;
  char* phoneme_name;
  char* dict_name;
  int silence_word;
  char* silence_name;
  bool htk;
  bool big_endian;
  bool little_endian;
  bool viterbi;
  real word_entrance_penalty;
  int initial_aligned_training_iter;
  char* extension;
  bool isolated;
  bool save_htk;
  bool htk_model;


  CmdLine cmd;

  cmd.info(help);

  cmd.addText("\nArguments:");
  cmd.addSCmdArg("phoneme_name", &phoneme_name, "the list of phonemes file");
  cmd.addSCmdArg("dict_name", &dict_name, "the dictionary file");
  cmd.addSCmdArg("train_file", &train_file, "the train file list");
  cmd.addSCmdArg("target_train_file", &target_train_file, "the target_train file");
  cmd.addSCmdArg("test_file", &test_file, "the test file list");
  cmd.addSCmdArg("target_test_file", &target_test_file, "the target_test file");

  cmd.addText("\nOptions:");
  cmd.addSCmdOption("-silence", &silence_name,"sil", "name of silence word");
  cmd.addSCmdOption("-train_alignment", &train_alignment,"", "train_alignment file");
  cmd.addICmdOption("-initial_aligned_training_iter", &initial_aligned_training_iter,0, "initial aligned EM/Viterbi training iterations");
  cmd.addSCmdOption("-data_dir", &data_dir,".", "directory containing data");
  cmd.addSCmdOption("-cv_file", &cv_file, "","the cross-valid file list");
  cmd.addSCmdOption("-target_cv_file", &target_cv_file, "","the target_cv file");
  cmd.addSCmdOption("-extension", &extension, "mfcSC","extension of HTK files");
  cmd.addBCmdOption("-save_htk", &save_htk, false,"convert model to HTK format");
  cmd.addBCmdOption("-htk_model", &htk_model, false,"model file is in HTK format");

  cmd.addText("\nPhoneme Model Options:");
  cmd.addICmdOption("-n_gaussians", &n_gaussians, 10, "number of Gaussians");
  cmd.addICmdOption("-n_states", &n_states, 5, "number of states");
  cmd.addRCmdOption("-threshold", &threshold, 0.1, "relative var threshold");
  cmd.addRCmdOption("-prior", &prior, 0.001, "prior on the weights and transitions");
  cmd.addBCmdOption("-isolated", &isolated, false, "isolated word recognition only");

  cmd.addText("\nLearning Options:");
  cmd.addBCmdOption("-viterbi", &viterbi, false, "viterbi learning (else EM learning)");
  cmd.addICmdOption("-iterk", &max_iter_kmeans, 25, "max number of iterations of Kmeans");
  cmd.addICmdOption("-iter", &max_iter_hmm, 25, "max number of iterations of HMM");
  cmd.addRCmdOption("-e", &accuracy, 0.0001, "end accuracy");

  cmd.addText("\nMisc Options:");
  cmd.addBCmdOption("-big_endian", &big_endian, false, "load in big endian format");
  cmd.addBCmdOption("-little_endian", &little_endian, false, "load in little endian format");
  cmd.addBCmdOption("-htk", &htk, false, "use the HTK file format");
  cmd.addICmdOption("-load_train", &max_load_train, -1, "max number of train examples to load");
  cmd.addICmdOption("-load_test", &max_load_test, -1, "max number of test examples to load");
  cmd.addICmdOption("-load_cv", &max_load_cv, -1, "max number of cv examples to load");
  cmd.addICmdOption("-seed", &seed_value, -1, "initial seed for random generator");
  cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");
  cmd.addSCmdOption("-lm", &load_model, "", "start from given model file");
  cmd.addSCmdOption("-sm", &save_model, "", "save results into given model file");
  cmd.addRCmdOption("-word_entrance_penalty", &word_entrance_penalty, 0., "word entrance penalty");

  cmd.read(argc, argv);

  if (seed_value == -1)
    seed();
  else
    manual_seed((long)seed_value);

  // read phoneme list and dictionary
  int n_phonemes;
  char** phonemes = read_phonemes(phoneme_name,&n_phonemes);
  Dictionary dict(dict_name,phonemes,n_phonemes);

  // find silence word
  silence_word = dict.findWord(silence_name);
  if (silence_word < 0)
    error("silence word %s is not in dictionary",silence_name);
  dict.silence_word = silence_word;
  int silence_phoneme = dict.words[silence_word][0];

  // create grammar
  bool all_sentences_starts_with_silence = true;
  Grammar grammar(dict.n_words+3);
  grammar.words[0] = -1; // initial state
  grammar.words[1] = silence_word; // initial silence
  grammar.words[dict.n_words+1] = silence_word; // final silence
  grammar.words[dict.n_words+2] = -1; // final state
  int* gw = &grammar.words[2];
  for (int i=0;i<dict.n_words;i++) {
    if (i != silence_word)
      *gw++ = i;
  }
  grammar.transitions[1][0] = true;
  for (int i=0;i<dict.n_words-1;i++) {
    grammar.transitions[i+2][1] = true;
    grammar.transitions[dict.n_words+1][i+2] = true;
    if (!isolated) {
      for (int j=0;j<dict.n_words-1;j++)
        grammar.transitions[j+2][i+2] = true;
    }
  }
  grammar.transitions[dict.n_words+2][dict.n_words+1] = true;

  // load datasets

  if (big_endian)
    setBigEndianMode();
  else if (little_endian)
    setLittleEndianMode();

  int n_train_data;
  char** train_files = read_data(train_file,&n_train_data, data_dir,extension);
  int real_n_train_data = n_train_data;
  if (n_train_data!=1) {
    n_train_data = max_load_train > 0 && max_load_train < n_train_data ? max_load_train : n_train_data;
    max_load_train = -1;
  }
  SeqDataSet *data;
  if (htk) {
    HtkSeqDataSet *hdata = new HtkSeqDataSet(train_files,n_train_data, max_load_train);
    hdata->setDictionary(&dict);
    hdata->setNPerFrame(125000);
    data = hdata;
  } else {
    data = new MatSeqDataSet(train_files, n_train_data, 0,-1,0,false, max_load_train);
  }
  data->init();
  data->readTargets(target_train_file);
  if (all_sentences_starts_with_silence)
    add_sil_to_targets(data,silence_word);
  if (strcmp(train_alignment,"")) {
    data->readAlignments(train_alignment);
  }
  int n_observations = data->n_observations;

  int n_cv_data = 0;
  int real_n_cv_data = 0;
  char** cv_files = NULL;
  SeqDataSet *cv_data = NULL;
  if (strcmp(cv_file,"")) {
    read_data(cv_file,&n_cv_data, data_dir,extension);
    real_n_cv_data = n_cv_data;
    if (n_cv_data!=1) {
      n_cv_data = max_load_cv > 0 && max_load_cv < n_cv_data ? max_load_cv : n_cv_data;
      max_load_cv = -1;
    }
    if (htk) {
      HtkSeqDataSet* hcv_data = new HtkSeqDataSet(cv_files,n_cv_data, max_load_cv);
      hcv_data->setDictionary(&dict);
      cv_data = hcv_data;
    } else {
      cv_data = new MatSeqDataSet(cv_files, n_cv_data, 0,-1,0,false, max_load_cv);
    }
    cv_data->init();
    cv_data->readTargets(target_cv_file);
    if (all_sentences_starts_with_silence)
      add_sil_to_targets(cv_data,silence_word);
  }

  int n_test_data;
  char** test_files = read_data(test_file,&n_test_data, data_dir,extension);
  int real_n_test_data = n_test_data;
  if (n_test_data!=1) {
    n_test_data = max_load_test > 0 && max_load_test < n_test_data ? max_load_test : n_test_data;
    max_load_test = -1;
  }
  SeqDataSet *test_data;
  if (htk) {
    HtkSeqDataSet *htest_data = new HtkSeqDataSet(test_files,n_test_data, max_load_test);
    htest_data->setDictionary(&dict);
    test_data = htest_data;
  } else {
    test_data = new MatSeqDataSet(test_files, n_test_data, 0,-1,0,false, max_load_test);
  }
  test_data->init();
  test_data->readTargets(target_test_file);
  if (all_sentences_starts_with_silence)
    add_sil_to_targets(test_data,silence_word);


  // compute the global variance of the training observations
  // and set the minimum variance (thresh) as a ratio of the global variance
  real* thresh = (real*)xalloc(n_observations*sizeof(real));
  for (int i=0;i<n_observations;i++) {
    real var_buff = 0;
    real mean_buff = 0;
    int n_frames = 0;
    for (int j=0;j<data->n_examples;j++) {
      data->setExample(j);
      SeqExample* ex = (SeqExample*)data->inputs->ptr;
      for (int k=0;k<ex->n_frames;k++) {
        data->setFrame(k);
        real z = ex->observations[data->current_frame][i];
        var_buff += z*z;
        mean_buff += z;
        n_frames++;
      }
    }
    var_buff /= (real)n_frames;
    mean_buff /= (real)n_frames;
    var_buff -= mean_buff*mean_buff;
    if (var_buff <= 0) {
      warning("compute variance: column %d has a null variance. setting to 1",i);
      var_buff = 1;
    }
    thresh[i] = threshold * var_buff;
  }

  // create models for each phoneme
  // Each model is an HMM with left-righ topology, where each state
  // is modelled with a DiagonalGMM which is initialized with a Kmeans
  // trained with EM.
  Kmeans*** kmeans = (Kmeans***)xalloc(sizeof(Kmeans**)*n_phonemes);
  DiagonalGMM*** gmm = (DiagonalGMM***)xalloc(sizeof(DiagonalGMM**)*n_phonemes);
  EMTrainer*** kmeans_trainer = (EMTrainer***)xalloc(sizeof(EMTrainer**)*n_phonemes);
  HMM** hmm = (HMM**)xalloc(sizeof(HMM*)*n_phonemes);
  real*** transitions = (real***)xalloc(n_phonemes*sizeof(real**));

  for (int i=0;i<n_phonemes;i++) {
    kmeans[i] = (Kmeans**)xalloc(sizeof(Kmeans*)*n_states);
    gmm[i] = (DiagonalGMM**)xalloc(sizeof(DiagonalGMM*)*n_states);
    kmeans_trainer[i] = (EMTrainer**)xalloc(sizeof(EMTrainer*)*n_states);
    for (int j=1;j<n_states-1;j++) {
      kmeans[i][j] = new Kmeans(n_observations,n_gaussians,thresh,prior,data);
      kmeans[i][j]->init();
      kmeans_trainer[i][j] = new EMTrainer(kmeans[i][j],data);
      kmeans_trainer[i][j]->setROption("end accuracy", accuracy);
      kmeans_trainer[i][j]->setIOption("max iter", max_iter_kmeans);
      gmm[i][j] = new DiagonalGMM(n_observations,n_gaussians,thresh,prior);
      gmm[i][j]->setOption("initial kmeans trainer",&kmeans_trainer[i][j]);
      gmm[i][j]->init();
    }
    gmm[i][0] = NULL;
    gmm[i][n_states-1] = NULL;

    // the transition table probability: left-right topology
    transitions[i] = (real**)xalloc(n_states*sizeof(real*));
    for (int j=0;j<n_states;j++) {
      transitions[i][j] = (real*)xalloc(n_states*sizeof(real));
    }
    for (int j=0;j<n_states;j++) {
      for (int k=0;k<n_states;k++)
        transitions[i][j][k] = 0;
    }
    transitions[i][1][0] = 1;
    for (int j=1;j<n_states-1;j++) {
      transitions[i][j][j] = 0.5;
      transitions[i][j+1][j] = 0.5;
    }
    
    // the silence model is special
    if (i == silence_phoneme) {
      transitions[i][n_states-2][1] = 1./3.;
      transitions[i][1][1] = 1./3.;
      transitions[i][2][1] = 1./3.;
      transitions[i][1][n_states-2] = 1./3.;
      transitions[i][n_states-2][n_states-2] = 1./3.;
      transitions[i][n_states-1][n_states-2] = 1./3.;
    }

    hmm[i] = new HMM(n_states,(Distribution**)gmm[i],prior,data,transitions[i]);
    hmm[i]->init();
  }


  // eventually provide a trainer using alignment, which is used to initialize
  // each HMM with its own alignment when provided
  EMTrainer* model_trainer = NULL;
  if (initial_aligned_training_iter > 0) {
    if (viterbi) {
      model_trainer = (EMTrainer*)new ViterbiTrainer(hmm[0],data);
    } else {
      model_trainer = new EMTrainer(hmm[0],data);
    }
    model_trainer->setROption("end accuracy", accuracy);
    model_trainer->setIOption("max iter", initial_aligned_training_iter);
  }


  SpeechHMM shmm(n_phonemes,hmm,phonemes,&dict,&grammar,word_entrance_penalty,
    model_trainer);
  shmm.init();

  // if you just wanted to transform your saved model into an HTK model
  if (save_htk) {
    char save_model_name[100];
    sprintf(save_model_name,"%s/%s",dir_name,save_model);
    save_htk_model(save_model_name,&shmm,phonemes);
    exit(1);
  }

  // these are the costs used in HTK
  shmm.edit_distance->setCosts(7,7,10);
  // these are the costs used in NIST
  //shmm.edit_distance->setCosts(33,33,40);
  // these are the basic costs 
  //shmm.edit_distance->setCosts(1,1,1);

  EMTrainer* trainer = NULL;
  if (viterbi) {
    trainer = (EMTrainer*)new ViterbiTrainer(&shmm,data);
  } else {
    trainer = new EMTrainer(&shmm,data);
  }
  trainer->setROption("end accuracy", accuracy);
  trainer->setIOption("max iter", max_iter_hmm);

  // provide a few measurers, to measure negative log likelihoods,
  // edit distance, and word segmentation for train, cv, and test data

  List *train_meas_hmm = NULL;
  char hmm_name[100];
  if (strcmp(load_model,""))
    sprintf(hmm_name,"%s/nll_train_load",dir_name);
  else
    sprintf(hmm_name,"%s/nll_train",dir_name);
  NllMeasurer nll_meas_hmm(shmm.outputs,data,hmm_name);
  nll_meas_hmm.init();
  addToList(&train_meas_hmm,1,&nll_meas_hmm);

  List *cv_test_meas_hmm = NULL;
  char hmm_cv_name[100];
  char cv_ed_name[100];
  char cv_word_name[100];
  NllMeasurer *nll_meas_hmm_cv = NULL;
  EditDistanceMeasurer *ed_cv = NULL;
  WordSegMeasurer *word_cv = NULL;
  if (strcmp(cv_file,"")) {
    if (strcmp(load_model,""))
      sprintf(hmm_cv_name,"%s/nll_cv_load",dir_name);
    else
      sprintf(hmm_cv_name,"%s/nll_cv",dir_name);
    nll_meas_hmm_cv = new NllMeasurer(shmm.outputs,cv_data,hmm_cv_name);
    nll_meas_hmm_cv->init();
    addToList(&train_meas_hmm,1,nll_meas_hmm_cv);

    sprintf(cv_ed_name,"%s/ed_cv",dir_name);
    ed_cv = new EditDistanceMeasurer(shmm.edit_distance,cv_data,cv_ed_name,false);
    ed_cv->init();
    addToList(&cv_test_meas_hmm,1,ed_cv);

    sprintf(cv_word_name,"%s/word_cv",dir_name);
    word_cv = new WordSegMeasurer(&shmm,cv_data,cv_word_name,true);
    word_cv->init();
    addToList(&cv_test_meas_hmm,1,word_cv);
  }

  char test_ed_name[100];
  sprintf(test_ed_name,"%s/ed_test",dir_name);
  EditDistanceMeasurer ed_test(shmm.edit_distance,test_data,test_ed_name,false);
  ed_test.init();
  addToList(&cv_test_meas_hmm,1,&ed_test);

  char test_word_name[100];
  sprintf(test_word_name,"%s/word_test",dir_name);
  WordSegMeasurer word_test(&shmm,test_data,test_word_name,true);
  word_test.init();
  addToList(&cv_test_meas_hmm,1,&word_test);

  List *test_init_meas = NULL;
  char test_init_ed_name[100];
  sprintf(test_init_ed_name,"%s/ed_test_init",dir_name);
  EditDistanceMeasurer ed_test_init(shmm.edit_distance,test_data,test_init_ed_name,false);
  ed_test_init.init();
  addToList(&test_init_meas,1,&ed_test_init);

  char test_init_word_name[100];
  sprintf(test_init_word_name,"%s/word_test_init",dir_name);
  WordSegMeasurer word_test_init(&shmm,test_data,test_init_word_name);
  word_test_init.init();
  addToList(&test_init_meas,1,&word_test_init);

  // either we load the parameters of a saved model, or we initialize
  // the model, decode (to see how good is an initialized model), and train it.

  if (strcmp(load_model,"")) {
    char load_model_name[100];
    sprintf(load_model_name,"%s/%s",dir_name,load_model);
    if (htk_model) {
      load_htk_model(load_model,&shmm);
    } else {
      trainer->load(load_model_name);
    }
  } else {
    shmm.reset();
    if (strcmp(save_model,"")) {
      char save_model_name[100];
      sprintf(save_model_name,"%s/%s_init",dir_name,save_model);
      trainer->save(save_model_name);
    }
    trainer->decode(test_init_meas);
    trainer->train(train_meas_hmm);
  }

  // after training, we can save the model

  if (strcmp(save_model,"")) {
    char save_model_name[100];
    sprintf(save_model_name,"%s/%s",dir_name,save_model);
    trainer->save(save_model_name);
  }

  // then launch the decoding (using simple viterbi decoding)

  trainer->decode(cv_test_meas_hmm);

  // and delete everything!

  for (int j=0;j<n_phonemes;j++) {
    for (int i=1;i<n_states-1;i++) {
      delete kmeans[j][i];
      delete gmm[j][i];
      delete kmeans_trainer[j][i];
    }
    free(kmeans[j]);
    free(gmm[j]);
    free(kmeans_trainer[j]);
    delete hmm[j];
    for (int i=0;i<n_states;i++)
      free(transitions[j][i]);
    free(transitions[j]);
  }
  free(transitions);
  free(kmeans);
  free(gmm);
  free(kmeans_trainer);
  free(hmm);
  free(thresh);
  freeList(&train_meas_hmm);
  freeList(&cv_test_meas_hmm);
  freeList(&test_init_meas);
  if (strcmp(cv_file,"")) {
    delete nll_meas_hmm_cv;
    delete ed_cv;
    delete word_cv;
  }

  delete data;
  delete test_data;
  delete cv_data;
  delete trainer;

  for (int i=0;i<real_n_train_data;i++)
    free(train_files[i]);
  free(train_files);
  for (int i=0;i<real_n_cv_data;i++)
    free(cv_files[i]);
  free(cv_files);
  for (int i=0;i<real_n_test_data;i++)
    free(test_files[i]);
  free(test_files);
  for (int i=0;i<n_phonemes;i++)
    free(phonemes[i]);
  free(phonemes);
  return(0);
}
