mirror of
https://github.com/huggingface/deep-rl-class.git
synced 2026-04-01 09:40:26 +08:00
Update colab
This commit is contained in:
@@ -46,12 +46,12 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"###🎮 Environments: \n",
|
||||
"### 🎮 Environments: \n",
|
||||
"\n",
|
||||
"- [CartPole-v1](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n",
|
||||
"- [PixelCopter](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n",
|
||||
"\n",
|
||||
"###📚 RL-Library: \n",
|
||||
"### 📚 RL-Library: \n",
|
||||
"\n",
|
||||
"- Python\n",
|
||||
"- PyTorch\n",
|
||||
@@ -351,15 +351,6 @@
|
||||
"### [The environment 🎮](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "vVwcV9LjMzQk"
|
||||
},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@@ -1028,6 +1019,8 @@
|
||||
"import json\n",
|
||||
"import imageio\n",
|
||||
"\n",
|
||||
"import tempfile\n",
|
||||
"\n",
|
||||
"import os"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -1068,18 +1061,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "D1ywQFrkf3t8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def push_to_hub(repo_id, \n",
|
||||
" model,\n",
|
||||
" hyperparameters,\n",
|
||||
" eval_env,\n",
|
||||
" video_fps=30,\n",
|
||||
" local_repo_path=\"hub\"\n",
|
||||
" video_fps=30\n",
|
||||
" ):\n",
|
||||
" \"\"\"\n",
|
||||
" Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n",
|
||||
@@ -1094,7 +1081,6 @@
|
||||
" :param hyperparameters: training hyperparameters\n",
|
||||
" :param eval_env: evaluation environment\n",
|
||||
" :param video_fps: how many frame per seconds to record our video replay \n",
|
||||
" :param local_repo_path: where the local repository is\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" _, repo_name = repo_id.split(\"/\")\n",
|
||||
@@ -1106,97 +1092,101 @@
|
||||
" exist_ok=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Step 2: Download files\n",
|
||||
" repo_local_path = Path(snapshot_download(repo_id=repo_id))\n",
|
||||
"\n",
|
||||
" # Step 3: Save the model\n",
|
||||
" torch.save(model, os.path.join(repo_local_path,\"model.pt\"))\n",
|
||||
"\n",
|
||||
" # Step 4: Save the hyperparameters to JSON\n",
|
||||
" with open(Path(repo_local_path) / \"hyperparameters.json\", \"w\") as outfile:\n",
|
||||
" json.dump(hyperparameters, outfile)\n",
|
||||
" with tempfile.TemporaryDirectory() as tmpdirname:\n",
|
||||
" local_directory = Path(tmpdirname)\n",
|
||||
" \n",
|
||||
" # Step 5: Evaluate the model and build JSON\n",
|
||||
" mean_reward, std_reward = evaluate_agent(eval_env, \n",
|
||||
" hyperparameters[\"max_t\"],\n",
|
||||
" hyperparameters[\"n_evaluation_episodes\"], \n",
|
||||
" model)\n",
|
||||
" # Get datetime\n",
|
||||
" eval_datetime = datetime.datetime.now()\n",
|
||||
" eval_form_datetime = eval_datetime.isoformat()\n",
|
||||
" # Step 2: Save the model\n",
|
||||
" torch.save(model, local_directory / \"model.pt\")\n",
|
||||
"\n",
|
||||
" evaluate_data = {\n",
|
||||
" \"env_id\": hyperparameters[\"env_id\"], \n",
|
||||
" \"mean_reward\": mean_reward,\n",
|
||||
" \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
|
||||
" \"eval_datetime\": eval_form_datetime,\n",
|
||||
" }\n",
|
||||
" # Step 3: Save the hyperparameters to JSON\n",
|
||||
" with open(local_directory / \"hyperparameters.json\", \"w\") as outfile:\n",
|
||||
" json.dump(hyperparameters, outfile)\n",
|
||||
" \n",
|
||||
" # Step 4: Evaluate the model and build JSON\n",
|
||||
" mean_reward, std_reward = evaluate_agent(eval_env, \n",
|
||||
" hyperparameters[\"max_t\"],\n",
|
||||
" hyperparameters[\"n_evaluation_episodes\"], \n",
|
||||
" model)\n",
|
||||
" # Get datetime\n",
|
||||
" eval_datetime = datetime.datetime.now()\n",
|
||||
" eval_form_datetime = eval_datetime.isoformat()\n",
|
||||
"\n",
|
||||
" # Write a JSON file\n",
|
||||
" with open(Path(repo_local_path) / \"results.json\", \"w\") as outfile:\n",
|
||||
" json.dump(evaluate_data, outfile)\n",
|
||||
" evaluate_data = {\n",
|
||||
" \"env_id\": hyperparameters[\"env_id\"], \n",
|
||||
" \"mean_reward\": mean_reward,\n",
|
||||
" \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
|
||||
" \"eval_datetime\": eval_form_datetime,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # Step 6: Create the model card\n",
|
||||
" # Env id\n",
|
||||
" env_name = hyperparameters[\"env_id\"]\n",
|
||||
" \n",
|
||||
" metadata = {}\n",
|
||||
" metadata[\"tags\"] = [\n",
|
||||
" env_name,\n",
|
||||
" \"reinforce\",\n",
|
||||
" \"reinforcement-learning\",\n",
|
||||
" \"custom-implementation\",\n",
|
||||
" \"deep-rl-class\"\n",
|
||||
" ]\n",
|
||||
" # Write a JSON file\n",
|
||||
" with open(local_directory / \"results.json\", \"w\") as outfile:\n",
|
||||
" json.dump(evaluate_data, outfile)\n",
|
||||
"\n",
|
||||
" # Add metrics\n",
|
||||
" eval = metadata_eval_result(\n",
|
||||
" model_pretty_name=repo_name,\n",
|
||||
" task_pretty_name=\"reinforcement-learning\",\n",
|
||||
" task_id=\"reinforcement-learning\",\n",
|
||||
" metrics_pretty_name=\"mean_reward\",\n",
|
||||
" metrics_id=\"mean_reward\",\n",
|
||||
" metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
|
||||
" dataset_pretty_name=env_name,\n",
|
||||
" dataset_id=env_name,\n",
|
||||
" )\n",
|
||||
" # Step 5: Create the model card\n",
|
||||
" env_name = hyperparameters[\"env_id\"]\n",
|
||||
" \n",
|
||||
" metadata = {}\n",
|
||||
" metadata[\"tags\"] = [\n",
|
||||
" env_name,\n",
|
||||
" \"reinforce\",\n",
|
||||
" \"reinforcement-learning\",\n",
|
||||
" \"custom-implementation\",\n",
|
||||
" \"deep-rl-class\"\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # Merges both dictionaries\n",
|
||||
" metadata = {**metadata, **eval}\n",
|
||||
" # Add metrics\n",
|
||||
" eval = metadata_eval_result(\n",
|
||||
" model_pretty_name=repo_name,\n",
|
||||
" task_pretty_name=\"reinforcement-learning\",\n",
|
||||
" task_id=\"reinforcement-learning\",\n",
|
||||
" metrics_pretty_name=\"mean_reward\",\n",
|
||||
" metrics_id=\"mean_reward\",\n",
|
||||
" metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
|
||||
" dataset_pretty_name=env_name,\n",
|
||||
" dataset_id=env_name,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" model_card = f\"\"\"\n",
|
||||
" # Merges both dictionaries\n",
|
||||
" metadata = {**metadata, **eval}\n",
|
||||
"\n",
|
||||
" model_card = f\"\"\"\n",
|
||||
" # **Reinforce** Agent playing **{env_id}**\n",
|
||||
" This is a trained model of a **Reinforce** agent playing **{env_id}** .\n",
|
||||
" To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" readme_path = repo_local_path / \"README.md\"\n",
|
||||
" readme = \"\"\n",
|
||||
" if readme_path.exists():\n",
|
||||
" with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
|
||||
" readme = f.read()\n",
|
||||
" else:\n",
|
||||
" readme = model_card\n",
|
||||
" readme_path = local_directory / \"README.md\"\n",
|
||||
" readme = \"\"\n",
|
||||
" if readme_path.exists():\n",
|
||||
" with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
|
||||
" readme = f.read()\n",
|
||||
" else:\n",
|
||||
" readme = model_card\n",
|
||||
"\n",
|
||||
" with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
|
||||
" f.write(readme)\n",
|
||||
" with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
|
||||
" f.write(readme)\n",
|
||||
"\n",
|
||||
" # Save our metrics to Readme metadata\n",
|
||||
" metadata_save(readme_path, metadata)\n",
|
||||
" # Save our metrics to Readme metadata\n",
|
||||
" metadata_save(readme_path, metadata)\n",
|
||||
"\n",
|
||||
" # Step 7: Record a video\n",
|
||||
" video_path = repo_local_path / \"replay.mp4\"\n",
|
||||
" record_video(env, model, video_path, video_fps)\n",
|
||||
" # Step 6: Record a video\n",
|
||||
" video_path = local_directory / \"replay.mp4\"\n",
|
||||
" record_video(env, model, video_path, video_fps)\n",
|
||||
"\n",
|
||||
" # Step 7. Push everything to the Hub\n",
|
||||
" api.upload_folder(\n",
|
||||
" repo_id=repo_id,\n",
|
||||
" folder_path=repo_local_path,\n",
|
||||
" path_in_repo=\".\",\n",
|
||||
" )\n",
|
||||
" # Step 7. Push everything to the Hub\n",
|
||||
" api.upload_folder(\n",
|
||||
" repo_id=repo_id,\n",
|
||||
" folder_path=local_directory,\n",
|
||||
" path_in_repo=\".\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" print(f\"Your model is pushed to the hub. You can view your model here: {repo_url}\")"
|
||||
]
|
||||
" print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "_TPdq47D7_f_"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -1240,7 +1230,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
@@ -1275,8 +1264,7 @@
|
||||
" cartpole_policy, # The model we want to save\n",
|
||||
" cartpole_hyperparameters, # Hyperparameters\n",
|
||||
" eval_env, # Evaluation environment\n",
|
||||
" video_fps=30,\n",
|
||||
" local_repo_path=\"hub\",\n",
|
||||
" video_fps=30\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
@@ -1529,8 +1517,7 @@
|
||||
" pixelcopter_policy, # The model we want to save\n",
|
||||
" pixelcopter_hyperparameters, # Hyperparameters\n",
|
||||
" eval_env, # Evaluation environment\n",
|
||||
" video_fps=30,\n",
|
||||
" local_repo_path=\"hub\",\n",
|
||||
" video_fps=30\n",
|
||||
" )"
|
||||
],
|
||||
"metadata": {
|
||||
@@ -1598,7 +1585,6 @@
|
||||
"JoTC9o2SczNn",
|
||||
"gfGJNZBUP7Vn",
|
||||
"YB0Cxrw1StrP",
|
||||
"Jmhs1k-cftIq",
|
||||
"47iuAFqV8Ws-",
|
||||
"x62pP0PHdA-y"
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user