CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/course/en/chapter13/grpo_format.py
Views: 2935
1
import marimo
2
3
__generated_with = "0.10.6"
4
app = marimo.App(width="medium")
5
6
7
@app.cell(hide_code=True)
8
def _():
9
import marimo as mo
10
11
mo.md(
12
"""
13
## Structure Format Reward Function
14
15
This example demonstrates a reward function that evaluates whether
16
completions follow a specific structure format. Either a `<think>...</think><answer>...</answer>`
17
or a `<code>...</code><explanation>...</explanation>` format.
18
19
Use the buttons to select which structure format to reward.
20
"""
21
)
22
return (mo,)
23
24
25
@app.cell(hide_code=True)
26
def _(mo):
27
format_buttons = mo.ui.radio(
28
options=["think-answer", "code-explanation"],
29
value="think-answer",
30
label="Format to reward",
31
)
32
format_buttons
33
return (format_buttons,)
34
35
36
@app.cell(hide_code=True)
37
def _(mo, format_buttons):
38
import plotly.express as px
39
import re
40
41
# Sample completions with different formats
42
completions = [
43
# Think-answer format examples
44
"<think>Let me solve this step by step</think><answer>42</answer>",
45
"The answer is 15 without any special format",
46
"<code>print('Hello world')</code><explanation>This prints a greeting</explanation>",
47
# Code-explanation format examples
48
"<code>def add(a, b): return a + b</code><explanation>A function to add numbers</explanation>",
49
"<code>for i in range(10): print(i)</code>",
50
"<think>I should use a loop</think><code>while True: pass</code>",
51
]
52
53
# Create shortened versions for display
54
short_completions = [c[:30] + "..." if len(c) > 30 else c for c in completions]
55
56
def format_reward(completions, format_type="think-answer", **kwargs):
57
"""
58
Reward completions that follow the desired format structure
59
60
Args:
61
completions: list of completions to evaluate
62
format_type: which format structure to reward
63
64
Returns:
65
list of rewards and details
66
"""
67
# Define patterns for different formats
68
patterns = {
69
"think-answer": r"<think>.*?</think>\s*<answer>.*?</answer>",
70
"code-explanation": r"<code>.*?</code>\s*<explanation>.*?</explanation>",
71
}
72
73
# Select the pattern based on format_type
74
pattern = patterns.get(format_type, patterns["think-answer"])
75
76
rewards = []
77
details = []
78
categories = []
79
80
for completion in completions:
81
match = re.search(pattern, completion, re.DOTALL)
82
if match:
83
# Full match for the exact format
84
rewards.append(1.0)
85
details.append(f"Correct {format_type} format")
86
categories.append("Exact Format Match")
87
elif f"<{format_type.split('-')[0]}>" in completion:
88
# Partial match - has the opening tag of the format
89
rewards.append(0.5)
90
details.append(f"Has {format_type.split('-')[0]} tag but incomplete")
91
categories.append("Partial Format Match")
92
elif any(f"<{tag}>" in completion for tag in format_type.split("-")):
93
# Contains at least one of the required tags
94
rewards.append(0.2)
95
details.append("Has some required tags but wrong format")
96
categories.append("Some Tags Present")
97
else:
98
# No match at all
99
rewards.append(0.0)
100
details.append("Incorrect format")
101
categories.append("No Format Match")
102
103
return rewards, details, categories
104
105
# Calculate rewards
106
rewards, details, categories = format_reward(
107
completions=completions, format_type=format_buttons.value
108
)
109
110
# Display the results
111
results = []
112
for completion, reward, detail, category in zip(
113
short_completions, rewards, details, categories
114
):
115
results.append(
116
{
117
"Completion": completion,
118
"Reward": reward,
119
"Detail": detail,
120
"Category": category,
121
}
122
)
123
124
# Create a table view
125
mo.md(f"### Results for {format_buttons.value} format")
126
mo.ui.table(results)
127
128
# Create a bar chart comparing rewards by completion
129
fig = px.bar(
130
results,
131
x="Completion",
132
y="Reward",
133
color="Category",
134
title=f"Format Rewards by Completion ({format_buttons.value})",
135
hover_data=["Detail"],
136
)
137
mo.ui.plotly(fig)
138
139
140
if __name__ == "__main__":
141
app.run()
142
143