# GNU Enterprise Common - PostgreSQL DB Driver - Schema Introspection
#
# Copyright 2001-2005 Free Software Foundation
#
# This file is part of GNU Enterprise
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# $Id: Introspection.py 6851 2005-01-03 20:59:28Z jcater $

__all__ = ['Introspection']

import string

from gnue.common.datasources import GIntrospection

# =============================================================================
# This class implements schema introspection for PostgreSQL backends
# =============================================================================

class Introspection (GIntrospection.Introspection):


  # list of the types of Schema objects this driver provides
  types = [ ('view' , _('Views') , 1),
            ('table', _('Tables'), 1) ]


  # ---------------------------------------------------------------------------
  # Find a schema element by name and/or type
  # ---------------------------------------------------------------------------

  def find (self, name = None, type = None):
    """
    This function searches the schema for an element by name and/or type. If no
    name and no type is given, all elements will be retrieved.

    @param name: look for an element with this name
    @param type: look for an element with this type
    @return: A sequence of schema instances, one per element found, or None if
        no element could be found.
    """

    result = []
    cond   = ["relname NOT LIKE 'pg_%'"]

    if name is not None:
      cond = [u"relname = '%s'" % name.lower ()]

    reltypes = []
    if type in ('table', 'sources', None):
      reltypes.append ("'r'")
    if type in ('view', 'sources', None):
      reltypes.append ("'v'")

    cond.append (u"relkind in (%s)" % string.join (reltypes, ","))

    cmd = u"SELECT oid, relname, relkind FROM pg_class WHERE %s " \
           "ORDER BY relname" % string.join (cond, " AND ")

    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        attrs = {'id'        : rs [0],
                 'name'      : rs [1],
                 'type'      : rs [2] == 'v' and 'view' or 'table',
                 'indices'   : self.__getIndices (rs [0]),
                 'primarykey': None}

        if attrs ['indices'] is not None:
          for index in attrs ['indices'].values ():
            if index ['primary']:
              attrs ['primarykey'] = index ['fields']
              break

        result.append ( \
          GIntrospection.Schema (attrs, getChildSchema = self._getChildSchema))

    finally:
      cursor.close ()

    return len (result) and result or None


  # ---------------------------------------------------------------------------
  # Get all fields of a relation/view
  # ---------------------------------------------------------------------------

  def _getChildSchema (self, parent):
    """
    This function returns a list of all child elements of a given parent
    relation.

    @param parent: schema object instance whose child elements should be
        fetched.
    @return: sequence of schema instances, one per element found
    """

    result = []

    cmd = u"SELECT attrelid, attname, t.oid, t.typname, attnotnull, " \
           "atthasdef, atttypmod, attnum, attlen " \
           "FROM pg_attribute, pg_type t " \
           "WHERE attrelid = %s AND attnum >= 0 AND t.oid = atttypid " \
           "ORDER BY attnum" % parent.id

    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        attrs = {'id'        : "%s.%s" % (rs [0], rs [7]),
                 'name'      : rs [1],
                 'type'      : 'field',
                 'nativetype': rs [3],
                 'required'  : rs [4] and not rs [5]}

        if rs [3] in ('numeric', 'float4', 'float8', 'money', 'bool', 'int8',
                      'int2', 'int4'):
          attrs ['datatype'] = 'number'

        elif rs [3] in ('date', 'time', 'timestamp', 'abstime', 'reltime'):
          attrs ['datatype'] = 'date'
        else:
          attrs ['datatype'] = 'text'

        if rs [8] > 0:
          attrs ['length'] = rs [8]

        elif rs [6] != -1: #text field
          attrs ['length'] = rs [6] - 4

        # If attribute has default values, we fetch them too
        if rs [5]:
          cmd = u"SELECT adsrc FROM pg_attrdef " \
                 "WHERE adrelid = %s AND adnum = %s" % (parent.id, rs [7])
          defcursor = self._connection.makecursor (cmd)

          try:
            defrs = defcursor.fetchone ()
            if defrs:
              default = defrs [0]

              if default [:8] == 'nextval(':
                attrs ['defaulttype'] = 'sequence'
                attrs ['defaultval']  = default.split ("'") [1]

              elif default == 'now()':
                attrs ['defaulttype'] = 'system'
                attrs ['defaultval']  = 'timestamp'
              else:
                attrs ['defaulttype'] = 'constant'
                attrs ['defaultval']  = default.split ("::") [0]

          finally:
            defcursor.close ()

        result.append (GIntrospection.Schema (attrs))

    finally:
      cursor.close ()

    return result


  # ---------------------------------------------------------------------------
  # Get a dictionary of all indices defined for a relation
  # ---------------------------------------------------------------------------

  def __getIndices (self, relid):
    """
    This function creates a dictionary with all indices of a given relation
    where the keys are the indexnames and the values are dictionaries
    describing the indices. Such a dictionary has the keys 'unique', 'primary'
    and 'fields', where 'unique' specifies whether the index is unique or not
    and 'primary' specifies wether the index is the primary key or not.
    'fields' holds a sequence with all field names building the index.

    @param relid: relation id of the table to fetch indices for
    @return: dictionary with indices or None if no indices were found
    """

    result = {}
    cmd    = u"SELECT c.relname, i.indisunique, i.indisprimary, i.indkey " \
              "FROM pg_index i, pg_class c " \
              "WHERE i.indrelid = %s AND c.oid = i.indexrelid" % relid

    cursor = self._connection.makecursor (cmd)

    try:
      for rs in cursor.fetchall ():
        order = rs [3].split ()
        result [rs [0]] = {'unique' : rs [1],
                           'primary': rs [2],
                           'fields' : []}

        cmd = u"SELECT attname, attnum FROM pg_attribute " \
               "WHERE attrelid = %s AND attnum in (%s)" \
              % (relid, string.join (order, ","))

        fcursor = self._connection.makecursor (cmd)

        try:
          parts = {}
          for frs in fcursor.fetchall ():
            parts ["%s" % frs [1]] = frs [0]

          result [rs [0]] ['fields'] = [parts [ix] for ix in order]

        finally:
          fcursor.close ()


    finally:
      cursor.close ()

    return len (result.keys ()) and result or None

