# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import collections
from collections.abc import Callable, Hashable, Iterable, Sequence
import dataclasses
import difflib
import functools
from functools import partial
import operator as op
import textwrap
from typing import Any, TypeVar

from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.util import safe_zip, set_module
from jax._src.util import unzip2


export = set_module('jax.tree_util')

traceback_util.register_exclusion(__file__)

T = TypeVar("T")
Typ = TypeVar("Typ", bound=type[Any])
H = TypeVar("H", bound=Hashable)

Leaf = Any
PyTree = Any
PyTreeDef = pytree.PyTreeDef

default_registry = pytree.default_registry()
# Set __module__ and __name__, which allow this registry to be pickled by
# reference.
default_registry.__module__ = __name__
default_registry.__name__ = "default_registry"

# A copy of the default registry, where None is a leaf.
none_leaf_registry = pytree.PyTreeRegistry(
    enable_none=False, enable_tuple=True, enable_namedtuple=True,
    enable_list=True, enable_dict=True)
none_leaf_registry.__module__ = __name__
none_leaf_registry.__name__ = "none_leaf_registry"

# A special, internal pytree registry that includes everything in
# `default_registry`, plus internal Python-defined types that we want
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
# unflatten. A key example is PRNG key arrays, which are currently a
# Python-defined class (in `jax._src.prng`). These ought to be a leaf
# node everywhere in the system (e.g. in Jaxpr), but we want to unpack
# and repack them across the fast dispatch boundary. If we were to
# skip registering such types here, the fast dispatch path would not
# know how to handle them as arguments. It would instead always
# indicate a "cache miss" and dispatch on the slow path.
dispatch_registry = pytree.PyTreeRegistry(
    enable_none=True, enable_tuple=True, enable_namedtuple=True,
    enable_list=True, enable_dict=True)
dispatch_registry.__module__ = __name__
dispatch_registry.__name__ = "dispatch_registry"


@export
def tree_flatten(tree: Any,
                 is_leaf: Callable[[Any], bool] | None = None
                 ) -> tuple[list[Leaf], PyTreeDef]:
  """Alias of :func:`jax.tree.flatten`."""
  return default_registry.flatten(tree, is_leaf)


@export
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
  """Alias of :func:`jax.tree.unflatten`."""
  return treedef.unflatten(leaves)


@export
def tree_leaves(tree: Any,
                is_leaf: Callable[[Any], bool] | None = None
                ) -> list[Leaf]:
  """Alias of :func:`jax.tree.leaves`."""
  return default_registry.flatten(tree, is_leaf)[0]


@export
def tree_structure(tree: Any,
                   is_leaf: None | (Callable[[Any],
                                              bool]) = None) -> PyTreeDef:
  """Alias of :func:`jax.tree.structure`."""
  return default_registry.flatten(tree, is_leaf)[1]


@export
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
  """Makes a tuple treedef from an iterable of child treedefs.

  Args:
    treedefs: iterable of PyTree structures

  Returns:
    a single treedef representing a tuple of the structures

  Examples:
    >>> import jax
    >>> x = [1, 2, 3]
    >>> y = {'a': 4, 'b': 5}
    >>> x_tree = jax.tree.structure(x)
    >>> y_tree = jax.tree.structure(y)
    >>> xy_tree = jax.tree_util.treedef_tuple([x_tree, y_tree])
    >>> xy_tree == jax.tree.structure((x, y))
    True

  See Also:
    - :func:`jax.tree_util.treedef_children`
  """
  return pytree.treedef_tuple(default_registry, list(treedefs))


@export
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
  """Return a list of treedefs for immediate children

  Args:
    treedef: a single PyTreeDef

  Returns:
    a list of PyTreeDefs representing the children of treedef.

  Examples:
    >>> import jax
    >>> x = [(1, 2), 3, {'a': 4}]
    >>> treedef = jax.tree.structure(x)
    >>> jax.tree_util.treedef_children(treedef)
    [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
    >>> _ == [jax.tree.structure(vals) for vals in x]
    True

  See Also:
    - :func:`jax.tree_util.treedef_tuple`
  """
  return treedef.children()


@export
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
  """Return True if the treedef represents a leaf.

  Args:
    treedef: tree to check

  Returns:
    True if treedef is a leaf (i.e. has a single node); False otherwise.

  Examples:
    >>> import jax
    >>> tree1 = jax.tree.structure(1)
    >>> jax.tree_util.treedef_is_leaf(tree1)
    True
    >>> tree2 = jax.tree.structure([1, 2])
    >>> jax.tree_util.treedef_is_leaf(tree2)
    False
  """
  return treedef.num_nodes == 1


# treedef_is_strict_leaf is not exported.
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
  return treedef.num_nodes == 1 and treedef.num_leaves == 1


@export
def all_leaves(iterable: Iterable[Any],
               is_leaf: Callable[[Any], bool] | None = None) -> bool:
  """Tests whether all elements in the given iterable are all leaves.

  This function is useful in advanced cases, for example if a library allows
  arbitrary map operations on a flat iterable of leaves it may want to check
  if the result is still a flat iterable of leaves.

  Args:
    iterable: Iterable of leaves.

  Returns:
    A boolean indicating if all elements in the input are leaves.

  Examples:
    >>> import jax
    >>> tree = {"a": [1, 2, 3]}
    >>> assert all_leaves(jax.tree_util.tree_leaves(tree))
    >>> assert not all_leaves([tree])
  """
  if is_leaf is None:
    return pytree.all_leaves(default_registry, iterable)
  else:
    items = list(iterable)
    leaves = tree_leaves(items, is_leaf)
    return len(leaves) == len(items) and all(
        item is leaf for item, leaf in zip(items, leaves, strict=True)
    )


_Children = TypeVar("_Children", bound=Iterable[Any])
_AuxData = TypeVar("_AuxData", bound=Hashable)
KeyEntry = TypeVar("KeyEntry", bound=Any)
KeyLeafPair = tuple[KeyEntry, Any]
KeyLeafPairs = Iterable[KeyLeafPair]
KeyPath = tuple[KeyEntry, ...]


@export
def register_pytree_node(
    nodetype: type[T],
    flatten_func: Callable[[T], tuple[_Children, _AuxData]],
    unflatten_func: Callable[[_AuxData, _Children], T],
    flatten_with_keys_func: (
        Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None
    ) = None,
) -> None:
  """Extends the set of types that are considered internal nodes in pytrees.

  See :ref:`example usage <pytrees>`.

  Args:
    nodetype: a Python type to register as a pytree.
    flatten_func: a function to be used during flattening, taking a value of
      type ``nodetype`` and returning a pair, with (1) an iterable for the
      children to be flattened recursively, and (2) some hashable auxiliary data
      to be stored in the treedef and to be passed to the ``unflatten_func``.
    unflatten_func: a function taking two arguments: the auxiliary data that was
      returned by ``flatten_func`` and stored in the treedef, and the
      unflattened children. The function should return an instance of
      ``nodetype``.

  See also:
    - :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
    - :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
    - :func:`~jax.tree_util.register_pytree_with_keys`
    - :func:`~jax.tree_util.register_pytree_node_class`
    - :func:`~jax.tree_util.register_pytree_with_keys_class`

  Examples:
    First we'll define a custom type:

    >>> class MyContainer:
    ...   def __init__(self, size):
    ...     self.x = jnp.zeros(size)
    ...     self.y = jnp.ones(size)
    ...     self.size = size

    If we try using this in a JIT-compiled function, we'll get an error because JAX
    does not yet know how to handle this type:

    >>> m = MyContainer(size=5)
    >>> def f(m):
    ...   return m.x + m.y + jnp.arange(m.size)
    >>> jax.jit(f)(m)  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
      ...
    TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute

    In order to make our object recognized by JAX, we must register it as
    a pytree:

    >>> def flatten_func(obj):
    ...   children = (obj.x, obj.y)  # children must contain arrays & pytrees
    ...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
    ...   return (children, aux_data)
    ...
    >>> def unflatten_func(aux_data, children):
    ...   # Here we avoid `__init__` because it has extra logic we don't require:
    ...   obj = object.__new__(MyContainer)
    ...   obj.x, obj.y = children
    ...   obj.size, = aux_data
    ...   return obj
    ...
    >>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)

    Now with this defined, we can use instances of this type in JIT-compiled functions.

    >>> jax.jit(f)(m)
    Array([1., 2., 3., 4., 5.], dtype=float32)
  """
  default_registry.register_node(
      nodetype, flatten_func, unflatten_func, flatten_with_keys_func
  )
  none_leaf_registry.register_node(
      nodetype, flatten_func, unflatten_func, flatten_with_keys_func
  )
  dispatch_registry.register_node(
      nodetype, flatten_func, unflatten_func, flatten_with_keys_func
  )
  _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)


@export
def register_pytree_node_class(cls: Typ) -> Typ:
  """Extends the set of types that are considered internal nodes in pytrees.

  This function is a thin wrapper around ``register_pytree_node``, and provides
  a class-oriented interface.

  Args:
    cls: a type to register as a pytree

  Returns:
    The input class ``cls`` is returned unchanged after being added to JAX's pytree
    registry. This return value allows ``register_pytree_node_class`` to be used as
    a decorator.

  See also:
    - :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
    - :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
    - :func:`~jax.tree_util.register_pytree_node`
    - :func:`~jax.tree_util.register_pytree_with_keys`
    - :func:`~jax.tree_util.register_pytree_with_keys_class`

  Examples:
    Here we'll define a custom container that will be compatible with :func:`jax.jit`
    and other JAX transformations:

    >>> import jax
    >>> @jax.tree_util.register_pytree_node_class
    ... class MyContainer:
    ...   def __init__(self, x, y):
    ...     self.x = x
    ...     self.y = y
    ...   def tree_flatten(self):
    ...     return ((self.x, self.y), None)
    ...   @classmethod
    ...   def tree_unflatten(cls, aux_data, children):
    ...     return cls(*children)
    ...
    >>> m = MyContainer(jnp.zeros(4), jnp.arange(4))
    >>> def f(m):
    ...   return m.x + 2 * m.y
    >>> jax.jit(f)(m)
    Array([0., 2., 4., 6.], dtype=float32)
  """
  register_pytree_node(cls, op.methodcaller("tree_flatten"), cls.tree_unflatten)
  return cls


@export
def tree_map(f: Callable[..., Any],
             tree: Any,
             *rest: Any,
             is_leaf: Callable[[Any], bool] | None = None) -> Any:
  """Alias of :func:`jax.tree.map`."""
  leaves, treedef = tree_flatten(tree, is_leaf)
  all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))


@export
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
  """Build a treedef from a nested iterable structure

  DEPRECATED: Use :func:`jax.tree.unflatten` instead.

  Args:
    treedef: the PyTreeDef structure to build.
    xs: nested iterables matching the arity as the treedef

  Returns:
    object with structure defined by treedef

  See Also:
    - :func:`jax.tree.unflatten`

  Examples:
    >>> import jax
    >>> tree = [(1, 2), {'a': 3, 'b': 4}]
    >>> treedef = jax.tree.structure(tree)
    >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
    [(10, 11), {'a': 12, 'b': 13}]
  """
  return treedef.from_iterable_tree(xs)


@export
def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
                   pytree_to_transpose: Any) -> Any:
  """Alias of :func:`jax.tree.transpose`."""
  flat, treedef = tree_flatten(pytree_to_transpose)
  if inner_treedef is None:
    inner_treedef = tree_structure(outer_treedef.flatten_up_to(pytree_to_transpose)[0])
  inner_size = inner_treedef.num_leaves
  outer_size = outer_treedef.num_leaves
  if treedef.num_leaves != (inner_size * outer_size):
    expected_treedef = outer_treedef.compose(inner_treedef)
    raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
  iter_flat = iter(flat)
  lol = [
      [next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)
  ]
  transposed_lol = zip(*lol)
  subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
  return tree_unflatten(inner_treedef, subtrees)


# TODO(mattjj): remove the Python-side registry when the C++-side registry is
# sufficiently queryable that we can express _replace_nones. That may mean once
# we have a flatten_one function.
_RegistryEntry = collections.namedtuple("_RegistryEntry", ["to_iter", "from_iter"])
_registry: dict[type[Any], _RegistryEntry] = {
    tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)),
    list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)),
    dict: _RegistryEntry(lambda xs: unzip2(sorted(xs.items()))[::-1],
                         lambda keys, xs: dict(zip(keys, xs))),
    type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
}


class Unspecified:
  pass


@export
def tree_reduce(function: Callable[[T, Any], T],
                tree: Any,
                initializer: T | Unspecified = Unspecified(),
                is_leaf: Callable[[Any], bool] | None = None) -> T:
  """Alias of :func:`jax.tree.reduce`."""
  if isinstance(initializer, Unspecified):
    return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf))
  else:
    return functools.reduce(function, tree_leaves(tree, is_leaf=is_leaf), initializer)


def _parallel_reduce(
    sequence: list[T],
    operation: Callable[[T, T], T],
    identity: T | Unspecified = Unspecified(),
) -> T:
  length = len(sequence)
  if length == 0:
    if isinstance(identity, Unspecified):
      raise TypeError("Must specify identity for parallel reduction of empty sequence.")
    return identity
  elif length == 1:
    return sequence[0]
  else:
    index = length // 2
    a = _parallel_reduce(sequence[:index], operation, identity)
    b = _parallel_reduce(sequence[index:], operation, identity)
    return operation(a, b)


@export
def tree_reduce_associative(
    operation: Callable[[T, T], T],
    tree: Any,
    *,
    identity: T | Unspecified = Unspecified(),
    is_leaf: Callable[[Any], bool] | None = None,
) -> T:
  """Alias of :func:`jax.tree.reduce_associative`."""
  sequence = tree_leaves(tree, is_leaf=is_leaf)
  return _parallel_reduce(sequence, operation, identity)


@export
def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool:
  """Alias of :func:`jax.tree.all`."""
  return all(tree_leaves(tree, is_leaf=is_leaf))


class _HashableCallableShim:
  """Object that delegates __call__, __hash__, and __eq__ to another object."""

  def __init__(self, fun):
    self.fun = fun

  def __call__(self, *args, **kw):
    return self.fun(*args, **kw)

  def __hash__(self):
    return hash(self.fun)

  def __eq__(self, other):
    if isinstance(other, _HashableCallableShim):
      return self.fun == other.fun
    return self.fun == other

  def __repr__(self):
    return f'_HashableCallableShim({self.fun!r})'


@export
class Partial(functools.partial):
  """A version of functools.partial that works in pytrees.

  Use it for partial function evaluation in a way that is compatible with JAX's
  transformations, e.g., ``Partial(func, *args, **kwargs)``.

  (You need to explicitly opt-in to this behavior because we didn't want to give
  functools.partial different semantics than normal function closures.)

  For example, here is a basic usage of ``Partial`` in a manner similar to
  ``functools.partial``:

  >>> import jax.numpy as jnp
  >>> add_one = Partial(jnp.add, 1)
  >>> add_one(2)
  Array(3, dtype=int32, weak_type=True)

  Pytree compatibility means that the resulting partial function can be passed
  as an argument within transformed JAX functions, which is not possible with a
  standard ``functools.partial`` function:

  >>> from jax import jit
  >>> @jit
  ... def call_func(f, *args):
  ...   return f(*args)
  ...
  >>> call_func(add_one, 2)
  Array(3, dtype=int32, weak_type=True)

  Passing zero arguments to ``Partial`` effectively wraps the original function,
  making it a valid argument in JAX transformed functions:

  >>> call_func(Partial(jnp.add), 1, 2)
  Array(3, dtype=int32, weak_type=True)

  Had we passed ``jnp.add`` to ``call_func`` directly, it would have resulted in
  a ``TypeError``.

  Note that if the result of ``Partial`` is used in the context where the
  value is traced, it results in all bound arguments being traced when passed
  to the partially-evaluated function:

  >>> print_zero = Partial(print, 0)
  >>> print_zero()
  0
  >>> call_func(print_zero)  # doctest:+ELLIPSIS
  JitTracer(~int32[])
  """

  def __new__(klass, func, *args, **kw):
    # In Python 3.10+, if func is itself a functools.partial instance,
    # functools.partial.__new__ would merge the arguments of this Partial
    # instance with the arguments of the func. We box func in a class that does
    # not (yet) have a `func` attribute to defeat this optimization, since we
    # care exactly which arguments are considered part of the pytree.
    if isinstance(func, functools.partial):
      original_func = func
      func = _HashableCallableShim(original_func)
      out = super().__new__(klass, func, *args, **kw)
      func.func = original_func.func
      func.args = original_func.args
      func.keywords = original_func.keywords
      return out
    else:
      return super().__new__(klass, func, *args, **kw)


register_pytree_node(
    Partial,
    lambda partial_: ((partial_.args, partial_.keywords), partial_.func),
    lambda func, xs: Partial(func, *xs[0], **xs[1]),
)


@export
def tree_broadcast(prefix_tree: Any, full_tree: Any,
                   is_leaf: Callable[[Any], bool] | None = None
                  ) -> Any:
  """Alias of :func:`jax.tree.broadcast`."""
  broadcast_leaves = broadcast_prefix(prefix_tree, full_tree, is_leaf=is_leaf)
  return tree_structure(full_tree).unflatten(broadcast_leaves)


# broadcast_prefix is not exported
def broadcast_prefix(prefix_tree: Any, full_tree: Any,
                     is_leaf: Callable[[Any], bool] | None = None
                     ) -> list[Any]:
  """Broadcasts tree prefix leaves into the full set of leaves for a given full tree.

    Args:
      prefix_tree: a pytree that is a tree prefix of full_tree.
      full_tree: a pytree with the structure to broadcast the prefix leaves into.
      is_leaf: an optionally specified function that will be called at each
        flattening step for prefix_tree. It should return a boolean, with true
        stopping the traversal and the whole subtree being treated as a leaf,
        and false indicating the flattening should traverse the current object.

    Returns:
      A list of leaves matching the expected count for the full tree,
      with the leaf of each prefix tree being duplicated to match the count of
      its corresponding subtree.
  """
  result = []
  num_leaves = lambda t: tree_structure(t).num_leaves
  add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
  try:
    tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
  except ValueError:
      e, *_ = prefix_errors(prefix_tree, full_tree)
      raise e('broadcast_prefix prefix_tree') from None
  return result


# broadcast_flattened_prefix_with_treedef is not exported
def broadcast_flattened_prefix_with_treedef(
    prefix_leaves: list[Any],
    prefix_treedef: PyTreeDef,
    full_treedef: PyTreeDef,
) -> list[Any]:
  """Broadcasts tree prefix leaves into the full set of leaves for a given full treedef.

    Args:
      prefix_leaves: the leaves of a pytree that is a tree prefix
        of full_treedef.
      prefix_treedef: the PyTreeDef of a pytree that is a tree prefix of
        full_treedef.
      full_treedef: a PyTreeDef with the structure to broadcast the prefix
        leaves into.

    Returns:
      A list of leaves matching the expected count for the full tree,
      with each leaf of prefix tree being duplicated to match the count of
      its corresponding subtree.
  """
  # NOTE: At the moment, `broadcast_flattened_prefix_with_treedef` is only
  # called from `api_util.flatten_axes`, which replaces any raised exception
  # with its own exception and error message.  The errors raised from this
  # function should probably be improved before this function is used in
  # more places.
  #
  # TODO(jburnim): Merge `broadcast_prefix` with this function?
  # prefix_leaves, prefix_treedef = tree_flatten(prefix_tree, is_leaf)
  ret = []

  # TODO(jburnim): Should this traversal be done in C++?
  def _broadcast(broadcast_fn, leaf_start, leaf_end, prefix_treedef, treedef):
    if treedef_is_strict_leaf(prefix_treedef):
      # We have encountered a leaf in the prefix, so we repeat the prefix leaf
      # for each leaf in the corresponding part of the tree.
      assert (leaf_end - leaf_start) == 1
      ret.extend(prefix_leaves[leaf_start:leaf_end] * treedef.num_leaves)
      return

    if treedef_is_strict_leaf(treedef):
      raise ValueError('`prefix_treedef` is not a prefix of `full_treedef`')

    prefix_node_data = prefix_treedef.node_data()
    node_data = treedef.node_data()
    if prefix_node_data != node_data:
      raise ValueError(f'expected {node_data}, got {prefix_node_data}')

    prefix_i = leaf_start
    for prefix_child, tree_child in zip(
        prefix_treedef.children(), treedef.children(), strict=True):
      broadcast_fn(broadcast_fn, prefix_i, prefix_i + prefix_child.num_leaves,
                   prefix_child, tree_child,
      )
      prefix_i += prefix_child.num_leaves

  # Pass _broadcast as arg to avoid it being a free variable within its own
  # closure, which creates a reference cycle.
  _broadcast(_broadcast, 0, len(prefix_leaves), prefix_treedef, full_treedef)
  return ret


# flatten_one_level is not exported.
def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
  """Flatten the given pytree node by one level.

  Args:
    tree: A valid pytree node, either built-in or registered via
      :func:`register_pytree_node` or related functions.

  Returns:
    A pair of the pytrees flattened children and its hashable metadata.

  Raises:
    ValueError: If the given pytree is not a built-in or registered container
    via ``register_pytree_node`` or ``register_pytree_with_keys``.

  Examples:
    >>> import jax
    >>> from jax._src.tree_util import flatten_one_level
    >>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}})
    >>> flattened
    ([1, 2], {'c': 3})
    >>> meta
    ('a', 'b')
  """
  out = default_registry.flatten_one_level(tree)
  if out is None:
    raise ValueError(f"can't tree-flatten type: {type(tree)}")
  else:
    return out


# flatten_one_level_with_keys is not exported.
def flatten_one_level_with_keys(
    tree: Any,
) -> tuple[Iterable[KeyLeafPair], Hashable]:
  """Flatten the given pytree node by one level, with keys."""
  out = default_registry.flatten_one_level_with_keys(tree)  # type: ignore
  if out is None:
    raise ValueError(f"can't tree-flatten type: {type(tree)}")
  else:
    return out


# prefix_errors is not exported
def prefix_errors(prefix_tree: Any, full_tree: Any,
                  is_leaf: Callable[[Any], bool] | None = None,
                  ) -> list[Callable[[str], ValueError]]:
  return list(_prefix_error((), prefix_tree, full_tree, is_leaf))


# equality_errors is not exported
def equality_errors(
    tree1: Any, tree2: Any, is_leaf: Callable[[Any], bool] | None = None,
) -> Iterable[tuple[KeyPath, str, str, str]]:
  """Helper to describe structural differences between two pytrees.

  Args:
    tree1, tree2: pytrees known to have different structure.

  Usage:

    raise Exception(
        "Value 1 and value 2 must have the same pytree structure, but they have "
        "the following structural differences:\n" +
        ("\n".join(
           f"   - {keystr(path)} is a {thing1} in value 1 and a {thing2} in "
           f" value 2, so {explanation}.\n"
           for path, thing1, thing2, explanation
           in equality_errors(val1, val2))))
  """
  yield from _equality_errors((), tree1, tree2, is_leaf)

def equality_errors_pytreedef(
    tree1: PyTreeDef,
    tree2: PyTreeDef) -> Iterable[tuple[KeyPath, str, str, str]]:
  """Like `equality_errors` but invoked on PyTreeDef."""
  # TODO(mattjj): make equality_errors not print type name, avoid metaclass
  leaf = type("LeafMeta", (type,), dict(__repr__=lambda _: "pytree leaf"))("Leaf", (), {})()
  return equality_errors(tree_unflatten(tree1, [leaf] * tree1.num_leaves),
                         tree_unflatten(tree2, [leaf] * tree2.num_leaves))

# TODO(mattjj): maybe share some logic with _prefix_error?
def _equality_errors(path, t1, t2, is_leaf):
  # If both are leaves, this isn't a structure equality error.
  if (treedef_is_strict_leaf(tree_structure(t1, is_leaf=is_leaf)) and
      treedef_is_strict_leaf(tree_structure(t2, is_leaf=is_leaf))): return

  # The trees may disagree because they are different types:
  if type(t1) != type(t2):
    yield path, str(type(t1)), str(type(t2)), 'their Python types differ'
    return  # no more errors to find

  # Or they may disagree because their roots have different numbers or keys of
  # children (with special-case handling of list/tuple):
  if isinstance(t1, (list, tuple)):
    assert type(t1) == type(t2)
    if len(t1) != len(t2):
      yield (path,
             f'{type(t1).__name__} of length {len(t1)}',
             f'{type(t2).__name__} of length {len(t2)}',
             'the lengths do not match')
      return  # no more errors to find
  t1_children, t1_meta = flatten_one_level(t1)
  t2_children, t2_meta = flatten_one_level(t2)
  t1_children = tuple(t1_children)
  t2_children = tuple(t2_children)
  t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
  try:
    diff = ' '.join(repr(k.key) for k in
                    set(t1_keys).symmetric_difference(set(t2_keys)))
  except:
    diff = ''
  if len(t1_children) != len(t2_children):
    yield (path,
           f'{type(t1)} with {len(t1_children)} child'
           f'{"ren" if len(t1_children) > 1 else ""}',
           f'{type(t2)} with {len(t2_children)} child'
           f'{"ren" if len(t2_children) > 1 else ""}',
           'the numbers of children do not match' +
           (diff and f', with the symmetric difference of key sets: {{{diff}}}')
           )
    return  # no more errors to find

  # Or they may disagree if their roots have different pytree metadata:
  if t1_meta != t2_meta:
    yield (path,
           f'{type(t1)} with pytree metadata {t1_meta}',
           f'{type(t2)} with pytree metadata {t2_meta}',
           'the pytree node metadata does not match')
    return  # no more errors to find

  # If the root types and numbers of children agree, there must be a mismatch in
  # a subtree, so recurse:
  assert t1_keys == t2_keys, \
      f"equal pytree nodes gave different tree keys: {t1_keys} and {t2_keys}"
  for k, c1, c2 in zip(t1_keys, t1_children, t2_children):
    yield from _equality_errors((*path, k), c1, c2, is_leaf)


SequenceKey: Any = pytree.SequenceKey
DictKey: Any = pytree.DictKey
GetAttrKey: Any = pytree.GetAttrKey
FlattenedIndexKey: Any = pytree.FlattenedIndexKey


@export
def keystr(keys: KeyPath, *, simple: bool = False, separator: str = '') -> str:
  """Helper to pretty-print a tuple of keys.

  Args:
    keys: A tuple of ``KeyEntry`` or any class that can be converted to string.
    simple: If True, use a simplified string representation for keys. The
      simple representation of keys will be more compact than the default, but
      is ambiguous in some cases (for example "0" might refer to the first item
      in a list or a dictionary key for the integer 0 or string "0").
    separator: The separator to use to join string representations of the keys.

  Returns:
    A string that joins all string representations of the keys.

  Examples:
    >>> import jax
    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    >>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']
    >>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
  """
  str_fn = _simple_entrystr if simple else str
  return separator.join(map(str_fn, keys))


def _simple_entrystr(key: KeyEntry) -> str:
  match key:
    case (
        SequenceKey(idx=key)
        | DictKey(key=key)
        | GetAttrKey(name=key)
        | FlattenedIndexKey(key=key)
    ):
      return str(key)
    case _:
      return str(key)


@export
def register_pytree_with_keys(
    nodetype: type[T],
    flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]],
    unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
    flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None,
):
  """Extends the set of types that are considered internal nodes in pytrees.

  This is a more powerful alternative to ``register_pytree_node`` that allows
  you to access each pytree leaf's key path when flattening and tree-mapping.

  Args:
    nodetype: a Python type to treat as an internal pytree node.
    flatten_with_keys: a function to be used during flattening, taking a value
      of type ``nodetype`` and returning a pair, with (1) an iterable for tuples
      of each key path and its child, and (2) some hashable auxiliary data to be
      stored in the treedef and to be passed to the ``unflatten_func``.
    unflatten_func: a function taking two arguments: the auxiliary data that was
      returned by ``flatten_func`` and stored in the treedef, and the
      unflattened children. The function should return an instance of
      ``nodetype``.
    flatten_func: an optional function similar to ``flatten_with_keys``, but
      returns only children and auxiliary data. It must return the children
      in the same order as ``flatten_with_keys``, and return the same aux data.
      This argument is optional and only needed for faster traversal when
      calling functions without keys like ``tree_map`` and ``tree_flatten``.

  Examples:
    First we'll define a custom type:

    >>> class MyContainer:
    ...   def __init__(self, size):
    ...     self.x = jnp.zeros(size)
    ...     self.y = jnp.ones(size)
    ...     self.size = size

    Now register it using a key-aware flatten function:

    >>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
    >>> def flatten_with_keys(obj):
    ...   children = [(GetAttrKey('x'), obj.x),
    ...               (GetAttrKey('y'), obj.y)]  # children must contain arrays & pytrees
    ...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
    ...   return children, aux_data
    ...
    >>> def unflatten(aux_data, children):
    ...   # Here we avoid `__init__` because it has extra logic we don't require:
    ...   obj = object.__new__(MyContainer)
    ...   obj.x, obj.y = children
    ...   obj.size, = aux_data
    ...   return obj
    ...
    >>> jax.tree_util.register_pytree_node(MyContainer, flatten_with_keys, unflatten)

    Now this can be used with functions like :func:`~jax.tree_util.tree_flatten_with_path`:

    >>> m = MyContainer(4)
    >>> leaves, treedef = jax.tree_util.tree_flatten_with_path(m)
  """
  if not flatten_func:
    def flatten_func_impl(tree):
      key_children, treedef = flatten_with_keys(tree)
      return [c for _, c in key_children], treedef
    flatten_func = flatten_func_impl

  register_pytree_node(
      nodetype, flatten_func, unflatten_func, flatten_with_keys
  )


@export
def register_pytree_with_keys_class(cls: Typ) -> Typ:
  """Extends the set of types that are considered internal nodes in pytrees.

  This function is similar to ``register_pytree_node_class``, but requires a
  class that defines how it could be flattened with keys.

  It is a thin wrapper around ``register_pytree_with_keys``, and
  provides a class-oriented interface:

  Args:
    cls: a type to register as a pytree

  Returns:
    The input class ``cls`` is returned unchanged after being added to JAX's pytree
    registry. This return value allows ``register_pytree_node_class`` to be used as
    a decorator.

  See also:
    - :func:`~jax.tree_util.register_static`: simpler API for registering a static pytree.
    - :func:`~jax.tree_util.register_dataclass`: simpler API for registering a dataclass.
    - :func:`~jax.tree_util.register_pytree_node`
    - :func:`~jax.tree_util.register_pytree_with_keys`
    - :func:`~jax.tree_util.register_pytree_node_class`

  Examples:
    >>> from jax.tree_util import register_pytree_with_keys_class, GetAttrKey
    >>> @register_pytree_with_keys_class
    ... class Special:
    ...   def __init__(self, x, y):
    ...     self.x = x
    ...     self.y = y
    ...   def tree_flatten_with_keys(self):
    ...     return (((GetAttrKey('x'), self.x), (GetAttrKey('y'), self.y)), None)
    ...   @classmethod
    ...   def tree_unflatten(cls, aux_data, children):
    ...     return cls(*children)
  """
  flatten_func = (
      op.methodcaller("tree_flatten") if hasattr(cls, "tree_flatten") else None
  )
  register_pytree_with_keys(
      cls, op.methodcaller("tree_flatten_with_keys"), cls.tree_unflatten,
      flatten_func
  )
  return cls


@export
def register_dataclass(
    nodetype: Typ,
    data_fields: Sequence[str] | None = None,
    meta_fields: Sequence[str] | None = None,
    drop_fields: Sequence[str] = (),
) -> Typ:
  """Extends the set of types that are considered internal nodes in pytrees.

  This differs from ``register_pytree_with_keys_class`` in that the C++
  registries use the optimized C++ dataclass builtin instead of the argument
  functions.

  See :ref:`pytrees-custom-pytree-nodes` for more information about registering pytrees.

  Args:
    nodetype: a Python type to treat as an internal pytree node. This is assumed
      to have the semantics of a :obj:`~dataclasses.dataclass`: namely, class
      attributes represent the whole of the object state, and can be passed
      as keywords to the class constructor to create a copy of the object.
      All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
    meta_fields: metadata field names: these are attributes which will be treated as
      :term:`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is
      optional only if ``nodetype`` is a dataclass, in which case individual fields can
      be marked static via :func:`dataclasses.field` (see examples below).
      Metadata fields *must* be static, hashable, immutable objects, as these objects
      are used to generate JIT cache keys. In particular, metadata fields cannot contain
      :class:`jax.Array` or :class:`numpy.ndarray` objects.
    data_fields: data field names: these are attributes which will be treated as non-static
      when this pytree is passed to :func:`jax.jit`. ``data_fields`` is optional only if
      ``nodetype`` is a dataclass, in which case fields are assumed data fields unless
      marked via :func:`dataclasses.field` (see examples below).
      Data fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array`
      or :class:`numpy.ndarray`), scalars, or pytrees whose leaves are arrays or scalars.
      Note that ``None`` is a valid data field, as JAX recognizes this as an empty pytree.

  Returns:
    The input class ``nodetype`` is returned unchanged after being added to JAX's
    pytree registry, so that :func:`register_dataclass` can be used as a decorator.

  Examples:
    In JAX v0.4.35 or older, you must specify ``data_fields`` and ``meta_fields``
    in order to use this decorator:

    >>> import jax
    >>> from dataclasses import dataclass
    >>> from functools import partial
    ...
    >>> @partial(jax.tree_util.register_dataclass,
    ...          data_fields=['x', 'y'],
    ...          meta_fields=['op'])
    ... @dataclass
    ... class MyStruct:
    ...   x: jax.Array
    ...   y: jax.Array
    ...   op: str
    ...
    >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
    >>> m
    MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

    Starting in JAX v0.4.36, the ``data_fields`` and ``meta_fields`` arguments are optional
    for :func:`~dataclasses.dataclass` inputs, with fields defaulting to ``data_fields``
    unless marked as static using `static` metadata in :func:`dataclasses.field`.

    >>> import jax
    >>> from dataclasses import dataclass, field
    ...
    >>> @jax.tree_util.register_dataclass
    ... @dataclass
    ... class MyStruct:
    ...   x: jax.Array  # defaults to non-static data field
    ...   y: jax.Array  # defaults to non-static data field
    ...   op: str = field(metadata=dict(static=True))  # marked as static meta field.
    ...
    >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
    >>> m
    MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

    Once this class is registered, it can be used with functions in :mod:`jax.tree` and
    :mod:`jax.tree_util`:

    >>> leaves, treedef = jax.tree.flatten(m)
    >>> leaves
    [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)]
    >>> treedef
    PyTreeDef(CustomNode(MyStruct[('add',)], [*, *]))
    >>> jax.tree.unflatten(treedef, leaves)
    MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')

    In particular, this registration allows ``m`` to be passed seamlessly through code
    wrapped in :func:`jax.jit` and other JAX transformations, with ``data_fields`` being
    treated as dynamic arguments, and ``meta_fields`` being treated as static arguments:

    >>> @jax.jit
    ... def compiled_func(m):
    ...   if m.op == 'add':
    ...     return m.x + m.y
    ...   else:
    ...     raise ValueError(f"{m.op=}")
    ...
    >>> compiled_func(m)
    Array([1., 2., 3.], dtype=float32)
  """
  if data_fields is None or meta_fields is None:
    if (data_fields is None) != (meta_fields is None):
      raise TypeError("register_dataclass: data_fields and meta_fields must both be specified"
                      f" when either is specified. Got {data_fields=} {meta_fields=}.")
    if not dataclasses.is_dataclass(nodetype):
      raise TypeError("register_dataclass: data_fields and meta_fields are required when"
                      f" nodetype is not a dataclass. Got {nodetype=}.")
    data_fields = [f.name for f in dataclasses.fields(nodetype)
                   if not f.metadata.get('static', False)]
    meta_fields = [f.name for f in dataclasses.fields(nodetype)
                   if f.metadata.get('static', False)]

  assert meta_fields is not None
  assert data_fields is not None

  # Store inputs as immutable tuples in this scope, because we close over them
  # for later evaluation. This prevents potentially confusing behavior if the
  # caller were to pass in lists that are later mutated.
  meta_fields = tuple(meta_fields)
  data_fields = tuple(data_fields)

  if dataclasses.is_dataclass(nodetype):
    init_fields = {f.name for f in dataclasses.fields(nodetype) if f.init}
    init_fields.difference_update(*drop_fields)
    if {*meta_fields, *data_fields} != init_fields:
      msg = (
          "data_fields and meta_fields must include all dataclass fields with"
          " ``init=True`` and only them."
      )
      if missing := init_fields - {*meta_fields, *data_fields}:
        msg += (
            f" Missing fields: {missing}. Add them to drop_fields to suppress"
            " this error."
        )
      if unexpected := {*meta_fields, *data_fields} - init_fields:
        msg += f" Unexpected fields: {unexpected}."
      raise ValueError(msg)

  def unflatten_func(meta, data):
    meta_args = tuple(zip(meta_fields, meta))
    data_args = tuple(zip(data_fields, data))
    kwargs = dict(meta_args + data_args)
    return nodetype(**kwargs)

  def flatten_func(x):
    meta = tuple(getattr(x, name) for name in meta_fields)
    data = tuple(getattr(x, name) for name in data_fields)
    return data, meta

  default_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
  none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
  dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
  _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
  return nodetype


register_pytree_with_keys(
    collections.OrderedDict,
    lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())),
    lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
)

def _flatten_defaultdict_with_keys(d):
  keys = tuple(sorted(d))
  return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys)

register_pytree_with_keys(
    collections.defaultdict,
    _flatten_defaultdict_with_keys,
    lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
)


@export
def register_static(cls: type[H]) -> type[H]:
  """Registers `cls` as a pytree with no leaves.

  Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc. This can
  be an alternative to labeling inputs as static using ``jit``'s ``static_argnums``
  and ``static_argnames`` kwargs, ``pmap``'s ``static_broadcasted_argnums``, etc.

  Args:
    cls: type to be registered as static. Must be hashable, as defined in
      https://docs.python.org/3/glossary.html#term-hashable.

  Returns:
    The input class ``cls`` is returned unchanged after being added to JAX's
    pytree registry. This allows ``register_static`` to be used as a decorator.

  Examples:
    >>> import jax
    >>> @jax.tree_util.register_static
    ... class StaticStr(str):
    ...   pass

    This static string can now be used directly in :func:`jax.jit`-compiled
    functions, without marking the variable static using ``static_argnums``:

    >>> @jax.jit
    ... def f(x, y, s):
    ...   return x + y if s == 'add' else x - y
    ...
    >>> f(1, 2, StaticStr('add'))
    Array(3, dtype=int32, weak_type=True)
  """
  flatten = lambda obj: ((), obj)
  unflatten = lambda obj, empty_iter_children: obj
  register_pytree_with_keys(cls, flatten, unflatten)
  return cls


@export
def tree_flatten_with_path(
    tree: Any, is_leaf: Callable[..., bool] | None = None,
    is_leaf_takes_path: bool = False,
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
  """Alias of :func:`jax.tree.flatten_with_path`."""
  is_leaf_with_kp: Callable[[Any, Any], bool] | None = is_leaf
  if not is_leaf_takes_path and is_leaf is not None:
    is_leaf_with_kp = lambda _, x: is_leaf(x)
  return default_registry.flatten_with_path(tree, is_leaf_with_kp)


@export
def tree_leaves_with_path(
    tree: Any, is_leaf: Callable[..., bool] | None = None,
    is_leaf_takes_path: bool = False,
) -> list[tuple[KeyPath, Any]]:
  """Alias of :func:`jax.tree.leaves_with_path`."""
  return tree_flatten_with_path(tree, is_leaf, is_leaf_takes_path)[0]
generate_key_paths = tree_leaves_with_path


@export
def tree_map_with_path(
    f: Callable[..., Any],
    tree: Any,
    *rest: Any,
    is_leaf: Callable[..., bool] | None = None,
    is_leaf_takes_path: bool = False,
) -> Any:
  """Alias of :func:`jax.tree.map_with_path`."""
  keypath_leaves, treedef = tree_flatten_with_path(
      tree, is_leaf, is_leaf_takes_path
  )
  keypath_leaves = list(zip(*keypath_leaves))
  all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
  return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))


def _child_keys(pytree: Any) -> KeyPath:
  assert not treedef_is_strict_leaf(tree_structure(pytree))
  return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0])


def _prefix_error(
    key_path: KeyPath,
    prefix_tree: Any,
    full_tree: Any,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Iterable[Callable[[str], ValueError]]:
  # A leaf is a valid prefix of any tree:
  if treedef_is_strict_leaf(tree_structure(prefix_tree, is_leaf=is_leaf)):
    return

  # The subtrees may disagree because their roots are of different types:
  if type(prefix_tree) != type(full_tree):
    yield lambda name: ValueError(
      "pytree structure error: different types at key path\n"
      f"    {name}{keystr(key_path)}\n"
      f"At that key path, the prefix pytree {name} has a subtree of type\n"
      f"    {type(prefix_tree)}\n"
      f"but at the same key path the full pytree has a subtree of different type\n"
      f"    {type(full_tree)}.")
    return  # don't look for more errors in this subtree

  # Or they may disagree if their roots have different numbers or keys of
  # children. Because both prefix_tree and full_tree have the same type at this
  # point, and because prefix_tree is not a leaf, each can be flattened once:
  prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
  full_tree_children, full_tree_meta = flatten_one_level(full_tree)
  prefix_tree_children = tuple(prefix_tree_children)
  full_tree_children = tuple(full_tree_children)
  prefix_tree_keys = _child_keys(prefix_tree)
  full_tree_keys = _child_keys(full_tree)
  # First we check special case types (list and tuple, though if they were
  # pytrees we could check strings and sets here, basically Sequences) so that
  # we can report length disagreement rather than integer keys:
  if isinstance(prefix_tree, (list, tuple)):
    if len(prefix_tree) != len(full_tree):
      ty = type(prefix_tree)
      yield lambda name: ValueError(
          f"pytree structure error: different lengths of {ty.__name__} at key path\n"
          f"    {name}{keystr(key_path)}\n"
          f"At that key path, the prefix pytree {name} has a subtree of type "
          f"{ty.__name__} of length {len(prefix_tree)}, but the full pytree "
          f"has a subtree of the same type but of length {len(full_tree)}.")
      return  # don't look for more errors in this subtree
  else:
    # Next we handle the general case of checking child keys.
    try:
      diff = set(prefix_tree_keys).symmetric_difference(set(full_tree_keys))
    except:
      diff = None
    if len(prefix_tree_children) != len(full_tree_children):
      yield lambda name: ValueError(
        "pytree structure error: different numbers of pytree children at key path\n"
        f"    {name}{keystr(key_path)}\n"
        f"At that key path, the prefix pytree {name} has a subtree of type\n"
        f"    {type(prefix_tree)}\n"
        f"with {len(prefix_tree_children)} child keys\n"
        f"    {' '.join(str(k.key) for k in prefix_tree_keys)}\n"
        f"but at the same key path the full pytree has a subtree of the same "
        f"type but with {len(full_tree_children)} child keys\n"
        f"    {' '.join(str(k.key) for k in full_tree_keys)}\n"
        + ("" if diff is None else
           f"so the symmetric difference on key sets is\n"
           f"    {' '.join(str(k.key) for k in diff)}"))
      return  # don't look for more errors in this subtree

  # Or they may disagree if their roots have different pytree metadata:
  if prefix_tree_meta != full_tree_meta:
    prefix_tree_meta_str = str(prefix_tree_meta)
    full_tree_meta_str = str(full_tree_meta)
    metadata_diff = textwrap.indent(
        "\n".join(
            difflib.ndiff(prefix_tree_meta_str.splitlines(),
                          full_tree_meta_str.splitlines())),
        prefix="    ")
    yield lambda name: ValueError(
      "pytree structure error: different pytree metadata at key path\n"
      f"    {name}{keystr(key_path)}\n"
      f"At that key path, the prefix pytree {name} has a subtree of type\n"
      f"    {type(prefix_tree)}\n"
      f"with metadata\n"
      f"    {prefix_tree_meta_str}\n"
      f"but at the same key path the full pytree has a subtree of the same "
      f"type but with metadata\n"
      f"    {full_tree_meta_str}\n"
      f"so the diff in the metadata at these pytree nodes is\n"
      f"{metadata_diff}")
    return  # don't look for more errors in this subtree

  # If the root types and numbers of children agree, there must be an error
  # in a subtree, so recurse:
  assert prefix_tree_keys == full_tree_keys, \
    ("equal pytree nodes gave differing prefix_tree_keys: "
     f"{prefix_tree_keys} and {full_tree_keys}")
  for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
    yield from _prefix_error((*key_path, k), t1, t2)

# === flat tree ===

class FlatTree:
  """A FlatTree stores a treedef and a flat list of values. It's meant to be
  isomorphic to the corresponding pytree but we can map over it more easily.
  Compared to `tree_map`, FlatTree.map has these benefits:
    1. It doesn't touch user flatten/unflatten code (which shouldn't have side
       effects but sometimes does in practice).
    2. It can be faster, because it skips the recursive traversal.
    3. It actually obeys the functor rules. For example,
       `flat_tree.map(lambda x: (f(x), g(x))).unzip2()[0]` will give
       the same result as `flat_tree.map(f)`, whereas in the `tree_map` version
       the tuple-returning function would change the tree structure and `unzip`
       wouldn't be able to recover it.
  """
  def __init__(self, vals:Sequence, treedef:PyTreeDef):
    assert isinstance(treedef, pytree.PyTreeDef)
    self.tree = treedef
    self.vals = tuple(vals)

  def map(self, f:Callable) -> FlatTree:
    ans_vals = []
    for x in self.vals:
      ans_vals.append(f(x))
    return FlatTree(ans_vals, self.tree)

  def map2(self:FlatTree, f:Callable, t2:FlatTree) -> FlatTree:

    n = len(self)
    assert len(t2) == n
    ans_vals = []
    for x1, x2 in zip(self.vals, t2.vals):
      ans_vals.append(f(x1, x2))
    return FlatTree(ans_vals, self.tree)

  def map3(
      self:FlatTree, f:Callable, t2:FlatTree, t3:FlatTree) -> FlatTree:
    n = len(self)
    assert len(t2) == n and len(t3) == n
    ans_vals = []
    for x1, x2, x3 in zip(self.vals, t2.vals, t3.vals):
      ans_vals.append(f(x1, x2, x3))
    return FlatTree(ans_vals, self.tree)

  def zip(self, t2:FlatTree) -> FlatTree:
    assert False

  def unzip2(self:FlatTree) -> tuple[FlatTree, FlatTree]:
    ys = []
    zs = []
    for y, z in self.vals:
      ys.append(y)
      zs.append(z)
    return FlatTree(ys, self.tree), FlatTree(zs, self.tree)

  # TODO: add map3, zip3, unzip3 etc. as needed

  @staticmethod
  def pack(tree):
    # We could generalize this to arbitrary pytrees of FlatTree but tuples/dicts
    # are sufficient for now.
    if isinstance(tree, FlatTree):
      return tree
    elif isinstance(tree, tuple):
      vals = []
      trees = []
      for child_tree in tree:
        child = FlatTree.pack(child_tree)
        vals.extend(child.vals)
        trees.append(child.tree)
      return FlatTree(vals, treedef_tuple(trees))
    elif isinstance(tree, dict):
      # only empty case handled for now
      if tree == {}:
        return FlatTree.flatten({})
      else:
        assert False
    else:
      assert False

  def unpack(self:FlatTree) -> tuple[FlatTree, ...]:
    # TODO: this is O(N) not O(1) (with N as the number of leaves). If it
    # becomes a problem we can fix it with a fancier data tree.
    trees = treedef_children(self.tree)
    children = []
    offset = 0
    for tree in trees:
      new_offset = offset + tree.num_leaves
      children.append(FlatTree(self.vals[offset:new_offset], tree))
      offset = new_offset
    return tuple(children)

  @staticmethod
  def flatten(tree: PyTree) -> FlatTree:
    return FlatTree(*tree_flatten(tree))

  def unflatten(self) -> PyTree:
    return tree_unflatten(self.tree, self.vals)

  def update_from_list(self, new_vals:list) -> FlatTree:
    return FlatTree(new_vals, self.tree)

  def __len__(self):
    return self.tree.num_leaves

  def __iter__(self):
    return self.vals.__iter__()

  def __eq__(self, other):
    return (isinstance(other, FlatTree)
            and self.vals == other.vals
            and self.tree == other.tree)

  def __hash__(self):
    return hash((self.vals, self.tree))
