
# Autogenerated by mlir-tblgen; don't manually edit.

from ._ods_common import _cext as _ods_cext
from ._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)

import builtins
from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional


@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "mpmd"

@_ods_cext.register_operation(_Dialect)
class AssignOp(_ods_ir.OpView):
  r"""
  Assigns a local tensor to a mesh as fully replicated within that mesh.
  
  This is a temporary op that is introduced when lowering jax ops, to move
  from local types to mesh types. These ops will be eliminated during import,
  when the inputs and results of the func op become mesh tensors.
  
  The mesh name of the result type should correspond to a mesh in the
  topology, and its global type should be identical to the operand type.
  
  The origin of the assign op is the origin of mesh, e.g. named_computation,
  mesh inference, etc.
  """

  OPERATION_NAME = "mpmd.assign"

  _ODS_REGIONS = (0, True)

  def __init__(self, result, tensor, *, origin=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    if origin is not None: attributes["origin"] = (origin if (
        isinstance(origin, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('StrAttr')) else
          _ods_ir.AttrBuilder.get('StrAttr')(origin, context=_ods_context))
    results = []
    results.append(result)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def origin(self) -> _Optional[_ods_ir.StringAttr]:
    if "origin" not in self.operation.attributes:
      return None
    return self.operation.attributes["origin"]

  @origin.setter
  def origin(self, value: _Optional[_ods_ir.StringAttr]):
    if value is not None:
      self.operation.attributes["origin"] = value
    elif "origin" in self.operation.attributes:
      del self.operation.attributes["origin"]

  @origin.deleter
  def origin(self):
    del self.operation.attributes["origin"]

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def assign(result, tensor, *, origin=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return AssignOp(result=result, tensor=tensor, origin=origin, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class BroadcastOp(_ods_ir.OpView):
  r"""
  Allows for a tensor to be transferred (or replicated) in any mesh where it's
  used. Whenever transferred, the origin of the transfer is the current
  location of the operand.
  """

  OPERATION_NAME = "mpmd.broadcast"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def broadcast(tensor, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return BroadcastOp(tensor=tensor, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class CallOp(_ods_ir.OpView):
  r"""
  A function call operation. Useful to wrap the body of loops in function
  declarations to reduce code size, for example.
  """

  OPERATION_NAME = "mpmd.call"

  _ODS_REGIONS = (0, True)

  def __init__(self, result, tensors, callee, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["callee"] = (callee if (
    isinstance(callee, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
      _ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
    results = []
    results.extend(result)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def callee(self) -> _ods_ir.FlatSymbolRefAttr:
    return self.operation.attributes["callee"]

  @callee.setter
  def callee(self, value: _ods_ir.FlatSymbolRefAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["callee"] = value

def call(result, tensors, callee, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallOp]:
  op = CallOp(result=result, tensors=tensors, callee=callee, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class ForOp(_ods_ir.OpView):
  r"""
  Returns the result of executing a body function for a fixed number of
  iterations, with the iteration index available in the body.
  
  An optional unroll factor, that must divide the number of iterations,
  can be specified to unroll the body of the op by that factor, i.e. for
  unroll factor N, the body is replicated to create N copies and the number of
  iterations is reduced by a factor of 1/N. Each copy except the first uses
  the results of the previous copy instead of the block arguments, and the
  iteration index is multiplied by the unroll factor and incremented after
  every copy.
  
  A for operator can accept and return any types, but the TypeID of these
  must be the same -- e.g. all tensor types or all MPMD mesh types etc. This
  allows us to use the op at various levels, sharing implementation and
  transformations.
  """

  OPERATION_NAME = "mpmd.for"

  _ODS_REGIONS = (1, True)

  def __init__(self, results_, tensors, iterations, *, unroll_factor=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["iterations"] = (iterations if (
    isinstance(iterations, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('UI32Attr')) else
      _ods_ir.AttrBuilder.get('UI32Attr')(iterations, context=_ods_context))
    if unroll_factor is not None: attributes["unroll_factor"] = (unroll_factor if (
        isinstance(unroll_factor, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('UI32Attr')) else
          _ods_ir.AttrBuilder.get('UI32Attr')(unroll_factor, context=_ods_context))
    results = []
    results.extend(results_)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def iterations(self) -> _ods_ir.IntegerAttr:
    return self.operation.attributes["iterations"]

  @iterations.setter
  def iterations(self, value: _ods_ir.IntegerAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["iterations"] = value

  @builtins.property
  def unroll_factor(self) -> _Optional[_ods_ir.IntegerAttr]:
    if "unroll_factor" not in self.operation.attributes:
      return None
    return self.operation.attributes["unroll_factor"]

  @unroll_factor.setter
  def unroll_factor(self, value: _Optional[_ods_ir.IntegerAttr]):
    if value is not None:
      self.operation.attributes["unroll_factor"] = value
    elif "unroll_factor" in self.operation.attributes:
      del self.operation.attributes["unroll_factor"]

  @unroll_factor.deleter
  def unroll_factor(self):
    del self.operation.attributes["unroll_factor"]

  @builtins.property
  def results_(self) -> _ods_ir.OpResultList:
    _ods_variadic_group_length = len(self.operation.results) - 1 + 1
    return self.operation.results[0:0 + _ods_variadic_group_length]

  @builtins.property
  def region(self) -> _ods_ir.Region:
    return self.regions[0]

def for_(results_, tensors, iterations, *, unroll_factor=None, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, ForOp]:
  op = ForOp(results_=results_, tensors=tensors, iterations=iterations, unroll_factor=unroll_factor, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class FragmentCallOp(_ods_ir.OpView):
  r"""
  Represents a call to a function that holds an MPMD fragment body, i.e. a
  computation assigned to a specific mesh in an MPMD topology, that is
  intended to be executed as an individual SPMD program fragment.
  
  The mesh name of the fragment should correspond to a mesh in the topology of
  the enclosing function, and that mesh shape should match that of the callee.
  
  The origin specifies the user named computations that contributed to this
  fragment call e.g. through merging.
  
  The function input and result types of the callee must be the local tensor
  types of the corresponding mesh tensors of this op's operands and results
  respectively.
  
  Example:
  
  ```mlir
  %2 = mpmd.fragment_call<mesh="m1", origin=[]> @my_fragment(%0, %1) :
    (mesh_tensor<...>, mesh_tensor<...>) -> mesh_tensor<...>
  ```
  """

  OPERATION_NAME = "mpmd.fragment_call"

  _ODS_REGIONS = (0, True)

  def __init__(self, result, tensors, origin, mesh_name, callee, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["origin"] = (origin if (
    isinstance(origin, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('anonymous_830')) else
      _ods_ir.AttrBuilder.get('anonymous_830')(origin, context=_ods_context))
    attributes["mesh_name"] = (mesh_name if (
    isinstance(mesh_name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('StrAttr')) else
      _ods_ir.AttrBuilder.get('StrAttr')(mesh_name, context=_ods_context))
    attributes["callee"] = (callee if (
    isinstance(callee, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
      _ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
    results = []
    results.extend(result)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def origin(self) -> _ods_ir.ArrayAttr:
    return self.operation.attributes["origin"]

  @origin.setter
  def origin(self, value: _ods_ir.ArrayAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["origin"] = value

  @builtins.property
  def mesh_name(self) -> _ods_ir.StringAttr:
    return self.operation.attributes["mesh_name"]

  @mesh_name.setter
  def mesh_name(self, value: _ods_ir.StringAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["mesh_name"] = value

  @builtins.property
  def callee(self) -> _ods_ir.FlatSymbolRefAttr:
    return self.operation.attributes["callee"]

  @callee.setter
  def callee(self, value: _ods_ir.FlatSymbolRefAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["callee"] = value

def fragment_call(result, tensors, origin, mesh_name, callee, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, FragmentCallOp]:
  op = FragmentCallOp(result=result, tensors=tensors, origin=origin, mesh_name=mesh_name, callee=callee, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class FragmentOp(_ods_ir.OpView):
  r"""
  Assigns a computation, i.e. a block of operations, to a specific mesh in an
  MPMD topology, that is intended to be executed as an individual SPMD program
  fragment.
  
  The fragment takes and returns only mesh tensors that are assigned to the
  same mesh as the fragment.
  
  The mesh name of the fragment should correspond to a mesh in the topology.
  
  The fragment includes a list of origins, i.e., metadata with information re
  the original named_computations that formed this fragment, and a staged_id
  defined _iff_ it is a user defined fragment, i.e., it has a non-empty list
  of origins. The optional in_shardings specifies the sharding of the
  block arguments of a fragment, which correspond to the operands.
  The optional out_shardings specifies the shardings of the results.
  
  The fragment's region shouldn't have any free variables, and the type of
  each block arguments and returned values in the region is the global tensor
  type of the corresponding mesh tensor.
  """

  OPERATION_NAME = "mpmd.fragment"

  _ODS_REGIONS = (1, True)

  def __init__(self, results_, inputs, origin, mesh_name, *, stage_id=None, in_shardings=None, out_shardings=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(inputs))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["origin"] = (origin if (
    isinstance(origin, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('anonymous_830')) else
      _ods_ir.AttrBuilder.get('anonymous_830')(origin, context=_ods_context))
    attributes["mesh_name"] = (mesh_name if (
    isinstance(mesh_name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('StrAttr')) else
      _ods_ir.AttrBuilder.get('StrAttr')(mesh_name, context=_ods_context))
    if stage_id is not None: attributes["stage_id"] = (stage_id if (
        isinstance(stage_id, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('I64Attr')) else
          _ods_ir.AttrBuilder.get('I64Attr')(stage_id, context=_ods_context))
    if in_shardings is not None: attributes["in_shardings"] = (in_shardings if (
        isinstance(in_shardings, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
    if out_shardings is not None: attributes["out_shardings"] = (out_shardings if (
        isinstance(out_shardings, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
    results = []
    results.extend(results_)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def inputs(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def origin(self) -> _ods_ir.ArrayAttr:
    return self.operation.attributes["origin"]

  @origin.setter
  def origin(self, value: _ods_ir.ArrayAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["origin"] = value

  @builtins.property
  def mesh_name(self) -> _ods_ir.StringAttr:
    return self.operation.attributes["mesh_name"]

  @mesh_name.setter
  def mesh_name(self, value: _ods_ir.StringAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["mesh_name"] = value

  @builtins.property
  def stage_id(self) -> _Optional[_ods_ir.IntegerAttr]:
    if "stage_id" not in self.operation.attributes:
      return None
    return self.operation.attributes["stage_id"]

  @stage_id.setter
  def stage_id(self, value: _Optional[_ods_ir.IntegerAttr]):
    if value is not None:
      self.operation.attributes["stage_id"] = value
    elif "stage_id" in self.operation.attributes:
      del self.operation.attributes["stage_id"]

  @stage_id.deleter
  def stage_id(self):
    del self.operation.attributes["stage_id"]

  @builtins.property
  def in_shardings(self) -> _Optional[_ods_ir.Attribute]:
    if "in_shardings" not in self.operation.attributes:
      return None
    return self.operation.attributes["in_shardings"]

  @in_shardings.setter
  def in_shardings(self, value: _Optional[_ods_ir.Attribute]):
    if value is not None:
      self.operation.attributes["in_shardings"] = value
    elif "in_shardings" in self.operation.attributes:
      del self.operation.attributes["in_shardings"]

  @in_shardings.deleter
  def in_shardings(self):
    del self.operation.attributes["in_shardings"]

  @builtins.property
  def out_shardings(self) -> _Optional[_ods_ir.Attribute]:
    if "out_shardings" not in self.operation.attributes:
      return None
    return self.operation.attributes["out_shardings"]

  @out_shardings.setter
  def out_shardings(self, value: _Optional[_ods_ir.Attribute]):
    if value is not None:
      self.operation.attributes["out_shardings"] = value
    elif "out_shardings" in self.operation.attributes:
      del self.operation.attributes["out_shardings"]

  @out_shardings.deleter
  def out_shardings(self):
    del self.operation.attributes["out_shardings"]

  @builtins.property
  def results_(self) -> _ods_ir.OpResultList:
    _ods_variadic_group_length = len(self.operation.results) - 1 + 1
    return self.operation.results[0:0 + _ods_variadic_group_length]

  @builtins.property
  def region(self) -> _ods_ir.Region:
    return self.regions[0]

def fragment(results_, inputs, origin, mesh_name, *, stage_id=None, in_shardings=None, out_shardings=None, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, FragmentOp]:
  op = FragmentOp(results_=results_, inputs=inputs, origin=origin, mesh_name=mesh_name, stage_id=stage_id, in_shardings=in_shardings, out_shardings=out_shardings, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class NamedComputationOp(_ods_ir.OpView):
  r"""
  Groups a computation, i.e. a block of operations, and gives it a name and
  a transpose count via the UserOrigin attribute. This NamedComputation can be
  used to assign a mesh to the computation in MPMD or for optimizations.
  
  The transpose count (default=0) denotes whether the named computation has
  been produced by a certain number of JAX AD transpose transformations.
  
  The op's region shouldn't have any free variables, and the type of
  each block arguments and returned values in the region must be the same as
  the type of the inputs and the return type of the op.
  """

  OPERATION_NAME = "mpmd.named_computation"

  _ODS_REGIONS = (1, True)

  def __init__(self, results_, tensors, origin, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["origin"] = (origin if (
    isinstance(origin, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Mpmd_UserOrigin')) else
      _ods_ir.AttrBuilder.get('Mpmd_UserOrigin')(origin, context=_ods_context))
    results = []
    results.extend(results_)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def origin(self) -> _ods_ir.Attribute:
    return self.operation.attributes["origin"]

  @origin.setter
  def origin(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["origin"] = value

  @builtins.property
  def results_(self) -> _ods_ir.OpResultList:
    _ods_variadic_group_length = len(self.operation.results) - 1 + 1
    return self.operation.results[0:0 + _ods_variadic_group_length]

  @builtins.property
  def region(self) -> _ods_ir.Region:
    return self.regions[0]

def named_computation(results_, tensors, origin, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, NamedComputationOp]:
  op = NamedComputationOp(results_=results_, tensors=tensors, origin=origin, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class NamedTensorOp(_ods_ir.OpView):
  r"""
  An identity op that associates the result of the tensor with a given name.
  This NamedTensor can be used to assign a mesh to the tensor in MPMD.
  
  NOTE: this is different than TagOp in that TagOp is used for naming a tensor
  and can be used to partition that tensor. NamedTensorOp is for MPMD programs
  for tensors that may be explicitly assigned to meshes.
  """

  OPERATION_NAME = "mpmd.named_tensor"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, name, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["name"] = (name if (
    isinstance(name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('StrAttr')) else
      _ods_ir.AttrBuilder.get('StrAttr')(name, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def name(self) -> _ods_ir.StringAttr:
    return self.operation.attributes["name"]

  @name.setter
  def name(self, value: _ods_ir.StringAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["name"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def named_tensor(tensor, name, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return NamedTensorOp(tensor=tensor, name=name, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReduceOp(_ods_ir.OpView):
  r"""
  Allows for a tensor to be reduced across different meshes, and then
  broadcast to wherever it needs to be used.
  """

  OPERATION_NAME = "mpmd.reduce"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensors, *, reduction=None, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    if reduction is not None: attributes["reduction"] = (reduction if (
        isinstance(reduction, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Mpmd_Reduction')) else
          _ods_ir.AttrBuilder.get('Mpmd_Reduction')(reduction, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def reduction(self) -> _ods_ir.Attribute:
    return self.operation.attributes["reduction"]

  @reduction.setter
  def reduction(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["reduction"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def reduce(tensors, *, reduction=None, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ReduceOp(tensors=tensors, reduction=reduction, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReturnOp(_ods_ir.OpView):
  OPERATION_NAME = "mpmd.return"

  _ODS_REGIONS = (0, True)

  def __init__(self, results_, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(results_))
    _ods_context = _ods_get_default_loc_context(loc)
    results = []
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def results_(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

def return_(results_, *, loc=None, ip=None) -> ReturnOp:
  return ReturnOp(results_=results_, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class TransferOp(_ods_ir.OpView):
  r"""
  Transfers a distributed tensor from one mesh to another.
  
  The mesh names of the operand and result types should correspond to meshes
  in the topology, and their global types should be identical.
  """

  OPERATION_NAME = "mpmd.transfer"

  _ODS_REGIONS = (0, True)

  def __init__(self, result, tensor, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    results = []
    results.append(result)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def transfer(result, tensor, *, loc=None, ip=None) -> _ods_ir.OpResult:
  return TransferOp(result=result, tensor=tensor, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class UnassignOp(_ods_ir.OpView):
  r"""
  Unassigns a fully replicated tensor from a mesh.
  
  This is a temporary op that is introduced when lowering jax ops, to move
  from local types to mesh types. These ops will be eliminated during import,
  when the inputs and results of the func op become mesh tensors.
  
  The mesh name of the operand type should correspond to a mesh in the
  topology, and its global type should be identical to the result type.
  """

  OPERATION_NAME = "mpmd.unassign"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, *, origin=None, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    if origin is not None: attributes["origin"] = (origin if (
        isinstance(origin, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('StrAttr')) else
          _ods_ir.AttrBuilder.get('StrAttr')(origin, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def origin(self) -> _Optional[_ods_ir.StringAttr]:
    if "origin" not in self.operation.attributes:
      return None
    return self.operation.attributes["origin"]

  @origin.setter
  def origin(self, value: _Optional[_ods_ir.StringAttr]):
    if value is not None:
      self.operation.attributes["origin"] = value
    elif "origin" in self.operation.attributes:
      del self.operation.attributes["origin"]

  @origin.deleter
  def origin(self):
    del self.operation.attributes["origin"]

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def unassign(tensor, *, origin=None, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return UnassignOp(tensor=tensor, origin=origin, results=results, loc=loc, ip=ip).result
