# Copyright 2017 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An XLA client in Python."""

from __future__ import annotations

import atexit
from collections.abc import Mapping
import contextlib
import enum
import logging
import os
import threading
from typing import Any, Protocol, Union

from jaxlib import _jax as _xla

# Note this module does *not* depend on any Python protocol buffers. The XLA
# Python bindings are currently packaged both as part of jaxlib and as part
# of TensorFlow. If we use protocol buffers here, then importing both jaxlib
# and TensorFlow may fail with duplicate protocol buffer message definitions.

# Most functions are snake_case for consistency with other modules, some
# method names are CamelCase for consistency with XLA.
# pylint: disable=invalid-name

# Pylint has false positives for type annotations.
# pylint: disable=invalid-sequence-index

ifrt_programs = _xla.ifrt_programs

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.jaxlib_extension_version.
#
# Please suffix the version number with a brief description of your change
# in a comment. The goal here is to force a merge conflict if two changes
# attempt to grab the same version number.
_version = 391  # ResultHandler.pre_wrap

# An internal increasing version number for protecting jaxlib code against
# ifrt changes.
# lives in xla/python/version.h.
# In JAX, reference this via jax._src.lib.ifrt_version.
_ifrt_version = _xla.ifrt_version_number

xla_platform_names = {
    'cpu': 'Host',
    'gpu': 'CUDA',
}

logger = logging.getLogger(__name__)

_NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]]


def make_cpu_client(
    asynchronous=True,
    distributed_client=None,
    node_id=0,
    num_nodes=1,
    collectives=None,
    num_devices=None,
    get_local_topology_timeout_minutes=None,
    get_global_topology_timeout_minutes=None,
    transfer_server_factory=None,
) -> Client:
  register_custom_call_handler('cpu', _xla.register_custom_call_target)
  register_custom_type_handler('cpu', _xla.register_custom_type)
  return _xla.get_tfrt_cpu_client(
      asynchronous=asynchronous,
      distributed_client=distributed_client,
      node_id=node_id,
      num_nodes=num_nodes,
      collectives=collectives,
      num_devices=num_devices,
      get_local_topology_timeout_minutes=get_local_topology_timeout_minutes,
      get_global_topology_timeout_minutes=get_global_topology_timeout_minutes,
      transfer_server_factory=transfer_server_factory,
  )


DeviceTopology = _xla.DeviceTopology
get_topology_for_devices = _xla.get_topology_for_devices


def make_tfrt_tpu_c_api_device_topology(
    topology_name: str = '', **kwargs
) -> DeviceTopology:
  """Creates a PJRT C API TopologyDescription."""
  return _xla.get_default_c_api_topology('tpu', topology_name, dict(**kwargs))


def make_c_api_device_topology(
    c_api: Any, topology_name: str = '', **kwargs
) -> DeviceTopology:
  """Creates a PJRT C API TopologyDescription."""
  return _xla.get_c_api_topology(c_api, topology_name, dict(**kwargs))


def pjrt_plugin_loaded(plugin_name: str) -> bool:
  return _xla.pjrt_plugin_loaded(plugin_name)


def load_pjrt_plugin_dynamically(plugin_name: str, library_path: str) -> Any:
  return _xla.load_pjrt_plugin(plugin_name, library_path, c_api=None)


def load_pjrt_plugin_with_c_api(plugin_name: str, c_api: Any) -> None:
  _xla.load_pjrt_plugin(plugin_name, None, c_api)


def pjrt_plugin_initialized(plugin_name: str) -> bool:
  return _xla.pjrt_plugin_initialized(plugin_name)


def initialize_pjrt_plugin(plugin_name: str) -> None:
  """Initializes a PJRT plugin.

  The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
  static linking) before this method is called.
  Args:
    plugin_name: the name of the PJRT plugin.
  """
  _xla.initialize_pjrt_plugin(plugin_name)


def make_c_api_client(
    plugin_name: str,
    options: _NameValueMapping | None = None,
    distributed_client: _xla.DistributedRuntimeClient | None = None,
    transfer_server_factory: _xla.TransferServerInterfaceFactory | None = None,
):
  """Creates a PJRT C API client for a PJRT plugin.

  It is required that load_pjrt_plugin_dynamically is called once with the same
  plugin_name before this method is called.

  Args:
     plugin_name: the name of the PJRT plugin.
     options: extra platform-specific options.
     distributed_client: distributed client.

  Returns:
     A PJRT C API client for plugin_name.
  """
  if options is None:
    options = {}
  return _xla.get_c_api_client(
      plugin_name,
      options,
      distributed_client,
      transfer_server_factory,
  )


def generate_pjrt_gpu_plugin_options() -> _NameValueMapping:
  """Generates the PjRt GPU plugin options.

  Returns:
    A dictionary of plugin options.
  """

  options: dict[str, Any] = {}
  options['platform_name'] = 'cuda'
  allocator = os.getenv('XLA_PYTHON_CLIENT_ALLOCATOR', 'default').lower()
  memory_fraction = os.getenv('XLA_CLIENT_MEM_FRACTION', '')
  deprecated_memory_fraction = os.getenv('XLA_PYTHON_CLIENT_MEM_FRACTION', '')
  if deprecated_memory_fraction:
    if memory_fraction:
      raise ValueError(
          'XLA_CLIENT_MEM_FRACTION is specified together '
          'with XLA_PYTHON_CLIENT_MEM_FRACTION. '
          'Remove the latter one, it is deprecated.'
      )
    else:
      memory_fraction = deprecated_memory_fraction
  preallocate = os.getenv('XLA_PYTHON_CLIENT_PREALLOCATE', '')
  collective_memory_size = os.getenv(
      'XLA_PYTHON_CLIENT_COLLECTIVE_MEM_SIZE_MB', ''
  )
  if allocator not in ('default', 'platform', 'bfc', 'cuda_async'):
    raise ValueError(
        'XLA_PYTHON_CLIENT_ALLOCATOR env var must be "default", "platform", '
        '"bfc", or "cuda_async", got "%s"' % allocator
    )
  options['allocator'] = allocator
  if memory_fraction:
    options['memory_fraction'] = float(memory_fraction)
  if preallocate:
    options['preallocate'] = preallocate not in ('false', 'False', '0')
  if collective_memory_size:
    options['collective_memory_size'] = int(collective_memory_size) * (1 << 20)
  abort = os.getenv('XLA_PYTHON_CLIENT_ABORT_COLLECTIVES_ON_FAILURE', '0')
  options['abort_collectives_on_failure'] = bool(int(abort))
  use_trft_gpu_client = os.getenv('XLA_PYTHON_CLIENT_USE_TFRT_GPU_CLIENT', '0')
  options['use_tfrt_gpu_client'] = bool(int(use_trft_gpu_client))
  return options


PrimitiveType = _xla.PrimitiveType

Shape = _xla.Shape
Shape.__doc__ = """
A Shape is an object defined in C++ that duck types like the following class:

class Shape:
  '''Represents an XLA shape.

  A shape is either an array shape, having rank-many integer
  dimensions and an element type (represented by a Numpy dtype), or it
  is a tuple shape, having a shape for every tuple component:

    type shape =
        TupleShape of shape list
      | ArrayShape of { dimensions: int list; element_type: dtype }
  '''

  @staticmethod
  def tuple_shape(tuple_shapes) -> Shape:
    "Construct a tuple shape."

  @staticmethod
  def array_shape(element_type, dimensions, minor_to_major=None) -> Shape:

  @staticmethod
  def from_pyval(pyval) -> Shape:
    "Returns a Shape that describes a tuple-tree of Numpy arrays."

  def __init__(self, str) -> Shape:
    "Parses a shape string."
  def __eq__(self, other: Shape) -> bool:
  def __ne__(self, other: Shape) -> bool:
  def __hash__(self):
  def __repr__(self):
  def is_tuple(self) -> bool:
  def is_array(self) -> bool:
  def tuple_shapes(self) -> [Shape]:
  def numpy_dtype(self) -> np.dtype:
    "Like element_type(), but returns dtype('O') for a tuple shape."
  def xla_element_type(self) -> PrimitiveType:
  def element_type(self) -> np.dtype:
  def dimensions(self) -> (int, int, ...):
  def rank(self) -> int:
  def with_major_to_minor_layout_if_absent(self) -> Shape:
    "Returns a copy with missing layouts set to major-to-minor."

  def to_serialized_proto(self) -> bytes:
    "Returns 'shape' as a serialized proto."
"""

ProgramShape = _xla.ProgramShape
ProgramShape.__doc__ = """
A ProgramShape is a C++ object that duck types like the following class.

class ProgramShape:
  def __init__(self, parameter_shapes, result_shape):
  def parameter_shapes(self) -> [Shape]:
  def result_shape(self) -> Shape:
  def __repr__(self):
"""

DeviceAssignment = _xla.DeviceAssignment
DeviceAssignment.__doc__ = """
A DeviceAssignment is a C++ object with the following signature.

def create(assignment):
  '''Builds a device assignment.

   Args:
     assignment: a 2D numpy array of device ordinal integers, indexed by
       [replica][computation_in_replica].
   Returns:
     A device assignment.
  '''

def replica_count():
  '''Returns the number of replicas.'''
def computation_count():
  '''Returns the number of computations per replica.'''
"""

Device = _xla.Device
CompileOptions = _xla.CompileOptions

HostBufferSemantics = _xla.HostBufferSemantics

# An Executable is a C++ class that duck types with the following API:
# class Executable:
#   def local_devices(self) -> [Device]:
#   def execute(self, arguments : [Buffer]) -> Buffer:
#     """Execute on one replica with Buffer arguments and return value."""
#
#   def size_of_generated_code_in_bytes(self) -> int:
#     """Return generated binary size, or -1 if not known."""
#
#   def execute_sharded_on_local_devices(self, arguments: [[Buffer]])
#       -> [Buffer]:
#     """Execute on many replicas with Buffer arguments and return value.
#
#     Args:
#       arguments: A sequence of sequences of Buffers. The i'th element of each
#         sequence comprises the arguments for execution on the i'th local
#         device.
#
#     Returns:
#       A list of the computation's outputs as a list of Buffers for each
#       device.
#     """
#
# There are different implementations of Executable for different backends.


XlaComputation = _xla.XlaComputation
Client = _xla.Client
Memory = _xla.Memory
Array = _xla.Array
ArrayImpl = _xla.ArrayImpl
LoadedExecutable = _xla.LoadedExecutable
Executable = _xla.Executable
DeviceList = _xla.DeviceList
OpSharding = _xla.OpSharding
HloSharding = _xla.HloSharding
Sharding = _xla.Sharding
NamedSharding = _xla.NamedSharding
SingleDeviceSharding = _xla.SingleDeviceSharding
PmapSharding = _xla.PmapSharding
GSPMDSharding = _xla.GSPMDSharding
PjRtLayout = _xla.PjRtLayout
AutotuneCacheMode = _xla.AutotuneCacheMode


def LoadedExecutable_execute(self, arguments, device=None):
  del device
  results = self.execute_sharded(arguments)
  return [x[0] for x in results.disassemble_into_single_device_arrays()]


def LoadedExecutable_execute_with_token(self, arguments, device=None):
  del device
  results = self.execute_sharded(arguments, with_tokens=True)
  return (
      [x[0] for x in results.disassemble_into_single_device_arrays()],
      results.consume_token().get_token(0),
  )


LoadedExecutable.execute = LoadedExecutable_execute  # type: ignore[method-assign]
LoadedExecutable.execute_with_token = LoadedExecutable_execute_with_token  # type: ignore[method-assign]


class CustomCallTargetTraits(enum.IntFlag):
  DEFAULT = 0
  # Calls to custom call are safe to trace into the command buffer. It means
  # that calls to custom call always launch exactly the same device operations
  # (can depend on attribute values) that can be captured and then replayed.
  #
  # Supported only for custom calls implemented with XLA FFI.
  COMMAND_BUFFER_COMPATIBLE = 1


class CustomCallHandler(Protocol):

  def __call__(
      self,
      name: str,
      fn: Any,
      platform: str,
      /,
      api_version: int = ...,
      traits: CustomCallTargetTraits = ...,
  ) -> None:
    ...


_custom_callback_handler: dict[str, CustomCallHandler] = {}
# Key is xla_platform_name, value is (function_name, function, api_version)
_custom_callback: dict[
    str, list[tuple[str, Any, int, CustomCallTargetTraits]]
] = {}
_custom_callback_lock = threading.Lock()


def register_custom_call_target(
    name: str,
    fn: Any,
    platform: str = 'cpu',
    api_version: int = 0,
    traits: CustomCallTargetTraits = CustomCallTargetTraits.DEFAULT,
) -> None:
  """Registers a custom call target.

  Args:
    name: bytes containing the name of the function.
    fn: a PyCapsule object containing the function pointer.
    platform: the target platform.
    api_version: the XLA FFI version to use. Supported versions are: 0 for the
      untyped FFI and 1 for the typed FFI.
    traits: custom call traits corresponding to XLA FFI handler traits.
  """
  # To support AMD GPUs, we need to have xla_platform_names["gpu"] == "ROCM"
  # Since that is hardcoded to CUDA, we are using the following as workaround.
  xla_platform_name = xla_platform_names.get(platform, platform)
  with _custom_callback_lock:
    if xla_platform_name in _custom_callback_handler:
      _custom_callback_handler[xla_platform_name](
          name, fn, xla_platform_name, api_version, traits
      )
    else:
      _custom_callback.setdefault(xla_platform_name, []).append(
          (name, fn, api_version, traits)
      )


def register_custom_call_handler(
    platform: str, handler: CustomCallHandler
) -> None:
  """Registers a custom handler and use it to register existing custom calls.

  If a custom call handler for the platform already exist, calling this method
  is a no-op and it will not register a new handler.

  Args:
    platform: the target platform.
    handler: the function to register a custom call.
  """
  xla_platform_name = xla_platform_names.get(platform, platform)
  with _custom_callback_lock:
    if xla_platform_name in _custom_callback_handler:
      logger.debug(
          'Custom call handler for %s is already register. Will not register a'
          ' new one',
          xla_platform_name,
      )
      return
    _custom_callback_handler[xla_platform_name] = handler
    if xla_platform_name in _custom_callback:
      for name, fn, api_version, traits in _custom_callback[xla_platform_name]:
        handler(name, fn, xla_platform_name, api_version, traits)
      del _custom_callback[xla_platform_name]


class CustomTypeIdHandler(Protocol):

  def __call__(self, type_name: str, type_id: Any, /) -> None:
    ...


_custom_type_id_handler: dict[str, CustomTypeIdHandler] = {}
_custom_type_id: dict[str, Any] = {}
_custom_type_id_lock = threading.Lock()


def register_custom_type(
    type_name: str,
    type_id: Any,
    platform: str = 'cpu',
) -> None:
  """Register a custom type id for use with the FFI.

  Args:
    type_name: a unique name for the type.
    type_id: a PyCapsule object containing a pointer to the ``ffi::TypeId``.
    platform: the target platform.
  """
  xla_platform_name = xla_platform_names.get(platform, platform)
  with _custom_type_id_lock:
    if xla_platform_name in _custom_type_id_handler:
      _custom_type_id_handler[xla_platform_name](type_name, type_id)
    else:
      _custom_type_id.setdefault(xla_platform_name, []).append(
          (type_name, type_id)
      )


def register_custom_type_handler(
    platform: str, handler: CustomTypeIdHandler
) -> None:
  """Register a custom type id handler and use it to register existing type ids.

  If a custom type id handler for the platform already exist, calling this
  method is a no-op and it will not register a new handler.

  Args:
    platform: the target platform.
    handler: the function to register a custom type id.
  """
  xla_platform_name = xla_platform_names.get(platform, platform)
  with _custom_callback_lock:
    if xla_platform_name in _custom_type_id_handler:
      logger.debug(
          'Custom type id handler for %s is already register. Will not '
          'register a new one',
          xla_platform_name,
      )
      return
    _custom_type_id_handler[xla_platform_name] = handler
    if xla_platform_name in _custom_type_id:
      for name, capsule in _custom_type_id[xla_platform_name]:
        handler(name, capsule)
      del _custom_type_id[xla_platform_name]


register_custom_call_partitioner = _xla.register_custom_call_partitioner
encode_inspect_sharding_callback = _xla.encode_inspect_sharding_callback
hlo_sharding_util = _xla.hlo_sharding_util
register_custom_call_as_batch_partitionable = (
    _xla.register_custom_call_as_batch_partitionable
)


Traceback = _xla.Traceback
Frame = _xla.Frame


@contextlib.contextmanager
def execution_stream_id(new_id: int):
  """Context manager that overwrites and restores the current thread's execution_stream_id."""
  saved = _xla.get_execution_stream_id()
  _xla.set_execution_stream_id(new_id)
  try:
    yield
  finally:
    _xla.set_execution_stream_id(saved)


XlaRuntimeError = _xla.JaxRuntimeError

# Perform one last garbage collection of deferred Python references. This is
# mostly to keep ASAN happy.
atexit.register(_xla.collect_garbage)

array_result_handler = _xla.array_result_handler
batched_copy_array_to_devices_with_sharding = (
    _xla.batched_copy_array_to_devices_with_sharding
)
batched_device_put = _xla.batched_device_put
reorder_shards = _xla.reorder_shards
batched_block_until_ready = _xla.batched_block_until_ready
check_and_canonicalize_memory_kind = _xla.check_and_canonicalize_memory_kind
Layout = _xla.Layout
custom_call_targets = _xla.custom_call_targets
ArrayCopySemantics = _xla.ArrayCopySemantics
