Update colab

This commit is contained in:
Thomas Simonini
2023-01-04 14:24:22 +01:00
parent 8a35f1bf67
commit 9b5db4e879

View File

@@ -46,12 +46,12 @@
{
"cell_type": "markdown",
"source": [
"###🎮 Environments: \n",
"### 🎮 Environments: \n",
"\n",
"- [CartPole-v1](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n",
"- [PixelCopter](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n",
"\n",
"###📚 RL-Library: \n",
"### 📚 RL-Library: \n",
"\n",
"- Python\n",
"- PyTorch\n",
@@ -351,15 +351,6 @@
"### [The environment 🎮](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vVwcV9LjMzQk"
},
"source": [
"![cartpole.jpg](data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAkACQAAD/2wBDAAIBAQIBAQICAgICAgICAwUDAwMDAwYEBAMFBwYHBwcGBwcICQsJCAgKCAcHCg0KCgsMDAwMBwkODw0MDgsMDAz/2wBDAQICAgMDAwYDAwYMCAcIDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAz/wAARCAC6AUsDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD9+KKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACikJwK+bfC/wDwUx8LalaSNqmiaxp0yvhEgaO4VxxySSmD14wenXnFceKzDD4ZpV58t9vkdWGwVfEJujG9t/mfSdFQafqCanZQXEYIjnjWRQeuCM81PXZ0ucrVgooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAQ9K/JWL/WfiK/Wo9K/JRG2yfiK/PuOv+XH/b36H2/B21b/ALd/U/Vzwlx4W03/AK9Yv/QRWhWf4SH/ABSunf8AXrF/6CK0K++p/AvQ+LqfGwoooqyAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigBD0r8kv8Alp+NfraelfkoV+cfga/P+Ov+XH/b36H2/Bv/AC++X6n6ueD/APkUdN/69Iv/AEAVoVn+D/8AkUdM/wCvSH/0AVoV97S+Beh8XU+NhRRRVkBRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFACHpX5KH74+g/nX61npX5KH74+g/nX59x1/y4/7e/Q+34N/5ff9u/qfq54P/wCRR0z/AK9If/QBWhWf4P8A+RR0z/r0h/8AQBWhX31L4F6HxdT42FFFFWQFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAIelfkofvj6D+dfrWelfkofvj6D+dfn3HX/Lj/ALe/Q+34N/5ff9u/qfq54P8A+RR0z/r0h/8AQBWhWf4P/wCRR0z/AK9If/QBWhX31L4F6HxdT42FFFFWQFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAIelfkofvj6D+dfrWelfkofvj6D+dfn3HX/Lj/t79D7fg3/l9/27+p+rng//AJFHTP8Ar0h/9AFaFZ/g/wD5FHTP+vSH/wBAFaFffUvgXofF1PjYUUUVZAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRXA/tV/Fi9+Av7L3xI8c6bb2t5qPgvwtqeu2tvc7vJnltbSWdEfaQ20sgBwQcE4Ir+c/VP+D1n9pWW6zZfDj4GwQ4HyzabqsrZ7/ML9R+lAH9Np6V+SWf3n41+e/wDxGqftQH/mn3wF/wDBPq3/AMsa+85/GNlb6g8BZ96PsJ2HAI618hxVlGNxzpLBUpVOXmb5U3Zab2Pr+FcXQoe19tNRva1/mfrl4Q/5FPS/+vSL/wBAFaNfzX67/wAHnXx38E61eaPbfC74STW2kzvZxSSJqG90jYoCcXIGSAOlU/8AiNj+Pv8A0Sn4P/8AfGo//JNfWU1aCTPlKnxM/pcory39h747aj+1H+xl8KPiVq9nZafqnj7wlpfiC8tbPd9nt5rq1jmdI9xLbAzkDJJwBkmvUqszCiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA8c/4eGfAv/orPgP8A8HMP/wAVR/w8M+Bf/RWfAf8A4OYf/iq9jop6D0PHP+HhnwL/AOis+A//AAcw/wDxVH/Dwz4F/wDRWfAf/g5h/wDiq9joo0DQ8c/4eGfAv/orPgP/AMHMP/xVH/Dwz4F/9FZ8B/8Ag5h/+Kr2OijQNDxz/h4Z8C/+is+A/wDwcw//ABVH/Dwz4F/9FZ8B/wDg5h/+Kr2OijQNDxz/AIeGfAv/AKKz4D/8HMP/AMVR/wAPDPgX/wBFZ8B/+DmH/wCKr2OijQNDxz/h4Z8C/wDorPgP/wAHMP8A8VR/w8M+Bf8A0VnwH/4OYf8A4qvY6KNA0PHP+HhnwL/6Kz4D/wDBzD/8VR/w8M+Bf/RWfAf/AIOYf/iq9joo0DQ8c/4eG/Av/orPgP8A8HMH/wAVR/w8N+Bf/RWPAf8A4OYP/iq9joo0EeMf8PEfgZ/0VfwJ/wCDmD/4qj/h4l8DP+ir+BP/AAcwf/FV7PRSA+O/+CgX7e3wY8T/ALBnxt03T/id4KvL/UPAOu21tbw6tC0k8r6dOqIoDZJLEAD1Nfxl1/ch/wAFH/8AlHj8ef8AsnXiH/02XFfw30AKOtfsh4C/4LIfAjxvqF6l5q/iHwba25QwvrWkSSm5DZ4QWhn5XAzv2D5hgnnH43U4vmvSy7NsXged4WfK5Kz0T/PqLlhK3Or2Nfx3qEOqeMtXubaQTW1xezSxSAEblZyQcEAjIPesccmkzSjrXnNtu7Kbuf2Lf8ErP26fg34H/wCCZH7Pej6t8TvBVhqemfDrQba7tptUjSW3lTT4VdHUnIYEEEHuK97/AOHiXwK/6Kz4D/8ABvF/jXMf8Efv+UUn7N3/AGTTw/8A+m6Cvo2gR4z/AMPEvgV/0VnwH/4N4v8AGj/h4l8Cv+is+A//AAbxf417NRSA8Z/4eJfAr/orPgP/AMG8X+NH/DxL4Ff9FZ8B/wDg3i/xr2aigDxn/h4l8Cv+is+A/wDwbxf40f8ADxL4Ff8ARWfAf/g3i/xr2aigDxn/AIeJfAr/AKKz4D/8G8X+NH/DxL4Ff9FZ8B/+DeL/ABr2aigDxn/h4l8Cv+is+A//AAbxf40f8PEvgV/0VnwH/wCDeL/GvZqKAPGf+HiXwK/6Kz4D/wDBvF/jR/w8S+BX/RWfAf8A4N4v8a9mooA8Z/4eJfAr/orPgP8A8G8X+NH/AA8S+BX/AEVnwH/4N4v8a9mooA8Z/wCHiXwK/wCis+A//BvF/jR/w8S+BX/RWfAf/g3i/wAa9mooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA4n9pT4RN+0D+zn4/wDAaX40p/G3hvUdAW9aHzhZm6tZIPNKZXft8zdt3DOMZHWvwo/4gcdW/wCjk9O/8IV//k+v6DaKAP58v+IHHVv+jk9O/wDCFf8A+T6P+IHDVv8Ao5TTv/CEf/5Pr+g2igD+fL/iBw1b/o5TTv8AwhH/APk+j/iBw1Yf83Kad/4Qr/8AyfX9BtFAHmv7Gv7P7/so/sl/DT4ZSaouuSfD/wAMaf4efUVg+zrfG1to4TKI9zbA2zO3c2M4yetelUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQB/9k=)"
]
},
{
"cell_type": "markdown",
"metadata": {
@@ -1028,6 +1019,8 @@
"import json\n",
"import imageio\n",
"\n",
"import tempfile\n",
"\n",
"import os"
],
"metadata": {
@@ -1068,18 +1061,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "D1ywQFrkf3t8"
},
"outputs": [],
"source": [
"def push_to_hub(repo_id, \n",
" model,\n",
" hyperparameters,\n",
" eval_env,\n",
" video_fps=30,\n",
" local_repo_path=\"hub\"\n",
" video_fps=30\n",
" ):\n",
" \"\"\"\n",
" Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n",
@@ -1094,7 +1081,6 @@
" :param hyperparameters: training hyperparameters\n",
" :param eval_env: evaluation environment\n",
" :param video_fps: how many frame per seconds to record our video replay \n",
" :param local_repo_path: where the local repository is\n",
" \"\"\"\n",
"\n",
" _, repo_name = repo_id.split(\"/\")\n",
@@ -1106,97 +1092,101 @@
" exist_ok=True,\n",
" )\n",
"\n",
" # Step 2: Download files\n",
" repo_local_path = Path(snapshot_download(repo_id=repo_id))\n",
"\n",
" # Step 3: Save the model\n",
" torch.save(model, os.path.join(repo_local_path,\"model.pt\"))\n",
"\n",
" # Step 4: Save the hyperparameters to JSON\n",
" with open(Path(repo_local_path) / \"hyperparameters.json\", \"w\") as outfile:\n",
" json.dump(hyperparameters, outfile)\n",
" with tempfile.TemporaryDirectory() as tmpdirname:\n",
" local_directory = Path(tmpdirname)\n",
" \n",
" # Step 5: Evaluate the model and build JSON\n",
" mean_reward, std_reward = evaluate_agent(eval_env, \n",
" hyperparameters[\"max_t\"],\n",
" hyperparameters[\"n_evaluation_episodes\"], \n",
" model)\n",
" # Get datetime\n",
" eval_datetime = datetime.datetime.now()\n",
" eval_form_datetime = eval_datetime.isoformat()\n",
" # Step 2: Save the model\n",
" torch.save(model, local_directory / \"model.pt\")\n",
"\n",
" evaluate_data = {\n",
" \"env_id\": hyperparameters[\"env_id\"], \n",
" \"mean_reward\": mean_reward,\n",
" \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
" \"eval_datetime\": eval_form_datetime,\n",
" }\n",
" # Step 3: Save the hyperparameters to JSON\n",
" with open(local_directory / \"hyperparameters.json\", \"w\") as outfile:\n",
" json.dump(hyperparameters, outfile)\n",
" \n",
" # Step 4: Evaluate the model and build JSON\n",
" mean_reward, std_reward = evaluate_agent(eval_env, \n",
" hyperparameters[\"max_t\"],\n",
" hyperparameters[\"n_evaluation_episodes\"], \n",
" model)\n",
" # Get datetime\n",
" eval_datetime = datetime.datetime.now()\n",
" eval_form_datetime = eval_datetime.isoformat()\n",
"\n",
" # Write a JSON file\n",
" with open(Path(repo_local_path) / \"results.json\", \"w\") as outfile:\n",
" json.dump(evaluate_data, outfile)\n",
" evaluate_data = {\n",
" \"env_id\": hyperparameters[\"env_id\"], \n",
" \"mean_reward\": mean_reward,\n",
" \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n",
" \"eval_datetime\": eval_form_datetime,\n",
" }\n",
"\n",
" # Step 6: Create the model card\n",
" # Env id\n",
" env_name = hyperparameters[\"env_id\"]\n",
" \n",
" metadata = {}\n",
" metadata[\"tags\"] = [\n",
" env_name,\n",
" \"reinforce\",\n",
" \"reinforcement-learning\",\n",
" \"custom-implementation\",\n",
" \"deep-rl-class\"\n",
" ]\n",
" # Write a JSON file\n",
" with open(local_directory / \"results.json\", \"w\") as outfile:\n",
" json.dump(evaluate_data, outfile)\n",
"\n",
" # Add metrics\n",
" eval = metadata_eval_result(\n",
" model_pretty_name=repo_name,\n",
" task_pretty_name=\"reinforcement-learning\",\n",
" task_id=\"reinforcement-learning\",\n",
" metrics_pretty_name=\"mean_reward\",\n",
" metrics_id=\"mean_reward\",\n",
" metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
" dataset_pretty_name=env_name,\n",
" dataset_id=env_name,\n",
" )\n",
" # Step 5: Create the model card\n",
" env_name = hyperparameters[\"env_id\"]\n",
" \n",
" metadata = {}\n",
" metadata[\"tags\"] = [\n",
" env_name,\n",
" \"reinforce\",\n",
" \"reinforcement-learning\",\n",
" \"custom-implementation\",\n",
" \"deep-rl-class\"\n",
" ]\n",
"\n",
" # Merges both dictionaries\n",
" metadata = {**metadata, **eval}\n",
" # Add metrics\n",
" eval = metadata_eval_result(\n",
" model_pretty_name=repo_name,\n",
" task_pretty_name=\"reinforcement-learning\",\n",
" task_id=\"reinforcement-learning\",\n",
" metrics_pretty_name=\"mean_reward\",\n",
" metrics_id=\"mean_reward\",\n",
" metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n",
" dataset_pretty_name=env_name,\n",
" dataset_id=env_name,\n",
" )\n",
"\n",
" model_card = f\"\"\"\n",
" # Merges both dictionaries\n",
" metadata = {**metadata, **eval}\n",
"\n",
" model_card = f\"\"\"\n",
" # **Reinforce** Agent playing **{env_id}**\n",
" This is a trained model of a **Reinforce** agent playing **{env_id}** .\n",
" To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n",
" \"\"\"\n",
"\n",
" readme_path = repo_local_path / \"README.md\"\n",
" readme = \"\"\n",
" if readme_path.exists():\n",
" with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
" readme = f.read()\n",
" else:\n",
" readme = model_card\n",
" readme_path = local_directory / \"README.md\"\n",
" readme = \"\"\n",
" if readme_path.exists():\n",
" with readme_path.open(\"r\", encoding=\"utf8\") as f:\n",
" readme = f.read()\n",
" else:\n",
" readme = model_card\n",
"\n",
" with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
" f.write(readme)\n",
" with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n",
" f.write(readme)\n",
"\n",
" # Save our metrics to Readme metadata\n",
" metadata_save(readme_path, metadata)\n",
" # Save our metrics to Readme metadata\n",
" metadata_save(readme_path, metadata)\n",
"\n",
" # Step 7: Record a video\n",
" video_path = repo_local_path / \"replay.mp4\"\n",
" record_video(env, model, video_path, video_fps)\n",
" # Step 6: Record a video\n",
" video_path = local_directory / \"replay.mp4\"\n",
" record_video(env, model, video_path, video_fps)\n",
"\n",
" # Step 7. Push everything to the Hub\n",
" api.upload_folder(\n",
" repo_id=repo_id,\n",
" folder_path=repo_local_path,\n",
" path_in_repo=\".\",\n",
" )\n",
" # Step 7. Push everything to the Hub\n",
" api.upload_folder(\n",
" repo_id=repo_id,\n",
" folder_path=local_directory,\n",
" path_in_repo=\".\",\n",
" )\n",
"\n",
" print(f\"Your model is pushed to the hub. You can view your model here: {repo_url}\")"
]
" print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")"
],
"metadata": {
"id": "_TPdq47D7_f_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
@@ -1240,7 +1230,6 @@
},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
@@ -1275,8 +1264,7 @@
" cartpole_policy, # The model we want to save\n",
" cartpole_hyperparameters, # Hyperparameters\n",
" eval_env, # Evaluation environment\n",
" video_fps=30,\n",
" local_repo_path=\"hub\",\n",
" video_fps=30\n",
" )"
]
},
@@ -1529,8 +1517,7 @@
" pixelcopter_policy, # The model we want to save\n",
" pixelcopter_hyperparameters, # Hyperparameters\n",
" eval_env, # Evaluation environment\n",
" video_fps=30,\n",
" local_repo_path=\"hub\",\n",
" video_fps=30\n",
" )"
],
"metadata": {
@@ -1598,7 +1585,6 @@
"JoTC9o2SczNn",
"gfGJNZBUP7Vn",
"YB0Cxrw1StrP",
"Jmhs1k-cftIq",
"47iuAFqV8Ws-",
"x62pP0PHdA-y"
],