CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/diffusers/reinforcement_learning_for_control.ipynb
Views: 2932
Kernel: Unknown Kernel

Diffusion-based Policy Learning for RL

This notebook implements Diffusion Policy, a diffusion model that predicts robot action sequences in reinforcement learning tasks.

This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow. This script was contributed by Dorsa Rohani and the notebook by Parag Ekbote.

!pip install torch==2.0.1+cu117 \ torchvision==0.15.2+cu117 \ torchaudio==2.0.2+cu117 \ git+https://github.com/rail-berkeley/d4rl.git \ gym==0.23.1 \ protobuf==3.20.1 \ einops \ mediapy \ Pillow==9.0.0 \ -f https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://download.pytorch.org/whl/torch_stable.html Collecting git+https://github.com/rail-berkeley/d4rl.git Cloning https://github.com/rail-berkeley/d4rl.git to /tmp/pip-req-build-tdkn3r22 Running command git clone --filter=blob:none --quiet https://github.com/rail-berkeley/d4rl.git /tmp/pip-req-build-tdkn3r22 Resolved https://github.com/rail-berkeley/d4rl.git to commit 89141a689b0353b0dac3da5cba60da4b1b16254d Preparing metadata (setup.py) ... done Collecting torch==2.0.1+cu117 Downloading https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl (1843.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 GB 49.6 MB/s eta 0:00:00:00:0100:01 Collecting torchvision==0.15.2+cu117 Downloading https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl (6.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.1/6.1 MB 192.0 MB/s eta 0:00:00 Collecting torchaudio==2.0.2+cu117 Downloading https://download.pytorch.org/whl/cu117/torchaudio-2.0.2%2Bcu117-cp310-cp310-linux_x86_64.whl (4.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.4/4.4 MB 181.3 MB/s eta 0:00:00 Collecting gym==0.23.1 Downloading gym-0.23.1.tar.gz (626 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 626.2/626.2 kB 99.6 MB/s eta 0:00:00 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Collecting protobuf==3.20.1 Downloading protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (698 bytes) Collecting einops Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB) Collecting mediapy Downloading mediapy-1.2.2-py3-none-any.whl.metadata (4.8 kB) Collecting Pillow==9.0.0 Downloading Pillow-9.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB) Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (3.17.0) Requirement already satisfied: typing-extensions in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (4.12.2) Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (1.13.3) Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (3.4.2) Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (3.1.5) Collecting triton==2.0.0 (from torch==2.0.1+cu117) Downloading triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.0 kB) Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torchvision==0.15.2+cu117) (1.26.4) Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torchvision==0.15.2+cu117) (2.32.3) Requirement already satisfied: cloudpickle>=1.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from gym==0.23.1) (3.1.1) Collecting gym_notices>=0.0.4 (from gym==0.23.1) Downloading gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB) Collecting cmake (from triton==2.0.0->torch==2.0.1+cu117) Downloading cmake-3.31.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.5 kB) Collecting lit (from triton==2.0.0->torch==2.0.1+cu117) Downloading lit-18.1.8-py3-none-any.whl.metadata (2.5 kB) Collecting mjrl@ git+https://github.com/aravindr93/mjrl@master#egg=mjrl (from D4RL==1.1) Cloning https://github.com/aravindr93/mjrl (to revision master) to /tmp/pip-install-sr3n9qkg/mjrl_98dfd6b68c6048399cc4826c3796d7e4 Running command git clone --filter=blob:none --quiet https://github.com/aravindr93/mjrl /tmp/pip-install-sr3n9qkg/mjrl_98dfd6b68c6048399cc4826c3796d7e4 Resolved https://github.com/aravindr93/mjrl to commit 3871d93763d3b49c4741e6daeaebbc605fe140dc Preparing metadata (setup.py) ... done Collecting mujoco_py (from D4RL==1.1) Downloading mujoco_py-2.1.2.14-py3-none-any.whl.metadata (669 bytes) Collecting pybullet (from D4RL==1.1) Downloading pybullet-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB) Collecting h5py (from D4RL==1.1) Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB) Requirement already satisfied: termcolor in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from D4RL==1.1) (2.5.0) Requirement already satisfied: click in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from D4RL==1.1) (8.1.8) Collecting dm_control>=1.0.3 (from D4RL==1.1) Downloading dm_control-1.0.27-py3-none-any.whl.metadata (1.6 kB) Requirement already satisfied: ipython in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mediapy) (8.17.2) Requirement already satisfied: matplotlib in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mediapy) (3.8.2) Requirement already satisfied: absl-py>=0.7.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (2.1.0) Collecting dm-env (from dm_control>=1.0.3->D4RL==1.1) Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes) Collecting dm-tree!=0.1.2 (from dm_control>=1.0.3->D4RL==1.1) Downloading dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.2 kB) Requirement already satisfied: glfw in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (2.8.0) Collecting labmaze (from dm_control>=1.0.3->D4RL==1.1) Downloading labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (278 bytes) Collecting lxml (from dm_control>=1.0.3->D4RL==1.1) Downloading lxml-5.3.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.7 kB) Collecting mujoco>=3.2.7 (from dm_control>=1.0.3->D4RL==1.1) Downloading mujoco-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB) Requirement already satisfied: pyopengl>=3.1.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (3.1.9) Requirement already satisfied: pyparsing>=3.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (3.2.1) Requirement already satisfied: setuptools!=50.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (75.8.0) Requirement already satisfied: scipy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (1.11.4) Requirement already satisfied: tqdm in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (4.67.1) Requirement already satisfied: decorator in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (5.1.1) Requirement already satisfied: jedi>=0.16 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (0.19.2) Requirement already satisfied: matplotlib-inline in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (0.1.7) Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (3.0.50) Requirement already satisfied: pygments>=2.4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (2.19.1) Requirement already satisfied: stack-data in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (0.6.3) Requirement already satisfied: traitlets>=5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (5.14.3) Requirement already satisfied: exceptiongroup in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (1.2.2) Requirement already satisfied: pexpect>4.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (4.9.0) Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch==2.0.1+cu117) (3.0.2) Requirement already satisfied: contourpy>=1.0.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (1.3.1) Requirement already satisfied: cycler>=0.10 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (4.55.8) Requirement already satisfied: kiwisolver>=1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (1.4.8) Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (24.2) Requirement already satisfied: python-dateutil>=2.7 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (2.9.0.post0) Collecting Cython>=0.27.2 (from mujoco_py->D4RL==1.1) Downloading Cython-3.0.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB) Requirement already satisfied: imageio>=2.1.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mujoco_py->D4RL==1.1) (2.37.0) Requirement already satisfied: cffi>=1.10 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mujoco_py->D4RL==1.1) (1.17.1) Collecting fasteners~=0.15 (from mujoco_py->D4RL==1.1) Downloading fasteners-0.19-py3-none-any.whl.metadata (4.9 kB) Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (3.4.1) Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (2.3.0) Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (2025.1.31) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch==2.0.1+cu117) (1.3.0) Requirement already satisfied: pycparser in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from cffi>=1.10->mujoco_py->D4RL==1.1) (2.22) Requirement already satisfied: attrs>=18.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm-tree!=0.1.2->dm_control>=1.0.3->D4RL==1.1) (25.1.0) Collecting wrapt>=1.11.2 (from dm-tree!=0.1.2->dm_control>=1.0.3->D4RL==1.1) Downloading wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.4 kB) Requirement already satisfied: parso<0.9.0,>=0.8.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4) Requirement already satisfied: etils[epath] in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (1.12.0) Requirement already satisfied: ptyprocess>=0.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0) Requirement already satisfied: wcwidth in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython->mediapy) (0.2.13) Requirement already satisfied: six>=1.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0) Requirement already satisfied: executing>=1.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from stack-data->ipython->mediapy) (2.2.0) Requirement already satisfied: asttokens>=2.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from stack-data->ipython->mediapy) (3.0.0) Requirement already satisfied: pure-eval in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from stack-data->ipython->mediapy) (0.2.3) Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from etils[epath]->mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (2025.2.0) Requirement already satisfied: importlib_resources in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from etils[epath]->mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (6.5.2) Requirement already satisfied: zipp in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from etils[epath]->mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (3.21.0) Downloading protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 140.3 MB/s eta 0:00:00 Downloading Pillow-9.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 175.2 MB/s eta 0:00:00 Downloading triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 MB 202.4 MB/s eta 0:00:00a 0:00:01 Downloading einops-0.8.1-py3-none-any.whl (64 kB) Downloading mediapy-1.2.2-py3-none-any.whl (26 kB) Downloading dm_control-1.0.27-py3-none-any.whl (56.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 184.2 MB/s eta 0:00:00a 0:00:01 Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB) Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 152.9 MB/s eta 0:00:00 Downloading mujoco_py-2.1.2.14-py3-none-any.whl (2.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.4/2.4 MB 151.7 MB/s eta 0:00:00 Downloading pybullet-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (103.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.2/103.2 MB 150.7 MB/s eta 0:00:0000:0100:01 Downloading Cython-3.0.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 181.3 MB/s eta 0:00:00 Downloading dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB) Downloading fasteners-0.19-py3-none-any.whl (18 kB) Downloading mujoco-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.4/6.4 MB 173.5 MB/s eta 0:00:00 Downloading cmake-3.31.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.8/27.8 MB 221.9 MB/s eta 0:00:00 Downloading dm_env-1.6-py3-none-any.whl (26 kB) Downloading labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.9/4.9 MB 162.1 MB/s eta 0:00:00 Downloading lit-18.1.8-py3-none-any.whl (96 kB) Downloading lxml-5.3.1-cp310-cp310-manylinux_2_28_x86_64.whl (5.2 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.2/5.2 MB 198.7 MB/s eta 0:00:00 Downloading wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (82 kB) Building wheels for collected packages: gym, D4RL, mjrl Building wheel for gym (pyproject.toml) ... done Created wheel for gym: filename=gym-0.23.1-py3-none-any.whl size=701426 sha256=15f82ac4bfe4147aaad6d0f913ff41b813a84085d2c64b88c0bed68bd1ccb7e1 Stored in directory: /home/zeus/.cache/pip/wheels/1a/00/fb/fe5cf2860fb9b7bc860e28f00095a1f42c7b726dd6f42d1acc Building wheel for D4RL (setup.py) ... done Created wheel for D4RL: filename=D4RL-1.1-py3-none-any.whl size=26412345 sha256=55419958d2c8899fe20a7de17b03908cf493598d2fb33ea5bc743ac86f0407ab Stored in directory: /tmp/pip-ephem-wheel-cache-cxisi6vd/wheels/7a/7a/27/d17500f4699272a90767c018dbd88a5e2376f2870a79b6a4ac Building wheel for mjrl (setup.py) ... done Created wheel for mjrl: filename=mjrl-1.0.0-py3-none-any.whl size=61962 sha256=ec762cf0195d77fce72265c2be84581f7f11f0bb58fd2d7edd4988c28daa6e06 Stored in directory: /tmp/pip-ephem-wheel-cache-cxisi6vd/wheels/8f/99/f9/efd223b38d503df5eaada10ffe96a869fb0c0f3c92d9e43ed0 Successfully built gym D4RL mjrl Installing collected packages: pybullet, mjrl, lit, gym_notices, wrapt, protobuf, Pillow, lxml, labmaze, h5py, gym, fasteners, einops, Cython, cmake, dm-tree, mujoco_py, mujoco, mediapy, dm-env, dm_control, D4RL, triton, torch, torchvision, torchaudio Attempting uninstall: protobuf Found existing installation: protobuf 4.23.4 Uninstalling protobuf-4.23.4: Successfully uninstalled protobuf-4.23.4 Attempting uninstall: Pillow Found existing installation: pillow 11.1.0 Uninstalling pillow-11.1.0: Successfully uninstalled pillow-11.1.0 Attempting uninstall: mujoco Found existing installation: mujoco 3.1.6 Uninstalling mujoco-3.1.6: Successfully uninstalled mujoco-3.1.6 Attempting uninstall: triton Found existing installation: triton 2.2.0 Uninstalling triton-2.2.0: Successfully uninstalled triton-2.2.0 Attempting uninstall: torch Found existing installation: torch 2.2.1+cu121 Uninstalling torch-2.2.1+cu121: Successfully uninstalled torch-2.2.1+cu121 Attempting uninstall: torchvision Found existing installation: torchvision 0.17.1+cu121 Uninstalling torchvision-0.17.1+cu121: Successfully uninstalled torchvision-0.17.1+cu121 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. gymnasium-robotics 1.3.1 requires mujoco<3.2.0,>=2.2.0, but you have mujoco 3.2.7 which is incompatible. lightning 2.5.0.post0 requires torch<4.0,>=2.1.0, but you have torch 2.0.1+cu117 which is incompatible. pytorch-lightning 2.5.0.post0 requires torch>=2.1.0, but you have torch 2.0.1+cu117 which is incompatible. Successfully installed Cython-3.0.12 D4RL-1.1 Pillow-9.0.0 cmake-3.31.4 dm-env-1.6 dm-tree-0.1.9 dm_control-1.0.27 einops-0.8.1 fasteners-0.19 gym-0.23.1 gym_notices-0.0.8 h5py-3.12.1 labmaze-1.0.6 lit-18.1.8 lxml-5.3.1 mediapy-1.2.2 mjrl-1.0.0 mujoco-3.2.7 mujoco_py-2.1.2.14 protobuf-3.20.1 pybullet-3.2.7 torch-2.0.1+cu117 torchaudio-2.0.2+cu117 torchvision-0.15.2+cu117 triton-2.0.0 wrapt-1.17.2
import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download from diffusers import DDPMScheduler class ObservationEncoder(nn.Module): def __init__(self, state_dim): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, 512), nn.ReLU(), nn.Linear(512, 256) ) def forward(self, x): return self.net(x) class ObservationProjection(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.randn(32, 512)) self.bias = nn.Parameter(torch.zeros(32)) def forward(self, x): if x.size(-1) == 256: x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1) return nn.functional.linear(x, self.weight, self.bias) class UNet1D(nn.Module): def __init__(self, in_channels, out_channels, hidden_channels=128): super().__init__() # Downsampling path self.down1 = nn.Sequential( nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1), nn.ReLU(), nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1), nn.ReLU() ) # Time embedding self.time_mlp = nn.Sequential( nn.Linear(1, hidden_channels), nn.ReLU(), nn.Linear(hidden_channels, hidden_channels) ) # Middle self.mid = nn.Sequential( nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1), nn.ReLU(), nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1), nn.ReLU() ) # Upsampling path self.up1 = nn.Sequential( nn.Conv1d(2 * hidden_channels, hidden_channels, kernel_size=3, padding=1), nn.ReLU(), nn.Conv1d(hidden_channels, out_channels, kernel_size=3, padding=1) ) def forward(self, x, t): # Ensure proper tensor dimensions if not isinstance(t, torch.Tensor): t = torch.tensor([t], device=x.device) if t.dim() == 0: t = t.view(1) if t.dim() == 1: t = t.unsqueeze(-1) # Time embedding t_emb = self.time_mlp(t.float()) # [B, H] # Reshape time embedding to match spatial dimensions t_emb = t_emb.unsqueeze(-1) # [B, H, 1] t_emb = t_emb.expand(-1, -1, x.shape[-1]) # [B, H, L] # Downsampling d1 = self.down1(x) # [B, H, L] # Add time embedding mid = self.mid(d1 + t_emb) # [B, H, L] # Upsampling with skip connections up = self.up1(torch.cat([mid, d1], dim=1)) # [B, out_channels, L] return up class DiffusionPolicy: def __init__(self, state_dim=5, device="cpu"): self.device = device # Define valid ranges self.stats = { "obs": { "min": torch.zeros(5, device=device), "max": torch.tensor([512, 512, 512, 512, 2 * np.pi], device=device) }, "action": { "min": torch.zeros(2, device=device), "max": torch.full((2,), 512, device=device) }, } self.obs_encoder = ObservationEncoder(state_dim).to(device) self.obs_projection = ObservationProjection().to(device) # Use custom UNet1D implementation self.model = UNet1D( in_channels=34, # 2 action channels + 32 context channels out_channels=2, # x,y coordinates hidden_channels=128 ).to(device) self.noise_scheduler = DDPMScheduler( num_train_timesteps=100, beta_schedule="squaredcos_cap_v2" ) # Load pre-trained weights using a more compatible approach try: checkpoint_path = hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt") checkpoint = torch.load(checkpoint_path, map_location=device) # Load weights for encoder and projection self.obs_encoder.load_state_dict(self._fix_state_dict(checkpoint["encoder_state_dict"])) self.obs_projection.load_state_dict(self._fix_state_dict(checkpoint["projection_state_dict"])) # Transfer UNet weights self._transfer_weights(checkpoint["model_state_dict"]) except Exception as e: print(f"Warning: Could not load pre-trained weights: {e}") print("The model will use randomly initialized weights.") def _fix_state_dict(self, state_dict): """Helper function to fix state dict keys if needed""" new_state_dict = {} for k, v in state_dict.items(): # Remove 'module.' prefix if it exists k = k.replace('module.', '') new_state_dict[k] = v return new_state_dict def _transfer_weights(self, original_state_dict): custom_state_dict = self.model.state_dict() # Create mapping between original and custom architecture layer_mapping = { 'down_blocks.0.resnets.0': 'down1.0', 'down_blocks.0.resnets.1': 'down1.2', 'mid_block.resnets.0': 'mid.0', 'mid_block.resnets.1': 'mid.2', 'up_blocks.0.resnets.0': 'up1.0', 'up_blocks.0.resnets.1': 'up1.2', } # Transfer weights for compatible layers transferred = set() for orig_name, param in original_state_dict.items(): for orig_prefix, custom_prefix in layer_mapping.items(): if orig_name.startswith(orig_prefix): custom_name = orig_name.replace(orig_prefix, custom_prefix) if custom_name in custom_state_dict: if custom_state_dict[custom_name].shape == param.shape: custom_state_dict[custom_name].copy_(param) transferred.add(custom_name) # Load the transferred weights self.model.load_state_dict(custom_state_dict, strict=False) print(f"Transferred weights for {len(transferred)} layers") def normalize_data(self, data, stats): return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1 def unnormalize_data(self, ndata, stats): return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"] @torch.no_grad() def predict(self, observation): # Ensure observation is a tensor and has batch dimension if not isinstance(observation, torch.Tensor): observation = torch.tensor(observation, device=self.device) if observation.dim() == 1: observation = observation.unsqueeze(0) observation = observation.to(self.device) normalized_obs = self.normalize_data(observation, self.stats["obs"]) # Generate context cond = self.obs_projection(self.obs_encoder(normalized_obs)) cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16) # Initialize with noise action = torch.randn((observation.shape[0], 2, 16), device=self.device) # Denoise self.noise_scheduler.set_timesteps(100) for t in self.noise_scheduler.timesteps: model_input = torch.cat([action, cond], dim=1) model_output = self.model(model_input, t.to(self.device)) action = self.noise_scheduler.step(model_output, t.to(self.device), action).prev_sample action = action.transpose(1, 2) action = self.unnormalize_data(action, self.stats["action"]) return action if __name__ == "__main__": policy = DiffusionPolicy() # Test with sample observation obs = torch.tensor([[ 256.0, # robot arm x position 256.0, # robot arm y position 200.0, # block x position 300.0, # block y position np.pi / 2, # block angle ]]) action = policy.predict(obs) print("Action shape:", action.shape) print("\nPredicted trajectory:") for i, (x, y) in enumerate(action[0]): print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")
Transferred weights for 0 layers Action shape: torch.Size([1, 16, 2]) Predicted trajectory: Step 0: x= 48.9, y= 449.9 Step 1: x= 118.0, y= 49.0 Step 2: x= 191.8, y= 110.5 Step 3: x= 501.7, y= 512.0 Step 4: x= 0.0, y= 425.6 Step 5: x= 378.3, y= 0.0 Step 6: x= 39.3, y= 0.5 Step 7: x= 474.6, y= 372.0 Step 8: x= 17.0, y= 398.2 Step 9: x= 30.2, y= 369.9 Step 10: x= 11.4, y= 503.2 Step 11: x= 512.0, y= 424.3 Step 12: x= 415.7, y= 508.0 Step 13: x= 357.8, y= 503.9 Step 14: x= 294.6, y= 512.0 Step 15: x= 219.7, y= 87.5