/*+++++++++++++++++
  connect.c - read/write replacements with timeouts
  markus@mhoenicka.de 2-10-00

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   You should have received a copy of the GNU General Public License
   along with this program; if not, see <http://www.gnu.org/licenses/>

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/


/* Parts of this file are inspired considerably by a similar file in
   Wget (Copyright (C) 1995, 1996, 1997 Free Software Foundation,
   Inc.) */


#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <unistd.h>

#include <errno.h>
#include <string.h>

#include "connect.h"
#include "refdb.h"
#include "refstat.h"

extern int n_refdb_timeout; /* timeout in seconds for read/write on sockets */
extern int n_abort_connect; /* if 1, user tries to abort current connection */

const char cs_term[5] = {'\0', '\0', '\0', '\0', '\0'}; /* client-server message terminator */

/* forward declarations of local functions */
static int select_fd (int fd, int maxtime, int writep);
static int get_numstatus(const char* status);


/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  select_fd(): Wait for file descriptor FD to be readable, MAXTIME
  being the timeout in seconds.  If WRITEP is non-zero, checks for FD
  being writable instead.

  static int select_fd returns 1 if FD is accessible, 0 for timeout
                       and -1 for error in select().

  int fd the file descriptor of a socket

  int maxtime the time in seconds that the function will wait before
      a timeout occurs

  int writep if 1, checks for fd being writable. if 0, checks for fd
      being readable

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
static int select_fd (int fd, int maxtime, int writep)
{
  fd_set fds, exceptfds;
  struct timeval timeout;

  FD_ZERO (&fds);
  FD_SET (fd, &fds);
  FD_ZERO (&exceptfds);
  FD_SET (fd, &exceptfds);
  timeout.tv_sec = maxtime;
  timeout.tv_usec = 0;
  /* HPUX reportedly warns here.  What is the correct incantation?  */
  return select (fd + 1, writep ? NULL : &fds, writep ? &fds : NULL,
		 &exceptfds, &timeout);
}


/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  iread(): Read at most LEN bytes from FD, storing them to BUF.  This
  is virtually the same as read(), but takes care of EINTR braindamage
  and uses select() to timeout the stale connections (a connection is
  stale if more than n_timeout time is spent in select() or read()).  

  int iread the number of bytes read from fd, or -1 if timeout

  char *buf a pointer to a character buffer which receives the data

  int len the number of bytes that iread will attempt to read from fd

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int iread (int fd, char *buf, int len) {
  int res;
  int n_byte_read = 0;

  while (len > 0) { /* read until we have all or until timeout occurs */
    do {
      if (n_refdb_timeout) {
	do {
	  res = select_fd (fd, n_refdb_timeout, 0);
	} while (res == -1 && errno == EINTR && !n_abort_connect);
	if (n_abort_connect) {
	  n_abort_connect = 0;
	  return -1;
	}
	if (res <= 0) {
	  /* Set errno to ETIMEDOUT on timeout.  */
	  if (res == 0)
	    /* #### Potentially evil!  */
	    errno = ETIMEDOUT;
	  return -1;
	}
      }
      res = read(fd, buf, len); /* try to read data */
    } while (res == -1 && errno == EINTR);
    if (res <= 0)
      break;
    n_byte_read += res;
    buf += res;
    len -= res;
  }
/*    printf("%d\n", n_byte_read); */
  return n_byte_read;
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  tread(): Read at most LEN bytes from FD, storing them to BUF.  This
  is virtually the same as iread(), but it checks after each success-
  ful read() whether a string is complete (i.e. whether the
  terminating sequence was received). In this case, the function
  returns immediately, instead of timing out, even if less byte than
  requested were received.

  int tread the number of bytes read from fd, or -1 if timeout

  char *buf a pointer to a character buffer which receives the data

  int len the number of bytes that iread will attempt to read from fd

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int tread (int fd, char *buf, int len) {
  int res;
  int n_byte_read = 0;
  char* buf_start;

  buf_start = buf;

  while (len > 0) { /* read until we have all, a complete string, or timeout */
    do {
      if (n_refdb_timeout) {
	do {
	  res = select_fd (fd, n_refdb_timeout, 0);
	} while (res == -1 && errno == EINTR && !n_abort_connect);
	if (n_abort_connect) {
	  n_abort_connect = 0;
	  return -1;
	}
	if (res <= 0) {
	  /* Set errno to ETIMEDOUT on timeout.  */
	  if (res == 0)
	    /* #### Potentially evil!  */
	    errno = ETIMEDOUT;
	  return -1;
	}
      }
      res = read(fd, buf, len); /* read some data */
      if (res > 0) { /* see whether we've got a complete string */
	if (n_byte_read+res>= TERM_LEN
	    && !memcmp((const void*)(buf_start+n_byte_read+res-TERM_LEN),
		       (const void*)cs_term,
		       TERM_LEN)) {
/* 	if (buf[res-1] == '\0') { */ /* complete string received */
	  n_byte_read += res;
	  return n_byte_read; /* get back w/o timeout */
	}
      }
    } while (res == -1 && errno == EINTR);
    if (res <= 0)
      break;
    n_byte_read += res;
    buf += res;
    len -= res;
  }
/*    printf("%d\n", n_byte_read); */
  return n_byte_read;
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  iwrite(): Write LEN bytes from BUF to FD.  This is similar to
  iread().  It makes sure that all of BUF is actually
  written to FD, so callers needn't bother with checking that the
  return value equals to LEN.  Instead, you should simply check
  for -1.

  int iwrite the number of bytes actually written to fd, or -1 if
             timeout

  char *buf a pointer to a character buffer which holds the data

  int len the number of bytes that iwrite will attempt to write to fd

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int iwrite (int fd, const char *buf, int len) {
  int res = 0;
  int n_byte_written = 0;

  /* `write' may write less than LEN bytes, thus the outward loop
     keeps trying it until all was written, or an error occurred.  The
     inner loop is reserved for the usual EINTR f*kage, and the
     innermost loop deals with the same during select().  */
  while (len > 0) {
    do {
      if (n_refdb_timeout) {
	do {
	  res = select_fd (fd, n_refdb_timeout, 1);
	} while (res == -1 && errno == EINTR && !n_abort_connect);
	if (n_abort_connect) {
	  n_abort_connect = 0;
	  return -1;
	}
	if (res <= 0) {
	  /* Set errno to ETIMEDOUT on timeout.  */
	  if (res == 0)
	    /* #### Potentially evil!  */
	    errno = ETIMEDOUT;
	  return -1;
	}
      }
      res = write (fd, buf, len); /* write some data */
      n_byte_written += res;
    } while (res == -1 && errno == EINTR);
    if (res <= 0)
      break;
    buf += res;
    len -= res;
  }
/*    printf("%d\n", n_byte_written); */
  return n_byte_written;
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  tiwrite(): Write bytes from BUF to FD and terminate if requested.
             This is similar to iwrite(). However, it determines the
	     number of bytes to write from the passed string and adds
	     the message terminator automatically if requested.

  int tiwrite the number of bytes actually written to fd, or -1 if
             timeout

  char *buf a pointer to a character buffer which holds the data

  int n_term if 1, send a terminator sequence after sending the buffer

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int tiwrite (int fd, const char *buf, int n_term) {
  size_t buf_len;
  int n_byte_written;
  int n_byte_written_total = 0;

  buf_len = strlen(buf);

  if (buf_len) {
    /* send buffer proper */
    n_byte_written = iwrite(fd, buf, buf_len);
    
    if (n_byte_written == -1) {
      return -1;
    }
    else {
      n_byte_written_total = n_byte_written;
    }
  }

  if (n_term) {
    /* send terminator */
    n_byte_written = iwrite(fd, cs_term, TERM_LEN);
  
    if (n_byte_written == -1) {
      return -1;
    }
    else {
      n_byte_written_total += n_byte_written;
    }
  }
  
  return n_byte_written_total;
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  get_trailz(): counts the trailing \0 in a buffer

  int get_trailz the number of trailing \0, or -1 if buf is NULL

  char *buf a pointer to a buffer

  int numbyte offset at which the function should check the buffer
                (going from offset towards the start of the buffer)

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int get_trailz(const char* buf, int numbyte) {
  int numz = 0;
  int i = numbyte-1;

  if (!buf) {
    return -1;
  }

  while (!buf[i] && i>=0) {
    numz++;
    i--;
  }
  return numz;
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  send_status(): sends the status bytes that precede each message

  int send_status returns 0 if ok, 1 if error

  int fd file descriptor 

  int n_status status code

  int n_term send terminator if 1

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int send_status(int fd, int n_status, int n_term) {
  int n_byte_written;

/*   printf("sending status %d<< term:%d\n", n_status, n_term); */

  if (iwrite(fd, get_status_string(n_status), STATUS_LEN) == -1) {
/*       printf("error sending status %d<< term:%d\n", n_status, n_term); */
    return 1;
  }

  if (n_term) {
    /* send terminator */
    n_byte_written = iwrite(fd, cs_term, TERM_LEN);
  
    if (n_byte_written == -1) {
/*       printf("error sending status %d<< term:%d\n", n_status, n_term); */
      return 1;
    }
  }
  
/*   printf("done sending status %d<< term:%d\n", n_status, n_term); */
  return 0;
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  read_status(): reads the status bytes that precede each message

  int read_status returns 0 if ok, 1 if error

  int fd file descriptor 

  const char* status ptr to the status message

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
int read_status(int fd) {
  char status[STATUS_LEN+1] = {'\0', '\0', '\0', '\0'};

/*   printf("receiving status...\n"); */
  if (iread(fd, status, STATUS_LEN) != STATUS_LEN) {
/*     printf("read error, received status %s<<\n", status); */
    return 1;
  }

  status[3] = '\0';

/*   printf("received status %s<<\n", status); */
  return get_numstatus(status);
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  get_numstatus(): retrieves the numeric status of a status string

  int get_numstatus returns status number if ok, -1 if error

  const char* status ptr to the status string

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
static int get_numstatus(const char* status) {
  int i;

  for (i=0; refstat[i].n_status != 999; i++) {
    if (!strncmp(refstat[i].status, status, STATUS_LEN)) {
      return refstat[i].n_status;
    }
  }

  /* unknown status message - should never happen */
  return 1; /* unspecified error */
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  get_status_string(): retrieves the status string of a given status number

  const char* get_status_string returns status string if ok, "999" if error

  const char* status ptr to the status string

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
const char* get_status_string(int n_status) {
  int i;

  for (i=0; refstat[i].n_status != 999; i++) {
    if (refstat[i].n_status == n_status) {
      return refstat[i].status;
    }
  }

  /* unknown status message - should never happen */
  return "999";
}

/*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  get_status_msg(): retrieves the status message of a given status number

  const char* get_status_msg returns status message if ok,
                             empty string if error

  int n_status status number

  ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*/
const char* get_status_msg(int n_status) {
  int i;

  for (i=0; refstat[i].n_status != 999; i++) {
    if (refstat[i].n_status == n_status) {
      return refstat[i].msg;
    }
  }

  /* unknown status message - should never happen */
  return "undefined error";
}

