# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
import collections
import dataclasses
import functools
from typing import Any, Union

from jax._src.util import use_cpp_class, cache, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import sdy
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisType
from jax._src.partition_spec import PartitionSpec
from jax._src import sharding as JSharding
import numpy as np

Shape = tuple[int, ...]
Device = xc.Device
Index = tuple[slice, ...]
XLADeviceAssignment = Sequence[Device]


class AUTO:

  def __init__(self, mesh: mesh_lib.Mesh):
    self.mesh = mesh

  def _to_sdy_sharding(self, ndim: int) -> SdyArray:
    dim_shardings = [SdyDim(axes=[], is_open=True)
                     for _ in range(ndim)]
    return SdyArray(mesh_shape=self.mesh.shape_tuple,
                    dim_shardings=dim_shardings)

  @property
  def _device_assignment(self):
    return self.mesh._flat_devices_tuple


class UnspecifiedValue:
  def __repr__(self):
    return "UnspecifiedValue"
UNSPECIFIED = UnspecifiedValue()


MeshAxisName = Any

"""
ArrayMapping specifies how an ndarray should map to mesh axes.

Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.

For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = collections.OrderedDict[MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]


def _unpickle_named_sharding(mesh, spec, memory_kind, logical_device_ids):
  return NamedSharding(mesh, spec, memory_kind=memory_kind,
                       _logical_device_ids=logical_device_ids)


@use_cpp_class(xc.NamedSharding)
class NamedSharding(JSharding.Sharding):
  r"""A :class:`NamedSharding` expresses sharding using named axes.

  A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
  :class:`PartitionSpec` which describes how to shard an array across that
  mesh.

  A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
  where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.

  A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
  a mesh axis, or a tuple of mesh axes. Each element describes how an input
  dimension is partitioned across zero or more mesh dimensions. For example,
  ``PartitionSpec('x', 'y')`` says that the first dimension of data
  is sharded across ``x`` axis of the mesh, and the second dimension is sharded
  across ``y`` axis of the mesh.

  The `Distributed arrays and automatic parallelization`_
  and `Explicit Sharding`_ tutorials have more details and diagrams that
  explain how :class:`Mesh` and :class:`PartitionSpec` are used.

  Args:
    mesh: A :class:`jax.sharding.Mesh` object.
    spec: A :class:`jax.sharding.PartitionSpec` object.

  Examples:

    >>> from jax.sharding import Mesh
    >>> from jax.sharding import PartitionSpec as P
    >>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
    >>> spec = P('x', 'y')
    >>> named_sharding = jax.sharding.NamedSharding(mesh, spec)

  .. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
  .. _Explicit Sharding:  https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html
  """

  mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh
  spec: PartitionSpec
  _memory_kind: str | None
  _logical_device_ids: tuple[int, ...] | None

  @use_cpp_method()
  def __init__(
      self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *,
      memory_kind: str | None = None, _logical_device_ids=None):
    self.mesh = mesh
    self.spec = spec
    self._memory_kind = memory_kind
    self._logical_device_ids = _logical_device_ids
    check_pspec(self.mesh, self.spec)

  def __repr__(self):
    mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
    ldi = ('' if self._logical_device_ids is None else
           f', logical_device_ids={self._logical_device_ids}')
    mesh_repr = f"{str(self.mesh)}"
    return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'

  def __reduce__(self):
    return (_unpickle_named_sharding,
            (self.mesh, self.spec, self.memory_kind, self._logical_device_ids))

  @property
  def memory_kind(self) -> str | None:
    return self._memory_kind

  @use_cpp_method()
  def __hash__(self):
    if not hasattr(self, '_hash'):
      self._hash = hash(
          (self.mesh, self.memory_kind, self.spec, self._logical_device_ids))
    return self._hash

  @use_cpp_method()
  def __eq__(self, other):
    if not isinstance(other, NamedSharding):
      return False
    if self is other:
      return True
    if (self.spec != other.spec
        or self.memory_kind != other.memory_kind
        or self._logical_device_ids != other._logical_device_ids):
      return False
    return self.mesh is other.mesh or self.mesh == other.mesh

  def check_compatible_aval(self, aval_shape: Shape) -> None:
    if len(aval_shape) < len(self.spec):
      extra_msg = (' For scalars the PartitionSpec should be P()'
                   if len(aval_shape) == 0 else '')
      raise ValueError(
          f"Sharding {self} is only valid for values of rank at least "
          f"{len(self.spec)}, but was applied to a value of rank "
          f"{len(aval_shape)}.{extra_msg}")

  @property
  def num_devices(self) -> int:
    return self.mesh.size

  @property
  def device_set(self) -> set[Device]:
    if isinstance(self.mesh, mesh_lib.AbstractMesh):
      raise ValueError(
          'device_set is not implemented for `jax.sharding.AbstractMesh`.')
    return self.mesh._flat_devices_set

  @property
  def _device_assignment(self) -> XLADeviceAssignment:
    if isinstance(self.mesh, mesh_lib.AbstractMesh):
      raise ValueError('_device_assignment is not implemented for'
                       ' `jax.sharding.AbstractMesh`.')
    return self.mesh._flat_devices_tuple

  @property
  def is_fully_addressable(self) -> bool:
    if isinstance(self.mesh, mesh_lib.AbstractMesh):
      raise ValueError('is_fully_addressable is not implemented for '
                       '`jax.sharding.AbstractMesh`.')
    # return False if addressable_device_list is empty.
    return self._internal_device_list.is_fully_addressable  # type: ignore

  @property
  def _is_concrete(self) -> bool:
    if isinstance(self.mesh, mesh_lib.AbstractMesh):
      return False
    return True

  @property
  def addressable_devices(self) -> set[Device]:
    if isinstance(self.mesh, mesh_lib.AbstractMesh):
      raise ValueError('addressable_devices is not implemented for '
                       '`jax.sharding.AbstractMesh`.')
    # Override addressable devices because there is a high chance that the mesh
    # across multiple NamedSharding objects will be the same.
    return self.mesh._local_devices_set

  @functools.cached_property
  def is_fully_replicated(self) -> bool:
    if self.mesh.size == 1:
      return True
    array_mapping = get_array_mapping(self.spec)
    mesh_shape = self.mesh.shape
    num_partitions = 1
    for name in array_mapping:  # type: ignore
      num_partitions *= mesh_shape[name]
    return num_partitions == 1

  @functools.cached_property
  def replicated_axes(self) -> frozenset[MeshAxisName]:
    flat_spec = frozenset(
        s for s in flatten_spec(self.spec)
        if s is not None and s is not PartitionSpec.UNCONSTRAINED)
    return frozenset(self.mesh.axis_names) - (
        flat_spec | self.spec.unreduced | self.spec.reduced)

  def with_memory_kind(self, kind: str) -> NamedSharding:
    return self.update(memory_kind=kind)

  def update(self, **kwargs) -> NamedSharding:
    spec = kwargs.pop("spec", self.spec)
    if not isinstance(spec, PartitionSpec):
      spec = PartitionSpec(*spec)
    return NamedSharding(
        mesh=kwargs.pop("mesh", self.mesh),
        spec=spec,
        memory_kind=kwargs.pop("memory_kind", self.memory_kind),
        _logical_device_ids=kwargs.pop("_logical_device_ids",
                                       self._logical_device_ids))

  def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
    return named_sharding_to_xla_hlo_sharding(self, num_dimensions)

  def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray:
    dim_shardings = [SdyDim(axes=[], is_open=False)
                     for _ in range(num_dimensions)]
    for i, dim_spec in enumerate(self.spec):
      if dim_spec is PartitionSpec.UNCONSTRAINED:
        dim_shardings[i].is_open = True
      elif dim_spec is None:
        # Already empty and closed sharding.
        pass
      else:
        dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,)
        dim_shardings[i].axes = dim_spec
    return SdyArray(mesh_shape=self.mesh.shape_tuple,
                    dim_shardings=dim_shardings,
                    logical_device_ids=self._logical_device_ids,
                    unreduced_axes=self.spec.unreduced)

NamedSharding.__module__ = 'jax.sharding'

def flatten_spec(spec):
  out = []
  for s in spec:
    if isinstance(s, tuple):
      out.extend(s)
    else:
      out.append(s)
  return out


def get_array_mapping(
    axis_resources: PartitionSpec | AUTO | UnspecifiedValue
) -> ArrayMappingOrAutoOrUnspecified:
  if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
    return axis_resources
  d = collections.OrderedDict()
  for i, axes in enumerate(axis_resources):
    if axes is None or axes is PartitionSpec.UNCONSTRAINED:
      continue
    axes = axes if isinstance(axes, tuple) else (axes,)
    for axis in axes:
      d[axis] = i
  return d

@dataclasses.dataclass
class SdyDim:
  axes: Sequence[str]
  is_open: bool

  def build(self) -> sdy.DimensionShardingAttr:
    return sdy.DimensionShardingAttr.get(
        [sdy.AxisRefAttr.get(axis) for axis in self.axes],
        is_closed=not self.is_open)

  def __repr__(self):
    return f'SdyDim({self._custom_repr()})'

  def _custom_repr(self):
    axes_repr = ', '.join(f"'{a}'" for a in self.axes)
    open_repr = ''
    if self.is_open:
      open_repr = ', ?' if self.axes else '?'
    return f'{{{axes_repr}{open_repr}}}'

def _get_axes(axes, mesh_shape):
  if not axes:
    return ()
  assert mesh_shape is not None
  # Sort wrt mesh axis names so order is deterministic and doesn't hang in
  # McJAX.
  return tuple(n for n, _ in mesh_shape if n in axes)

@dataclasses.dataclass(kw_only=True)
class SdyArray:
  mesh_shape: tuple[tuple[str, int], ...] | None
  dim_shardings: Sequence[SdyDim]
  logical_device_ids: tuple[int, ...] | None = None
  replicated_axes: tuple[str, ...] = ()
  unreduced_axes: frozenset[str] = frozenset()

  def build(self) -> sdy.TensorShardingAttr:
    if self.mesh_shape is None:
      mesh_attr = sdy.MeshAttr.get([])
    else:
      ldi = ([] if self.logical_device_ids is None else
             list(self.logical_device_ids))
      mesh_attr = sdy.MeshAttr.get(
          [sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape],
          ldi)

    replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape)
    unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape)
    return sdy.TensorShardingAttr.get(
        mesh_attr,
        [dim_sharding.build() for dim_sharding in self.dim_shardings],
        replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes],
        unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes])

  def __repr__(self):
    dim_sharding_repr = ', '.join(
        d._custom_repr() for d in self.dim_shardings)
    device_id_repr = (f', device_ids={self.logical_device_ids}'
                      if self.logical_device_ids is not None else '')
    rar = (f', replicated_axes={self.replicated_axes}'
           if self.replicated_axes else '')
    return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})"


# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh):
  if mesh._any_axis_auto:
    dim_shardings, used_axes = [], []  # type: ignore
    for d in sdy_sharding.dim_shardings:
      dim_shardings.append(SdyDim(axes=d.axes, is_open=True))
      used_axes.extend(d.axes)
    remaining_axes = set(mesh.axis_names) - set(used_axes)
    replicated_axes = tuple(r for r in remaining_axes
                            if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit)
    return SdyArray(mesh_shape=sdy_sharding.mesh_shape,
                    dim_shardings=dim_shardings,
                    logical_device_ids=sdy_sharding.logical_device_ids,
                    replicated_axes=replicated_axes)
  return sdy_sharding


@cache(max_size=4096, trace_context_in_key=False)
def named_sharding_to_xla_hlo_sharding(
    self, num_dimensions: int) -> xc.HloSharding:
  mesh_shape = self.mesh.shape
  array_mapping = get_array_mapping(self.spec)
  mesh_axis_pos = {name: i for i, name in enumerate(self.mesh.axis_names)}

  special_axes = {}
  manual_axes = frozenset(self.mesh.manual_axes)
  if manual_axes:
    axis_names = self.mesh.axis_names
    for manual_axis in manual_axes:
      special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL

  unreduced_axes = self.spec.unreduced
  if unreduced_axes:
    axis_names = self.mesh.axis_names
    for u in unreduced_axes:
      special_axes[axis_names.index(u)] = xc.OpSharding.Type.UNREDUCED

  replicated_mesh_axes = []
  for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
    if axis_name not in array_mapping:  # type: ignore
      replicated_mesh_axes.append((i, axis_val))

  if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes:
    return xc.HloSharding.replicate()

  mesh_permutation = []
  new_mesh_shape = [1] * num_dimensions
  for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]):  # type: ignore
    new_mesh_shape[pos] *= mesh_shape[name]
    mesh_permutation.append(mesh_axis_pos[name])

  last_tile_dims = []
  if replicated_mesh_axes:
    axes_by_type: dict[Any, list[int]] = collections.defaultdict(list)
    size_by_type = collections.defaultdict(lambda: 1)  # type: ignore
    assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys()))
    for i, size in replicated_mesh_axes:
      ty = special_axes.get(i, xc.OpSharding.Type.REPLICATED)
      axes_by_type[ty].append(i)
      size_by_type[ty] *= size
    for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
      last_tile_dims.append(ty)
      new_mesh_shape.append(size_by_type[ty])
      mesh_permutation.extend(axes)

  # Explanation of the parameters of `HloSharding.iota_tile`.
  # This is the HloShardingV2 format:
  #   * dims: How many ways each dimension is sharded.
  #       Replicated/Manual dims are added added at the end
  #   * reshape_dims: This is the just the shape of the mesh.
  #   * transpose_perm: This is the order in which mesh axes in PartitionSpec
  #       appear relative to mesh.axis_names order.
  #   * subgroup_types: List of type of OpSharding. Type can be REPLICATED and MANUAL.
  # Let's see an example:
  #   Consider input_shape=(8, 4, 2, 2), mesh={'a': 2, 'b': 2, 'c': 2, 'd': 2}
  #   and partition_spec=P(None, ('d', 'b'), 'c').
  #   Arguments to iota_tile will be:
  #     dims = [1, 4, 2, 1, 2]  # 'a' is replicated hence `2` is at the end.
  #     reshape_dims = [2, 2, 2, 2]
  #     transpose_perm = [3, 1, 2, 0]  # 'a' is replicated hence 0 is at the end
  #     subgroup_types = [xc.OpSharding.Type.REPLICATED]
  dims = new_mesh_shape
  reshape_dims = self.mesh.axis_sizes
  if self._logical_device_ids is None:
    return xc.HloSharding.iota_tile(
        dims=dims, reshape_dims=reshape_dims, transpose_perm=mesh_permutation,
        subgroup_types=last_tile_dims)
  else:
    return xc.HloSharding.subgroup_with_device_ordering(
        np.asarray(self._logical_device_ids)
        .reshape(dims).reshape(reshape_dims).transpose(mesh_permutation)
        .reshape(dims), subgroup_types=last_tile_dims)


def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
  if not array_mapping:
    return PartitionSpec()
  max_index = -1
  reverse_map = collections.defaultdict(list)
  for axis, index in array_mapping.items():
    reverse_map[index].append(axis)
    if index > max_index:
      max_index = index
  partitions = []
  for i in range(max_index + 1):
    axis = reverse_map[i]
    if axis:
      partitions.append(axis[0] if len(axis) == 1 else tuple(axis))
    else:
      partitions.append(None)
  return PartitionSpec(*partitions)


@cache(max_size=128, trace_context_in_key=False)
def check_pspec(mesh, spec, _manual_axes=frozenset()):
  _check_unique_resources(spec, "NamedSharding spec", mesh)
  _check_mesh_resource_axis(mesh, spec)
  _check_mesh_unreduced(mesh, spec)

class DuplicateSpecError(Exception):
  def __init__(self, message, mesh, pspec):
    super().__init__(message)
    self.message = message
    self.mesh = mesh
    self.pspec = pspec

  def __str__(self):
    return f"{self.message}"

def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None
                            ) -> None:
  resource_counts: dict[MeshAxisName, int] = {}
  duplicate = False
  for d in pspec:
    if d is PartitionSpec.UNCONSTRAINED or d is None:
      continue
    d = d if isinstance(d, tuple) else (d,)
    for resource in d:
      count = resource_counts.get(resource, 0)
      if count > 0:
        duplicate = True
      resource_counts[resource] = count + 1
  if duplicate:
    multiple_uses = [r for r, c in resource_counts.items() if c > 1]
    raise DuplicateSpecError(
        message=(
            f'A single {arg_name} specification can map every mesh axis to at'
            f' most one positional dimension, but {pspec} has duplicate entries'
            f' for {mesh_lib.show_axes(multiple_uses)}'),
        mesh=mesh, pspec=pspec)

def check_pspec_mix_axis_type(mesh, pspec):
  for spec in pspec:
    if isinstance(spec, tuple):
      if all(mesh._name_to_type[spec[0]] == mesh._name_to_type[p]
             for p in spec):
        continue
      if any(mesh._name_to_type[p] == AxisType.Manual for p in spec):
        raise ValueError(
            'Tuple subset of `PartitionSpec` cannot contain `Manual` mixed'
            f' with `Auto` or `Explicit`. Got pspec {pspec} and subset'
            f' {spec} with axis types:'
            f' ({", ".join(str(mesh._name_to_type[p]) for p in spec)})')

def _check_mesh_resource_axis(mesh, pspec):
  for p in pspec:
    if p is PartitionSpec.UNCONSTRAINED or p is None:
      continue
    p = p if isinstance(p, tuple) else (p,)
    for r in p:
      if r not in mesh.axis_names:
        raise ValueError(
            f"Resource axis: {r} of {pspec} "
            f"is not found in mesh: {tuple(mesh.shape.keys())}.")
  check_pspec_mix_axis_type(mesh, pspec)
  if (AxisType.Auto not in mesh.axis_types and
      PartitionSpec.UNCONSTRAINED in pspec):
    raise ValueError(
        f'{pspec} cannot contain'
        ' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
        f' axis_types: {mesh.axis_types}')

def _check_mesh_unreduced(mesh, pspec):
  for u in pspec.unreduced:
    if u not in mesh.axis_names:
      raise ValueError(
          f'Unreduced axes {u} is not found in {mesh.axis_names=}. '
          f'Got {pspec=}')
    if mesh._name_to_type[u] == AxisType.Auto:
      raise ValueError(
          'Unreduced axes can only refer to mesh axes that are of type'
          f' `Explicit` or `Manual`. Got unreduced axes: {pspec.unreduced} and'
          f' mesh: {mesh}')

  for u in pspec.reduced:
    if u not in mesh.axis_names:
      raise ValueError(
          f'Reduced axes {u} is not found in {mesh.axis_names=}. '
          f'Got {pspec=}')
    if mesh._name_to_type[u] == AxisType.Auto:
      raise ValueError(
          'Reduced axes can only refer to mesh axes that are of type'
          f' `Explicit` or `Manual`. Got reduced axes: {pspec.reduced} and'
          f' mesh: {mesh}')
