# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from functools import partial
import math
import operator

import numpy as np

from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import fft as jnp_fft
from jax._src.numpy.util import (
    promote_dtypes_complex, promote_dtypes_inexact, ensure_arraylike)
from jax._src.util import canonicalize_axis, canonicalize_axis_tuple
from jax._src.typing import Array

def _W4(N: int, k: Array) -> Array:
  N_arr, k = promote_dtypes_complex(N, k)
  return jnp.exp(-.5j * np.pi * k / N_arr)

def _dct_interleave(x: Array, axis: int) -> Array:
  v0 = lax.slice_in_dim(x, None, None, 2, axis)
  v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis,))
  return lax.concatenate([v0, v1], axis)

def _dct_ortho_norm(out: Array, axis: int) -> Array:
  factor = lax.concatenate([lax.full((1,), 4, out.dtype), lax.full((out.shape[axis] - 1,), 2, out.dtype)], 0)
  factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
  return out / lax.sqrt(factor * out.shape[axis])

# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)


def dct(x: Array, type: int = 2, n: int | None = None,
        axis: int = -1, norm: str | None = None) -> Array:
  """Computes the discrete cosine transform of the input

  JAX implementation of :func:`scipy.fft.dct`.

  Args:
    x: array
    type: integer, default = 2. Currently only type 2 is supported.
    n: integer, default = x.shape[axis]. The length of the transform.
      If larger than ``x.shape[axis]``, the input will be zero-padded, if
      smaller, the input will be truncated.
    axis: integer, default=-1. The axis along which the dct will be performed.
    norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
      The default is ``None``, which is equivalent to ``"backward"``.

  Returns:
    array containing the discrete cosine transform of x

  See Also:
    - :func:`jax.scipy.fft.dctn`: multidimensional DCT
    - :func:`jax.scipy.fft.idct`: inverse DCT
    - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT

  Examples:
    >>> x = jax.random.normal(jax.random.key(0), (3, 3))
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dct(x))
    [[ 6.43  3.56 -2.86]
     [-1.75  1.55 -1.4 ]
     [ 1.33 -2.01 -0.82]]

    When ``n`` smaller than ``x.shape[axis]``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dct(x, n=2))
    [[ 7.3  -0.57]
     [ 0.19 -0.36]
     [-0.   -1.4 ]]

    When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dct(x, n=2, axis=0))
    [[ 3.09  4.4  -2.81]
     [ 2.41  2.62  0.76]]

    When ``n`` larger than ``x.shape[axis]`` and ``axis=1``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dct(x, n=4, axis=1))
    [[ 6.43  4.88  0.04 -3.3 ]
     [-1.75  0.73  1.01 -2.18]
     [ 1.33 -1.05 -2.34 -0.07]]
  """
  x = ensure_arraylike("idctn", x)

  if type != 2:
    raise NotImplementedError('Only DCT type 2 is implemented.')
  if norm is not None and norm not in ['backward', 'ortho']:
    raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented")

  axis = canonicalize_axis(axis, x.ndim)
  if n is not None:
    x = lax.pad(x, jnp.array(0, x.dtype),
                [(0, n - x.shape[axis] if a == axis else 0, 0)
                 for a in range(x.ndim)])

  N = x.shape[axis]
  v = _dct_interleave(x, axis)
  V = jnp_fft.fft(v, axis=axis)
  k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
  out = V * _W4(N, k)
  out = 2 * out.real
  if norm == 'ortho':
    out = _dct_ortho_norm(out, axis)
  return out


def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
  axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
  N1, N2 = x.shape[axis1], x.shape[axis2]
  v = _dct_interleave(_dct_interleave(x, axis1), axis2)
  V = jnp_fft.fftn(v, axes=axes)
  k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
                       [a for a in range(x.ndim) if a != axis1])
  k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype),
                       [a for a in range(x.ndim) if a != axis2])
  out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
  out = 2 * out.real
  if norm == 'ortho':
    return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2)
  return out


def dctn(x: Array, type: int = 2,
         s: Sequence[int] | None=None,
         axes: Sequence[int] | None = None,
         norm: str | None = None) -> Array:
  """Computes the multidimensional discrete cosine transform of the input

  JAX implementation of :func:`scipy.fft.dctn`.

  Args:
    x: array
    type: integer, default = 2. Currently only type 2 is supported.
    s: integer or sequence of integers. Specifies the shape of the result. If not
      specified, it will default to the shape of ``x`` along the specified ``axes``.
    axes: integer or sequence of integers. Specifies the axes along which the
      transform will be computed.
    norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
      The default is ``None``, which is equivalent to ``"backward"``.

  Returns:
    array containing the discrete cosine transform of x

  See Also:
    - :func:`jax.scipy.fft.dct`: one-dimensional DCT
    - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
    - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT

  Examples:

    ``jax.scipy.fft.dctn`` computes the transform along both the axes by default
    when ``axes`` argument is ``None``.

    >>> x = jax.random.normal(jax.random.key(0), (3, 3))
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dctn(x))
    [[ 12.01   6.2  -10.17]
     [  8.84   9.65  -3.54]
     [ 11.25  -1.54  -0.88]]

    When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2``
    and dimension along ``axis 1`` will be same as that of input.

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dctn(x, s=[2]))
    [[ 9.36 10.22 -8.53]
     [11.57  2.85 -2.06]]

    When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will
    be ``2`` and dimension along ``axis 0`` will  be same as that of input.
    Also when ``axes=[1]``, transform will be computed only along ``axis 1``.

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dctn(x, s=[2], axes=[1]))
    [[ 7.3  -0.57]
     [ 0.19 -0.36]
     [-0.   -1.4 ]]

    When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``.

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...   print(jax.scipy.fft.dctn(x, s=[2, 4]))
    [[  9.36  11.23   2.12 -10.97]
     [ 11.57   5.86  -1.37  -1.58]]
  """
  x = ensure_arraylike("idctn", x)

  if type != 2:
    raise NotImplementedError('Only DCT type 2 is implemented.')
  if norm is not None and norm not in ['backward', 'ortho']:
    raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented")

  if s is not None:
    try:
      s = list(s)
    except TypeError:
      assert not isinstance(s, Sequence)
      s = [operator.index(s)]

  axes = canonicalize_axis_tuple(axes, x.ndim)

  if len(axes) == 1:
    return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)

  if s is not None:
    ns = dict(zip(axes, s))
    pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
    x = lax.pad(x, jnp.array(0, x.dtype), pads)

  if len(axes) == 2:
    return _dct2(x, axes=axes, norm=norm)

  # compose high-D DCTs from 2D and 1D DCTs:
  for axes_block in [axes[i:i+2] for i in range(0, len(axes), 2)]:
    x = dctn(x, axes=axes_block, norm=norm)
  return x


def idct(x: Array, type: int = 2, n: int | None = None,
         axis: int = -1, norm: str | None = None) -> Array:
  """Computes the inverse discrete cosine transform of the input

  JAX implementation of :func:`scipy.fft.idct`.

  Args:
    x: array
    type: integer, default = 2. Currently only type 2 is supported.
    n: integer, default = x.shape[axis]. The length of the transform.
      If larger than ``x.shape[axis]``, the input will be zero-padded, if
      smaller, the input will be truncated.
    axis: integer, default=-1. The axis along which the dct will be performed.
    norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
      The default is ``None``, which is equivalent to ``"backward"``.

  Returns:
    array containing the inverse discrete cosine transform of x

  See Also:
    - :func:`jax.scipy.fft.dct`: DCT
    - :func:`jax.scipy.fft.dctn`: multidimensional DCT
    - :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT

  Examples:

    >>> x = jax.random.normal(jax.random.key(0), (3, 3))
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...    print(jax.scipy.fft.idct(x))
    [[ 0.78  0.41 -0.39]
     [-0.12  0.31 -0.23]
     [ 0.17 -0.3  -0.11]]

    When ``n`` smaller than ``x.shape[axis]``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...    print(jax.scipy.fft.idct(x, n=2))
    [[ 1.12 -0.31]
     [ 0.04 -0.08]
     [ 0.05 -0.3 ]]

    When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...    print(jax.scipy.fft.idct(x, n=2, axis=0))
    [[ 0.38  0.57 -0.45]
     [ 0.43  0.44  0.24]]

    When ``n`` larger than ``x.shape[axis]`` and ``axis=0``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...    print(jax.scipy.fft.idct(x, n=4, axis=0))
    [[ 0.1   0.38 -0.16]
     [ 0.28  0.18 -0.26]
     [ 0.3   0.15 -0.08]
     [ 0.13  0.3   0.29]]

    ``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
    of ``jax.scipy.fft.dct``

    >>> x_dct = jax.scipy.fft.dct(x)
    >>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
    Array(True, dtype=bool)
  """
  x = ensure_arraylike("idct", x)

  if type != 2:
    raise NotImplementedError('Only DCT type 2 is implemented.')
  if norm is not None and norm not in ['backward', 'ortho']:
    raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented")

  axis = canonicalize_axis(axis, x.ndim)
  if n is not None:
    x = lax.pad(x, jnp.array(0, x.dtype),
                [(0, n - x.shape[axis] if a == axis else 0, 0)
                 for a in range(x.ndim)])
  N = x.shape[axis]
  x, = promote_dtypes_inexact(x)
  if norm is None or norm == 'backward':
    x = _dct_ortho_norm(x, axis)
  x = _dct_ortho_norm(x, axis)

  k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis])
  # everything is complex from here...
  w4 = _W4(N,k)
  x = x.astype(w4.dtype)
  x = x / (_W4(N, k))
  x = x * 2 * N

  x = jnp_fft.ifft(x, axis=axis)
  # convert back to reals..
  out = _dct_deinterleave(x.real, axis)
  return out


def idctn(x: Array, type: int = 2,
          s: Sequence[int] | None=None,
          axes: Sequence[int] | None = None,
          norm: str | None = None) -> Array:
  """Computes the multidimensional inverse discrete cosine transform of the input

  JAX implementation of :func:`scipy.fft.idctn`.

  Args:
    x: array
    type: integer, default = 2. Currently only type 2 is supported.
    s: integer or sequence of integers. Specifies the shape of the result. If not
      specified, it will default to the shape of ``x`` along the specified ``axes``.
    axes: integer or sequence of integers. Specifies the axes along which the
      transform will be computed.
    norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
      The default is ``None``, which is equivalent to ``"backward"``.

  Returns:
    array containing the inverse discrete cosine transform of x

  See Also:
    - :func:`jax.scipy.fft.dct`: one-dimensional DCT
    - :func:`jax.scipy.fft.dctn`: multidimensional DCT
    - :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT

  Examples:

    ``jax.scipy.fft.idctn`` computes the transform along both the axes by default
    when ``axes`` argument is ``None``.

    >>> x = jax.random.normal(jax.random.key(0), (3, 3))
    >>> with jnp.printoptions(precision=2, suppress=True):
    ...    print(jax.scipy.fft.idctn(x))
    [[ 0.12  0.11 -0.15]
     [ 0.07  0.17 -0.03]
     [ 0.19 -0.07 -0.02]]

    When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2``
    and dimension along ``axis 1`` will be the same as that of input.

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...  print(jax.scipy.fft.idctn(x, s=[2]))
    [[ 0.15  0.21 -0.18]
     [ 0.24 -0.01 -0.02]]

    When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will
    be ``2`` and dimension along ``axis 0`` will  be same as that of input.
    Also when ``axes=[1]``, transform will be computed only along ``axis 1``.

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...  print(jax.scipy.fft.idctn(x, s=[2], axes=[1]))
    [[ 1.12 -0.31]
     [ 0.04 -0.08]
     [ 0.05 -0.3 ]]

    When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``

    >>> with jnp.printoptions(precision=2, suppress=True):
    ...  print(jax.scipy.fft.idctn(x, s=[2, 4]))
    [[ 0.1   0.18  0.07 -0.16]
     [ 0.2   0.06 -0.03 -0.01]]

    ``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
    of ``jax.scipy.fft.dctn``

    >>> x_dctn = jax.scipy.fft.dctn(x)
    >>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
    Array(True, dtype=bool)
  """
  x = ensure_arraylike("idctn", x)

  if type != 2:
    raise NotImplementedError('Only DCT type 2 is implemented.')
  if norm is not None and norm not in ['backward', 'ortho']:
    raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented")

  if s is not None:
    try:
      s = list(s)
    except TypeError:
      assert not isinstance(s, Sequence)
      s = [operator.index(s)]

  axes = canonicalize_axis_tuple(axes, x.ndim)

  if len(axes) == 1:
    return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)

  if s is not None:
    ns = dict(zip(axes, s))
    pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
    x = lax.pad(x, jnp.array(0, x.dtype), pads)

  # compose high-D DCTs from 1D DCTs:
  for axis in axes:
    x = idct(x, axis=axis, norm=norm)
  return x


def _dct_deinterleave(x: Array, axis: int) -> Array:
  empty_slice = slice(None, None, None)
  ix0 = tuple(
      slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice
      for i in range(len(x.shape)))
  ix1  = tuple(
      slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice
      for i in range(len(x.shape)))
  v0 = x[ix0]
  v1 = lax.rev(x[ix1], (axis,))
  out = jnp.zeros(x.shape, dtype=x.dtype)
  evens = tuple(
      slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
  odds = tuple(
      slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
  out =  out.at[evens].set(v0)
  out = out.at[odds].set(v1)
  return out
