/* $Id: rex.h,v 1.30 2001/10/29 23:05:50 ericp Exp $ */

#ifndef _SFSMISC_REX_H_
#define _SFSMISC_REX_H_ 1

#include "agentconn.h"

static bool garbage_bool;

class rexchannel;

class rexfd : public virtual refcount
{
 protected:
  rexchannel *pch;
  ptr <aclnt> proxy;
  u_int32_t channo;

  int fd;
  
 public:
  // these implement null fd behavior, so you'll probably want to override them
  virtual void abort ();
  virtual void data (svccb *sbp);

  virtual void newfd (svccb *sbp) {sbp->replyref (false);}

  /* called when remote module exits */
  virtual void exited () {};
  
  rexfd (rexchannel *pch, int fd);
  virtual ~rexfd ();
};

class unixfd : public rexfd
{
 protected:
  int localfd;

 private:
  ptr<aios> paios;

  bool weof;
  bool reof;
  bool shutrdonexit;

  cbv closecb;

  void update_connstate (int how, int error = 0);

  void datacb (ref<bool> okp, clnt_stat) {
    if (!*okp)
      update_connstate (SHUT_RDWR);
  }
  
  void rcb (const str data, int err) {
    //I think this gets called for write errors as well
    if (err) {
      abort ();
      return;
    }

    if (reof)
      return;

    if (!data) {
      rex_payload payarg;
      payarg.channel = channo;
      payarg.fd = fd;
      payarg.data.set ((char *)NULL, 0);
      
      proxy->call (REX_DATA, &payarg, &garbage_bool, aclnt_cb_null);
      
      update_connstate (SHUT_RD);
      return;
    }
    else {
      rex_payload arg;
      arg.channel = channo;
      arg.fd = fd;
      
      arg.data.set (const_cast<char *> (data.cstr ()), 
		    data.len (), freemode::NOFREE);

      ref<bool> pres (New refcounted<bool> (false));
      proxy->call (REX_DATA, &arg, pres,
		   wrap (mkref (this), &unixfd::datacb, pres));
      
      // todo:  flow control
      paios->readany (wrap (this, &unixfd::rcb));
    }
  }

 public:

  virtual void newfd (svccb *sbp) {
    rexcb_newfd_arg *argp = sbp->template getarg<rexcb_newfd_arg> ();
    
    int s[2];
    
    if(socketpair(AF_UNIX, SOCK_STREAM, 0, s)) {
      warn << "error creating socketpair";
      sbp->replyref(false);
      return;
    }
    
    make_async (s[1]);
    make_async (s[0]);
    
    paios->sendfd (s[1]);
    
    vNew refcounted<unixfd> (pch, argp->newfd, s[0]);
        
    sbp->replyref (true);
  }
  
  virtual void data (svccb *sbp) {
    assert (paios);
    
    rex_payload *argp = sbp->template getarg<rex_payload> ();
    
    if (argp->data.size () > 0) {
      if (weof) {
	sbp->replyref (false);
	return;
      }
      else {
	str data (argp->data.base (), argp->data.size ());
	paios << data;
	sbp->replyref (true);
      }
    }
    else {
      sbp->replyref (true);
      
      //we don't shutdown immediately to give data a chance to
      //asynchronously flush
      paios->setwcb (wrap (this, &unixfd::update_connstate, SHUT_WR));
    }
  }

  virtual void abort () {
    update_connstate (SHUT_RDWR);
  }
      

  void exited () {
    if (shutrdonexit)
      update_connstate (SHUT_RD);
  }
  
  
  /* unixfd specific arguments:
       localfd: local file descriptor 
       noclose: will not use close or shutdown calls on the local file
                descriptor, useful for terminal descriptors, which must
		hang around so that raw mode can be disabled, etc.
       shutrdonexit: when the remote module exits, shutdown the read
                direction of the local file descriptor.  this isn't
                always done since not all file descriptors managed on
                the REX channel are necessarily connected to the remote
                module.
   */
  unixfd (rexchannel *pch, int fd,
	  int localfd, bool noclose = false, bool shutrdonexit = false,
	  cbv closecb = cbv_null)
    : rexfd::rexfd (pch, fd),
      localfd (localfd),  weof (false), reof (false),
      shutrdonexit (shutrdonexit), closecb (closecb)
    {
      if (noclose) {
	int duplocalfd = dup (localfd);
	if (duplocalfd < 0)
	  warn ("failed to duplicate fd for noclose behavior (%m)\n");
	else
	  unixfd::localfd = duplocalfd;
      }
      paios = aios::alloc (unixfd::localfd);
      paios->readany (wrap (this, &unixfd::rcb));
    }

  virtual ~unixfd ()
    {
      if (paios)
	paios->flush ();
      closecb ();
    }
};

class rexsession;

class rexchannel {

  vec <ptr<rexfd> >  vfds;
  int fdc;

  rexsession *sess;

 protected: 
  ptr <aclnt> proxy;
  u_int32_t channo;

  int initnfds;
  vec <str> command;
  
  friend class rexsession;

  virtual void quit () {
/*     warn << "--entering rexchannel::quit\n"; */
    rex_int_arg arg;
    arg.channel = channo;
    arg.val = 15;
    proxy->call (REX_KILL, &arg, &garbage_bool, aclnt_cb_null);
  }

  virtual void abort ();

  virtual void madechannel (int error) {};

  void
    channelinit (u_int32_t chnumber, ref <aclnt> proxyaclnt, int error)
    {
      proxy = proxyaclnt;
      channo = chnumber;
      madechannel (error);
    }
  
  virtual void data(svccb *sbp) {
    assert (sbp->prog () == REXCB_PROG && sbp->proc () == REXCB_DATA);
    rex_payload *dp = sbp->template getarg<rex_payload> ();
    assert (dp->channel == channo);
    if (dp->fd < 0 ||
	implicit_cast<size_t> (dp->fd) >= vfds.size () ||
	!vfds[dp->fd]) {
      warn ("payload fd %d out of range\ndata:%s\n", dp->fd,
	    dp->data.base ());
      sbp->replyref (false);
      return;
    }

    vfds[dp->fd]->data(sbp);
  }

  virtual void newfd (svccb *sbp) {
    assert (sbp->prog () == REXCB_PROG && sbp->proc () == REXCB_NEWFD);
    rexcb_newfd_arg *arg = sbp->template getarg<rexcb_newfd_arg> ();

    int fd = arg->fd;

    if (fd < 0 || implicit_cast<size_t> (fd) >= vfds.size () || !vfds[fd]) {
      warn ("newfd received on invalid fd %d at rexchannel::newfd\n", fd);
      sbp->replyref (false);
      return;
    }
      
    vfds[fd]->newfd (sbp);
  }

  virtual void exited () {
    for (size_t ix = 0; ix < vfds.size();  ix++) {
      if (!vfds[ix]) continue;
      vfds[ix]->exited();
    } 
  }

 public:

  void insert_fd (int fdn, ref<rexfd> rfd) {
    assert (fdn >= 0);

/*     warn << "--reached insert_fd\n"; */
    size_t oldsize = vfds.size ();
    size_t neededsize = fdn + 1;
    
    if (neededsize > oldsize) {
      vfds.setsize (neededsize);
      for (int ix = oldsize; implicit_cast <size_t> (ix) < neededsize; ix++)
	  vfds[ix] = NULL;
    }
    
    if (vfds[fdn]) {
      warn ("creating fd on busy fd %d at rexfd::rexfd, overwriting\n", fdn);
      assert (false);
    }
    
    vfds [fdn] = rfd;
    fdc++;
    
  }
  
  void remove_fd (int fdn);  

  int        get_initnfds () { return initnfds; }
  vec<str>   get_cmd      () { return command; }
  u_int32_t  get_channo   () { return channo; }
  ptr<aclnt> get_proxy    () { return proxy; }
      
  rexchannel (rexsession *sess, int initialfdcount, vec <str> command)
    : fdc (0), sess (sess), initnfds (initialfdcount),
      command(command) {
/*     warn << "--reached rexchannel: fdc = " << fdc << "\n"; */
  }

  virtual ~rexchannel () {
/*     warn << "--reached ~rexchannel\n"; */
  }
  
};


class rexsession {
  
  ptr<axprt_crypt> proxyxprt;
  ptr<asrv> rexserv;

  //todo : make this non-refcounted pointer
  qhash<u_int32_t, ref <rexchannel> > channels;
  int cchan;

  callback<void>::ptr endcb;
  
 public:
  ptr<aclnt> proxy;

 private:
  void
    rexcb_dispatch (svccb *sbp)
    {
      if (!sbp) {
	warn << "rexcb_dispatch: error\n";
	if (endcb) endcb ();
	return;
      }
      
      switch (sbp->proc ()) {
	
      case REXCB_NULL:
	sbp->reply (NULL);
	break;
	
      case REXCB_EXIT:
	{
	  rex_int_arg *argp = sbp->template getarg<rex_int_arg> ();
	  rexchannel *chan = channels[argp->channel];
	  
	  if(chan)
	    chan->exited();
	  break;
	}
	
      case REXCB_DATA:
	{
	  rex_payload *argp = sbp->template getarg<rex_payload> ();
	  rexchannel *chan = channels[argp->channel];
	  
	  if(chan)
	    chan->data(sbp);
	  else	         	    
	    sbp->replyref (false);
	  break;
	}
	
      case REXCB_NEWFD:
	{
	  rex_int_arg *argp = sbp->template getarg<rex_int_arg> ();
	  rexchannel *chan = channels[argp->channel];
	  if(chan)
	    chan->newfd(sbp);
	  else
	    sbp->replyref(false);
	  break;
	}
    
      default:
	sbp->reject (PROC_UNAVAIL);
	break;
      }
    }
  
  
  void madechannel (ref<rex_mkchannel_res> resp, ref<rexchannel> newchan,
		    clnt_stat err) {
    if (err) {
      warn << "REX_MKCHANNEL RPC failed (" << err << ")\n";
      newchan->channelinit (0, proxy, 1);
    }
    else if (resp->err != SFS_OK) {
      warn << "REX_MKCHANNEL failed (" << int (resp->err) << ")\n";
      newchan->channelinit (0, proxy, 1);
    }
    else {
      cchan++;

      warn << "made channel: ";
      vec<str> command = newchan->get_cmd ();
      for (size_t i = 0; i < command.size (); i++)
	warnx << command[i] << " ";
      warnx << "\n";

      channels.insert (resp->resok->channel, newchan);
      newchan->channelinit (resp->resok->channel, proxy, 0);
    }
  }
    
  void seq2sessinfo (u_int64_t seqno, sfs_hash *sidp, sfs_sessinfo *sip,
		     rex_sesskeydat *kcsdat, rex_sesskeydat *kscdat) {
    kcsdat->seqno = seqno;
    kscdat->seqno = seqno;
    
    sfs_sessinfo si;
    si.type = SFS_SESSINFO;
    si.kcs.setsize (sha1::hashsize);
    sha1_hashxdr (si.kcs.base (), *kcsdat, true);
    si.ksc.setsize (sha1::hashsize);
    sha1_hashxdr (si.ksc.base (), *kscdat, true);
    
    if (sidp)
      sha1_hashxdr (sidp->base (), si, true);
    if (sip)
      *sip = si;
    
    bzero (si.kcs.base (), si.kcs.size ());
    bzero (si.ksc.base (), si.ksc.size ());
  }

  void attached (rexd_attach_res *resp, ptr<axprt_crypt> sessxprt,
		 sfs_sessinfo *sessinfo, cbv sessioncreatedcb, clnt_stat err) {

    if (err) {
      fatal << "FAILED (" << err << ")\n";
    }
    else if (*resp != SFS_OK) {
      // XXX
      fatal << "FAILED (attach err " << int (*resp) << ")\n";
    }
    delete resp;
    warn << "attached\n";
    
    proxyxprt = axprt_crypt::alloc (sessxprt->reclaim ());
    proxyxprt->encrypt (sessinfo->kcs.base (), sessinfo->kcs.size (),
			sessinfo->ksc.base (), sessinfo->ksc.size ());
    
    bzero (sessinfo->kcs.base (), sessinfo->kcs.size ());
    bzero (sessinfo->ksc.base (), sessinfo->ksc.size ());
    delete sessinfo;
    
    proxy = aclnt::alloc (proxyxprt, rex_prog_1);
    rexserv = asrv::alloc (proxyxprt, rexcb_prog_1,
			   wrap (this, &rexsession::rexcb_dispatch));
    
    sessioncreatedcb ();
  }

  

  void connected (rex_sesskeydat *kcsdat, rex_sesskeydat *kscdat,
		  sfs_seqno *rexseqno, cbv sessioncreatedcb, ptr<sfscon> sc,
		  str err) {
    if (!sc) {
      fatal << schost << ": FAILED (" << err << ")\n";
    }
    
    ptr <axprt_crypt> sessxprt = sc->x;
    ptr <aclnt> sessclnt = aclnt::alloc (sessxprt, rexd_prog_1);
    
    rexd_attach_arg arg;
    
    arg.seqno = *rexseqno;
    sfs_sessinfo *sessinfo = New sfs_sessinfo;
    
    seq2sessinfo (0, &arg.sessid, NULL, kcsdat, kscdat);
    seq2sessinfo (arg.seqno, &arg.newsessid, sessinfo, kcsdat, kscdat);
    
    //ECP comment: why doesn't agent just give us sessid,newsessid,sessinfo??
    
    rexd_attach_res *resp = New rexd_attach_res;
    sessclnt->call (REXD_ATTACH, &arg, resp, wrap (this,
						   &rexsession::attached,
						   resp, sessxprt, sessinfo,
						   sessioncreatedcb));
    
    delete kcsdat;
    delete kscdat;
    delete rexseqno;
  }

  void quitcaller (const u_int32_t &chno, ptr<rexchannel> pchan) {
/*     warn << "--reached quitcaller\n"; */
    pchan->quit ();
  }

  void abortcaller (const u_int32_t &chno, ptr<rexchannel> pchan) {
    pchan->abort ();
  }

 public:
  str schost;

  //use this one if you already have an encrypted transport connected to proxy
  rexsession (str schostname, ptr<axprt_crypt> proxyxprt)
    : proxyxprt (proxyxprt), cchan (0), endcb (NULL), schost (schostname)
    {
      proxy = aclnt::alloc (proxyxprt, rex_prog_1);
      rexserv = asrv::alloc (proxyxprt, rexcb_prog_1,
			     wrap (this, &rexsession::rexcb_dispatch));
    }

  
      
  rexsession (cbv sessioncreatedcb, str schostname, bool forwardagent)
    : cchan (0), endcb (NULL), schost(schostname)
    {
      ref <agentconn> aconn = New refcounted<agentconn> ();
      ptr<sfsagent_rex_res> ares = aconn->rex (schost, forwardagent);
      if (!ares || !ares->status)
	fatal << "could not connect to agent\n";
  
      rex_sesskeydat *kscdat = New rex_sesskeydat;
      rex_sesskeydat *kcsdat = New rex_sesskeydat;
      sfs_seqno *rexseqno = New sfs_seqno;
      	    
      kcsdat->type = SFS_KCS;
      kcsdat->cshare = ares->resok->kcs.kcs_share;
      kcsdat->sshare = ares->resok->kcs.ksc_share;
      kscdat->type = SFS_KSC;
      kscdat->cshare = ares->resok->ksc.kcs_share;
      kscdat->sshare = ares->resok->ksc.ksc_share;
      *rexseqno = ares->resok->seqno;
      sfs_connect_path (schostname, SFS_REX,
			wrap (this, &rexsession::connected, kcsdat, kscdat,
			      rexseqno, sessioncreatedcb),
			false);
    }

  ~rexsession ()
    {
/*       warn << "--reached ~rexsession\n"; */
    }

  // get's called when all channels close or we get EOF from proxy
  void setendcb (cbv endcb) { rexsession::endcb = endcb; }

  void makechannel (ref<rexchannel> newchan, rex_env env = rex_env ()) {
    rex_mkchannel_arg arg;

    vec<str> command = newchan->get_cmd ();
    arg.av.setsize (command.size ());
    for (size_t i = 0; i < command.size (); i++)
      arg.av[i] = command[i];
    arg.nfds = newchan->get_initnfds ();
    arg.env = env;
    
    ref<rex_mkchannel_res> resp = New refcounted<rex_mkchannel_res> ();
    proxy->call (REX_MKCHANNEL, &arg, resp, wrap (this,
						  &rexsession::madechannel,
						  resp, newchan));
  }

  void remove_chan (int channo) {
/*     warn << "--reached remove_chan; cchan = " << cchan << "\n"; */
    channels.remove (channo);
    if (!--cchan) {
/*       warn << "--remove_chan: removing last channel\n"; */
      if (endcb)
	endcb ();
    }
  }

  // informs all channels that client wants to quit.  default
  // behavior is to kill remote module.
  void quit () { channels.traverse (wrap (this, &rexsession::quitcaller)); }

  // calls the abort member function of every channel, which should blow
  // all the channels away
  void abort () {
    endcb = NULL;
    channels.traverse (wrap (this, &rexsession::abortcaller));
  }
};


inline
rexfd::~rexfd () { 
/*   warn << "--reached ~rexfd\n"; */
  rex_int_arg arg;
  arg.channel = channo;
  arg.val = fd;
  proxy->call (REX_CLOSE, &arg, &garbage_bool, aclnt_cb_null);
}

inline
rexfd::rexfd (rexchannel *pch, int fd)
  : pch (pch), proxy (pch->get_proxy ()), channo (pch->get_channo ()),
    fd (fd)
{
/*   warn << "--reached rexfd\n"; */
  if (fd < 0)
    fatal ("attempt to create negative fd: %d\n", fd);
  pch->insert_fd (fd, mkref (this));
}

inline void
rexfd::abort ()
{
  rex_payload payarg;
  payarg.channel = channo;
  payarg.fd = fd;
  payarg.data.set ((char *)NULL, 0);
  proxy->call (REX_DATA, &payarg, &garbage_bool, aclnt_cb_null);
  
  pch->remove_fd (fd); 
}

inline void
rexfd::data (svccb *sbp)
{
  rex_payload *argp = sbp->template getarg<rex_payload> ();
  if (!argp->data.size ()) {
    rex_payload payarg;
    payarg.channel = channo;
    payarg.fd = fd;
    payarg.data.set ((char *)NULL, 0);
    proxy->call (REX_DATA, &payarg, &garbage_bool, aclnt_cb_null);
    
    pch->remove_fd (fd); 
  }
#if 0      
  str data (argp->data.base (), argp->data.size ());  
  warn ("received data on dummy fd: %s\n", data.cstr ());
#endif      
  sbp->replyref (true);
}

inline void
unixfd::update_connstate (int how, int) {
  if (localfd < 0)
    return;
  
  if      (how == SHUT_WR) weof = true; 
  else if (how == SHUT_RD) reof = true;
  else weof = reof = true;

  if (how == SHUT_WR)
    paios->sendeof ();
  
  if (weof && reof) {
    localfd = -1;
    pch->remove_fd (fd);
  }
}

void
rexchannel::remove_fd (int fdn)
{
/*   warn << "--reached remove_fd (" << fdn << "), fdc = " << fdc << "\n"; */
  vfds[fdn] = NULL;
  if (!--fdc)
    sess->remove_chan (channo); 
}

void
rexchannel::abort () {
  size_t lvfds = vfds.size ();
  for (size_t f = 0; f < lvfds; f++)
    if (vfds[f])
      vfds[f]->abort ();
}

#endif /* _SFSMISC_REX_H_ */
