Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/diffusers/reinforcement_learning_for_control.ipynb
Views: 2932
Kernel: Unknown Kernel
Diffusion-based Policy Learning for RL
This notebook implements Diffusion Policy, a diffusion model that predicts robot action sequences in reinforcement learning tasks.
This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow. This script was contributed by Dorsa Rohani and the notebook by Parag Ekbote.
In [2]:
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting git+https://github.com/rail-berkeley/d4rl.git
Cloning https://github.com/rail-berkeley/d4rl.git to /tmp/pip-req-build-tdkn3r22
Running command git clone --filter=blob:none --quiet https://github.com/rail-berkeley/d4rl.git /tmp/pip-req-build-tdkn3r22
Resolved https://github.com/rail-berkeley/d4rl.git to commit 89141a689b0353b0dac3da5cba60da4b1b16254d
Preparing metadata (setup.py) ... done
Collecting torch==2.0.1+cu117
Downloading https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl (1843.9 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 GB 49.6 MB/s eta 0:00:00:00:0100:01
Collecting torchvision==0.15.2+cu117
Downloading https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl (6.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.1/6.1 MB 192.0 MB/s eta 0:00:00
Collecting torchaudio==2.0.2+cu117
Downloading https://download.pytorch.org/whl/cu117/torchaudio-2.0.2%2Bcu117-cp310-cp310-linux_x86_64.whl (4.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.4/4.4 MB 181.3 MB/s eta 0:00:00
Collecting gym==0.23.1
Downloading gym-0.23.1.tar.gz (626 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 626.2/626.2 kB 99.6 MB/s eta 0:00:00
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Collecting protobuf==3.20.1
Downloading protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (698 bytes)
Collecting einops
Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting mediapy
Downloading mediapy-1.2.2-py3-none-any.whl.metadata (4.8 kB)
Collecting Pillow==9.0.0
Downloading Pillow-9.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB)
Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (3.17.0)
Requirement already satisfied: typing-extensions in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (4.12.2)
Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (1.13.3)
Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (3.4.2)
Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch==2.0.1+cu117) (3.1.5)
Collecting triton==2.0.0 (from torch==2.0.1+cu117)
Downloading triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.0 kB)
Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torchvision==0.15.2+cu117) (1.26.4)
Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torchvision==0.15.2+cu117) (2.32.3)
Requirement already satisfied: cloudpickle>=1.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from gym==0.23.1) (3.1.1)
Collecting gym_notices>=0.0.4 (from gym==0.23.1)
Downloading gym_notices-0.0.8-py3-none-any.whl.metadata (1.0 kB)
Collecting cmake (from triton==2.0.0->torch==2.0.1+cu117)
Downloading cmake-3.31.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.5 kB)
Collecting lit (from triton==2.0.0->torch==2.0.1+cu117)
Downloading lit-18.1.8-py3-none-any.whl.metadata (2.5 kB)
Collecting mjrl@ git+https://github.com/aravindr93/mjrl@master#egg=mjrl (from D4RL==1.1)
Cloning https://github.com/aravindr93/mjrl (to revision master) to /tmp/pip-install-sr3n9qkg/mjrl_98dfd6b68c6048399cc4826c3796d7e4
Running command git clone --filter=blob:none --quiet https://github.com/aravindr93/mjrl /tmp/pip-install-sr3n9qkg/mjrl_98dfd6b68c6048399cc4826c3796d7e4
Resolved https://github.com/aravindr93/mjrl to commit 3871d93763d3b49c4741e6daeaebbc605fe140dc
Preparing metadata (setup.py) ... done
Collecting mujoco_py (from D4RL==1.1)
Downloading mujoco_py-2.1.2.14-py3-none-any.whl.metadata (669 bytes)
Collecting pybullet (from D4RL==1.1)
Downloading pybullet-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting h5py (from D4RL==1.1)
Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Requirement already satisfied: termcolor in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from D4RL==1.1) (2.5.0)
Requirement already satisfied: click in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from D4RL==1.1) (8.1.8)
Collecting dm_control>=1.0.3 (from D4RL==1.1)
Downloading dm_control-1.0.27-py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: ipython in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mediapy) (8.17.2)
Requirement already satisfied: matplotlib in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mediapy) (3.8.2)
Requirement already satisfied: absl-py>=0.7.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (2.1.0)
Collecting dm-env (from dm_control>=1.0.3->D4RL==1.1)
Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Collecting dm-tree!=0.1.2 (from dm_control>=1.0.3->D4RL==1.1)
Downloading dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.2 kB)
Requirement already satisfied: glfw in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (2.8.0)
Collecting labmaze (from dm_control>=1.0.3->D4RL==1.1)
Downloading labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (278 bytes)
Collecting lxml (from dm_control>=1.0.3->D4RL==1.1)
Downloading lxml-5.3.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.7 kB)
Collecting mujoco>=3.2.7 (from dm_control>=1.0.3->D4RL==1.1)
Downloading mujoco-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
Requirement already satisfied: pyopengl>=3.1.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (3.1.9)
Requirement already satisfied: pyparsing>=3.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (3.2.1)
Requirement already satisfied: setuptools!=50.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (75.8.0)
Requirement already satisfied: scipy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (1.11.4)
Requirement already satisfied: tqdm in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm_control>=1.0.3->D4RL==1.1) (4.67.1)
Requirement already satisfied: decorator in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (0.19.2)
Requirement already satisfied: matplotlib-inline in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (0.1.7)
Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (3.0.50)
Requirement already satisfied: pygments>=2.4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (2.19.1)
Requirement already satisfied: stack-data in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (0.6.3)
Requirement already satisfied: traitlets>=5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (5.14.3)
Requirement already satisfied: exceptiongroup in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (1.2.2)
Requirement already satisfied: pexpect>4.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from ipython->mediapy) (4.9.0)
Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch==2.0.1+cu117) (3.0.2)
Requirement already satisfied: contourpy>=1.0.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (4.55.8)
Requirement already satisfied: kiwisolver>=1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (1.4.8)
Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (24.2)
Requirement already satisfied: python-dateutil>=2.7 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib->mediapy) (2.9.0.post0)
Collecting Cython>=0.27.2 (from mujoco_py->D4RL==1.1)
Downloading Cython-3.0.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Requirement already satisfied: imageio>=2.1.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mujoco_py->D4RL==1.1) (2.37.0)
Requirement already satisfied: cffi>=1.10 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mujoco_py->D4RL==1.1) (1.17.1)
Collecting fasteners~=0.15 (from mujoco_py->D4RL==1.1)
Downloading fasteners-0.19-py3-none-any.whl.metadata (4.9 kB)
Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (3.4.1)
Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->torchvision==0.15.2+cu117) (2025.1.31)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch==2.0.1+cu117) (1.3.0)
Requirement already satisfied: pycparser in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from cffi>=1.10->mujoco_py->D4RL==1.1) (2.22)
Requirement already satisfied: attrs>=18.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from dm-tree!=0.1.2->dm_control>=1.0.3->D4RL==1.1) (25.1.0)
Collecting wrapt>=1.11.2 (from dm-tree!=0.1.2->dm_control>=1.0.3->D4RL==1.1)
Downloading wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.4 kB)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4)
Requirement already satisfied: etils[epath] in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (1.12.0)
Requirement already satisfied: ptyprocess>=0.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0)
Requirement already satisfied: wcwidth in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython->mediapy) (0.2.13)
Requirement already satisfied: six>=1.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0)
Requirement already satisfied: executing>=1.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from stack-data->ipython->mediapy) (2.2.0)
Requirement already satisfied: asttokens>=2.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from stack-data->ipython->mediapy) (3.0.0)
Requirement already satisfied: pure-eval in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from stack-data->ipython->mediapy) (0.2.3)
Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from etils[epath]->mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (2025.2.0)
Requirement already satisfied: importlib_resources in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from etils[epath]->mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (6.5.2)
Requirement already satisfied: zipp in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from etils[epath]->mujoco>=3.2.7->dm_control>=1.0.3->D4RL==1.1) (3.21.0)
Downloading protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 140.3 MB/s eta 0:00:00
Downloading Pillow-9.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.3/4.3 MB 175.2 MB/s eta 0:00:00
Downloading triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 MB 202.4 MB/s eta 0:00:00a 0:00:01
Downloading einops-0.8.1-py3-none-any.whl (64 kB)
Downloading mediapy-1.2.2-py3-none-any.whl (26 kB)
Downloading dm_control-1.0.27-py3-none-any.whl (56.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 184.2 MB/s eta 0:00:00a 0:00:01
Downloading gym_notices-0.0.8-py3-none-any.whl (3.0 kB)
Downloading h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 152.9 MB/s eta 0:00:00
Downloading mujoco_py-2.1.2.14-py3-none-any.whl (2.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.4/2.4 MB 151.7 MB/s eta 0:00:00
Downloading pybullet-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (103.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 103.2/103.2 MB 150.7 MB/s eta 0:00:0000:0100:01
Downloading Cython-3.0.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 181.3 MB/s eta 0:00:00
Downloading dm_tree-0.1.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)
Downloading fasteners-0.19-py3-none-any.whl (18 kB)
Downloading mujoco-3.2.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.4 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.4/6.4 MB 173.5 MB/s eta 0:00:00
Downloading cmake-3.31.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.8 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.8/27.8 MB 221.9 MB/s eta 0:00:00
Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Downloading labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.9/4.9 MB 162.1 MB/s eta 0:00:00
Downloading lit-18.1.8-py3-none-any.whl (96 kB)
Downloading lxml-5.3.1-cp310-cp310-manylinux_2_28_x86_64.whl (5.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.2/5.2 MB 198.7 MB/s eta 0:00:00
Downloading wrapt-1.17.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (82 kB)
Building wheels for collected packages: gym, D4RL, mjrl
Building wheel for gym (pyproject.toml) ... done
Created wheel for gym: filename=gym-0.23.1-py3-none-any.whl size=701426 sha256=15f82ac4bfe4147aaad6d0f913ff41b813a84085d2c64b88c0bed68bd1ccb7e1
Stored in directory: /home/zeus/.cache/pip/wheels/1a/00/fb/fe5cf2860fb9b7bc860e28f00095a1f42c7b726dd6f42d1acc
Building wheel for D4RL (setup.py) ... done
Created wheel for D4RL: filename=D4RL-1.1-py3-none-any.whl size=26412345 sha256=55419958d2c8899fe20a7de17b03908cf493598d2fb33ea5bc743ac86f0407ab
Stored in directory: /tmp/pip-ephem-wheel-cache-cxisi6vd/wheels/7a/7a/27/d17500f4699272a90767c018dbd88a5e2376f2870a79b6a4ac
Building wheel for mjrl (setup.py) ... done
Created wheel for mjrl: filename=mjrl-1.0.0-py3-none-any.whl size=61962 sha256=ec762cf0195d77fce72265c2be84581f7f11f0bb58fd2d7edd4988c28daa6e06
Stored in directory: /tmp/pip-ephem-wheel-cache-cxisi6vd/wheels/8f/99/f9/efd223b38d503df5eaada10ffe96a869fb0c0f3c92d9e43ed0
Successfully built gym D4RL mjrl
Installing collected packages: pybullet, mjrl, lit, gym_notices, wrapt, protobuf, Pillow, lxml, labmaze, h5py, gym, fasteners, einops, Cython, cmake, dm-tree, mujoco_py, mujoco, mediapy, dm-env, dm_control, D4RL, triton, torch, torchvision, torchaudio
Attempting uninstall: protobuf
Found existing installation: protobuf 4.23.4
Uninstalling protobuf-4.23.4:
Successfully uninstalled protobuf-4.23.4
Attempting uninstall: Pillow
Found existing installation: pillow 11.1.0
Uninstalling pillow-11.1.0:
Successfully uninstalled pillow-11.1.0
Attempting uninstall: mujoco
Found existing installation: mujoco 3.1.6
Uninstalling mujoco-3.1.6:
Successfully uninstalled mujoco-3.1.6
Attempting uninstall: triton
Found existing installation: triton 2.2.0
Uninstalling triton-2.2.0:
Successfully uninstalled triton-2.2.0
Attempting uninstall: torch
Found existing installation: torch 2.2.1+cu121
Uninstalling torch-2.2.1+cu121:
Successfully uninstalled torch-2.2.1+cu121
Attempting uninstall: torchvision
Found existing installation: torchvision 0.17.1+cu121
Uninstalling torchvision-0.17.1+cu121:
Successfully uninstalled torchvision-0.17.1+cu121
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gymnasium-robotics 1.3.1 requires mujoco<3.2.0,>=2.2.0, but you have mujoco 3.2.7 which is incompatible.
lightning 2.5.0.post0 requires torch<4.0,>=2.1.0, but you have torch 2.0.1+cu117 which is incompatible.
pytorch-lightning 2.5.0.post0 requires torch>=2.1.0, but you have torch 2.0.1+cu117 which is incompatible.
Successfully installed Cython-3.0.12 D4RL-1.1 Pillow-9.0.0 cmake-3.31.4 dm-env-1.6 dm-tree-0.1.9 dm_control-1.0.27 einops-0.8.1 fasteners-0.19 gym-0.23.1 gym_notices-0.0.8 h5py-3.12.1 labmaze-1.0.6 lit-18.1.8 lxml-5.3.1 mediapy-1.2.2 mjrl-1.0.0 mujoco-3.2.7 mujoco_py-2.1.2.14 protobuf-3.20.1 pybullet-3.2.7 torch-2.0.1+cu117 torchaudio-2.0.2+cu117 torchvision-0.15.2+cu117 triton-2.0.0 wrapt-1.17.2
In [6]:
Transferred weights for 0 layers
Action shape: torch.Size([1, 16, 2])
Predicted trajectory:
Step 0: x= 48.9, y= 449.9
Step 1: x= 118.0, y= 49.0
Step 2: x= 191.8, y= 110.5
Step 3: x= 501.7, y= 512.0
Step 4: x= 0.0, y= 425.6
Step 5: x= 378.3, y= 0.0
Step 6: x= 39.3, y= 0.5
Step 7: x= 474.6, y= 372.0
Step 8: x= 17.0, y= 398.2
Step 9: x= 30.2, y= 369.9
Step 10: x= 11.4, y= 503.2
Step 11: x= 512.0, y= 424.3
Step 12: x= 415.7, y= 508.0
Step 13: x= 357.8, y= 503.9
Step 14: x= 294.6, y= 512.0
Step 15: x= 219.7, y= 87.5