/*
 * Copyright 2009 the original author or authors.
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.codenarc.rule.basic

import org.codehaus.groovy.ast.MethodNode
import org.codehaus.groovy.ast.stmt.ReturnStatement
import org.codenarc.rule.AbstractAstVisitor
import org.codenarc.rule.AbstractAstVisitorRule
import org.codenarc.util.AstUtil
import org.codehaus.groovy.ast.expr.*

/**
 * This rule detects when null is returned from a method that might return a
 * collection. Instead of null, return a zero length array.
 * 
 * @author Hamlet D'Arcy
 * @version $Revision: 602 $ - $Date: 2011-02-08 16:37:40 -0500 (Tue, 08 Feb 2011) $
 */
class ReturnsNullInsteadOfEmptyCollectionRule extends AbstractAstVisitorRule {
    String name = 'ReturnsNullInsteadOfEmptyCollection'
    int priority = 2
    Class astVisitorClass = ReturnsNullInsteadOfEmptyCollectionRuleAstVisitor
}

class ReturnsNullInsteadOfEmptyCollectionRuleAstVisitor extends AbstractAstVisitor {

    private static final String ERROR_MSG = 'Returning null from a method that might return a Collection or Map'

    def void visitMethodEx(MethodNode node) {
        if (methodReturnsCollection(node)) {
            // does this method ever return null?
            node.code?.visit(new NullReturnTracker(parent: this, errorMessage: ERROR_MSG))
        }
        super.visitMethodEx(node)
    }

    def void handleClosure(ClosureExpression expression) {
        if (closureReturnsCollection(expression)) {
            // does this closure ever return null?
            expression.code?.visit(new NullReturnTracker(parent: this, errorMessage: ERROR_MSG))
        }
        super.visitClosureExpression(expression)
    }

    private static boolean methodReturnsCollection(MethodNode node) {
        if (AstUtil.classNodeImplementsType(node.returnType, Iterable)) {
            return true
        }
        if (AstUtil.classNodeImplementsType(node.returnType, Map)) {
            return true
        }
        if (AstUtil.classNodeImplementsType(node.returnType, List)) {
            return true
        }
        if (AstUtil.classNodeImplementsType(node.returnType, Collection)) {
            return true
        }
        if (AstUtil.classNodeImplementsType(node.returnType, ArrayList)) {
            return true
        }
        if (AstUtil.classNodeImplementsType(node.returnType, Set)) {
            return true
        }
        if (AstUtil.classNodeImplementsType(node.returnType, HashSet)) {
            return true
        }

        boolean returnsCollection = false
        node.code?.visit(new CollectionReturnTracker(callbackFunction: {returnsCollection = true}))
        returnsCollection
    }

    private static boolean closureReturnsCollection(ClosureExpression node) {
        boolean returnsArray = false
        node.code?.visit(new CollectionReturnTracker(callbackFunction: {returnsArray = true}))
        returnsArray
    }
}

class CollectionReturnTracker extends AbstractAstVisitor {
    def callbackFunction

    def void visitReturnStatement(ReturnStatement statement) {
        expressionReturnsList(statement.expression)
        super.visitReturnStatement(statement)
    }

    private expressionReturnsList(Expression expression) {

        def stack = [expression] as Stack  // as alternative to recursion
        while (stack) {
            expression = stack.pop()
            if (expression instanceof ListExpression || expression instanceof MapExpression) {
                callbackFunction()
            }
            if (expression instanceof ConstructorCallExpression || expression instanceof CastExpression) {
                [Map, Iterable, List, Collection, ArrayList, Set, HashSet].findAll {
                    AstUtil.classNodeImplementsType(expression.type, it)
                }.each {
                    callbackFunction()                    
                }
            }
            if (expression instanceof TernaryExpression) {
                stack.push(expression.trueExpression)
                stack.push(expression.falseExpression)
            }
        }
    }
}