Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/course/en/chapter13/grpo_math.py
Views: 2935
import marimo12__generated_with = "0.10.6"3app = marimo.App(width="medium")456@app.cell(hide_code=True)7def _():8import marimo as mo910mo.md(11"""12## Math Problem Reward Function1314This example demonstrates a reward function for math problems15with verifiable answers.1617The slider controls the tolerance for approximate answers.18"""19)20return (mo,)212223@app.cell(hide_code=True)24def _(mo):25tolerance_slider = mo.ui.slider(26start=0, stop=25, step=5, value=0, label="Tolerance"27)28tolerance_slider29return (tolerance_slider,)303132@app.cell(hide_code=True)33def _(mo, tolerance_slider):34import plotly.express as px3536# Sample math problems and their correct answers37problems = [38"What is 5 + 7?",39"Calculate 12 * 6",40"What is 100 / 4?",41"Solve for x: 3x = 15",42"What is the square root of 81?",43]4445# Correct answers46correct_answers = [12, 72, 25, 5, 9]4748# Model completions (simulated)49model_completions = [5012, # Correct5192, # Wrong5215, # Wrong530, # Wrong549, # Correct55]5657def extract_final_answer(completion):58"""59In a real scenario, this would parse the completion to extract the answer.60For this example, we're using direct integer completions.61"""62return completion6364def problem_reward(completions, answers, tolerance=0):65"""66Reward function for math problems with verifiable answers6768Args:69completions: list of completions to evaluate70answers: list of correct answers to the problems71tolerance: allowed difference for correct answers7273Returns:74list of rewards for each completion75"""76rewards = []7778for completion, correct_answer in zip(completions, answers):79try:80# Extract the answer from the completion81answer = extract_final_answer(completion)8283# Calculate how close the answer is84difference = abs(answer - correct_answer)8586# Binary reward with tolerance87if difference <= tolerance:88reward = 1.089else:90# Partial credit based on how close the answer is91# Scale with problem size92max_diff = max(correct_answer * 0.5, 10)93reward = max(0, 1 - (difference / max_diff))9495rewards.append(reward)96except Exception:97# If we can't parse an answer, give a low reward98rewards.append(0.0)99100return rewards101102# Calculate rewards103rewards = problem_reward(104completions=model_completions,105answers=correct_answers,106tolerance=tolerance_slider.value,107)108109# Display the results110results = []111for problem, correct, completion, reward in zip(112problems, correct_answers, model_completions, rewards113):114results.append(115{116"Problem": problem,117"Correct Answer": correct,118"Model Answer": completion,119"Difference": abs(correct - completion),120"Reward": reward,121}122)123124# Create a table view125mo.md("### Results")126mo.ui.table(results)127128# Create a bar chart129fig = px.bar(130results,131x="Problem",132y="Reward",133color="Difference",134hover_data=["Correct Answer", "Model Answer"],135title="Rewards by Problem",136)137mo.ui.plotly(fig)138139140if __name__ == "__main__":141app.run()142143144