CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/course/en/chapter13/grpo_math.py
Views: 2935
1
import marimo
2
3
__generated_with = "0.10.6"
4
app = marimo.App(width="medium")
5
6
7
@app.cell(hide_code=True)
8
def _():
9
import marimo as mo
10
11
mo.md(
12
"""
13
## Math Problem Reward Function
14
15
This example demonstrates a reward function for math problems
16
with verifiable answers.
17
18
The slider controls the tolerance for approximate answers.
19
"""
20
)
21
return (mo,)
22
23
24
@app.cell(hide_code=True)
25
def _(mo):
26
tolerance_slider = mo.ui.slider(
27
start=0, stop=25, step=5, value=0, label="Tolerance"
28
)
29
tolerance_slider
30
return (tolerance_slider,)
31
32
33
@app.cell(hide_code=True)
34
def _(mo, tolerance_slider):
35
import plotly.express as px
36
37
# Sample math problems and their correct answers
38
problems = [
39
"What is 5 + 7?",
40
"Calculate 12 * 6",
41
"What is 100 / 4?",
42
"Solve for x: 3x = 15",
43
"What is the square root of 81?",
44
]
45
46
# Correct answers
47
correct_answers = [12, 72, 25, 5, 9]
48
49
# Model completions (simulated)
50
model_completions = [
51
12, # Correct
52
92, # Wrong
53
15, # Wrong
54
0, # Wrong
55
9, # Correct
56
]
57
58
def extract_final_answer(completion):
59
"""
60
In a real scenario, this would parse the completion to extract the answer.
61
For this example, we're using direct integer completions.
62
"""
63
return completion
64
65
def problem_reward(completions, answers, tolerance=0):
66
"""
67
Reward function for math problems with verifiable answers
68
69
Args:
70
completions: list of completions to evaluate
71
answers: list of correct answers to the problems
72
tolerance: allowed difference for correct answers
73
74
Returns:
75
list of rewards for each completion
76
"""
77
rewards = []
78
79
for completion, correct_answer in zip(completions, answers):
80
try:
81
# Extract the answer from the completion
82
answer = extract_final_answer(completion)
83
84
# Calculate how close the answer is
85
difference = abs(answer - correct_answer)
86
87
# Binary reward with tolerance
88
if difference <= tolerance:
89
reward = 1.0
90
else:
91
# Partial credit based on how close the answer is
92
# Scale with problem size
93
max_diff = max(correct_answer * 0.5, 10)
94
reward = max(0, 1 - (difference / max_diff))
95
96
rewards.append(reward)
97
except Exception:
98
# If we can't parse an answer, give a low reward
99
rewards.append(0.0)
100
101
return rewards
102
103
# Calculate rewards
104
rewards = problem_reward(
105
completions=model_completions,
106
answers=correct_answers,
107
tolerance=tolerance_slider.value,
108
)
109
110
# Display the results
111
results = []
112
for problem, correct, completion, reward in zip(
113
problems, correct_answers, model_completions, rewards
114
):
115
results.append(
116
{
117
"Problem": problem,
118
"Correct Answer": correct,
119
"Model Answer": completion,
120
"Difference": abs(correct - completion),
121
"Reward": reward,
122
}
123
)
124
125
# Create a table view
126
mo.md("### Results")
127
mo.ui.table(results)
128
129
# Create a bar chart
130
fig = px.bar(
131
results,
132
x="Problem",
133
y="Reward",
134
color="Difference",
135
hover_data=["Correct Answer", "Model Answer"],
136
title="Rewards by Problem",
137
)
138
mo.ui.plotly(fig)
139
140
141
if __name__ == "__main__":
142
app.run()
143
144