
    bi                         d dl Z d dlmZmZmZmZ d dlZd dlZddl	m
Z
mZ ddlmZmZ ddlmZ ddlmZmZmZ  e       rd dlZ	 	 dd	Z G d
 dee
      Zy)    N)ListOptionalTupleUnion   )ConfigMixinregister_to_config)	deprecateis_scipy_available)randn_tensor   )KarrasDiffusionSchedulersSchedulerMixinSchedulerOutputc           
      $   |dk(  rd }n|dk(  rd }nt        d|       g }t        |       D ]<  }|| z  }|dz   | z  }|j                  t        d ||       ||      z  z
  |             > t	        j
                  |t        j                        S )a  
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    cosinec                 f    t        j                  | dz   dz  t         j                  z  dz        dz  S )NgMb?gT㥛 ?r   )mathcospits    v/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.pyalpha_bar_fnz)betas_for_alpha_bar.<locals>.alpha_bar_fn;   s-    88QY%/$''9A=>!CC    expc                 2    t        j                  | dz        S )Ng      ()r   r   r   s    r   r   z)betas_for_alpha_bar.<locals>.alpha_bar_fn@   s    88AI&&r   z"Unsupported alpha_transform_type: r   dtype)
ValueErrorrangeappendmintorchtensorfloat32)num_diffusion_timestepsmax_betaalpha_transform_typer   betasit1t2s           r   betas_for_alpha_barr.   "   s    . x'	D 
	&	' =>R=STUUE*+ M((!e..S\"-R0@@@(KLM <<U]]33r   c            0          e Zd ZdZeD  cg c]  }|j
                   c}} ZdZeddddddd	d
dddddd
d
d
d
d
d e	d       dddfde
de	de	dedeeej                  ee	   f      de
dedede	de	dededed ed!ee   d"ee   d#ee   d$ee   d%ee	   d&e	d'ee   d(ed)e
f.d*       Zed+        ZdMd,e
d-eeej,                  f   fd.Zd/ej0                  d0ej0                  fd1Zd2 Zd3 Zd4ej0                  d0ej0                  fd5Zd4ej0                  d,e
d0ej0                  fd6Z	 dNd4ej0                  d,e
d7e	d8e	d0ej0                  f
d9Zdd:d;ej0                  d/ej0                  d0ej0                  fd<Zddd=d;ej0                  d/ej0                  d>eej0                     d0ej0                  fd?Z ddd=d@eej0                     d/ej0                  d>eej0                     d0ej0                  fdAZ!ddd=d@eej0                     d/ej0                  d>eej0                     d0ej0                  fdBZ"dC Z#	 	 	 dOd;ej0                  dDee
ej0                  f   d/ej0                  dEeej0                     dFed0ee$e%f   fdGZ&d/ej0                  d0ej0                  fdHZ'dIej0                  d>ej0                  dJejP                  d0ej0                  fdKZ)dL Z*yc c}} w )P"DPMSolverMultistepInverseScheduleru  
    `DPMSolverMultistepInverseScheduler` is the reverse scheduler of [`DPMSolverMultistepScheduler`].

    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.

    Args:
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        solver_order (`int`, defaults to 2):
            The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
            sampling, and `solver_order=3` for unconditional sampling.
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        thresholding (`bool`, defaults to `False`):
            Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
            as Stable Diffusion.
        dynamic_thresholding_ratio (`float`, defaults to 0.995):
            The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
        sample_max_value (`float`, defaults to 1.0):
            The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
            `algorithm_type="dpmsolver++"`.
        algorithm_type (`str`, defaults to `dpmsolver++`):
            Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
            `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
            paper, and the `dpmsolver++` type implements the algorithms in the
            [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
            `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
        solver_type (`str`, defaults to `midpoint`):
            Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
            sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
        lower_order_final (`bool`, defaults to `True`):
            Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
            stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
        euler_at_final (`bool`, defaults to `False`):
            Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
            richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
            steps, but sometimes may result in blurring.
        use_karras_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
            the sigmas are determined according to a sequence of noise levels {σi}.
        use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
        use_beta_sigmas (`bool`, *optional*, defaults to `False`):
            Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
            Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
        lambda_min_clipped (`float`, defaults to `-inf`):
            Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
            cosine (`squaredcos_cap_v2`) noise schedule.
        variance_type (`str`, *optional*):
            Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
            contains the predicted Gaussian variance.
        timestep_spacing (`str`, defaults to `"linspace"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        steps_offset (`int`, defaults to 0):
            An offset added to the inference steps, as required by some model families.
    r   i  g-C6?g{Gz?linearNr   epsilonFgףp=
?      ?dpmsolver++midpointTinflinspacer   num_train_timesteps
beta_startbeta_endbeta_scheduletrained_betassolver_orderprediction_typethresholdingdynamic_thresholding_ratiosample_max_valuealgorithm_typesolver_typelower_order_finaleuler_at_finaluse_karras_sigmasuse_exponential_sigmasuse_beta_sigmasuse_flow_sigmas
flow_shiftlambda_min_clippedvariance_typetimestep_spacingsteps_offsetc                    | j                   j                  rt               st        d      t	        | j                   j                  | j                   j
                  | j                   j                  g      dkD  rt        d      |dv rd| d}t        dd|       |+t        j                  |t        j                  	      | _        n|d
k(  r-t        j                  |||t        j                  	      | _        nk|dk(  r6t        j                  |dz  |dz  |t        j                  	      dz  | _        n0|dk(  rt        |      | _        nt        | d| j                          d| j                  z
  | _        t        j$                  | j"                  d      | _        t        j(                  | j&                        | _        t        j(                  d| j&                  z
        | _        t        j.                  | j*                        t        j.                  | j,                        z
  | _        d| j&                  z
  | j&                  z  dz  | _        d| _        |dvr2|dk(  r| j7                  d       nt        | d| j                          |dvr1|dv r| j7                  d       nt        | d| j                          d | _        t;        j                  d|dz
  |t:        j                  	      j=                         }t        j>                  |      | _         d g|z  | _!        d| _"        d | _#        | j2                  jI                  d      | _        || _        || _        || _        y )Nz:Make sure to install scipy if you want to use beta sigmas.r   znOnly one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.	dpmsolversde-dpmsolverzalgorithm_type zn is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` insteadz+algorithm_types dpmsolver and sde-dpmsolver1.0.0r   r1   scaled_linear      ?r   squaredcos_cap_v2z is not implemented for r3   r   dim)rQ   r4   rR   sde-dpmsolver++deisr4   )rB   )r5   heun)logrhobh1bh2r5   )rC   cpu)%configrH   r   ImportErrorsumrG   rF   r    r
   r$   r%   r&   r*   r7   r.   NotImplementedError	__class__alphascumprodalphas_cumprodsqrtalpha_tsigma_tloglambda_tsigmasinit_noise_sigmar	   num_inference_stepsnpcopy
from_numpy	timestepsmodel_outputslower_order_nums_step_indexto)selfr8   r9   r:   r;   r<   r=   r>   r?   r@   rA   rB   rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   deprecation_messagers   s                             r   __init__z+DPMSolverMultistepInverseScheduler.__init__   s
   6 ;;&&/A/CZ[[++T[[-O-OQUQ\Q\QnQnopstt A  ;;$3N3C  Dr  #sCWNab$m5==IDJh&
H>QY^YfYfgDJo-
C3H[chcpcpquvvDJ11,-@ADJ%7OPTP^P^O_&`aaDJJ&#mmDKKQ?zz$"5"56zz!d&9&9"9:		$,,/%))DLL2IID///43F3FF3N !$ !aa'''}'E)^,<<TUYUcUcTd*eff2266''J'?)[M9QRVR`R`Qa*bcc $( KK#6#:<OWYWaWabggi	)))4"Vl2 !kknnU+!2&<#.r   c                     | j                   S )zg
        The index counter for current timestep. It will increase 1 after each scheduler step.
        )rv   rx   s    r   
step_indexz-DPMSolverMultistepInverseScheduler.step_index   s    
 r   ro   devicec           	      ^   t        j                  t        j                  | j                  dg      | j                  j
                        j                         }| j                  j                  dz
  |z
  | _        | j                  j                  dk(  rbt        j                  d| j                  |dz         j                         dd j                         j                  t        j                        }n^| j                  j                  dk(  r| j                  dz   |dz   z  }t        j                   d|dz         |z  j                         dd j                         j                  t        j                        }|| j                  j"                  z  }n| j                  j                  dk(  r| j                  j                  |z  }t        j                   | j                  dz   d|       j                         ddd   j                         j                  t        j                        }|dz  }n"t%        | j                  j                   d      t        j&                  d| j(                  z
  | j(                  z  d	z        }t        j*                  |      }| j                  j,                  r| j/                  ||
      }t        j&                  |D cg c]  }| j1                  ||       c}      j                         }|j                         j                  t        j                        }t        j2                  ||dd g      j                  t        j4                        }n| j                  j6                  r| j9                  ||
      }t        j&                  |D cg c]  }| j1                  ||       c}      }t        j2                  ||dd g      j                  t        j4                        }n?| j                  j:                  r| j=                  ||
      }t        j&                  |D cg c]  }| j1                  ||       c}      }t        j2                  ||dd g      j                  t        j4                        }n| j                  j>                  rt        j                  dd| j                  j                  z  |dz         }	d|	z
  }t        j                  | j                  j@                  |z  d| j                  j@                  dz
  |z  z   z        dd j                         }|| j                  j                  z  j                         }t        j2                  ||dd g      j                  t        j4                        }nt        jB                  |t        j                   dtE        |            |      }d| j(                  | j                     z
  | j(                  | j                     z  d	z  }
t        j2                  ||
gg      j                  t        j4                        }t        jF                  |      | _$        t        jJ                  |d      \  }}|t        jL                  |         }t        jF                  |      jO                  |t         j                        | _(        tE        |      | _)        dg| j                  jT                  z  | _+        d| _,        d| _-        | jH                  jO                  d      | _$        yc c}w c c}w c c}w )a  
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        r   r   r7   NleadingtrailingzY is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'.rU   )	in_sigmasro   r3   T)return_indexr~   r   r_   ).r$   searchsortedfliprl   r`   rK   itemr8   noisiest_timesteprM   rp   r7   roundrq   astypeint64arangerN   r    arrayrg   rk   rF   _convert_to_karras_sigma_to_tconcatenater&   rG   _convert_to_exponentialrH   _convert_to_betarI   rJ   interplenrr   rm   uniquesortrw   rs   ro   r=   rt   ru   rv   )rx   ro   r~   clipped_idxrs   
step_ratiorm   
log_sigmassigmare   	sigma_max_unique_indicess                r   set_timestepsz0DPMSolverMultistepInverseScheduler.set_timesteps   s    ((DMMA3)GIgIghmmo!%!@!@1!D{!R ;;'':5At557JQ7NOUUWX[Y[\aacjjkmkskst  [[))Y60014:MPQ:QRJ 1&9A&=>KRRTUXVXY^^`gghjhphpqI111I[[))Z788;NNJ 		$"8"81"<a*MSSUVZXZVZ[``biijljrjrsINI;;//0 1+ + 
 A 3 33t7J7JJsRSVVF^
;;((,,vSf,gFSY!Z%$"2"25*"E!Z[aacI!(//9I^^VVBC[$9:AA"**MF[[//11FXk1lFSY!Z%$"2"25*"E!Z[I^^VVBC[$9:AA"**MF[[((**VQd*eFSY!Z%$"2"25*"E!Z[I^^VVBC[$9:AA"**MF[[(([[A(G(G$GI\_`I`aF6\FWWT[[33f<T[[E[E[^_E_ciDi@ijklomopuuwF$++"A"AAGGII^^VVBC[$9:AA"**MFYYy"))As6{*CVLFT(()?)?@@DDWDWX\XnXnDooI ^^Vi[$9:AA"**MF&&v. IIidC>bggn56	)))477vU[[7Y#&y>  
KK$$% !"  kknnU+U "[
 "[ "[s   \ %\%:\*samplereturnc                 b   |j                   }|j                  ^}}}|t        j                  t        j                  fvr|j                         }|j                  ||t        j                  |      z        }|j                         }t        j                  || j                  j                  d      }t        j                  |d| j                  j                        }|j                  d      }t        j                  || |      |z  } |j                  ||g| }|j!                  |      }|S )a{  
        "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
        prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
        s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
        pixels from saturation at each step. We find that dynamic thresholding results in significantly better
        photorealism as well as better image-text alignment, especially when using very large guidance weights."

        https://huggingface.co/papers/2205.11487
        r   rW   )r#   max)r   shaper$   r&   float64floatreshaperp   prodabsquantiler`   r@   clamprA   	unsqueezerw   )rx   r   r   
batch_sizechannelsremaining_dims
abs_sampless           r   _threshold_samplez4DPMSolverMultistepInverseScheduler._threshold_sampleM  s     06-
H~66\\^F 
Hrww~7N,NOZZ\
NN:t{{'M'MSTUKK1$++66
 KKNVaR+a/
HF~F5!r   c                    t        j                  t        j                  |d            }||d d t         j                  f   z
  }t        j                  |dk\  d      j                  d      j                  |j                  d   dz
        }|dz   }||   }||   }||z
  ||z
  z  }	t        j                  |	dd      }	d|	z
  |z  |	|z  z   }
|
j                  |j                        }
|
S )Ng|=r   )axisr   )r   r   )	rp   rk   maximumnewaxiscumsumargmaxclipr   r   )rx   r   r   	log_sigmadistslow_idxhigh_idxlowhighwr   s              r   r   z.DPMSolverMultistepInverseScheduler._sigma_to_to  s    FF2::eU34	 Jq"**}55 ))UaZq188a8@EE*JZJZ[\J]`aJaEbQ;!(# 9_t,GGAq! UgH,IIekk"r   c                 r    | j                   j                  rd|z
  }|}||fS d|dz  dz   dz  z  }||z  }||fS )Nr   r   rU   )r`   rI   )rx   r   ri   rj   s       r   _sigma_to_alpha_sigma_tz:DPMSolverMultistepInverseScheduler._sigma_to_alpha_sigma_t  sW    ;;&&%iGG
  E1HqLS01GgoGr   r   c                    t        | j                  d      r| j                  j                  }nd}t        | j                  d      r| j                  j                  }nd}||n|d   j	                         }||n|d   j	                         }d}t        j                  dd|      }|d|z  z  }|d|z  z  }||||z
  z  z   |z  }	|	S )z6Constructs the noise schedule of Karras et al. (2022).	sigma_minNr   r   r   g      @r   )hasattrr`   r   r   r   rp   r7   )
rx   r   ro   r   r   rhorampmin_inv_rhomax_inv_rhorm   s
             r   r   z5DPMSolverMultistepInverseScheduler._convert_to_karras  s    
 4;;,--II4;;,--II!*!6IIbM<N<N<P	!*!6IIaL<M<M<O	{{1a!45AG,AG,k(A BBsJr   c                    t        | j                  d      r| j                  j                  }nd}t        | j                  d      r| j                  j                  }nd}||n|d   j	                         }||n|d   j	                         }t        j                  t        j                  t        j                  |      t        j                  |      |            }|S )z)Constructs an exponential noise schedule.r   Nr   r   r   )
r   r`   r   r   r   rp   r   r7   r   rk   )rx   r   ro   r   r   rm   s         r   r   z:DPMSolverMultistepInverseScheduler._convert_to_exponential  s    
 4;;,--II4;;,--II!*!6IIbM<N<N<P	!*!6IIaL<M<M<O	DHHY$7)9LNabcr   alphabetac           
      (   t        | j                  d      r| j                  j                  }nd}t        | j                  d      r| j                  j                  }nd}||n|d   j	                         }||n|d   j	                         }t        j                  dt        j                  dd|      z
  D cg c]-  }t        j                  j                  j                  |||      / c}D cg c]  }||||z
  z  z    c}      }	|	S c c}w c c}w )zJFrom "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)r   Nr   r   r   r   )r   r`   r   r   r   rp   r   r7   scipystatsr   ppf)
rx   r   ro   r   r   r   r   timestepr   rm   s
             r   r   z3DPMSolverMultistepInverseScheduler._convert_to_beta  s     4;;,--II4;;,--II!*!6IIbM<N<N<P	!*!6IIaL<M<M<O	
 %&Aq:M(N$N  KK$$((5$? SI	$9:;
 s   82D
/Dr   model_outputc                   t        |      dkD  r|d   n|j                  dd      }|t        |      dkD  r|d   }nt        d      |t        ddd       | j                  j
                  d	v rj| j                  j                  d
k(  r\| j                  j                  dv r|ddddf   }| j                  | j                     }| j                  |      \  }}|||z  z
  |z  }	n| j                  j                  dk(  r|}	n| j                  j                  dk(  r9| j                  | j                     }| j                  |      \  }}||z  ||z  z
  }	n^| j                  j                  dk(  r"| j                  | j                     }|||z  z
  }	n#t        d| j                  j                   d      | j                  j                  r| j                  |	      }	|	S | j                  j
                  dv rs| j                  j                  d
k(  r'| j                  j                  dv r|ddddf   }
n|}
n| j                  j                  dk(  r9| j                  | j                     }| j                  |      \  }}|||z  z
  |z  }
nu| j                  j                  dk(  r9| j                  | j                     }| j                  |      \  }}||z  ||z  z   }
n#t        d| j                  j                   d      | j                  j                  rT| j                  | j                     }| j                  |      \  }}|||
z  z
  |z  }	| j                  |	      }	|||	z  z
  |z  }
|
S y)a0  
        Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
        designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
        integral of the data prediction model.

        <Tip>

        The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
        prediction and data prediction models.

        </Tip>

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The converted model output.
        r   r   Nr   /missing `sample` as a required keyword argumentrs   rS   Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`)r4   rY   r2   )learnedlearned_range   r   v_predictionflow_predictionzprediction_type given as zn must be one of `epsilon`, `sample`, `v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler.rP   z[ must be one of `epsilon`, `sample`, or `v_prediction` for the DPMSolverMultistepScheduler.)r   popr    r
   r`   rB   r>   rL   rm   r}   r   r?   r   )rx   r   r   argskwargsr   r   ri   rj   x0_predr2   s              r   convert_model_outputz7DPMSolverMultistepInverseScheduler.convert_model_output  sb   : "$i!m47J1M>4y1}a !RSS Z ;;%%)KK{{**i7;;,,0LL#/2A2#6LDOO4#'#?#?#F !Gl$::gE,,8&,,>DOO4#'#?#?#F !F*W|-CC,,0AA++doo6 7\#99 /0K0K/L M` ` 
 {{''009N [[''+II{{**i7;;,,0LL*1bqb51G*G,,8DOO4#'#?#?#F !Gl$::gE,,>DOO4#'#?#?#F !L07V3CC /0K0K/L MK K 
 {{''DOO4#'#?#?#F !Gg$55@009!Gg$55@N9 Jr   r   noiser   c          	         t        |      dkD  r|d   n|j                  dd      }t        |      dkD  r|d   n|j                  dd      }|t        |      dkD  r|d   }nt        d      |t        dd	d
       |t        dd	d       | j                  | j
                  dz      | j                  | j
                     }	}| j                  |      \  }
}| j                  |	      \  }}	t        j                  |
      t        j                  |      z
  }t        j                  |      t        j                  |	      z
  }||z
  }| j                  j                  dk(  r*||	z  |z  |
t        j                  |       dz
  z  |z  z
  }|S | j                  j                  dk(  r)|
|z  |z  |t        j                  |      dz
  z  |z  z
  }|S | j                  j                  dk(  r||J ||	z  t        j                  |       z  |z  |
dt        j                  d|z        z
  z  |z  z   |t        j                  dt        j                  d|z        z
        z  |z  z   }|S | j                  j                  dk(  rc|J |
|z  |z  d|t        j                  |      dz
  z  z  |z  z
  |t        j                  t        j                  d|z        dz
        z  |z  z   }S )a  
        One step for the first-order DPMSolver (equivalent to DDIM).

        Args:
            model_output (`torch.Tensor`):
                The direct output from the learned diffusion model.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   r   Nr   prev_timestepr   r   rs   rS   r   Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r4   r3   rQ   rY          rR          @r   r   r    r
   rm   r}   r   r$   rk   r`   rB   r   rh   )rx   r   r   r   r   r   r   r   rj   sigma_sri   alpha_srl   lambda_shx_ts                   r   dpm_solver_first_order_updatez@DPMSolverMultistepInverseScheduler.dpm_solver_first_order_updateI  s   * "$i!m47J1M#&t9q=QfjjRV6W>4y1}a !RSS Z $ ^  ;;t':;T[[=Y77@77@99W%		'(::99W%		'(::x;;%%6W$.'UYYr]S=P2QUa1aaC" 
! [['';6W$.'UYYq\C=O2PT`1``C 
 [[''+<<$$$7"UYYr]2f<a%))D1H"556,FGEJJsUYYrAv->'>??%GH  
 [[''?:$$$7"f,EIIaL3$678<GHEJJuyyQ'7#'=>>FG 
 
r   model_output_listc          	         t        |      dkD  r|d   n|j                  dd      }t        |      dkD  r|d   n|j                  dd      }|t        |      dkD  r|d   }nt        d      |t        ddd	       |t        ddd
       | j                  | j
                  dz      | j                  | j
                     | j                  | j
                  dz
     }
}	}| j                  |      \  }}| j                  |	      \  }}	| j                  |
      \  }}
t        j                  |      t        j                  |      z
  }t        j                  |      t        j                  |	      z
  }t        j                  |      t        j                  |
      z
  }|d   |d   }}||z
  ||z
  }}||z  }|d|z  ||z
  z  }}| j                  j                  dk(  r| j                  j                  dk(  rM||	z  |z  |t        j                  |       dz
  z  |z  z
  d|t        j                  |       dz
  z  z  |z  z
  }|S | j                  j                  dk(  rN||	z  |z  |t        j                  |       dz
  z  |z  z
  |t        j                  |       dz
  |z  dz   z  |z  z   }S | j                  j                  dk(  r| j                  j                  dk(  rK||z  |z  |t        j                  |      dz
  z  |z  z
  d|t        j                  |      dz
  z  z  |z  z
  }|S | j                  j                  dk(  rL||z  |z  |t        j                  |      dz
  z  |z  z
  |t        j                  |      dz
  |z  dz
  z  |z  z
  }S | j                  j                  dk(  rv|J | j                  j                  dk(  r||	z  t        j                  |       z  |z  |dt        j                  d|z        z
  z  |z  z   d|dt        j                  d|z        z
  z  z  |z  z   |t        j                  dt        j                  d|z        z
        z  |z  z   }|S | j                  j                  dk(  r||	z  t        j                  |       z  |z  |dt        j                  d|z        z
  z  |z  z   |dt        j                  d|z        z
  d|z  z  dz   z  |z  z   |t        j                  dt        j                  d|z        z
        z  |z  z   }S | j                  j                  dk(  r=|J | j                  j                  dk(  r||z  |z  d|t        j                  |      dz
  z  z  |z  z
  |t        j                  |      dz
  z  |z  z
  |t        j                  t        j                  d|z        dz
        z  |z  z   }|S | j                  j                  dk(  r||z  |z  d|t        j                  |      dz
  z  z  |z  z
  d|t        j                  |      dz
  |z  dz
  z  z  |z  z
  |t        j                  t        j                  d|z        dz
        z  |z  z   }S )a  
        One step for the second-order multistep DPMSolver.

        Args:
            model_output_list (`List[torch.Tensor]`):
                The direct outputs from learned diffusion model at current and latter timesteps.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   timestep_listNr   r   r   r   rS   Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`r   r   r   r3   r4   r5   rU   r[   rQ   rY   r   rR   r   )r   r   r    r
   rm   r}   r   r$   rk   r`   rB   rC   r   rh   )rx   r   r   r   r   r   r   r   rj   sigma_s0sigma_s1ri   alpha_s0alpha_s1rl   	lambda_s0	lambda_s1m0m1r   h_0r0D0D1r   s                            r   (multistep_dpm_solver_second_order_updatezKDPMSolverMultistepInverseScheduler.multistep_dpm_solver_second_order_update  s   * $'t9q=QfjjRV6W#&t9q=QfjjRV6W>4y1}a !RSS$ ^ $ ^ KK!+,KK(KK!+, $  77@!99(C(!99(C(99W%		'(::IIh'%))H*==	IIh'%))H*==	"2&(9"(=BI%y9'<31WcBh27+B;;%%6{{&&*4x'61%))QB-#"56"<=W		1"(;<=BC r 
i ((F2x'61%))QB-#"56"<=599aR=3#6!";c"ABbHI f 
] [['';6{{&&*4x'61%))A,"45;<W		!s(:;<rAB V 
M ((F2x'61%))A,"45;<599Q<##5":S"@ARGH J 
A [[''+<<$$${{&&*4x'%))QB-76A!eiiq&9"9:b@AWEIIdQh,?(?@ABFG 

3261B+B CCeKL : 
/ ((F2x'%))QB-76A!eiiq&9"9:b@A34!8)<#<"JS"PQUWWX 

3261B+B CCeKL , 
! [[''?:$$${{&&*4x'61W		!s(:;<rAB%))A,"45;< 

599QU+;c+A BBUJK  
 ((F2x'61W		!s(:;<rABW1);q(@3(FGH2MN 

599QU+;c+A BBUJK  
r   c          	         t        |      dkD  r|d   n|j                  dd      }t        |      dkD  r|d   n|j                  dd      }|t        |      dkD  r|d   }nt        d      |t        ddd	       |t        ddd
       | j                  | j
                  dz      | j                  | j
                     | j                  | j
                  dz
     | j                  | j
                  dz
     f\  }}	}
}| j                  |      \  }}| j                  |	      \  }}	| j                  |
      \  }}
| j                  |      \  }}t        j                  |      t        j                  |      z
  }t        j                  |      t        j                  |	      z
  }t        j                  |      t        j                  |
      z
  }t        j                  |      t        j                  |      z
  }|d   |d   |d   }}}||z
  ||z
  ||z
  }}}||z  ||z  }}|}d|z  ||z
  z  d|z  ||z
  z  }}||||z   z  ||z
  z  z   }d||z   z  ||z
  z  } | j                  j                  dk(  r|||	z  |z  |t        j                  |       dz
  z  |z  z
  |t        j                  |       dz
  |z  dz   z  |z  z   |t        j                  |       dz
  |z   |dz  z  dz
  z  | z  z
  }!|!S | j                  j                  dk(  ry||z  |z  |t        j                  |      dz
  z  |z  z
  |t        j                  |      dz
  |z  dz
  z  |z  z
  |t        j                  |      dz
  |z
  |dz  z  dz
  z  | z  z
  }!|!S | j                  j                  dk(  r|J ||	z  t        j                  |       z  |z  |dt        j                  d|z        z
  z  |z  z   |dt        j                  d|z        z
  d|z  z  dz   z  |z  z   |dt        j                  d|z        z
  d|z  z
  d|z  dz  z  dz
  z  | z  z   |t        j                  dt        j                  d|z        z
        z  |z  z   }!!S )a  
        One step for the third-order multistep DPMSolver.

        Args:
            model_output_list (`List[torch.Tensor]`):
                The direct outputs from learned diffusion model at current and latter timesteps.
            sample (`torch.Tensor`):
                A current instance of a sample created by diffusion process.

        Returns:
            `torch.Tensor`:
                The sample tensor at the previous timestep.
        r   r   Nr   r   r   r   rS   r   r   r   r   r3   r4   rU   rQ   rY   r   r   r   )"rx   r   r   r   r   r   r   r   rj   r   r   sigma_s2ri   r   r   alpha_s2rl   r   r   	lambda_s2r   r   m2r   r   h_1r   r1r   D1_0D1_1r   D2r   s"                                     r   'multistep_dpm_solver_third_order_updatezJDPMSolverMultistepInverseScheduler.multistep_dpm_solver_third_order_update  s   , $'t9q=QfjjRV6W#&t9q=QfjjRV6W>4y1}a !RSS$ ^ $ ^ KK!+,KK(KK!+,KK!+,	1
-8X  77@!99(C(!99(C(!99(C(99W%		'(::IIh'%))H*==	IIh'%))H*==	IIh'%))H*==	&r*,=b,ACTUWCXB*I	,A9yCX3q#'BBh27+cBh27-CdR27^t44R"Wo$+.;;%%6 8#v-eiimc12b89uyy!}s2a7#=>"DE uyy!}s2Q6!Q$>DEKL . 
# [['';6 8#v-eiilS01R78uyy|c1Q6<=CD uyy|c1A5A=CDJK  
 [[''+<<$$$8#eiim3v=cEIIdQh$778B>?sUYYtax%88TAXFLMQSST sUYYtax%8837BsQwSTnTWZZ[_aab EJJsUYYrAv->'>??%G	H  
r   c                    t        |t        j                        r%|j                  | j                  j
                        }| j                  |k(  j                         }t        |      dk(  r t        | j                        dz
  }|| _	        y t        |      dkD  r|d   j                         }|| _	        y |d   j                         }|| _	        y )Nr   r   )

isinstancer$   Tensorrw   rs   r~   nonzeror   r   rv   )rx   r   index_candidatesr}   s       r   _init_step_indexz3DPMSolverMultistepInverseScheduler._init_step_indexj  s    h-{{4>>#8#89H NNh6??A A%T^^,q0J & !"Q&)!,113J & *!,113J%r   r   variance_noisereturn_dictc                    | j                   t        d      | j                  | j                  |       | j                  t	        | j
                        dz
  k(  xrH | j                  j                  xs0 | j                  j                  xr t	        | j
                        dk  }| j                  t	        | j
                        dz
  k(  xr0 | j                  j                  xr t	        | j
                        dk  }| j                  ||      }t        | j                  j                  dz
        D ]!  }	| j                  |	dz      | j                  |	<   # || j                  d<   | j                  j                  dv r0|.t        |j                  ||j                   |j"                  	      }
n| j                  j                  dv r|}
nd}
| j                  j                  dk(  s| j$                  dk  s|r| j'                  |||

      }nf| j                  j                  dk(  s| j$                  dk  s|r| j)                  | j                  ||

      }n| j+                  | j                  |      }| j$                  | j                  j                  k  r| xj$                  dz  c_        | xj,                  dz  c_        |s|fS t/        |      S )a  
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the multistep DPMSolver.

        Args:
            model_output (`torch.Tensor`):
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.Tensor`):
                A current instance of a sample created by the diffusion process.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            variance_noise (`torch.Tensor`):
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`CycleDiffusion`].
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.

        Returns:
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.

        NzaNumber of inference steps is 'None', you need to run 'set_timesteps' after creating the schedulerr      r   r   r   )rR   rY   )	generatorr~   r   r   )prev_sample)ro   r    r}   r  r   rs   r`   rE   rD   r   r!   r=   rt   rB   r   r   r~   r   ru   r   r   r  rv   r   )rx   r   r   r   r  r  r  rD   lower_order_secondr+   r   r  s               r   stepz'DPMSolverMultistepInverseScheduler.step}  s   D ##+s  ??"!!(+ "__DNN0Ca0GG 
KK&&f4;;+H+H+eSQUQ_Q_M`ceMe 	 __DNN 3a 77wT[[=Z=Zw_bcgcqcq_ruw_w 	 00f0Mt{{//!34 	>A$($6$6q1u$=Dq!	>!-2;;%%)MMR`Rh ""i@S@S[g[m[mE [[''+OO"EE;;##q(D,A,AA,EIZ<<\RX`e<fK[[%%*d.C.Ca.GK]GGHZHZciqvGwKFFtGYGYbhFiK  4;;#;#;;!!Q&! 	A>!;77r   c                     |S )a?  
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
            sample (`torch.Tensor`):
                The input sample.

        Returns:
            `torch.Tensor`:
                A scaled input sample.
         )rx   r   r   r   s       r   scale_model_inputz4DPMSolverMultistepInverseScheduler.scale_model_input  s	     r   original_samplesrs   c                 8   | j                   j                  |j                  |j                        }|j                  j                  dk(  rvt        j                  |      ra| j                  j                  |j                  t
        j                        }|j                  |j                  t
        j                        }n@| j                  j                  |j                        }|j                  |j                        }g }|D ]x  }||k(  j                         }t        |      dk(  rt        |      dz
  }	n5t        |      dkD  r|d   j                         }	n|d   j                         }	|j                  |	       z ||   j                         }
t        |
j                        t        |j                        k  r=|
j                  d      }
t        |
j                        t        |j                        k  r=| j!                  |
      \  }}||z  ||z  z   }|S )Nr   mpsr   r   r   r   )rm   rw   r~   r   typer$   is_floating_pointrs   r&   r  r   r   r"   flattenr   r   r   )rx   r  r   rs   rm   schedule_timestepsstep_indicesr   r  r}   r   ri   rj   noisy_sampless                 r   	add_noisez,DPMSolverMultistepInverseScheduler.add_noise  s    '7'>'>FVF\F\]""''50U5L5LY5W!%!2!23C3J3JRWR_R_!2!`!%5%<%<EMMRI!%!2!23C3J3J!K!%5%<%<=I! 	,H 2h >GGI#$) !34q8
%&*-a0557
-a0557

+	, |$,,.%++%5%;%;!<<OOB'E %++%5%;%;!<<  77>"22Wu_Dr   c                 .    | j                   j                  S )N)r`   r8   r|   s    r   __len__z*DPMSolverMultistepInverseScheduler.__len__  s    {{...r   )NN)333333?r*  )NNT)+__name__
__module____qualname____doc__r   name_compatiblesorderr	   r   intstrr   r   rp   ndarrayr   boolrz   propertyr}   r$   r~   r   r  r   r   r   r   r   r   r   r   r   r  r  r   r   r  r  	IntTensorr'  r)  ).0es   00r   r0   r0   N   s   DL %>>qAFF>LE $("%BF(",1"%+%"&$,116*/*/&)%*5\M'+ *1S/ S/ S/ 	S/
 S/  bjj$u+&= >?S/ S/ S/ S/ %*S/  S/ S/ S/  S/ S/  $D>!S/" !)#S/$ "$%S/& "$'S/( UO)S/* "+S/,  }-S/. /S/0 1S/ S/j    U, U,U3PUP\P\K\E] U,p  D0 ELL RWR^R^ 4 TW \a\h\h . dg<?HM[`	H  $	dlld 	d 
dV  $(,CllC 	C
 %C 
CT  $(,y-y 	y
 %y 
y@  $(,]-] 	]
 %] 
]~&0 15 P8llP8 U\\)*P8 	P8 !.P8 P8 
%	&P8f %,, !,,! ||! ??	!
 
!F/[ ?s   Kr0   )g+?r   )r   typingr   r   r   r   numpyrp   r$   configuration_utilsr   r	   utilsr
   r   utils.torch_utilsr   scheduling_utilsr   r   r   scipy.statsr   r.   r0   r  r   r   <module>rA     sN   "  / /   A 1 , X X  !)4Xu/ u/r   