
# Autogenerated by mlir-tblgen; don't manually edit.

from ._ods_common import _cext as _ods_cext
from ._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)

import builtins
from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional


@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "sdy"

@_ods_cext.register_operation(_Dialect)
class AllGatherOp(_ods_ir.OpView):
  r"""
  Gathers chunks of a tensor along axes specified in `gathering_axes`.
  
  The `gathering_axes` is a list of lists of axes. The outer list is over the
  dimensions of the tensor. Each inner list specifies the axes along which a
  separate gather should be performed on the respective dimension. It will be
  applied to the sharding of the operand (`tensor`) to obtain the sharding of
  the result (`out_sharding`).
  
  Note that `out_sharding` is not used to determine the sharding of the
  result. Instead, the sharding of the result is determined by the sharding of
  the operand and the `gathering_axes`, and `out_sharding` must match this
  inferred sharding.
  
  Example:
  ```mlir
  %1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8x8xf32>
  %2 = sdy.all_gather [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8x8xf32>
  ```
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - Elements in `gathering_axes` must satisfy the constraints listed in
    `AxisRefListAttr`.
  - Applying `gathering_axes` to the operand sharding gets `out_sharding`.
  """

  OPERATION_NAME = "sdy.all_gather"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, gathering_axes, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["gathering_axes"] = (gathering_axes if (
    isinstance(gathering_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ListOfAxisRefLists')) else
      _ods_ir.AttrBuilder.get('Sdy_ListOfAxisRefLists')(gathering_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def gathering_axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["gathering_axes"]

  @gathering_axes.setter
  def gathering_axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["gathering_axes"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def all_gather(tensor, gathering_axes, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return AllGatherOp(tensor=tensor, gathering_axes=gathering_axes, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class AllReduceOp(_ods_ir.OpView):
  r"""
  Reduces chunks of a tensor along axes specified in `reduction_axes`.
  The order of `reduction_axes` is not important for the result, but can
  affect the order of the corresponding replica groups.
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - `reduction_axes` must satisfy the constraints listed in `AxisRefListAttr`.
  - `reduction_axes` must be sorted w.r.t. the mesh.
  - The operand sharding and `out_sharding` must have equivalent dimension
    shardings.
  - `reduction_axes` must not overlap with the operand dimension sharding and
    replicated axes (it can overlap with unreduced axes).
  - `reduction_axes` must not overlap with the unreduced axes of
    `out_sharding`. In other words, `out_sharding` must be be replicated along
    `reduction_axes` (implicitly or explicitly).
  """

  OPERATION_NAME = "sdy.all_reduce"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, reduction_axes, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["reduction_axes"] = (reduction_axes if (
    isinstance(reduction_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_AxisRefList')) else
      _ods_ir.AttrBuilder.get('Sdy_AxisRefList')(reduction_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def reduction_axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["reduction_axes"]

  @reduction_axes.setter
  def reduction_axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["reduction_axes"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def all_reduce(tensor, reduction_axes, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return AllReduceOp(tensor=tensor, reduction_axes=reduction_axes, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class AllSliceOp(_ods_ir.OpView):
  r"""
  Slices chunks of a tensor along axes specified in `slicing_axes`. There is
  an algebric duality between `sdy.all_slice` and `sdy.all_gather`.
  
  The `slicing_axes` is a list of lists of axes. The outer list is over the
  dimensions of the tensor. Each inner list specifies the axes along which a
  slice should be performed on the respective dimension. It will be applied to
  the sharding of the operand (`tensor`) to obtain the sharding of the result
  (`out_sharding`).
  
  Note that `out_sharding` is not used to determine the sharding of the
  result. Instead, the sharding of the result is determined by the sharding of
  the operand and the `slicing_axes`, and `out_sharding` must match this
  inferred sharding.
  
  Example:
  ```mlir
  %1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}, {}\]>]>} : tensor<8x8x8xf32>
  %2 = sdy.all_slice [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a", "b", "c"}, {}, {"d"}\]> : tensor<8x8x8xf32>
  ```
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - Elements in `slicing_axes` must satisfy the constraints listed in
    `AxisRefListAttr`.
  - Applying `slicing_axes` to the operand sharding gets `out_sharding`.
  """

  OPERATION_NAME = "sdy.all_slice"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, slicing_axes, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["slicing_axes"] = (slicing_axes if (
    isinstance(slicing_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ListOfAxisRefLists')) else
      _ods_ir.AttrBuilder.get('Sdy_ListOfAxisRefLists')(slicing_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def slicing_axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["slicing_axes"]

  @slicing_axes.setter
  def slicing_axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["slicing_axes"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def all_slice(tensor, slicing_axes, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return AllSliceOp(tensor=tensor, slicing_axes=slicing_axes, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class AllToAllOp(_ods_ir.OpView):
  r"""
  For each (axes, src_dim, tgt_dim) tuple in the parameter list, this
  operation slices chunks of a tensor along dimension `tgt_dim` and axes
  specified in `axes`, scatteres those chunks along the axes, and concatenates
  them along dimension `src_dim`.
  
  This operation is essentially a combination of an all-gather along `src_dim`
  and `axes`, followed by an all-slice along `tgt_dim` and `axes`, i.e., a
  suffix of the axes sharding dimension `src_dim` on the input tensor is
  appended to the axes sharding dimension `tgt_dim` on the output tensor.
  
  The all-to-all will be applied to the sharding of the operand (`tensor`) to
  obtain the sharding of the result (`out_sharding`).
  
  Note that `out_sharding` is not used to determine the sharding of the
  result. Instead, the sharding of the result is determined by the sharding of
  the operand, `src_dim`, `tgt_dim`, and `axes`, and `out_sharding` must match
  this inferred sharding.
  
  Example:
  ```mlir
  %1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b"}, {"c"}, {}, {}\]>]>} : tensor<8x8x4x4x32>
  %2 = sdy.all_to_all [{"b"}: 0->2, {"c"}: 1->3] %1 out_sharding=<@mesh, [{"a"}, {}, {"b"}, {"c"}\]> : tensor<8x8x4x4x32>
  ```
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - The parameter list must not be empty.
  - For each parameter in `params`:
    - Elements in `axes` must satisfy the constraints of `AxisRefAttr`.
    - `src_dim` and `tgt_dim` must be valid dimensions (non-negative and less
    than rank of tensor).
    - Any `src_dim` or `tgt_dim` must be unique across all parameters.
    - `src_dim` must be sorted in ascending order across all parameters.
  - Moving `axes` from `src_dim` to `tgt_dim` in the operand sharding gets
    `out_sharding`.
  """

  OPERATION_NAME = "sdy.all_to_all"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, params, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["params"] = (params if (
    isinstance(params, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_AllToAllParamList')) else
      _ods_ir.AttrBuilder.get('Sdy_AllToAllParamList')(params, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def params(self) -> _ods_ir.Attribute:
    return self.operation.attributes["params"]

  @params.setter
  def params(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["params"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def all_to_all(tensor, params, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return AllToAllOp(tensor=tensor, params=params, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class CollectivePermuteOp(_ods_ir.OpView):
  r"""
  Sends a chunk of the input tensor from each device to another to
  reorder/replace the axes that shard the tensor.
  
  A collective permute can transform the input sharding such that each
  dimension must be as sharded as it was before, i.e., it must be sharded
  along axes whose product of sizes matches that of the axes that previously
  sharded the tensor.
  
  This is useful for reordering axes in a single dimension or across different
  dimensions, and swapping sharded axes with replicated ones.
  
  In the below example, the sharded tensor size is `tensor<1x4x2xf32>`, and
  that is preserved by the collective permute.
  
  Example:
  ```mlir
  sdy.mesh @mesh = <["a"=2, "b"=2, "c"=4, "d"=2, "e"=2, "f"=2]>
  %1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "c"}, {"f"}, {"d", "e"}\]>]>} : tensor<8x8x8xf32>
  %2 = sdy.collective_permute %1 out_sharding=<@mesh, [{"c":(1)2, "b", "f"}, {"a"}, {"e", "d"}\]> : tensor<8x8x8xf32>
  ```
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - If input and output sharding have different meshes, then those meshes must
    have exactly the same axes and different order of device ids.
  - For each dimension, the product of sharding axis sizes in `out_sharding`
    must match that of the corresponding operand dimension sharding.
  """

  OPERATION_NAME = "sdy.collective_permute"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def collective_permute(tensor, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return CollectivePermuteOp(tensor=tensor, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ConstantOp(_ods_ir.OpView):
  r"""
  Produces an `output` tensor from a constant `value`.
  
  See:
  https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
  
  NOTE: SDY defines its own constant op that isn't ConstantLike and doesn't
  have a folder, so that we'll be able to duplicate constants without any
  greedy pattern rewriter folding them back into a single constant. In this
  way, constants can be sharded differently for every use, and no propagation
  is done between constants (or constant expressions).
  
  Example:
  ```mlir
  %output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
  ```
  """

  OPERATION_NAME = "sdy.constant"

  _ODS_REGIONS = (0, True)

  def __init__(self, value, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["value"] = (value if (
    isinstance(value, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('ElementsAttr')) else
      _ods_ir.AttrBuilder.get('ElementsAttr')(value, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def value(self) -> _ods_ir.Attribute:
    return self.operation.attributes["value"]

  @value.setter
  def value(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["value"] = value

  @builtins.property
  def output(self) -> _ods_ir.OpResult[_ods_ir.RankedTensorType]:
    return self.operation.results[0]

def constant(value, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ConstantOp(value=value, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class DataFlowEdgeOp(_ods_ir.OpView):
  r"""
  A data flow edge of some op X defines a bridge between a set of sources
  (each is either an operand of X or an operand of X's block terminator) and
  a set of targets (each is either a result of X or a block argument of X),
  such that all sources and targets should be sharded in the same way.
  
  An op can have multiple data flow edges that are orthogonal to one another.
  
  For example:
  
  ```mlir
    y_0, ..., y_n = while (x_0, ..., x_n)
                    ((pred_arg_0,... , pred_arg_n) { ... })
                    ((body_arg_0,..., body_arg_n) {
                      ...
                      return return_value_0, ..., return_value_n
                    })
  ```
  
  This while op has n data flow edges, the i-th data flow edges is between
  sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`,
  `body_arg_i`.
  
  An `sdy.data_flow_edge` takes as input the owner of an edge (can be
  any of the targets, but preferably an op result rather than a block
  argument), which shouldn't have any other uses. This op isn't pure because
  it can take an input that originally didn't have any uses.
  
  The `sdy.data_flow_edge` also holds an optional sharding for all targets of
  the edge, and that sharding should be updated instead of the targets'
  sharding (if can be attached) during propagation. This is useful when an op
  has many edges, as it's much more efficient to:
  - propagate through each edge separately.
  - update the sharding of each edge separately instead of all targets at once
    (e.g. an op has a single immutable `TensorShardingPerValueAttr` for result
    shardings).
  - add each edge to the worklist separately when the sharding of a source has
    changed.
  
  Propagation will propagate shardings between all sources and targets of a
  `sdy.data_flow_edge` as if it was a regular op with the sources as operands
  and targets as results, and an identity `sdy.op_sharding_rule`. That means
  that forward propagation is from sources to targets and backwards
  propagation is from targets to sources.
  
  We don't allow the input of a `sdy.data_flow_edge` to be defined by an
  `SdyDialect` op, so we can assume that it's defined by an op that has
  unregistered `sdy.sharding` attribute.
  
  NOTE: it's NOT the responsibility of the `sdy.data_flow_edge` to link
  between sources and targets, it's simply attached to the owner of the edge.
  The op that this edge is bound to (while in the example above) is
  responsible for providing this information.
  """

  OPERATION_NAME = "sdy.data_flow_edge"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, *, sharding=None, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    if sharding is not None: attributes["sharding"] = (sharding if (
        isinstance(sharding, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def sharding(self) -> _Optional[_ods_ir.Attribute]:
    if "sharding" not in self.operation.attributes:
      return None
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value: _Optional[_ods_ir.Attribute]):
    if value is not None:
      self.operation.attributes["sharding"] = value
    elif "sharding" in self.operation.attributes:
      del self.operation.attributes["sharding"]

  @sharding.deleter
  def sharding(self):
    del self.operation.attributes["sharding"]

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def data_flow_edge(input, *, sharding=None, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return DataFlowEdgeOp(input=input, sharding=sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ManualComputationOp(_ods_ir.OpView):
  r"""
  Jump into a region written in terms of per-device local code with explicit
  collectives, where logical shapes match local per-device physical buffer
  shapes and collectives correspond exactly to physical cross-device
  communication.
  
  The body is local wrt the manual_axes. Propagation will occur through
  the body on any free axes - those not in the manual_axes list.
  
  Note that any unranked tensors are expected to have a sharding with rank 0,
  i.e. fully replicated.
  
  **Constraints:**
  - Elements in `in_shardings` and `out_shardings` must satisfy the constraints listed in `TensorShardingAttr`.
  - The number of global and local tensor inputs/outputs of the op region must match.
  - The manual axes must come before any free axes in each dim sharding.
  - The manual axes cannot introduce padding. Namely, the dimension size must be divisible by the corresponding manual axes size.
  - The global and local shapes of the op regions arguments/results must match.
  """

  OPERATION_NAME = "sdy.manual_computation"

  _ODS_REGIONS = (1, True)

  def __init__(self, results_, tensors, in_shardings, out_shardings, manual_axes, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["in_shardings"] = (in_shardings if (
    isinstance(in_shardings, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
    attributes["out_shardings"] = (out_shardings if (
    isinstance(out_shardings, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
    attributes["manual_axes"] = (manual_axes if (
    isinstance(manual_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ManualAxes')) else
      _ods_ir.AttrBuilder.get('Sdy_ManualAxes')(manual_axes, context=_ods_context))
    results = []
    results.extend(results_)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def in_shardings(self) -> _ods_ir.Attribute:
    return self.operation.attributes["in_shardings"]

  @in_shardings.setter
  def in_shardings(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["in_shardings"] = value

  @builtins.property
  def out_shardings(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_shardings"]

  @out_shardings.setter
  def out_shardings(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_shardings"] = value

  @builtins.property
  def manual_axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["manual_axes"]

  @manual_axes.setter
  def manual_axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["manual_axes"] = value

  @builtins.property
  def results_(self) -> _ods_ir.OpResultList:
    _ods_variadic_group_length = len(self.operation.results) - 1 + 1
    return self.operation.results[0:0 + _ods_variadic_group_length]

  @builtins.property
  def body(self) -> _ods_ir.Region:
    return self.regions[0]

def manual_computation(results_, tensors, in_shardings, out_shardings, manual_axes, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, ManualComputationOp]:
  op = ManualComputationOp(results_=results_, tensors=tensors, in_shardings=in_shardings, out_shardings=out_shardings, manual_axes=manual_axes, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class MeshOp(_ods_ir.OpView):
  r"""
  Defines a new named mesh. All meshes in a module must have the same number
  of devices (except for meshes with a single device_id).
  The mesh is a `Symbol` operation that appears in the module's
  `SymbolTable` and can be referenced by its `name`.
  """

  OPERATION_NAME = "sdy.mesh"

  _ODS_REGIONS = (0, True)

  def __init__(self, sym_name, mesh, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sym_name"] = (sym_name if (
    isinstance(sym_name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
      _ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
    attributes["mesh"] = (mesh if (
    isinstance(mesh, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_Mesh')) else
      _ods_ir.AttrBuilder.get('Sdy_Mesh')(mesh, context=_ods_context))
    results = []
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def sym_name(self) -> _ods_ir.StringAttr:
    return self.operation.attributes["sym_name"]

  @sym_name.setter
  def sym_name(self, value: _ods_ir.StringAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sym_name"] = value

  @builtins.property
  def mesh(self) -> _ods_ir.Attribute:
    return self.operation.attributes["mesh"]

  @mesh.setter
  def mesh(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["mesh"] = value

def mesh(sym_name, mesh, *, loc=None, ip=None) -> MeshOp:
  return MeshOp(sym_name=sym_name, mesh=mesh, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class NamedComputationOp(_ods_ir.OpView):
  r"""
  Groups a computation, i.e. a block of operations, and gives it a name.
  Propagation will flow in/out of the region as if everything was inlined.
  
  This can be used to handle propagating through call instructions to other
  functions. Any users of Shardy should write an import/export pass that
  converts their call ops to `sdy.named_computation` ops, duplicating/copying
  the body of the called function into the body of the `named_computation`.
  
  The type of each block arguments and returned values in the region must be
  the same as the type of the operands and results type of the op.
  
  Example:
  
  ```mlir
  %1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
    sdy.return %arg1 : tensor<16x32xf32>
  } : (tensor<16x32xf32>) -> tensor<16x32xf32>
  ```
  """

  OPERATION_NAME = "sdy.named_computation"

  _ODS_REGIONS = (1, True)

  def __init__(self, result, name, operands_, *, in_shardings=None, out_shardings=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(operands_))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["name"] = (name if (
    isinstance(name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('StrAttr')) else
      _ods_ir.AttrBuilder.get('StrAttr')(name, context=_ods_context))
    if in_shardings is not None: attributes["in_shardings"] = (in_shardings if (
        isinstance(in_shardings, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
    if out_shardings is not None: attributes["out_shardings"] = (out_shardings if (
        isinstance(out_shardings, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
    results = []
    results.extend(result)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def operands_(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def name(self) -> _ods_ir.StringAttr:
    return self.operation.attributes["name"]

  @name.setter
  def name(self, value: _ods_ir.StringAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["name"] = value

  @builtins.property
  def in_shardings(self) -> _Optional[_ods_ir.Attribute]:
    if "in_shardings" not in self.operation.attributes:
      return None
    return self.operation.attributes["in_shardings"]

  @in_shardings.setter
  def in_shardings(self, value: _Optional[_ods_ir.Attribute]):
    if value is not None:
      self.operation.attributes["in_shardings"] = value
    elif "in_shardings" in self.operation.attributes:
      del self.operation.attributes["in_shardings"]

  @in_shardings.deleter
  def in_shardings(self):
    del self.operation.attributes["in_shardings"]

  @builtins.property
  def out_shardings(self) -> _Optional[_ods_ir.Attribute]:
    if "out_shardings" not in self.operation.attributes:
      return None
    return self.operation.attributes["out_shardings"]

  @out_shardings.setter
  def out_shardings(self, value: _Optional[_ods_ir.Attribute]):
    if value is not None:
      self.operation.attributes["out_shardings"] = value
    elif "out_shardings" in self.operation.attributes:
      del self.operation.attributes["out_shardings"]

  @out_shardings.deleter
  def out_shardings(self):
    del self.operation.attributes["out_shardings"]

  @builtins.property
  def body(self) -> _ods_ir.Region:
    return self.regions[0]

def named_computation(result, name, operands_, *, in_shardings=None, out_shardings=None, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, NamedComputationOp]:
  op = NamedComputationOp(result=result, name=name, operands_=operands_, in_shardings=in_shardings, out_shardings=out_shardings, loc=loc, ip=ip); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

@_ods_cext.register_operation(_Dialect)
class PropagationBarrierOp(_ods_ir.OpView):
  r"""
  This op operates like an identity op, outputting the same value it took as
  input. But in terms of propagation, this will only allow propagation to flow
  through it in a certain direction.
  
  This prevents shardings from being propagated between the uses of the result
  of the barrier op and its operand.
  
  - `FORWARD` means shardings can only flow from the operand to the result.
  - `BACKWARD` means shardings can only flow from the result to the operand.
  - `NONE` means no sharding can propagate through this op.
  - Cannot specify `BOTH`, as this op would be redundant.
  """

  OPERATION_NAME = "sdy.propagation_barrier"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, allowed_direction, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["allowed_direction"] = (allowed_direction if (
    isinstance(allowed_direction, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_PropagationDirection')) else
      _ods_ir.AttrBuilder.get('Sdy_PropagationDirection')(allowed_direction, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self) -> _ods_ir.Value[_ods_ir.RankedTensorType]:
    return self.operation.operands[0]

  @builtins.property
  def allowed_direction(self) -> _ods_ir.Attribute:
    return self.operation.attributes["allowed_direction"]

  @allowed_direction.setter
  def allowed_direction(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["allowed_direction"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult[_ods_ir.RankedTensorType]:
    return self.operation.results[0]

def propagation_barrier(input, allowed_direction, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return PropagationBarrierOp(input=input, allowed_direction=allowed_direction, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReduceScatterOp(_ods_ir.OpView):
  r"""
  Reduces chunks of a tensor along axes specified in `reduce_scatter_axes`,
  and then scatters the result along the same axes. This operation is
  essentially a combination of an `sdy.all_reduce` followed by an
  `sdy.all_slice` along the same `reduce_scatter_axes`.
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - Elements in `reduce_scatter_axes` must satisfy the constraints listed in
    `AxisRefListAttr`.
  - Applying `reduce_scatter_axes` to the operand sharding gets
    `out_sharding`.
  """

  OPERATION_NAME = "sdy.reduce_scatter"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, reduce_scatter_axes, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["reduce_scatter_axes"] = (reduce_scatter_axes if (
    isinstance(reduce_scatter_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ListOfAxisRefLists')) else
      _ods_ir.AttrBuilder.get('Sdy_ListOfAxisRefLists')(reduce_scatter_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def reduce_scatter_axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["reduce_scatter_axes"]

  @reduce_scatter_axes.setter
  def reduce_scatter_axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["reduce_scatter_axes"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def reduce_scatter(tensor, reduce_scatter_axes, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ReduceScatterOp(tensor=tensor, reduce_scatter_axes=reduce_scatter_axes, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReplicatedToUnreducedOp(_ods_ir.OpView):
  r"""
  The `axes` should be implicitly or explicitly replicated in the operand.
  This operation makes them unreduced in the result. We have the following
  relationship:
  
  all-reduce(replicated-to-unreduced(x, axes), axes) = x
  
  Example:
  ```mlir
  %1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b"}, {}, {}\], replicated={"c", "d"}, unreduced={"e"}>]>} : tensor<8x8x8xf32>
  %2 = sdy.replicated_to_unreduced {"a", "c", "f"} %1 out_sharding=<@mesh, [{"b"}, {}, {}\], replicated={"d"}, unreduced={"a", "c", "e", "f"}> : tensor<8x8x8xf32>
  ```
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - `axes` must satisfy the constraints listed in `AxisRefListAttr`.
  - `axes` must be sorted w.r.t. the mesh.
  - `axes` are not empty.
  - The input and output sharding must have the same dimension shardings.
  - `axes` must be implicitly or explicitly replicated in the operand sharding.
  - inUnreducedAxes + axes = outUnreducedAxes.
  """

  OPERATION_NAME = "sdy.replicated_to_unreduced"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, axes, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["axes"] = (axes if (
    isinstance(axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_AxisRefList')) else
      _ods_ir.AttrBuilder.get('Sdy_AxisRefList')(axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["axes"]

  @axes.setter
  def axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["axes"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def replicated_to_unreduced(tensor, axes, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ReplicatedToUnreducedOp(tensor=tensor, axes=axes, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReshardOp(_ods_ir.OpView):
  r"""
    Reshards the input tensor with the specified sharding, which is different
    from the input tensor's existing sharding.
  
    Both ShardingConstraintOp and ReshardOp attach a sharding to a tensor. Their
    lifespan is:
    1. Before sharding propagation, ShardingConstraintOp is added by users.
    2. Sharding propagation consumes ShardingConstraintOp. There is no
       ShardingConstraintOp in the results of sharding propagation. Instead,
       ReshardOp may be added if needed.
    3. A partitioner converts a ReshardOp into a collective op (or an identity
       op). There should be no ReshardOp in the results of the partitioner.
  
  // TODO(b/331680067). Add a canonicalization pattern to remove redundant
  // reshard ops.
  """

  OPERATION_NAME = "sdy.reshard"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sharding"] = (sharding if (
    isinstance(sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def reshard(input, sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ReshardOp(input=input, sharding=sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReturnOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.return"

  _ODS_REGIONS = (0, True)

  def __init__(self, results_, *, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(results_))
    _ods_context = _ods_get_default_loc_context(loc)
    results = []
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def results_(self) -> _ods_ir.OpOperandList:
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

def return_(results_, *, loc=None, ip=None) -> ReturnOp:
  return ReturnOp(results_=results_, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class ShardedToUnreducedOp(_ods_ir.OpView):
  r"""
  The `axes` should be used to shard the operand. This operation makes them
  unreduced in the result. We have the following relationship:
  
  all-gather(x, axes) = all-reduce(sharded-to-unreduced(x, axes), axes), where
  all-gather, sharded-to-unreduced, all-reduce are applied on the same axes.
  
  Example:
  ```mlir
  %1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\], unreduced={"e"}>]>} : tensor<8x8x8xf32>
  %2 = sdy.sharded_to_unreduced [{"b", "c"}, {}, {"d"}\] %1 out_sharding=<@mesh, [{"a"}, {}, {}\], unreduced={"b", "c", "d", "e"}> : tensor<8x8x8xf32>
  ```
  
  **Constraints:**
  - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
  - Elements in `axes` must satisfy the constraints listed in `AxisRefListAttr`.
  - Applying `axes` to the operand sharding gets `out_sharding`.
  """

  OPERATION_NAME = "sdy.sharded_to_unreduced"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, axes, out_sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["axes"] = (axes if (
    isinstance(axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ListOfAxisRefLists')) else
      _ods_ir.AttrBuilder.get('Sdy_ListOfAxisRefLists')(axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def axes(self) -> _ods_ir.Attribute:
    return self.operation.attributes["axes"]

  @axes.setter
  def axes(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["axes"] = value

  @builtins.property
  def out_sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def sharded_to_unreduced(tensor, axes, out_sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ShardedToUnreducedOp(tensor=tensor, axes=axes, out_sharding=out_sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ShardingConstraintOp(_ods_ir.OpView):
  r"""
  Attaches a sharding to an intermediate tensor (e.g. the result of a matmul)
  to indicate that this is how that tensor, or a subset of its uses, should be
  sharded.
  
  If the sharding has open dimensions and unconstraint axes, it means the
  tensor can be further sharded along the open dimensions.
  
  This op can either:
  - Have no uses (dangling) - which means the attached sharding is how the
    input tensor itself should be sharded.
  - Have uses - which means the attached sharding is how the uses of the
    sharding constraint op should be sharded, while other uses of the input
    tensor might have a different sharding (if the input tensor has no other
    uses then the behavior is the same as the no uses case).
  """

  OPERATION_NAME = "sdy.sharding_constraint"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, sharding, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sharding"] = (sharding if (
    isinstance(sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    if results is None: results = [operands[0].type] * 1
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self) -> _ods_ir.Value:
    return self.operation.operands[0]

  @builtins.property
  def sharding(self) -> _ods_ir.Attribute:
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value: _ods_ir.Attribute):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sharding"] = value

  @builtins.property
  def result(self) -> _ods_ir.OpResult:
    return self.operation.results[0]

def sharding_constraint(input, sharding, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
  return ShardingConstraintOp(input=input, sharding=sharding, results=results, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ShardingGroupOp(_ods_ir.OpView):
  r"""
  This op provides an interface to assign tensors to sharding groups (
  groups of tensors that will be enforced to have identical shardings).
  During propagation, as soon as one group element is sharded, all other
  members will be sharded in exactly the same way. This operation takes the
  argument group ID and returns no result, but instead modifies the internal
  sharding group representation to add the input tensor to the group with the
  given ID.
  """

  OPERATION_NAME = "sdy.sharding_group"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, group_id, *, results=None, loc=None, ip=None):
    operands = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["group_id"] = (group_id if (
    isinstance(group_id, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I64Attr')) else
      _ods_ir.AttrBuilder.get('I64Attr')(group_id, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self) -> _ods_ir.Value[_ods_ir.RankedTensorType]:
    return self.operation.operands[0]

  @builtins.property
  def group_id(self) -> _ods_ir.IntegerAttr:
    return self.operation.attributes["group_id"]

  @group_id.setter
  def group_id(self, value: _ods_ir.IntegerAttr):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["group_id"] = value

def sharding_group(input, group_id, *, results=None, loc=None, ip=None) -> ShardingGroupOp:
  return ShardingGroupOp(input=input, group_id=group_id, results=results, loc=loc, ip=ip)
