#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


from ._scf_ops_gen import *
from ._scf_ops_gen import _Dialect
from .arith import constant

try:
    from ..ir import *
    from ._ods_common import (
        get_op_result_or_value as _get_op_result_or_value,
        get_op_results_or_values as _get_op_results_or_values,
        get_op_result_or_op_results as _get_op_result_or_op_results,
        _cext as _ods_cext,
    )
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

from typing import List, Optional, Sequence, Tuple, Union


@_ods_cext.register_operation(_Dialect, replace=True)
class ForOp(ForOp):
    """Specialization for the SCF for op class."""

    def __init__(
        self,
        lower_bound,
        upper_bound,
        step,
        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
        *,
        loc=None,
        ip=None,
    ):
        """Creates an SCF `for` operation.

        - `lower_bound` is the value to use as lower bound of the loop.
        - `upper_bound` is the value to use as upper bound of the loop.
        - `step` is the value to use as loop step.
        - `iter_args` is a list of additional loop-carried arguments or an operation
          producing them as results.
        """
        if iter_args is None:
            iter_args = []
        iter_args = _get_op_results_or_values(iter_args)

        results = [arg.type for arg in iter_args]
        super().__init__(
            results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
        )
        self.regions[0].blocks.append(self.operands[0].type, *results)

    @property
    def body(self):
        """Returns the body (block) of the loop."""
        return self.regions[0].blocks[0]

    @property
    def induction_variable(self):
        """Returns the induction variable of the loop."""
        return self.body.arguments[0]

    @property
    def inner_iter_args(self):
        """Returns the loop-carried arguments usable within the loop.

        To obtain the loop-carried operands, use `iter_args`.
        """
        return self.body.arguments[1:]


def _dispatch_index_op_fold_results(
    ofrs: Sequence[Union[Operation, OpView, Value, int]],
) -> Tuple[List[Value], List[int]]:
    """`mlir::dispatchIndexOpFoldResults`"""
    dynamic_vals = []
    static_vals = []
    for ofr in ofrs:
        if isinstance(ofr, (Operation, OpView, Value)):
            val = _get_op_result_or_value(ofr)
            dynamic_vals.append(val)
            static_vals.append(ShapedType.get_dynamic_size())
        else:
            static_vals.append(ofr)
    return dynamic_vals, static_vals


@_ods_cext.register_operation(_Dialect, replace=True)
class ForallOp(ForallOp):
    """Specialization for the SCF forall op class."""

    def __init__(
        self,
        lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
        upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
        steps: Sequence[Union[Value, int]],
        shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
        *,
        mapping=None,
        loc=None,
        ip=None,
    ):
        """Creates an SCF `forall` operation.

        - `lower_bounds` are the values to use as lower bounds of the loop.
        - `upper_bounds` are the values to use as upper bounds of the loop.
        - `steps` are the values to use as loop steps.
        - `shared_outs` is a list of additional loop-carried arguments or an operation
          producing them as results.
        """
        assert (
            len(lower_bounds) == len(upper_bounds) == len(steps)
        ), "Mismatch in length of lower bounds, upper bounds, and steps"
        if shared_outs is None:
            shared_outs = []
        shared_outs = _get_op_results_or_values(shared_outs)

        dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
        dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
        dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)

        results = [arg.type for arg in shared_outs]
        super().__init__(
            results,
            dynamic_lbs,
            dynamic_ubs,
            dynamic_steps,
            static_lbs,
            static_ubs,
            static_steps,
            shared_outs,
            mapping=mapping,
            loc=loc,
            ip=ip,
        )
        rank = len(static_lbs)
        iv_types = [IndexType.get()] * rank
        self.regions[0].blocks.append(*iv_types, *results)

    @property
    def body(self) -> Block:
        """Returns the body (block) of the loop."""
        return self.regions[0].blocks[0]

    @property
    def rank(self) -> int:
        """Returns the number of induction variables the loop has."""
        return len(self.staticLowerBound)

    @property
    def induction_variables(self) -> BlockArgumentList:
        """Returns the induction variables usable within the loop."""
        return self.body.arguments[: self.rank]

    @property
    def inner_iter_args(self) -> BlockArgumentList:
        """Returns the loop-carried arguments usable within the loop.

        To obtain the loop-carried operands, use `iter_args`.
        """
        return self.body.arguments[self.rank :]

    def terminator(self) -> InParallelOp:
        """
        Returns the loop terminator if it exists.
        Otherwise, creates a new one.
        """
        ops = self.body.operations
        with InsertionPoint(self.body):
            if not ops:
                return InParallelOp()
            last = ops[len(ops) - 1]
            return last if isinstance(last, InParallelOp) else InParallelOp()


@_ods_cext.register_operation(_Dialect, replace=True)
class InParallelOp(InParallelOp):
    """Specialization of the SCF forall.in_parallel op class."""

    def __init__(self, loc=None, ip=None):
        super().__init__(loc=loc, ip=ip)
        self.region.blocks.append()

    @property
    def block(self) -> Block:
        return self.region.blocks[0]


@_ods_cext.register_operation(_Dialect, replace=True)
class IfOp(IfOp):
    """Specialization for the SCF if op class."""

    def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
        """Creates an SCF `if` operation.

        - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
        - `hasElse` determines whether the if operation has the else branch.
        """
        if results_ is None:
            results_ = []
        operands = []
        operands.append(cond)
        results = []
        results.extend(results_)
        super().__init__(results, cond, loc=loc, ip=ip)
        self.regions[0].blocks.append(*[])
        if hasElse:
            self.regions[1].blocks.append(*[])

    @property
    def then_block(self):
        """Returns the then block of the if operation."""
        return self.regions[0].blocks[0]

    @property
    def else_block(self):
        """Returns the else block of the if operation."""
        return self.regions[1].blocks[0]


def for_(
    start,
    stop=None,
    step=None,
    iter_args: Optional[Sequence[Value]] = None,
    *,
    loc=None,
    ip=None,
):
    if step is None:
        step = 1
    if stop is None:
        stop = start
        start = 0
    params = [start, stop, step]
    for i, p in enumerate(params):
        if isinstance(p, int):
            p = constant(IndexType.get(), p)
        elif isinstance(p, float):
            raise ValueError(f"{p=} must be int.")
        params[i] = p

    start, stop, step = params

    for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
    iv = for_op.induction_variable
    iter_args = tuple(for_op.inner_iter_args)
    with InsertionPoint(for_op.body):
        if len(iter_args) > 1:
            yield iv, iter_args, for_op.results
        elif len(iter_args) == 1:
            yield iv, iter_args[0], for_op.results[0]
        else:
            yield iv


@_ods_cext.register_operation(_Dialect, replace=True)
class IndexSwitchOp(IndexSwitchOp):
    __doc__ = IndexSwitchOp.__doc__

    def __init__(
        self,
        results,
        arg,
        cases,
        case_body_builder=None,
        default_body_builder=None,
        loc=None,
        ip=None,
    ):
        cases = DenseI64ArrayAttr.get(cases)
        super().__init__(
            results, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
        )
        for region in self.regions:
            region.blocks.append()

        if default_body_builder is not None:
            with InsertionPoint(self.default_block):
                default_body_builder(self)

        if case_body_builder is not None:
            for i, case in enumerate(cases):
                with InsertionPoint(self.case_block(i)):
                    case_body_builder(self, i, self.cases[i])

    @property
    def default_region(self) -> Region:
        return self.regions[0]

    @property
    def default_block(self) -> Block:
        return self.default_region.blocks[0]

    @property
    def case_regions(self) -> Sequence[Region]:
        return self.regions[1:]

    def case_region(self, i: int) -> Region:
        return self.case_regions[i]

    @property
    def case_blocks(self) -> Sequence[Block]:
        return [region.blocks[0] for region in self.case_regions]

    def case_block(self, i: int) -> Block:
        return self.case_regions[i].blocks[0]


def index_switch(
    results,
    arg,
    cases,
    case_body_builder=None,
    default_body_builder=None,
    loc=None,
    ip=None,
) -> Union[OpResult, OpResultList, IndexSwitchOp]:
    op = IndexSwitchOp(
        results=results,
        arg=arg,
        cases=cases,
        case_body_builder=case_body_builder,
        default_body_builder=default_body_builder,
        loc=loc,
        ip=ip,
    )
    return _get_op_result_or_op_results(op)
