
# Autogenerated by mlir-tblgen; don't manually edit.

from ._ods_common import _cext as _ods_cext
from ._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)

import builtins
from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional


@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "cf"

@_ods_cext.register_operation(_Dialect)
class AssertOp(_ods_ir.OpView):
  r"""
  Assert operation at runtime with single boolean operand and an error
  message attribute.
  If the argument is `true` this operation has no effect. Otherwise, the
  program execution will abort. The provided error message may be used by a
  runtime to propagate the error to the user.
  
  Example:
  
  ```mlir
  cf.assert %b, "Expected ... to be true"
  ```
  """

  OPERATION_NAME = "cf.assert"

  _ODS_REGIONS = (0, True)

  def __init__(self, arg, msg, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(arg)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["msg"] = (msg if (
    isinstance(msg, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('StrAttr')) else
      _ods_ir.AttrBuilder.get('StrAttr')(msg, context=_ods_context))
    results = []
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def arg(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
    return self.operation.operands[0]

  @builtins.property
  def msg(self) -> _ods_ir.StringAttr:
    return self.operation.attributes["msg"]

  @msg.setter
  def msg(self, value: _ods_ir.StringAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["msg"] = value

def assert_(arg, msg, *, loc=None, ip=None) -> AssertOp:
  return AssertOp(arg=arg, msg=msg, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class BranchOp(_ods_ir.OpView):
  r"""
  The `cf.br` operation represents a direct branch operation to a given
  block. The operands of this operation are forwarded to the successor block,
  and the number and type of the operands must match the arguments of the
  target block.
  
  Example:
  
  ```mlir
  ^bb2:
    %2 = call @someFn()
    cf.br ^bb3(%2 : tensor<*xf32>)
  ^bb3(%3: tensor<*xf32>):
  ```
  """

  OPERATION_NAME = "cf.br"

  _ODS_REGIONS = (0, True)

  def __init__(self, destOperands, dest, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(destOperands))
    _ods_context = _ods_get_default_loc_context(loc)
    results = []
    _ods_successors = []
    _ods_successors.append(dest)
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def destOperands(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

def br(dest_operands, dest, *, loc=None, ip=None) -> BranchOp:
  return BranchOp(destOperands=dest_operands, dest=dest, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class CondBranchOp(_ods_ir.OpView):
  r"""
  The `cf.cond_br` terminator operation represents a conditional branch on a
  boolean (1-bit integer) value. If the bit is set, then the first destination
  is jumped to; if it is false, the second destination is chosen. The count
  and types of operands must align with the arguments in the corresponding
  target blocks.
  
  The MLIR conditional branch operation is not allowed to target the entry
  block for a region. The two destinations of the conditional branch operation
  are allowed to be the same.
  
  The following example illustrates a function with a conditional branch
  operation that targets the same block.
  
  Example:
  
  ```mlir
  func.func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
    // Both targets are the same, operands differ
    cf.cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
  
  ^bb1(%x : i32) :
    return %x : i32
  }
  ```
  """

  OPERATION_NAME = "cf.cond_br"

  _ODS_OPERAND_SEGMENTS = [1,-1,-1,]

  _ODS_REGIONS = (0, True)

  def __init__(self, condition, trueDestOperands, falseDestOperands, trueDest, falseDest, *, branch_weights=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(condition)
    operands.append(_get_op_results_or_values(trueDestOperands))
    operands.append(_get_op_results_or_values(falseDestOperands))
    _ods_context = _ods_get_default_loc_context(loc)
    if branch_weights is not None: attributes["branch_weights"] = (branch_weights if (
        isinstance(branch_weights, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('DenseI32ArrayAttr')) else
          _ods_ir.AttrBuilder.get('DenseI32ArrayAttr')(branch_weights, context=_ods_context))
    results = []
    _ods_successors = []
    _ods_successors.append(trueDest)
    _ods_successors.append(falseDest)
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def condition(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 0)
    return operand_range[0]

  @builtins.property
  def trueDestOperands(self) -> _ods_ir.OpOperandList:
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 1)
    return operand_range

  @builtins.property
  def falseDestOperands(self) -> _ods_ir.OpOperandList:
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 2)
    return operand_range

  @builtins.property
  def branch_weights(self) -> _Optional[_ods_ir.DenseI32ArrayAttr]:
    if "branch_weights" not in self.operation.attributes:
      return None
    return self.operation.attributes["branch_weights"]

  @branch_weights.setter
  def branch_weights(self, value: _Optional[_ods_ir.DenseI32ArrayAttr]):
    if value is not None:
      self.operation.attributes["branch_weights"] = value
    elif "branch_weights" in self.operation.attributes:
      del self.operation.attributes["branch_weights"]

  @branch_weights.deleter
  def branch_weights(self):
    del self.operation.attributes["branch_weights"]

def cond_br(condition, true_dest_operands, false_dest_operands, true_dest, false_dest, *, branch_weights=None, loc=None, ip=None) -> CondBranchOp:
  return CondBranchOp(condition=condition, trueDestOperands=true_dest_operands, falseDestOperands=false_dest_operands, trueDest=true_dest, falseDest=false_dest, branch_weights=branch_weights, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class SwitchOp(_ods_ir.OpView):
  r"""
  The `cf.switch` terminator operation represents a switch on a signless integer
  value. If the flag matches one of the specified cases, then the
  corresponding destination is jumped to. If the flag does not match any of
  the cases, the default destination is jumped to. The count and types of
  operands must align with the arguments in the corresponding target blocks.
  
  Example:
  
  ```mlir
  cf.switch %flag : i32, [
    default: ^bb1(%a : i32),
    42: ^bb1(%b : i32),
    43: ^bb3(%c : i32)
  ]
  ```
  """

  OPERATION_NAME = "cf.switch"

  _ODS_OPERAND_SEGMENTS = [1,-1,-1,]

  _ODS_REGIONS = (0, True)

  def __init__(self, flag, defaultOperands, caseOperands, case_operand_segments, defaultDestination, caseDestinations, *, case_values=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(flag)
    operands.append(_get_op_results_or_values(defaultOperands))
    operands.append(_get_op_results_or_values(caseOperands))
    _ods_context = _ods_get_default_loc_context(loc)
    if case_values is not None: attributes["case_values"] = (case_values if (
        isinstance(case_values, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('AnyIntElementsAttr')) else
          _ods_ir.AttrBuilder.get('AnyIntElementsAttr')(case_values, context=_ods_context))
    attributes["case_operand_segments"] = (case_operand_segments if (
    isinstance(case_operand_segments, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('DenseI32ArrayAttr')) else
      _ods_ir.AttrBuilder.get('DenseI32ArrayAttr')(case_operand_segments, context=_ods_context))
    results = []
    _ods_successors = []
    _ods_successors.append(defaultDestination)
    _ods_successors.extend(caseDestinations)
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def flag(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 0)
    return operand_range[0]

  @builtins.property
  def defaultOperands(self) -> _ods_ir.OpOperandList:
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 1)
    return operand_range

  @builtins.property
  def caseOperands(self) -> _ods_ir.OpOperandList:
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 2)
    return operand_range

  @builtins.property
  def case_values(self) -> _Optional[_ods_ir.DenseIntElementsAttr]:
    if "case_values" not in self.operation.attributes:
      return None
    return self.operation.attributes["case_values"]

  @case_values.setter
  def case_values(self, value: _Optional[_ods_ir.DenseIntElementsAttr]):
    if value is not None:
      self.operation.attributes["case_values"] = value
    elif "case_values" in self.operation.attributes:
      del self.operation.attributes["case_values"]

  @case_values.deleter
  def case_values(self):
    del self.operation.attributes["case_values"]

  @builtins.property
  def case_operand_segments(self) -> _ods_ir.DenseI32ArrayAttr:
    return self.operation.attributes["case_operand_segments"]

  @case_operand_segments.setter
  def case_operand_segments(self, value: _ods_ir.DenseI32ArrayAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["case_operand_segments"] = value

def switch(flag, default_operands, case_operands, case_operand_segments, default_destination, case_destinations, *, case_values=None, loc=None, ip=None) -> SwitchOp:
  return SwitchOp(flag=flag, defaultOperands=default_operands, caseOperands=case_operands, case_operand_segments=case_operand_segments, defaultDestination=default_destination, caseDestinations=case_destinations, case_values=case_values, loc=loc, ip=ip)
