-
Notifications
You must be signed in to change notification settings - Fork 193
Add Question 189: Compute Direct Preference Optimization Loss #583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
zhenhuan-yang
wants to merge
1
commit into
Open-Deep-ML:main
Choose a base branch
from
zhenhuan-yang:zhy-dpo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+236
−0
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| { | ||
| "id": "189", | ||
| "title": "Compute Direct Preference Optimization Loss", | ||
| "difficulty": "medium", | ||
| "category": "Deep Learning", | ||
| "video": "", | ||
| "likes": "0", | ||
| "dislikes": "0", | ||
| "contributor": [ | ||
| { | ||
| "profile_link": "https://github.com/zhenhuan-yang", | ||
| "name": "Zhenhuan Yang" | ||
| } | ||
| ], | ||
| "description": "## Task: Compute Direct Preference Optimization Loss\n\nImplement a function to compute the Direct Preference Optimization (DPO) loss, a technique used to fine-tune language models based on human preferences without requiring a separate reward model.\n\nGiven policy log probabilities for chosen and rejected responses, along with reference model log probabilities, your function should compute the DPO loss using the Bradley-Terry preference model.\n\nThe function should take the following inputs:\n- `policy_chosen_logps`: Log probabilities from the policy model for chosen responses\n- `policy_rejected_logps`: Log probabilities from the policy model for rejected responses\n- `reference_chosen_logps`: Log probabilities from the reference model for chosen responses\n- `reference_rejected_logps`: Log probabilities from the reference model for rejected responses\n- `beta`: Temperature parameter controlling the strength of the KL constraint (default: 0.1)\n\nReturn the average DPO loss across all examples as a float.", | ||
| "learn_section": "## Direct Preference Optimization (DPO)\n\nDirect Preference Optimization is a novel approach to aligning language models with human preferences. Unlike traditional Reinforcement Learning from Human Feedback (RLHF), which requires training a separate reward model and using reinforcement learning algorithms, DPO directly optimizes the policy using preference data.\n\n---\n\n### **Background: The Problem with RLHF**\n\nTraditional RLHF involves:\n1. Training a reward model on human preference data\n2. Using PPO or other RL algorithms to optimize the language model\n3. Maintaining both the policy model and reference model during training\n\nThis process is complex, unstable, and computationally expensive.\n\n---\n\n### **How DPO Works**\n\nDPO eliminates the need for a separate reward model by deriving a loss function directly from the preference data. It uses the Bradley-Terry model of preferences, which states that the probability of preferring response $y_w$ (chosen) over $y_l$ (rejected) is:\n\n$$\nP(y_w \\succ y_l | x) = \\frac{\\exp(r(x, y_w))}{\\exp(r(x, y_w)) + \\exp(r(x, y_l))}\n$$\n\nwhere $r(x, y)$ is the reward for response $y$ to prompt $x$.\n\n---\n\n### **Mathematical Formulation**\n\nThe key insight of DPO is that the optimal policy $\\pi^*$ can be expressed in closed form relative to the reward function and reference policy $\\pi_{ref}$:\n\n$$\n\\pi^*(y|x) = \\frac{1}{Z(x)} \\pi_{ref}(y|x) \\exp\\left(\\frac{1}{\\beta} r(x, y)\\right)\n$$\n\nwhere $\\beta$ is a temperature parameter controlling the deviation from the reference policy.\n\nBy rearranging, we can express the reward in terms of the policy:\n\n$$\nr(x, y) = \\beta \\log \\frac{\\pi^*(y|x)}{\\pi_{ref}(y|x)} + \\beta \\log Z(x)\n$$\n\nSubstituting this into the Bradley-Terry model and noting that $Z(x)$ cancels out, we get:\n\n$$\nP(y_w \\succ y_l | x) = \\sigma\\left(\\beta \\log \\frac{\\pi_\\theta(y_w|x)}{\\pi_{ref}(y_w|x)} - \\beta \\log \\frac{\\pi_\\theta(y_l|x)}{\\pi_{ref}(y_l|x)}\\right)\n$$\n\nwhere $\\sigma$ is the sigmoid function.\n\n---\n\n### **DPO Loss Function**\n\nThe DPO loss for a single example is:\n\n$$\n\\mathcal{L}_{DPO}(\\pi_\\theta) = -\\log \\sigma\\left(\\beta \\log \\frac{\\pi_\\theta(y_w|x)}{\\pi_{ref}(y_w|x)} - \\beta \\log \\frac{\\pi_\\theta(y_l|x)}{\\pi_{ref}(y_l|x)}\\right)\n$$\n\nIn terms of log probabilities:\n\n$$\n\\mathcal{L}_{DPO} = -\\log \\sigma\\left(\\beta \\left[\\log \\pi_\\theta(y_w|x) - \\log \\pi_{ref}(y_w|x) - \\log \\pi_\\theta(y_l|x) + \\log \\pi_{ref}(y_l|x)\\right]\\right)\n$$\n\nEquivalently:\n\n$$\n\\mathcal{L}_{DPO} = -\\log \\sigma\\left(\\beta \\left[(\\log \\pi_\\theta(y_w|x) - \\log \\pi_{ref}(y_w|x)) - (\\log \\pi_\\theta(y_l|x) - \\log \\pi_{ref}(y_l|x))\\right]\\right)\n$$\n\n---\n\n### **Implementation Notes**\n\n1. **Log Probabilities**: Work with log probabilities for numerical stability\n2. **Beta Parameter**: Typical values range from 0.1 to 0.5. Lower values enforce stronger KL constraints\n3. **Sigmoid Function**: Use the log-sigmoid for numerical stability: $\\log \\sigma(x) = -\\log(1 + e^{-x})$\n4. **Batch Processing**: Average the loss over all examples in the batch\n\n---\n\n### **Advantages of DPO**\n\n- **Simpler**: No need for a separate reward model or RL training loop\n- **More Stable**: Direct supervision on preferences avoids RL instabilities\n- **Efficient**: Trains faster than traditional RLHF approaches\n- **Effective**: Achieves comparable or better performance than RLHF\n\n---\n\n### **Applications**\n\nDPO is widely used for:\n- Fine-tuning large language models (LLMs) to follow instructions\n- Aligning chatbots with human preferences\n- Reducing harmful or biased outputs\n- Improving factuality and helpfulness of AI assistants\n\nUnderstanding DPO is essential for modern language model alignment and is becoming the standard approach for post-training optimization.", | ||
| "starter_code": "import numpy as np\n\ndef compute_dpo_loss(policy_chosen_logps: np.ndarray, policy_rejected_logps: np.ndarray,\n reference_chosen_logps: np.ndarray, reference_rejected_logps: np.ndarray,\n beta: float = 0.1) -> float:\n # Your code here\n pass", | ||
| "solution": "import numpy as np\n\ndef compute_dpo_loss(policy_chosen_logps: np.ndarray, policy_rejected_logps: np.ndarray,\n reference_chosen_logps: np.ndarray, reference_rejected_logps: np.ndarray,\n beta: float = 0.1) -> float:\n \"\"\"\n Compute the Direct Preference Optimization (DPO) loss.\n\n Args:\n policy_chosen_logps: Log probabilities from policy model for chosen responses\n policy_rejected_logps: Log probabilities from policy model for rejected responses\n reference_chosen_logps: Log probabilities from reference model for chosen responses\n reference_rejected_logps: Log probabilities from reference model for rejected responses\n beta: Temperature parameter (default: 0.1)\n\n Returns:\n Average DPO loss across all examples\n \"\"\"\n # Compute log ratios for chosen and rejected responses\n policy_log_ratios = policy_chosen_logps - policy_rejected_logps\n reference_log_ratios = reference_chosen_logps - reference_rejected_logps\n\n # Compute the logits for the Bradley-Terry model\n logits = beta * (policy_log_ratios - reference_log_ratios)\n\n # Compute loss using log-sigmoid for numerical stability\n # Loss = -log(sigmoid(logits)) = log(1 + exp(-logits))\n losses = np.log1p(np.exp(-logits))\n\n # Return the average loss\n return float(np.mean(losses))", | ||
| "example": { | ||
| "input": "import numpy as np\n\npolicy_chosen_logps = np.array([-1.0, -0.5])\npolicy_rejected_logps = np.array([-2.0, -1.5])\nreference_chosen_logps = np.array([-1.2, -0.6])\nreference_rejected_logps = np.array([-1.8, -1.4])\nbeta = 0.1\n\nloss = compute_dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, beta)\nprint(round(loss, 4))", | ||
| "output": "0.6783", | ||
| "reasoning": "The DPO loss computes the preference loss by comparing policy and reference log-probability ratios. With beta=0.1, the logits are 0.1 * ([1.0, 1.0] - [0.6, 0.8]) = [0.04, 0.02]. The loss for each example is log(1 + exp(-logit)), giving approximately [0.67334, 0.68319], with a rounded average of 0.6783." | ||
| }, | ||
| "test_cases": [ | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-0.1])\npolicy_rejected = np.array([-10.0])\nreference_chosen = np.array([-5.0])\nreference_rejected = np.array([-5.0])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.1)\nprint(round(result, 4))", | ||
| "expected_output": "0.316" | ||
| }, | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-1.0, -0.5])\npolicy_rejected = np.array([-2.0, -1.5])\nreference_chosen = np.array([-1.0, -0.5])\nreference_rejected = np.array([-2.0, -1.5])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.1)\nprint(round(result, 4))", | ||
| "expected_output": "0.6931" | ||
| }, | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-1.0, -0.5])\npolicy_rejected = np.array([-2.0, -1.5])\nreference_chosen = np.array([-1.2, -0.6])\nreference_rejected = np.array([-1.8, -1.4])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.1)\nprint(round(result, 4))", | ||
| "expected_output": "0.6783" | ||
| }, | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-0.5, -1.0, -0.8])\npolicy_rejected = np.array([-3.0, -2.5, -3.2])\nreference_chosen = np.array([-1.0, -1.5, -1.2])\nreference_rejected = np.array([-2.0, -2.0, -2.5])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.2)\nprint(round(result, 4))", | ||
| "expected_output": "0.5806" | ||
| } | ||
| ] | ||
| } |
14 changes: 14 additions & 0 deletions
14
questions/189_compute-direct-preference-loss/description.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| ## Task: Compute Direct Preference Optimization Loss | ||
|
|
||
| Implement a function to compute the Direct Preference Optimization (DPO) loss, a technique used to fine-tune language models based on human preferences without requiring a separate reward model. | ||
|
|
||
| Given policy log probabilities for chosen and rejected responses, along with reference model log probabilities, your function should compute the DPO loss using the Bradley-Terry preference model. | ||
|
|
||
| The function should take the following inputs: | ||
| - `policy_chosen_logps`: Log probabilities from the policy model for chosen responses | ||
| - `policy_rejected_logps`: Log probabilities from the policy model for rejected responses | ||
| - `reference_chosen_logps`: Log probabilities from the reference model for chosen responses | ||
| - `reference_rejected_logps`: Log probabilities from the reference model for rejected responses | ||
| - `beta`: Temperature parameter controlling the strength of the KL constraint (default: 0.1) | ||
|
|
||
| Return the average DPO loss across all examples as a float. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| { | ||
| "input": "import numpy as np\n\npolicy_chosen_logps = np.array([-1.0, -0.5])\npolicy_rejected_logps = np.array([-2.0, -1.5])\nreference_chosen_logps = np.array([-1.2, -0.6])\nreference_rejected_logps = np.array([-1.8, -1.4])\nbeta = 0.1\n\nloss = compute_dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, beta)\nprint(round(loss, 4))", | ||
| "output": "0.6783", | ||
| "reasoning": "The DPO loss computes the preference loss by comparing policy and reference log-probability ratios. With beta=0.1, the logits are 0.1 * ([1.0, 1.0] - [0.6, 0.8]) = [0.04, 0.02]. The loss for each example is log(1 + exp(-logit)), giving approximately [0.67334, 0.68319], with a rounded average of 0.6783." | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| ## Direct Preference Optimization (DPO) | ||
|
|
||
| Direct Preference Optimization is a novel approach to aligning language models with human preferences. Unlike traditional Reinforcement Learning from Human Feedback (RLHF), which requires training a separate reward model and using reinforcement learning algorithms, DPO directly optimizes the policy using preference data. | ||
|
|
||
| --- | ||
|
|
||
| ### **Background: The Problem with RLHF** | ||
|
|
||
| Traditional RLHF involves: | ||
| 1. Training a reward model on human preference data | ||
| 2. Using PPO or other RL algorithms to optimize the language model | ||
| 3. Maintaining both the policy model and reference model during training | ||
|
|
||
| This process is complex, unstable, and computationally expensive. | ||
|
|
||
| --- | ||
|
|
||
| ### **How DPO Works** | ||
|
|
||
| DPO eliminates the need for a separate reward model by deriving a loss function directly from the preference data. It uses the Bradley-Terry model of preferences, which states that the probability of preferring response $y_w$ (chosen) over $y_l$ (rejected) is: | ||
|
|
||
| $$ | ||
| P(y_w \succ y_l | x) = \frac{\exp(r(x, y_w))}{\exp(r(x, y_w)) + \exp(r(x, y_l))} | ||
| $$ | ||
|
|
||
| where $r(x, y)$ is the reward for response $y$ to prompt $x$. | ||
|
|
||
| --- | ||
|
|
||
| ### **Mathematical Formulation** | ||
|
|
||
| The key insight of DPO is that the optimal policy $\pi^*$ can be expressed in closed form relative to the reward function and reference policy $\pi_{ref}$: | ||
|
|
||
| $$ | ||
| \pi^*(y|x) = \frac{1}{Z(x)} \pi_{ref}(y|x) \exp\left(\frac{1}{\beta} r(x, y)\right) | ||
| $$ | ||
|
|
||
| where $\beta$ is a temperature parameter controlling the deviation from the reference policy. | ||
|
|
||
| By rearranging, we can express the reward in terms of the policy: | ||
|
|
||
| $$ | ||
| r(x, y) = \beta \log \frac{\pi^*(y|x)}{\pi_{ref}(y|x)} + \beta \log Z(x) | ||
| $$ | ||
|
|
||
| Substituting this into the Bradley-Terry model and noting that $Z(x)$ cancels out, we get: | ||
|
|
||
| $$ | ||
| P(y_w \succ y_l | x) = \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)}\right) | ||
| $$ | ||
|
|
||
| where $\sigma$ is the sigmoid function. | ||
|
|
||
| --- | ||
|
|
||
| ### **DPO Loss Function** | ||
|
|
||
| The DPO loss for a single example is: | ||
|
|
||
| $$ | ||
| \mathcal{L}_{DPO}(\pi_\theta) = -\log \sigma\left(\beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)}\right) | ||
| $$ | ||
|
|
||
| In terms of log probabilities: | ||
|
|
||
| $$ | ||
| \mathcal{L}_{DPO} = -\log \sigma\left(\beta \left[\log \pi_\theta(y_w|x) - \log \pi_{ref}(y_w|x) - \log \pi_\theta(y_l|x) + \log \pi_{ref}(y_l|x)\right]\right) | ||
| $$ | ||
|
|
||
| Equivalently: | ||
|
|
||
| $$ | ||
| \mathcal{L}_{DPO} = -\log \sigma\left(\beta \left[(\log \pi_\theta(y_w|x) - \log \pi_{ref}(y_w|x)) - (\log \pi_\theta(y_l|x) - \log \pi_{ref}(y_l|x))\right]\right) | ||
| $$ | ||
|
|
||
| --- | ||
|
|
||
| ### **Implementation Notes** | ||
|
|
||
| 1. **Log Probabilities**: Work with log probabilities for numerical stability | ||
| 2. **Beta Parameter**: Typical values range from 0.1 to 0.5. Lower values enforce stronger KL constraints | ||
| 3. **Sigmoid Function**: Use the log-sigmoid for numerical stability: $\log \sigma(x) = -\log(1 + e^{-x})$ | ||
| 4. **Batch Processing**: Average the loss over all examples in the batch | ||
|
|
||
| --- | ||
|
|
||
| ### **Advantages of DPO** | ||
|
|
||
| - **Simpler**: No need for a separate reward model or RL training loop | ||
| - **More Stable**: Direct supervision on preferences avoids RL instabilities | ||
| - **Efficient**: Trains faster than traditional RLHF approaches | ||
| - **Effective**: Achieves comparable or better performance than RLHF | ||
|
|
||
| --- | ||
|
|
||
| ### **Applications** | ||
|
|
||
| DPO is widely used for: | ||
| - Fine-tuning large language models (LLMs) to follow instructions | ||
| - Aligning chatbots with human preferences | ||
| - Reducing harmful or biased outputs | ||
| - Improving factuality and helpfulness of AI assistants | ||
|
|
||
| Understanding DPO is essential for modern language model alignment and is becoming the standard approach for post-training optimization. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| { | ||
| "id": "189", | ||
| "title": "Compute Direct Preference Optimization Loss", | ||
| "difficulty": "medium", | ||
| "category": "Deep Learning", | ||
| "video": "", | ||
| "likes": "0", | ||
| "dislikes": "0", | ||
| "contributor": [ | ||
| { | ||
| "profile_link": "https://github.com/zhenhuan-yang", | ||
| "name": "Zhenhuan Yang" | ||
| } | ||
| ] | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import numpy as np | ||
|
|
||
| def compute_dpo_loss(policy_chosen_logps: np.ndarray, policy_rejected_logps: np.ndarray, | ||
| reference_chosen_logps: np.ndarray, reference_rejected_logps: np.ndarray, | ||
| beta: float = 0.1) -> float: | ||
| """ | ||
| Compute the Direct Preference Optimization (DPO) loss. | ||
|
|
||
| Args: | ||
| policy_chosen_logps: Log probabilities from policy model for chosen responses | ||
| policy_rejected_logps: Log probabilities from policy model for rejected responses | ||
| reference_chosen_logps: Log probabilities from reference model for chosen responses | ||
| reference_rejected_logps: Log probabilities from reference model for rejected responses | ||
| beta: Temperature parameter (default: 0.1) | ||
|
|
||
| Returns: | ||
| Average DPO loss across all examples | ||
| """ | ||
| # Compute log ratios for chosen and rejected responses | ||
| policy_log_ratios = policy_chosen_logps - policy_rejected_logps | ||
| reference_log_ratios = reference_chosen_logps - reference_rejected_logps | ||
|
|
||
| # Compute the logits for the Bradley-Terry model | ||
| logits = beta * (policy_log_ratios - reference_log_ratios) | ||
|
|
||
| # Compute loss using log-sigmoid for numerical stability | ||
| # Loss = -log(sigmoid(logits)) = log(1 + exp(-logits)) | ||
| losses = np.log1p(np.exp(-logits)) | ||
|
|
||
| # Return the average loss | ||
| return float(np.mean(losses)) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| import numpy as np | ||
|
|
||
| def compute_dpo_loss(policy_chosen_logps: np.ndarray, policy_rejected_logps: np.ndarray, | ||
| reference_chosen_logps: np.ndarray, reference_rejected_logps: np.ndarray, | ||
| beta: float = 0.1) -> float: | ||
| # Your code here | ||
| pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| [ | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-0.1])\npolicy_rejected = np.array([-10.0])\nreference_chosen = np.array([-5.0])\nreference_rejected = np.array([-5.0])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.1)\nprint(round(result, 4))", | ||
| "expected_output": "0.316" | ||
| }, | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-1.0, -0.5])\npolicy_rejected = np.array([-2.0, -1.5])\nreference_chosen = np.array([-1.0, -0.5])\nreference_rejected = np.array([-2.0, -1.5])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.1)\nprint(round(result, 4))", | ||
| "expected_output": "0.6931" | ||
| }, | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-1.0, -0.5])\npolicy_rejected = np.array([-2.0, -1.5])\nreference_chosen = np.array([-1.2, -0.6])\nreference_rejected = np.array([-1.8, -1.4])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.1)\nprint(round(result, 4))", | ||
| "expected_output": "0.6783" | ||
| }, | ||
| { | ||
| "test": "import numpy as np\npolicy_chosen = np.array([-0.5, -1.0, -0.8])\npolicy_rejected = np.array([-3.0, -2.5, -3.2])\nreference_chosen = np.array([-1.0, -1.5, -1.2])\nreference_rejected = np.array([-2.0, -2.0, -2.5])\nresult = compute_dpo_loss(policy_chosen, policy_rejected, reference_chosen, reference_rejected, beta=0.2)\nprint(round(result, 4))", | ||
| "expected_output": "0.5806" | ||
| } | ||
| ] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This overflows to inf when logits is large and negative
losses = np.logaddexp(0, -logits)this might be more stable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also add a test case to exploit this