""" The registry of all extension points and extensions. """


# Enthought library imports.
from enthought.traits.api import Dict, HasPrivateTraits, Instance, Property


class ExtensionRegistry(HasPrivateTraits):
    """ The registry of all extension points and extensions. """

    #### 'ExtensionRegistry' interface ########################################

    # The application that the registry is part of.
    application = Instance('enthought.envisage.core.application.Application')

    # All registered extension points.
    extension_points = Property(Dict)

    # All registered extensions (by extension point ID).
    extensions = Property(Dict)
    
    #### Private interface ####################################################
    
    # fixme: This should be: Dict(Class, Class). See Dave Morrill.
    _extension_points = Dict

    # fixme: This should be: Dict(Class, List(Callable)). See Dave Morrill.
    _extension_listeners = Dict

    # fixme: This should be: Dict(Class, List((ExtensionPoint)). See Dave 
    # Morrill.
    _extensions = Dict

    ###########################################################################
    # 'ExtensionRegistry' interface.
    ###########################################################################

    #### Properties ###########################################################

    def _get_extension_points(self):
        """ Returns all registered extension points. """

        return self._extension_points

    def _get_extensions(self):
        """ Returns all registered extensions. """

        return self._extensions
    
    #### Methods ##############################################################
    
    def add_extension_point(self, extension_point):
        """ Adds an extension point to the registry.

        Parameters
        ----------
        extension_point : a *class* that is derived from **ExtensionPoint**
            The extension point being registered
        
        """

        self._extension_points[extension_point] = extension_point
        
        return
    
    def add_extension(self, extension):
        """ Adds an extension to the registry.

        Parameters
        ----------
        extension : an *instance* of a class derived from **ExtensionPoint**.

        """

        # We maintain a list of extensions per extension point.
        extensions = self._extensions.setdefault(extension.__class__, [])
        extensions.append(extension)

        # if there are listeners attached to this extension
        listeners = self._extension_listeners.get(extension.__class__, [])
        for listener in listeners:
            listener(extension)

        return

    def add_extension_listener(self, extension_point, callable):
        """ Adds an extension listener to the registry.

        Parameters
        ----------
        extension_point : a class that is derived from **ExtensionPoint**.
            The extension point for which a listener is being registered.
        callable : a callable method or function 
            The callable to be invoked when extensions are added to 
            the extension point.
        """

        listeners = self._extension_listeners.setdefault(extension_point, [])
        listeners.append(callable)

        return

    def get_extensions(self, extension_point, plugin_id=None, sort=False):
        """ Returns all extensions to the specified extension point.
        
        Parameters
        ----------
        extension_point : a class that is derived from **ExtensionPoint**
            The extension point whose extensions are retrieved.
        plugin_id : a plugin ID
            If specified, only this plugin's extensions to the extension point
            are returned
        sort : Boolean
            If '''True''', the returned list is sorted by start order.

        """

        extensions = []
        for klass in self._extensions:
            if issubclass(klass, extension_point):
                extensions.extend(self._extensions[klass])

        # Filter by plugin ID if specified.
        if plugin_id is not None:
            extensions = [
                extension for extension in extensions
                if extension._definition_.id == plugin_id
            ]

        # Sort them by the start order of the plugin definitions that
        # contributed them.
        if sort:
            self._sort_by_start_order(extensions)
        
        return extensions

    def load_extensions(self, extension_point_id, plugin_id=None, sort=False):
        """ Returns a list of all contributions made to an extension point.

        The difference between this and **get_extensions()** is that this method
        ensures that the plugin that contributed each extension has
        been started.

        Parameters
        ----------
        extension_point_id : a class that is derived from **ExtensionPoint**
            The extension point whose extensions are retrieved.
        plugin_id : a plugin ID
            If specified, only this plugin's extensions to the extension point
            are returned
        sort : Boolean
            If '''True''', the returned list is sorted by start order.

        """

        extensions = self.get_extensions(extension_point_id, plugin_id, sort)
        for extension in extensions:
            self.application.start_plugin(extension._definition_.id)

        return extensions

    ###########################################################################
    # Private interface.
    ###########################################################################

    def _sort_by_start_order(self, extensions):
        """ Sorts a list of extensions by start order. """
        
        # fixme: We might want to add this to the 'Application' interface!
        start_order = self.application.plugin_activator.start_order
        
        def by_start_order(x, y):
            """ Sort the extensions by the plugin start order. """

            ix = start_order.index(x._definition_.id)
            iy = start_order.index(y._definition_.id)

            return cmp(ix, iy)

        extensions.sort(by_start_order)

        return

#### EOF ######################################################################
