##############################################################################
#
# Copyright (c) 2001 Zope Corporation and Contributors. All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL).  A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################

__version__='$Revision: 1.3.6.1 $'[11:-2]

import sys
from traceback import format_exception_only
from types import TupleType

def getSyntaxError(source, mode):
    try:
        compile(source, '<string>', mode)
    except SyntaxError:
        err = format_exception_only(SyntaxError, sys.exc_info()[1])
        err = [line.rstrip() for line in err]
    else:
        err = ['Unknown parser error.']
    return None, err, [], {}

from parser import ParserError
from compiler_2_1.transformer import Transformer

def tryParsing(source, mode):
    if mode == 'eval':
        parser = Transformer().parseexpr
    else:
        parser = Transformer().parsesuite
    try:
        return parser(source), None
    except ParserError:
        return None, getSyntaxError(source, mode)

import MutatingWalker
from RestrictionMutator import RestrictionMutator
from compiler_2_1 import ast, visitor, pycodegen

class RestrictedCodeGenerator:
    """Mixin for CodeGenerator to replace UNPACK_SEQUENCE bytecodes.

    The UNPACK_SEQUENCE opcode is not safe because it extracts
    elements from a sequence without using a safe iterator or
    making __getitem__ checks.

    This code generator replaces use of UNPACK_SEQUENCE with calls to
    a function that unpacks the sequence, performes the appropriate
    security checks, and returns a simple list.
    """

    # Replace the standard code generator for assignments to tuples
    # and lists.

    def _gen_safe_unpack_sequence(self, num):
        # We're at a place where UNPACK_SEQUENCE should be generated, to
        # unpack num items.  That's a security hole, since it exposes
        # individual items from an arbitrary iterable.  We don't remove
        # the UNPACK_SEQUENCE, but instead insert a call to our _getiter_()
        # wrapper first.  That applies security checks to each item as
        # it's delivered.  codegen is (just) a bit messy because the
        # iterable is already on the stack, so we have to do a stack swap
        # to get things in the right order.
        self.emit('LOAD_GLOBAL', '_getiter_')
        self.emit('ROT_TWO')
        self.emit('CALL_FUNCTION', 1)
        self.emit('UNPACK_SEQUENCE', num)

    def _visitAssSequence(self, node):
        if pycodegen.findOp(node) != 'OP_DELETE':
            self._gen_safe_unpack_sequence(len(node.nodes))
        for child in node.nodes:
            self.visit(child)

    visitAssTuple = _visitAssSequence
    visitAssList = _visitAssSequence

    # Call to generate code for unpacking nested tuple arguments
    # in function calls.

    def unpackSequence(self, tup):
        self._gen_safe_unpack_sequence(len(tup))
        for elt in tup:
            if isinstance(elt, TupleType):
                self.unpackSequence(elt)
            else:
                self._nameOp('STORE', elt)

# Create variants of the standard code generators that mixin the
# RestrictedCodeGenerator.  The various classes must be hooked
# together via an initClass() call.

class RNestedFunctionCodeGenerator(RestrictedCodeGenerator,
                                   pycodegen.NestedFunctionCodeGenerator):
    pass

class RNestedScopeModuleCodeGenerator(
    RestrictedCodeGenerator,
    pycodegen.NestedScopeModuleCodeGenerator):

    def initClass(self):
        pycodegen.NestedScopeMixin.initClass(self)
        self.__class__.FunctionGen = RNestedFunctionCodeGenerator

class RModuleCodeGenerator(RestrictedCodeGenerator,
                           pycodegen.ModuleCodeGenerator):

    def initClass(self):
        pycodegen.LGBScopeMixin.initClass(self)
        self.__class__.FunctionGen = RNestedFunctionCodeGenerator

def compile_restricted_function(p, body, name, filename):
    '''Compile a restricted code object for a function.

    The function can be reconstituted using the 'new' module:

    new.function(<code>, <globals>)
    '''
    rm = RestrictionMutator()
    # Parse the parameters and body, then combine them.
    tree, err = tryParsing('def f(%s): pass' % p, 'exec')
    if err:
        if len(err) > 1:
            # Drop the first line of the error and adjust the next two.
            err[1].pop(0)
            err[1][0] = 'parameters: %s\n' % err[1][0][10:-8]
            err[1][1] = '  ' + err[1][1]
        return err
    f = tree.node.nodes[0]
    btree, err = tryParsing(body, 'exec')
    if err: return err
    f.code.nodes = btree.node.nodes
    f.name = name
    # Look for a docstring
    stmt1 = f.code.nodes[0]
    if (isinstance(stmt1, ast.Discard) and
        isinstance(stmt1.expr, ast.Const) and
        type(stmt1.expr.value) is type('')):
        f.doc = stmt1.expr.value
    MutatingWalker.walk(tree, rm)
    if rm.errors:
        return None, rm.errors, rm.warnings, rm.used_names
    gen = RNestedScopeModuleCodeGenerator(filename)
    visitor.walk(tree, gen)
    return gen.getCode(), (), rm.warnings, rm.used_names

def get_mutated_ast(s, filename='<string>'):
    '''Return an AST, mutated via RestrictionMutator.'''
    result = {}
    result['tree'], result['err'] = tryParsing(s, 'exec')
    if not result['err']:
        result['rm'] = rm = RestrictionMutator()
        MutatingWalker.walk(result['tree'], rm)
    return result

def compile_restricted_exec(s, filename='<string>', nested_scopes=1):
    '''Compile a restricted code suite.'''
    ast_info = get_mutated_ast(s, filename)
    if ast_info['err']:
        return ast_info['err']
    rm = ast_info['rm']
    if rm.errors:
        return None, rm.errors, rm.warnings, rm.used_names
    if nested_scopes:
        gen = RNestedScopeModuleCodeGenerator(filename)
    else:
        gen = RModuleCodeGenerator(filename)
    visitor.walk(ast_info['tree'], gen)
    return gen.getCode(), (), rm.warnings, rm.used_names

if 1:
    def compile_restricted_eval(s, filename='<string>', nested_scopes=1):
        '''Compile a restricted expression.'''
        r = compile_restricted_exec('def f(): return \\\n' + s, filename,
                                    nested_scopes)
        err = r[1]
        if err:
            if len(err) > 1:
                err.pop(0) # Discard first line of error
        else:
            # Extract the code object representing the function body
            r = (r[0].co_consts[1],) + r[1:]
        return r

else:

    def compile_restricted_eval(s, filename='<string>'):
        '''Compile a restricted expression.'''
        rm = RestrictionMutator()
        tree, err = tryParsing(s, 'eval')
        if err:
            err[1].pop(0) # Discard first line of error
            return err
        MutatingWalker.walk(tree, rm)
        if rm.errors:
            return None, rm.errors, rm.warnings, rm.used_names
        # XXX No "EvalCodeGenerator" exists
        # so here's a hack that gets around it.
        gen = pycodegen.ModuleCodeGenerator(filename)
        gen.emit('SET_LINENO', 0)
        visitor.walk(tree, gen)
        gen.emit('RETURN_VALUE')
        return gen.getCode(), (), rm.warnings, rm.used_names

DEBUG = 0
def compile_restricted(source, filename, mode):
    '''Returns restricted compiled code. The signature of this
    function should match the signature of the builtin compile.'''
    if DEBUG:
        from time import clock
        start = clock()

    if mode == 'eval':
        r = compile_restricted_eval(source, filename)
    elif mode == 'exec':
        r = compile_restricted_exec(source, filename)
    else:
        raise ValueError, "compile_restricted() arg 3 must be 'exec' or 'eval'"

    if DEBUG:
        end = clock()
        print 'compile_restricted: %d ms for %s' % (
            (end - start) * 1000, repr(filename))
    code, errors, warnings, used_names = r
    if errors:
        raise SyntaxError, errors[0]
    return code
