
    biU                        d dl mZmZmZmZ d dlZd dlZd dlm	Z	 ddl
mZ ddlmZmZ ddlmZ ddlmZmZmZ dd	lmZ d
dlmZmZ ddlmZ  e       rd dlmc mZ dZ ndZ  ejB                  e"      Z#dZ$ddZ% G d de      Z&y)    )CallableListOptionalUnionN)XLMRobertaTokenizer   )VaeImageProcessor)UNet2DConditionModelVQModel)DDIMScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipelineImagePipelineOutput   )MultilingualCLIPTFa  
    Examples:
        ```py
        >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline
        >>> from diffusers.utils import load_image
        >>> import torch

        >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
        ...     "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
        ... )
        >>> pipe_prior.to("cuda")

        >>> prompt = "A red cartoon frog, 4k"
        >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False)

        >>> pipe = KandinskyImg2ImgPipeline.from_pretrained(
        ...     "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16
        ... )
        >>> pipe.to("cuda")

        >>> init_image = load_image(
        ...     "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
        ...     "/kandinsky/frog.png"
        ... )

        >>> image = pipe(
        ...     prompt,
        ...     image=init_image,
        ...     image_embeds=image_emb,
        ...     negative_image_embeds=zero_image_emb,
        ...     height=768,
        ...     width=768,
        ...     num_inference_steps=100,
        ...     strength=0.2,
        ... ).images

        >>> image[0].save("red_frog.png")
        ```
c                 v    | |dz  z  }| |dz  z  dk7  r|dz  }||dz  z  }||dz  z  dk7  r|dz  }||z  ||z  fS )Nr   r   r    )hwscale_factornew_hnew_ws        s/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.pyget_new_h_wr   W   sg    q E<?a
q E<?a
<!555    c            #           e Zd ZdZdZdedededede	f
 fdZ
d	 Zd
 Z	 d#dZdej                  dej                  dej                   dej                  fdZ ej$                          ee      	 	 	 	 	 	 	 	 	 	 	 	 d$deeee   f   deej                  ej2                  j2                  eej                     eej2                  j2                     f   dej                  dej                  deeeee   f      dededededededeeej:                  eej:                     f      dee   deeeeej                  gdf      d ed!ef d"              Z  xZ!S )%KandinskyImg2ImgPipelinea  
    Pipeline for image-to-image generation using Kandinsky

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        text_encoder ([`MultilingualCLIP`]):
            Frozen text-encoder.
        tokenizer ([`XLMRobertaTokenizer`]):
            Tokenizer of class
        scheduler ([`DDIMScheduler`]):
            A scheduler to be used in combination with `unet` to generate image latents.
        unet ([`UNet2DConditionModel`]):
            Conditional U-Net architecture to denoise the image embedding.
        movq ([`VQModel`]):
            MoVQ image encoder and decoder
    ztext_encoder->unet->movqtext_encodermovq	tokenizerunet	schedulerc                 n   t         |           | j                  |||||       t        | dd       r/dt	        | j
                  j                  j                        dz
  z  nd| _        t        | dd       r | j
                  j                  j                  nd}t        | j                  |dd      | _        y )	N)r"   r$   r%   r&   r#   r#   r   r         bicubic)vae_scale_factorvae_latent_channelsresamplereducing_gap)super__init__register_modulesgetattrlenr#   configblock_out_channelsmovq_scale_factorlatent_channelsr	   image_processor)selfr"   r#   r$   r%   r&   movq_latent_channels	__class__s          r   r0   z!KandinskyImg2ImgPipeline.__init__w   s     	% 	 	
 DK4QWY]C^A#dii&&99:Q>?de 	 DK4QWY]C^tyy//??de0!33 4	 
r   c                     t        t        ||z        |      }t        ||z
  d      }| j                  j                  |d  }|||z
  fS )Nr   )minintmaxr&   	timesteps)r9   num_inference_stepsstrengthdeviceinit_timestept_startr@   s          r   get_timestepsz&KandinskyImg2ImgPipeline.get_timesteps   sS    C 3h >?ATU)M91=NN,,WX6	-777r   c                    |t        ||||      }n;|j                  |k7  rt        d|j                   d|       |j                  |      }||j                  z  }|j                  }t        ||||      }| j                  |||      }|S )N)	generatorrC   dtypezUnexpected latents shape, got z, expected )r   shape
ValueErrortoinit_noise_sigma	add_noise)	r9   latentslatent_timesteprJ   rI   rC   rH   r&   noises	            r   prepare_latentsz(KandinskyImg2ImgPipeline.prepare_latents   s    ?"5IfTYZG}}% #A'--P[\a[b!cddjj(GI666UieT..%Ar   Nc                 h   t        |t              rt        |      nd}| j                  |dddddd      }|j                  }| j                  |dd      j                  }	|	j
                  d	   |j
                  d	   k\  rt        j                  ||	      sj| j                  j                  |	d d | j                  j                  dz
  d	f         }
t        j                  d
| j                  j                   d|
        |j                  |      }|j                  j                  |      }| j                  ||      \  }}|j                  |d      }|j                  |d      }|j                  |d      }|r|dg|z  }nt!        |      t!        |      ur$t#        dt!        |       dt!        |       d      t        |t$              r|g}n1|t        |      k7  r!t'        d| dt        |       d| d| d	      |}| j                  |dddddd      }|j                  j                  |      }|j                  j                  |      }| j                  ||      \  }}|j
                  d   }|j)                  d|      }|j+                  ||z  |      }|j
                  d   }|j)                  d|d      }|j+                  ||z  |d	      }|j                  |d      }t        j,                  ||g      }t        j,                  ||g      }t        j,                  ||g      }|||fS )Nr   
max_lengthM   Tpt)paddingrT   
truncationreturn_attention_maskadd_special_tokensreturn_tensorslongest)rW   r[   z\The following part of your input was truncated because CLIP can only handle sequences up to z	 tokens: )	input_idsattention_maskr   dim z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)
isinstancelistr3   r$   r^   rJ   torchequalbatch_decodemodel_max_lengthloggerwarningrL   r_   r"   repeat_interleavetype	TypeErrorstrrK   repeatviewcat)r9   promptrC   num_images_per_promptdo_classifier_free_guidancenegative_prompt
batch_sizetext_inputstext_input_idsuntruncated_idsremoved_text	text_maskprompt_embedstext_encoder_hidden_statesuncond_tokensuncond_inputuncond_text_input_idsuncond_text_masknegative_prompt_embeds!uncond_text_encoder_hidden_statesseq_lens                        r   _encode_promptz'KandinskyImg2ImgPipeline._encode_prompt   s    %/vt$<S[!
nn "&# % 
 %....SW.Xbb  $(<(<R(@@UcetIu>>66q$..JiJilmJmprJrGr7stLNNNN334Il^M
 (**62..11&9	484E4E$Y 5F 5
11 &778MST7U%?%Q%QRgmn%Q%o"//0E1/M	&&!#z 1fT/%::UVZ[jVkUl mV~Q(  OS1!0 1s?33 )/)::J3K_J` ax/
| <33  !0>>$&*#'# * L %1$:$:$=$=f$E!+::==fEHLHYHY/@P IZ IE"$E -2215G%;%B%B1F[%\"%;%@%@NcAcel%m"7==a@G0Q0X0XYZ\qst0u-0Q0V0V22GR1-  0AABW]^A_ "II'=}&MNM).4UWq3r)s&		#3Y"?@I8)CCr   original_samplesrQ   r@   returnc                 ,   t        j                  dddt         j                        }d|z
  }t        j                  |d      }|j	                  |j
                  |j                        }|j	                  |j
                        }||   d	z  }|j                         }t        |j                        t        |j                        k  r=|j                  d
      }t        |j                        t        |j                        k  r=d||   z
  d	z  }|j                         }t        |j                        t        |j                        k  r=|j                  d
      }t        |j                        t        |j                        k  r=||z  ||z  z   }	|	S )Ng-C6?g{Gz?i  )rI         ?r   r`   )rC   rI   g      ?r]   r   )rf   linspacefloat32cumprodrL   rC   rI   flattenr3   rJ   	unsqueeze)
r9   r   rQ   r@   betasalphasalphas_cumprodsqrt_alpha_prodsqrt_one_minus_alpha_prodnoisy_sampless
             r   rN   z"KandinskyImg2ImgPipeline.add_noise  sr    vtTGuv15'**2B2I2IQaQgQg*hLL!1!8!89	(3s:)113/''(3/?/E/E+FF-77;O /''(3/?/E/E+FF &'	)B%Bs$J!$=$E$E$G!+112S9I9O9O5PP(A(K(KB(O% +112S9I9O9O5PP (*::=VY^=^^r   rs   imageimage_embedsnegative_image_embedsrv   heightwidthrA   rB   guidance_scalert   rH   output_typecallbackcallback_stepsreturn_dictc           
      
   t        |t              rd}n3t        |t              rt        |      }nt	        dt        |             | j                  }||z  }|
dkD  }| j                  |||||      \  }}}t        |t              rt        j                  |d      }t        |t              rt        j                  |d      }|rZ|j                  |d      }|j                  |d      }t        j                  ||gd      j                  |j                  |      }t        |t              s|g}t        d |D              s&t	        d|D cg c]  }t        |       c} d	      t        j                  |D cg c]  }| j                  j                  |||      ! c}d      }|j                  |j                  |      }| j                   j#                  |      d
   }|j                  |d      }| j$                  j'                  ||       | j)                  ||	|      \  }}t+        | j$                  j,                  j.                  |	z        dz
  }t        j0                  |g|z  |j                  |      }| j2                  j,                  j4                  }t7        ||| j8                        \  }}| j;                  ||||||f|j                  ||| j$                        }t=        | j?                  |            D ]  \  }}|rt        j                  |gdz        n|}||d}| j3                  ||||d      d   }|ro|jA                  |jB                  d   d      \  }} |jE                  d      \  }!}"| jE                  d      \  }}#|!|
|"|!z
  z  z   }t        j                  ||#gd      }tG        | j$                  j,                  d      r"| j$                  j,                  jH                  dv s#|jA                  |jB                  d   d      \  }}| j$                  jK                  ||||      jL                  }|,||z  dk(  r$|tO        | j$                  dd      z  }$ ||$||       tP        swtS        jT                           | j                   jW                  |d      d   }| jY                          |dvrt	        d|       | j                  j[                  ||      }|s|fS t]        |      S c c}w c c}w )a  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`):
                The prompt or prompts to guide the image generation.
            image (`torch.Tensor`, `PIL.Image.Image`):
                `Image`, or tensor representing an image batch, that will be used as the starting point for the
                process.
            image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
                The clip image embeddings for text prompt, that will be used to condition the image generation.
            negative_image_embeds (`torch.Tensor` or `List[torch.Tensor]`):
                The clip image embeddings for negative text prompt, will be used to condition the image generation.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
                if `guidance_scale` is less than `1`).
            height (`int`, *optional*, defaults to 512):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to 512):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 100):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            strength (`float`, *optional*, defaults to 0.3):
                Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
                will be used as a starting point, adding more noise to it the larger the `strength`. The number of
                denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
                be maximum and the denoising process will run for the full number of iterations specified in
                `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
            guidance_scale (`float`, *optional*, defaults to 4.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
                (`np.array`) or `"pt"` (`torch.Tensor`).
            callback (`Callable`, *optional*):
                A function that calls every `callback_steps` steps during inference. The function is called with the
                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function is called. If not specified, the callback is called at
                every step.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`
        r   z2`prompt` has to be of type `str` or `list` but is r   r   r`   )rI   rC   c              3      K   | ]8  }t        |t        j                  j                  t        j                  f       : y wN)rd   PILImagerf   Tensor).0is     r   	<genexpr>z4KandinskyImg2ImgPipeline.__call__.<locals>.<genexpr>  s(     Qa:a#))//5<<!@AQs   >A zInput is in incorrect format: z:. Currently, we only support  PIL image and pytorch tensorrO   )rC   r   )text_embedsr   F)sampletimestepencoder_hidden_statesadded_cond_kwargsr   variance_type)learnedlearned_range)rH   orderT)force_not_quantizer   )rV   nppilzIOnly the output types `pt`, `pil` and `np` are supported not output_type=)images)/rd   ro   re   r3   rK   rm   _execution_devicer   rf   rr   rl   rL   rI   allr8   
preprocessr#   encoder&   set_timestepsrF   r>   r4   num_train_timestepstensorr%   in_channelsr   r6   rR   	enumerateprogress_barsplitrJ   chunkhasattrr   stepprev_sampler2   XLA_AVAILABLExm	mark_stepdecodemaybe_free_model_hookspostprocessr   )%r9   rs   r   r   r   rv   r   r   rA   rB   r   rt   rH   r   r   r   r   rw   rC   ru   r}   r~   _r   rO   timesteps_tensorrP   num_channels_latentstlatent_model_inputr   
noise_predvariance_prednoise_pred_uncondnoise_pred_textvariance_pred_textstep_idxs%                                        r   __call__z!KandinskyImg2ImgPipeline.__call__-  s#   ^ fc"J%VJQRVW]R^Q_`aa''"77
&4s&:# 8<7J7JF13NP_8
411 lD) 99\q9L+T2$)II.C$K!&'99:OUV9WL$9$K$KLagh$K$i! 99&;\%JPQRUU#))& V L
 %&GEQ5QQ051Ia$q'1I0J  KE  F  		V[\QR4//::1eVL\bcd}226B))""5))4++,Aq+I 	$$%8$H040B0BCVX`bh0i-- dnn33GG(RSVWW,,'8:'EM]McMclrs#yy//;;#FE43I3IJ &&-vu=&,,NN
 d//0@AB '	DAq=XG9q=!9^e0=| \)&@"3! #  J +,6,<,<W]]1=MST,<,U)
M5?5E5Ea5H2!?(5(;(;A(>%%.?UfCf1gg
"YY
4F'GQO
 --?NN))77;WW * 0 0q1Aq 0 I
A nn))#	 * 
 k  #N(:a(? CC1g.O'	T 		  T B8L##%11hithuvww$$00D8O"%00s 2J ]s   T
>$Tr   )N   r   d   g333333?g      @r   Nr   Nr   T)"__name__
__module____qualname____doc__model_cpu_offload_seqr   r   r   r
   r   r0   rF   rR   r   rf   r   	IntTensorrN   no_gradr   EXAMPLE_DOC_STRINGr   ro   r   r   r   r   r>   float	Generatorr   boolr   __classcell__)r;   s   @r   r!   r!   a   s)   & 7
&
 
 '	

 #
 !
88, dDN,, || ??	
 
4 U]]_12 <@#& #%&MQ%*GK #J1c49n%J1 U\\399??D4FSYY__H]]^J1 ll	J1
  %||J1 "%T#Y"78J1 J1 J1 !J1 J1 J1  #J1 E%//43H"HIJJ1 c]J1 8S#u||$<d$BCDJ1  !J1" #J1 3 J1r   r!   )r(   )'typingr   r   r   r   	PIL.Imager   rf   transformersr   r8   r	   modelsr
   r   
schedulersr   utilsr   r   r   utils.torch_utilsr   pipeline_utilsr   r   r"   r   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   rj   r   r   r!   r   r   r   <module>r      s    3 2   1 3 ' 
 . C * ))MM			H	%& R6X10 X1r   