# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from typing import Any

import numpy as np

from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import literals
from jax._src import tree_util
from jax._src import xla_bridge
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.sharding import Sharding
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P


export = util.set_module('jax.numpy')

for pkg_name in ['jax_cuda13_plugin', 'jax_cuda12_plugin', 'jaxlib.cuda']:
  try:
    cuda_plugin_extension = importlib.import_module(
        f'{pkg_name}.cuda_plugin_extension'
    )
  except ImportError:
    cuda_plugin_extension = None  # type: ignore
  else:
    break


def _supports_buffer_protocol(obj):
  try:
    view = memoryview(obj)
  except TypeError:
    return False
  else:
    return True


def _make_string_array(
    object: np.ndarray,
    dtype: DTypeLike | None = None,
    ndmin: int = 0,
    device: xc.Device | Sharding | None = None,
) -> Array:
  if not isinstance(object, np.ndarray):
    raise TypeError(
        "Currently, string arrays can only be made from NumPy"
        f" arrays. Got:  {type(object)}."
    )
  if dtype is not None and (
      dtypes.is_string_dtype(object.dtype) != dtypes.is_string_dtype(dtype)
  ):
    raise TypeError(
        f"Cannot make an array with dtype {dtype} from an object with dtype"
        f" {object.dtype}."
    )
  if ndmin > object.ndim:
    raise TypeError(
        f"ndmin {ndmin} cannot be greater than object's ndims"
        f" {object.ndim} for string arrays."
    )

  # Just do a device_put since XLA does not support string as a data type.
  return api.device_put(x=object, device=device)


@export
def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
          order: str | None = "K", ndmin: int = 0,
          *, device: xc.Device | Sharding | None = None,
          out_sharding: NamedSharding | P | None = None) -> Array:
  """Convert an object to a JAX array.

  JAX implementation of :func:`numpy.array`.

  Args:
    object: an object that is convertible to an array. This includes JAX
      arrays, NumPy arrays, Python scalars, Python collections like lists
      and tuples, objects with a ``__jax_array__`` method, and objects
      supporting the Python buffer protocol.
    dtype: optionally specify the dtype of the output array. If not
      specified it will be inferred from the input.
    copy: specify whether to force a copy of the input. Default: True.
    order: not implemented in JAX
    ndmin: integer specifying the minimum number of dimensions in the
      output array.
    device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
      to which the created array will be committed.
    out_sharding: (optional) :class:`~jax.sharding.PartitionSpec` or :class:`~jax.NamedSharding`
      representing the sharding of the created array (see `explicit sharding`_ for more details).
      This argument exists for consistency with other array creation routines across JAX.
      Specifying both ``out_sharding`` and ``device`` will result in an error.

  Returns:
    A JAX array constructed from the input.

  See also:
    - :func:`jax.numpy.asarray`: like `array`, but by default only copies
      when necessary.
    - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
      that implements the dlpack interface.
    - :func:`jax.numpy.frombuffer`: construct a JAX array from an object
      that implements the buffer interface.

  Examples:
    Constructing JAX arrays from Python scalars:

    >>> jnp.array(True)
    Array(True, dtype=bool)
    >>> jnp.array(42)
    Array(42, dtype=int32, weak_type=True)
    >>> jnp.array(3.5)
    Array(3.5, dtype=float32, weak_type=True)
    >>> jnp.array(1 + 1j)
    Array(1.+1.j, dtype=complex64, weak_type=True)

    Constructing JAX arrays from Python collections:

    >>> jnp.array([1, 2, 3])  # list of ints -> 1D array
    Array([1, 2, 3], dtype=int32)
    >>> jnp.array([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
    Array([[1, 2, 3],
           [4, 5, 6]], dtype=int32)
    >>> jnp.array(range(5))
    Array([0, 1, 2, 3, 4], dtype=int32)

    Constructing JAX arrays from NumPy arrays:

    >>> jnp.array(np.linspace(0, 2, 5))
    Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

    Constructing a JAX array via the Python buffer interface, using Python's
    built-in :mod:`array` module.

    >>> from array import array
    >>> pybuffer = array('i', [2, 3, 5, 7])
    >>> jnp.array(pybuffer)
    Array([2, 3, 5, 7], dtype=int32)

  .. _explicit sharding: https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html
  """
  if order is not None and order != "K":
    raise NotImplementedError("Only implemented for order='K'")

  # check if the given dtype is compatible with JAX
  if dtype is not None:
    dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "array")

  # Here we make a judgment call: we only return a weakly-typed array when the
  # input object itself is weakly typed. That ensures asarray(x) is a no-op
  # whenever x is weak, but avoids introducing weak types with something like
  # array([1, 2, 3])
  weak_type = dtype is None and dtypes.is_weakly_typed(object)

  if device is None and out_sharding is None and isinstance(object, core.Tracer):
    sharding = object.aval.sharding
    sharding = None if sharding.mesh.empty else sharding
  else:
    sharding = util.choose_device_or_out_sharding(device, out_sharding, "jnp.array")

  # Use device_put to avoid a copy for ndarray inputs.
  if (not copy and isinstance(object, np.ndarray) and
      (dtype is None or dtype == object.dtype) and (ndmin <= object.ndim) and
      device is None):
    if dtype is not None:
      # If there is an explicit dtype, we've already canonicalized things and
      # device_put should not canonicalize again.
      object = literals.TypedNdArray(object, weak_type=False)
    # Keep the output uncommitted.
    return api.device_put(object)

  # String arrays need separate handling because XLA does not support string
  # as a data type.
  if dtypes.is_string_dtype(dtype) or (
      hasattr(object, "dtype") and dtypes.is_string_dtype(object.dtype)
  ):
    return _make_string_array(
        object=object, dtype=dtype, ndmin=ndmin, device=device
    )

  # For Python scalar literals, call coerce_to_array to catch any overflow
  # errors. We don't use dtypes.is_python_scalar because we don't want this
  # triggering for traced values. We do this here because it matters whether or
  # not dtype is None. We don't assign the result because we want the raw object
  # to be used for type inference below.
  if isinstance(object, (bool, int, float, complex)):
    _ = dtypes.coerce_to_array(object, dtype)
  elif not isinstance(object, Array):
    # Check if object supports any of the data exchange protocols
    # (except dlpack, see data-apis/array-api#301). If it does,
    # consume the object as jax array and continue (but not return) so
    # that other array() arguments get processed against the input
    # object.
    #
    # Notice that data exchange protocols define dtype in the
    # corresponding data structures and it may not be available as
    # object.dtype. So, we'll resolve the protocols here before
    # evaluating object.dtype.
    if hasattr(object, '__jax_array__'):
      object = object.__jax_array__()
    elif hasattr(object, '__cuda_array_interface__'):
      cai = object.__cuda_array_interface__
      backend = xla_bridge.get_backend("cuda")
      if cuda_plugin_extension is None:
        device_id = None
      else:
        device_id = cuda_plugin_extension.get_device_ordinal(cai["data"][0])
      object = xc._xla.cuda_array_interface_to_buffer(
          cai=cai, gpu_backend=backend, device_id=device_id)

  # To handle nested lists & tuples, flatten the tree and process each leaf.
  leaves, treedef = tree_util.tree_flatten(
      object, is_leaf=lambda x: not isinstance(x, (list, tuple)))
  if any(leaf is None for leaf in leaves):
    raise ValueError("None is not a valid value for jnp.array")
  leaves = [
      leaf
      if (leaf_jax_array := getattr(leaf, "__jax_array__", None)) is None
      else leaf_jax_array()
      for leaf in leaves
  ]
  if dtype is None:
    # Use lattice_result_type rather than result_type to avoid canonicalization.
    # Otherwise, weakly-typed inputs would have their dtypes canonicalized.
    try:
      dtype = (
          dtypes.lattice_result_type(*leaves)[0]
          if leaves
          else dtypes.default_float_dtype()
      )
    except TypeError:
      # This happens if, e.g. one of the entries is a memoryview object.
      # This is rare, so we only handle it if the normal path fails.
      leaves = [_convert_to_array_if_dtype_fails(leaf) for leaf in leaves]
      dtype = dtypes.lattice_result_type(*leaves)[0]

  object = treedef.unflatten(leaves)
  out: ArrayLike
  if all(not isinstance(leaf, Array) for leaf in leaves):
    # TODO(jakevdp): falling back to numpy here fails to overflow for lists
    # containing large integers; see discussion in
    # https://github.com/jax-ml/jax/pull/6047. More correct would be to call
    # coerce_to_array on each leaf, but this may have performance implications.
    out = np.asarray(object, dtype=dtype)
  elif isinstance(object, Array):
    assert object.aval is not None
    out = lax._array_copy(object) if copy else object
  elif isinstance(object, (list, tuple)):
    if object:
      arrs = (array(elt, dtype=dtype, copy=False) for elt in object)
      arrays_out = [lax.expand_dims(arr, [0]) for arr in arrs]
      # lax.concatenate can be slow to compile for wide concatenations, so form a
      # tree of concatenations as a workaround especially for op-by-op mode.
      # (https://github.com/jax-ml/jax/issues/653).
      k = 16
      while len(arrays_out) > k:
        arrays_out = [lax.concatenate(arrays_out[i:i+k], 0)
                      for i in range(0, len(arrays_out), k)]
      out = lax.concatenate(arrays_out, 0)
    else:
      out = np.array([], dtype=dtype)
  elif _supports_buffer_protocol(object):
    object = memoryview(object)
    # TODO(jakevdp): update this once we support NumPy 2.0 semantics for the copy arg.
    out = np.array(object) if copy else np.asarray(object)
  else:
    raise TypeError(f"Unexpected input type for array: {type(object)}")
  out_array: Array = lax._convert_element_type(
      out, dtype, weak_type=weak_type, sharding=sharding)
  if ndmin > np.ndim(out_array):
    out_array = lax.expand_dims(out_array, range(ndmin - np.ndim(out_array)))
  return out_array


def _get_platform(
    device_or_sharding: xc.Device | Sharding | None | str) -> str:
  """Get device_or_sharding platform or look up config.default_device.value."""
  if isinstance(device_or_sharding, xc.Device):
    return device_or_sharding.platform
  elif isinstance(device_or_sharding, Sharding):
    return list(device_or_sharding.device_set)[0].platform
  elif isinstance(device_or_sharding, str):
    return device_or_sharding
  elif device_or_sharding is None:
    if config.default_device.value is None:
      return xla_bridge.default_backend()
    else:
      return _get_platform(config.default_device.value)
  else:
    raise ValueError(f"`{device_or_sharding = }` was passed to"
                     "`canonicalize_or_get_default_platform`, only xc.Device,"
                     " Sharding, None or str values are supported.")


def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:
  try:
    dtypes.dtype(x)
  except TypeError:
    return np.asarray(x)
  else:
    return x


@export
def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None,
            *, copy: bool | None = None,
            device: xc.Device | Sharding | None = None,
            out_sharding: NamedSharding | P | None = None) -> Array:
  """Convert an object to a JAX array.

  JAX implementation of :func:`numpy.asarray`.

  Args:
    a: an object that is convertible to an array. This includes JAX
      arrays, NumPy arrays, Python scalars, Python collections like lists
      and tuples, objects with a ``__jax_array__`` method, and objects
      supporting the Python buffer protocol.
    dtype: optionally specify the dtype of the output array. If not
      specified it will be inferred from the input.
    order: not implemented in JAX
    copy: optional boolean specifying the copy mode. If True, then always
      return a copy. If False, then error if a copy is necessary. Default is
      None, which will only copy when necessary.
    device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
      to which the created array will be committed.

  Returns:
    A JAX array constructed from the input.

  See also:
    - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`.
    - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object
      that implements the dlpack interface.
    - :func:`jax.numpy.frombuffer`: construct a JAX array from an object
      that implements the buffer interface.

  Examples:
    Constructing JAX arrays from Python scalars:

    >>> jnp.asarray(True)
    Array(True, dtype=bool)
    >>> jnp.asarray(42)
    Array(42, dtype=int32, weak_type=True)
    >>> jnp.asarray(3.5)
    Array(3.5, dtype=float32, weak_type=True)
    >>> jnp.asarray(1 + 1j)
    Array(1.+1.j, dtype=complex64, weak_type=True)

    Constructing JAX arrays from Python collections:

    >>> jnp.asarray([1, 2, 3])  # list of ints -> 1D array
    Array([1, 2, 3], dtype=int32)
    >>> jnp.asarray([(1, 2, 3), (4, 5, 6)])  # list of tuples of ints -> 2D array
    Array([[1, 2, 3],
           [4, 5, 6]], dtype=int32)
    >>> jnp.asarray(range(5))
    Array([0, 1, 2, 3, 4], dtype=int32)

    Constructing JAX arrays from NumPy arrays:

    >>> jnp.asarray(np.linspace(0, 2, 5))
    Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32)

    Constructing a JAX array via the Python buffer interface, using Python's
    built-in :mod:`array` module.

    >>> from array import array
    >>> pybuffer = array('i', [2, 3, 5, 7])
    >>> jnp.asarray(pybuffer)
    Array([2, 3, 5, 7], dtype=int32)
  """
  # For copy=False, the array API specifies that we raise a ValueError if the input supports
  # the buffer protocol but a copy is required. Since array() supports the buffer protocol
  # via numpy, this is only the case when the default device is not 'cpu'
  if (copy is False and not isinstance(a, Array)
      and _get_platform(device) != "cpu"
      and _supports_buffer_protocol(a)):
    raise ValueError(f"jnp.asarray: cannot convert object of type {type(a)} to JAX Array "
                     f"on platform={_get_platform(device)} with "
                     "copy=False. Consider using copy=None or copy=True instead.")
  if dtype is not None:
    dtype = dtypes.check_and_canonicalize_user_dtype(dtype, "asarray")
  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device,
               out_sharding=out_sharding)
