
    biI                         d dl Z d dlZd dlmZ d dlmZ d dlmZmZm	Z	m
Z
 d dlmZ d dlZd dlmZ d dlmZ d dlmZmZ d d	lmZ d d
lmZ ddlmZ ddlmZ ddlmZmZ  e       rd dl Z  ee!      Z" G d de      Z#y)    N)defaultdict)Path)AnyCallableOptionalUnion)warn)Accelerator)
get_logger)ProjectConfigurationset_seed)PyTorchModelHubMixin)is_wandb_available   )DDPOStableDiffusionPipeline   )AlignPropConfig)generate_model_cardget_comet_experiment_urlc                   ~    e Zd ZdZddgZ	 d dedeej                  e	e
   e	e   gej                  f   deg e	e
ef   f   ded	eeeeegef      f
d
Zd ZdedefdZd Zdej                  dedej                  fdZd Zd Zd Zd!dZd dee   fdZd Z fdZ	 	 	 d"dee
   dee
   dee
ee
   df   fdZ xZS )#AlignPropTrainera  
    The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is
    heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ As of now only Stable Diffusion based
    pipelines are supported

    Attributes:
        config (`AlignPropConfig`):
            Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
        reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
            Reward function to be used
        prompt_function (`Callable[[], tuple[str, Any]]`):
            Function to generate prompts to guide model
        sd_pipeline (`DDPOStableDiffusionPipeline`):
            Stable Diffusion pipeline to be used for training.
        image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
            Hook to be called to log images
    trl	alignpropNconfigreward_functionprompt_functionsd_pipelineimage_samples_hookc           
         |t        d       || _        || _        || _        || _        t        di | j                  j                  }| j                  j                  rt        j                  j                  t        j                  j                  | j                  j                              | j                  _        dt        j                  j                  | j                  j                        vrt        t        d t        j                  | j                  j                                    }t!        |      dk(  r"t#        d| j                  j                         t%        |D cg c]  }t'        |j)                  d      d         ! c}      }	t        j                  j+                  | j                  j                  d|	d          | j                  _        |	d   dz   |_        t/        d| j                  j0                  | j                  j2                  || j                  j4                  d	| j                  j6                  | _        |j0                  d uxr |j0                  d
k(  }
| j8                  j:                  rp| j8                  j=                  | j                  j>                  |
stA        |jC                               n|jC                         | j                  jD                         tF        jI                  d|        tK        | j                  jL                  d       || _'        | jN                  jQ                  d| j8                  jR                   ddd       | j8                  j2                  dk(  rtT        jV                  }n:| j8                  j2                  dk(  rtT        jX                  }ntT        jZ                  }| jN                  j\                  j_                  | j8                  j`                  |       | jN                  jb                  j_                  | j8                  j`                  |       | jN                  jd                  j_                  | j8                  j`                  |       | jN                  jg                         }| j8                  ji                  | jj                         | j8                  jm                  | jn                         | j                  jp                  r)dtT        jr                  jt                  jv                  _8        | jy                  t{        |t              s|j}                         n|      | _?        | jN                  jc                  | jN                  j                  | j                  j                  dgn| j                  j                  ddd| jN                  j                  j                        j                  j_                  | j8                  j`                              d   | _D        | jN                  j                  xs | j8                  j                  | _E        t        | jN                  d      rn| jN                  j                  rX| j8                  j                  || j~                        \  }| _?        t        t        d |j}                                     | _I        n3| j8                  j                  || j~                        \  | _I        | _?        |j                  rwtF        jI                  d|j                          | j8                  j                  |j                         t'        |j                  j)                  d      d         dz   | _K        y d| _K        y c c}w )Nz8No image_samples_hook provided; no images will be loggedcheckpoint_c                 
    d| v S )Nr     )xs    X/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/trl/trainer/alignprop_trainer.py<lambda>z+AlignPropTrainer.__init__.<locals>.<lambda>U   s    -1"4     r   zNo checkpoints found in _r   )log_withmixed_precisionproject_configgradient_accumulation_stepstensorboard)alignprop_trainer_config)r   init_kwargs
T)device_specificFTimestep)positiondisableleavedescdynamic_ncolsfp16bf16)dtype pt
max_lengthreturn_tensorspadding
truncationr=   use_lorac                     | j                   S N)requires_grad)ps    r$   r%   z+AlignPropTrainer.__init__.<locals>.<lambda>   s
    !// r&   zResuming from r"   )Lr	   	prompt_fn	reward_fnr   image_samples_callbackr   project_kwargsresume_fromospathnormpath
expanduserbasenamelistfilterlistdirlen
ValueErrorsortedintsplitjoin	iterationr
   r)   r*   !train_gradient_accumulation_stepsaccelerator_kwargsacceleratoris_main_processinit_trackerstracker_project_namedictto_dicttracker_kwargsloggerinfor   seedr   set_progress_bar_configis_local_main_processtorchfloat16bfloat16float32vaetodevicetext_encoderunetget_trainable_layersregister_save_state_pre_hook_save_model_hookregister_load_state_pre_hook_load_model_hook
allow_tf32backendscudamatmul_setup_optimizer
isinstance
parameters	optimizer	tokenizernegative_promptsmodel_max_length	input_idsneg_prompt_embedautocasthasattrrB   preparetrainable_layers
load_statefirst_epoch)selfr   r   r   r   r   accelerator_project_configcheckpointsr#   checkpoint_numbersis_using_tensorboardinference_dtyper   rq   s                 r$   __init__zAlignPropTrainer.__init__=   s    %KL((&8#%9%WDKK<V<V%W";;""&(gg&6&6rww7I7I$++JaJa7b&cDKK#BGG$4$4T[[5L5L$MM"4

4;;#:#:; {#q($'?@W@W?X%YZZ%+K,XqSb1A-B,X%Y"*,'',,KK++!"4R"8!9:+'
 8J"7MPQ7Q*4& 	
[[)) KK775 )-(U(U	
 kk,,	
  &d:_vR_?_++**00+ V^^5EF^^% KK66 +  	bM"!!48&00((>>> 	1 	
 ++v5#mmO--7#nnO#mmO 0 0 7 7O%%(()9)9)@)@(X  !1!1!8!8 P++@@B55d6K6KL55d6K6KL ;;!!48ENN&&1..1;<Ld1S'')Yi
 !% 0 0 = =&&44<$++B^B^#$++55FF '  i4++223!
 ! ((11NT5E5E5N5N4##Z0T5E5E5N5N#'#3#3#;#;<Ldnn#] D$.$(0I4??K\)]$^D!484D4D4L4LM]_c_m_m4n1D!4>KK.););(<=>''(:(:;"6#5#5#;#;C#@#DEID DA -Ys   *$^c                 D    | j                  |d   |d   |d         \  }}|S )Nimagespromptsprompt_metadata)rH   )r   prompt_image_pairsrewardreward_metadatas       r$   compute_rewardsz AlignPropTrainer.compute_rewards   s5    "&..x(*<Y*GI[\mIn#
 r&   epochglobal_stepc           
      b   t        t              }| j                  j                  j	                          t        | j                  j                        D ]<  }| j                  j                  | j                  j                        5  | j                         5  t        j                         5  | j                  | j                  j                        }| j                  |      }||d<   | j                  j!                  |      j#                         j%                         j'                         }| j)                  |      }| j                  j+                  |       | j                  j,                  rn| j                  j/                  t1        | j2                  t              s| j2                  j5                         n| j2                  | j                  j6                         | j8                  j;                          | j8                  j=                          ddd       ddd       ddd       |d   j?                  jA                                |d   j?                  |jC                                |d   j?                  jE                                ? | j                  j,                  r|jG                         D 	
ci c].  \  }	}
|	t        j@                  t        jH                  |
            0 }}	}
| j                  jK                  |d      }|jM                  d	|i       | j                  jO                  ||
       |dz  }t        t              }ntQ        d      | jR                  F|| j                  jT                  z  dk(  r*| jS                  || j                  jV                  d          |dk7  rL|| j                  jX                  z  dk(  r0| j                  jZ                  r| j                  j]                          |S # 1 sw Y   xY w# 1 sw Y   xY w# 1 sw Y   xY wc c}
}	w )a  
        Perform a single step of training.

        Args:
            epoch (int): The current epoch.
            global_step (int): The current global step.

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.
            - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
              and the accelerator tracker.

        Returns:
            global_step (int): The updated global step.
        )
batch_sizerewardsNreward_mean
reward_stdlossmean)	reductionr   )stepr   zsOptimization step should have been performed by this point. Please check calculated gradient accumulation settings.r   )/r   rQ   r   rq   trainranger   r[   r]   
accumulater   ri   enable_grad_generate_samplestrain_batch_sizer   gatherdetachcpunumpycalculate_lossbackwardsync_gradientsclip_grad_norm_r|   r   r}   train_max_grad_normr~   r   	zero_gradappendr   stditemitemstensorreduceupdatelogrU   rI   log_image_freqtrackers	save_freqr^   
save_state)r   r   r   re   r'   r   r   rewards_visr   kvs              r$   r   zAlignPropTrainer.step   sa   " 4 ##%t{{DDE 	-A!!,,T-=-=-B-BC +T]]_ +V[VgVgVi +%)%;%;#{{;; &< &" ../AB07"9-"..55g>EEGKKMSSU**73  ))$/##22$$44)$*?*?F --88:!2277	 ##%((*1+ + +4 &&{'7'7'9:%%koo&78L		,;	-@ **?Czz|Ltq!Auzz%,,q/22LDL##**46*BDKK%()  K 81Kt$D F  &&2{T[[E_E_7_cd7d''(:KIYIYIbIbcdIefA:%$++"7"771<AQAQAaAa'')e+ + + + + +B MsC   PP,E&P	PP<3P+P	PPPP(	c                 ,    d|j                         z
  }|S )a(  
        Calculate the loss for a batch of an unpacked sample

        Args:
            rewards (torch.Tensor):
                Differentiable reward scalars for each generated image, shape: [batch_size]

        Returns:
            loss (torch.Tensor) (all of these are of shape (1,))
        g      $@)r   )r   r   r   s      r$   r   zAlignPropTrainer.calculate_loss  s     wnn&&r&   
advantages
clip_rangeratioc                     | |z  }| t        j                  |d|z
  d|z         z  }t        j                  t        j                  ||            S )Ng      ?)ri   clampr   maximum)r   r   r   r   unclipped_lossclipped_losss         r$   r   zAlignPropTrainer.loss  sV     %u,"{U[[**&
 

 zz%--EFFr&   c                 ~   | j                   j                  rdd l}|j                  j                  }nt
        j                  j                  } ||| j                   j                  | j                   j                  | j                   j                  f| j                   j                  | j                   j                        S )Nr   )lrbetasweight_decayeps)r   train_use_8bit_adambitsandbytesoptim	AdamW8bitri   AdamWtrain_learning_ratetrain_adam_beta1train_adam_beta2train_adam_weight_decaytrain_adam_epsilon)r   trainable_layers_parametersr   optimizer_clss       r$   r{   z!AlignPropTrainer._setup_optimizer*  s    ;;**(..88M!KK--M'{{..;;//1M1MN<<..
 	
r&   c                 ^    | j                   j                  |||       |j                          y rD   )r   save_checkpointpop)r   modelsweights
output_dirs       r$   rt   z!AlignPropTrainer._save_model_hook:  s#    ((*Er&   c                 \    | j                   j                  ||       |j                          y rD   )r   load_checkpointr   )r   r   	input_dirs      r$   rv   z!AlignPropTrainer._load_model_hook>  s!    ((;

r&   c                    i }| j                   j                  |dd      }|1t        t        |      D cg c]  }| j	                          c} \  }}nt        |      D cg c]  }i  }}| j
                  j                  |ddd| j
                  j                  j                        j                  j                  | j                  j                        }| j
                  j                  |      d   }	|r| j
                  j                  |	|| j                  j                  | j                  j                   | j                  j"                  | j                  j$                  | j                  j&                  | j                  j(                  d	      }
nS| j                  |	|| j                  j                  | j                  j                   | j                  j"                  d      }
|
j*                  }||d	<   ||d
<   ||d<   |S c c}w c c}w )a  
        Generate samples from the model

        Args:
            batch_size (int): Batch size to use for sampling
            with_grad (bool): Whether the generated RGBs should have gradients attached to it.

        Returns:
            prompt_image_pairs (dict[Any])
        r   r<   r=   Tr>   r   )	prompt_embedsnegative_prompt_embedsnum_inference_stepsguidance_scaleetatruncated_backprop_randtruncated_backprop_timesteptruncated_rand_backprop_minmaxoutput_type)r   r   r   r   r   r   r   r   r   )r   repeatzipr   rG   r   r   r   r   rn   r]   ro   rp   rgb_with_gradr   sample_num_stepssample_guidance_scale
sample_etar   r   r   r   )r   r   	with_gradr   r   sample_neg_prompt_embedsr'   r   
prompt_idsr   	sd_outputr   s               r$   r   z"AlignPropTrainer._generate_samplesB  s     #'#8#8#?#?
Aq#Q ?'*uZGX,Y!T^^-=,Y'Z$G_+0+<=ar=O=%%// ''11BB 0 
 )BBt''../ 	 ((55jA!D((66+'?$(KK$@$@#{{@@KK**(,(K(K,0KK,S,S/3{{/Y/Y  7 
I ((+'?$(KK$@$@#{{@@KK**  ) I !!'-8$(/9%0?,-!!U -Z=s   H 	Hepochsc                     d}|| j                   j                  }t        | j                  |      D ]  }| j	                  ||      } y)z>
        Train the model for a given number of epochs
        r   N)r   
num_epochsr   r   r   )r   r   r   r   s       r$   r   zAlignPropTrainer.train~  sI     >[[++F4++V4 	8E))E;7K	8r&   c                 Z    | j                   j                  |       | j                          y rD   )r   save_pretrainedcreate_model_card)r   save_directorys     r$   _save_pretrainedz!AlignPropTrainer._save_pretrained  s"    ((8 r&   c                    | j                   j                  *t        | j                   j                        j                  }n(| j                   j                  j                  d      d   }| j                  |       t        | !  ||       y )N/r(   )
model_name)	argshub_model_idr   r   namerX   r   super_save_checkpoint)r   modeltrialr  	__class__s       r$   r	  z!AlignPropTrainer._save_checkpoint  sl    99!!)dii22388J//55c:2>J*5 .r&   r  dataset_nametagsc                    | j                         syt        | j                  j                  d      r^t        j
                  j                  | j                  j                  j                        s!| j                  j                  j                  }nd}|t               }nt        |t              r|h}nt        |      }t        | j                  j                  d      r|j                  d       |j                  | j                         t        j                  d      }t!        ||| j"                  ||t%               r.t&        j(                  t&        j(                  j+                         ndt-               d|dd	      }|j/                  t        j
                  j1                  | j2                  j4                  d
             y)a  
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        N_name_or_pathunsloth_versionunslothaS          @article{prabhudesai2024aligning,
            title        = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
            author       = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
            year         = 2024,
            eprint       = {arXiv:2310.03739}
        }	AlignPropzCAligning Text-to-Image Diffusion Models with Reward Backpropagationz
2310.03739)
base_modelr  r  r  r  	wandb_url	comet_urltrainer_nametrainer_citationpaper_titlepaper_idz	README.md)is_world_process_zeror   r
  r   rL   rM   isdirr  setr|   straddr   
_tag_namestextwrapdedentr   r  r   wandbrunget_urlr   saverY   r  r   )r   r  r  r  r  citation
model_cards          r$   r   z"AlignPropTrainer.create_model_card  sG   " ))+4::$$o6rww}}TZZM^M^MlMl?m**88JJ <5Dc"6Dt9D4::$$&78HHYDOO$?? $  )!!**%-?-AeiiF[eii'')ae.0$%]!

 	TYY%9%9;GHr&   rD   )TN)NNN) __name__
__module____qualname____doc__r   r   r   ri   Tensortupler  r   r   r   r   r   rW   r   r   floatr   r{   rt   rv   r   r   r  r	  r   rQ   r   __classcell__)r  s   @r$   r   r   (   sd   $ %J HL~!~! "5<<sU3Z"H%,,"VW~! ""eCHo"56	~!
 1~! %XsCos.B%CD~!@H# HC HTGLLG G ||	G
 :"x8HSM 8!
/ %)&*,0	<ISM<I sm<I CcD()	<Ir&   r   )$rL   r!  collectionsr   pathlibr   typingr   r   r   r   warningsr	   ri   
accelerater
   accelerate.loggingr   accelerate.utilsr   r   huggingface_hubr   transformersr   r   r   alignprop_configr   utilsr   r   r#  r)  rd   r   r"   r&   r$   <module>r<     s]    
  #  1 1   " ) ; 0 + 0 - @ 	H	iI+ iIr&   