# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import annotations

import dataclasses
import functools
import itertools
import math
from typing import Any, Callable, Iterator, cast

from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import llvm
from jaxlib.mlir.dialects import memref
from jaxlib.mlir.dialects import nvvm
import numpy as np

from . import fragmented_array as fa
from . import mma_utils
from . import utils
from .launch_context import LaunchContext


TMEM_ROWS = 128
TMEM_MAX_COLS = 512
TCGEN05_SMEM_DESCRIPTOR_BIT = 1 << 46
LAYOUT = fa.TCGEN05_LAYOUT
TRANSPOSED_LAYOUT = fa.TCGEN05_TRANSPOSED_LAYOUT
ROW_LAYOUT = fa.TCGEN05_ROW_LAYOUT
COL_LAYOUT = fa.TCGEN05_COL_LAYOUT
TMEM_NATIVE_LAYOUT = fa.TMEM_NATIVE_LAYOUT


def create_instr_descriptor(
    m: int,
    n: int,
    acc_dtype,
    input_dtype,
    transpose_a: bool = False,
    transpose_b: bool = False,
    sparsity_selector: int | None = None,
) -> ir.Value:
  f16 = ir.F16Type.get()
  f32 = ir.F32Type.get()
  i32 = ir.IntegerType.get_signless(32)

  desc = 0
  if sparsity_selector is not None:
    assert 0 <= sparsity_selector < 3
    desc |= sparsity_selector
    desc |= 1 << 2  # Enable sparsity
  if acc_dtype == f16:
    d_type_val = 0
  elif acc_dtype == f32:
    d_type_val = 1
  elif acc_dtype == i32:
    d_type_val = 2
  else:
    raise NotImplementedError(f"Unsupported accumulator dtype: {acc_dtype}")
  desc |= (d_type_val << 4)  # D type, bits 4-5
  # Bit 6 is reserved
  if input_dtype == f16:
    assert acc_dtype in {f16, f32}
    ab_type_val = 0
  elif input_dtype == ir.BF16Type.get():
    assert acc_dtype == f32
    ab_type_val = 1
  elif input_dtype == ir.Float8E4M3FNType.get():
    assert acc_dtype in {f16, f32}
    ab_type_val = 0
  elif input_dtype == ir.Float8E5M2Type.get():
    assert acc_dtype in {f16, f32}
    ab_type_val = 1
  elif input_dtype == ir.IntegerType.get_signless(8):  # Only s8 for now.
    assert acc_dtype == i32
    ab_type_val = 1
  else:
    raise NotImplementedError(f"Unsupported input dtype: {input_dtype}")
  desc |= (ab_type_val << 7)   # A dtype, bits 7-9
  desc |= (ab_type_val << 10)  # B dtype, bits 10-12
  # We ignore negate bits 13-14
  desc |= transpose_a << 15  # Transpose A
  desc |= transpose_b << 16  # Transpose B
  if n % 8 or n > 256:
    raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
  desc |= (n >> 3) << 17  # N, bits 17-22
  # Bit 23 is reserved
  if m % 16 or m > 256:
    raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
  desc |= (m >> 4) << 24  # M >> 4, bits 24-28
  # Bit 29 is reserved
  # We ignore max shift under .ws, bits 30-31
  return arith.constant(ir.IntegerType.get_signless(32), desc)


def _create_scaled_instr_descriptor(
    get_input_encoding: Callable[[ir.Type], int],
    m: int,
    n: int,
    a_type: ir.Type,
    b_type: ir.Type,
    a_scale_idx: int,
    b_scale_idx: int,
    transpose_a: bool,
    transpose_b: bool,
    scale_type: ir.Type,
) -> ir.Value:
  desc = 0
  # Bits 0, 1 are reserved
  # We ignore sparsity (bit 2)
  # Bit 3 is reserved
  assert 0 <= b_scale_idx < 4
  desc |= b_scale_idx << 4  # B scale factor data ID, bits 4-5
  # Bit 6 is reserved
  desc |= get_input_encoding(a_type) << 7  # A dtype, bits 7-9
  desc |= get_input_encoding(b_type) << 10  # B dtype, bits 10-12
  # We ignore negate bits 13-14
  desc |= transpose_a << 15  # Transpose A
  desc |= transpose_b << 16  # Transpose B
  if n % 8 or n > 256:
    raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
  desc |= (n >> 3) << 17  # N, bits 17-22
  if scale_type == ir.Float8E8M0FNUType.get():
    scale_encoding = 1
  elif scale_type == ir.Float8E4M3FNType.get():
    scale_encoding = 0
  else:
    raise NotImplementedError(f"Unsupported scale type: {scale_type}")
  desc |= scale_encoding << 23  # Scale matrix type
  # Bits 24-26 are reserved
  if m % 128 or m > 256:
    raise ValueError(f"M must be a multiple of 16 and <= 256, got: {m}")
  desc |= (m >> 7) << 27  # M >> 7, bits 27-28
  desc |= a_scale_idx << 29  # A scale factor data ID, bits 29-30
  # Bit 31 is reserved
  return arith.constant(ir.IntegerType.get_signless(32), desc)


def create_scaled_f8f6f4_instr_descriptor(*args, **kwargs) -> ir.Value:
  def get_input_encoding(ty):
    if ty == ir.Float8E4M3FNType.get():
      return 0
    elif ty == ir.Float8E5M2Type.get():
      return 1
    else:
      raise NotImplementedError(f"Unsupported input dtype: {ty}")
  return _create_scaled_instr_descriptor(get_input_encoding, *args, **kwargs)


def create_scaled_f4_instr_descriptor(*args, **kwargs) -> ir.Value:
  def get_input_encoding(ty):
    if ty == ir.Float4E2M1FNType.get():
      return 1
    else:
      raise NotImplementedError(f"Unsupported input dtype: {ty}")
  return _create_scaled_instr_descriptor(get_input_encoding, *args, **kwargs)


def mma(
    d: TMEMRef,
    a: ir.Value | TMEMRef,
    b: ir.Value,
    *,
    a_swizzle: int = 128,
    b_swizzle: int = 128,
    a_scale: TMEMRef | None = None,
    b_scale: TMEMRef | None = None,
    a_sparse_metadata: TMEMRef | None = None,
    accumulate: ir.Value | bool = True,
    collective: bool = False,
) -> None:
  if a_swizzle == 16 or b_swizzle == 16:
    raise NotImplementedError("No swizzle is not supported")
  i32 = ir.IntegerType.get_signless(32)
  i64 = ir.IntegerType.get_signless(64)
  if isinstance(accumulate, bool):
    accumulate = arith.constant(ir.IntegerType.get_signless(1), accumulate)
  num_cta = 2 if collective else 1
  if (is_scaled := a_scale is not None) != (b_scale is not None):
    raise ValueError("Either none or both scales should be provided")
  is_sparse = a_sparse_metadata is not None
  if is_scaled and is_sparse:
    raise NotImplementedError("Block-scaled sparse matmuls unsupported")

  # Step 1. Establish the shape and element type of the operation.
  if not ir.MemRefType.isinstance(b.type):
    raise ValueError(f"B must be a memref, got: {b.type}")
  (k, n), element_type = mma_utils.tiled_memref_shape(b)
  if isinstance(a, TMEMRef):
    m, k2 = a.shape
    element_type2 = a.dtype
    if is_scaled:
      raise NotImplementedError(
          "A in TMEM unsupported for block-scaled matmuls"
      )
    if m != 128:
      raise NotImplementedError(f"Only M=128 is supported for MMA with A in TMEM, but got M={m}")
    # Watch out: this layout must be consistent with D's layout (up to packing).
    expected_packing = 32 // utils.bitwidth(element_type)
    expected_layout = _infer_tmem_layout(
        a.shape, collective, packing=expected_packing
    )
    if a.layout != expected_layout:
      raise ValueError(
          f"A layout mismatch: expected {expected_layout}, got {a.layout}"
      )
  else:
    if not ir.MemRefType.isinstance(a.type):
      raise ValueError(f"A must be a memref, got {a.type}")
    (m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
  if is_sparse:
    k2 *= 2
  if k != k2:
    raise ValueError(
        "MMA requires A and B to have the same contraction dimension (K),"
        f" got: {k2} and {k}"
    )
  if element_type != element_type2:
    raise ValueError(
        "MMA requires A and B to have the same element type, got:"
        f" {element_type2} and {element_type}"
    )
  if d.shape != (m, n * num_cta):
    raise ValueError(
        f"Accumulator shape mismatch: expected {(m, n * num_cta)}, got {d.shape}"
    )
  if m == 128:
    if d.layout != (expected_d_layout := tmem_default_layout(packing=1)):
      raise ValueError(
          f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}"
      )
    n_lane_groups = 1
  elif m == 64:
    if is_scaled:
      raise NotImplementedError("MMA with block scaling is not supported for M=64")
    if is_sparse:
      raise NotImplementedError("Sparse MMA not supported for M=64")
    # Watch out: this layout must be consistent with A's layout (up to packing).
    # 2CTA M=128 instruction uses a different TMEM layout than 1CTA M=64.
    expected_d_layout = _infer_tmem_layout(d.shape, collective, packing=1)
    if d.layout != expected_d_layout:
      raise ValueError(
          f"Accumulator layout mismatch: expected {expected_d_layout}, got {d.layout}"
      )
    if collective:
      n_lane_groups = 1
    else:
      n_lane_groups = 2
      # We can't split N into groups if we would partition it below the tile size.
      # TODO: We only need to check this if N is the minormost dim in B.
      if 8 * b_swizzle // utils.bitwidth(element_type) > n // n_lane_groups:
        raise ValueError(
            f"Swizzle={b_swizzle} is too big for MMA with M=64. Try"
            " lowering it."
        )
  else:
    raise ValueError(f"Only M=128 and M=64 are supported for MMA, but got M={m}")
  f32 = ir.F32Type.get()
  f16 = ir.F16Type.get()
  s32 = ir.IntegerType.get_signless(32)
  if element_type == f32 or element_type == ir.BF16Type.get():
    if element_type == f32 and is_sparse:
      raise NotImplementedError("Sparse MMA unsupported for f32")
    if is_scaled:
      raise ValueError(
          f"MMA with element type {element_type} does not support block scaling"
      )
    if d.dtype != f32:
      raise ValueError(
          f"MMA with element type {element_type} only supports accumulators"
          f" of type f32, but got: {d.dtype}"
      )
  elif element_type == f16:
    if is_scaled:
      raise ValueError(
          f"MMA with element type {element_type} does not support block scaling"
      )
    if d.dtype != f16 and d.dtype != f32:
      raise ValueError(
          f"MMA with element type {element_type} only supports accumulators of"
          f" type f32 or f16, but got: {d.dtype}"
      )
  elif any(
      t.isinstance(element_type)
      for t in {ir.Float8E5M2Type, ir.Float8E4M3FNType}
  ):
    if d.dtype != f16 and d.dtype != f32:
      raise ValueError(
          f"MMA with element type {element_type} only supports accumulators of"
          f" type f32 or f16, but got: {d.dtype}"
      )
    if is_scaled and d.dtype != f32:
      raise ValueError(
          f"Block-scaled MMA with element type {element_type} only supports f32"
          f" accumulators, but got: {d.dtype}"
      )
  elif any(
      t.isinstance(element_type) for t in {ir.Float4E2M1FNType}
  ):
    if is_sparse:
      raise NotImplementedError("Sparse MMA unsupported for f4e2m1fn")
    if not is_scaled:
      raise ValueError(
          f"MMA with element type {element_type} only supports block scaling"
      )
    if d.dtype != f32:
      raise ValueError(
          f"Block-scaled MMA with element type {element_type} only supports f32"
          f" accumulators, but got: {d.dtype}"
      )
  elif element_type == ir.IntegerType.get_signless(8):
    if is_scaled:
      raise ValueError(
          f"MMA with element type {element_type} does not support block scaling"
      )
    if d.dtype != s32:
      raise ValueError(
          "MMA with element type s8 only supports s32 accumulators, but got:"
          f" {d.dtype}"
      )
  else:
    raise NotImplementedError(f"Unsupported element type: {element_type}")

  # Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
  # instructions must be issued in groups that are a multiple of swizzle.
  m_group_elems = m  # We have already verified M is supported above.
  k_group_elems = 8 * max(a_swizzle * (1 + is_sparse), b_swizzle) // utils.bitwidth(element_type)
  if is_sparse and k_group_elems < 64:
    # This is a limitation of the implementation below. We could relax it if we
    # ever need to support k=32.
    k_group_elems = 64
  scale_block: int | None = None
  if is_scaled:
    scale_block = 32 if a_scale.dtype == ir.Float8E8M0FNUType.get() else 16  # type: ignore
    k_group_elems = max(k_group_elems, 4 * scale_block)
  required_multiple = 16 if collective else 8
  mode_name = "2 CTA" if collective else "1 CTA"
  if d.dtype == s32:
    required_multiple *= 2
    mode_name += " integer"
  if n_lane_groups > 1:
    mode_name += f" with {n_lane_groups} lane groups"
  if (n // n_lane_groups) % required_multiple != 0:
    raise ValueError(
        f"In {mode_name} MMA, N must be a multiple of {required_multiple},"
        f" got N={n}"
    )
  if (is_sparse or is_scaled) and n.bit_count() != 1:
    raise NotImplementedError(
        "Only N that is power of 2 supported for sparse and block-scaled MMA,"
        f" but got N={n}"
    )
  if n > 256 and n.bit_count() != 1:
    raise NotImplementedError(f"The only supported N > 256, is 512, but got N={n}")
  # TODO: We could relax those constraints if we have multiple n_lane_groups,
  # since we will be unrolling the instructions anyway.
  if collective and n > 128:
    raise ValueError("Only N <= 128 are supported for collective MMA")
  elif n > 512:
    raise ValueError("Only N <= 512 are supported for MMA")
  n_group_elems = min(n // n_lane_groups, 256 // num_cta)
  if m % m_group_elems:
    raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
  if k % k_group_elems:
    raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
  if n % n_group_elems:
    raise ValueError(f"N must be a multiple of {n_group_elems}, got: {n}")
  m_groups = m // m_group_elems
  k_groups = k // k_group_elems
  n_groups = n // n_group_elems
  # TODO(apaszke): Require users to bitcast input refs to tf32 before MMA.
  mma_element_type = (
      ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
  )

  # Check that the shapes and element types are correct for block scaling.
  scale_element_type = None
  if is_scaled:
    assert m == 128  # Checked above.
    if n % 32:
      raise ValueError(
          f"MMA with block scaling requires N to be divisible by 32, got: {n}"
      )
    assert a_scale is not None and b_scale is not None
    scale_element_type = a_scale.dtype
    if (
        a_scale.dtype != ir.Float8E8M0FNUType.get()
        and a_scale.dtype != ir.Float8E4M3FNType.get()
    ):
      raise ValueError(
          f"A scale dtype mismatch: expected f8e8m0fnu or f8e4m3fn, got {a_scale.dtype}"
      )
    if b_scale.dtype != a_scale.dtype:
      raise ValueError(
          f"B scale dtype mismatch: expected {a_scale.dtype} (same as A), got"
          f" {b_scale.dtype}"
      )
    if a_scale.shape != (m, k // scale_block):
      raise ValueError(
          f"A scale shape mismatch: expected ({m}, {k // scale_block}), got"
          f" {a_scale.shape}"
      )
    if b_scale.shape != (n * num_cta, k // scale_block):
      raise ValueError(
          f"B scale shape mismatch: expected ({n}, {k // scale_block}), got"
          f" {b_scale.shape}"
      )
  if is_sparse:
    a_sparse_metadata = cast(TMEMRef, a_sparse_metadata)
    if n % 32:
      raise ValueError(f"Sparse MMA requires N to be divisible by 32, got: {n}")
    if a_sparse_metadata.shape != (m, k // 2):
      raise ValueError(
          f"A sparse metadata shape mismatch: expected {(m, k // 2)}, got"
          f" {a_sparse_metadata.shape}"
      )
    if a_sparse_metadata.dtype != ir.IntegerType.get_signless(2):
      raise ValueError(
          "A sparse metadata dtype mismatch: expected i2, got"
          f" {a_sparse_metadata.dtype}"
      )

  # Step 3. Compute the operand descriptors.
  if not isinstance(a, TMEMRef):
    # Both dense and sparse matmul consume A with a K bytewidth of 32, only
    # the group size is halved when it's sparse.
    (
        (a_desc_base, a_k_instr_strides),
        (a_m_group_stride, a_k_group_stride),
        a_fastest,
    ) = mma_utils.create_descriptor(
        a,
        swizzle=a_swizzle,
        group_size=(m_group_elems, k_group_elems // (1 + is_sparse)),
        logical_k_major=False,
        mma_bytewidth_k=32,
        split_const=True,
    )
  else:
    a_fastest = mma_utils.Dim.K
    a_k_instr_strides = None
    a_m_group_stride = a_k_group_stride = a_desc_base = None
  (
      (b_desc_base, b_k_instr_strides),
      (b_n_group_stride, b_k_group_stride),
      b_fastest,
  ) = mma_utils.create_descriptor(
      b,
      swizzle=b_swizzle,
      group_size=(k_group_elems, n_group_elems),
      logical_k_major=True,
      mma_bytewidth_k=64 if is_sparse else 32,
      split_const=True,
  )

  if is_scaled and utils.bitwidth(mma_element_type) == 4:
    if a_fastest != mma_utils.Dim.K:
      raise ValueError(
          "4-bit block scaled MMA only supports K-fastest operands, but A is M-fastest"
      )
    if b_fastest != mma_utils.Dim.K:
      raise ValueError(
          "4-bit block scaled MMA only supports K-fastest operands, but B is N-fastest"
      )
  if is_sparse:
    if b_swizzle == 32 and b_fastest == mma_utils.Dim.K:
      raise NotImplementedError(
          "B tiling too small. Increase swizzle or transpose the input."
      )

  # Step 4. Issue the instructions.
  true = arith.constant(ir.IntegerType.get_signless(1), 1)
  n_collective_group_elems = n_group_elems * num_cta
  n_col_groups = n_groups // n_lane_groups
  assert d.layout.base_tile_shape[0] % 4 == 0
  lanes_per_n_group = d.layout.base_tile_shape[0] // 4
  a_sparse_addr_base = a_sparse_metadata.address if is_sparse else None  # type: ignore
  a_scale_addr_base = a_scale.address if is_scaled else None  # type: ignore
  b_scale_addr_base = b_scale.address if is_scaled else None  # type: ignore
  for mi, ni, ki in np.ndindex(m_groups, n_groups, k_groups):
    if isinstance(a, TMEMRef):
      if m_groups != 1:
        raise NotImplementedError("A address calculation for multiple M tiles")
      a_k_group_elems = k_group_elems // (1 + is_sparse)
      a_mk = a.slice(slice(None), utils.ds(ki * a_k_group_elems, a_k_group_elems)).address
    else:
      a_offset = mi * a_m_group_stride + ki * a_k_group_stride
      a_mk = (a_desc_base[0], a_desc_base[1] + mma_utils.encode_addr(a_offset))
    b_offset = ni * b_n_group_stride + ki * b_k_group_stride
    b_nk = (b_desc_base[0], b_desc_base[1] + mma_utils.encode_addr(b_offset))
    if a_sparse_addr_base is not None:
      if n_groups != 1 or m_groups != 1:
        raise NotImplementedError("A sparse metadata address calculation for multiple tiles")
      assert k_group_elems % 32 == 0
      cols_per_k_group = k_group_elems // 32
      a_sparse_addr = arith.addi(a_sparse_addr_base, utils.c(ki * cols_per_k_group, i32))
    else:
      a_sparse_addr = None
    if a_scale_addr_base is not None and b_scale_addr_base is not None:
      if m_groups != 1:
        raise NotImplementedError("A scale address calculation for multiple M tiles")
      if n_groups != 1:
        raise NotImplementedError("B scale address calculation for multiple N tiles")
      assert scale_block is not None  # For type checkers.
      assert k_group_elems % (scale_block * 4) == 0
      assert m_group_elems % 32 == 0 and n_group_elems % 32 == 0
      k_scales_per_group = k_group_elems // (scale_block * 4)
      # A scales are sharded, B scales are replicated across CTAs.
      a_scale_addr = arith.addi(
          a_scale_addr_base,
          utils.c(ki * k_scales_per_group * m_group_elems // 32, i32),
      )
      b_scale_addr = arith.addi(
          b_scale_addr_base,
          utils.c(ki * k_scales_per_group * n_collective_group_elems // 32, i32)
      )
    else:
      a_scale_addr = b_scale_addr = None
    acc = accumulate if ki == 0 else true
    ni_lane_group, ni_col = ni // n_col_groups, ni % n_col_groups
    d_offset = (
        ((ni_lane_group * lanes_per_n_group) << 16)
        + ni_col * n_collective_group_elems
    )
    if m_groups != 1:
      raise NotImplementedError("D address calculation for multiple M tiles")
    _do_mma(
        arith.addi(d.address, arith.constant(i32, d_offset)),
        a_mk,
        b_nk,
        d_type=d.dtype,
        m=m_group_elems,
        n=n_group_elems,
        k=k_group_elems,
        collective=collective,
        a_transpose=a_fastest != mma_utils.Dim.K,
        b_transpose=b_fastest != mma_utils.Dim.K,
        a_k_strides=a_k_instr_strides,
        b_k_strides=b_k_instr_strides,
        a_scale_addr=a_scale_addr,
        b_scale_addr=b_scale_addr,
        a_sparse_addr=a_sparse_addr,
        accumulate=acc,
        element_type=mma_element_type,
        scale_element_type=scale_element_type,
    )


def _do_mma(
    d_addr: ir.Value,
    a_desc_or_addr: tuple[ir.Value, int] | ir.Value,  # TMEM address if a_k_stride is None
    b_desc: tuple[ir.Value, int],
    a_transpose: bool,
    b_transpose: bool,
    a_k_strides: tuple[tuple[int, ...], tuple[int, ...]] | None,
    b_k_strides: tuple[tuple[int, ...], tuple[int, ...]],
    a_scale_addr: ir.Value | None,
    b_scale_addr: ir.Value | None,
    a_sparse_addr: ir.Value | None,
    m: int,
    n: int,
    k: int,
    element_type: ir.Type,
    scale_element_type: ir.Type | None,
    d_type: ir.Type,
    accumulate: ir.Value,
    collective: bool,
) -> None:
  i1 = ir.IntegerType.get_signless(1)
  i32 = ir.IntegerType.get_signless(32)
  i64 = ir.IntegerType.get_signless(64)
  a_k_idx_tiling, a_k_strides = a_k_strides or (None, None)
  b_k_idx_tiling, b_k_strides = b_k_strides
  assert all(s % 16 == 0 for s in itertools.chain(a_k_strides or (), b_k_strides))
  assert (a_scale_addr is None) == (b_scale_addr is None)
  is_scaled = a_scale_addr is not None
  is_sparse = a_sparse_addr is not None
  elem_bitwidth = utils.bitwidth(element_type)
  instr_k = (1 + is_sparse) * 8 * 32 // elem_bitwidth
  packing = 8 * 4 // elem_bitwidth

  scale_steps = None
  if is_scaled:
    assert not is_sparse
    if (ir.Float8E5M2Type.isinstance(element_type) or
        ir.Float8E4M3FNType.isinstance(element_type)):
      if scale_element_type != ir.Float8E8M0FNUType.get():
        raise ValueError(
            f"Scale element type mismatch: expected f8e8m0fnu, got {scale_element_type}"
        )
      kind = "mxf8f6f4.block_scale.scale_vec::1X"
      scale_steps = 4
      create_scaled_instr_descriptor = functools.partial(
          create_scaled_f8f6f4_instr_descriptor, scale_type=scale_element_type
      )
    elif ir.Float4E2M1FNType.isinstance(element_type):
      assert not a_transpose and not b_transpose
      create_scaled_instr_descriptor = functools.partial(
          create_scaled_f4_instr_descriptor,
          scale_type=scale_element_type,
      )
      if scale_element_type == ir.Float8E8M0FNUType.get():
        kind = "mxf4.block_scale.scale_vec::2X"
        scale_steps = 2
      elif scale_element_type == ir.Float8E4M3FNType.get():
        kind = "mxf4nvf4.block_scale.scale_vec::4X"
        scale_steps = 1
    else:
      raise NotImplementedError(f"Unsupported element type for block scaling: {element_type}")
    extra_ptx = "[$5], [$6], "
    extra_constraints = ",r,r"
  else:
    if ir.F16Type.isinstance(element_type) or ir.BF16Type.isinstance(element_type):
      kind = "f16"
    elif ir.Float8E5M2Type.isinstance(element_type):
      kind = "f8f6f4"
    elif ir.Float8E4M3FNType.isinstance(element_type):
      kind = "f8f6f4"
    elif ir.IntegerType.get_signless(8).isinstance(element_type):
      kind = "i8"
    else:
      raise NotImplementedError(f"Unsupported input element type: {element_type}")
    extra_constraints = extra_ptx = ""

    def create_scaled_instr_descriptor(*args):  # type: ignore
      raise NotImplementedError

  num_cta = 2 if collective else 1
  a_in_tmem = a_k_strides is None
  a_ptx = "[a_desc]" if a_in_tmem else "a_desc"
  sparse_mod = ".sp" if is_sparse else ""
  sparse_meta_ptx = "[$5], " if is_sparse else ""
  extra_constraints += ",r" if is_sparse else ""
  sparse_addr: tuple[Any, ...] = ()
  scales_addrs: tuple[Any, ...] = ()
  def _get_offset(idx: int, idx_tiling: tuple[int, ...], strides: tuple[int, ...]):
    assert len(idx_tiling) + 1 == len(strides)
    idxs = []
    for t in idx_tiling:
      idxs.append(idx // t)
      idx = idx % t
    idxs.append(idx)
    offset = sum(i * s for i, s in zip(idxs, strides, strict=True))
    return offset >> 4
  for k_step in range(k // instr_k):
    if is_scaled:
      assert scale_steps is not None
      assert not is_sparse
      scale_vec_width = 4 // scale_steps
      scale_id = (k_step % scale_steps) * scale_vec_width
      i_desc = create_scaled_instr_descriptor(
          m * num_cta, n * num_cta, element_type, element_type,
          scale_id, scale_id, a_transpose, b_transpose
      )
      assert m == 128
      assert n % 128 == 0
      # A scales are sharded, B scales are replicated across CTAs.
      a_scale_addr_offset = arith.constant(i32, k_step // scale_steps * 4)
      b_scale_addr_offset = arith.constant(i32, k_step // scale_steps * n // 32 * num_cta)
      scales_addrs = (
          arith.addi(a_scale_addr, a_scale_addr_offset),
          arith.addi(b_scale_addr, b_scale_addr_offset),
      )
    else:
      sp_selector = None
      if is_sparse:
        assert 32 <= instr_k <= 64
        selector_width = instr_k
        k_steps_for_col_inc = 64 // selector_width
        assert (k // instr_k) % k_steps_for_col_inc == 0
        sp_selector = k_step % k_steps_for_col_inc
        # If the K group is large, we need to increment the sparse metadata.
        # TODO(apaszke): At this point the purpose of this function is becoming
        # less clear, since we end up replicating address arithmetic that's
        # already there in the caller. We should unify them into a single loop.
        sparse_addr = (
            arith.addi(
                a_sparse_addr, utils.c(k_step // k_steps_for_col_inc * 2, i32)
            ),
        )
      i_desc = create_instr_descriptor(
          m * num_cta, n * num_cta, d_type, element_type, a_transpose, b_transpose, sparsity_selector=sp_selector
      )
    if a_in_tmem:
      cols_per_k_group = instr_k // packing // (1 + is_sparse)
      a_offset = k_step * cols_per_k_group
      assert isinstance(a_desc_or_addr, ir.Value)
      assert a_desc_or_addr.type == ir.IntegerType.get_signless(32)
      a_enc_addr_base = a_desc_or_addr
    else:
      assert a_k_idx_tiling is not None and a_k_strides is not None
      a_enc_addr_base, a_offset = a_desc_or_addr
      a_offset += _get_offset(k_step, a_k_idx_tiling, a_k_strides)
    b_enc_addr_base, b_offset = b_desc
    b_offset += _get_offset(k_step, b_k_idx_tiling, b_k_strides)
    a_offset_low, a_offset_high = a_offset & 0xFFFFFFFF, a_offset >> 32
    b_offset_low, b_offset_high = b_offset & 0xFFFFFFFF, b_offset >> 32
    llvm.inline_asm(
        ir.Type.parse("!llvm.void"),
        [d_addr, a_enc_addr_base, b_enc_addr_base, i_desc, accumulate, *scales_addrs, *sparse_addr],
        f"""{{
            .reg .b32 a_desc_low, a_desc_high, b_desc_low, b_desc_high;
            .reg {".b32" if a_in_tmem else ".b64"} a_desc;
            .reg .b64 b_desc;
            add.s32 a_desc_low, $1, {a_offset_low};
            add.s32 b_desc_low, $2, {b_offset_low};
            mov.b64 b_desc, {{b_desc_low, {b_offset_high}}};
            {"mov.b32 a_desc, a_desc_low;" if a_in_tmem else f"mov.b64 a_desc, {{a_desc_low, {a_offset_high}}};"}
            tcgen05.mma{sparse_mod}.cta_group::{num_cta}.kind::{kind} [$0], {a_ptx}, b_desc, {sparse_meta_ptx}$3, {extra_ptx}$4;
        }}""",
        "r,r,r,r,b" + extra_constraints,
        has_side_effects=True,
    )
    accumulate = arith.constant(i1, 1)


def commit_arrive(
    barrier: utils.BarrierRef | ir.Value,
    collective: bool = False,
    ctx: LaunchContext | None = None,
) -> None:
  if isinstance(barrier, utils.BarrierRef):
    barrier = barrier.get_ptr()
  elif barrier.type != ir.Type.parse("!llvm.ptr<3>"):
    raise ValueError(
        "barrier must be a Mosaic barrier or a SMEM pointer, got:"
        f" {barrier.type}"
    )
  if collective:
    if ctx is None:
      raise ValueError("ctx must be provided for collective barriers")
    # TODO(apaszke): This is just 0b11 shifted by the even CTA index.
    if ctx.cluster_size != (2, 1, 1):
      raise NotImplementedError("Collective arrivals only support (2, 1, 1)-shaped clusters")
    i16 = ir.IntegerType.get_signless(16)
    mask = arith.constant(i16, 3)
    nvvm.tcgen05_commit(
        barrier, group=nvvm.CTAGroupKind.CTA_2, multicast_mask=mask
    )
  else:
    nvvm.tcgen05_commit(barrier)


def tmem_alloc_exact_ncols(ncols: int, exact: bool) -> int:
  """Returns the exact number of columns to allocate in TMEM.

  The number of columns is rounded up to the nearest power of 2.

  Args:
    ncols: The number of columns to allocate.
    exact: If true, throws an error if the number of columns is not a power of 2
      and within [32, 512].
  """
  if exact:
    if ncols.bit_count() != 1 or not 32 <= ncols <= 512:
      raise ValueError(f"ncols must be a power of 2 and within [32, 512], got: {ncols}")
  else:
    ncols = max(32, 1 << (ncols - 1).bit_length())
    if ncols > 512:
      raise ValueError(
          f"After rounding up, got {ncols} columns, exceeding the limit of 512"
      )
  return ncols


def tmem_alloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True) -> tuple[ir.Value, int]:
  if ir.MemRefType.isinstance(tmem_addr.type):
    ref_ty = ir.MemRefType(tmem_addr.type)
    if ref_ty.element_type != ir.IntegerType.get_signless(32):
      raise ValueError(f"tmem_addr must be an i32 memref, got: {ref_ty}")
    if not utils.is_smem_ref(ref_ty):
      raise ValueError(f"tmem_addr must be in shared memory, got: {ref_ty}")
    if math.prod(ref_ty.shape) != 1:
      raise ValueError(f"tmem_addr must contain a single element, got: {ref_ty}")
    tmem_addr = utils.memref_ptr(tmem_addr, memory_space=3)
  elif tmem_addr.type != ir.Type.parse("!llvm.ptr<3>"):
    raise ValueError(f"tmem_addr must be an SMEM pointer or a memref, got: {tmem_addr.type}")
  ncols = tmem_alloc_exact_ncols(ncols, exact)
  group = nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1
  i32 = ir.IntegerType.get_signless(32)
  return nvvm.tcgen05_alloc(tmem_addr, utils.c(ncols, i32), group=group), ncols


def _tmem_addr_to_ptr(tmem_addr: ir.Value) -> ir.Value:
  assert tmem_addr.type == ir.IntegerType.get_signless(32)
  ptr_ty = ir.Type.parse("!llvm.ptr<6>")
  return llvm.inttoptr(ptr_ty, tmem_addr)


def tmem_dealloc(tmem_addr: ir.Value, ncols: int, collective: bool = False, exact: bool = True) -> None:
  if tmem_addr.type != ir.IntegerType.get_signless(32):
    raise ValueError(f"tmem_addr must be an i32, got: {tmem_addr.type}")
  ncols = tmem_alloc_exact_ncols(ncols, exact)
  group = nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1
  i32 = ir.IntegerType.get_signless(32)
  nvvm.tcgen05_dealloc(
      _tmem_addr_to_ptr(tmem_addr), utils.c(ncols, i32), group=group
  )


def tmem_relinquish_alloc_permit(collective: bool) -> None:
  group = nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1
  nvvm.tcgen05_relinquish_alloc_permit(group=group)

def _tmem_access_helper(shape, num) -> tuple[int, str]:
  if num.bit_count() != 1 or num > 128:
    raise ValueError(f"num must be a power of 2 and <= 128, got: {num}")
  match shape:
    case "32x32b":
      num_regs = 1
    case "16x128b":
      num_regs = 2
    case "16x256b":
      num_regs = 4
    case _:
      raise NotImplementedError(f"{shape=} is unsupported")
  num_regs *= num
  if num_regs > 255:
    raise ValueError(
        f"TMEM translation too big : {shape=} and {num=} involve"
        f" {num_regs} registers per-thread, which exceeds the limit of 255"
    )
  regs_vector = ",".join(f"${i}" for i in range(num_regs))
  regs_vector = "{" + regs_vector + "}"
  return num_regs, regs_vector


def _tmem_load(tmem_addr, shape, num, pack: bool):
  i32 = ir.IntegerType.get_signless(32)
  num_out_regs, regs_vector = _tmem_access_helper(shape, num)
  pack_mod = ".pack::16b" if pack else ""
  regs = llvm.inline_asm(
      ir.Type.parse(
          "!llvm.struct<(" + ",".join("i32" for _ in range(num_out_regs)) + ")>"
      ),
      [tmem_addr],
      f"tcgen05.ld.sync.aligned.{shape}.x{num}{pack_mod}.b32 {regs_vector}, [${num_out_regs}];",
      "=r," * num_out_regs + "r",
      has_side_effects=True,
  )
  return [llvm.extractvalue(i32, regs, [i]) for i in range(num_out_regs)]


def _tmem_store(tmem_addr, shape, num, regs, unpack: bool) -> None:
  num_out_regs, regs_vector = _tmem_access_helper(shape, num)
  pack_mod = ".unpack::16b" if unpack else ""
  llvm.inline_asm(
      ir.Type.parse("!llvm.void"),
      [*regs, tmem_addr],
      f"tcgen05.st.sync.aligned.{shape}.x{num}{pack_mod}.b32 [${num_out_regs}], {regs_vector};",
      "r," * num_out_regs + "r",
      has_side_effects=True,
  )


@dataclasses.dataclass(frozen=True)
class TMEMLayout(fa.TiledLayout):
  """Represents the way a shape is laid out in TMEM.

  The layout describes how the shape is split across the 128 rows (lanes) of
  TMEM. We reinterpret warp_dims as the partitioning of TMEM into 4 banks, each
  accessible from a single warp. The 32 lanes inside each bank are assigned
  consecutive elements from lane_dims. The data within each lane is linearized
  in row-major order, with each vector padded up to 32 bits (wider vectors are
  unsupported).
  """

  def check_type(self, shape: tuple[int, ...], bitwidth: int) -> None:
    if len(shape) != 2:
      raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
    if any(s % t for s, t in zip(shape, self.base_tile_shape)):
      raise ValueError(
          f"{shape} is not divisible into tiles of shape {self.base_tile_shape}"
      )
    if self.vector_length not in {1, fully_packed := 32 // bitwidth}:
      raise ValueError(
          f"For {bitwidth}-bit types, the vector length must be 1 or"
          f" {fully_packed} , but got: {self.vector_length}"
      )

  def cols_in_shape(self, shape: tuple[int, int], bitwidth: int) -> int:
    self.check_type(shape, bitwidth)
    replication_factor = 1
    for dim in self.warp_dims:
      if isinstance(dim, fa.Replicated):
        replication_factor *= dim.times
    for dim in self.lane_dims:
      if isinstance(dim, fa.Replicated):
        replication_factor *= dim.times
    return math.prod(shape) // TMEM_ROWS // self.vector_length * replication_factor

  def canonicalize(self) -> TMEMLayout:
    layout = super().canonicalize()
    return TMEMLayout(
        layout.tiling,
        layout.warp_dims,
        layout.lane_dims,
        layout.vector_dim,
        _check_canonical=False,
    )

  def as_tiled_layout(self) -> fa.TiledLayout:
    return fa.TiledLayout(
        self.tiling, self.warp_dims, self.lane_dims, self.vector_dim
    )


def _infer_tmem_load_registers_layout(
    tmem_layout: TMEMLayout, columns: int, packing: int
) -> fa.TiledLayout:
  if tmem_layout == tmem_default_layout(packing=packing):
    return LAYOUT
  if tmem_layout == tmem_half_lane_layout(columns, packing=packing):
    return fa.WGMMA_LAYOUT
  if tmem_layout == tmem_m64_collective_layout(columns, packing=packing):
    return fa_m64_collective_layout(columns)
  raise ValueError(f"TMEM layout {tmem_layout} is not supported")


def _infer_tmem_layout(shape: tuple[int, int], collective: bool, packing: int) -> TMEMLayout:
  if len(shape) != 2:
    raise ValueError(f"TMEM can only represent 2D shapes, got {shape}")
  if packing > 8 or packing.bit_count() != 1:
    raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}")
  if shape[1] % packing:
    raise ValueError(f"Minor dimension of shape must be divisible by packing, got: {shape}")
  if shape[0] == TMEM_ROWS:
    return tmem_default_layout(packing)
  elif shape[0] == TMEM_ROWS // 2:
    if collective:
      return tmem_m64_collective_layout(shape[1], packing)
    else:
      return tmem_half_lane_layout(shape[1], packing)
  else:
    raise ValueError(
        f"Unsupported shape: {shape}. TMEM references must have either"
        f" {TMEM_ROWS} or {TMEM_ROWS // 2} rows, but got {shape[0]}."
    )


def tmem_default_layout(packing: int = 1) -> TMEMLayout:
  """A TMEM layout used for 1CTA MMA with M=128 and 2CTA MMA with M=256."""
  if packing.bit_count() != 1:
    raise ValueError(f"Packing must be a power of 2, got: {packing}")
  return TMEMLayout(
      fa.Tiling(((TMEM_ROWS, packing), (fa.WARP_SIZE, packing))),
      warp_dims=(-4,),
      lane_dims=(-2,),
      vector_dim=-1,
  )


def tmem_half_lane_layout(columns, packing: int = 1) -> TMEMLayout:
  """A TMEM layout used for 1CTA MMA with M=64."""
  if packing > columns or packing.bit_count() != 1:
    raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}")
  if columns % 16:
    raise ValueError(f"Columns must be a multiple of 16, got: {columns}")
  return TMEMLayout(
      fa.Tiling((
          (TMEM_ROWS // 2, columns),
          (fa.WARP_SIZE // 2, columns // 2),
          (packing,),
      )),
      warp_dims=(-5,),
      lane_dims=(-4, -3),
      vector_dim=-1,
  )


def tmem_m64_collective_layout(columns: int, packing: int = 1) -> TMEMLayout:
  """A TMEM layout used for 2CTA MMA with M=128."""
  if packing > 8 or packing.bit_count() != 1:
    raise ValueError(f"Packing must be <= 8 and a power of 2, got: {packing}")
  if columns % 16:
    raise ValueError(f"Columns must be a multiple of 16, got: {columns}")
  return TMEMLayout(
      fa.Tiling((
          (TMEM_ROWS // 2, columns),
          (fa.WARP_SIZE, columns // 2),
          (packing,),
      )),
      warp_dims=(-4, -5,),
      lane_dims=(-3,),
      vector_dim=-1,
  )


def fa_m64_collective_layout(columns: int) -> fa.TiledLayout:
  """The register layout for transfers to/from tmem_m64_collective_layout."""
  if columns % 16:
    raise ValueError(f"Columns must be a multiple of 16, got: {columns}")
  return fa.TiledLayout(
      fa.Tiling((
          (TMEM_ROWS // 2, columns), (fa.WARP_SIZE, columns // 2), (8, 8), (2,)
      )),
      warp_dims=(-6, -7),
      lane_dims=(-3, -2),
      vector_dim=-1,
  )


def scales_layout() -> TMEMLayout:
  """A TMEM layout for A and B scales in .scale_vec::1X configuration.

  See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
  """
  return TMEMLayout(
      fa.Tiling(((TMEM_ROWS, 4), (TMEM_ROWS // 4, 1))),
      warp_dims=(fa.Replicated(times=4),),
      lane_dims=(-2,),
      vector_dim=-3,
  )


def sparse_meta_layout() -> TMEMLayout:
  """A TMEM layout for A sparsity metadata.

  See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-sparse-matrices-sparsity-selector-kind-tf32-m128-256
  """
  # TODO(apaszke): This does not really describe this layout and we can't do it
  # until we add support for multiple vector dims. Still, it's ok to do for now,
  # because we don't use TMEM layouts for any automatic transformations at the
  # moment and only ever compare it for equality.
  return TMEMLayout(
      fa.Tiling(((TMEM_ROWS, 16), (TMEM_ROWS // 4, 1), (16, 1), (8, 1))),
      warp_dims=(-8,),
      lane_dims=(-2, -4, -6),
      vector_dim=-7,
  )


@dataclasses.dataclass(frozen=True)
class TMEMRef:
  address: ir.Value
  shape: tuple[int, int]
  dtype: ir.Type
  layout: TMEMLayout

  @property
  def packing(self) -> int:
    return self.layout.vector_length

  def __post_init__(self):
    packed_bitwidth = utils.bitwidth(self.dtype) * self.packing
    if not packed_bitwidth <= 32:
      raise ValueError("Expected packed packed bitwidth to be <= 32, but got: "
                       f"{packed_bitwidth=}")

  @classmethod
  def from_alloc(
      cls,
      tmem_addr_ref: ir.Value,
      shape: tuple[int, int],
      dtype,
      collective: bool | None = None,
      layout: TMEMLayout | None = None,
  ) -> TMEMRef:
    i32 = ir.IntegerType.get_signless(32)
    if not ir.MemRefType.isinstance(tmem_addr_ref.type):
      raise ValueError(f"tmem_addr_ref must be a memref or a pointer, got: {tmem_addr_ref.type}")
    addr_ref_ty = ir.MemRefType(tmem_addr_ref.type)
    if not utils.is_smem_ref(addr_ref_ty):
      raise ValueError(f"tmem_addr_ref must be in shared memory, got: {addr_ref_ty}")
    if addr_ref_ty.element_type != i32:
      raise ValueError(f"tmem_addr_ref must be an i32 memref, got: {addr_ref_ty}")
    if math.prod(addr_ref_ty.shape) != 1:
      raise ValueError(f"tmem_addr_ref must contain a single element, got: {addr_ref_ty}")
    i0 = arith.ConstantOp.create_index(0)
    tmem_addr = memref.load(tmem_addr_ref, [i0] * addr_ref_ty.rank)
    if shape[0] < 32:
      raise ValueError(f"TMEM refs must have at least 32 rows, got: {shape[0]}")
    if layout is None:
      if collective is None:
        raise ValueError(
            "collective argument must be provided when TMEM layout is inferred"
        )
      layout = _infer_tmem_layout(shape, collective, packing=1)
    else:
      layout.check_type(shape, utils.bitwidth(dtype))
    # TODO: Do we have to do this??
    # warp_idx = utils.warp_idx(sync=False)
    # tmem_addr = arith.ori(tmem_addr, arith.shli(warp_idx, utils.c(21, i32)))
    return cls(tmem_addr, shape, dtype, layout)

  def slice(self, *idxs) -> TMEMRef:
    i32 = ir.IntegerType.get_signless(32)
    base_idx, slice_shape, is_squeezed = utils.parse_indices(idxs, self.shape)
    if any(is_squeezed):
      raise ValueError("TMEM can only be sliced, not indexed")
    if base_idx == [0] * len(base_idx) and slice_shape == list(self.shape):
      return self  # Trival slice
    if self.layout != tmem_default_layout(packing=self.packing):
      raise NotImplementedError(
          "Slicing only implemented for refs with standard layout, got:"
          f" {self.layout}"
      )
    if base_idx[0] != 0 or slice_shape[0] != TMEM_ROWS:
      raise NotImplementedError("TMEM cannot be sliced along rows")
    if slice_shape[1] % 8:
      raise NotImplementedError(
          "TMEM column slice length must be a multiple of 8. "
          f"Got {slice_shape[1]}."
      )
    col_idx = base_idx[1]
    if not isinstance(col_idx, ir.Value):
      col_idx = arith.constant(i32, col_idx)
    if col_idx.type == ir.IndexType.get():
      col_idx = arith.index_cast(i32, col_idx)
    if self.packing != 1:
      col_idx = arith.divui(col_idx, arith.constant(i32, self.packing))
    return TMEMRef(
        address=arith.addi(self.address, col_idx),
        shape=cast(tuple[int, int], tuple(slice_shape)),
        layout=self.layout,
        dtype=self.dtype,
    )

  def load(self, layout: fa.TiledLayout | None = None, is_signed: bool | None = None) -> fa.FragmentedArray:
    packing = self.packing
    if layout is None:
      layout = _infer_tmem_load_registers_layout(
          self.layout, self.shape[1], packing
      )
    bitwidth = utils.bitwidth(self.dtype)
    has_default_layout = self.layout == tmem_default_layout(packing=packing)
    regs_shape = layout.registers_shape(self.shape)
    if regs_shape[0] != 1:  # We'll need to issue multiple loads below.
      raise NotImplementedError("Loading multiple row tiles")
    if layout == LAYOUT and self.layout == tmem_default_layout(packing=packing):
      registers = _load_32xcols(
          self.address, self.shape[1], self.dtype, packing
      ).T.reshape(regs_shape)
    elif layout == self.layout.as_tiled_layout() and packing * bitwidth == 32:
      assert len(layout.base_tile_shape) == 2
      # We could allow replicated dims in the input, but we'd need to divide the
      # split factor computed below by the replication factor of the input.
      assert not any(isinstance(d, fa.Replicated) for d in layout.warp_dims)
      assert not any(isinstance(d, fa.Replicated) for d in layout.lane_dims)
      warp_split_factor = math.prod(
          d.times if isinstance(d, fa.Replicated) else 1
          for d in layout.remove_dimension(1).warp_dims
      )
      lane_split_factor = math.prod(
          d.times if isinstance(d, fa.Replicated) else 1
          for d in layout.remove_dimension(1).lane_dims
      )
      split_factor = warp_split_factor * lane_split_factor
      registers = _load_32xcols_native(
          self.address, self.shape[1] // split_factor, self.dtype, packing, packing
      ).reshape(regs_shape)
    # TODO(apaszke): Support the case where we have a long vector length in the
    # FA more generally, not just for 2x32b.
    # 16-bit types are special, because the store instruction can unpack them.
    elif layout == TMEM_NATIVE_LAYOUT and has_default_layout and (
        (bitwidth == 16 and packing == 1)
        or (bitwidth == 32 and layout.vector_length == 2)
    ):
      registers = _load_32xcols_native(
          self.address, self.shape[1], self.dtype, packing, TMEM_NATIVE_LAYOUT.vector_length
      ).reshape(regs_shape)
    elif layout == fa.WGMMA_LAYOUT and self.layout == tmem_half_lane_layout(self.shape[1], packing=packing):
      # Load half the columns, since they are folded over lanes.
      raw_registers = _load_32xcols(
          self.address, self.shape[1] // 2, self.dtype, packing
      )
      assert raw_registers.shape[0] == 4
      registers = np.concatenate([raw_registers[:2], raw_registers[2:]], axis=1)
      registers = registers.T.reshape(regs_shape)
    elif layout == fa_m64_collective_layout(self.shape[1]) and self.layout == tmem_m64_collective_layout(self.shape[1], packing=packing):
      regs_shape = layout.registers_shape(self.shape)
      # We take half the columns, because they are split over halves of TMEM.
      registers = _load_32xcols(
          self.address, self.shape[1] // 2, self.dtype, packing
      ).reshape(regs_shape)
    else:
      raise ValueError(
          f"Loads from TMEM layout {self.layout} to register layout"
          f" {layout} are not supported"
      )
    return fa.FragmentedArray(
        _registers=registers, _layout=layout, _is_signed=is_signed
    )

  def store(self, value: fa.FragmentedArray):
    if not isinstance(value, fa.FragmentedArray):
      raise TypeError(f"TMEM stores expect a FragmentedArray, got: {value}")
    if value.shape != self.shape:
      raise ValueError(
          f"Stored array has shape {value.shape}, but TMEM has shape"
          f" {self.shape}"
      )
    if value.mlir_dtype != self.dtype:
      raise ValueError(
          f"Stored array has dtype {value.mlir_dtype}, but TMEM has dtype"
          f" {self.dtype}"
      )
    if not isinstance(value.layout, fa.TiledLayout):
      raise TypeError(f"Stored array has layout {value.layout}, but TMEM stores expect a TiledLayout")
    packing = self.packing
    has_default_layout = self.layout == tmem_default_layout(packing=packing)
    bitwidth = utils.bitwidth(self.dtype)
    if value.layout == LAYOUT and has_default_layout:
      _store_32xcols(
          self.address, value.registers.T.reshape((4, -1)), packing
      )
    elif value.layout == self.layout.as_tiled_layout() and packing * bitwidth == 32:
      _store_32xcols_native(self.address, value.registers.reshape(-1), packing)
    # TODO(apaszke): Support the case where we have a long vector length in the
    # FA more generally, not just for 2x32b.
    # TODO(apaszke): Support a wider range of layouts when dealing with unpacking.
    # 16-bit types are special, because the store instruction can unpack them.
    elif value.layout == TMEM_NATIVE_LAYOUT and has_default_layout and (
        (bitwidth == 16 and packing == 1)
        or (bitwidth == 32 and value.layout.vector_length == 2)
    ):
      _store_32xcols_native(self.address, value.registers.reshape(-1), packing)
    elif (
        value.layout == fa.WGMMA_LAYOUT
        and self.layout == tmem_half_lane_layout(self.shape[1], packing=packing)
    ):
      registers = value.registers.T.reshape(2, -1)
      registers = np.concatenate(np.split(registers, 2, axis=1), axis=0)
      _store_32xcols(self.address, registers, packing)
    elif value.layout == fa_m64_collective_layout(
        self.shape[1]
    ) and self.layout == tmem_m64_collective_layout(
        self.shape[1], packing=packing
    ):
      _store_32xcols(self.address, value.registers.reshape(4, -1), packing)
    else:
      raise ValueError(
          f"Storing from register layout {value.layout} to TMEM layout"
          f" {self.layout} is not supported"
      )

  def _debug_print(self) -> None:
    i32 = ir.IntegerType.get_signless(32)
    num_cols = self.layout.cols_in_shape(self.shape, utils.bitwidth(self.dtype))
    lane = arith.remui(utils.thread_idx(), arith.constant(i32, utils.WARPGROUP_SIZE))
    for c in range(num_cols):
      ptr = _tmem_addr_to_ptr(arith.addi(self.address, arith.constant(i32, c)))
      val = nvvm.tcgen05_ld(i32, nvvm.Tcgen05LdStShape.SHAPE_32X32B, ptr)
      dtype_bitwidth = utils.bitwidth(self.dtype)
      full_packing = 32 // dtype_bitwidth
      if self.packing == 1:
        if dtype_bitwidth < 32:
          val = arith.trunci(ir.IntegerType.get_signless(dtype_bitwidth), val)
        val = utils.bitcast(val, self.dtype)
      elif self.packing == full_packing:
        val = utils.bitcast(val, ir.VectorType.get((full_packing,), self.dtype))
      else:
        raise NotImplementedError(f"Unsupported packing: {self.packing}")
      # TODO(apaszke): Make this print logical, not physical location.
      utils.debug_print(f"[{{}}, {c}]: {{}}", lane, val, uniform=False)


def _transfer_32xcols(
    base_addr: ir.Value,
    cols: int,
    atom_shape: tuple[int, int],
    tmem_packing: int,
    reg_packing: int,
) -> Iterator[tuple[ir.Value, int, int, slice]]:
  """Generates a sequence of parameters for a given TMEM read or write.

  Arguments:
    base_addr: The base address of the TMEM region.
    cols: The number of logical columns to transfer.
    atom_shape: The logical shape of the tile written by the warp in a single
      TMEM transfer.
    tmem_packing: Packing degree in TMEM. When packing is 1, but the data is
      16-bit, we expect that each transfer actually involves double the number
      of physical columns.
    reg_packing: The number of elements that fit in a single 32-bit register.
  """
  i32 = ir.IntegerType.get_signless(32)
  atom_rows, atom_cols = atom_shape
  assert cols % atom_cols == 0
  total_num = cols // atom_cols
  regs_per_instr = atom_shape[0] * atom_shape[1] // (utils.WARP_SIZE * reg_packing)
  assert 32 % atom_rows == 0
  num_row_steps = 32 // atom_rows
  # We artificially lower the instr_num compared to its limits, because higher
  # values can lead to register spills..
  max_num = 1 << (total_num.bit_length() - 1)  # power of 2 <= than total_num
  max_num = min(max_num, 32 // regs_per_instr)
  for lane_step in range(num_row_steps):
    addr_row = arith.addi(base_addr, utils.c((lane_step * atom_rows) << 16, i32))
    num_processed = 0
    instr_num = max_num
    while (remaining := total_num - num_processed) > 0:
      while instr_num > remaining:
        instr_num //= 2
      num_slice = slice(num_processed, num_processed + instr_num)
      addr_row_col = arith.addi(
          addr_row, utils.c(num_processed * atom_cols // tmem_packing, i32)
      )
      yield addr_row_col, instr_num, lane_step, num_slice
      num_processed += instr_num
    assert num_processed == total_num


def _store_32xcols(base_addr, vector_regs, tmem_packing) -> None:
  i32 = ir.IntegerType.get_signless(32)
  assert vector_regs.ndim == 2 and vector_regs.shape[0] == 4
  cols = vector_regs.shape[1] * 8

  reg_packing = 64 // utils.bitwidth(vector_regs.flat[0].type)
  if reg_packing == 1:
    store_shape = "16x256b"  # 4 threads * 64 bits per vreg = 256 bits
    regs = np.empty((4, vector_regs.shape[1], 2), dtype=object)
    c0 = arith.constant(i32, 0)
    c1 = arith.constant(i32, 1)
    for idx, vreg in np.ndenumerate(vector_regs):
      regs[(*idx, 0)] = llvm.extractelement(vreg, c0)
      regs[(*idx, 1)] = llvm.extractelement(vreg, c1)
    regs = regs.reshape(2, 2, vector_regs.shape[1], 2).swapaxes(1, 2)
    # From a single lane perspective a num tile consists of a 2x2, with the
    # minor dim traversing columns and major being 8 rows apart.
    # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
    assert regs.shape[-2:] == (2, 2)
    assert tmem_packing == 1
    unpack = False
  elif reg_packing == 2:
    store_shape = "16x128b"  # 4 threads * 32 bits per vreg = 128 bits
    # From a single lane perspective a num tile has 2 registers, 8 rows apart.
    # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
    regs = vector_regs.reshape(2, 2, vector_regs.shape[1]).swapaxes(1, 2)
    assert 1 <= tmem_packing <= 2
    unpack = tmem_packing == 1
  else:
    raise NotImplementedError(reg_packing)

  it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing)
  for addr_row_col, instr_num, lane_step, num_slice in it:
    regs_slice = regs[lane_step, num_slice].flat
    _tmem_store(addr_row_col, store_shape, instr_num, regs_slice, unpack)


def _store_32xcols_native(base_addr, vector_regs, tmem_packing) -> None:
  i32 = ir.IntegerType.get_signless(32)
  assert vector_regs.ndim == 1
  vec_ty = ir.VectorType(vector_regs.flat[0].type)
  [vector_length] = vec_ty.shape
  elt_bitwidth = utils.bitwidth(vec_ty.element_type)
  reg_packing = 32 // elt_bitwidth
  store_atom_shape = (32, reg_packing)
  # TODO(apaszke): More general register splitting code, not just 2x32b.
  if reg_packing == 1:
    if vector_length == 2:
      # Transform data such that each reg is 32 bits wide.
      regs = [None] * (len(vector_regs) * 2)
      c0 = arith.constant(i32, 0)
      c1 = arith.constant(i32, 1)
      for idx, vreg in enumerate(vector_regs):
        regs[2 * idx] = llvm.extractelement(vreg, c0)
        regs[2 * idx + 1] = llvm.extractelement(vreg, c1)
    else:
      regs = [utils.bitcast(r, i32) for r in vector_regs]
    assert tmem_packing == 1
    unpack = False
  elif reg_packing == 2:
    assert vector_length == 2
    # In this case, registers are already packed into 32-bit registers.
    regs = [utils.bitcast(r, i32) for r in vector_regs]
    if elt_bitwidth == 16:
      assert 1 <= tmem_packing <= 2
      unpack = tmem_packing == 1
    else:
      if tmem_packing == 1 and elt_bitwidth != 32:
        raise NotImplementedError(
            f"Unsupported packing: {tmem_packing} for element type {elt_bitwidth}"
        )
      assert tmem_packing == 32 // elt_bitwidth
      unpack = False
  else:
    if tmem_packing != reg_packing:
      raise NotImplementedError(
          f"Only {reg_packing} packing supported for bitwidth {elt_bitwidth},"
          f" but got TMEM packing of {tmem_packing}"
      )
    assert utils.bitwidth(vec_ty) == 32
    regs = [utils.bitcast(r, i32) for r in vector_regs]
    unpack = False
  cols = len(regs) * reg_packing
  it = _transfer_32xcols(base_addr, cols, store_atom_shape, tmem_packing, reg_packing)
  for addr_row_col, instr_num, lane_step, num_slice in it:
    assert lane_step == 0
    regs_slice = regs[num_slice]
    _tmem_store(addr_row_col, "32x32b", instr_num, regs_slice, unpack)


def _load_32xcols(base_addr, cols, dtype, tmem_packing) -> np.ndarray:
  i32 = ir.IntegerType.get_signless(32)
  vec_ty = ir.VectorType.get((2,), dtype)
  reg_packing = 32 // utils.bitwidth(dtype)
  if reg_packing == 1:
    load_shape = "16x256b"  # 4 threads * 64 bits per vreg = 256 bits
    assert tmem_packing == 1
    pack = False
  elif reg_packing == 2:
    load_shape = "16x128b"  # 4 threads * 32 bits per vreg = 128 bits
    assert 1 <= tmem_packing <= 2
    pack = tmem_packing == 1
  else:
    raise NotImplementedError(reg_packing)

  vector_regs = np.ndarray((4, cols // 8), dtype=object)

  it = _transfer_32xcols(base_addr, cols, (16, 8), tmem_packing, reg_packing)
  c0 = arith.constant(i32, 0)
  c1 = arith.constant(i32, 1)
  for addr_row_col, instr_num, lane_step, num_slice in it:
    regs = _tmem_load(addr_row_col, load_shape, instr_num, pack)
    row_slice = slice(lane_step * 2, (lane_step + 1) * 2)
    # This aliases the original array, so updates will be reflected there.
    vector_regs_update = vector_regs[row_slice, num_slice]
    assert vector_regs_update.shape == (2, instr_num), (vector_regs_update.shape, instr_num)
    if reg_packing == 1:
      regs = [llvm.bitcast(dtype, r) for r in regs]
      # From a single lane perspective a num tile consists of a 2x2, with the
      # minor dim traversing columns and major being 8 rows apart.
      # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16256b
      regs = np.asarray(regs, dtype=object).reshape(instr_num, 2, 2).swapaxes(0, 1)
      undef = llvm.mlir_undef(vec_ty)
      assert regs.shape == (*vector_regs_update.shape, 2)
      for idx in np.ndindex(vector_regs_update.shape):
        high_undef = llvm.insertelement(undef, regs[(*idx, 0)], c0)
        vreg = llvm.insertelement(high_undef, regs[(*idx, 1)], c1)
        vector_regs_update[idx] = vreg
    else:
      assert reg_packing == 2
      regs = [llvm.bitcast(vec_ty, r) for r in regs]
      # From a single lane perspective a num tile has 2 registers, 8 rows apart.
      # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-fragments-shape-16128b
      regs = np.asarray(regs, dtype=object).reshape(instr_num, 2).swapaxes(0, 1)
      vector_regs_update[...] = regs

  return vector_regs


def _load_32xcols_native(base_addr, cols, dtype, tmem_packing, vector_length) -> np.ndarray:
  i32 = ir.IntegerType.get_signless(32)
  vec_ty = ir.VectorType.get((vector_length,), dtype)
  reg_packing = 32 // utils.bitwidth(dtype)
  assert vector_length % reg_packing == 0
  load_shape = "32x32b"
  load_atom_shape = (32, reg_packing)
  if reg_packing == 2:
    assert 1 <= tmem_packing <= 2
    pack = tmem_packing == 1
  else:
    if tmem_packing != reg_packing:
      raise NotImplementedError(
          f"Only {reg_packing} supported for element type {dtype}, but got"
          f" TMEM packing of {tmem_packing}"
      )
    pack = False

  it = _transfer_32xcols(base_addr, cols, load_atom_shape, tmem_packing, reg_packing)
  c0 = arith.constant(i32, 0)
  c1 = arith.constant(i32, 1)
  regs = [None] * (cols // reg_packing)
  for addr_row_col, instr_num, lane_step, num_slice in it:
    assert lane_step == 0, lane_step
    instr_regs = _tmem_load(addr_row_col, load_shape, instr_num, pack)
    if reg_packing == 1 and vector_length == 2:
      regs[num_slice] = [llvm.bitcast(dtype, r) for r in instr_regs]
    else:
      regs[num_slice] = [utils.bitcast(r, vec_ty) for r in instr_regs]

  if reg_packing == 1 and vector_length == 2:
    vector_regs = np.ndarray((cols // 2,), dtype=object)
    undef = llvm.mlir_undef(vec_ty)
    for idx in range(vector_regs.size):
      high_undef = llvm.insertelement(undef, regs[2 * idx], c0)
      vreg = llvm.insertelement(high_undef, regs[2 * idx + 1], c1)
      vector_regs[idx] = vreg
  else:
    assert vector_length == reg_packing
    vector_regs = np.asarray(regs, dtype=object)

  return vector_regs


def commit_tmem() -> None:
  nvvm.tcgen05_wait(nvvm.Tcgen05WaitKind.STORE)
  utils.warpgroup_barrier()


def wait_load_tmem() -> None:
  nvvm.tcgen05_wait(nvvm.Tcgen05WaitKind.LOAD)
  utils.warpgroup_barrier()


def async_copy_scales_smem_to_tmem(
    smem_ref: ir.Value, tmem_ref: TMEMRef, collective: bool = False
) -> None:
  """Asynchronously copies the scale data from SMEM to TMEM.

  The result of the copy can be awaited by calling ``commit_arrive`` and waiting
  on the chosen ``Barrier``. However, if TMEM reference is to be consumed by a
  MMA issued in the same thread, no additional synchronization is needed.

  At the moment the function requires ``smem_ref`` to be contiguous and have a
  shape of ``(MN // 128, K // 128, 32, 16)`` for 8-bit scales (here MN stands
  for the size of the non-contracting dimension which is M or N), matching the
  scale layout for .scale_vec::1X. See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x
  for more details. Note that we always put the non-contracting dimension first.
  If you have a (MN, K // 32) array of scales in JAX (where MN and K are
  divisible by 128), you can prepare it for use in the kernel this way::

      scales.reshape(mn // 128, 4, 32, k // 4, 4)
            .transpose(0, 3, 2, 1, 4)
            .reshape(mn // 128, k // 4, 32, 16)

  The TMEM ref is expected to have the logical shape of the scales
  ``(MN, K // 32)``, and the layout created by ``scales_layout()``.
  """
  i32 = ir.IntegerType.get_signless(32)
  smem_ty = ir.MemRefType(smem_ref.type)
  if (dtype := smem_ty.element_type) != tmem_ref.dtype:
    raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}")
  if dtype not in {ir.Float8E8M0FNUType.get(), ir.Float8E4M3FNType.get()}:
    raise NotImplementedError(f"Unsupported dtype: {dtype}, only f8e8m0fnu and f8e4m3fn are supported")
  if tmem_ref.shape[0] % TMEM_ROWS:
    raise ValueError(f"TMEM reference must have a multiple of {TMEM_ROWS} rows, but got {tmem_ref.shape[0]}")
  if tmem_ref.shape[1] % 4:
    raise ValueError(f"TMEM reference must have a multiple of 4 columns, but got {tmem_ref.shape[1]}")
  if tmem_ref.layout != scales_layout():
    raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported")
  smem_shape = tuple(smem_ty.shape)
  expected_smem_shape = (tmem_ref.shape[0] // TMEM_ROWS, tmem_ref.shape[1] // 4, 32, 16)
  if smem_shape != expected_smem_shape:
    raise NotImplementedError(
        f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM"
        f" ref shape {tmem_ref.shape}"
    )
  strides, _ = smem_ty.get_strides_and_offset()
  # TODO(apaszke): This should only matter for the two minor dims.
  if strides != utils.get_contiguous_strides(smem_shape):
    raise ValueError("Only copies from contiguous SMEM references are supported")
  mn_tile_stride, k_tile_stride = strides[:2]
  # One tile of scales has 128 bytes.
  if mn_tile_stride % 128 or k_tile_stride % 128:
    raise ValueError("Scale tile strides must be a multiple of 128")
  mn_tile_stride_i32 = mn_tile_stride // 4
  k_tile_stride_i32 = k_tile_stride // 4
  smem_base_ptr = utils.memref_ptr(smem_ref, 3)
  # TODO(apaszke): Need to figure out the TMEM layout otherwise and MMA doesn't
  # support it anyway.
  if smem_shape[0] > 2:
    raise NotImplementedError("Only M/N up to 256 supported")
  for mn_tile, k_tile in np.ndindex(smem_shape[:2]):
    load_ptr = utils.getelementptr(
        smem_base_ptr,
        [mn_tile * mn_tile_stride_i32 + k_tile * k_tile_stride_i32],
        i32,
    )
    # NOTE: The tiles are MN-minor in TMEM, but MN-major (logically) in SMEM.
    store_addr = arith.addi(
        tmem_ref.address,
        arith.constant(i32, 4 * smem_shape[0] * k_tile + 4 * mn_tile),
    )
    # The "core matrix" here is the same as in MMA: 8x(16 bytes).
    desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None)
    nvvm.tcgen05_cp(
        nvvm.Tcgen05CpShape.SHAPE_32x128b,
        _tmem_addr_to_ptr(store_addr),
        desc,
        multicast=nvvm.Tcgen05CpMulticast.WARPX4,
        group=nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1
    )


def async_copy_sparse_metadata_smem_to_tmem(
    smem_ref: ir.Value, tmem_ref: TMEMRef, collective: bool = False
) -> None:
  i8 = ir.IntegerType.get_signless(8)
  i32 = ir.IntegerType.get_signless(32)
  smem_ty = ir.MemRefType(smem_ref.type)
  if (dtype := smem_ty.element_type) != tmem_ref.dtype:
    raise ValueError(f"Incompatible dtypes: SMEM has {dtype}, TMEM has {tmem_ref.dtype}")
  if dtype != ir.IntegerType.get_signless(2):
    raise NotImplementedError(f"Unsupported dtype: {dtype}, only i2 supported")
  if tmem_ref.shape[0] % 128:
    raise ValueError(f"TMEM reference must have a multiple of 128 rows, but got {tmem_ref.shape[0]}")
  if tmem_ref.shape[1] % 64:
    raise ValueError(f"TMEM reference must have a multiple of 64 colums, but got {tmem_ref.shape[1]}")
  if tmem_ref.layout != sparse_meta_layout():
    raise ValueError(f"TMEM layout {tmem_ref.layout} is not supported")
  smem_shape = tuple(smem_ty.shape)
  expected_smem_shape = (tmem_ref.shape[0] // 128, tmem_ref.shape[1] // 64, 128, 64)
  if smem_shape != expected_smem_shape:
    raise NotImplementedError(
        f"SMEM has {smem_shape}, but expected {expected_smem_shape} for TMEM"
        f" ref shape {tmem_ref.shape}"
    )
  strides, _ = smem_ty.get_strides_and_offset()
  if strides != utils.get_contiguous_strides(smem_shape):
    raise ValueError("Only copies from contiguous SMEM references are supported")
  if expected_smem_shape[0] != 1:
    raise NotImplementedError("Only M=128 supported")
  k_tile_stride = strides[1]
  if k_tile_stride % 16:
    raise ValueError("K tile stride must be a multiple of 16")
  k_tile_byte_stride = k_tile_stride // 4
  smem_base_ptr = utils.memref_ptr(smem_ref, 3)
  for k_tile in range(expected_smem_shape[1]):
    load_ptr = utils.getelementptr(
        smem_base_ptr, [k_tile * k_tile_byte_stride], i8
    )
    store_ptr = arith.addi(tmem_ref.address, arith.constant(i32, 4 * k_tile))
    # The "core matrix" here is the same as in MMA: 8x(16 bytes).
    desc = mma_utils.encode_descriptor(load_ptr, 0, 8 * 16, swizzle=None)
    ptr = _tmem_addr_to_ptr(store_ptr)
    nvvm.tcgen05_cp(
        nvvm.Tcgen05CpShape.SHAPE_128x128b, ptr, desc,
        group=nvvm.CTAGroupKind.CTA_2 if collective else nvvm.CTAGroupKind.CTA_1
    )
