#
# Copyright (C) 2006 Chris Halls <halls@debian.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

"""This module tests the Fetcher classes"""

import time, os, socket, signal, string, re, base64
from StringIO import StringIO

from twisted.internet import reactor, protocol, defer, error
from twisted.protocols import ftp
from twisted.cred import portal, checkers, credentials
from twisted.python import failure

from apt_proxy.apt_proxy_conf import apConfig
from apt_proxy.apt_proxy import Factory
from apt_proxy.misc import log
from apt_proxy.cache import CacheEntry
from apt_proxy.fetchers import HttpFetcher, FetcherHttpClient, FtpFetcher, Fetcher, \
     RsyncFetcher, DownloadQueue, DownloadQueuePerClient
from apt_proxy.test.test_apt_proxy import apTestHelper, FactoryTestHelper


config1="""
[DEFAULT]
debug=all:9
port=9999
address=
cleanup_freq=off
max_versions=off

[backend1]
backends = http://localhost/nothing-really

[backend2]
backends = http://localhost/nothing-really
http_proxy = user1:password@test:1234

[ftp]
backends = ftp://localhost/nothing-really
"""

class FetcherHttpTest(FactoryTestHelper):
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, config1)
        backend = self.factory.getBackend('backend1')
        backendServer = backend.uris[0]
        httpFetcher = HttpFetcher(backendServer)
        httpFetcher.proxy = None # Would otherwise have been set by httpFetcher.connect
        self.connection = FetcherHttpClient(httpFetcher)
        self.connection.transport = StringIO()
        def endHeaders():
            pass
        self.connection.endHeaders = endHeaders

    def testProxyHeader(self):
        self.connection.proxy = self.factory.getBackend('backend2').config.http_proxy
        self.connection.download(None, '/a/b/c', None)
        fetcherHeaders = self.connection.transport.getvalue()
        authheader = "Proxy-Authorization: Basic " + base64.encodestring('user1:password')[:-1]
        self.assertNotEquals(re.search(r'(?:\n|^)'+authheader+r'\r\n', fetcherHeaders), None)

class FetcherFtpInitTest(FactoryTestHelper):
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, config1)

    def testInit(self):
        "Brief init test"
        backend = self.factory.getBackend('ftp')
        backendServer = backend.uris[0]
        ftpFetcher = FtpFetcher(backendServer)

class FetcherFtpTestHelper(FactoryTestHelper):
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        self.ftpserver = FtpServer()
        port = self.ftpserver.start()
        config = (config1 +
                  "[test_ftp]\n" +
                  "backends=http://127.0.0.1:" + str(port))
        FactoryTestHelper.setUp(self, config)
        self.backend = self.factory.getBackend('test_ftp')
        self.backendServer = self.backend.uris[0]
        self.ftpFetcher = FtpFetcher(self.backendServer)
        self.ftpFetcher.debug = 1
    def tearDown(self):
        # We don't care about deferreds left over e.g. pending connection
        #delayeds = reactor.getDelayedCalls()
        #for d in delayeds:
        #    d.cancel()
        reactor.iterate(0.1) # Process FTP callbacks before disconnecting
        self.ftpFetcher.disconnect()
        self.ftpFetcher = None
        self.ftpserver.stop()
        self.ftpserver = None
        # Allow FTP code to shutdown
        FactoryTestHelper.tearDown(self)
        reactor.iterate(0.1)
        reactor.iterate(0.1)
        reactor.iterate(0.1)

class FetcherFtpTest(FetcherFtpTestHelper):
    def setUp(self):
        FetcherFtpTestHelper.setUp(self)

    def testConnect(self):
        "Test connect"
        return self.ftpFetcher.connect()
    testConnect.timeout = 2

    def testConnectFail(self):
        "Test connect failure"
        self.ftpserver.stop()
        d = self.ftpFetcher.connect()
        def callBack(result):
            raise RuntimeError("Connect should have failed")
        def errorBack(result):
            result.trap(error.ConnectionRefusedError)
        # Reverse meaning of deferred, ie errorback = as expected
        d.addCallbacks(callBack, errorBack)
        return d
    testConnectFail.timeout = 2

class FetchersDummyFetcher:
    def __init__(self, deferred):
        self.deferred = deferred
        self.error_code = None # Anticipated error
        self.wait_for_mtime = False
        self.wait_for_not_found = False
    def download_failed(self, code, reason):
        if self.error_code is not None and \
            self.error_code == code:
            self.deferred.callback()
        else:
            self.deferred.errback(None)
    def server_mtime(self, time):
        if self.wait_for_mtime == True:
            self.deferred.callback(None)
    def file_not_found(self):
        if self.wait_for_not_found == True:
            self.deferred.callback(None)

class FetcherFtpProtocolTest(FetcherFtpTestHelper):
    def setUp(self):
        FetcherFtpTestHelper.setUp(self)
        self.resultCallback = defer.Deferred()
        self.fetcher = FetchersDummyFetcher(self.resultCallback)
        self.fetcher.backendServer = self.backendServer

    def tearDown(self):
        FetcherFtpTestHelper.tearDown(self)

    def testNotFound(self):
        "Test for file not found"
        d = self.ftpFetcher.connect()
        d.addCallback(self.NotFoundConnectCallback)
        return self.resultCallback
    testNotFound.timeout = 1
    def NotFoundConnectCallback(self,result):
        self.fetcher.wait_for_not_found = True
        self.ftpFetcher.download(self.fetcher, 'notHereFile', 0)

    def MtimeConnectCallback(self,result):
        log.debug("connection made", 'FetcherFtpProtocolTest')
        self.fetcher.wait_for_mtime = True
        self.ftpFetcher.download(self.fetcher, 'packages/Packages', 0)

    def testMtime(self):
        "Test mtime request"
        def FetchSize():
            pass
        self.ftpFetcher.ftpFetchSize = FetchSize # We don't want to get size afterwards
        d = self.ftpFetcher.connect()
        d.addCallback(self.MtimeConnectCallback)
        return self.resultCallback
    testMtime.timeout = 1

class FtpServer:
    def start(self):
        """
        Start FTP server, serving test data
        
        @ret port number that server listens on
        
        This routine was hacked from twisted/tap/ftp.py
        """
        root = '../test_data'
        f = ftp.FTPFactory()
        r = ftp.FTPRealm(root)
        f.tld = root
        p = portal.Portal(r)
        p.registerChecker(checkers.AllowAnonymousAccess(), credentials.IAnonymous)

        f.userAnonymous = 'anonymous'
        f.portal = p
        f.protocol = ftp.FTP

        self.port = reactor.listenTCP(0, f, interface="127.0.0.1")
        portnum = self.port.getHost().port
        log.debug("Ftp server listening on port %s" %(portnum))
        self.factory = f
        return portnum

    def stop(self):
        #pass
        self.port.stopListening()
        self.factory.stopFactory()

class RsyncFetcherTest(FactoryTestHelper):
    """
    Set up a cache dir and a factory
    """

    rsync_config="""
[DEFAULT]
debug=all:9
port=9999
address=
cleanup_freq=off
max_versions=off

[rsync]
backends = rsync://127.0.0.1:0/test
"""

    class RsyncDummyFetcher:
        def __init__(self, backend, backendServer):
            self.backend = backend
            self.backendServer = backendServer
            self.cacheEntry = backend.get_cache_entry("testdir/testfile.deb")
        def fetcher_internal_error(self, message):
            log.debug('fetcher_internal_error: %s' % (message))
        def server_mtime(self, time):
            pass

    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, self.rsync_config)
        self.backend = self.factory.getBackend('rsync')
        self.backendServer = self.backend.get_first_server()
        self.f = RsyncFetcher(self.backendServer)
    def tearDown(self):
        self.f.disconnect()
    def testRsyncInit(self):
        self.assertEquals(self.f.backendServer, self.backendServer)
    def testConnect(self):
        return self.f.connect() # connect returns a deferred that fires
    def testDownload(self):
        self.f.connect()
        dummyFetcher = self.RsyncDummyFetcher(self.backend, self.backendServer)
        self.f.download(dummyFetcher, 'test', time.time())

class RsyncServer(protocol.ProcessProtocol):
    """
    Starts an rsync daemon on localhost for testing
    """
    rsyncCommand = '/usr/bin/rsync'
    
    def start(self):
        """
        Start rsync server, serving test data

        @ret port number that server listens on
        """
        self.rsync_dir = '../test_data'

        # Find a port number for the rsync server process:
        # Start listening on a random port, then close it
        s = socket.socket()
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        s.listen(1)
        self.rsync_port = s.getsockname()[1]
        s.close()

        self.rsync_confpath = self.rsync_dir + os.sep + 'testrsync.conf'
        self.write_rsyncconf()

        args = (self.rsyncCommand, '--daemon', '--config=' + self.rsync_confpath, '--verbose', '--no-detach')
        self.rsyncProcess = reactor.spawnProcess(self, self.rsyncCommand, args, None,self.rsync_dir)

        # wait for server to start
        s2 = socket.socket()
        while s2:
            try:
                s2.connect(('127.0.0.1', self.rsync_port))
                break
            except:
                pass
            reactor.iterate(0.1)
        s2.close()

        log.debug("rsync server listening on port %s" %(self.rsync_port))
        return self.rsync_port

    def stop(self):
        if self.rsyncProcess and self.rsyncProcess.pid:
            log.debug("killing rsync child pid " + 
                      str(self.rsyncProcess.pid), 'RsyncServer')
            self.rsyncProcess.loseConnection()
            os.kill(self.rsyncProcess.pid, signal.SIGTERM)

    def write_rsyncconf(self):
        f = open(self.rsync_confpath, 'w')
        f.write("address = 127.0.0.1\n")
        f.write("port = %s\n" % (self.rsync_port))
        f.write("log file = %s\n" %(self.rsync_dir+os.sep+'testrsync.log'))
        f.write("[apt-proxy]\n")
        f.write("path = %s\n" %(self.rsync_dir))
        f.write("use chroot = false\n") # Can't chroot becuase daemon isn't root
        f.close()
        
    def outReceived(self, data):
        "Data received from rsync process to stdout"
        for s in string.split(data, '\n'):
            if len(s):
                log.debug('rsync: ' + s, 'RsyncServer')

    def errReceived(self, data):
        "Data received from rsync process to stderr"
        for s in string.split(data, '\n'):
            if len(s):
                log.err('rsync error: ' + s, 'RsyncServer')

    def processEnded(self, status_object):
        if isinstance(status_object, failure.Failure):
            log.debug("rsync failure: %s" %(status_object)
                  ,'RsyncServer')
        else:
            log.debug("Status: %d" %(status_object.value.exitCode)
                      ,'RsyncServer')

            # Success?
            exitcode = status_object.value.exitCode


class FetcherRsyncTestHelper(FactoryTestHelper):
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        self.rsyncserver = RsyncServer()
        port = self.rsyncserver.start()
        config = (config1 +
                  "[test_rsync]\n" +
                  "backends=http://127.0.0.1:" + str(port) + '/apt-proxy')
        FactoryTestHelper.setUp(self, config)
        self.backend = self.factory.getBackend('test_rsync')
        self.backendServer = self.backend.uris[0]
        self.rsyncFetcher = RsyncFetcher(self.backendServer)
        self.rsyncFetcher.debug = 1
    def tearDown(self):
        # We don't care about deferreds left over e.g. pending connection
        #delayeds = reactor.getDelayedCalls()
        #for d in delayeds:
        #    d.cancel()
        self.rsyncFetcher.disconnect()
        self.rsyncFetcher = None
        self.rsyncserver.stop()
        self.rsyncserver = None
        FactoryTestHelper.tearDown(self)

class FetcherRsyncProtocolTest(FetcherRsyncTestHelper):
    def setUp(self):
        FetcherRsyncTestHelper.setUp(self)
        self.resultCallback = defer.Deferred()
        self.fetcher = FetchersDummyFetcher(self.resultCallback)
        self.fetcher.backendServer = self.backendServer

    def tearDown(self):
        FetcherRsyncTestHelper.tearDown(self)

    def testNotFound(self):
        "Test for file not found"
        d = self.rsyncFetcher.connect()
        d.addCallback(self.NotFound2)
        return self.resultCallback
    testNotFound.timeout = 1
    def NotFound2(self,result):
        self.fetcher.wait_for_not_found = True
        fileName = 'notHereFile'
        self.fetcher.cacheEntry = self.backend.get_cache_entry(fileName)
        self.rsyncFetcher.download(self.fetcher, fileName, 0)

QueueConfig = """
[test_queue]
backends=http://server1/path1
"""

class DummyFetcher:
    def __init__(self, backend):
        self.backend = backend
    def connect(self):
        # We always conect
        d = defer.succeed(True)
        return d
    def disconnect(self):
        pass
    def download(self, fetcher, uri, mtime):
        fetcher.cacheEntry.state = CacheEntry.STATE_DOWNLOAD
        pass
    def server_mtime(self, time):
        pass

class DummyServer:
    fetcher = DummyFetcher
    path = 'Dummy'
    uri = 'dummy://'
class DummyBackend:
    name = 'Dummy'
    def get_first_server(self):
        return DummyServer()
class DummyCacheEntry:
    """
    Class that provides basic CacheEntry information
    """

    STATE_NEW = CacheEntry.STATE_NEW
    STATE_DOWNLOAD = CacheEntry.STATE_DOWNLOAD
    def __init__(self, cache_dir, backend, file):
        self.filename = os.path.basename(file)
        self.path = file
        self.cache_path = backend + os.sep + file
        self.file_path = cache_dir + os.sep + self.cache_path
        self.file_mtime = None
        self.requests = []
        self.state = self.STATE_NEW
    def get_request_mtime(self):
        return None

class DownloadQueueTest(FactoryTestHelper):

    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, QueueConfig)
        self.queue = DownloadQueue()
        self.backend = self.factory.getBackend('test_queue')
    def testInit(self):
        self.assertEquals(len(self.queue.queue), 0)
        self.assertEquals(self.queue.fetcher, None)
        self.assertEquals(self.queue.activeFile, None)
    def testAddFile(self):
        entry = DummyCacheEntry(self.cache_dir, 'test_queue', 'test.deb')
        entry.backend = DummyBackend()
        self.queue.addFile(entry)
        self.assertEquals(len(self.queue.queue), 0)
        self.assertEquals(self.queue.activeFile, entry)
        self.queue.stop() # Cancel timeout CB
    def testDownloadComplete(self):
        entry = DummyCacheEntry(self.cache_dir, 'test_queue', 'test.deb')
        entry.backend = DummyBackend()
        self.queue.addFile(entry)
        self.assertEquals(self.queue.activeFile, entry)
        self.queue.downloadFinished([True, 'Test complete'])
        self.assertEquals(self.queue.activeFile, None)
        self.queue.stop() # Cancel timeout CB

class DummyRequest:
    def __init__(self, fileno=0):
        self.fileno=fileno
        self.finished = False
        self.streamed = 0
    def finishCode(self, code, reason):
        self.finished = True
    def start_streaming(self, file_size, file_mtime):
        self.streamed = self.streamed + 1
    def getFileno(self):
        return self.fileno

class DownloadQueuePerClientTest(FactoryTestHelper):
    def setUp(self):
        """
        Set up a factory using the additional config given
        """
        FactoryTestHelper.setUp(self, QueueConfig)
        self.queue = DownloadQueuePerClient()
        self.backend = self.factory.getBackend('test_queue')
    def testInit(self):
        self.assertEquals(len(self.queue.queues), 0)
    def testSeparateFilesAndClients(self):
        req1 = DummyRequest(123)
        req2 = DummyRequest(234)

        entry1 = DummyCacheEntry(self.cache_dir, 'test_queue', 'test1.deb')
        entry1.requests = [req1]
        entry1.backend = DummyBackend()

        entry2 = DummyCacheEntry(self.cache_dir, 'test_queue', 'test2.deb')
        entry2.requests = [req2]
        entry2.backend = entry1.backend

        self.assertEquals(len(self.queue.queues.keys()), 0)
        self.assertNotEquals(self.queue.queues.has_key(req1.fileno), True)
        self.queue.addFile(entry1)
        self.assertEquals(len(self.queue.queues.keys()), 1)
        self.assertEquals(self.queue.queues[req1.fileno].activeFile, entry1)

        self.queue.addFile(entry2)
        self.assertEquals(len(self.queue.queues.keys()), 2)
        self.assertEquals(self.queue.queues[req2.fileno].activeFile, entry2)

        self.queue.stop() # Cancel timeout CB

    def testSeparateFiles(self):
        req1 = DummyRequest(123)
        req2 = DummyRequest(123)

        entry1 = DummyCacheEntry(self.cache_dir, 'test_queue', 'test1.deb')
        entry1.requests = [req1]
        entry1.backend = DummyBackend()

        entry2 = DummyCacheEntry(self.cache_dir, 'test_queue', 'test2.deb')
        entry2.requests = [req2]
        entry2.backend = entry1.backend

        self.assertEquals(len(self.queue.queues.keys()), 0)
        self.assertNotEquals(self.queue.queues.has_key(req1.fileno), True)
        self.queue.addFile(entry1)
        self.assertEquals(len(self.queue.queues.keys()), 1)
        self.assertEquals(self.queue.queues[req1.fileno].activeFile, entry1)

        self.queue.addFile(entry2)
        self.assertEquals(len(self.queue.queues.keys()), 1)
        # Entry 2 should have been added to the first queue, and entry1 will
        # still be active
        self.assertEquals(self.queue.queues[req2.fileno].activeFile, entry1)
        self.assertEquals(self.queue.queues[req2.fileno].queue[0], entry2)

        self.queue.stop() # Cancel timeout CB

    def testSeparateClients(self):
        # 2 clients requesting 1 file
        req1 = DummyRequest(123)
        req2 = DummyRequest(234)

        entry1 = DummyCacheEntry(self.cache_dir, 'test_queue', 'test1.deb')
        entry1.requests = [req1]
        entry1.backend = DummyBackend()

        self.assertEquals(len(self.queue.queues.keys()), 0)
        self.assertNotEquals(self.queue.queues.has_key(req1.fileno), True)
        self.queue.addFile(entry1)
        self.assertEquals(len(self.queue.queues.keys()), 1)
        self.assertEquals(self.queue.queues[req1.fileno].activeFile, entry1)

        entry2 = entry1
        entry2.requests.append(req2)

        # Entry 2 will have been added to a second queue, but will be immediately
        # dequeued because it is on entry 1's queue
        self.queue.addFile(entry2)
        self.assertEquals(len(self.queue.queues.keys()), 2)
        self.assertEquals(self.queue.queues[req2.fileno].activeFile, None)

        self.queue.stop() # Cancel timeout CB

    def testDownloadComplete(self):
        req = DummyRequest(678)
        entry = DummyCacheEntry(self.cache_dir, 'test_queue', 'test.deb')
        entry.backend = DummyBackend()
        entry.requests = [req]
        self.queue.addFile(entry)
        self.assertEquals(len(self.queue.queues.keys()), 1)
        self.queue.queues[req.fileno].closeFetcher()
        # Check that queue for this client has been removed
        self.assertEquals(len(self.queue.queues.keys()), 0)
        #self.queue.stop() # Cancel timeout CB
