#include <sys/param.h>
#include <string.h>

#include <asdutil.h>
#include <auth.h>

#include <asd.h>
#include <block.h>
#include <namespace.h>
#include <sample-convert.h>
#include <conjunction.h>
#include <latency.h>

#include "protocol-asd-impl.h"
#include "protocol-asd.h"
#include "idgen.h"

#include "protocol-esound.h"

#include "../sources/source-socket.h"
#include "../sinks/sink-socket.h"

gint protocol_asd_AUTHENTICATE(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_SERVER_VERSION(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_SERVER_INFO(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_STREAM_PLAY(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_LOCK(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_VOLUME_GET(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_VOLUME_SET(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_INFO_SOURCE(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_INFO_SINK(ProtocolAsdClient *client, ProtocolAsdRequest *request);
gint protocol_asd_LIST_SOURCES(ProtocolAsdClient *client, ProtocolAsdRequest *request);

ProtocolAsdHandler protocol_asd_handler_map[] = {
    { PROTOCOL_ASD_COMMAND_AUTHENTICATE,   PROTOCOL_ASD_COOKIE_SIZE,       TRUE, 
      FALSE, protocol_asd_AUTHENTICATE,    "Authenticate"},
    { PROTOCOL_ASD_COMMAND_SERVER_VERSION, 0,                              TRUE, 
      FALSE, protocol_asd_SERVER_VERSION,  "Get server version"},
    { PROTOCOL_ASD_COMMAND_SERVER_INFO,    0,                              TRUE, 
      TRUE, protocol_asd_SERVER_INFO,      "Get server information"},

    { PROTOCOL_ASD_COMMAND_STREAM_PLAY,    sizeof(ProtocolAsdStreamQuery), TRUE,
      TRUE, protocol_asd_STREAM_PLAY,      "Play stream"},
    { PROTOCOL_ASD_COMMAND_STREAM_CAPTURE, sizeof(ProtocolAsdStreamQuery), TRUE, 
      TRUE, protocol_asd_STREAM_PLAY,      "Capture stream"},
    { PROTOCOL_ASD_COMMAND_STREAM_MONITOR, sizeof(ProtocolAsdStreamQuery), TRUE, 
      TRUE, protocol_asd_STREAM_PLAY,      "Monitor stream"},
    { PROTOCOL_ASD_COMMAND_STREAM_SOURCE,  sizeof(ProtocolAsdStreamQuery), TRUE, 
      TRUE, protocol_asd_STREAM_PLAY,      "Source stream"},
    { PROTOCOL_ASD_COMMAND_STREAM_SINK,    sizeof(ProtocolAsdStreamQuery), TRUE, 
      TRUE, protocol_asd_STREAM_PLAY,      "Sink stream"},

    { PROTOCOL_ASD_COMMAND_LOCK,           0,                              TRUE, 
      TRUE, protocol_asd_LOCK,             "Lock"},
    { PROTOCOL_ASD_COMMAND_UNLOCK,         0,                              TRUE, 
      TRUE, protocol_asd_LOCK,             "Unlock"},

    { PROTOCOL_ASD_COMMAND_VOLUME_GET,     sizeof(ProtocolAsdVolumeGetQuery), TRUE, 
      TRUE, protocol_asd_VOLUME_GET,       "Get volume"},
    { PROTOCOL_ASD_COMMAND_VOLUME_SET,     sizeof(ProtocolAsdVolumeSetQuery), TRUE, 
      TRUE, protocol_asd_VOLUME_SET,       "Set volume"},

    { PROTOCOL_ASD_COMMAND_INFO_SOURCE,    sizeof(ProtocolAsdInfoQuery),   TRUE, 
      TRUE, protocol_asd_INFO_SOURCE,      "Get source info"},
    { PROTOCOL_ASD_COMMAND_INFO_SINK,      sizeof(ProtocolAsdInfoQuery),   TRUE, 
      TRUE, protocol_asd_INFO_SINK,        "Get sink info"},
    { PROTOCOL_ASD_COMMAND_LIST_SOURCES,   0,                              TRUE, 
      TRUE, protocol_asd_LIST_SOURCES,     "List sinks"},
    { PROTOCOL_ASD_COMMAND_LIST_SINKS,     0,                              TRUE, 
      TRUE, protocol_asd_LIST_SOURCES,     "List sources"},

    { 0xFFFF,                              0,                              FALSE,
      FALSE, NULL,                         NULL}
};

ProtocolAsdHandler* protocol_asd_find_handler(guint16 command)
{
    guint i;
    for (i = 0; ; i++)
        if (protocol_asd_handler_map[i].proc == NULL)
            return NULL;
        else if (protocol_asd_handler_map[i].command == command)
            return &protocol_asd_handler_map[i];
}

gint protocol_asd_AUTHENTICATE(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  guint8 cookie[PROTOCOL_ASD_COOKIE_SIZE];
  g_assert(client && request);

  if (atomic_read(client->fd, cookie, PROTOCOL_ASD_COOKIE_SIZE) != PROTOCOL_ASD_COOKIE_SIZE)
    return 2;

  if (protocol_asd_auth_locked && (!auth_cookie_available || !auth_compare_cookie(cookie)))
    protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_AUTH_FAILURE);
  else
    if (protocol_asd_write_ack(client, request, 0))
      {
	client->auth = TRUE;
	return 0;
      }

  return 2;
}

gint protocol_asd_SERVER_VERSION(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdServerVersionResponse version;
  g_assert(client && request);

  strncpy(version.version, asd_version, sizeof(version.version));

  if (protocol_asd_write_ack(client, request, sizeof(version)))
    if (atomic_write(client->fd, &version, sizeof(version)) == sizeof(version))
      return 0;

  return 2;
}

gint protocol_asd_SERVER_INFO(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdServerInfoResponse info;
  gchar hn[MAXHOSTNAMELEN] = "localhost";
  g_assert(client && request);

  info.block_size = GUINT32_TO_LE(default_block_size);
  info.sample_type = default_sample_type;
  sample_type_to_le(&info.sample_type);
  info.alloc_blocks = GUINT16_TO_LE(block_count);
  info.asd_auth_locked = protocol_asd_auth_locked;
  info.esound_auth_locked = protocol_esound_auth_locked;
  info.average_latency = GUINT32_TO_LE(latency_global_average());
 
  gethostname(hn, sizeof(hn));
  g_snprintf(info.string, sizeof(info.string), "%s@%s/%i", g_get_user_name(), hn, main_pid);

  strncpy(info.version, asd_version, sizeof(info.version));

  if (protocol_asd_write_ack(client, request, sizeof(info)))
    if (atomic_write(client->fd, &info, sizeof(info)) == sizeof(info))
      return 0;
 
  return 2;
}

// protocol_asd_STREAM_CAPTURE
// protocol_asd_STREAM_MONITOR
gint protocol_asd_STREAM_PLAY(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdStreamQuery query;
  Sink *dev_sink = NULL;
  Source *dev_source = NULL;

  g_assert(client && request);
  
  if (atomic_read(client->fd, &query, sizeof(query)) == sizeof(query))
    {
      gchar t[256];
      gchar shortname[ASD_SHORTNAME_LENGTH];
          
      query.queue_length = GUINT32_FROM_LE(query.queue_length);
      query.queue_hold = GUINT32_FROM_LE(query.queue_hold);
      query.name[sizeof(query.name)-1] = 0;
      query.device[sizeof(query.device)-1] = 0;
      sample_type_from_le(&query.sample_type);
      volume_from_le(&query.volume);

      if (!sample_type_valid(&query.sample_type))
	{
	  protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_SAMPLE_TYPE_INVALID);
	  return 2;
	}

      if ((query.sample_type.bits != 8) && (query.sample_type.bits != 16))
        {
	  protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_SAMPLE_TYPE_NOT_SUPPORTED);
	  return 2;
        }

      g_snprintf(shortname, sizeof(shortname), "client%u", client->id);
      if (!namespace_register(shortname))
        {
          protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_COULD_NOT_REGISTER_NAME);
          return 2;
        }

      if ((request->command == PROTOCOL_ASD_COMMAND_STREAM_PLAY) || (request->command == PROTOCOL_ASD_COMMAND_STREAM_MONITOR)) // Play or monitor
        {
          if (!(dev_sink = (Sink*) namespace_lookup_with_type(query.device, NAMESPACE_SINK)))
            {
              protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_DEVICE_NOT_FOUND);
              namespace_unregister(shortname);
              return 2;
            }
        }
      else if (request->command == PROTOCOL_ASD_COMMAND_STREAM_CAPTURE) // Capture
        {
          if (!(dev_source = (Source*) namespace_lookup_with_type(query.device, NAMESPACE_SOURCE)))
            {
              protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_DEVICE_NOT_FOUND);
              namespace_unregister(shortname);
              return 2;
            }
        }

      pthread_cleanup_push(gc_ref_dec, dev_sink ? (gpointer) dev_sink : (gpointer) dev_source);
      pthread_cleanup_push(MAKE_CLEANUP_HANDLER(namespace_unregister), &shortname[0]);

      g_message("ASD: New client '%s' registered.", shortname);

      if (request->command == PROTOCOL_ASD_COMMAND_STREAM_PLAY || request->command == PROTOCOL_ASD_COMMAND_STREAM_CAPTURE || request->command == PROTOCOL_ASD_COMMAND_STREAM_MONITOR)
        g_message("ASD: Client wants to %s %s.", (request->command == PROTOCOL_ASD_COMMAND_STREAM_PLAY ? "play to" : (request->command == PROTOCOL_ASD_COMMAND_STREAM_MONITOR ? "monitor" : "capture from")), query.device);
      else
        g_message("ASD: Client wants a generic %s.", request->command == PROTOCOL_ASD_COMMAND_STREAM_SOURCE ?  "source" : "sink");

      sample_type_to_string(&query.sample_type, t, sizeof(t));
      g_message("ASD: Sample format is %s.", t);

      volume_to_string(&query.volume, t, sizeof(t));
      g_message("ASD: Volume is %s.", t);

      if (protocol_asd_write_ack(client, request, sizeof(ProtocolAsdStreamResponse)))
	{
	  ProtocolAsdStreamResponse response;

	  strncpy(response.shortname, shortname, sizeof(response.shortname));
	  response.block_size = GUINT32_TO_LE(sample_convert_length(&default_sample_type, &query.sample_type, default_block_size, FALSE));

	  if (atomic_write(client->fd, &response, sizeof(response)) == sizeof(response))
	    {
	      if ((request->command == PROTOCOL_ASD_COMMAND_STREAM_PLAY) || 
                  (request->command == PROTOCOL_ASD_COMMAND_STREAM_SOURCE)) // Play
		{
		  Source *s;
		  g_assert(s = source_socket_new(shortname, query.name, client->fd, &query.sample_type));

                  source_set_volume(s, &query.volume);

                  if (dev_sink)
                    {
                      link_source_sink(s, dev_sink, query.queue_length, query.queue_hold);
                      gc_ref_dec(dev_sink);
                    }

                  if (query.immediate_stop)
                    s->flags |= SOURCE_IMMEDIATE_STOP;

		  source_start(s);
		  gc_ref_dec(s);
		}
	      else
		{
		  Sink *s;
		  g_assert(s = sink_socket_new(shortname, query.name, client->fd, &query.sample_type));

                  sink_set_volume(s, &query.volume);

                  if (dev_source)
                    {
                      link_source_sink(dev_source, s, query.queue_length, query.queue_hold);
                      gc_ref_dec(dev_source);
                    }

                  if (dev_sink)
                    {
                      link_sink_sink(dev_sink, s, query.queue_length, query.queue_hold);
                      gc_ref_dec(dev_sink);
                    }

		  sink_start(s);
		  gc_ref_dec(s);
		}
	      
	      return 1;
	    }
      
	pthread_cleanup_pop(1);
	pthread_cleanup_pop(1);
      }
    }

  return 2;
}

gint protocol_asd_LOCK(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  g_assert(client && request);

  protocol_asd_auth_locked = request->command == PROTOCOL_ASD_COMMAND_LOCK;

  g_message("ASD: Server %s.", protocol_asd_auth_locked ? "locked" : "unlocked");

  if (protocol_asd_write_ack(client, request, 0))
    return 0;

  return 2;
}

gint protocol_asd_VOLUME_GET(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdVolumeGetQuery query;

  g_assert(client && request);

  if (atomic_read(client->fd, &query, sizeof(query)) == sizeof(query))
    {
      ProtocolAsdVolumeGetResponse response;
      Sink *sink = NULL;
      Source *source = NULL;

      query.shortname[sizeof(query.shortname)-1] = 0;

      if ((source = (Source*) namespace_lookup_with_type(query.shortname, NAMESPACE_SOURCE)))
        {
          response.volume = source->volume;
          gc_ref_dec(source);
        }
      else if ((sink = (Sink*) namespace_lookup_with_type(query.shortname, NAMESPACE_SINK)))
        {
          response.volume = sink->volume;
          gc_ref_dec(sink);
        }
      else
        {
          protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_DEVICE_NOT_FOUND);
          return 2;
        }

      volume_to_le(&response.volume);

      if (protocol_asd_write_ack(client, request, sizeof(response)))
        if (atomic_write(client->fd, &response, sizeof(response)) == sizeof(response))
          return 0;
    }
  
  return 2;
}

gint protocol_asd_VOLUME_SET(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdVolumeSetQuery query;
  g_assert(client && request);

  if (atomic_read(client->fd, &query, sizeof(query)) == sizeof(query))  
    {
      Sink *sink = NULL;
      Source *source = NULL;

      query.shortname[sizeof(query.shortname)-1] = 0;
      volume_from_le(&query.volume);

      if ((source = (Source*) namespace_lookup_with_type(query.shortname, NAMESPACE_SOURCE)))
        {
          source_set_volume(source, &query.volume);
          gc_ref_dec(source);
        }
      else if ((sink = (Sink*) namespace_lookup_with_type(query.shortname, NAMESPACE_SINK)))
        {
          sink_set_volume(sink, &query.volume);
          gc_ref_dec(sink);
        }
      else
        {
          protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_DEVICE_NOT_FOUND);
          return 2;
        }

      if (protocol_asd_write_ack(client, request, 0))
        return 0;
    }

  return 2;
}

gint protocol_asd_INFO_SOURCE(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdInfoQuery query;
  g_assert(client && request);

  if (atomic_read(client->fd, &query, sizeof(query)) == sizeof(query))  
    {
      Source *source;

      query.shortname[sizeof(query.shortname)-1] = 0;

      if ((source = (Source*) namespace_lookup_with_type(query.shortname, NAMESPACE_SOURCE)))
        {
          ProtocolAsdInfoSourceResponse response;

          strncpy(response.shortname, source->shortname, sizeof(response.shortname));
          strncpy(response.name, source->name, sizeof(response.name));
          strncpy(response.type, source->type, sizeof(response.type));
          response.sample_type = source->sample_type;
          sample_type_to_le(&response.sample_type);
          response.volume = source->volume;
          volume_to_le(&response.volume);
          response.mode = source->mode;
          response.flags = GUINT32_TO_LE(source->flags);
          response.throughput = GUINT32_TO_LE(throughput_value(source->throughput));
          response.latency = GUINT32_TO_LE(latency_global_get(source->origin_id));
          response.byte_counter = GUINT32_TO_LE(source->throughput->byte_counter);
          
          gc_ref_dec(source);

          if (protocol_asd_write_ack(client, request, sizeof(response)))
            if (atomic_write(client->fd, &response, sizeof(response)) == sizeof(response))            
              return 0;
        }
      else
        protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_DEVICE_NOT_FOUND);
    }
  return 2;
}

gint protocol_asd_INFO_SINK(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  ProtocolAsdInfoQuery query;
  g_assert(client && request);

  if (atomic_read(client->fd, &query, sizeof(query)) == sizeof(query))  
    {
      Sink *sink;

      query.shortname[sizeof(query.shortname)-1] = 0;

      if ((sink = (Sink*) namespace_lookup_with_type(query.shortname, NAMESPACE_SINK)))
        {
          ProtocolAsdInfoSinkResponse response;

          strncpy(response.shortname, sink->shortname, sizeof(response.shortname));
          strncpy(response.name, sink->name, sizeof(response.name));
          strncpy(response.type, sink->type, sizeof(response.type));
          response.sample_type = sink->sample_type;
          sample_type_to_le(&response.sample_type);
          response.volume = sink->volume;
          volume_to_le(&response.volume);
          response.mode = sink->mode;
          response.flags = GUINT32_TO_LE(sink->flags);
          response.throughput = GUINT32_TO_LE(throughput_value(sink->throughput));
          response.byte_counter = GUINT32_TO_LE(sink->throughput->byte_counter);
          
          gc_ref_dec(sink);

          if (protocol_asd_write_ack(client, request, sizeof(response)))
            if (atomic_write(client->fd, &response, sizeof(response)) == sizeof(response))            
              return 0;
        }
      else
        protocol_asd_write_error(client, request, PROTOCOL_ASD_ERROR_DEVICE_NOT_FOUND);
    }
  return 2;
}

typedef struct {
  ProtocolAsdClient *client;
  ProtocolAsdRequest *request;
}_UserData;

static void _proc(gchar *name, gpointer data, gpointer ud)
{
  ProtocolAsdListResponse response;
  _UserData *userdata;

  g_assert(name && data && ud);
  userdata = (_UserData*) ud;

  g_assert(userdata->client && userdata->request);

  if (userdata->request->command == PROTOCOL_ASD_COMMAND_LIST_SOURCES)
    {
      Source *s = (Source*) data;
      strncpy(response.shortname, s->shortname, sizeof(response.shortname));
      strncpy(response.name, s->name, sizeof(response.name));
      strncpy(response.type, s->type, sizeof(response.type));
    }
  else
    {
      Sink *s = (Sink*) data;
      strncpy(response.shortname, s->shortname, sizeof(response.shortname));
      strncpy(response.name, s->name, sizeof(response.name));
      strncpy(response.type, s->type, sizeof(response.type));
    }

  atomic_write(userdata->client->fd, &response, sizeof(response));
}

gint protocol_asd_LIST_SOURCES(ProtocolAsdClient *client, ProtocolAsdRequest *request)
{
  gboolean source = (request->command == PROTOCOL_ASD_COMMAND_LIST_SOURCES);

  if (protocol_asd_write_ack(client, request, -1))
    {
      _UserData userdata;
      ProtocolAsdListResponse response;

      userdata.client = client;
      userdata.request = request;
      
      namespace_foreach(_proc, source ? NAMESPACE_SOURCE : NAMESPACE_SINK, &userdata);
      
      memset(&response, 0, sizeof(response));
      if (atomic_write(client->fd, &response, sizeof(response)) == sizeof(response))            
        return 0;
    }

  return 2;
}
