# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570

from jax._src.interpreters import batching as _src_batching

from jax._src.interpreters.batching import (
  axis_primitive_batchers as axis_primitive_batchers,
  bdim_at_front as bdim_at_front,
  broadcast as broadcast,
  defbroadcasting as defbroadcasting,
  defreducer as defreducer,
  defvectorized as defvectorized,
  fancy_primitive_batchers as fancy_primitive_batchers,
  not_mapped as not_mapped,
  primitive_batchers as primitive_batchers,
  register_vmappable as register_vmappable,
  unregister_vmappable as unregister_vmappable,
)


_deprecations = {
  # Deprecated for JAX v0.7.1; finalize in JAX v0.9.0.
  "AxisSize": (
    "jax.interpreters.batching.AxisSize is deprecated.",
    _src_batching.AxisSize,
  ),
  "Array": (
    "jax.interpreters.batching.Array is deprecated. Use jax.Array directly.",
    _src_batching.Array,
  ),
  "BatchTrace": (
    "jax.interpreters.batching.BatchTrace is deprecated.",
    _src_batching.BatchTrace,
  ),
  "BatchTracer": (
    "jax.interpreters.batching.BatchTracer is deprecated.",
    _src_batching.BatchTracer,
  ),
  "BatchingRule": (
    "jax.interpreters.batching.BatchingRule is deprecated.",
    _src_batching.BatchingRule,
  ),
  "Elt": (
    "jax.interpreters.batching.Elt is deprecated.",
    _src_batching.Elt,
  ),
  "FromEltHandler": (
    "jax.interpreters.batching.FromEltHandler is deprecated.",
    _src_batching.FromEltHandler,
  ),
  "GetIdx": (
    "jax.interpreters.batching.GetIdx is deprecated.",
    _src_batching.GetIdx,
  ),
  "MakeIotaHandler": (
    "jax.interpreters.batching.MakeIotaHandler is deprecated.",
    _src_batching.MakeIotaHandler,
  ),
  "MapSpec": (
    "jax.interpreters.batching.MapSpec is deprecated.",
    _src_batching.MapSpec,
  ),
  "NotMapped": (
    "jax.interpreters.batching.NotMapped is deprecated.",
    _src_batching.NotMapped,
  ),
  "ToEltHandler": (
    "jax.interpreters.batching.ToEltHandler is deprecated.",
    _src_batching.ToEltHandler,
  ),
  "Vmappable": (
    "jax.interpreters.batching.Vmappable is deprecated.",
    _src_batching.Vmappable,
  ),
  "Zeros": (
    "jax.interpreters.batching.Zero is deprecated. Use jax.interpreters.ad.Zero.",
    _src_batching.Zero,
  ),
  "ZeroIfMapped": (
    "jax.interpreters.batching.ZeroIfMapped is deprecated. It is an internal type.",
    _src_batching.ZeroIfMapped,
  ),
  "batch": (
    "jax.interpreters.batching.batch is deprecated. It is an internal API.",
    _src_batching.batch,
  ),
  "batch_custom_jvp_subtrace": (
    "jax.interpreters.batching.batch_custom_jvp_subtrace is deprecated. It is an internal API.",
    _src_batching.batch_custom_jvp_subtrace,
  ),
  "batch_custom_vjp_bwd": (
    "jax.interpreters.batching.batch_custom_vjp_bwd is deprecated. It is an internal API.",
    _src_batching.batch_custom_vjp_bwd,
  ),
  "batch_jaxpr": (
    "jax.interpreters.batching.batch_jaxpr is deprecated. It is an internal API.",
    _src_batching.batch_jaxpr,
  ),
  "batch_jaxpr_axes": (
    "jax.interpreters.batching.batch_jaxpr_axes is deprecated. It is an internal API.",
    _src_batching.batch_jaxpr_axes,
  ),
  "batch_subtrace": (
    "jax.interpreters.batching.batch_subtrace is deprecated. It is an internal API.",
    _src_batching.batch_subtrace,
  ),
  "broadcast_batcher": (
    "jax.interpreters.batching.broadcast_batcher is deprecated. It is an internal API.",
    _src_batching.broadcast_batcher,
  ),
  "flatten_fun_for_vmap": (
    "jax.interpreters.batching.flatten_fun_for_vmap is deprecated. It is an internal API.",
    _src_batching.flatten_fun_for_vmap,
  ),
  "from_elt": (
    "jax.interpreters.batching.from_elt is deprecated. It is an internal API.",
    _src_batching.from_elt,
  ),
  "from_elt_handlers": (
    "jax.interpreters.batching.from_elt_handlers is deprecated. It is an internal API.",
    _src_batching.from_elt_handlers,
  ),
  "is_vmappable": (
    "jax.interpreters.batching.is_vmappable is deprecated. It is an internal API.",
    _src_batching.is_vmappable,
  ),
  "make_iota": (
    "jax.interpreters.batching.make_iota is deprecated. It is an internal API.",
    _src_batching.make_iota,
  ),
  "make_iota_handlers": (
    "jax.interpreters.batching.make_iota_handlers is deprecated. It is an internal API.",
    _src_batching.make_iota_handlers,
  ),
  "matchaxis": (
    "jax.interpreters.batching.matchaxis is deprecated. It is an internal API.",
    _src_batching.matchaxis,
  ),
  "moveaxis": (
    "jax.interpreters.batching.moveaxis is deprecated. Use jax.numpy.moveaxis.",
    _src_batching.moveaxis,
  ),
  "reducer_batcher": (
    "jax.interpreters.batching.reducer_batcher is deprecated. It is an internal API.",
    _src_batching.reducer_batcher,
  ),
  "spec_types": (
    "jax.interpreters.batching.spec_types is deprecated. It is an internal API.",
    _src_batching.spec_types,
  ),
  "to_elt": (
    "jax.interpreters.batching.to_elt is deprecated. It is an internal API.",
    _src_batching.to_elt,
  ),
  "to_elt_handlers": (
    "jax.interpreters.batching.to_elt_handlers is deprecated. It is an internal API.",
    _src_batching.to_elt_handlers,
  ),
  "vectorized_batcher": (
    "jax.interpreters.batching.vectorized_batcher is deprecated. It is an internal API.",
    _src_batching.vectorized_batcher,
  ),
  "vmappables": (
    "jax.interpreters.batching.vmappables is deprecated. It is an internal API.",
    _src_batching.vmappables,
  ),
  "vtile": (
    "jax.interpreters.batching.vtile is deprecated. It is an internal API.",
    _src_batching.vtile,
  ),
  "zero_if_mapped": (
    "jax.interpreters.batching.zero_if_mapped is deprecated. It is an internal API.",
    _src_batching.zero_if_mapped,
  ),
}


import typing as _typing
if _typing.TYPE_CHECKING:
  Array = _src_batching.Array
  AxisSize = _src_batching.AxisSize
  BatchTrace = _src_batching.BatchTrace
  BatchTracer = _src_batching.BatchTracer
  BatchingRule = _src_batching.BatchingRule
  Elt = _src_batching.Elt
  FromEltHandler = _src_batching.FromEltHandler
  GetIdx = _src_batching.GetIdx
  MakeIotaHandler = _src_batching.MakeIotaHandler
  MapSpec = _src_batching.MapSpec
  NotMapped = _src_batching.NotMapped
  ToEltHandler = _src_batching.ToEltHandler
  Vmappable = _src_batching.Vmappable
  Zero = _src_batching.Zero
  ZeroIfMapped = _src_batching.ZeroIfMapped
  batch = _src_batching.batch
  batch_custom_jvp_subtrace = _src_batching.batch_custom_jvp_subtrace
  batch_custom_vjp_bwd = _src_batching.batch_custom_vjp_bwd
  batch_jaxpr = _src_batching.batch_jaxpr
  batch_jaxpr_axes = _src_batching.batch_jaxpr_axes
  batch_subtrace = _src_batching.batch_subtrace
  broadcast_batcher = _src_batching.broadcast_batcher
  flatten_fun_for_vmap = _src_batching.flatten_fun_for_vmap
  from_elt = _src_batching.from_elt
  from_elt_handlers = _src_batching.from_elt_handlers
  is_vmappable = _src_batching.is_vmappable
  make_iota = _src_batching.make_iota
  make_iota_handlers = _src_batching.make_iota_handlers
  matchaxis = _src_batching.matchaxis
  moveaxis = _src_batching.moveaxis
  reducer_batcher = _src_batching.reducer_batcher
  spec_types = _src_batching.spec_types
  to_elt = _src_batching.to_elt
  to_elt_handlers = _src_batching.to_elt_handlers
  vectorized_batcher = _src_batching.vectorized_batcher
  vmappables = _src_batching.vmappables
  vtile = _src_batching.vtile
  zero_if_mapped = _src_batching.zero_if_mapped
else:
  from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
  __getattr__ = _deprecation_getattr(__name__, _deprecations)
  del _deprecation_getattr
del _typing
