
    biP                       d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
 d dlZd dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZmZmZ d dlmZmZmZ ddlmZ ddl m!Z! ddl"m#Z# ddl$m%Z%m&Z& 	 	 	 	 	 	 ddZ' G d de      Z(y)    )annotationsN)asdict)Enum)OptionalUnion)_calculate_correct_fan)tqdm)Conv1D)is_bnb_4bit_availableis_bnb_available)	BaseTunerBaseTunerLayercheck_target_module_exists)2TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPINGModulesToSaveWrapper_get_submodules   )
BufferDict) _maybe_include_all_linear_layers   )
VeraConfig)Linear	VeraLayerc                r   t        | t              rt        j                  |       }n| }t	        |d      }t        j                  d      }|t        j                  |      z  }t        j                  d      |z  }t        j                         5  |j                  | ||      cddd       S # 1 sw Y   yxY w)a  
    Kaiming Uniform Initialisation adapted to accept a `torch.Generator` object for PRNG.

    Args:
        tensor_or_shape (`Union[torch.Tensor, tuple[int, ...]]`):
            Tensor to initialise, or shape of new tensor to create and then initialise.
        generator: (`torch.Generator`):
            Generator object that manages the state of the PRNG algorithm in use.

    Returns:
        `torch.Tensor`: The initialised tensor.
    fan_inr   g      @	generatorN)	
isinstancetupletorchemptyr   mathsqrtno_graduniform_)tensor_or_shaper   tensorfangainstdbounds          Q/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/peft/tuners/vera/model.py_kaiming_initr-   +   s      /5)_- 
 
2C99Q<D
3
CIIcNS E	 Cvu	BC C Cs   B--B6c                      e Zd ZU dZdZded<   ddZddZddZddZ	e
d	        Zd
 Ze
d        ZddZe
d        Zd fdZdd dZd!dZd Zd Zd Ze
d        Z	 	 	 	 d"	 	 	 	 	 d#dZd$dZ	 d%	 	 	 	 	 d#dZd Z xZS )&	VeraModela=  
    Creates Vector-based Random Matrix Adaptation (Vera) model from a pretrained transformers model.

    Args:
        model ([`~transformers.PreTrainedModel`]): The model to be adapted.
        config ([`VeraConfig`]): The configuration of the Vera model.
        adapter_name (`str`): The name of the adapter, defaults to `"default"`.
        low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
            Create empty adapter weights on meta device. Useful to speed up the loading process.

    Returns:
        `torch.nn.Module`: The Vera model.

    Example:

        ```py
        >>> from transformers import AutoModelForCausalLM
        >>> from peft import VeraConfig, get_peft_model

        >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
        >>> config = VeraConfig(r=128)
        >>> model = get_peft_model(base_model, config)
        ```

    **Attributes**:
        - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
        - **peft_config** ([`VeraConfig`]): The configuration of the Vera model.
    vera_lambda_strprefixc                   | j                  | j                        }| j                  ||      }t        || j                        }d}| j                  j	                         D ]  \  }}| j                  ||      st        |t        j                        r|j                  |j                  f}n\t        |t              rKt        |j                  d      r|j                  j                  n|j                  j                  }|ddd   }n||}||k7  st!        d t#        ||      D              } |d}t%        |      |S )z
        Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA.

        This will be used for determining the size of the shared vera_A and vera_B matrices.
        Nds_shapec              3  :   K   | ]  \  }}t        ||        y wN)max).0abs      r,   	<genexpr>z&VeraModel._find_dim.<locals>.<genexpr>   s     %]DAqc!Qi%]s   z[No layers types compatible with VeRA were found. Please check `peft_config.target_modules`.)get_model_configmodel_prepare_adapter_configr   named_modules_check_target_module_existsr   nnr   out_featuresin_featuresr
   hasattrweightr4   shaper   zip
ValueError)	selfconfigmodel_configpeft_configlargest_shapekeymodulemodule_shapemsgs	            r,   	_find_dimzVeraModel._find_dimh   s)    ,,TZZ8226<H6{DJJO::335 	^KC33KE&")),%22F4F4FFFF+9@PZ9[v}}55agananatat+DbD1$ ,}, %%]C|<\%] ]#	^&  oCS/!    c                   | j                  |      \  }}t        i |j                        | _        t        i |j                        | _        t        j                  d      j                  |j                        }t        |j                  |f|      }t        ||j                  f|      }|| j                  |<   || j                  |<   y )N)
persistentcpudevicer   )rS   r   save_projectionvera_Avera_Br    	Generatormanual_seedprojection_prng_keyr-   r)rJ   rK   adapter_namelinear_out_dimlinear_in_dimr   r[   r\   s           r,   _init_vera_A_vera_BzVeraModel._init_vera_A_vera_B   s    (,v(>% !0F0FG 0F0FG OO51==f>X>XY	-8IN9YO$*L!$*L!rT   c                (    | j                  ||       y r7   )rd   )rJ   r>   rK   ra   s       r,   _pre_injection_hookzVeraModel._pre_injection_hook   s      6rT   c                   t        | j                        dkD  r1|j                  dk7  r"t        | j                  j
                   d      | j                  j                         D ]F  }||u r|j                  |j                  k7  s"t        d|j                  d|j                   d       t        | j                  j                         D ch c]  }|j                   c}      }t        |      dkD  rt        d|       yc c}w )	z
        A helper method to check the config when a new adapter is being added.

        Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.

        r   nonezf supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters.z_Vera PRNG initialisation key must be the same for all adapters. Got config.projection_prng_key=z but previous config had .zcVeRA projection weights must be saved for all adapters or none, but got multiple different values: N)
lenrM   biasrI   	__class____name__valuesr_   sortedrZ   )rJ   rK   existing_configsave_project_unique_valuess       r,   _check_new_adapter_configz#VeraModel._check_new_adapter_config   s       !A%FKK6,A>>**+ ,7 7 
  $//668 		O&(22f6P6PP v[a[u[uZw x++:+N+N*OqR 		 &,RVRbRbRiRiRk,lV-C-C,l%m")*Q.u-.0  / -ms   Dc                    t        | |      S r7   )r   )vera_configrO   s     r,   rA   z%VeraModel._check_target_module_exists   s    )+s;;rT   c           
        |t        d      |j                  }t        |d      xr |j                  d u}	||j                  |j
                  |j                  t        | j                  dd      t        | j                  dd      d}
|	|
d<   t        |t              rK|j                  || j                  | j                  ||j                  |j                  |j                         y  | j                  || j                  | j                  ||fi |
}|| j                   vr|j#                  d       | j%                  ||||       y )NzCurrent Key shouldn't be `None`rk   is_loaded_in_8bitFis_loaded_in_4bit)r`   vera_dropoutfan_in_fan_outinit_weightsloaded_in_8bitloaded_in_4bit)	d_initial)rI   r`   rE   rk   rx   ry   rz   getattrr>   r   r   update_layerr[   r\   r}   _create_new_moduleactive_adapterrequires_grad__replace_module)rJ   rt   ra   targettarget_nameparentcurrent_keyoptional_kwargsr`   rk   kwargs
new_modules               r,   _create_and_replacezVeraModel._create_and_replace   s3    >??MMvv&B6;;d+B'44)88'44%djj2EuM%djj2EuM
 vff%((((%//    100dkk4;;XdflwpvwJ4#6#66))%0  j&IrT   c                   t        | ||       t        |d      r|j                  }t        |d      s.|j                  |_        t        |d      r|j                  |_        t        |dd       ^t        |d      r|j                  |j                  _        n|j                  |_        |j                  |j                  j                         t        j                  d      |j                         D ]R  \  }}d|v st        fd|j                         D              r.|j                  |j                  j                         T y )N
base_layerrk   statemetavera_c              3  <   K   | ]  }|j                   k(    y wr7   rX   )r9   pr   s     r,   r<   z,VeraModel._replace_module.<locals>.<genexpr>	  s     I188t+Is   )setattrrE   r   rF   rk   r~   r   torY   r    r@   any
parameters)r   
child_namer   childnamerP   r   s         @r,   r   zVeraModel._replace_module   s    
J/
 5,'$$Ez<0 %Juf%"'**
5'4(4z<0.3kk
%%+#(;;
 MM%,,--.||F#&446 	3LD&$IV5F5F5HIIIIell112	3rT   c                   |j                         D ]  \  }}| j                  |vsd|_         | j                  D ]  }| j                  |   j
                  }|dk(  r"|dk(  r%|j                         D ]  \  }}d|v sd|_         L|dk(  rR|j                         D ]>  }t        |t              st        |d      s!|j
                  .d|j
                  _        @ t        d| d       y )	NFrh   allrk   T	vera_onlyzRequested bias: z, is not implemented.)named_parametersr2   requires_gradactive_adaptersrM   rk   modulesr   r   rE   NotImplementedError)rJ   r>   nr   r   rk   ms          r,    _mark_only_adapters_as_trainablez*VeraModel._mark_only_adapters_as_trainable  s    **, 	(DAq{{!#"'	( #22 	ZN##N388Dv~u}!224 /DAq{*./ $ 4A!!Y/GAv4F166K]/3,4 *,<TFBW*XYY	ZrT   c                x   t               r
dd l}ddlm} t	               rddlm} |j                  dd      }	|j                  dd      }
|j                  dd      }t        |t              r|j                         }n|}|
rt        |j                  j                        rc|j                         }|j                  |j                  j                  |j                  j                   |j"                  d	        ||||fi |S |rt        |j                  j
                        rc|j                         }|j                  |j$                  |j&                  j(                  |j&                  j*                  d
        ||||fi |S t        |t,        j                  j.                        r'|d   rmt1        j2                  d       dx|d<   | _        nKt        |t6              r,d|d<   |d   s1t1        j2                  d       dx|d<   | _        nt9        d| d      t/        ||||f|	| j:                  d|}|S )Nr   r   )Linear8bitLt)
Linear4bitrk   Fr{   r|   )has_fp16_weights	thresholdindex)compute_dtypecompress_statistics
quant_typery   zjfan_in_fan_out is set to True but the target module is `torch.nn.Linear`. Setting fan_in_fan_out to False.Tis_target_conv_1d_layerzafan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True.zTarget module z is not supported. Currently, only the following modules are supported: `torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`.)rk   r}   )r   bitsandbytesbnbr   r   r   popgetr   r   get_base_layerrB   copyupdater   r   r   r   r   rF   r   r   r    r   warningswarnry   r
   rI   r}   )rt   r[   r\   ra   r   r   r   r   r   rk   r{   r|   target_base_layereightbit_kwargsfourbit_kwargsr   s                   r,   r   zVeraModel._create_new_module!  s8    &) "'zz&%($4e<$4e<fn- & 5 5 7 &j):CFF<O<OP$kkmO""(9(?(?(P(P!2!8!8!B!B.44  ffXXX
+<cff>O>O P#[[]N!!%6%D%D+<+C+C+W+W"3":":"E"E flFFUnUU)588??;&'7 INM'(;+E)6204F,-*+w IML'(;+E  )J J  	

 !++
 

 rT   c                z    	 t         |   |      S # t        $ r |dk(  r t        | j                  |      cY S w xY w)z1Forward missing attributes to the wrapped module.r>   )super__getattr__AttributeErrorr~   r>   )rJ   r   rl   s     r,   r   zVeraModel.__getattr__h  sB    	-7&t,, 	-w4::t,,	-s    %::c           
        i }| j                   j                         D ]U  \  }}t        |      j                         D ci c]$  \  }}|t        |t              r|j
                  n|& }}}|sQd|d<   W |<   |S c c}}w )NTinference_mode)rM   itemsr   r   r   value)rJ   	inferenceconfig_dictrO   r   kvrK   s           r,   get_peft_config_as_dictz!VeraModel.get_peft_config_as_dictq  s    **002 	0JCKQRW=K^K^K`a41aaJq$$7Q>aFa+/'(	0 "C	 bs   )A<c                    | j                   j                         D ]*  }t        |t        t        f      s|j                  |       , y r7   )r>   r   r   r   r   enable_adapters)rJ   enabledrP   s      r,   _set_adapter_layerszVeraModel._set_adapter_layersz  s<    jj((* 	0F&>3G"HI&&w/	0rT   c                (    | j                  d       y )NTr   )r   rJ   s    r,   enable_adapter_layerszVeraModel.enable_adapter_layers  s       .rT   c                    | j                   D ]<  }| j                  |   j                  }|dk7  s"d| d}t        j                  |       > | j                  d       y )Nrh   z>Careful, disabling adapter layers with bias configured to be 'zL' does not produce the same output as the base model would without adaption.Fr   )r   rM   rk   r   r   r   )rJ   r   valrR   s       r,   disable_adapter_layersz VeraModel.disable_adapter_layers  sp    "22 	#N"">277Cf}TUXTY ZG G  c"	# 	   /rT   c                    | j                   j                         D ]U  }t        |t              s|j                  r%t        j                  d       |j                          |j                  |       W || _	        y )NzJAdapter cannot be set when the model is merged. Unmerging the model first.)
r>   r   r   r   mergedr   r   unmergeset_adapterr   )rJ   ra   rP   s      r,   r   zVeraModel.set_adapter  s^    jj((* 	1F&),==MM"noNN$""<0	1 +rT   c                ~    | j                   0|d   t        vrt        d      t        t        |d            | _         | S )N
model_typez0Please specify `target_modules` in `peft_config`)target_modulesr   rI   set)rM   rL   s     r,   r?   z!VeraModel._prepare_adapter_config  sK    %%-L)1cc !STT),B<P\C]^*K& rT   c                   | j                   j                         D cg c]  \  }}d|vs| }}}d|rdndz   dz   }t        || |      D ]  }	 t        | j                   |      \  }	}
}t        |
d      r8|r|
j                  ||       | j                  |	||
j                         |
       bt        |
t              sst        |	||
j                  |
j                             | j                   S c c}}w # t        $ r Y w xY w)	Nveraz
Unloading zand merging  r>   )disabledescr   )
safe_mergeadapter_names)r>   r@   r	   r   r   rE   merger   r   r   r   r   modules_to_saver   )rJ   r   progressbarr   r   rO   _key_listr   r   r   r   s               r,   _unload_and_optionally_mergez&VeraModel._unload_and_optionally_merge  s    '+jj&>&>&@VFCFRUDUCVV~B?'Ik/E 	\C.=djj#.N+ v|,LLJmLT$$V[&:O:O:QSYZF$89V-C-CFDYDY-Z[	\ zz# W
 " s   C2C2C88	DDc                   |t        | j                  j                               vrt        d| d      | j                  |= | j                  j                         D cg c]  \  }}d|vs| }}}d}|D ]P  }t        | j                  |      \  }}}t        |t              s.|j                  |       |B|j                  dd }R |xs g | _
        | j                  ||       yc c}}w )z
        Deletes an existing adapter.

        Args:
            adapter_name (str): Name of the adapter to be deleted.
        zAdapter z does not existr   N)new_active_adapters)listrM   keysrI   r>   r@   r   r   r   delete_adapterr   _delete_auxiliary_adapter)rJ   ra   rO   r   r   new_adapterr   s          r,   r   zVeraModel.delete_adapter  s     tD$4$4$9$9$;<<x~_EFF\* '+jj&>&>&@VFCFRUDUCVV 	;C*4::s;LAvq&),%%l3&"("7"7":K	; */R&&|&U Ws   C),C)c                *    | j                  |||      S )aH  
        This method merges the Vera layers into the base model. This is needed if someone wants to use the base model
        as a standalone model.

        Args:
            progressbar (`bool`):
                whether to show a progressbar indicating the unload and merge process
            safe_merge (`bool`):
                whether to activate the safe merging check to check if there is any potential Nan in the adapter
                weights
            adapter_names (`list[str]`, *optional*):
                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
                to `None`.

        Example:

        ```py
        >>> from transformers import AutoModelForCausalLM
        >>> from peft import PeftModel

        >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
        >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
        >>> model = PeftModel.from_pretrained(base_model, peft_model_id)
        >>> merged_model = model.merge_and_unload()
        ```
        )r   r   r   r   )rJ   r   r   r   s       r,   merge_and_unloadzVeraModel.merge_and_unload  s#    : 00#
- 1 
 	
rT   c                &    | j                  d      S )z
        Gets back the base model by removing all the Vera modules without merging. This gives back the original base
        model.
        F)r   r   r   s    r,   unloadzVeraModel.unload  s    
 00u0==rT   )returnztuple[int, int])rK   r   ra   r1   r   None)r>   	nn.ModulerK   r   ra   r1   r   r   )rK   r   r   r   )r>   r   r   r   )r   r1   )F)r   bool)T)TFFN)r   r   r   r   r   zOptional[list[str]])ra   r1   )FFN)rm   
__module____qualname____doc__r2   __annotations__rS   rd   rf   rr   staticmethodrA   r   r   r   r   r   r   r   r   r   r   r?   r   r   r   r   __classcell__)rl   s   @r,   r/   r/   H   s	   : !FC #J+7 D < <(JT 3 38Z* D DL-0
/	0+   ! -1  	
 +6V2 im

59
Re
B>rT   r/   )r&   z$Union[torch.Tensor, tuple[int, ...]]r   ztorch.Generatorr   ztorch.Tensor))
__future__r   r"   r   dataclassesr   enumr   typingr   r   r    torch.nnrB   torch.nn.initr   r	   transformers.pytorch_utilsr
   peft.import_utilsr   r   peft.tuners.tuners_utilsr   r   r   
peft.utilsr   r   r   _buffer_dictr   tuners_utilsr   rK   r   layerr   r   r-   r/    rT   r,   <module>r     sz    #     "   0  - E Z Z  & ;  $C9CC C:q>	 q>rT   