"""

Copyright 2013 Steven Diamond, 2017 Robin Verschueren

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest

import numpy as np
import scipy.sparse as sp
from scipy.linalg import lstsq

import cvxpy as cp
import cvxpy.tests.solver_test_helpers as sths
from cvxpy import Maximize, Minimize, Parameter, Problem
from cvxpy.atoms import (
    QuadForm,
    abs,
    huber,
    matrix_frac,
    norm,
    power,
    quad_over_lin,
    sum,
    sum_squares,
)
from cvxpy.expressions.variable import Variable
from cvxpy.reductions.solvers.defines import INSTALLED_SOLVERS, QP_SOLVERS
from cvxpy.tests.base_test import BaseTest
from cvxpy.tests.solver_test_helpers import StandardTestLPs, StandardTestQPs


class TestQp(BaseTest):
    """ Unit tests for the domain module. """

    def setUp(self) -> None:
        self.a = Variable(name='a')
        self.b = Variable(name='b')
        self.c = Variable(name='c')

        self.x = Variable(2, name='x')
        self.y = Variable(3, name='y')
        self.z = Variable(2, name='z')
        self.w = Variable(5, name='w')

        self.A = Variable((2, 2), name='A')
        self.B = Variable((2, 2), name='B')
        self.C = Variable((3, 2), name='C')

        self.slope = Variable(1, name='slope')
        self.offset = Variable(1, name='offset')
        self.quadratic_coeff = Variable(1, name='quadratic_coeff')

        T = 30
        self.position = Variable((2, T), name='position')
        self.velocity = Variable((2, T), name='velocity')
        self.force = Variable((2, T - 1), name='force')

        self.xs = Variable(80, name='xs')
        self.xsr = Variable(50, name='xsr')
        self.xef = Variable(80, name='xef')

        # Check for all installed QP solvers
        self.solvers = [x for x in QP_SOLVERS if x in INSTALLED_SOLVERS]

        def is_mosek_available():
            """Check if MOSEK is installed and a license is available."""
            if 'MOSEK' not in INSTALLED_SOLVERS:
                return False
            try:
                import mosek  # type: ignore
                env = mosek.Env()
                # Try to get license status (returns 0 if OK)
                status = env.getlicense()
                return status == mosek.rescode.ok
            except Exception:
                return False

        def is_knitro_available():
            """Check if KNITRO is installed and a license is available."""
            if 'KNITRO' not in INSTALLED_SOLVERS:
                return False
            try:
                import knitro  # type: ignore
                # Try to create and delete a Knitro solver instance
                kc = knitro.KN_new()
                if kc is None:
                    return False
                knitro.KN_free(kc)
                return True
            except Exception:
                return False

        def is_xpress_available():
            """Check if XPRESS is installed and a license is available."""
            if 'XPRESS' not in INSTALLED_SOLVERS:
                return False
            try:
                import xpress  # type: ignore
                env = xpress.env()
                status = env.getlicense()
                return status == 0
            except Exception:
                return False
        # Remove XPRESS if license is not available
        if 'XPRESS' in self.solvers and not is_xpress_available():
            self.solvers.remove('XPRESS')
        if 'MOSEK' in self.solvers and not is_mosek_available():
            self.solvers.remove('MOSEK')
        if 'KNITRO' in self.solvers and not is_knitro_available():
            self.solvers.remove('KNITRO')

    def solve_QP(self, problem, solver_name):
        return problem.solve(solver=solver_name, verbose=False)

    def test_all_solvers(self) -> None:
        for solver in self.solvers:
            self.quad_over_lin(solver)
            self.power(solver)
            self.power_matrix(solver)
            self.square_affine(solver)
            self.quad_form(solver)
            self.affine_problem(solver)
            self.maximize_problem(solver)
            self.abs(solver)

            # Do we need the following functionality?
            # self.norm_2(solver)
            # self.mat_norm_2(solver)

            self.quad_form_coeff(solver)
            self.quad_form_bound(solver)
            self.regression_1(solver)
            self.regression_2(solver)
            self.rep_quad_form(solver)

            # slow tests:
            self.control(solver)
            self.sparse_system(solver)
            self.smooth_ridge(solver)
            self.huber_small(solver)
            self.huber(solver)
            self.equivalent_forms_1(solver)
            self.equivalent_forms_2(solver)
            self.equivalent_forms_3(solver)

    def quad_over_lin(self, solver) -> None:
        p = Problem(Minimize(0.5 * quad_over_lin(abs(self.x-1), 1)),
                    [self.x <= -1])
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(np.array([-1., -1.]),
                                        var.value, places=4)
        for con in p.constraints:
            self.assertItemsAlmostEqual(np.array([2., 2.]),
                                        con.dual_value, places=4)

    def abs(self, solver) -> None:
        u = Variable(2)
        constr = []
        constr += [abs(u[1] - u[0]) <= 100]
        prob = Problem(Minimize(sum_squares(u)), constr)
        print("The problem is QP: ", prob.is_qp())
        self.assertEqual(prob.is_qp(), True)
        result = prob.solve(solver=solver)
        self.assertAlmostEqual(result, 0)

    def power(self, solver) -> None:
        p = Problem(Minimize(sum(power(self.x, 2))), [])
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual([0., 0.], var.value, places=4)

    def power_matrix(self, solver) -> None:
        p = Problem(Minimize(sum(power(self.A - 3., 2))), [])
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual([3., 3., 3., 3.],
                                        var.value, places=4)

    def square_affine(self, solver) -> None:
        A = np.random.randn(10, 2)
        b = np.random.randn(10)
        p = Problem(Minimize(sum_squares(A @ self.x - b)))
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(lstsq(A, b)[0].flatten(order='F'), var.value,
                                        places=1)

    def quad_form(self, solver) -> None:
        np.random.seed(0)
        A = np.random.randn(5, 5)
        z = np.random.randn(5)
        P = A.T.dot(A)
        q = -2*P.dot(z)
        p = Problem(Minimize(QuadForm(self.w, P) + q.T @ self.w))
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(z, var.value, places=4)

    def rep_quad_form(self, solver) -> None:
        """A problem where the quad_form term is used multiple times.
        """
        np.random.seed(0)
        A = np.random.randn(5, 5)
        z = np.random.randn(5)
        P = A.T.dot(A)
        q = -2*P.dot(z)
        qf = QuadForm(self.w, P)
        p = Problem(Minimize(0.5*qf + 0.5*qf + q.T @ self.w))
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(z, var.value, places=4)

    def affine_problem(self, solver) -> None:
        A = np.random.randn(5, 2)
        A = np.maximum(A, 0)
        b = np.random.randn(5)
        b = np.maximum(b, 0)
        p = Problem(Minimize(sum(self.x)), [self.x >= 0, A @ self.x <= b])
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual([0., 0.], var.value, places=3)

    def maximize_problem(self, solver) -> None:
        A = np.random.randn(5, 2)
        A = np.maximum(A, 0)
        b = np.random.randn(5)
        b = np.maximum(b, 0)
        p = Problem(Maximize(-sum(self.x)), [self.x >= 0, A @ self.x <= b])
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual([0., 0.], var.value, places=3)

    def norm_2(self, solver) -> None:
        A = np.random.randn(10, 5)
        b = np.random.randn(10)
        p = Problem(Minimize(norm(A @ self.w - b, 2)))
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(lstsq(A, b)[0].flatten(order='F'), var.value,
                                        places=1)

    def mat_norm_2(self, solver) -> None:
        A = np.random.randn(5, 3)
        B = np.random.randn(5, 2)
        p = Problem(Minimize(norm(A @ self.C - B, 2)))
        s = self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(lstsq(A, B)[0],
                                        s.primal_vars[var.id], places=1)

    def quad_form_coeff(self, solver) -> None:
        np.random.seed(0)
        A = np.random.randn(5, 5)
        z = np.random.randn(5)
        P = A.T.dot(A)
        q = -2*P.dot(z)
        p = Problem(Minimize(QuadForm(self.w, P) + q.T @ self.w))
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(z, var.value, places=4)

    def quad_form_bound(self, solver) -> None:
        P = np.array([[13, 12, -2], [12, 17, 6], [-2, 6, 12]])
        q = np.array([[-22], [-14.5], [13]])
        r = 1
        y_star = np.array([[1], [0.5], [-1]])
        p = Problem(Minimize(0.5*QuadForm(self.y, P) + q.T @ self.y + r),
                    [self.y >= -1, self.y <= 1])
        self.solve_QP(p, solver)
        for var in p.variables():
            self.assertItemsAlmostEqual(y_star, var.value, places=4)

    def regression_1(self, solver) -> None:
        np.random.seed(1)
        # Number of examples to use
        n = 100
        # Specify the true value of the variable
        true_coeffs = np.array([[2, -2, 0.5]]).T
        # Generate data
        x_data = np.random.rand(n) * 5
        x_data = np.atleast_2d(x_data)
        x_data_expanded = np.vstack([np.power(x_data, i)
                                     for i in range(1, 4)])
        x_data_expanded = np.atleast_2d(x_data_expanded)
        y_data = x_data_expanded.T.dot(true_coeffs) + 0.5 * np.random.rand(n, 1)
        y_data = np.atleast_2d(y_data)

        line = self.offset + x_data * self.slope
        residuals = line.T - y_data
        fit_error = sum_squares(residuals)
        p = Problem(Minimize(fit_error), [])
        self.solve_QP(p, solver)
        self.assertAlmostEqual(1171.60037715, p.value, places=4)

    def regression_2(self, solver) -> None:
        np.random.seed(1)
        # Number of examples to use
        n = 100
        # Specify the true value of the variable
        true_coeffs = np.array([2, -2, 0.5])
        # Generate data
        x_data = np.random.rand(n) * 5
        x_data_expanded = np.vstack([np.power(x_data, i)
                                     for i in range(1, 4)])
        print(x_data_expanded.shape, true_coeffs.shape)
        y_data = x_data_expanded.T.dot(true_coeffs) + 0.5 * np.random.rand(n)

        quadratic = self.offset + x_data * self.slope + \
            self.quadratic_coeff*np.power(x_data, 2)
        residuals = quadratic.T - y_data
        fit_error = sum_squares(residuals)
        p = Problem(Minimize(fit_error), [])
        self.solve_QP(p, solver)

        self.assertAlmostEqual(139.225660756, p.value, places=4)

    def control(self, solver) -> None:
        # Some constraints on our motion
        # The object should start from the origin, and end at rest
        initial_velocity = np.array([-20, 100])
        final_position = np.array([100, 100])
        T = 30  # The number of timesteps
        h = 0.1  # The time between time intervals
        mass = 1  # Mass of object
        drag = 0.1  # Drag on object
        g = np.array([0, -9.8])  # Gravity on object
        # Create a problem instance
        constraints = []
        # Add constraints on our variables
        for i in range(T - 1):
            constraints += [self.position[:, i + 1] == self.position[:, i] +
                            h * self.velocity[:, i]]
            acceleration = self.force[:, i]/mass + g - \
                drag * self.velocity[:, i]
            constraints += [self.velocity[:, i + 1] == self.velocity[:, i] +
                            h * acceleration]

        # Add position constraints
        constraints += [self.position[:, 0] == 0]
        constraints += [self.position[:, -1] == final_position]
        # Add velocity constraints
        constraints += [self.velocity[:, 0] == initial_velocity]
        constraints += [self.velocity[:, -1] == 0]
        # Solve the problem
        p = Problem(Minimize(.01 * sum_squares(self.force)), constraints)
        self.solve_QP(p, solver)
        self.assertAlmostEqual(1059.616, p.value, places=1)

    def sparse_system(self, solver) -> None:
        m = 100
        n = 80
        np.random.seed(1)
        density = 0.4
        A = sp.random_array((m, n), density=density)
        b = np.random.randn(m)

        p = Problem(Minimize(sum_squares(A @ self.xs - b)), [self.xs == 0])
        self.solve_QP(p, solver)
        self.assertAlmostEqual(b.T.dot(b), p.value, places=4)

    def smooth_ridge(self, solver) -> None:
        np.random.seed(1)
        n = 50
        k = 20
        eta = 1

        A = np.ones((k, n))
        b = np.ones((k))
        obj = sum_squares(A @ self.xsr - b) + \
            eta*sum_squares(self.xsr[:-1]-self.xsr[1:])
        p = Problem(Minimize(obj), [])
        self.solve_QP(p, solver)
        self.assertAlmostEqual(0, p.value, places=4)

    def huber_small(self, solver) -> None:
        # Solve the Huber regression problem
        x = Variable(3)
        objective = sum(huber(x))

        # Solve problem with QP
        p = Problem(Minimize(objective), [x[2] >= 3])
        self.solve_QP(p, solver)
        self.assertAlmostEqual(3, x.value[2], places=4)
        self.assertAlmostEqual(5, objective.value, places=4)

    def huber(self, solver) -> None:
        # Generate problem data
        n = 3
        m = 5
        data = [0.89, 0.39, 0.96, 0.34, 0.68, 0.18, 0.63 ,0.42, 0.51, 0.66, 0.43, 0.77]
        indices = [0, 1, 2, 3, 4, 2, 3, 0, 1, 2, 3, 4]
        indptr = [0, 5, 7, 12]
        A = sp.csc_array((data, indices, indptr), shape=(m,n))
        x_true = np.random.randn(n) / np.sqrt(n)
        ind95 = (np.random.rand(m) < 0.95).astype(float)
        b = A.dot(x_true) + np.multiply(0.5*np.random.randn(m), ind95) \
            + np.multiply(10.*np.random.rand(m), 1. - ind95)

        # Solve the Huber regression problem
        x = Variable(n)
        objective = sum(huber(A @ x - b))

        # Solve problem with QP
        p = Problem(Minimize(objective))
        self.solve_QP(p, solver)
        self.assertAlmostEqual(1.452797819667, objective.value, places=3)
        self.assertItemsAlmostEqual(x.value,
                                    [1.20524645, -0.85271489, -0.50838494],
                                    places=3)

    def equivalent_forms_1(self, solver) -> None:
        m = 100
        n = 80
        r = 70
        np.random.seed(1)
        A = np.random.randn(m, n)
        b = np.random.randn(m)
        G = np.random.randn(r, n)
        h = np.random.randn(r)

        obj1 = .1 * sum((A @ self.xef - b) ** 2)
        cons = [G @ self.xef == h]

        p1 = Problem(Minimize(obj1), cons)
        self.solve_QP(p1, solver)
        self.assertAlmostEqual(p1.value, 68.1119420108, places=4)

    def equivalent_forms_2(self, solver) -> None:
        m = 100
        n = 80
        r = 70
        np.random.seed(1)
        A = np.random.randn(m, n)
        b = np.random.randn(m)
        G = np.random.randn(r, n)
        h = np.random.randn(r)

        # ||Ax-b||^2 = x^T (A^T A) x - 2(A^T b)^T x + ||b||^2
        P = np.dot(A.T, A)
        q = -2*np.dot(A.T, b)
        r = np.dot(b.T, b)

        obj2 = .1*(QuadForm(self.xef, P)+q.T @ self.xef+r)
        cons = [G @ self.xef == h]

        p2 = Problem(Minimize(obj2), cons)
        self.solve_QP(p2, solver)
        self.assertAlmostEqual(p2.value, 68.1119420108, places=4)

    def equivalent_forms_3(self, solver) -> None:
        m = 100
        n = 80
        r = 70
        np.random.seed(1)
        A = np.random.randn(m, n)
        b = np.random.randn(m)
        G = np.random.randn(r, n)
        h = np.random.randn(r)

        # ||Ax-b||^2 = x^T (A^T A) x - 2(A^T b)^T x + ||b||^2
        P = np.dot(A.T, A)
        q = -2*np.dot(A.T, b)
        r = np.dot(b.T, b)
        Pinv = np.linalg.inv(P)

        obj3 = .1 * (matrix_frac(self.xef, Pinv)+q.T @ self.xef+r)
        cons = [G @ self.xef == h]

        p3 = Problem(Minimize(obj3), cons)
        self.solve_QP(p3, solver)
        self.assertAlmostEqual(p3.value, 68.1119420108, places=4)

    def test_warm_start(self) -> None:
        """Test warm start.
        """
        m = 200
        n = 100
        np.random.seed(1)
        A = np.random.randn(m, n)
        b = Parameter(m)

        # Construct the problem.
        x = Variable(n)
        prob = Problem(Minimize(sum_squares(A @ x - b)))

        b.value = np.random.randn(m)
        result = prob.solve(solver="OSQP", warm_start=False)
        result2 = prob.solve(solver="OSQP", warm_start=True)
        self.assertAlmostEqual(result, result2)
        b.value = np.random.randn(m)
        result = prob.solve(solver="OSQP", warm_start=True)
        result2 = prob.solve(solver="OSQP", warm_start=False)
        self.assertAlmostEqual(result, result2)

    def test_gurobi_warmstart(self) -> None:
        """Test Gurobi warm start with a user provided point.
        """
        if cp.GUROBI in INSTALLED_SOLVERS:
            import gurobipy
            m = 4
            n = 3

            y = Variable(nonneg=True)
            X = Variable((m, n))
            X_vals = np.reshape(np.arange(m*n), (m, n))
            prob = Problem(Minimize(y**2 + cp.sum(X)), [X == X_vals])
            X.value = X_vals + 1
            prob.solve(solver=cp.GUROBI, warm_start=True)
            # Check that "start" value was set appropriately.
            model = prob.solver_stats.extra_stats
            model_x = model.getVars()
            assert gurobipy.GRB.UNDEFINED == model_x[0].start
            assert np.isclose(0, model_x[0].x)
            for i in range(1, X.size + 1):
                row = (i - 1) % X.shape[0]
                col = (i - 1) // X.shape[0]
                assert X_vals[row, col] + 1 == model_x[i].start
                assert np.isclose(X.value[row, col], model_x[i].x)

    def test_xpress_warmstart(self) -> None:
        """Test XPRESS warm start with a user provided point.
        """
        if cp.XPRESS in INSTALLED_SOLVERS:
            m = 20
            n = 10
            np.random.seed(1)
            A = np.random.randn(m, n)
            b = Parameter(m)

            # Construct the problem.
            x = Variable(n, integer=True)
            prob = Problem(Minimize(sum_squares(A @ x - b)))

            b.value = np.random.randn(m)
            result = prob.solve(solver=cp.XPRESS, warm_start=False)
            result2 = prob.solve(solver=cp.XPRESS, warm_start=True)
            self.assertAlmostEqual(result, result2)
            x.value = x.value.astype(np.int64)

            xprime = Variable(n, integer=True)
            prob = Problem(Minimize(sum_squares(A @ xprime - b)))
            xprime.value = x.value
            result = prob.solve(solver=cp.XPRESS, warm_start=True)
            result2 = prob.solve(solver=cp.XPRESS, warm_start=False)
            self.assertAlmostEqual(result, result2)

    def test_highs_warmstart(self) -> None:
        """Test warm start.
        """
        if cp.HIGHS in INSTALLED_SOLVERS:
            m = 200
            n = 100
            np.random.seed(1)
            A = np.random.randn(m, n)
            b = Parameter(m)

            # Construct the problem.
            x = Variable(n)
            prob = Problem(Minimize(sum_squares(A @ x - b)))

            b.value = np.random.randn(m)
            result = prob.solve(solver=cp.HIGHS, warm_start=False)
            result2 = prob.solve(solver=cp.HIGHS, warm_start=True)
            self.assertAlmostEqual(result, result2)
            b.value = np.random.randn(m)
            result = prob.solve(solver=cp.HIGHS, warm_start=True)
            result2 = prob.solve(solver=cp.HIGHS, warm_start=False)
            self.assertAlmostEqual(result, result2)

    def test_highs_cvar(self) -> None:
        """Test problem with CVaR constraint from
        https://github.com/cvxpy/cvxpy/issues/2836
        """
        if cp.HIGHS in INSTALLED_SOLVERS:
            # Generate data
            num_stocks = 5
            num_samples = 25
            np.random.seed(1)
            pnl_samples = np.random.uniform(low=0.0, high=1.0, size=(num_samples, num_stocks))
            pnl_expected = pnl_samples.mean(axis=0)

            # Prepare to solve
            quantile = 0.05
            w = cp.Variable(num_stocks, nonneg=True)
            cvar = cp.cvar(pnl_samples @ w, 1 - quantile)
            pnl = w @ pnl_expected

            # Solve
            objective = cp.Maximize(pnl)
            constraints = [cvar <= 0.5]
            problem = cp.Problem(objective, constraints)
            problem.solve(
                solver=cp.HIGHS,
            )
            assert problem.status == cp.OPTIMAL


    def test_piqp_warmstart(self) -> None:
        """Test warm start.
        """
        if cp.PIQP in INSTALLED_SOLVERS:
            m = 200
            n = 100
            np.random.seed(1)
            A = np.random.randn(m, n)
            b = Parameter(m)

            # Construct the problem.
            x = Variable(n)
            prob = Problem(Minimize(sum_squares(A @ x - b)))

            b.value = np.random.randn(m)
            result = prob.solve(solver=cp.PIQP, warm_start=False)
            result2 = prob.solve(solver=cp.PIQP, warm_start=True)
            self.assertAlmostEqual(result, result2)
            b.value = np.random.randn(m)
            result = prob.solve(solver=cp.PIQP, warm_start=True)
            result2 = prob.solve(solver=cp.PIQP, warm_start=False)
            self.assertAlmostEqual(result, result2)

    def test_parametric(self) -> None:
        """Test solve parametric problem vs full problem"""
        x = Variable()
        a = 10
        #  b_vec = [-10, -2., 2., 3., 10.]
        b_vec = [-10, -2.]

        for solver in self.solvers:

            print(solver)
            # Solve from scratch with no parameters
            x_full = []
            obj_full = []
            for b in b_vec:
                obj = Minimize(a * (x ** 2) + b * x)
                constraints = [0 <= x, x <= 1]
                prob = Problem(obj, constraints)
                prob.solve(solver=solver)
                x_full += [x.value]
                obj_full += [prob.value]

            # Solve parametric
            x_param = []
            obj_param = []
            b = Parameter()
            obj = Minimize(a * (x ** 2) + b * x)
            constraints = [0 <= x, x <= 1]
            prob = Problem(obj, constraints)
            for b_value in b_vec:
                b.value = b_value
                prob.solve(solver=solver)
                x_param += [x.value]
                obj_param += [prob.value]

            print(x_full)
            print(x_param)
            for i in range(len(b_vec)):
                self.assertItemsAlmostEqual(x_full[i], x_param[i], places=3)
                self.assertAlmostEqual(obj_full[i], obj_param[i])

    def test_square_param(self) -> None:
        """Test issue arising with square plus parameter.
        """
        a = Parameter(value=1)
        b = Variable()

        obj = Minimize(b ** 2 + abs(a))
        prob = Problem(obj)
        prob.solve(solver="SCS")
        self.assertAlmostEqual(obj.value, 1.0)

    def test_gurobi_time_limit_no_solution(self) -> None:
        """Make sure that if Gurobi terminates due to a time limit before finding a solution:
            1) no error is raised,
            2) solver stats are returned.
            The test is skipped if something changes on Gurobi's side so that:
            - a solution is found despite a time limit of zero,
            - a different termination criteria is hit first.
        """
        from cvxpy import GUROBI
        if GUROBI in INSTALLED_SOLVERS:
            import gurobipy
            objective = Minimize(self.x[0])
            constraints = [self.x[0] >= 1]
            prob = Problem(objective, constraints)
            try:
                prob.solve(solver=GUROBI, TimeLimit=0.0)
            except Exception as e:
                self.fail("An exception %s is raised instead of returning a result." % e)

            extra_stats = None
            solver_stats = getattr(prob, "solver_stats", None)
            if solver_stats:
                extra_stats = getattr(solver_stats, "extra_stats", None)
            self.assertTrue(extra_stats, "Solver stats have not been returned.")

            nb_solutions = getattr(extra_stats, "SolCount", None)
            if nb_solutions:
                self.skipTest("Gurobi has found a solution, the test is not relevant anymore.")

            solver_status = getattr(extra_stats, "Status", None)
            if solver_status != gurobipy.GRB.TIME_LIMIT:
                self.skipTest("Gurobi terminated for a different reason than reaching time limit, "
                              "the test is not relevant anymore.")

        else:
            with self.assertRaises(Exception) as cm:
                prob = Problem(Minimize(norm(self.x, 1)), [self.x == 0])
                prob.solve(solver=GUROBI, TimeLimit=0)
            self.assertEqual(str(cm.exception), "The solver %s is not installed." % GUROBI)

    def test_gurobi_environment(self) -> None:
        """Tests that Gurobi environments can be passed to Model.
        Gurobi environments can include licensing and model parameter data.
        """
        from cvxpy import GUROBI
        if GUROBI in INSTALLED_SOLVERS:
            import gurobipy

            # Set a few parameters to random values close to their defaults
            params = {
                'MIPGap': np.random.random(),  # range {0, INFINITY}
                'AggFill': np.random.randint(10),  # range {-1, MAXINT}
                'PerturbValue': np.random.random(),  # range: {0, INFINITY}
            }

            # Create a custom environment and set some parameters
            custom_env = gurobipy.Env()
            for k, v in params.items():
                custom_env.setParam(k, v)

            # Testing QP Solver Interface
            sth = StandardTestLPs.test_lp_0(solver='GUROBI', env=custom_env)
            model = sth.prob.solver_stats.extra_stats
            for k, v in params.items():
                # https://www.gurobi.com/documentation/9.1/refman/py_model_getparaminfo.html
                name, p_type, p_val, p_min, p_max, p_def = model.getParamInfo(k)
                self.assertEqual(v, p_val)

        else:
            with self.assertRaises(Exception) as cm:
                prob = Problem(Minimize(norm(self.x, 1)), [self.x == 0])
                prob.solve(solver=GUROBI, TimeLimit=0)
            self.assertEqual(str(cm.exception), "The solver %s is not installed." % GUROBI)


@unittest.skipUnless('MPAX' in INSTALLED_SOLVERS, 'MPAX is not installed.')
class TestMPAX(unittest.TestCase):

    def test_mpax_lp_0(self) -> None:
        StandardTestLPs.test_lp_0(solver='MPAX')

    def test_mpax_lp_1(self) -> None:
        StandardTestLPs.test_lp_1(solver='MPAX')

    def test_mpax_lp_2(self) -> None:
        StandardTestLPs.test_lp_2(solver='MPAX')

    def test_mpax_lp_3(self) -> None:
        sth = sths.lp_3()
        with self.assertWarns(Warning):
            sth.prob.solve(solver='MPAX')
            self.assertEqual(sth.prob.status, cp.settings.INFEASIBLE_OR_UNBOUNDED)

    def test_mpax_lp_4(self) -> None:
            sth = sths.lp_4()
            with self.assertWarns(Warning):
                sth.prob.solve(solver='MPAX')
                self.assertEqual(sth.prob.status, cp.settings.INFEASIBLE_OR_UNBOUNDED)

    def test_mpax_lp_5(self) -> None:
        StandardTestLPs.test_lp_5(solver='MPAX')

    def test_mpax_lp_6(self) -> None:
        StandardTestLPs.test_lp_6(solver='MPAX')

    def test_mpax_warmstart(self) -> None:
        x = cp.Variable(shape=(2,), name='x')
        objective = cp.Minimize(-4 * x[0] - 5 * x[1])
        constraints = [2 * x[0] + x[1] <= 3,
                    x[0] + 2 * x[1] <= 3,
                    x[0] >= 0,
                    x[1] >= 0]
        prob = cp.Problem(objective, constraints)
        result1 = prob.solve(solver='MPAX', warm_start=False)
        self.assertAlmostEqual(result1, -9, places=4)
        result2 = prob.solve(solver='MPAX', warm_start=True)
        self.assertAlmostEqual(result2, -9, places=4)

    def test_MPAX_qp_0(self) -> None:
        StandardTestQPs.test_qp_0(solver='MPAX')
