# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains shared logic and abstractions for Pallas indexing ops."""

from __future__ import annotations

import dataclasses
from typing import Any, Union

from jax._src import core
from jax._src import pretty_printer as pp
from jax._src import tree_util
from jax._src.typing import Array
from jax._src.util import merge_lists
from jax._src.util import partition_list
import numpy as np


@tree_util.register_pytree_node_class
@dataclasses.dataclass
class Slice:
  """A slice with a start index and a size.

  Both start index and size can either be static, i.e. known at tracing
  and compilation time, or dynamic.
  """

  start: int | Array
  size: int | Array
  stride: int = 1

  def __post_init__(self):
    if self.stride < 0:
      raise ValueError("`stride` must be >= 0.")

  @property
  def is_dynamic_start(self):
    return not core.is_dim(self.start)

  @property
  def is_dynamic_size(self):
    return not core.is_dim(self.size)

  def tree_flatten(self):
    # If `start` is statically known, we treat it as static information
    xs = ()
    data = ()
    xs += (self.start,) if self.is_dynamic_start else (None,)
    data += (None,) if self.is_dynamic_start else (self.start,)
    xs += (self.size,) if self.is_dynamic_size else (None,)
    data += (None,) if self.is_dynamic_size else (self.size,)
    data += (self.stride,)
    return xs, data

  @classmethod
  def tree_unflatten(cls, aux_data, children) -> Slice:
    start, size = (
        a if a is not None else b for a, b in zip(children, aux_data[:2])
    )
    return cls(start, size, aux_data[2])

  @classmethod
  def from_slice(cls, slc: slice, size: int) -> Slice:
    start, step, size = core.canonicalize_slice(slc, size)
    if step < 1:
      raise ValueError(f"slice must have a step >= 1 (found: {step})")
    return cls(start, size, step)


def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str:
  start, size = slc.start, slc.size
  if isinstance(start, core.Var):
    start_str = core.pp_var(start, context)
    size_str = (
        core.pp_var(size, context) if isinstance(size, core.Var) else str(size)
    )
    return f"{start_str}:{start_str}+{size_str}"
  else:
    start_str = str(start)
    if start == 0:
      start_str = ""
    if isinstance(size, core.Var):
      size_str = core.pp_var(size, context)
      if start_str:
        return f"{start_str}:{start_str}+{size_str}"
      else:
        return f":{size_str}"
    else:
      end = start + size
      end_str = "" if end == dim else str(end)
      return f"{start_str}:{end_str}"


def dslice(
    start: int | Array | None,
    size: int | Array | None = None,
    stride: int | None = None,
) -> slice | Slice:
  """Constructs a ``Slice`` from a start index and a size.

  The semantics of ``dslice`` mirror those of the builtin ``slice`` type:

  * ``dslice(None)`` is ``:``
  * ``dslice(j)`` is ``:j``
  * ``dslice(i, j)`` is ``i:i+j``
  * ``dslice(i, j, stride)`` is ``i:i+j:stride``
  """
  if start is None:
    return slice(None)
  if stride is None:
    stride = 1
  if not isinstance(stride, int):
    raise ValueError("Non-static stride in `dslice`")
  if size is None:
    if not isinstance(start, int):
      raise ValueError("Non-static `dslice`")
    return Slice(0, start, stride)
  return Slice(start, size, stride)


ds = dslice  # Handy alias


IntIndexer = Union[int, Array]
DimIndexer = Union[IntIndexer, Slice]

def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...],
                                                  tuple[Slice, ...],
                                                  tuple[IntIndexer, ...]]:
  # TODO(slebedev): Flip this to be ``is_slice_indexing`` and update callers.
  is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices]
  slice_indexers, int_indexers = partition_list(
      is_int_indexing, indexer.indices)
  return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers)  # type: ignore

def _maybe_concretize(x: Any):
  # This is roughly the same logic as core.concrete_or_error, but we avoid
  # calling that because constructing the ConcretizationTypeError can be
  # expensive as the size of the tracing context (i.e. the jaxpr) grows.
  return core.to_concrete_value(x)

@tree_util.register_pytree_node_class
@dataclasses.dataclass
class NDIndexer:
  indices: tuple[DimIndexer, ...]
  shape: tuple[int, ...]
  int_indexer_shape: tuple[int | Array, ...]
  # Off by default to avoid doing validation during pytree operations.
  validate: bool = False

  def __post_init__(self):
    if len(self.indices) != len(self.shape):
      raise ValueError(
          f"`indices` must be the same length as `Ref` shape.: {self}."
      )
    if not self.validate:
      return
    # We validate integer indexing shapes here
    for idx, s in zip(self.indices, self.shape):
      if isinstance(idx, Slice):
        start = idx.start
        if value := _maybe_concretize(start):
          if value >= s:
            raise ValueError(f"Out of bound slice: start={value}, dim={s}.")
          if size := _maybe_concretize(idx.size):
            if value + (size - 1) * idx.stride >= s:
              raise ValueError(
                  f"Out of bound slice: start={value}, size={size},"
                  f" stride={idx.stride}, dim={s}."
              )
        continue
      # The shape of indexer integers should be broadcastable up to the
      # int_indexer_shape of the whole NDIndexer
      from jax._src.state import types as state_types  # pytype: disable=import-error
      idx_shape = (
          idx.shape
          if isinstance(idx, state_types.TransformedRef)
          else core.get_aval(idx).shape
      )
      if not idx_shape:
        if (value := _maybe_concretize(idx)) and value >= s:
          raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.")
        # For ()-shaped indexers, we can broadcast no problm.
        continue
      # If we don't have a ()-shaped indexer, the rank must match
      # int_indexer_shape
      if len(idx_shape) != len(self.int_indexer_shape):
        raise ValueError(
            f"Indexer must have rank {len(idx_shape)}: {idx=} vs."
            f" {self.int_indexer_shape=}"
        )
      # Here we check that the shapes broadcast.
      try:
        np.broadcast_shapes(idx_shape, self.int_indexer_shape)
      except ValueError as e:
        raise ValueError(
            f"Could not broadcast integer indexer: {idx=} vs."
            f" {self.int_indexer_shape=}"
        ) from e

  @property
  def is_dynamic_size(self):
    return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices)

  def tree_flatten(self):
    flat_idx, idx_tree = tree_util.tree_flatten(self.indices)
    if not all(isinstance(i, int) for i in self.int_indexer_shape):
      return (*flat_idx, self.int_indexer_shape), (idx_tree, self.shape)
    else:
      return flat_idx, (idx_tree, self.shape, self.int_indexer_shape)

  @classmethod
  def tree_unflatten(cls, data, flat_idx):
    if len(data) == 3:
      idx_tree, shape, int_indexer_shape = data
    else:
      # The ``int_indexer_shape`` is dynamic.
      idx_tree, shape = data
      *flat_idx, int_indexer_shape = flat_idx
    indices = tree_util.tree_unflatten(idx_tree, flat_idx)
    return cls(tuple(indices), shape, int_indexer_shape)

  @classmethod
  def from_indices_shape(cls, indices, shape) -> NDIndexer:
    if not isinstance(indices, tuple):
      # TODO(slebedev): Consider requiring `indices` to be a Sequence.
      indices = (indices,)

    if num_ellipsis := sum(idx is ... for idx in indices):
      if num_ellipsis > 1:
        raise ValueError("Only one ellipsis is supported.")
      # Expand ... so that `indices` has the same length as `shape`.
      ip = indices.index(...)
      indices = list(indices)
      indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1)
      indices = tuple(indices)
    if len(indices) > len(shape):
      raise ValueError("`indices` must not be longer than `shape`: "
                       f"{indices=}, {shape=}")
    elif len(indices) < len(shape):
      # Pad `indices` to have the same length as `shape`.
      indices = (*indices, *[slice(None)] * (len(shape) - len(indices)))

    # Promote all builtin `slice`s to `Slice`.
    indices = tuple(
        Slice.from_slice(i, s) if isinstance(i, slice) else i
        for i, s in zip(indices, shape))

    is_slice_indexing = [isinstance(i, Slice) for i in indices]
    if all(is_slice_indexing):
      return cls(indices, shape, (), validate=True)

    other_indexers, slice_indexers = partition_list(is_slice_indexing, indices)
    validate = True

    # We treat refs differently from scalars and arrays, because refs can have
    # a dynamic shape, making it impossible to statically determine the
    # broadcasted shape in the presence of other non-slice indexers.
    from jax._src.state import types as state_types  # pytype: disable=import-error
    if ref_indexers := [
        i
        for i in other_indexers
        if isinstance(i, state_types.TransformedRef)
        or isinstance(core.get_aval(i), state_types.AbstractRef)
    ]:
      # TODO(slebedev): Consider pushing these checks to lowering time.
      if len(ref_indexers) > 1:
        raise NotImplementedError("Multiple Ref indexers are not supported")
      if len(ref_indexers) != len(other_indexers):
        raise NotImplementedError(
            "Ref cannot be mixed with other non-slice indexers"
        )
      [ref_indexer] = ref_indexers
      indexer_shape = ref_indexer.shape  # type: ignore
      try:
        core.canonicalize_shape(indexer_shape)
      except TypeError:
        validate = False  # The shape is dynamic.
    else:
      indexer_shapes = [core.get_aval(i).shape for i in other_indexers]
      try:
        indexer_shape = np.broadcast_shapes(*indexer_shapes)
      except ValueError as e:
        # Raise a nicer error than the NumPy one.
        raise ValueError(
            "Cannot broadcast shapes for indexing: {indexer_shapes}"
        ) from e

      # Here we use the `broadcast_to` primitive instead of composing lax
      # primitives together because it is easier to lower in targets like
      # Triton/Mosaic.
      #
      # The local import avoids a circular dependency between primitives
      # and this module.
      from jax._src.state import primitives as sp  # pytype: disable=import-error
      other_indexers = [
          sp.broadcast_to(i, indexer_shape) for i in other_indexers  # type: ignore[arg-type]
      ]
      indices = tuple(
          merge_lists(is_slice_indexing, other_indexers, slice_indexers)
       )
    return cls(indices, shape, indexer_shape, validate)

  @classmethod
  def make_trivial_indexer(cls, shape: tuple[int, ...]) -> NDIndexer:
    return NDIndexer.from_indices_shape(
        tuple(slice(0, e) for e in shape),
        shape,
    )

  def get_indexer_shape(self) -> tuple[int | Array, ...]:
    is_int_indexing, slice_indexers, _ = unpack_ndindexer(self)

    slice_shape = tuple(s.size for s in slice_indexers)
    int_indexers_contiguous = bool(
        np.all(np.diff(np.where(is_int_indexing)[0]) == 1)
    )
    if not int_indexers_contiguous:
      return self.int_indexer_shape + slice_shape

    has_int_indexers = any(is_int_indexing)
    if has_int_indexers:
      pos = is_int_indexing.index(True)
      return slice_shape[:pos] + self.int_indexer_shape + slice_shape[pos:]

    return slice_shape

  def transform_shape(self, shape: None | tuple[int | Array, ...]) -> None | tuple[int | Array, ...]:
    del shape  # Unused
    return self.get_indexer_shape()

  def transform_dtype(self, dtype):
    return dtype

  def transform_sharding(self, sharding):
    # If there are no explicit axes, do nothing.
    if all(p is None for p in sharding.spec):
      return sharding
    # If there are explicit axes, we don't support changing the shape, so we
    # don't support int indexers and instead require all slices.
    if (self.int_indexer_shape or
        not all(isinstance(idx, Slice) for idx in self.indices)):
      raise TypeError("sharded ref (array reference) can only be indexed by "
                      "slices, not integers")
    #  Moreover, only allow trivial slice(None) slices on explicitly sharded
    #  axes. Then the sharding stays the same.
    _, slice_indexers, _ = unpack_ndindexer(self)
    for i, (d, sl, s) in enumerate(zip(self.shape, slice_indexers, sharding.spec)):
      if s is None: continue
      if not (type(sl.start)  is int and sl.start == 0 and
              type(sl.size)   is int and sl.size  == d and
              type(sl.stride) is int and sl.stride == 1):
        raise ValueError("sharded ref (array reference) can only be sliced "
                         f"along unsharded axes, but ref of shape {self.shape} "
                         f"was sliced on axis {i}, which is sharded like {s}")
    return sharding

  def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc:
    indices = []
    for idx, dim in zip(self.indices, self.shape):
      if isinstance(idx, Slice):
        indices.append(_pp_slice(context, dim, idx))
      else:
        indices.append(core.pp_var(idx, context, print_literal_dtype=False))  # type: ignore
    return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")])
