
    bio                        d dl Z d dlZd dlmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZmZ d dlmZ d dlZd dlmZ d dlmZ d d	lmZmZ d d
lmZ d dlmZ ddlmZ ddlmZ ddlm Z m!Z!m"Z"  e       rd dl#Z# ee$      Z% G d de      Z&y)    N)defaultdict)futures)Path)AnyCallableOptionalUnion)warn)Accelerator)
get_logger)ProjectConfigurationset_seed)PyTorchModelHubMixin)is_wandb_available   )DDPOStableDiffusionPipeline   )
DDPOConfig)PerPromptStatTrackergenerate_model_cardget_comet_experiment_urlc                       e Zd ZdZddgZ	 d#dedeej                  e	e
   e	e   gej                  f   deg e	e
ef   f   ded	eeeeegef      f
d
Zd$dZdedefdZd Zdej                  dedej                  fdZd Zd Zd Zd Zd Zde	ee
f   fdZd#dee   fdZd Z fdZ	 	 	 d%dee
   d ee
   d!ee
e e
   df   fd"Z! xZ"S )&DDPOTrainerah  
    The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. Note, this trainer is heavily
    inspired by the work here: https://github.com/kvablack/ddpo-pytorch As of now only Stable Diffusion based pipelines
    are supported

    Attributes:
        **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more:
         details.
        **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used:
        **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
        **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
        **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
    trlddpoNconfigreward_functionprompt_functionsd_pipelineimage_samples_hookc           
         |t        d       || _        || _        || _        || _        t        di | j                  j                  }| j                  j                  rt        j                  j                  t        j                  j                  | j                  j                              | j                  _        dt        j                  j                  | j                  j                        vrt        t        d t        j                  | j                  j                                    }t!        |      dk(  r"t#        d| j                  j                         t%        |D cg c]  }t'        |j)                  d      d         ! c}      }	t        j                  j+                  | j                  j                  d|	d          | j                  _        |	d   dz   |_        t'        | j                  j.                  | j                  j0                  z        | _        t5        d| j                  j6                  | j                  j8                  || j                  j:                  | j2                  z  d	| j                  j<                  | _        | jA                         \  }
}|
st#        |      |j6                  d uxr |j6                  d
k(  }| j>                  jB                  rp| j>                  jE                  | j                  jF                  |stI        |jK                               n|jK                         | j                  jL                         tN        jQ                  d|        tS        | j                  jT                  d       || _+        | jV                  jY                  d| j>                  jZ                   ddd       | j>                  j8                  dk(  rt\        j^                  }n:| j>                  j8                  dk(  rt\        j`                  }nt\        jb                  }| jV                  jd                  jg                  | j>                  jh                  |       | jV                  jj                  jg                  | j>                  jh                  |       | jV                  jl                  jg                  | j>                  jh                  |       | jV                  jo                         }| j>                  jq                  | jr                         | j>                  ju                  | jv                         | j                  jx                  r)dt\        jz                  j|                  j~                  _<        | j                  t        |t              s|j                         n|      | _C        | jV                  jk                  | jV                  j                  | j                  j                  dgn| j                  j                  ddd| jV                  j                  j                        j                  jg                  | j>                  jh                              d   | _H        |j                  r%t        |j                  |j                        | _M        | jV                  j                  xs | j>                  j                  | _N        t        | jV                  d      rn| jV                  j                  rX| j>                  j                  || j                        \  }| _C        t        t        d |j                                     | _R        n3| j>                  j                  || j                        \  | _R        | _C        | j                  j                  r%t        j                  |j                        | _W        |j                  rwtN        jQ                  d|j                          | j>                  j                  |j                         t'        |j                  j)                  d      d         dz   | _Y        y d| _Y        y c c}w )Nz8No image_samples_hook provided; no images will be loggedcheckpoint_c                 
    d| v S )Nr"    )xs    S/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/trl/trainer/ddpo_trainer.py<lambda>z&DDPOTrainer.__init__.<locals>.<lambda>S   s    -1"4     r   zNo checkpoints found in _r   )log_withmixed_precisionproject_configgradient_accumulation_stepstensorboard)ddpo_trainer_config)r   init_kwargs
T)device_specificFTimestep)positiondisableleavedescdynamic_ncolsfp16bf16)dtype pt
max_lengthreturn_tensorspadding
truncationr?   use_lorac                     | j                   S N)requires_grad)ps    r&   r'   z&DDPOTrainer.__init__.<locals>.<lambda>   s
    !// r(   )max_workerszResuming from r$   )Zr
   	prompt_fn	reward_fnr   image_samples_callbackr   project_kwargsresume_fromospathnormpath
expanduserbasenamelistfilterlistdirlen
ValueErrorsortedintsplitjoin	iterationsample_num_stepstrain_timestep_fractionnum_train_timestepsr   r+   r,   !train_gradient_accumulation_stepsaccelerator_kwargsaccelerator_config_checkis_main_processinit_trackerstracker_project_namedictto_dicttracker_kwargsloggerinfor   seedr   set_progress_bar_configis_local_main_processtorchfloat16bfloat16float32vaetodevicetext_encoderunetget_trainable_layersregister_save_state_pre_hook_save_model_hookregister_load_state_pre_hook_load_model_hook
allow_tf32backendscudamatmul_setup_optimizer
isinstance
parameters	optimizer	tokenizernegative_promptsmodel_max_length	input_idsneg_prompt_embedper_prompt_stat_trackingr   $per_prompt_stat_tracking_buffer_size"per_prompt_stat_tracking_min_countstat_trackerautocasthasattrrD   preparetrainable_layersasync_reward_computationr   ThreadPoolExecutorrI   executor
load_statefirst_epoch)selfr   r   r   r   r    accelerator_project_configcheckpointsr%   checkpoint_numbersis_okaymessageis_using_tensorboardinference_dtyper   rx   s                   r&   __init__zDDPOTrainer.__init__;   sp    %KL((&8#%9%WDKK<V<V%W";;""&(gg&6&6rww7I7I$++JaJa7b&cDKK#BGG$4$4T[[5L5L$MM"4

4;;#:#:; {#q($'?@W@W?X%YZZ%+K,XqSb1A-B,X%Y"*,'',,KK++!"4R"8!9:+'
 8J"7MPQ7Q*4 $'t{{'C'CdkkFiFi'i#j & 	
[[)) KK775 )-(U(UX\XpXp(p	
 kk,,	
  --/W%%%d:_vR_?_++**00I]t0@Acicqcqcs KK66 +  	bM"!!48&00((>>> 	1 	
 ++v5#mmO--7#nnO#mmO 0 0 7 7O%%(()9)9)@)@(X  !1!1!8!8 P++@@B55d6K6KL55d6K6KL ;;!!48ENN&&1..1;<Ld1S'')Yi
 !% 0 0 = =&&44<$++B^B^#$++55FF '  i4++223!
 ! ** 4;;99!D ((11NT5E5E5N5N4##Z0T5E5E5N5N#'#3#3#;#;<Ldnn#] D$.$(0I4??K\)]$^D!484D4D4L4LM]_c_m_m4n1D!4>;;//#666CUCUVDMKK.););(<=>''(:(:;"6#5#5#;#;C#@#DEID D] -Ys   *$ac                     |sgg }|D ]X  \  }}} j                  |||      \  }}|j                  t        j                  | j                  j
                        |f       Z t        | S  j                  j                   fd|      }|D cg c]N  \  }}t        j                  |j                          j                  j
                        |j                         fP }}}t        | S c c}}w )Nrv   c                 "     j                   |  S rF   )rK   )r%   r   s    r&   r'   z-DDPOTrainer.compute_rewards.<locals>.<lambda>   s    .$..!2D r(   )
rK   appendrp   	as_tensorrc   rv   r   mapresultzip)	r   prompt_image_pairsis_asyncrewardsimagespromptsprompt_metadatarewardreward_metadatas	   `        r&   compute_rewardszDDPOTrainer.compute_rewards   s    G4F 0*.../*Z't7G7G7N7NO' G} mm''(DFXYG 07+FO 9I9I9P9PQSbSiSiSklG 
 G}s   AC-epochglobal_stepc                 
   | j                  | j                  j                  | j                  j                        \  }}|d   j	                         D ci c])  }|t        j                  |D cg c]  }||   	 c}      + }}}| j                  || j                  j                        \  }}t        |      D ]  \  }	}
|
j                  ||	   ||	   g         | j                  *| j                  ||| j                  j                  d          t        j                  |      }| j                  j                  |      j                         j!                         }| j                  j#                  |||j%                         |j'                         d|       | j                  j(                  r~| j                  j                  |d         j                         j!                         }| j*                  j,                  j/                  |d	      }| j0                  j3                  ||      }n'||j%                         z
  |j'                         d
z   z  }t        j4                  |      j7                  | j                  j8                  d      | j                  j:                     j=                  | j                  j>                        |d<   |d= |d   j@                  \  }}tC        | j                  jD                        D ]  }t        jF                  || j                  j>                        }|jI                         D ci c]  \  }}|||    }}}t        jJ                  tC        |      D cg c]-  }t        jF                  || j                  j>                        / c}      }dD ]?  }||   t        jL                  || j                  j>                        dddf   |f   ||<   A |j	                         }|jO                         }|D cg c]7  } |j6                  d| j                  jP                  g|j@                  dd  9 }}tS        | }|D cg c]  }tU        tS        ||             }}| j*                  jV                  jY                          | j[                  ||||      }| j                  j\                  rt_        d       |dk7  rL|| j                  j`                  z  dk(  r0| j                  jb                  r| j                  je                          |S c c}w c c}}w c c}}w c c}w c c}w c c}w )a  
        Perform a single step of training.

        Args:
            epoch (int): The current epoch.
            global_step (int): The current global step.

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.
            - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step,
              and the accelerator tracker.

        Returns:
            global_step (int): The updated global step.

        )
iterations
batch_sizer   )r   N)r   r   reward_mean
reward_stdstep
prompt_idsT)skip_special_tokensg:0yE>r*   
advantages	timestepsr   )r   latentsnext_latents	log_probsr   zsOptimization step should have been performed by this point. Please check calculated gradient accumulation settings.)3_generate_samplesr   sample_num_batches_per_epochsample_batch_sizekeysrp   catr   r   	enumerateextendrL   rc   trackersgathercpunumpylogmeanstdr   r   r   batch_decoder   updater   reshapenum_processesprocess_indexru   rv   shaperangetrain_num_inner_epochsrandpermitemsstackarangevaluestrain_batch_sizer   rh   rx   train_train_batched_samplessync_gradientsrX   	save_freqre   
save_state)r   r   r   samplesprompt_image_dataksr   rewards_metadatai
image_datar   r   r   total_batch_sizenum_timestepsinner_epochpermvr)   permskeyoriginal_keysoriginal_valuesreshaped_valuestransposed_values
row_valuessamples_batcheds                               r&   r   zDDPOTrainer.step   s   $ &*%;%;{{??{{44 &< &
"" CJ!*//BSTQ1eiiw 7!1 788TT$($8$8(L(L %9 %
!! ''89 	AMAzwqz+;A+>?@	A &&2''(9;HXHXHaHabcHde))G$""))'2668>>@!&||~%kkm	  	 	
 ;;//))001FGKKMSSUJ&&00==j^b=cG**11'7CJ!GLLN2w{{}t7KLJ OOJ'WT%%33R89I9I9W9WYR  ''( 	 L!*1+*>*D*D'- !C!CD !	K>>"24;K;K;R;RSD.5mmo>daq!D'z>G> KKX]^nXopSTd6F6F6M6MNpE M &s|LL!1$:J:J:Q:QRSTVZSZ[  $LLNM%nn.Obqr]^yqyyT[[-I-IXAGGTUTVKXrOr !$_ 5Vgh
tCz$BChOh!!'')55k5+WfgK##22  J ?!	F A:%$++"7"771<AQAQAaAa'')c !8T\ ?
 q s
 is0   T/.T*:	T/T522T;
<U U*T/c                 <   | j                         5  | j                  j                  r| j                  j	                  t        j                  |gdz        t        j                  |gdz        |      j                  }|j                  d      \  }}	|| j                  j                  |	|z
  z  z   }n'| j                  j	                  |||      j                  }| j                  j                  |||| j                  j                  |      }
|
j                  }ddd       t        j                  || j                  j                   | j                  j                        }t        j                  |z
        }| j!                  || j                  j"                  |      }dt        j$                  ||z
  dz        z  }t        j$                  t        j&                  |dz
        | j                  j"                  kD  j)                               }|||fS # 1 sw Y   xY w)a  
        Calculate the loss for a batch of an unpacked sample

        Args:
            latents (torch.Tensor):
                The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
            timesteps (torch.Tensor):
                The timesteps sampled from the diffusion model, shape: [batch_size]
            next_latents (torch.Tensor):
                The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height,
                width]
            log_probs (torch.Tensor):
                The log probabilities of the latents, shape: [batch_size]
            advantages (torch.Tensor):
                The advantages of the latents, shape: [batch_size]
            embeds (torch.Tensor):
                The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] Note: the "or" is because if
                train_cfg is True, the expectation is that negative prompts are concatenated to the embeds

        Returns:
            loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) (all of these are of shape (1,))
        r   )etaprev_sampleNg      ?      ?)r   r   	train_cfgr   rx   rp   r   samplechunksample_guidance_scalescheduler_step
sample_etar   clamptrain_adv_clip_maxexplosstrain_clip_ranger   absfloat)r   r   r   r   r   r   embeds
noise_prednoise_pred_uncondnoise_pred_textscheduler_step_outputlog_probratior  	approx_klclipfracs                   r&   calculate_losszDDPOTrainer.calculate_lossH  s   . ]]_ 	7{{$$!--22IIwi!m,IIykAo. &	 
 6@5E5Ea5H2!?.1R1R#&772 
 "--22 &	  %)$4$4$C$CKK**( %D %! -66H7	7: [[[[+++KK**

 		(Y./yyT[[%A%A5I%**h&:q%@AA	::uyy58T8TT[[]^Y((W	7 	7s   DHHr   
clip_ranger  c                     | |z  }| t        j                  |d|z
  d|z         z  }t        j                  t        j                  ||            S )Nr   )rp   r  r   maximum)r   r   r  r  unclipped_lossclipped_losss         r&   r  zDDPOTrainer.loss  sV     %u,"{U[[**&
 

 zz%--EFFr(   c                 ~   | j                   j                  rdd l}|j                  j                  }nt
        j                  j                  } ||| j                   j                  | j                   j                  | j                   j                  f| j                   j                  | j                   j                        S )Nr   )lrbetasweight_decayeps)r   train_use_8bit_adambitsandbytesoptim	AdamW8bitrp   AdamWtrain_learning_ratetrain_adam_beta1train_adam_beta2train_adam_weight_decaytrain_adam_epsilon)r   trainable_layers_parametersr  optimizer_clss       r&   r   zDDPOTrainer._setup_optimizer  s    ;;**(..88M!KK--M'{{..;;//1M1MN<<..
 	
r(   c                 ^    | j                   j                  |||       |j                          y rF   )r   save_checkpointpop)r   modelsweights
output_dirs       r&   r{   zDDPOTrainer._save_model_hook  s#    ((*Er(   c                 \    | j                   j                  ||       |j                          y rF   )r   load_checkpointr*  )r   r+  	input_dirs      r&   r}   zDDPOTrainer._load_model_hook  s!    ((;

r(   c                    g }g }| j                   j                  j                          | j                  j	                  |dd      }t        |      D ]  }t        t        |      D cg c]  }| j                          c} \  }}| j                   j                  |ddd| j                   j                  j                        j                  j                  | j                  j                        }	| j                   j                  |	      d   }
| j                         5  | j                  |
|| j                   j"                  | j                   j$                  | j                   j&                  d      }|j(                  }|j*                  }|j,                  }ddd       t/        j0                  d	      }t/        j0                  d	      }| j                   j2                  j4                  j	                  |d      }|j7                  |	|
||dddd
f   |ddddf   ||d       |j7                  ||g        ||fS c c}w # 1 sw Y   xY w)a4  
        Generate samples from the model

        Args:
            iterations (int): Number of iterations to generate samples for
            batch_size (int): Batch size to use for sampling

        Returns:
            samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
        r   r>   r?   Tr@   r   )prompt_embedsnegative_prompt_embedsnum_inference_stepsguidance_scaler   output_typeN)dimr*   )r   r2  r   r   r   r   r3  )r   rx   evalr   repeatr   r   rJ   r   r   r   ru   rc   rv   rw   r   r   r^   r   r   r   r   r   rp   r   	schedulerr   r   )r   r   r   r   r   sample_neg_prompt_embedsr)   r   r   r   r2  	sd_outputr   r   r   r   s                   r&   r   zDDPOTrainer._generate_samples  s-    ""$#'#8#8#?#?
Aq#Q z" )	JA'*uZGX,Y!T^^-=,Y'Z$G_))33#$++55FF 4  i4++223  !,,99*EaHM 0 ,,"/+C(,(D(D#';;#D#D.. $ - 	 #))#++%//	0 kk'q1GI15I((22<<CCJPQRINN",%2!*&q#2#v$+AqrEN!*.F
 %%vw&HIS)	JV ***U -Z0 0s   'I
A8II	c                    t        t              }t        |      D ]  \  }}| j                  j                  rt        j                  |d   |d   g      }n|d   }t        | j                        D ]R  }	| j                  j                  | j                  j                        5  | j                  |d   dd|	f   |d   dd|	f   |d   dd|	f   |d   dd|	f   |d   |      \  }
}}|d	   j                  |       |d
   j                  |       |d   j                  |
       | j                  j                  |
       | j                  j                   rn| j                  j#                  t%        | j&                  t              s| j&                  j)                         n| j&                  | j                  j*                         | j,                  j/                          | j,                  j1                          ddd       | j                  j                   s|j3                         D ci c].  \  }}|t        j4                  t        j6                  |            0 }}}| j                  j9                  |d      }|j;                  ||d       | j                  j=                  ||       |dz  }t        t              }U  |S # 1 sw Y   xY wc c}}w )a  
        Train on a batch of samples. Main training segment

        Args:
            inner_epoch (int): The current inner epoch
            epoch (int): The current epoch
            global_step (int): The current global step
            batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.

        Returns:
            global_step (int): The updated global step
        r3  r2  r   Nr   r   r   r   r  r  r  r   )	reduction)r   r   r   r   )r   rT   r   r   r   rp   r   r   r`   rc   
accumulater   rx   r  r   backwardr   clip_grad_norm_r   r   r   train_max_grad_normr   r   	zero_gradr   r   r   reducer   r   )r   r   r   r   batched_samplesrl   _ir   r  jr  r  r  r   r   s                  r&   r   z"DDPOTrainer._train_batched_samples  s   " 4 #O4 (	-JB{{$$F+C$Df_F]#^_04334 !-%%001A1A1F1FG /040C0Cy)!Q$/{+AqD1~.q!t4{+AqD1|,1-D)X %,,Y7$++H5L''-$$--d3''66((88#-d.C.CT#J !11<<>!%!6!6 KK;;	 NN'')NN,,.-/2 ##22FJjjlSdaAuzz%++a.99SDS++22462JDKK% LM$$((K(@1$K&t,DC!-(	-R C/ /6 Ts   #EK+3K
Kreturnc                    | j                   j                  | j                  j                  z  | j                   j                  z  }| j                   j
                  | j                  j                  z  | j                   j                  z  }| j                   j                  | j                   j
                  k\  s3dd| j                   j                   d| j                   j
                   dfS | j                   j                  | j                   j
                  z  dk(  s3dd| j                   j                   d| j                   j
                   dfS ||z  dk(  sdd| d| dfS y	)
NFzSample batch size (z9) must be greater than or equal to the train batch size ()r   z-) must be divisible by the train batch size (zNumber of samples per epoch (z3) must be divisible by the total train batch size ()Tr=   )r   r   rc   r   r   r   ra   )r   samples_per_epochtotal_train_batch_sizes      r&   rd   zDDPOTrainer._config_check-  s   KK))D,<,<,J,JJT[[MuMuu 	 KK((,,-kk;;< 	 {{,,0L0LL%dkk&C&C%DD}  C  J  J  [  [  ~\  \]  ^  {{,,t{{/K/KKqP%dkk&C&C%DDqrvr}r}  sO  sO  rP  PQ  R  !#99Q>/0A/BBu  wM  vN  NO  P  r(   epochsc                     d}|| j                   j                  }t        | j                  |      D ]  }| j	                  ||      } y)z>
        Train the model for a given number of epochs
        r   N)r   
num_epochsr   r   r   )r   rM  r   r   s       r&   r   zDDPOTrainer.trainH  sI     >[[++F4++V4 	8E))E;7K	8r(   c                 Z    | j                   j                  |       | j                          y rF   )r   save_pretrainedcreate_model_card)r   save_directorys     r&   _save_pretrainedzDDPOTrainer._save_pretrainedR  s"    ((8 r(   c                    | j                   j                  *t        | j                   j                        j                  }n(| j                   j                  j                  d      d   }| j                  |       t        | !  ||       y )N/r*   )
model_name)	argshub_model_idr   r-  namer[   rR  super_save_checkpoint)r   modeltrialrW  	__class__s       r&   r\  zDDPOTrainer._save_checkpointW  sl    99!!)dii22388J//55c:2>J*5 .r(   rW  dataset_nametagsc                    | j                         syt        | j                  j                  d      r^t        j
                  j                  | j                  j                  j                        s!| j                  j                  j                  }nd}|t               }nt        |t              r|h}nt        |      }t        | j                  j                  d      r|j                  d       |j                  | j                         t        j                  d      }t!        ||| j"                  ||t%               r.t&        j(                  t&        j(                  j+                         ndt-               d|dd	      }|j/                  t        j
                  j1                  | j2                  j4                  d
             y)a  
        Creates a draft of a model card using the information available to the `Trainer`.

        Args:
            model_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the model.
            dataset_name (`str` or `None`, *optional*, defaults to `None`):
                Name of the dataset used for training.
            tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
                Tags to be associated with the model card.
        N_name_or_pathunsloth_versionunslotha          @inproceedings{black2024training,
            title        = {{Training Diffusion Models with Reinforcement Learning}},
            author       = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
            year         = 2024,
            booktitle    = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
            publisher    = {OpenReview.net},
            url          = {https://openreview.net/forum?id=YCWjhGrJFD},
        }DDPOz5Training Diffusion Models with Reinforcement Learningz
2305.13301)
base_modelrW  rY  r`  ra  	wandb_url	comet_urltrainer_nametrainer_citationpaper_titlepaper_idz	README.md)is_world_process_zeror   r]  r   rO   rP   isdirrc  setr   straddr   
_tag_namestextwrapdedentr   rY  r   wandbrunget_urlr   saver\   rX  r-  )r   rW  r`  ra  rg  citation
model_cards          r&   rR  zDDPOTrainer.create_model_card_  sG   " ))+4::$$o6rww}}TZZM^M^MlMl?m**88JJ <5Dc"6Dt9D4::$$&78HHYDOO$?? $  )!!**%-?-AeiiF[eii'')ae.0%O!

 	TYY%9%9;GHr(   rF   )F)NNN)#__name__
__module____qualname____doc__rs  r   r   rp   Tensortuplerq  r   r   r   r   r   rZ   r   r  r  r  r   r{   r}   r   r   boolrd   r   rT  r\  r	   rT   rR  __classcell__)r_  s   @r&   r   r   *   s    J HLL!L! "5<<sU3Z"H%,,"VWL! ""eCHo"56	L!
 1L! %XsCos.B%CDL!\(i# iC iVB)HGLLG G ||	G
 <+|;zuT3Y/ 68HSM 8!
/ %)&*,0	>ISM>I sm>I CcD()	>Ir(   r   )'rO   rt  collectionsr   
concurrentr   pathlibr   typingr   r   r   r	   warningsr
   rp   
accelerater   accelerate.loggingr   accelerate.utilsr   r   huggingface_hubr   transformersr   r+  r   ddpo_configr   utilsr   r   r   rv  r|  rk   r   r$   r(   r&   <module>r     se    
  #   1 1   " ) ; 0 + 0 # V V  
H	s	I& s	Ir(   