# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Callable
import functools
from typing import Any

from jax._src import ad_util
from jax._src import api_util
from jax._src import core
from jax._src import custom_api_util
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.tree_util import (tree_flatten, tree_leaves, tree_map,
                                tree_structure, treedef_tuple, tree_unflatten,
                                PyTreeDef)


source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)


map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip


### bespoke linear_util and api_util deviations

class StoreEqual(lu.Store):
  """Stores an unchanging value. Checks empty reads and unequal overwrites."""
  def store(self, val):
    if self._val is not lu._EMPTY_STORE_VALUE and val != self._val:
      raise lu.StoreException(
          f"Store assignment mismatch, from {self._val} to {val}")
    self._val = val

@util.curry
def transformation_with_aux(
    gen, fun: lu.WrappedFun, *gen_static_args) -> tuple[lu.WrappedFun, Any]:
  out_store = StoreEqual()
  out_thunk = lambda: out_store.val
  return fun.wrap(gen, gen_static_args, out_store), out_thunk

flatten_fun_nokwargs = transformation_with_aux(
    api_util.flatten_fun_nokwargs.args[0])


### api

@custom_api_util.register_custom_decorator_type
class custom_transpose:
  fun: Callable
  transpose: Callable | None = None

  def __init__(self, fun: Callable):
    functools.update_wrapper(self, fun)
    self.fun = fun

  __getattr__ = custom_api_util.forward_attr

  def def_transpose(self, transpose: Callable):
    self.transpose = transpose
    return transpose

  @traceback_util.api_boundary
  def __call__(self, out_types, res_arg, lin_arg):
    _, res_tree = tree_flatten(res_arg)
    _, lin_tree = tree_flatten(lin_arg)
    args_flat, in_tree = tree_flatten((res_arg, lin_arg))

    # TODO(frostig,mattjj): check that out_trees match
    # TODO(frostig,mattjj): could, and should, we avoid flattening
    # self.fun at this point?

    flat_fun, out_tree2 = flatten_fun_nokwargs(
        lu.wrap_init(
            self.fun,
            debug_info=api_util.debug_info("custom_transpose fun", self.fun,
                                           (res_arg, lin_arg), {})),
        in_tree)
    out_types_flat, out_tree = tree_flatten(out_types)
    transpose_wrapped = lu.wrap_init(
        self.transpose,
        debug_info=api_util.debug_info("custom_transpose transpose_fun",
                                       self.transpose,
                                       (res_arg, out_types), {})                             )
    out_flat = custom_transpose_p.bind(flat_fun, *args_flat,
                                       transpose=transpose_wrapped,
                                       out_types=tuple(out_types_flat),
                                       lin_tree=lin_tree,
                                       res_tree=res_tree,
                                       out_tree=out_tree)
    return tree_unflatten(out_tree, out_flat)


### utils

def tree_fill(x, treedef):
  return tree_unflatten(treedef, [x] * treedef.num_leaves)

def tree_fill_like(x, tree):
  return tree_fill(x, tree_structure(tree))

def tree_broadcast(full_treedef, tree, is_leaf=None):
  full_tree = tree_fill(0, full_treedef)
  return tree_map(tree_fill_like, tree, full_tree, is_leaf=is_leaf)

def is_treedef_prefix(entire, prefix):
  entire = tree_fill(0, entire)
  prefix = tree_fill(0, prefix)
  try:
    tree_map(lambda x, y: x, prefix, entire)
  except ValueError:
    return False
  return True

def rule_name(rule):
  return getattr(rule, '__name__', '<unnamed transpose rule>')

def check_transpose_rule_trees(rule: lu.WrappedFun,
                               lin_tree: PyTreeDef,
                               rule_out_tree: PyTreeDef):
  if not is_treedef_prefix(lin_tree, rule_out_tree):
    rule_name = rule.debug_info.func_src_info if rule.debug_info else "<unknown>"
    raise TypeError(
        'structure of custom transpose rule\'s output does not prefix-match '
        'structure of primal function\'s linear inputs under '
        f'custom transpose rule ({rule_name}).\n'
        f'Transpose rule output: {rule_out_tree}\n'
        f'Linear primal inputs: {lin_tree}')

def make_transpose_from_thunk(thunk: Callable,
                              lin_tree: PyTreeDef) -> lu.WrappedFun:
  transpose_jaxpr, transpose_consts = thunk()
  transpose_jaxpr = core.ClosedJaxpr(
      pe.convert_constvars_jaxpr(transpose_jaxpr), ())
  def transpose(res_arg, ct_out):
    args_flat = tree_leaves((res_arg, ct_out))
    ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
    return tree_unflatten(lin_tree, ct_ins)
  return lu.wrap_init(transpose, debug_info=transpose_jaxpr.jaxpr.debug_info)


### custom_transpose primitive and rules

class CustomTransposePrimitive(core.Primitive):
  call_primitive = False
  map_primitive = False
  multiple_results = True

  def bind(self, *args, **params):
    return self._true_bind(*args, **params)

  def bind_with_trace(self, trace, call_args, params):
    call, tracers = call_args[0], call_args[1:]
    return trace.process_custom_transpose(self, call, tracers, **params)

  # TODO(frostig,mattjj): consider keeping `call` as a named parameter
  # instead of following this "call primitive" convention.
  def get_bind_params(self, params):
    if 'call_jaxpr' in params:
      assert 'transpose_jaxpr_thunk' in params
      new_params: dict[str, Any] = dict(params)
      new_params['transpose'] = make_transpose_from_thunk(
          new_params.pop('transpose_jaxpr_thunk'),
          new_params['lin_tree'])
      call_jaxpr: core.ClosedJaxpr = new_params.pop('call_jaxpr')
      call = lu.wrap_init(core.jaxpr_as_fun(call_jaxpr),
                          debug_info=call_jaxpr.jaxpr.debug_info)
    else:
      assert 'transpose' in params
      new_params: dict[str, Any] = dict(params)
      call = new_params.pop("call")
    return [call], new_params


# TODO(frostig,mattjj): reinstate checks
def custom_transpose_typecheck(_, *in_atoms, out_types, **params):
  del in_atoms, params
  return out_types, core.no_effects


def custom_transpose_transpose_rule(
    cts, *args, out_types, res_tree, lin_tree, out_tree, **params):
  transpose: lu.WrappedFun
  if 'transpose_jaxpr_thunk' in params:
    assert 'call_jaxpr' in params
    transpose = make_transpose_from_thunk(
        params['transpose_jaxpr_thunk'], lin_tree)
  else:
    assert 'call' in params
    transpose = params['transpose']

  call_in_tree = treedef_tuple((res_tree, lin_tree))

  # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
  # to which we are transposing (via `ad.is_undefined_primal`).
  # Consider passing this information to the custom transpose rule?

  res_arg, lin_arg = tree_unflatten(call_in_tree, args)
  assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

  cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
         for ct in cts]
  ct_out = tree_unflatten(out_tree, cts)
  ct_lin = transpose.call_wrapped(res_arg, ct_out)
  check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
  ct_lin = tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None)

  # When the transpose returns None, we treat that as a Zero, except when the
  # input is also None. In that case, the cotangent corresponding to that input
  # should be dropped.
  zero = object()
  ct_lin = tree_map(lambda l, ct: zero if ct is None and l is not None else ct,
                    lin_arg, ct_lin, is_leaf=ad.is_undefined_primal)

  ct_lin_flat, _ = tree_flatten(ct_lin)
  return [None] * res_tree.num_leaves + [None if ct is zero else ct for ct in ct_lin_flat]


def custom_transpose_lowering(*args, call_jaxpr, **params):
  return core.jaxpr_as_fun(call_jaxpr)(*args)


custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
    custom_transpose_p,
    mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
pxla.register_initial_style_primitive(custom_transpose_p)
