
# Autogenerated by mlir-tblgen; don't manually edit.

from enum import IntEnum, auto, IntFlag
from jaxlib.mlir.dialects._ods_common import _cext as _ods_cext
from jaxlib.mlir.ir import register_attribute_builder
_ods_ir = _ods_cext.ir

class Dimension(IntEnum):
    """a dimension, either 'x', 'y', or 'z'"""

    x = 0
    y = 1
    z = 2

    def __str__(self):
        if self is Dimension.x:
            return "x"
        if self is Dimension.y:
            return "y"
        if self is Dimension.z:
            return "z"
        raise ValueError("Unknown Dimension enum entry.")



@register_attribute_builder("MosaicGPU_Dimension")
def _mosaicgpu_dimension(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class SwizzlingMode(IntEnum):
    """What swizzling to use for a memory access."""

    kNoSwizzle = 16
    k32ByteSwizzle = 32
    k64ByteSwizzle = 64
    k128ByteSwizzle = 128

    def __str__(self):
        if self is SwizzlingMode.kNoSwizzle:
            return "kNoSwizzle"
        if self is SwizzlingMode.k32ByteSwizzle:
            return "k32ByteSwizzle"
        if self is SwizzlingMode.k64ByteSwizzle:
            return "k64ByteSwizzle"
        if self is SwizzlingMode.k128ByteSwizzle:
            return "k128ByteSwizzle"
        raise ValueError("Unknown SwizzlingMode enum entry.")



@register_attribute_builder("MosaicGPU_SwizzlingMode")
def _mosaicgpu_swizzlingmode(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class TMAReduction(IntEnum):
    """Reduction operation for TMA."""

    Add = 0
    Min = 1
    Max = 2
    Inc = 3
    Dec = 4
    And = 5
    Or = 6
    Xor = 7
    Umin = 8
    Umax = 9
    Smin = 10
    Smax = 11

    def __str__(self):
        if self is TMAReduction.Add:
            return "add"
        if self is TMAReduction.Min:
            return "min"
        if self is TMAReduction.Max:
            return "max"
        if self is TMAReduction.Inc:
            return "inc"
        if self is TMAReduction.Dec:
            return "dec"
        if self is TMAReduction.And:
            return "and"
        if self is TMAReduction.Or:
            return "or"
        if self is TMAReduction.Xor:
            return "xor"
        if self is TMAReduction.Umin:
            return "umin"
        if self is TMAReduction.Umax:
            return "umax"
        if self is TMAReduction.Smin:
            return "smin"
        if self is TMAReduction.Smax:
            return "smax"
        raise ValueError("Unknown TMAReduction enum entry.")



@register_attribute_builder("MosaicGPU_TMAReduction")
def _mosaicgpu_tmareduction(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

