mirror of
https://github.com/huggingface/deep-rl-class.git
synced 2026-04-08 13:20:41 +08:00
Merge pull request #12 from diskshima/fix/unit1-missing-import-and-save-model
Add missing pyglet import and step to save model to file
This commit is contained in:
@@ -265,6 +265,7 @@
|
||||
"!pip install gym[box2d]\n",
|
||||
"!pip install stable-baselines3[extra]\n",
|
||||
"!pip install huggingface_sb3\n",
|
||||
"!pip install pyglet\n",
|
||||
"!pip install ale-py==0.7.4 # To overcome an issue with gym (https://github.com/DLR-RM/stable-baselines3/issues/875)"
|
||||
]
|
||||
},
|
||||
@@ -694,7 +695,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train it for 500,000 timesteps"
|
||||
"# TODO: Train it for 500,000 timesteps\n",
|
||||
"\n",
|
||||
"# TODO: Specify file name for model and save the model to file\n",
|
||||
"model_name = \"\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -716,7 +720,10 @@
|
||||
"source": [
|
||||
"# SOLUTION\n",
|
||||
"# Train it for 500,000 timesteps\n",
|
||||
"model.learn(total_timesteps=500000)"
|
||||
"model.learn(total_timesteps=500000)\n",
|
||||
"# Save the model\n",
|
||||
"model_name = \"ppo-LunarLander-v2\"\n",
|
||||
"model.save(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -979,9 +986,6 @@
|
||||
"# Define the name of the environment\n",
|
||||
"env_id = \"LunarLander-v2\"\n",
|
||||
"\n",
|
||||
"# Define the name of the trained model that we defined in model_save\n",
|
||||
"model_name = \"ppo-LunarLander-v2\"\n",
|
||||
"\n",
|
||||
"# TODO: Define the model architecture we used\n",
|
||||
"model_architecture = \"PPO\"\n",
|
||||
"\n",
|
||||
@@ -1108,4 +1112,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user