# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from jax._src import array
from jax._src import dtypes
from jax._src import xla_bridge
from jax._src.api import device_put
from jax._src.lax.lax import _array_copy
from jax._src.lib import _jax
from jax._src.lib import xla_client
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import scalar_types as jnp_types
from jax._src.sharding import Sharding
from jax._src.typing import Array, DLDeviceType, DTypeLike

import numpy as np


DLPACK_VERSION = (0, 8)
MIN_DLPACK_VERSION = (0, 5)

# A set of dtypes that dlpack supports.
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
# because their hashes are different.
# For example,
# hash(jnp.float32) != hash(jnp.dtype(jnp.float32))
# hash(jnp.float32) == hash(jnp.dtype(jnp.float32).type)

# TODO(vanderplas): remove this set
SUPPORTED_DTYPES: frozenset[DTypeLike] = frozenset({
    jnp_types.int8, jnp_types.int16, jnp_types.int32, jnp_types.int64,
    jnp_types.uint8, jnp_types.uint16, jnp_types.uint32, jnp_types.uint64,
    jnp_types.float16, jnp_types.bfloat16, jnp_types.float32, jnp_types.float64,
    jnp_types.complex64, jnp_types.complex128, jnp_types.bool_})

SUPPORTED_DTYPES_SET: frozenset[np.dtype] = frozenset({np.dtype(dt) for dt in SUPPORTED_DTYPES})


def is_supported_dtype(dtype: DTypeLike) -> bool:
  """Check if dtype is supported by jax.dlpack."""
  if dtype is None:
    # NumPy will silently cast this to float64, which may be surprising.
    raise TypeError(f"Expected a string or dtype-like object; got {dtype=}")
  return np.dtype(dtype) in SUPPORTED_DTYPES_SET


def _to_dlpack(x: Array, stream: int | Any | None,
               src_device: _jax.Device | None = None,
               device: _jax.Device | None = None,
               copy: bool | None = None):

  if src_device is None:
    src_device, = x.devices()
  if device and (src_device is None or device != src_device):
    if copy is not None and not copy:
      raise ValueError(
        f"Specified {device=} which requires a copy since the source device "
        f"is {repr(src_device)}, however copy=False. Set copy=True or "
        "copy=None to perform the requested operation."
      )
    else:
      arr = device_put(x, device)
  else:
    arr = _array_copy(x) if copy else x
  return _jax.buffer_to_dlpack_managed_tensor(
    arr.addressable_data(0), stream=stream
  )


_DL_DEVICE_TO_PLATFORM = {
    DLDeviceType.kDLCPU: "cpu",
    DLDeviceType.kDLCUDA: "cuda",
    DLDeviceType.kDLROCM: "rocm",
}


def to_dlpack(x: Array, stream: int | Any | None = None,
              src_device: _jax.Device | None = None,
              dl_device: tuple[DLDeviceType, int] | None = None,
              max_version: tuple[int, int] | None = None,
              copy : bool | None = None):
  """Returns a DLPack tensor that encapsulates a :class:`~jax.Array` ``x``.

  Args:
    x: a :class:`~jax.Array`, on either CPU or GPU.
    stream: optional platform-dependent stream to wait on until the buffer is
      ready. This corresponds to the `stream` argument to ``__dlpack__``
      documented in https://dmlc.github.io/dlpack/latest/python_spec.html.
    src_device: either a CPU or GPU :class:`~jax.Device`.
    dl_device: a tuple of ``(dl_device_type, local_hardware_id)`` in DLPack
      format e.g. as produced by ``__dlpack_device__``.
    max_version: the maximum DLPack version that the consumer (i.e. caller of
      ``__dlpack__``) supports in the form of a 2-tuple of ``(major, minor)``.
      This function is not guaranteed to return a capsule of version
      ``max_version``.
    copy: a boolean indicating whether or not to copy the input. If
      ``copy=True`` then the function must always copy. When
      ``copy=False`` then the function must never copy, and must raise an error
      when a copy is deemed necessary. If ``copy=None`` then the function must
      avoid a copy if possible but may copy if needed.

  Returns:
    A DLPack PyCapsule object.

  Note:
    While JAX arrays are always immutable, ``DLPackManagedTensor`` buffers
    cannot be marked as immutable, and it is possible for processes external
    to JAX to mutate them in-place. If a DLPack buffer derived from a JAX array
    is mutated, it may lead to undefined behavior when using the associated JAX
    array. When JAX eventually supports ``DLManagedTensorVersioned``
    (DLPack 1.0), it will be possible to specify that a buffer is read-only.
  """
  if not isinstance(x, array.ArrayImpl):
    raise TypeError("Argument to to_dlpack must be a jax.Array, "
                    f"got {type(x)}")

  device = None
  dl_device_type, local_hardware_id = dl_device if dl_device else (None, None)
  if dl_device_type:
    try:
      dl_device_platform = _DL_DEVICE_TO_PLATFORM[dl_device_type]
      backend = xla_bridge.get_backend(dl_device_platform)
      device = backend.device_from_local_hardware_id(local_hardware_id)
    except KeyError:
      # https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
      # recommends using BufferError.
      raise BufferError(
          "The device specification passed to to_dlpack contains an"
          f" unsupported device type (DLDeviceType: {dl_device_type})"
      ) from None

  # As new versions are adopted over time, we can maintain some legacy paths
  # for compatibility mediated through the max_version parameter.
  # TODO(micky774): Deprecate default usage of DLPackManagedTensor when XLA
  # supports DLManagedTensorVersioned (DLPack version 1.0) and repurpose the
  # current _to_dlpack as a legacy path for (0,5) <= max_version < (1,0).
  if max_version is None or max_version >= DLPACK_VERSION:
    # Latest
    return _to_dlpack(
      x, stream=stream,
      src_device=src_device,
      device=device,
      copy=copy
    )
  elif max_version >= MIN_DLPACK_VERSION:
    # Oldest supported
    return _to_dlpack(
      x, stream=stream,
      src_device=src_device,
      device=device,
      copy=copy
    )
  else:
    raise BufferError(
      f"JAX does not support any version below {MIN_DLPACK_VERSION} but "
      f"version ({max_version}) was requested."
    )

def _check_device(device, dlpack_device, copy):
  if device and dlpack_device != device:
    if copy is not None and not copy:
      raise ValueError(
        f"Specified {device=} which requires a copy since the source device "
        f"is {repr(dlpack_device)}, however copy=False. Set copy=True or "
        "copy=None to perform the requested operation."
      )

def _place_array(_arr, device, dlpack_device, copy):
  if device and dlpack_device != device:
    return device_put(_arr, device)
  if copy:
    return jnp.array(_arr, copy=True)
  return _arr

def _is_tensorflow_tensor(external_array):
  t = type(external_array)
  return (
      t.__qualname__ == "EagerTensor"
      and t.__module__.endswith("tensorflow.python.framework.ops")
  )

def from_dlpack(external_array,
                device: _jax.Device | Sharding | None = None,
                copy: bool | None = None):
  """Returns a :class:`~jax.Array` representation of a DLPack tensor.

  The returned :class:`~jax.Array` shares memory with ``external_array`` if no
  device transfer or copy was requested.

  Args:
    external_array: An array object that has ``__dlpack__`` and
      ``__dlpack_device__`` methods.
    device: The (optional) :py:class:`Device`, representing the device on which
      the returned array should be placed. If given, then the result is
      committed to the device. If unspecified, the resulting array will be
      unpacked onto the same device it originated from. Setting ``device`` to a
      device different from the source of ``external_array`` will require a
      copy, meaning ``copy`` must be set to either ``True`` or ``None``.
    copy: An (optional) boolean, controlling whether or not a copy is performed.
      If ``copy=True`` then a copy is always performed, even if unpacked onto
      the same device. If ``copy=False`` then the copy is never performed and
      will raise an error if necessary. When ``copy=None`` then a copy may be
      performed if needed for a device transfer.

  Returns:
    A jax.Array

  Note:
    While JAX arrays are always immutable, dlpack buffers cannot be marked as
    immutable, and it is possible for processes external to JAX to mutate them
    in-place. If a jax Array is constructed from a dlpack buffer and the buffer
    is later modified in-place, it may lead to undefined behavior when using
    the associated JAX array.
  """
  if isinstance(device, Sharding):
    device_set = device.device_set
    if len(device_set) > 1:
      raise ValueError(
        "from_dlpack can only unpack a dlpack tensor onto a singular device, but "
        f"a Sharding with {len(device_set)} devices was provided."
      )
    device, = device_set
  if not hasattr(external_array, "__dlpack__") or not hasattr(external_array, "__dlpack_device__"):
    raise TypeError(
        "The array passed to from_dlpack must have __dlpack__ and __dlpack_device__ methods."
    )

  dl_device_type, device_id = external_array.__dlpack_device__()
  try:
    dl_device_platform = _DL_DEVICE_TO_PLATFORM[dl_device_type]
  except KeyError:
    raise TypeError(
        "Array passed to from_dlpack is on unsupported device type "
        f"(DLDeviceType: {dl_device_type}, array: {external_array}"
    ) from None

  backend = xla_bridge.get_backend(dl_device_platform)
  dlpack_device = backend.device_from_local_hardware_id(device_id)
  _check_device(device, dlpack_device, copy)
  if _is_tensorflow_tensor(external_array):
    # TensorFlow does not support stream=.
    stream = None
  else:
    try:
      stream = dlpack_device.get_stream_for_external_ready_events()
    except _jax.JaxRuntimeError as err:
      if "UNIMPLEMENTED" in str(err):
        stream = None
      else:
        raise
  dlpack = external_array.__dlpack__(stream=stream)

  try:
    arr = _jax.dlpack_managed_tensor_to_buffer(
      dlpack, dlpack_device, stream, copy)
  except xla_client.XlaRuntimeError as e:
    se = str(e)
    if "is not aligned to" in se:
      i = se.index("is not aligned to")
      raise ValueError(
        "Specified input which requires a copy since the source data "
        f"buffer {se[i:]} However copy=False. Set copy=True or "
        "copy=None to perform the requested operation."
      )
    else:
      raise
  # TODO(phawkins): when we are ready to support x64 arrays in
  # non-x64 mode, change the semantics to not canonicalize here.
  arr = jnp.asarray(arr, dtype=dtypes.canonicalize_dtype(arr.dtype))
  if copy:
    # copy was already handled by dlpack_managed_tensor_to_buffer.
    copy = None
  return _place_array(arr, device, dlpack_device, copy)
