
    birL                     P   d dl Zd dlZd dlmZmZ d dlmZmZ d dl	Z
d dlmZ d dlmZ d dlmZ ddlmZ  e       rd dlZ e       rd d	lmZ d
Z G d de      Z G d de      Z G d de      Z G d de      Z G d de      Z G d de      Z G d de      Z G d de      Zy)    N)ABCabstractmethod)OptionalUnion)Accelerator)InferenceClient)is_openai_available   )is_llm_blender_available)OpenAIa  I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective.

## Instruction

{{
    "instruction": """{prompt}""",
}}

## Model Outputs

Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier.

{{
    {{
        "model_identifier": "0",
        "output": """{response0}"""
    }},
    {{
        "model_identifier": "1",
        "output": """{response1}"""
    }}
}}

## Task

Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...).
c            
       @    e Zd ZdZeddee   dee   dedefd       Zy)		BaseJudgezb
    Base class for judges. The subclasses of this class should implement the `judge` method.
    promptscompletionsshuffle_orderreturnc                     t        d      )N3Judge subclasses must implement the `judge` method.NotImplementedErrorselfr   r   r   s       M/home/cdr/jupyterlab/.venv/lib/python3.12/site-packages/trl/trainer/judges.pyjudgezBaseJudge.judgeE   s    !"WXX    NT)	__name__
__module____qualname____doc__r   liststrboolr    r   r   r   r   @   sF     YT#Y YT#Y Yt Y_c Y Yr   r   c                   R    e Zd ZdZeddee   deee      dedeee      fd       Z	y)	BaseRankJudgea  
    Base class for LLM ranking judges.

    **Example**:
    ```python
    class MyRankJudge(BaseRankJudge):
        def judge(self, prompts, completions, shuffle_order=True):
            return ...  # Your ranking logic here


    judge = MyRankJudge()
    judge.judge(
        prompts=["The capital of France is", "The capital of Germany is"],
        completions=[[" Paris", " Marseille", "Lyon"], [" Munich", " Berlin"]],
    )  # [[0, 1, 2], [1, 0]]
    ```
    r   r   r   r   c                     t        d      )a8  
        Judge the completion for the given prompts and return the ranks of each completion.

        Args:
            prompts (`list[str]`):
                List of prompts.
            completions (`list[list[str]]`):
                List of completions list, where each element is a list of completions for the corresponding prompt.
            shuffle_order (`bool`, *optional*, defaults to `True`):
                Whether to shuffle the order of the completions to avoid positional bias.

        Returns:
            `list[list[int]]`:
                List of lists of idxs, where each list contains the ranks of the completions for the corresponding
                prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the
                third, and then the first.
        r   r   r   s       r   r   zBaseRankJudge.judge]   s    & ""WXXr   Nr   
r   r   r   r    r   r!   r"   r#   intr   r$   r   r   r&   r&   J   sW    $ YT#Y YT$s)_ YUY Yeijnorjset Y Yr   r&   c                   L    e Zd ZdZeddee   deee      dedee   fd       Z	y)	BasePairwiseJudgez)
    Base class for pairwise judges.
    r   r   r   r   c                     t        d      )a]  
        Judge the completion pairs for the given prompts.

        Args:
            prompts (`list[str]`):
                List of prompts.
            completions (`list[list[str]]`):
                List of completions pairs, where each element is a pair of completions for the corresponding prompt.
            shuffle_order (`bool`, *optional*, defaults to `True`):
                Whether to shuffle the order of the completions to avoid positional bias.

        Returns:
            `list[int]`:
                List of idxs, where each idx is the rank of the best completion for the corresponding prompt. E.g., `1`
                means that the second completion (`idx=1`) is the best.

        Note:
            If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the
            preference has failed. For instance, this could occur if the underlying language model returned an invalid
            answer. In such cases, the caller should handle these invalid indices appropriately, possibly by
            implementing fallback logic or error handling.
        r   r   r   s       r   r   zBasePairwiseJudge.judgex   s    0 ""WXXr   Nr   r(   r$   r   r   r+   r+   s   sQ     YT#Y YT$s)_ YUY Yeijmen Y Yr   r+   c                   Z    e Zd ZdZe	 	 d	dee   dee   deee      dedee	   f
d       Z
y)
BaseBinaryJudgez'
    Base class for binary judges.
    Nr   r   gold_completionsr   r   c                     t        d      )ar  
        Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint.

        This base class should be used to implement binary evaluations as done in section 4.1.4 of the [CGPO
        paper](https://huggingface.co/papers/2409.20370). It is relevant for assessing whether a prompt completion pair
        satisfies a specific contraint.

        Args:
            prompts (`list[str]`): List of prompts.
            completions (`list[str]`): List of completions.
            gold_completions (`list[str]`, `optional`): List of gold completions if it exists.
            shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias.

        Returns:
            list[int]: A list of binary labels:
                - 1 indicates that the completion satisfies the evaluated constraint.
                - 0 indicates that the completion does not satisfy the evaluated constraint.

        Note:
            If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference
            has failed. For instance, this could occur if the underlying language model or rule based contraint
            returned an invalid answer. In such cases, the caller should handle these invalid indices appropriately,
            possibly by implementing fallback logic or error handling.
        r   r   )r   r   r   r/   r   s        r   r   zBaseBinaryJudge.judge   s    @ ""WXXr   NT)r   r   r   r    r   r!   r"   r   r#   r)   r   r$   r   r   r.   r.      so     
 15"YcY #YY #49-	Y
 Y 
cY Yr   r.   c                   `    e Zd ZdZd Z	 	 	 ddee   deee      dedededee	e
ef      fd	Zy
)PairRMJudgea  
    LLM judge based on the PairRM model from AllenAI.

    This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise
    comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the
    default Accelerator device.

    **Attributes**:

        blender (`llm_blender.Blender`):
            An instance of the Blender class from llm-blender.

    **Example**:
    ```python
    >>> pairrm_judge = PairRMJudge()
    >>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"]
    >>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
    >>> results = pairrm_judge.judge(prompts, completions)
    >>> print(results)  # [0, 1] (indicating the first completion is preferred for the first prompt and the second)
    ```

    <Tip>

    This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`.

    </Tip>
    c                     t               st        d      t        j                         | _        | j                  j                  dt               j                         y )NzOllm-blender is not installed. Please install it with `pip install llm-blender`.zllm-blender/PairRM)device)r   
ValueErrorllm_blenderBlenderblender
loadrankerr   r5   )r   s    r   __init__zPairRMJudge.__init__   sD    ')noo"**, 4[]=Q=QRr   r   r   r   return_scorestemperaturer   c                 h   t        |d         dk7  rt        d      |rTt        j                  j	                  ddgt        |            }t        ||      D cg c]  \  }}|r|ddd   n| }}}| j                  j                  |||d	      }	|s|	d
z  }	n|	|z  }	|r|	   dddddf   |	|<   |rbt        j                  |	dd      }
t        j                  |	|
z
        }|t        j                  |dd      z  }|dddf   j                         S |	dddf   j                         S c c}}w )a'  
        Judge the completion pairs for the given prompts using the PairRM model.

        Args:
            prompts (`list[str]`):
                List of prompts to judge.
            completions (`list[list[str]]`):
                List of completion pairs for each prompt.
            shuffle_order (`bool`, *optional*, defaults to `True`):
                Whether to shuffle the order of the completions to avoid positional bias.
            return_scores (`bool`, *optional*, defaults to `False`):
                If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*).
            temperature (`float`, *optional*, defaults to `1.0`):
                Temperature for scaling logits if `return_scores` is True.

        Returns:
            `Union[list[int, float]]`:
                If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which
                completion is preferred. If `return_scores` is `True`, returns softmax probabilities for the first
                completion.

        Raises:
            `ValueError`:
                If the number of completions per prompt is not exactly 2.

        Note:
            Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred).
        r   r
   z7PairRM judge requires exactly 2 completions per prompt.TFsizeN)r<   disable_tqdm   )axiskeepdims)lenr6   nprandomchoicezipr9   rankamaxexpsumtolist)r   r   r   r   r<   r=   	flip_maskflippairranks	logit_maxexp_logit_shiftedprobss                r   r   zPairRMJudge.judge   sG   J {1~!#VWW 		(($S\(JIJMiYdJefJD$4":47fKf !!';mbf!gQJE [ E $Y/4R48E) B>I "uy'8 9%/@rTX(YYEA;%%''A;%%''+ gs   D.N)TFg      ?)r   r   r   r    r;   r!   r"   r#   floatr   r)   r   r$   r   r   r3   r3      sx    8S ## @(c@( $s)_@( 	@(
 @( @( 
eCJ	 @(r   r3   c            
       f    e Zd ZdZ	 	 	 ddee   dee   fdZddee   deee      ded	ee	   fd
Z
y)HfPairwiseJudgea  
    Pairwise judge based on the Hugging Face API with chat completion.

    This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.

    Args:
        model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`):
            Model to use for the judge.
        token (`str`, *optional*):
            Hugging Face API token to use for the [`huggingface_hub.InferenceClient`].
        system_prompt (`str` or `None`, *optional*, defaults to `None`):
            The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
            prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
            inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
            response.
    Ntokensystem_promptc                 F    t        ||      | _        |xs t        | _        y )N)modelrZ   )r   clientDEFAULT_PAIRWISE_SYSTEM_PROMPTr[   )r   r]   rZ   r[   s       r   r;   zHfPairwiseJudge.__init__3  s      &E?*L.Lr   r   r   r   r   c                     |rTt         j                  j                  ddgt        |            }t	        ||      D cg c]  \  }}|r|d d d   n| }}} fd}t
        j                  j                         5 }t        |j                  |||            }	d d d        |r*t              D 
cg c]  \  }
}|s	|
   nd	|
   z
   }	}
}	S c c}}w # 1 sw Y   =xY wc c}}
w )NTFr?   rA   c                 &   j                   j                  | |d   |d         }j                  j                  d|dgd      }|j                  d   j
                  j                  }|dv rt        |      S t        j                  d| d	       y
)Nr   rC   prompt	response0	response1userrolecontent)messages
max_tokens01(Invalid response from the judge model: ''. Returning -1.rA   )
r[   formatr^   chat_completionchoicesmessageri   r)   loggingdebug)rc   
candidatesri   
completionresponser   s        r   get_rankz'HfPairwiseJudge.judge.<locals>.get_rankC  s    ((//vTUblmnbo/pG44[b?c>dqr4sJ!))!,44<<H:%8}$ H
Rbcdr   rC   )rG   rH   rI   rF   rJ   
concurrentfuturesThreadPoolExecutorr!   map	enumerater   r   r   r   rP   rQ   rR   rz   executorrS   is   `          r   r   zHfPairwiseJudge.judge<  s    		(($S\(JIJMiYdJefJD$4":47fKf	 224 	GhEFE	G KTU^K_`4TU1Xq58|;`E` - g	G 	G
 as   C;C0C C)z$meta-llama/Meta-Llama-3-70B-InstructNNr   )r   r   r   r    r   r"   r;   r!   r#   r)   r   r$   r   r   rY   rY   !  sk    & 5#'+	M }M  }	MT#Y T$s)_ UY eijmen r   rY   c            
       f    e Zd ZdZ	 ddee   deedf   fdZdde	e   de	e	e      de
d	e	e   fd
Zy)OpenAIPairwiseJudgea  
    Judge based on the OpenAI API.

    This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt.

    Args:
        model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`):
            Model to use for the judge.
        system_prompt (`str` or `None`, *optional*, defaults to `None`):
            System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system
            prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the
            inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token
            response.
        max_requests (`int` or `None`, *optional*, defaults to `1000`):
            Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit.
    Nr[   max_requestsc                     t               st        d      t               | _        || _        |xs t
        | _        || _        d| _        d| _	        y )NzLOpenAI client is not installed. Please install it with 'pip install openai'.r   F)
r	   r6   r   r^   r]   r_   r[   r   num_requests_warned)r   r]   r[   r   s       r   r;   zOpenAIPairwiseJudge.__init__k  sK     #$kllh
*L.L(r   r   r   r   r   c                      j                   ^ j                   j                   k\  rE j                  s*t        j                  d j                    d       d _        dgt        |      z  S |rTt        j                  j                  ddgt        |            }t        ||      D cg c]  \  }}|r|d d d   n| }}} fd}t        j                  j                         5 }t        |j                  |||            }	d d d        |r*t              D 
cg c]  \  }
}|s	|
   nd	|
   z
   }	}
} xj                  t        |      z  c_        	S c c}}w # 1 sw Y   [xY wc c}}
w )	Nz(Reached the maximum number of requests (z~). From now on, returning -1 instead.  To increase the limit, set `max_requests` to a higher value, or to `None` for no limit.TrA   Fr?   c                 h   j                   j                  | |d   |d         }d|dg}j                  j                  j                  j                  j                  |d      }|j                  d   j                  j                  }|dv rt        |      S t        j                  d| d	       y
)Nr   rC   rb   rf   rg   )r]   rj   rk   rl   ro   rp   rA   )r[   rq   r^   chatr   creater]   rs   rt   ri   r)   ru   rv   )rc   rw   ri   rj   rx   ry   r   s         r   rz   z+OpenAIPairwiseJudge.judge.<locals>.get_rank  s    ((//vTUblmnbo/pG!'G<=H))55<<4::X`mn<oJ!))!,44<<H:%8}$ H
Rbcdr   rC   )r   r   r   ru   warningrF   rG   rH   rI   rJ   r{   r|   r}   r!   r~   r   r   s   `          r   r   zOpenAIPairwiseJudge.judgew  sg   (T->->$BSBS-S<<>t?P?P>Q Ro o  $4#g,&& 		(($S\(JIJMiYdJefJD$4":47fKf		 224 	GhEFE	G KTU^K_`4TU1Xq58|;`E` 	S\) 5 g	G 	G
 as   )E%EE(E%)zgpt-4-turbo-previewNi  r   )r   r   r   r    r   r"   r   r)   r;   r!   r#   r   r$   r   r   r   r   Y  sh    $ rw
:B3-
^cdgimdm^n
(T#Y (T$s)_ (UY (eijmen (r   r   c                   b    e Zd ZdZdee   fdZ	 	 ddee   dee   deee      de	d	ee
   f
d
Zy)AllTrueJudgea   
    Unify the decision of multiple [`BaseBinaryJudge`] instances.

    Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`. If any judge
    returns `-1`, indicating a failure in its process, this judge will also return `-1`.

    Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370).

    Args:
    judges (`list[BaseBinaryJudge]`): A list of [`BaseBinaryJudge`] instances whose decisions will be unified.
    judgesc                     || _         y )N)r   )r   r   s     r   r;   zAllTrueJudge.__init__  s	    r   Nr   r   r/   r   r   c           	      \   | j                   D cg c]  }|j                  ||||       }}g }t        | D ]n  }t        d |D              rt	        d| d      d|v r|j                  d       :t        d |D              r|j                  d       ^|j                  d       p |S c c}w )Nc              3   $   K   | ]  }|d v 
 yw)>   r   rC   rA   Nr$   .0binary_judgments     r   	<genexpr>z%AllTrueJudge.judge.<locals>.<genexpr>  s     ]?*4]s   zInvalid binary judgment: z(, expected list of values in {0, 1, -1}.rA   c              3   &   K   | ]	  }|d k(    yw)rC   Nr$   r   s     r   r   z%AllTrueJudge.judge.<locals>.<genexpr>  s     Ro_)Rs   rC   r   )r   r   rJ   anyr6   appendall)	r   r   r   r/   r   r   all_binary_judgmentsoutputbinary_judgmentss	            r   r   zAllTrueJudge.judge  s     ]a\g\g 
SXEKK.>N 
  
  #%9 : 	!]L\]] /0@/AAkl 
 %%b!RAQRRa a 	! % 
s   B)r1   )r   r   r   r    r!   r.   r;   r"   r   r#   r)   r   r$   r   r   r   r     sj    
tO4  15"c #Y #49-	
  
cr   r   )concurrent.futuresr{   ru   abcr   r   typingr   r   numpyrG   
accelerater   huggingface_hubr   transformers.utilsr	   import_utilsr   r7   openair   r_   r   r&   r+   r.   r3   rY   r   r   r$   r   r   <module>r      s      # "  " + 2 3 " :Y Y&YC &YRY	 Y@%Yi %YPc(# c(L5' 5pF+ FR)? )r   