"""
Copyright 2013 Steven Diamond, 2017 Akshay Agrawal, 2017 Robin Verschueren

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from cvxpy import problems
from cvxpy.expressions import cvxtypes
from cvxpy.expressions.expression import Expression
from cvxpy.reductions.inverse_data import InverseData
from cvxpy.reductions.reduction import Reduction
from cvxpy.reductions.solution import Solution


class Canonicalization(Reduction):
    """Recursively canonicalize each expression in a problem.

    This reduction recursively canonicalizes every expression tree in a
    problem, visiting each node. At every node, this reduction first
    canonicalizes its arguments; it then canonicalizes the node, using the
    canonicalized arguments.

    The attribute `canon_methods` is a dictionary
    mapping node types to functions that canonicalize them; the signature
    of these canonicalizing functions must be

        def canon_func(expr, canon_args) --> (new_expr, constraints)

    where `expr` is the `Expression` (node) to canonicalize, canon_args
    is a list of the canonicalized arguments of this expression,
    `new_expr` is a canonicalized expression, and `constraints` is a list
    of constraints introduced while canonicalizing `expr`.

    Attributes:
    ----------
        canon_methods : dict
            A dictionary mapping node types to canonicalization functions.
        problem : Problem
            A problem owned by this reduction.
    """
    def __init__(self, canon_methods, problem=None) -> None:
        super(Canonicalization, self).__init__(problem=problem)
        self.canon_methods = canon_methods

    def apply(self, problem):
        """Recursively canonicalize the objective and every constraint."""
        inverse_data = InverseData(problem)

        canon_objective, canon_constraints = self.canonicalize_tree(
            problem.objective)

        for constraint in problem.constraints:
            # canon_constr is the constraint rexpressed in terms of
            # its canonicalized arguments, and aux_constr are the constraints
            # generated while canonicalizing the arguments of the original
            # constraint
            canon_constr, aux_constr = self.canonicalize_tree(
                constraint)
            canon_constraints += aux_constr + [canon_constr]
            inverse_data.cons_id_map.update({constraint.id: canon_constr.id})

        new_problem = problems.problem.Problem(canon_objective,
                                               canon_constraints)
        return new_problem, inverse_data

    def invert(self, solution, inverse_data):
        pvars = {vid: solution.primal_vars[vid] for vid in inverse_data.id_map
                 if vid in solution.primal_vars}
        dvars = {orig_id: solution.dual_vars[vid]
                 for orig_id, vid in inverse_data.cons_id_map.items()
                 if vid in solution.dual_vars}

        return Solution(solution.status, solution.opt_val, pvars, dvars,
                        solution.attr)

    def canonicalize_tree(self, expr, canonicalize_params: bool = True):
        """Recursively canonicalize an Expression.
        
        Args:
            expr: Expression to canonicalize.
            canonicalize_params: Should constant subtrees 
                containing parameters be canonicalized?
        
        Returns:
            canonicalized expression, constraints
        """
        # TODO don't copy affine expressions?
        if type(expr) == cvxtypes.partial_problem():
            canon_expr, constrs = self.canonicalize_tree(
              expr.args[0].objective.expr, canonicalize_params=canonicalize_params)
            for constr in expr.args[0].constraints:
                canon_constr, aux_constr = self.canonicalize_tree(
                    constr, 
                    canonicalize_params=canonicalize_params
                )
                constrs += [canon_constr] + aux_constr
        else:
            canon_args = []
            constrs = []
            for arg in expr.args:
                canon_arg, c = self.canonicalize_tree(
                    arg, 
                    canonicalize_params=canonicalize_params
                )
                canon_args += [canon_arg]
                constrs += c
            canon_expr, c = self.canonicalize_expr(
                expr, 
                canon_args, 
                canonicalize_params=canonicalize_params
            )
            constrs += c
        return canon_expr, constrs

    def canonicalize_expr(
            self, 
            expr: Expression, 
            args: list, 
            canonicalize_params: bool = True
        ):
        """Canonicalize an expression, w.r.t. canonicalized arguments.
        
        Args:
            expr: Expression to canonicalize.
            args: Arguments to the expression.
            canonicalize_params: Should constant subtrees 
                containing parameters be canonicalized?
        
        Returns:
            canonicalized expression, constraints
        """
        if isinstance(expr, Expression):
            # Constant trees are collapsed, but parameter trees are preserved
            # when canonicalize_params = True. Otherwise parameters
            # are collapsed as well.
            canon_with_params = canonicalize_params and expr.parameters() 
            skip_canon = expr.is_constant() and not canon_with_params
        else:
            skip_canon = False

        if skip_canon:
            return expr, []
        if type(expr) in self.canon_methods:
            return self.canon_methods[type(expr)](expr, args)
        else:
            return expr.copy(args), []
