
    uki^                    ,   d Z ddlmZ ddlmZ ddlZddlZddlmZ ddl	Z
ddlmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlmZmZmZ ddlm Z  ddl!m"Z"m#Z#  e#d      Z$dZ%e$ G d d             Z&e$dd	 	 	 dd       Z'y)z#Tools to create numpy-style ufuncs.    )annotations)CallableN)Any)api)Array	ArrayLike	DTypeLike)control_flow)slicing)lax)indexing)	lax_numpy)	_moveaxis)check_arraylike_broadcast_to_where)	vectorize)canonicalize_axis
set_modulez	jax.numpyzBecause JAX arrays are immutable, jnp.ufunc.at() cannot operate inplace like
np.ufunc.at(). Instead, you can pass inplace=False and capture the result; e.g.
>>> arr = jnp.add.at(arr, ind, val, inplace=False)
c            	         e Zd ZdZddddddddd	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d#dZ ed       Z ed       Z ed       Z ed       Z	 ed	       Z
d$d
Zd%dZd&dZdddd'dZ ej                   dg      d        Z ej                   g d      	 	 	 	 d(	 	 	 	 	 	 	 	 	 	 	 d)d       Z	 	 	 d*	 	 	 	 	 	 	 d+dZ ej                   g d      	 	 d,	 	 	 d-d       Z	 	 d.	 	 	 d/dZ ej                   dgdg      d0dd	 	 	 d1d       Zd2dZ ej                   g d      	 	 d,	 	 	 	 	 d3d       Z	 	 d.	 	 	 d4d Z ej                   dg!      d5d"       Zy)6ufuncat
  Universal functions which operation element-by-element on arrays.

  JAX implementation of :class:`numpy.ufunc`.

  This is a class for JAX-backed implementations of NumPy's ufunc APIs.
  Most users will never need to instantiate :class:`ufunc`, but rather
  will use the pre-defined ufuncs in :mod:`jax.numpy`.

  For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`.

  Examples:
    Universal functions are functions that apply element-wise to broadcasted
    arrays, but they also come with a number of extra attributes and methods.

    As an example, consider the function :obj:`jax.numpy.add`. The object
    acts as a function that applies addition to broadcasted arrays in an
    element-wise manner:

    >>> x = jnp.array([1, 2, 3, 4, 5])
    >>> jnp.add(x, 1)
    Array([2, 3, 4, 5, 6], dtype=int32)

    Each :class:`ufunc` object includes a number of attributes that describe
    its behavior:

    >>> jnp.add.nin  # number of inputs
    2
    >>> jnp.add.nout  # number of outputs
    1
    >>> jnp.add.identity  # identity value, or None if no identity exists
    0

    Binary ufuncs like :obj:`jax.numpy.add` include  number of methods to
    apply the function to arrays in different manners.

    The :meth:`~ufunc.outer` method applies the function to the
    pair-wise outer-product of the input array values:

    >>> jnp.add.outer(x, x)
    Array([[ 2,  3,  4,  5,  6],
           [ 3,  4,  5,  6,  7],
           [ 4,  5,  6,  7,  8],
           [ 5,  6,  7,  8,  9],
           [ 6,  7,  8,  9, 10]], dtype=int32)

    The :meth:`ufunc.reduce` method performs a reduction over the array.
    For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``:

    >>> jnp.add.reduce(x)
    Array(15, dtype=int32)

    The :meth:`ufunc.accumulate` method performs a cumulative reduction
    over the array. For example, :meth:`jnp.add.accumulate` is equivalent
    to :func:`jax.numpy.cumulative_sum`:

    >>> jnp.add.accumulate(x)
    Array([ 1,  3,  6, 10, 15], dtype=int32)

    The :meth:`ufunc.at` method applies the function at particular indices in the
    array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`:

    >>> jnp.add.at(x, 0, 100, inplace=False)
    Array([101,   2,   3,   4,   5], dtype=int32)

    And the :meth:`ufunc.reduceat` method performs a number of ``reduce``
    operations between specified indices of an array; for ``jnp.add`` the
    operation is similar to :func:`jax.ops.segment_sum`:

    >>> jnp.add.reduceat(x, jnp.array([0, 2]))
    Array([ 3, 12], dtype=int32)

    In this case, the first element is ``x[0:2].sum()``, and the second element
    is ``x[2:].sum()``.
  N)namenargsidentitycallreduce
accumulateatreduceatc                  |j                   | _         |xs |j                  | _        |t        j                  |      t        j                  |      t        j                  |xs |      ||||	|
|d
| _        y )N)
funcninnoutr   r   r   r   r   r   r   )__doc____name__operatorindex_ufunc__static_props)selfr!   r"   r#   r   r   r   r   r   r   r   r   s               S/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/jax/_src/numpy/ufunc_api.py__init__zufunc.__init__|   sh     <<DL)DMMDM ^^C nnT"~~els+D    c                     | j                   d   S )Nr!   r(   r)   s    r*   <lambda>zufunc.<lambda>   s     3 3F ; r,   c                     | j                   d   S )Nr"   r.   r/   s    r*   r0   zufunc.<lambda>   s    d11%8 r,   c                     | j                   d   S )Nr#   r.   r/   s    r*   r0   zufunc.<lambda>   s    t226: r,   c                     | j                   d   S )Nr   r.   r/   s    r*   r0   zufunc.<lambda>   s     3 3G < r,   c                     | j                   d   S )Nr   r.   r/   s    r*   r0   zufunc.<lambda>   s    4#6#6z#B r,   c                    t        | j                  | j                  | j                  | j                  | j
                  | j                  f      S N)hash_funcr%   r   r"   r#   r   r/   s    r*   __hash__zufunc.__hash__   s;     T]]DMM499djj2 3 3r,   c                8   t        |t              xr | j                  | j                  | j                  | j
                  | j                  | j                  f|j                  |j                  |j                  |j
                  |j                  |j                  fk(  S r6   )
isinstancer   r8   r%   r   r"   r#   r   )r)   others     r*   __eq__zufunc.__eq__   sn    eU# Yzz4==$--499djjQ{{ENNENNEIIuzz5;;WXYr,   c                "    d| j                    dS )Nz<jnp.ufunc 'z'>)r%   r/   s    r*   __repr__zufunc.__repr__   s    $--++r,   )outwherec                   t        | j                  g|  |t        d|        |t        d|        | j                  d   xs | j                  } || S )Nout argument of zwhere argument of r   )r   r%   NotImplementedErrorr(   _call_vectorized)r)   r@   rA   argsr   s        r*   __call__zufunc.__call__   si    DMM)D)
"24& 9::"4TF ;<<v&?$*?*?D;r,   r)   )static_argnamesc                2     t        | j                        | S r6   )r   r8   )r)   rF   s     r*   rE   zufunc._call_vectorized   s     9TZZ $''r,   )r)   axisdtyper@   keepdimsr   c                \   t        | j                   d|       | j                  dk7  rt        d      | j                  dk7  rt        d      |t        d| j                   d      |t        | j                   d|       |}t        | j                   d|       | j                  |t        d| j                  d	      t        j                  |      t        k7  r!t        d
t        j                  |             | j                  d   xs | j                  } |||||||      S )aF  Reduction operation derived from a binary function.

    JAX implementation of :meth:`numpy.ufunc.reduce`.

    Args:
      a: Input array.
      axis: integer specifying the axis over which to reduce. default=0
      dtype: optionally specify the type of the output array.
      out: Unused by JAX
      keepdims: If True, reduced axes are left in the result with size 1.
        If False (default) then reduced axes are squeezed out.
      initial: int or array, Default=None. Initial value for the reduction.
      where: boolean mask, default=None. The elements to be used in the sum. Array
        should be broadcast compatible to the input.

    Returns:
      array containing the result of the reduction operation.

    Examples:
      Consider the following array:

      >>> x = jnp.array([[1, 2, 3],
      ...                [4, 5, 6]])

      :meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum`
      along ``axis=0``:

      >>> jnp.add.reduce(x)
      Array([5, 7, 9], dtype=int32)
      >>> x.sum(0)
      Array([5, 7, 9], dtype=int32)

      Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to
      :func:`jax.numpy.all`:

      >>> jnp.logical_and.reduce(x > 2)
      Array([False, False,  True], dtype=bool)
      >>> jnp.all(x > 2, axis=0)
      Array([False, False,  True], dtype=bool)

      Some reductions do not correspond to any built-in aggregation function;
      for example here is the reduction of :func:`jax.numpy.bitwise_or` along
      the first axis of ``x``:

      >>> jnp.bitwise_or.reduce(x, axis=1)
      Array([3, 7], dtype=int32)
    z.reduce   z'reduce only supported for binary ufuncs   z<reduce only supported for functions returning a single valuerC   z	.reduce()zreduction operation zP does not have an identity, so to use a where mask one has to specify 'initial'.z/where argument must have dtype=bool; got dtype=r   )rJ   rK   rL   initialrA   )r   r%   r"   
ValueErrorr#   rD   r   r   _dtypeboolr(   _reduce_via_scan)	r)   arJ   rK   r@   rL   rP   rA   r   s	            r*   r   zufunc.reduce   s2   h t}}oW-q1xx1}@AAyyA~UVV
"24==/ KLLw/9w/7		7?//@ AP P Q 	Q	E	d	"J3::V[K\J]^__  *Cd.C.CF!$ehW\]]r,   c                     j                   dk(  r j                  dk(  sJ t        j                        | j                  }Qt        j                   j                  t        j                        t        j                              j                  t        j                        t        |t              rt        fd|D              }t        d      |9|rdj                  z  }nd}j!                         j!                         d}nct#        |j                        }|r'g j                  d | dj                  |dz   d  }n$g j                  d | j                  |dz   d  }|dk7  rt%        |d      t%        |d      j                  d   dk(  r2|t'        d j(                   d	      t        j*                  ||      S  fd
}|d}	d   }
nd}	|}
t        t        j                  |
      j-                        j                  dd        }
t/        j0                  |	j                  d   ||
      }|r|j3                  |      }|S )NrN   rO   c              3  J   K   | ]  }t        |j                          y wr6   )r   ndim).0rU   arrs     r*   	<genexpr>z)ufunc._reduce_via_scan.<locals>.<genexpr>  s     @a$Q1@s    #ztuple of axes)rO    r   z'zero-size array to reduction operation z which has no identityc           	          ||    j                              S t        |     ||    j                              |      S r6   )astyper   )ivalrZ   rK   r)   rA   s     r*   body_funz(ufunc._reduce_via_scan.<locals>.body_fun,  sI    	CQu-..eAhS#a&--*> ?EEr,   )r"   r#   r   asarrayr   r   
eval_shaper8   _onerK   r   shaper;   tuplerD   rX   ravelr   r   rQ   r%   fullr^   r
   	fori_loopreshape)r)   rZ   rJ   rK   rL   rP   rA   final_shapera   start_indexstart_valueresults   `` `  `     r*   rT   zufunc._reduce_via_scan  s?    88q=TYY!^++
++c
Cg}nnTZZ#FLLeE399-e$@4@@d00		SXXoIIKc		dtSXX.d	C		%4(C!Cciiq	.BC@		%4(@399TAXY+?@ qyc4#c		%q)
yy|q	B4==/QghiiXXk7E22F kFkkkK 8 ? ? F		RSRTVK##K1xUF~~k*fMr,   )r)   rJ   rK   c                    | j                   dk7  rt        d      | j                  dk7  rt        d      |t        d| j                   d      | j
                  d   xs | j                  } ||||      S )	a  Accumulate operation derived from binary ufunc.

    JAX implementation of :func:`numpy.ufunc.accumulate`.

    Args:
      a: N-dimensional array over which to accumulate.
      axis: integer axis over which accumulation will be performed (default = 0)
      dtype: optionally specify the type of the output array.
      out: Unused by JAX

    Returns:
      An array containing the accumulated result.

    Examples:
      Consider the following array:

      >>> x = jnp.array([[1, 2, 3],
      ...                [4, 5, 6]])

      :meth:`jax.numpy.add.accumulate` is equivalent to
      :func:`jax.numpy.cumsum` along the specified axis:
      >>> jnp.add.accumulate(x, axis=1)
      Array([[ 1,  3,  6],
             [ 4,  9, 15]], dtype=int32)
      >>> jnp.cumsum(x, axis=1)
      Array([[ 1,  3,  6],
             [ 4,  9, 15]], dtype=int32)

      Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to
      :func:`jax.numpy.cumprod` along the specified axis:

      >>> jnp.multiply.accumulate(x, axis=1)
      Array([[  1,   2,   6],
             [  4,  20, 120]], dtype=int32)
      >>> jnp.cumprod(x, axis=1)
      Array([[  1,   2,   6],
             [  4,  20, 120]], dtype=int32)

      For other binary ufuncs, the accumulation is an operation not available
      via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate`
      is essentially a bitwise cumulative ``any``:

      >>> jnp.bitwise_or.accumulate(x, axis=1)
      Array([[1, 3, 3],
             [4, 5, 7]], dtype=int32)
    rN   z+accumulate only supported for binary ufuncsrO   z@accumulate only supported for functions returning a single valuerC   z.accumulate()r   rJ   rK   )r"   rQ   r#   rD   r%   r(   _accumulate_via_scan)r)   rU   rJ   rK   r@   r   s         r*   r   zufunc.accumulateA  s{    b xx1}DEEyyA~YZZ
"24==/ OPP$$\2Od6O6OJad%00r,   c                     j                   dk(  r j                  dk(  sJ t         j                   d       t	        j
                        Qt        j                   j                  t	        j                        t	        j                              j                  |t        |t              rt        d      t        |t        j                               }j"                  dk(  r!t	        j$                  j&                  d      S t)        |d       fd}t+        j,                  |dd   j/                        fd j&                  d         \  }}t)        |d|      S )NrN   rO   z.accumulatez'accumulate does not allow multiple axesr   c           
         | \  }}t        |dk(  d   j                         |j                        |   j                                    }|dz   |f|fS )Nr   rO   )r   r^   )carry_r_   xyrZ   rK   r)   s        r*   scan_funz,ufunc._accumulate_via_scan.<locals>.scan_fun  sZ    da
aQu-tAHHUOSV]]SXEY/Z
[a!eQZ]r,   )length)r"   r#   r   r%   r   rb   r   rc   r8   rd   rK   r;   rf   rQ   r   nprX   sizerh   re   r   r
   scanr^   )r)   rZ   rJ   rK   rx   ru   rn   s   `` `   r*   rq   zufunc._accumulate_via_scan{  s   88q=TYY!^++t}}o[137
++c
C}nnTZZ#FLLe|z$.@AAT2773<0D
xx1}XXciiE**
Cq
!C !!(QAe0D,EtTWT]T]^_T`aIAvVQ%%r,   inplace)static_argnumsrH   T)r}   c                  |rt        t              | j                  d   xs | j                  }|	 |||      S  ||||      S )a  Update elements of an array via the specified unary or binary ufunc.

    JAX implementation of :func:`numpy.ufunc.at`.

    Note:
      :meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable,
      so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX
      will return the updated value, but requires explicitly passing ``inplace=False``
      as a reminder of this difference.

    Args:
      a: N-dimensional array to update
      indices: index, slice, or tuple of indices and slices.
      b: array of values for binary ufunc updates.
      inplace: must be set to False to indicate that an updated copy will be returned.

    Returns:
      an updated copy of the input array.

    Examples:

      Add numbers to specified indices:

      >>> x = jnp.ones(10, dtype=int)
      >>> indices = jnp.array([2, 5, 7])
      >>> values = jnp.array([10, 20, 30])
      >>> jnp.add.at(x, indices, values, inplace=False)
      Array([ 1,  1, 11,  1,  1, 21,  1, 31,  1,  1], dtype=int32)

      This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method
      called this way:

      >>> x.at[indices].add(values)
      Array([ 1,  1, 11,  1,  1, 21,  1, 31,  1,  1], dtype=int32)
    r   )rD   _AT_INPLACE_WARNINGr(   _at_via_scan)r)   rU   indicesbr}   r   s         r*   r   zufunc.at  sM    L  344			T	"	7d&7&7BY2a>=Bq'1,==r,   c           	     n    t              dv sJ t         j                   d|g  t        j                   j
                  t        j                  |      gd D         j                  t        j                  |      j                        }t        fdD              t        j                        s|S D cg c](  }t        |t              rt!        j"                  |      * }}|xr t        j$                  | }|s@|j&                     j)                    |j&                     j+                         g       S rmt-        d   g |d   j"                  t        |      d        } |j.                  t1        j2                  |      gd   j"                  t        |      d   fD cg c].  }t        |t              r|nt-        ||      j5                         0 c} fd}	t7        j8                  |	d|fd t        d               \  }
}|
d   S c c}w c c}w )N>   r   rO   z.atc              3  F   K   | ]  }t        j                  |        y wr6   )r   rd   )rY   args     r*   r[   z%ufunc._at_via_scan.<locals>.<genexpr>  s     5Tchhsm5Ts   !c              3  f   K   | ](  }t        j                  |      j                         * y wr6   )r   rb   r^   )rY   r   rK   s     r*   r[   z%ufunc._at_via_scan.<locals>.<genexpr>  s%     @CS!((/@s   .1r   c                    | \  }t        fdD              }|j                  |   j                   |j                  |   j                         gfdD               }dz   |f|fS )Nc              3  L   K   | ]  }t        |t              r|n|     y wr6   )r;   slice)rY   indr_   s     r*   r[   z7ufunc._at_via_scan.<locals>.scan_fun.<locals>.<genexpr>  s#     OC/#SV;Os   !$c              3  (   K   | ]	  }|     y wr6   r\   )rY   r   r_   s     r*   r[   z7ufunc._at_via_scan.<locals>.scan_fun.<locals>.<genexpr>  s     /G3A/Gs   rO   )rf   r   setget)rt   rv   rU   idxr_   rF   r   r)   s       @r*   rx   z$ufunc._at_via_scan.<locals>.scan_fun  sb    daOwOOc
$$s)--QTT#Y]]_H/G$/GH
Ia!eQZ]r,   rO   )lenr   r%   r   rc   r8   r   rd   rK   rb   r^   rf   r   "eliminate_deprecated_list_indexingr;   r   rz   re   broadcast_shapesr   r   r   r   rj   mathprodrg   r
   r|   )r)   rU   r   rF   r_   shapesre   r   r   rx   rt   ru   rK   s   ` ``        @r*   r   zufunc._at_via_scan  s   t9t}}oS)14t4NN4::sxx{U5Tt5TU[[EAe$A@4@@D99'BGh#*Ga*Q2FbhhqkGFG4s++V4ETT']tADDM$5$5$7?$?@@$q'#HU#HT!W]]3u:;-G#HIcckk$))E*HT!W]]3u:;-GHJd_fgX[je,s-U2K2Q2Q2SSgG
   Aq64WQZIHE18O! H hs   H-H-3H2c                    | j                   dk7  rt        d      | j                  dk7  rt        d      |t        d| j                   d      | j
                  d   xs | j                  } |||||      S )	aX  Reduce an array between specified indices via a binary ufunc.

    JAX implementation of :meth:`numpy.ufunc.reduceat`

    Args:
      a: N-dimensional array to reduce
      indices: a 1-dimensional array of increasing integer values which encodes
        segments of the array to be reduced.
      axis: integer specifying the axis along which to reduce: default=0.
      dtype: optionally specify the dtype of the output array.
      out: unused by JAX
    Returns:
      An array containing the reduced values.

    Examples:
      The ``reduce`` method lets you efficiently compute reduction operations
      over array segments. For example:

      >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
      >>> indices = jnp.array([0, 2, 5])
      >>> jnp.add.reduce(x, indices)
      Array([ 3, 12, 21], dtype=int32)

      This is more-or-less equivalent to the following:

      >>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()])
      Array([ 3, 12, 21], dtype=int32)

      For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`.
      For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`,
      although in this case the segments are defined via an array of segment ids:

      >>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2])
      >>> jax.ops.segment_sum(x, segments)
      Array([ 3, 12, 21], dtype=int32)
    rN   z)reduceat only supported for binary ufuncsrO   z>reduceat only supported for functions returning a single valuerC   z.reduceat()r   rp   )r"   rQ   r#   rD   r%   r(   _reduceat_via_scan)r)   rU   r   rJ   rK   r@   r   s          r*   r   zufunc.reduceat  s}    N xx1}BCCyyA~WXX
"24==/ MNN"":.I$2I2IHAwT77r,   c           
        	
 t         j                   d|       t        j                        t	        j
                  |      }t        |      dk(  sJ |d   }j                  dk(  rt        dj                        |j                  dk7  rt        d|j                        |j                  }t        t        t        f      rt        d      t        j                        t	        j                  |      }t        j                   t#        j$                  |j                           t        t'        j(                  t'        j*                  |j                                          }t-        j.                  |d|j                     dz
        
t-        j.                  |d|j                           		
 fd}t1        j2                  dj                     ||      S )	Nz	.reduceatrO   r   z7reduceat: a must have 1 or more dimension, got a.shape=z=reduceat: indices must be one-dimensional, got indices.shape=z(reduceat requires a single integer axis.rJ   c                    t        | kD  | k  z   |t        j                  t        j                  | d                  |      S )N)r   r   )r   r   taker   expand_dims)r_   r@   rU   rJ   ind_end	ind_startr)   s     r*   	loop_bodyz+ufunc._reduceat_via_scan.<locals>.loop_body   sD    Q]q7{3hmmAsq$/GdST r,   )r   r%   r   rb   r   r   r   rX   rQ   re   rK   r;   rf   listr   r   r   jnpappendrz   deletearanger   slice_in_dimr
   ri   )r)   rU   r   rJ   rK   	idx_tupler@   r   r   r   r   s   `` `     @@r*   r   zufunc._reduceat_via_scan  s   t}}oY/G<AA;;GDIy>QlGvv{Q
STT||qWHXYZZ}gge|z$6ABBT166*D
--7
.C
//#**Waggdm<#'		"))CHH2Et(L#MOC$$S!SYYt_q-@tLI""3399T?FG  !!!QWWT]IsCCr,   )r~   c                  | j                   dk7  rt        d      | j                  dk7  rt        d      t        | j                   d||       d } t        j                  t        j                  | d      d       ||       ||            } |j                  g t        j                  |      t        j                  |       S )	a  Apply the function to all pairs of values in ``A`` and ``B``.

    JAX implementation of :meth:`numpy.ufunc.outer`.

    Args:
      A: N-dimensional array
      B: N-dimensional array

    Returns:
      An array of shape `tuple(*A.shape, *B.shape)`

    Examples:
      A times-table for integers 1...10 created via
      :meth:`jax.numpy.multiply.outer`:

      >>> x = jnp.arange(1, 11)
      >>> print(jnp.multiply.outer(x, x))
      [[  1   2   3   4   5   6   7   8   9  10]
       [  2   4   6   8  10  12  14  16  18  20]
       [  3   6   9  12  15  18  21  24  27  30]
       [  4   8  12  16  20  24  28  32  36  40]
       [  5  10  15  20  25  30  35  40  45  50]
       [  6  12  18  24  30  36  42  48  54  60]
       [  7  14  21  28  35  42  49  56  63  70]
       [  8  16  24  32  40  48  56  64  72  80]
       [  9  18  27  36  45  54  63  72  81  90]
       [ 10  20  30  40  50  60  70  80  90 100]]

      For input arrays with ``N`` and ``M`` dimensions respectively, the output
      will have dimension ``N + M``:

      >>> x = jnp.ones((1, 3, 5))
      >>> y = jnp.ones((2, 4))
      >>> jnp.add.outer(x, y).shape
      (1, 3, 5, 2, 4)
    rN   z&outer only supported for binary ufuncsrO   z;outer only supported for functions returning a single valuez.outerc                V    t        j                  | t        j                  |       f      S r6   )r   rj   rz   r{   )As    r*   r0   zufunc.outer.<locals>.<lambda>Q  s    s{{1rwwqzm4 r,   )Nr   r   N)
r"   rQ   r#   r   r%   r   vmaprj   rz   re   )r)   r   B_ravelrn   s        r*   outerzufunc.outer&  s    L xx1}?@@yyA~TUUt}}oV,a34F;SXXchhtY/;F1IvayQF6>>5288A;5!55r,   )r"   intr#   r   r!   Callable[..., Any]r   z
str | Noner   
int | Noner   r   r   Callable[..., Any] | Noner   r   r   r   r   r   r   r   )returnr   )r<   r   r   rS   )r   str)rF   r   r@   NonerA   r   r   r   )r   NNFNN)rU   r   rJ   r   rK   DTypeLike | Noner@   r   rL   rS   rP   ArrayLike | NonerA   r   r   r   )r   NFNN)rZ   r   rJ   r   rK   r   rL   rS   rP   r   rA   r   r   r   )r   NN)
rU   r   rJ   r   rK   r   r@   r   r   r   r   )rZ   r   rJ   r   rK   r   r   r   r6   )
rU   r   r   r   r   r   r}   rS   r   r   )rU   r   r   r   rF   r   r   r   )rU   r   r   r   rJ   r   rK   r   r@   r   r   r   )
rU   r   r   r   rJ   r   rK   r   r   r   )r   r   r   r   r   r   )r%   
__module____qualname__r$   r+   propertyr8   r"   r#   r   r   r9   r=   r?   rG   r   jitrE   r   rT   r   rq   r   r   r   r   r   r\   r,   r*   r   r   0   s   IX #'#'#15377;/359" ! 	
 / 1 5 - 3< ;
<%89#	:	;$
<
=%BC(3Y
, 48t  377F8$( %( 377GH45'+SW'+D^$D^D^)-D^@PD^ %D^ 16D^ ID^L ^bKO15<!%<8H<.<:?<| 37745NR!7171&+71 671r >?59&"2&>C&. 3771#	{;)>)>)>"')> <)>V8 37745=>;?.8&.848.8DI.8 6.8` HI37D 0D<AD6 3771#,6 ,6r,   r   r   c                   t        | |||      S )aD  Create a JAX ufunc from an arbitrary JAX-compatible scalar function.

  Args:
    func : a callable that takes `nin` scalar arguments and returns `nout` outputs.
    nin: integer specifying the number of scalar inputs
    nout: integer specifying the number of scalar outputs
    identity: (optional) a scalar specifying the identity of the operation, if any.

  Returns:
    wrapped : jax.numpy.ufunc wrapper of func.

  Examples:
    Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`:

    >>> import operator
    >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)

    Now all the standard :class:`jax.numpy.ufunc` methods are available:

    >>> x = jnp.arange(4)
    >>> add(x, 10)
    Array([10, 11, 12, 13], dtype=int32)
    >>> add.outer(x, x)
    Array([[0, 1, 2, 3],
           [1, 2, 3, 4],
           [2, 3, 4, 5],
           [3, 4, 5, 6]], dtype=int32)
    >>> add.reduce(x)
    Array(6, dtype=int32)
    >>> add.accumulate(x)
    Array([0, 1, 3, 6], dtype=int32)
    >>> add.at(x, 1, 10, inplace=False)
    Array([ 0, 11,  2,  3], dtype=int32)
  r   )r   )r!   r"   r#   r   s       r*   
frompyfuncr   V  s    J 
tS$	22r,   )
r"   r   r#   r   r!   r   r   r   r   r   )(r$   
__future__r   collections.abcr   r   r&   typingr   numpyrz   jax._srcr   jax._src.typingr   r   r	   jax._src.laxr
   r   r   jax._src.numpyr   r   r   jax._src.numpy.reductionsr   jax._src.numpy.utilr   r   r   jax._src.numpy.vectorizer   jax._src.utilr   r   exportr   r   r   r\   r,   r*   <module>r      s    * " $      7 7 %    # + / F F . 7 
K	   b6 b6 b6J "&$3$3+0$3 $3r,   