Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/course/en/chapter13/grpo_format.py
Views: 2935
import marimo12__generated_with = "0.10.6"3app = marimo.App(width="medium")456@app.cell(hide_code=True)7def _():8import marimo as mo910mo.md(11"""12## Structure Format Reward Function1314This example demonstrates a reward function that evaluates whether15completions follow a specific structure format. Either a `<think>...</think><answer>...</answer>`16or a `<code>...</code><explanation>...</explanation>` format.1718Use the buttons to select which structure format to reward.19"""20)21return (mo,)222324@app.cell(hide_code=True)25def _(mo):26format_buttons = mo.ui.radio(27options=["think-answer", "code-explanation"],28value="think-answer",29label="Format to reward",30)31format_buttons32return (format_buttons,)333435@app.cell(hide_code=True)36def _(mo, format_buttons):37import plotly.express as px38import re3940# Sample completions with different formats41completions = [42# Think-answer format examples43"<think>Let me solve this step by step</think><answer>42</answer>",44"The answer is 15 without any special format",45"<code>print('Hello world')</code><explanation>This prints a greeting</explanation>",46# Code-explanation format examples47"<code>def add(a, b): return a + b</code><explanation>A function to add numbers</explanation>",48"<code>for i in range(10): print(i)</code>",49"<think>I should use a loop</think><code>while True: pass</code>",50]5152# Create shortened versions for display53short_completions = [c[:30] + "..." if len(c) > 30 else c for c in completions]5455def format_reward(completions, format_type="think-answer", **kwargs):56"""57Reward completions that follow the desired format structure5859Args:60completions: list of completions to evaluate61format_type: which format structure to reward6263Returns:64list of rewards and details65"""66# Define patterns for different formats67patterns = {68"think-answer": r"<think>.*?</think>\s*<answer>.*?</answer>",69"code-explanation": r"<code>.*?</code>\s*<explanation>.*?</explanation>",70}7172# Select the pattern based on format_type73pattern = patterns.get(format_type, patterns["think-answer"])7475rewards = []76details = []77categories = []7879for completion in completions:80match = re.search(pattern, completion, re.DOTALL)81if match:82# Full match for the exact format83rewards.append(1.0)84details.append(f"Correct {format_type} format")85categories.append("Exact Format Match")86elif f"<{format_type.split('-')[0]}>" in completion:87# Partial match - has the opening tag of the format88rewards.append(0.5)89details.append(f"Has {format_type.split('-')[0]} tag but incomplete")90categories.append("Partial Format Match")91elif any(f"<{tag}>" in completion for tag in format_type.split("-")):92# Contains at least one of the required tags93rewards.append(0.2)94details.append("Has some required tags but wrong format")95categories.append("Some Tags Present")96else:97# No match at all98rewards.append(0.0)99details.append("Incorrect format")100categories.append("No Format Match")101102return rewards, details, categories103104# Calculate rewards105rewards, details, categories = format_reward(106completions=completions, format_type=format_buttons.value107)108109# Display the results110results = []111for completion, reward, detail, category in zip(112short_completions, rewards, details, categories113):114results.append(115{116"Completion": completion,117"Reward": reward,118"Detail": detail,119"Category": category,120}121)122123# Create a table view124mo.md(f"### Results for {format_buttons.value} format")125mo.ui.table(results)126127# Create a bar chart comparing rewards by completion128fig = px.bar(129results,130x="Completion",131y="Reward",132color="Category",133title=f"Format Rewards by Completion ({format_buttons.value})",134hover_data=["Detail"],135)136mo.ui.plotly(fig)137138139if __name__ == "__main__":140app.run()141142143