mirror of
https://github.com/huggingface/deep-rl-class.git
synced 2026-04-14 18:31:36 +08:00
Update notebook for gymnasium
This commit is contained in:
@@ -5,6 +5,9 @@
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"private_outputs": true,
|
||||
"collapsed_sections": [
|
||||
"tF42HvI7-gs5"
|
||||
],
|
||||
"include_colab_link": true
|
||||
},
|
||||
"kernelspec": {
|
||||
@@ -31,39 +34,25 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# Unit 6: Advantage Actor Critic (A2C) using Robotics Simulations with PyBullet and Panda-Gym 🤖\n",
|
||||
"# Unit 6: Advantage Actor Critic (A2C) using Robotics Simulations with Panda-Gym 🤖\n",
|
||||
"\n",
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/thumbnail.png\" alt=\"Thumbnail\"/>\n",
|
||||
"\n",
|
||||
"In this notebook, you'll learn to use A2C with PyBullet and Panda-Gym, two set of robotics environments.\n",
|
||||
"In this notebook, you'll learn to use A2C with [Panda-Gym](https://github.com/qgallouedec/panda-gym). You're going **to train a robotic arm** (Franka Emika Panda robot) to perform a task:\n",
|
||||
"\n",
|
||||
"With [PyBullet](https://github.com/bulletphysics/bullet3), you're going to **train a robot to move**:\n",
|
||||
"- `AntBulletEnv-v0` 🕸️ More precisely, a spider (they say Ant but come on... it's a spider 😆) 🕸️\n",
|
||||
"\n",
|
||||
"Then, with [Panda-Gym](https://github.com/qgallouedec/panda-gym), you're going **to train a robotic arm** (Franka Emika Panda robot) to perform a task:\n",
|
||||
"- `Reach`: the robot must place its end-effector at a target position.\n",
|
||||
"\n",
|
||||
"After that, you'll be able **to train in other robotics environments**.\n"
|
||||
"After that, you'll be able **to train in other robotics tasks**.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "-PTReiOw-RAN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/environments.gif\" alt=\"Robotics environments\"/>"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "2VGL_0ncoAJI"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### 🎮 Environments:\n",
|
||||
"\n",
|
||||
"- [PyBullet](https://github.com/bulletphysics/bullet3)\n",
|
||||
"- [Panda-Gym](https://github.com/qgallouedec/panda-gym)\n",
|
||||
"\n",
|
||||
"###📚 RL-Library:\n",
|
||||
@@ -90,7 +79,7 @@
|
||||
"\n",
|
||||
"At the end of the notebook, you will:\n",
|
||||
"\n",
|
||||
"- Be able to use **PyBullet** and **Panda-Gym**, the environment libraries.\n",
|
||||
"- Be able to use **Panda-Gym**, the environment library.\n",
|
||||
"- Be able to **train robots using A2C**.\n",
|
||||
"- Understand why **we need to normalize the input**.\n",
|
||||
"- Be able to **push your trained agent and the code to the Hub** with a nice video replay and an evaluation score 🔥.\n",
|
||||
@@ -148,15 +137,12 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"To validate this hands-on for the [certification process](https://huggingface.co/deep-rl-course/en/unit0/introduction#certification-process), you need to push your two trained models to the Hub and get the following results:\n",
|
||||
"To validate this hands-on for the [certification process](https://huggingface.co/deep-rl-course/en/unit0/introduction#certification-process), you need to push your trained model to the Hub and get the following results:\n",
|
||||
"\n",
|
||||
"- `AntBulletEnv-v0` get a result of >= 650.\n",
|
||||
"- `PandaReachDense-v2` get a result of >= -3.5.\n",
|
||||
"- `PandaReachDense-v3` get a result of >= -3.5.\n",
|
||||
"\n",
|
||||
"To find your result, go to the [leaderboard](https://huggingface.co/spaces/huggingface-projects/Deep-Reinforcement-Learning-Leaderboard) and find your model, **the result = mean_reward - std of reward**\n",
|
||||
"\n",
|
||||
"If you don't find your model, **go to the bottom of the page and click on the refresh button**\n",
|
||||
"\n",
|
||||
"For more information about the certification process, check this section 👉 https://huggingface.co/deep-rl-course/en/unit0/introduction#certification-process"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -233,13 +219,15 @@
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Install dependencies 🔽\n",
|
||||
"The first step is to install the dependencies, we’ll install multiple ones:\n",
|
||||
"\n",
|
||||
"- `pybullet`: Contains the walking robots environments.\n",
|
||||
"The first step is to install the dependencies, we’ll install multiple ones:\n",
|
||||
"- `gymnasium`\n",
|
||||
"- `panda-gym`: Contains the robotics arm environments.\n",
|
||||
"- `stable-baselines3[extra]`: The SB3 deep reinforcement learning library.\n",
|
||||
"- `stable-baselines3`: The SB3 deep reinforcement learning library.\n",
|
||||
"- `huggingface_sb3`: Additional code for Stable-baselines3 to load and upload models from the Hugging Face 🤗 Hub.\n",
|
||||
"- `huggingface_hub`: Library allowing anyone to work with the Hub repositories."
|
||||
"- `huggingface_hub`: Library allowing anyone to work with the Hub repositories.\n",
|
||||
"\n",
|
||||
"⏲ The installation can **take 10 minutes**."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "e1obkbdJ_KnG"
|
||||
@@ -248,28 +236,27 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Install the specific setuptools and wheel version required to install the dependencies\n",
|
||||
"!pip install setuptools==65.5.0 wheel==0.38.4"
|
||||
"!pip install stable-baselines3[extra]\n",
|
||||
"!pip install gymnasium"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "eUamMNshb6ee"
|
||||
"id": "TgZUkjKYSgvn"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "2yZRi_0bQGPM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install stable-baselines3[extra]==1.8.0\n",
|
||||
"!pip install huggingface_sb3\n",
|
||||
"!pip install panda_gym==2.0.0\n",
|
||||
"!pip install pyglet==1.5.1"
|
||||
]
|
||||
"!pip install git+https://github.com/huggingface/huggingface_sb3@gymnasium-v2 # We didn't merged this branch yet\n",
|
||||
"!pip install huggingface_hub\n",
|
||||
"!pip install panda_gym"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ABneW6tOSpyU"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -283,12 +270,11 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import pybullet_envs\n",
|
||||
"import panda_gym\n",
|
||||
"import gym\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import gymnasium as gym\n",
|
||||
"import panda_gym\n",
|
||||
"\n",
|
||||
"from huggingface_sb3 import load_from_hub, package_to_hub\n",
|
||||
"\n",
|
||||
"from stable_baselines3 import A2C\n",
|
||||
@@ -307,7 +293,22 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment 1: AntBulletEnv-v0 🕸\n",
|
||||
"## PandaReachDense-v3 🦾\n",
|
||||
"\n",
|
||||
"The agent we're going to train is a robotic arm that needs to do controls (moving the arm and using the end-effector).\n",
|
||||
"\n",
|
||||
"In robotics, the *end-effector* is the device at the end of a robotic arm designed to interact with the environment.\n",
|
||||
"\n",
|
||||
"In `PandaReach`, the robot must place its end-effector at a target position (green ball).\n",
|
||||
"\n",
|
||||
"We're going to use the dense version of this environment. It means we'll get a *dense reward function* that **will provide a reward at each timestep** (the closer the agent is to completing the task, the higher the reward). Contrary to a *sparse reward function* where the environment **return a reward if and only if the task is completed**.\n",
|
||||
"\n",
|
||||
"Also, we're going to use the *End-effector displacement control*, it means the **action corresponds to the displacement of the end-effector**. We don't control the individual motion of each joint (joint control).\n",
|
||||
"\n",
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/robotics.jpg\" alt=\"Robotics\"/>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This way **the training will be easier**.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -317,10 +318,11 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Create the AntBulletEnv-v0\n",
|
||||
"### Create the environment\n",
|
||||
"\n",
|
||||
"#### The environment 🎮\n",
|
||||
"In this environment, the agent needs to use correctly its different joints to walk correctly.\n",
|
||||
"You can find a detailled explanation of this environment here: https://hackmd.io/@jeffreymo/SJJrSJh5_#PyBullet"
|
||||
"\n",
|
||||
"In `PandaReachDense-v3` the robotic arm must place its end-effector at a target position (green ball)."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "frVXOrnlBerQ"
|
||||
@@ -329,16 +331,17 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"env_id = \"AntBulletEnv-v0\"\n",
|
||||
"env_id = \"PandaReachDense-v3\"\n",
|
||||
"\n",
|
||||
"# Create the env\n",
|
||||
"env = gym.make(env_id)\n",
|
||||
"\n",
|
||||
"# Get the state space and action space\n",
|
||||
"s_size = env.observation_space.shape[0]\n",
|
||||
"s_size = env.observation_space.shape\n",
|
||||
"a_size = env.action_space"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "JpU-JCDQYYax"
|
||||
"id": "zXzAu3HYF1WD"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
@@ -351,7 +354,7 @@
|
||||
"print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "2ZfvcCqEYgrg"
|
||||
"id": "E-U9dexcF-FB"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
@@ -359,14 +362,15 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The observation Space (from [Jeffrey Y Mo](https://hackmd.io/@jeffreymo/SJJrSJh5_#PyBullet)):\n",
|
||||
"The observation space **is a dictionary with 3 different elements**:\n",
|
||||
"- `achieved_goal`: (x,y,z) position of the goal.\n",
|
||||
"- `desired_goal`: (x,y,z) distance between the goal position and the current object position.\n",
|
||||
"- `observation`: position (x,y,z) and velocity of the end-effector (vx, vy, vz).\n",
|
||||
"\n",
|
||||
"The difference is that our observation space is 28 not 29.\n",
|
||||
"\n",
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/obs_space.png\" alt=\"PyBullet Ant Obs space\"/>\n"
|
||||
"Given it's a dictionary as observation, **we will need to use a MultiInputPolicy policy instead of MlpPolicy**."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "QzMmsdMJS7jh"
|
||||
"id": "g_JClfElGFnF"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -377,7 +381,7 @@
|
||||
"print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Tc89eLTYYkK2"
|
||||
"id": "ib1Kxy4AF-FC"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
@@ -385,12 +389,11 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The action Space (from [Jeffrey Y Mo](https://hackmd.io/@jeffreymo/SJJrSJh5_#PyBullet)):\n",
|
||||
"\n",
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/action_space.png\" alt=\"PyBullet Ant Obs space\"/>\n"
|
||||
"The action space is a vector with 3 values:\n",
|
||||
"- Control x, y, z movement"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "3RfsHhzZS9Pw"
|
||||
"id": "5MHTHEHZS4yp"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -458,8 +461,6 @@
|
||||
"source": [
|
||||
"### Create the A2C Model 🤖\n",
|
||||
"\n",
|
||||
"In this case, because we have a vector of 28 values as input, we'll use an MLP (multi-layer perceptron) as policy.\n",
|
||||
"\n",
|
||||
"For more information about A2C implementation with StableBaselines3 check: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html#notes\n",
|
||||
"\n",
|
||||
"To find the best parameters I checked the [official trained agents by Stable-Baselines3 team](https://huggingface.co/sb3)."
|
||||
@@ -491,20 +492,8 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model = A2C(policy = \"MlpPolicy\",\n",
|
||||
"model = A2C(policy = \"MultiInputPolicy\",\n",
|
||||
" env = env,\n",
|
||||
" gae_lambda = 0.9,\n",
|
||||
" gamma = 0.99,\n",
|
||||
" learning_rate = 0.00096,\n",
|
||||
" max_grad_norm = 0.5,\n",
|
||||
" n_steps = 8,\n",
|
||||
" vf_coef = 0.4,\n",
|
||||
" ent_coef = 0.0,\n",
|
||||
" policy_kwargs=dict(\n",
|
||||
" log_std_init=-2, ortho_init=False),\n",
|
||||
" normalize_advantage=False,\n",
|
||||
" use_rms_prop= True,\n",
|
||||
" use_sde= True,\n",
|
||||
" verbose=1)"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -517,7 +506,7 @@
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Train the A2C agent 🏃\n",
|
||||
"- Let's train our agent for 2,000,000 timesteps, don't forget to use GPU on Colab. It will take approximately ~25-40min"
|
||||
"- Let's train our agent for 1,000,000 timesteps, don't forget to use GPU on Colab. It will take approximately ~25-40min"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "opyK3mpJ1-m9"
|
||||
@@ -526,7 +515,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model.learn(2_000_000)"
|
||||
"model.learn(1_000_000)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "4TuGHZD7RF1G"
|
||||
@@ -538,7 +527,7 @@
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Save the model and VecNormalize statistics when saving the agent\n",
|
||||
"model.save(\"a2c-AntBulletEnv-v0\")\n",
|
||||
"model.save(\"a2c-PandaReachDense-v3\")\n",
|
||||
"env.save(\"vec_normalize.pkl\")"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -552,8 +541,7 @@
|
||||
"source": [
|
||||
"### Evaluate the agent 📈\n",
|
||||
"- Now that's our agent is trained, we need to **check its performance**.\n",
|
||||
"- Stable-Baselines3 provides a method to do that: `evaluate_policy`\n",
|
||||
"- In my case, I got a mean reward of `2371.90 +/- 16.50`"
|
||||
"- Stable-Baselines3 provides a method to do that: `evaluate_policy`"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "01M9GCd32Ig-"
|
||||
@@ -565,16 +553,19 @@
|
||||
"from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize\n",
|
||||
"\n",
|
||||
"# Load the saved statistics\n",
|
||||
"eval_env = DummyVecEnv([lambda: gym.make(\"AntBulletEnv-v0\")])\n",
|
||||
"eval_env = DummyVecEnv([lambda: gym.make(\"PandaReachDense-v3\")])\n",
|
||||
"eval_env = VecNormalize.load(\"vec_normalize.pkl\", eval_env)\n",
|
||||
"\n",
|
||||
"# We need to override the render_mode\n",
|
||||
"eval_env.render_mode = \"rgb_array\"\n",
|
||||
"\n",
|
||||
"# do not update them at test time\n",
|
||||
"eval_env.training = False\n",
|
||||
"# reward normalization is not needed at test time\n",
|
||||
"eval_env.norm_reward = False\n",
|
||||
"\n",
|
||||
"# Load the agent\n",
|
||||
"model = A2C.load(\"a2c-AntBulletEnv-v0\")\n",
|
||||
"model = A2C.load(\"a2c-PandaReachDense-v3\")\n",
|
||||
"\n",
|
||||
"mean_reward, std_reward = evaluate_policy(model, eval_env)\n",
|
||||
"\n",
|
||||
@@ -592,11 +583,7 @@
|
||||
"### Publish your trained model on the Hub 🔥\n",
|
||||
"Now that we saw we got good results after the training, we can publish our trained model on the Hub with one line of code.\n",
|
||||
"\n",
|
||||
"📚 The libraries documentation 👉 https://github.com/huggingface/huggingface_sb3/tree/main#hugging-face--x-stable-baselines3-v20\n",
|
||||
"\n",
|
||||
"Here's an example of a Model Card (with a PyBullet environment):\n",
|
||||
"\n",
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/modelcardpybullet.png\" alt=\"Model Card Pybullet\"/>"
|
||||
"📚 The libraries documentation 👉 https://github.com/huggingface/huggingface_sb3/tree/main#hugging-face--x-stable-baselines3-v20\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "44L9LVQaavR8"
|
||||
@@ -666,9 +653,20 @@
|
||||
"3️⃣ We're now ready to push our trained agent to the 🤗 Hub 🔥 using `package_to_hub()` function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"For this environment, **running this cell can take approximately 10min**"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "juxItTNf1W74"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from huggingface_sb3 import package_to_hub\n",
|
||||
"\n",
|
||||
"package_to_hub(\n",
|
||||
" model=model,\n",
|
||||
" model_name=f\"a2c-{env_id}\",\n",
|
||||
@@ -680,7 +678,7 @@
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ueuzWVCUTkfS"
|
||||
"id": "V1N8r8QVwcCE"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
@@ -688,143 +686,34 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Take a coffee break ☕\n",
|
||||
"- You already trained your first robot that learned to move congratutlations 🥳!\n",
|
||||
"- It's **time to take a break**. Don't hesitate to **save this notebook** `File > Save a copy to Drive` to work on this second part later.\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "Qk9ykOk9D6Qh"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Environment 2: PandaReachDense-v2 🦾\n",
|
||||
"## Some additional challenges 🏆\n",
|
||||
"The best way to learn **is to try things by your own**! Why not trying `PandaPickAndPlace-v3`?\n",
|
||||
"\n",
|
||||
"The agent we're going to train is a robotic arm that needs to do controls (moving the arm and using the end-effector).\n",
|
||||
"If you want to try more advanced tasks for panda-gym, you need to check what was done using **TQC or SAC** (a more sample-efficient algorithm suited for robotics tasks). In real robotics, you'll use a more sample-efficient algorithm for a simple reason: contrary to a simulation **if you move your robotic arm too much, you have a risk of breaking it**.\n",
|
||||
"\n",
|
||||
"In robotics, the *end-effector* is the device at the end of a robotic arm designed to interact with the environment.\n",
|
||||
"PandaPickAndPlace-v1 (this model uses the v1 version of the environment): https://huggingface.co/sb3/tqc-PandaPickAndPlace-v1\n",
|
||||
"\n",
|
||||
"In `PandaReach`, the robot must place its end-effector at a target position (green ball).\n",
|
||||
"And don't hesitate to check panda-gym documentation here: https://panda-gym.readthedocs.io/en/latest/usage/train_with_sb3.html\n",
|
||||
"\n",
|
||||
"We're going to use the dense version of this environment. It means we'll get a *dense reward function* that **will provide a reward at each timestep** (the closer the agent is to completing the task, the higher the reward). Contrary to a *sparse reward function* where the environment **return a reward if and only if the task is completed**.\n",
|
||||
"We provide you the steps to train another agent (optional):\n",
|
||||
"\n",
|
||||
"Also, we're going to use the *End-effector displacement control*, it means the **action corresponds to the displacement of the end-effector**. We don't control the individual motion of each joint (joint control).\n",
|
||||
"\n",
|
||||
"<img src=\"https://huggingface.co/datasets/huggingface-deep-rl-course/course-images/resolve/main/en/unit8/robotics.jpg\" alt=\"Robotics\"/>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This way **the training will be easier**.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "5VWfwAA7EJg7"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"In `PandaReachDense-v2` the robotic arm must place its end-effector at a target position (green ball).\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "oZ7FyDEi7G3T"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"\n",
|
||||
"env_id = \"PandaReachDense-v2\"\n",
|
||||
"\n",
|
||||
"# Create the env\n",
|
||||
"env = gym.make(env_id)\n",
|
||||
"\n",
|
||||
"# Get the state space and action space\n",
|
||||
"s_size = env.observation_space.shape\n",
|
||||
"a_size = env.action_space"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "zXzAu3HYF1WD"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(\"_____OBSERVATION SPACE_____ \\n\")\n",
|
||||
"print(\"The State Space is: \", s_size)\n",
|
||||
"print(\"Sample observation\", env.observation_space.sample()) # Get a random observation"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "E-U9dexcF-FB"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The observation space **is a dictionary with 3 different elements**:\n",
|
||||
"- `achieved_goal`: (x,y,z) position of the goal.\n",
|
||||
"- `desired_goal`: (x,y,z) distance between the goal position and the current object position.\n",
|
||||
"- `observation`: position (x,y,z) and velocity of the end-effector (vx, vy, vz).\n",
|
||||
"\n",
|
||||
"Given it's a dictionary as observation, **we will need to use a MultiInputPolicy policy instead of MlpPolicy**."
|
||||
],
|
||||
"metadata": {
|
||||
"id": "g_JClfElGFnF"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"print(\"\\n _____ACTION SPACE_____ \\n\")\n",
|
||||
"print(\"The Action Space is: \", a_size)\n",
|
||||
"print(\"Action Space Sample\", env.action_space.sample()) # Take a random action"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "ib1Kxy4AF-FC"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"The action space is a vector with 3 values:\n",
|
||||
"- Control x, y, z movement"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "5MHTHEHZS4yp"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Now it's your turn:\n",
|
||||
"\n",
|
||||
"1. Define the environment called \"PandaReachDense-v2\"\n",
|
||||
"1. Define the environment called \"PandaPickAndPlace-v3\"\n",
|
||||
"2. Make a vectorized environment\n",
|
||||
"3. Add a wrapper to normalize the observations and rewards. [Check the documentation](https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecnormalize)\n",
|
||||
"4. Create the A2C Model (don't forget verbose=1 to print the training logs).\n",
|
||||
"5. Train it for 1M Timesteps\n",
|
||||
"6. Save the model and VecNormalize statistics when saving the agent\n",
|
||||
"7. Evaluate your agent\n",
|
||||
"8. Publish your trained model on the Hub 🔥 with `package_to_hub`"
|
||||
"8. Publish your trained model on the Hub 🔥 with `package_to_hub`\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "nIhPoc5t9HjG"
|
||||
"id": "G3xy3Nf3c2O1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Solution (fill the todo)"
|
||||
"### Solution (optional)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "sKGbFXZq9ikN"
|
||||
@@ -834,11 +723,11 @@
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# 1 - 2\n",
|
||||
"env_id = \"PandaReachDense-v2\"\n",
|
||||
"env_id = \"PandaPickAndPlace-v3\"\n",
|
||||
"env = make_vec_env(env_id, n_envs=4)\n",
|
||||
"\n",
|
||||
"# 3\n",
|
||||
"env = VecNormalize(env, norm_obs=True, norm_reward=False, clip_obs=10.)\n",
|
||||
"env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.)\n",
|
||||
"\n",
|
||||
"# 4\n",
|
||||
"model = A2C(policy = \"MultiInputPolicy\",\n",
|
||||
@@ -857,7 +746,7 @@
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# 6\n",
|
||||
"model_name = \"a2c-PandaReachDense-v2\";\n",
|
||||
"model_name = \"a2c-PandaPickAndPlace-v3\";\n",
|
||||
"model.save(model_name)\n",
|
||||
"env.save(\"vec_normalize.pkl\")\n",
|
||||
"\n",
|
||||
@@ -865,7 +754,7 @@
|
||||
"from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize\n",
|
||||
"\n",
|
||||
"# Load the saved statistics\n",
|
||||
"eval_env = DummyVecEnv([lambda: gym.make(\"PandaReachDense-v2\")])\n",
|
||||
"eval_env = DummyVecEnv([lambda: gym.make(\"PandaPickAndPlace-v3\")])\n",
|
||||
"eval_env = VecNormalize.load(\"vec_normalize.pkl\", eval_env)\n",
|
||||
"\n",
|
||||
"# do not update them at test time\n",
|
||||
@@ -897,27 +786,6 @@
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Some additional challenges 🏆\n",
|
||||
"The best way to learn **is to try things by your own**! Why not trying `HalfCheetahBulletEnv-v0` for PyBullet and `PandaPickAndPlace-v1` for Panda-Gym?\n",
|
||||
"\n",
|
||||
"If you want to try more advanced tasks for panda-gym, you need to check what was done using **TQC or SAC** (a more sample-efficient algorithm suited for robotics tasks). In real robotics, you'll use a more sample-efficient algorithm for a simple reason: contrary to a simulation **if you move your robotic arm too much, you have a risk of breaking it**.\n",
|
||||
"\n",
|
||||
"PandaPickAndPlace-v1: https://huggingface.co/sb3/tqc-PandaPickAndPlace-v1\n",
|
||||
"\n",
|
||||
"And don't hesitate to check panda-gym documentation here: https://panda-gym.readthedocs.io/en/latest/usage/train_with_sb3.html\n",
|
||||
"\n",
|
||||
"Here are some ideas to achieve so:\n",
|
||||
"* Train more steps\n",
|
||||
"* Try different hyperparameters by looking at what your classmates have done 👉 https://huggingface.co/models?other=https://huggingface.co/models?other=AntBulletEnv-v0\n",
|
||||
"* **Push your new trained model** on the Hub 🔥\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "G3xy3Nf3c2O1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
|
||||
Reference in New Issue
Block a user