From 9b5db4e879c1f8a377d15be5aed95d335eade02b Mon Sep 17 00:00:00 2001 From: Thomas Simonini Date: Wed, 4 Jan 2023 14:24:22 +0100 Subject: [PATCH] Update colab --- notebooks/unit4/unit4.ipynb | 184 +++++++++++++++++------------------- 1 file changed, 85 insertions(+), 99 deletions(-) diff --git a/notebooks/unit4/unit4.ipynb b/notebooks/unit4/unit4.ipynb index e76b3c6..45d1c0c 100644 --- a/notebooks/unit4/unit4.ipynb +++ b/notebooks/unit4/unit4.ipynb @@ -46,12 +46,12 @@ { "cell_type": "markdown", "source": [ - "###🎮 Environments: \n", + "### 🎮 Environments: \n", "\n", "- [CartPole-v1](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n", "- [PixelCopter](https://pygame-learning-environment.readthedocs.io/en/latest/user/games/pixelcopter.html)\n", "\n", - "###📚 RL-Library: \n", + "### 📚 RL-Library: \n", "\n", "- Python\n", "- PyTorch\n", @@ -351,15 +351,6 @@ "### [The environment 🎮](https://www.gymlibrary.dev/environments/classic_control/cart_pole/)\n" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "vVwcV9LjMzQk" - }, - "source": [ - "![cartpole.jpg](data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEAkACQAAD/2wBDAAIBAQIBAQICAgICAgICAwUDAwMDAwYEBAMFBwYHBwcGBwcICQsJCAgKCAcHCg0KCgsMDAwMBwkODw0MDgsMDAz/2wBDAQICAgMDAwYDAwYMCAcIDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAz/wAARCAC6AUsDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD9+KKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACikJwK+bfC/wDwUx8LalaSNqmiaxp0yvhEgaO4VxxySSmD14wenXnFceKzDD4ZpV58t9vkdWGwVfEJujG9t/mfSdFQafqCanZQXEYIjnjWRQeuCM81PXZ0ucrVgooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAQ9K/JWL/WfiK/Wo9K/JRG2yfiK/PuOv+XH/b36H2/B21b/ALd/U/Vzwlx4W03/AK9Yv/QRWhWf4SH/ABSunf8AXrF/6CK0K++p/AvQ+LqfGwoooqyAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigBD0r8kv8Alp+NfraelfkoV+cfga/P+Ov+XH/b36H2/Bv/AC++X6n6ueD/APkUdN/69Iv/AEAVoVn+D/8AkUdM/wCvSH/0AVoV97S+Beh8XU+NhRRRVkBRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFACHpX5KH74+g/nX61npX5KH74+g/nX59x1/y4/7e/Q+34N/5ff9u/qfq54P/wCRR0z/AK9If/QBWhWf4P8A+RR0z/r0h/8AQBWhX31L4F6HxdT42FFFFWQFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAIelfkofvj6D+dfrWelfkofvj6D+dfn3HX/Lj/ALe/Q+34N/5ff9u/qfq54P8A+RR0z/r0h/8AQBWhWf4P/wCRR0z/AK9If/QBWhX31L4F6HxdT42FFFFWQFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAIelfkofvj6D+dfrWelfkofvj6D+dfn3HX/Lj/t79D7fg3/l9/27+p+rng//AJFHTP8Ar0h/9AFaFZ/g/wD5FHTP+vSH/wBAFaFffUvgXofF1PjYUUUVZAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRXA/tV/Fi9+Av7L3xI8c6bb2t5qPgvwtqeu2tvc7vJnltbSWdEfaQ20sgBwQcE4Ir+c/VP+D1n9pWW6zZfDj4GwQ4HyzabqsrZ7/ML9R+lAH9Np6V+SWf3n41+e/wDxGqftQH/mn3wF/wDBPq3/AMsa+85/GNlb6g8BZ96PsJ2HAI618hxVlGNxzpLBUpVOXmb5U3Zab2Pr+FcXQoe19tNRva1/mfrl4Q/5FPS/+vSL/wBAFaNfzX67/wAHnXx38E61eaPbfC74STW2kzvZxSSJqG90jYoCcXIGSAOlU/8AiNj+Pv8A0Sn4P/8AfGo//JNfWU1aCTPlKnxM/pcory39h747aj+1H+xl8KPiVq9nZafqnj7wlpfiC8tbPd9nt5rq1jmdI9xLbAzkDJJwBkmvUqszCiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA8c/4eGfAv/orPgP8A8HMP/wAVR/w8M+Bf/RWfAf8A4OYf/iq9jop6D0PHP+HhnwL/AOis+A//AAcw/wDxVH/Dwz4F/wDRWfAf/g5h/wDiq9joo0DQ8c/4eGfAv/orPgP/AMHMP/xVH/Dwz4F/9FZ8B/8Ag5h/+Kr2OijQNDxz/h4Z8C/+is+A/wDwcw//ABVH/Dwz4F/9FZ8B/wDg5h/+Kr2OijQNDxz/AIeGfAv/AKKz4D/8HMP/AMVR/wAPDPgX/wBFZ8B/+DmH/wCKr2OijQNDxz/h4Z8C/wDorPgP/wAHMP8A8VR/w8M+Bf8A0VnwH/4OYf8A4qvY6KNA0PHP+HhnwL/6Kz4D/wDBzD/8VR/w8M+Bf/RWfAf/AIOYf/iq9joo0DQ8c/4eG/Av/orPgP8A8HMH/wAVR/w8N+Bf/RWPAf8A4OYP/iq9joo0EeMf8PEfgZ/0VfwJ/wCDmD/4qj/h4l8DP+ir+BP/AAcwf/FV7PRSA+O/+CgX7e3wY8T/ALBnxt03T/id4KvL/UPAOu21tbw6tC0k8r6dOqIoDZJLEAD1Nfxl1/ch/wAFH/8AlHj8ef8AsnXiH/02XFfw30AKOtfsh4C/4LIfAjxvqF6l5q/iHwba25QwvrWkSSm5DZ4QWhn5XAzv2D5hgnnH43U4vmvSy7NsXged4WfK5Kz0T/PqLlhK3Or2Nfx3qEOqeMtXubaQTW1xezSxSAEblZyQcEAjIPesccmkzSjrXnNtu7Kbuf2Lf8ErP26fg34H/wCCZH7Pej6t8TvBVhqemfDrQba7tptUjSW3lTT4VdHUnIYEEEHuK97/AOHiXwK/6Kz4D/8ABvF/jXMf8Efv+UUn7N3/AGTTw/8A+m6Cvo2gR4z/AMPEvgV/0VnwH/4N4v8AGj/h4l8Cv+is+A//AAbxf417NRSA8Z/4eJfAr/orPgP/AMG8X+NH/DxL4Ff9FZ8B/wDg3i/xr2aigDxn/h4l8Cv+is+A/wDwbxf40f8ADxL4Ff8ARWfAf/g3i/xr2aigDxn/AIeJfAr/AKKz4D/8G8X+NH/DxL4Ff9FZ8B/+DeL/ABr2aigDxn/h4l8Cv+is+A//AAbxf40f8PEvgV/0VnwH/wCDeL/GvZqKAPGf+HiXwK/6Kz4D/wDBvF/jR/w8S+BX/RWfAf8A4N4v8a9mooA8Z/4eJfAr/orPgP8A8G8X+NH/AA8S+BX/AEVnwH/4N4v8a9mooA8Z/wCHiXwK/wCis+A//BvF/jR/w8S+BX/RWfAf/g3i/wAa9mooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA4n9pT4RN+0D+zn4/wDAaX40p/G3hvUdAW9aHzhZm6tZIPNKZXft8zdt3DOMZHWvwo/4gcdW/wCjk9O/8IV//k+v6DaKAP58v+IHHVv+jk9O/wDCFf8A+T6P+IHDVv8Ao5TTv/CEf/5Pr+g2igD+fL/iBw1b/o5TTv8AwhH/APk+j/iBw1Yf83Kad/4Qr/8AyfX9BtFAHmv7Gv7P7/so/sl/DT4ZSaouuSfD/wAMaf4efUVg+zrfG1to4TKI9zbA2zO3c2M4yetelUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQB/9k=)" - ] - }, { "cell_type": "markdown", "metadata": { @@ -1028,6 +1019,8 @@ "import json\n", "import imageio\n", "\n", + "import tempfile\n", + "\n", "import os" ], "metadata": { @@ -1068,18 +1061,12 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "D1ywQFrkf3t8" - }, - "outputs": [], "source": [ "def push_to_hub(repo_id, \n", " model,\n", " hyperparameters,\n", " eval_env,\n", - " video_fps=30,\n", - " local_repo_path=\"hub\"\n", + " video_fps=30\n", " ):\n", " \"\"\"\n", " Evaluate, Generate a video and Upload a model to Hugging Face Hub.\n", @@ -1094,7 +1081,6 @@ " :param hyperparameters: training hyperparameters\n", " :param eval_env: evaluation environment\n", " :param video_fps: how many frame per seconds to record our video replay \n", - " :param local_repo_path: where the local repository is\n", " \"\"\"\n", "\n", " _, repo_name = repo_id.split(\"/\")\n", @@ -1106,97 +1092,101 @@ " exist_ok=True,\n", " )\n", "\n", - " # Step 2: Download files\n", - " repo_local_path = Path(snapshot_download(repo_id=repo_id))\n", - "\n", - " # Step 3: Save the model\n", - " torch.save(model, os.path.join(repo_local_path,\"model.pt\"))\n", - "\n", - " # Step 4: Save the hyperparameters to JSON\n", - " with open(Path(repo_local_path) / \"hyperparameters.json\", \"w\") as outfile:\n", - " json.dump(hyperparameters, outfile)\n", + " with tempfile.TemporaryDirectory() as tmpdirname:\n", + " local_directory = Path(tmpdirname)\n", " \n", - " # Step 5: Evaluate the model and build JSON\n", - " mean_reward, std_reward = evaluate_agent(eval_env, \n", - " hyperparameters[\"max_t\"],\n", - " hyperparameters[\"n_evaluation_episodes\"], \n", - " model)\n", - " # Get datetime\n", - " eval_datetime = datetime.datetime.now()\n", - " eval_form_datetime = eval_datetime.isoformat()\n", + " # Step 2: Save the model\n", + " torch.save(model, local_directory / \"model.pt\")\n", "\n", - " evaluate_data = {\n", - " \"env_id\": hyperparameters[\"env_id\"], \n", - " \"mean_reward\": mean_reward,\n", - " \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n", - " \"eval_datetime\": eval_form_datetime,\n", - " }\n", + " # Step 3: Save the hyperparameters to JSON\n", + " with open(local_directory / \"hyperparameters.json\", \"w\") as outfile:\n", + " json.dump(hyperparameters, outfile)\n", + " \n", + " # Step 4: Evaluate the model and build JSON\n", + " mean_reward, std_reward = evaluate_agent(eval_env, \n", + " hyperparameters[\"max_t\"],\n", + " hyperparameters[\"n_evaluation_episodes\"], \n", + " model)\n", + " # Get datetime\n", + " eval_datetime = datetime.datetime.now()\n", + " eval_form_datetime = eval_datetime.isoformat()\n", "\n", - " # Write a JSON file\n", - " with open(Path(repo_local_path) / \"results.json\", \"w\") as outfile:\n", - " json.dump(evaluate_data, outfile)\n", + " evaluate_data = {\n", + " \"env_id\": hyperparameters[\"env_id\"], \n", + " \"mean_reward\": mean_reward,\n", + " \"n_evaluation_episodes\": hyperparameters[\"n_evaluation_episodes\"],\n", + " \"eval_datetime\": eval_form_datetime,\n", + " }\n", "\n", - " # Step 6: Create the model card\n", - " # Env id\n", - " env_name = hyperparameters[\"env_id\"]\n", - " \n", - " metadata = {}\n", - " metadata[\"tags\"] = [\n", - " env_name,\n", - " \"reinforce\",\n", - " \"reinforcement-learning\",\n", - " \"custom-implementation\",\n", - " \"deep-rl-class\"\n", - " ]\n", + " # Write a JSON file\n", + " with open(local_directory / \"results.json\", \"w\") as outfile:\n", + " json.dump(evaluate_data, outfile)\n", "\n", - " # Add metrics\n", - " eval = metadata_eval_result(\n", - " model_pretty_name=repo_name,\n", - " task_pretty_name=\"reinforcement-learning\",\n", - " task_id=\"reinforcement-learning\",\n", - " metrics_pretty_name=\"mean_reward\",\n", - " metrics_id=\"mean_reward\",\n", - " metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n", - " dataset_pretty_name=env_name,\n", - " dataset_id=env_name,\n", - " )\n", + " # Step 5: Create the model card\n", + " env_name = hyperparameters[\"env_id\"]\n", + " \n", + " metadata = {}\n", + " metadata[\"tags\"] = [\n", + " env_name,\n", + " \"reinforce\",\n", + " \"reinforcement-learning\",\n", + " \"custom-implementation\",\n", + " \"deep-rl-class\"\n", + " ]\n", "\n", - " # Merges both dictionaries\n", - " metadata = {**metadata, **eval}\n", + " # Add metrics\n", + " eval = metadata_eval_result(\n", + " model_pretty_name=repo_name,\n", + " task_pretty_name=\"reinforcement-learning\",\n", + " task_id=\"reinforcement-learning\",\n", + " metrics_pretty_name=\"mean_reward\",\n", + " metrics_id=\"mean_reward\",\n", + " metrics_value=f\"{mean_reward:.2f} +/- {std_reward:.2f}\",\n", + " dataset_pretty_name=env_name,\n", + " dataset_id=env_name,\n", + " )\n", "\n", - " model_card = f\"\"\"\n", + " # Merges both dictionaries\n", + " metadata = {**metadata, **eval}\n", + "\n", + " model_card = f\"\"\"\n", " # **Reinforce** Agent playing **{env_id}**\n", " This is a trained model of a **Reinforce** agent playing **{env_id}** .\n", " To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction\n", " \"\"\"\n", "\n", - " readme_path = repo_local_path / \"README.md\"\n", - " readme = \"\"\n", - " if readme_path.exists():\n", - " with readme_path.open(\"r\", encoding=\"utf8\") as f:\n", - " readme = f.read()\n", - " else:\n", - " readme = model_card\n", + " readme_path = local_directory / \"README.md\"\n", + " readme = \"\"\n", + " if readme_path.exists():\n", + " with readme_path.open(\"r\", encoding=\"utf8\") as f:\n", + " readme = f.read()\n", + " else:\n", + " readme = model_card\n", "\n", - " with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n", - " f.write(readme)\n", + " with readme_path.open(\"w\", encoding=\"utf-8\") as f:\n", + " f.write(readme)\n", "\n", - " # Save our metrics to Readme metadata\n", - " metadata_save(readme_path, metadata)\n", + " # Save our metrics to Readme metadata\n", + " metadata_save(readme_path, metadata)\n", "\n", - " # Step 7: Record a video\n", - " video_path = repo_local_path / \"replay.mp4\"\n", - " record_video(env, model, video_path, video_fps)\n", + " # Step 6: Record a video\n", + " video_path = local_directory / \"replay.mp4\"\n", + " record_video(env, model, video_path, video_fps)\n", "\n", - " # Step 7. Push everything to the Hub\n", - " api.upload_folder(\n", - " repo_id=repo_id,\n", - " folder_path=repo_local_path,\n", - " path_in_repo=\".\",\n", - " )\n", + " # Step 7. Push everything to the Hub\n", + " api.upload_folder(\n", + " repo_id=repo_id,\n", + " folder_path=local_directory,\n", + " path_in_repo=\".\",\n", + " )\n", "\n", - " print(f\"Your model is pushed to the hub. You can view your model here: {repo_url}\")" - ] + " print(f\"Your model is pushed to the Hub. You can view your model here: {repo_url}\")" + ], + "metadata": { + "id": "_TPdq47D7_f_" + }, + "execution_count": null, + "outputs": [] }, { "cell_type": "markdown", @@ -1240,7 +1230,6 @@ }, "outputs": [], "source": [ - "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, @@ -1275,8 +1264,7 @@ " cartpole_policy, # The model we want to save\n", " cartpole_hyperparameters, # Hyperparameters\n", " eval_env, # Evaluation environment\n", - " video_fps=30,\n", - " local_repo_path=\"hub\",\n", + " video_fps=30\n", " )" ] }, @@ -1529,8 +1517,7 @@ " pixelcopter_policy, # The model we want to save\n", " pixelcopter_hyperparameters, # Hyperparameters\n", " eval_env, # Evaluation environment\n", - " video_fps=30,\n", - " local_repo_path=\"hub\",\n", + " video_fps=30\n", " )" ], "metadata": { @@ -1598,7 +1585,6 @@ "JoTC9o2SczNn", "gfGJNZBUP7Vn", "YB0Cxrw1StrP", - "Jmhs1k-cftIq", "47iuAFqV8Ws-", "x62pP0PHdA-y" ],