mirror of
https://github.com/huggingface/deep-rl-class.git
synced 2026-02-03 02:14:53 +08:00
fix: wrap eval_env in Monitor 🐛
This commit is contained in:
@@ -338,8 +338,9 @@
|
||||
"from huggingface_hub import notebook_login # To log to our Hugging Face account to be able to upload models to the Hub.\n",
|
||||
"\n",
|
||||
"from stable_baselines3 import PPO\n",
|
||||
"from stable_baselines3.common.env_util import make_vec_env\n",
|
||||
"from stable_baselines3.common.evaluation import evaluate_policy\n",
|
||||
"from stable_baselines3.common.env_util import make_vec_env"
|
||||
"from stable_baselines3.common.monitor import Monitor"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -738,6 +739,7 @@
|
||||
},
|
||||
"source": [
|
||||
"## Evaluate the agent 📈\n",
|
||||
"- Remember to wrap the environment in a [Monitor](https://stable-baselines3.readthedocs.io/en/master/common/monitor.html).\n",
|
||||
"- Now that our Lunar Lander agent is trained 🚀, we need to **check its performance**.\n",
|
||||
"- Stable-Baselines3 provides a method to do that: `evaluate_policy`.\n",
|
||||
"- To fill that part you need to [check the documentation](https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#basic-usage-training-saving-loading)\n",
|
||||
@@ -784,7 +786,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#@title\n",
|
||||
"eval_env = gym.make(\"LunarLander-v2\")\n",
|
||||
"eval_env = Monitor(gym.make(\"LunarLander-v2\"))\n",
|
||||
"mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)\n",
|
||||
"print(f\"mean_reward={mean_reward:.2f} +/- {std_reward}\")"
|
||||
]
|
||||
@@ -917,7 +919,7 @@
|
||||
"env_id = \n",
|
||||
"\n",
|
||||
"# Create the evaluation env and set the render_mode=\"rgb_array\"\n",
|
||||
"eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode=\"rgb_array\")])\n",
|
||||
"eval_env = DummyVecEnv([lambda: Monitor(gym.make(env_id, render_mode=\"rgb_array\"))])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# TODO: Define the model architecture we used\n",
|
||||
@@ -1096,7 +1098,7 @@
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"#@title\n",
|
||||
"eval_env = gym.make(\"LunarLander-v2\")\n",
|
||||
"eval_env = Monitor(gym.make(\"LunarLander-v2\"))\n",
|
||||
"mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)\n",
|
||||
"print(f\"mean_reward={mean_reward:.2f} +/- {std_reward}\")"
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user