{"cells": [{"cell_type": "markdown", "id": "ca70018b", "metadata": {"papermill": {"duration": 0.031767, "end_time": "2022-04-28T12:56:14.098871", "exception": false, "start_time": "2022-04-28T12:56:14.067104", "status": "completed"}, "tags": []}, "source": ["\n", "# How to train a Deep Q Network\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2022-04-28T08:05:34.347059\n", "\n", "Main takeaways:\n", "\n", "1. RL has the same flow as previous models we have seen, with a few additions\n", "2. Handle unsupervised learning by using an IterableDataset where the dataset itself is constantly updated during training\n", "3. Each training step carries has the agent taking an action in the environment and storing the experience in the IterableDataset\n", "\n", "\n", "---\n", "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/reinforce-learning-DQN.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/stable/)\n", "| Join us [on Slack](https://www.pytorchlightning.ai/community)"]}, {"cell_type": "markdown", "id": "a723e429", "metadata": {"papermill": {"duration": 0.027631, "end_time": "2022-04-28T12:56:14.156552", "exception": false, "start_time": "2022-04-28T12:56:14.128921", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "523f5401", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2022-04-28T12:56:14.219388Z", "iopub.status.busy": "2022-04-28T12:56:14.218850Z", "iopub.status.idle": "2022-04-28T12:56:17.557671Z", "shell.execute_reply": "2022-04-28T12:56:17.558094Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 3.374029, "end_time": "2022-04-28T12:56:17.558377", "exception": false, "start_time": "2022-04-28T12:56:14.184348", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: You are using pip version 21.3.1; however, version 22.0.4 is available.\r\n", "You should consider upgrading via the '/usr/bin/python3.8 -m pip install --upgrade pip' command.\u001b[0m\r\n"]}], "source": ["! pip install --quiet \"ipython[notebook]\" \"seaborn\" \"torchmetrics>=0.6\" \"pygame\" \"gym\" \"pandas\" \"pytorch-lightning>=1.4\" \"torch>=1.6, <1.9\""]}, {"cell_type": "code", "execution_count": 2, "id": "20ddd89b", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:17.627508Z", "iopub.status.busy": "2022-04-28T12:56:17.626981Z", "iopub.status.idle": "2022-04-28T12:56:19.713569Z", "shell.execute_reply": "2022-04-28T12:56:19.714005Z"}, "papermill": {"duration": 2.123494, "end_time": "2022-04-28T12:56:19.714180", "exception": false, "start_time": "2022-04-28T12:56:17.590686", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "from collections import OrderedDict, deque, namedtuple\n", "from typing import Iterator, List, Tuple\n", "\n", "import gym\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sn\n", "import torch\n", "from IPython.core.display import display\n", "from pytorch_lightning import LightningModule, Trainer\n", "from pytorch_lightning.loggers import CSVLogger\n", "from torch import Tensor, nn\n", "from torch.optim import Adam, Optimizer\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import IterableDataset\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")"]}, {"cell_type": "code", "execution_count": 3, "id": "cdd4342e", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:19.777436Z", "iopub.status.busy": "2022-04-28T12:56:19.776915Z", "iopub.status.idle": "2022-04-28T12:56:19.778807Z", "shell.execute_reply": "2022-04-28T12:56:19.779212Z"}, "papermill": {"duration": 0.036088, "end_time": "2022-04-28T12:56:19.779345", "exception": false, "start_time": "2022-04-28T12:56:19.743257", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class DQN(nn.Module):\n", " \"\"\"Simple MLP network.\"\"\"\n", "\n", " def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):\n", " \"\"\"\n", " Args:\n", " obs_size: observation/state size of the environment\n", " n_actions: number of discrete actions available in the environment\n", " hidden_size: size of hidden layers\n", " \"\"\"\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(obs_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Linear(hidden_size, n_actions),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x.float())"]}, {"cell_type": "markdown", "id": "4ad8cc8a", "metadata": {"papermill": {"duration": 0.030859, "end_time": "2022-04-28T12:56:19.839724", "exception": false, "start_time": "2022-04-28T12:56:19.808865", "status": "completed"}, "tags": []}, "source": ["### Memory"]}, {"cell_type": "code", "execution_count": 4, "id": "a25bdee7", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:19.905479Z", "iopub.status.busy": "2022-04-28T12:56:19.904955Z", "iopub.status.idle": "2022-04-28T12:56:19.906808Z", "shell.execute_reply": "2022-04-28T12:56:19.907215Z"}, "papermill": {"duration": 0.037822, "end_time": "2022-04-28T12:56:19.907361", "exception": false, "start_time": "2022-04-28T12:56:19.869539", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "# Named tuple for storing experience steps gathered in training\n", "Experience = namedtuple(\n", " \"Experience\",\n", " field_names=[\"state\", \"action\", \"reward\", \"done\", \"new_state\"],\n", ")"]}, {"cell_type": "code", "execution_count": 5, "id": "f58ca859", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:19.970796Z", "iopub.status.busy": "2022-04-28T12:56:19.970276Z", "iopub.status.idle": "2022-04-28T12:56:19.972820Z", "shell.execute_reply": "2022-04-28T12:56:19.972388Z"}, "papermill": {"duration": 0.036848, "end_time": "2022-04-28T12:56:19.972936", "exception": false, "start_time": "2022-04-28T12:56:19.936088", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class ReplayBuffer:\n", " \"\"\"Replay Buffer for storing past experiences allowing the agent to learn from them.\n", "\n", " Args:\n", " capacity: size of the buffer\n", " \"\"\"\n", "\n", " def __init__(self, capacity: int) -> None:\n", " self.buffer = deque(maxlen=capacity)\n", "\n", " def __len__(self) -> None:\n", " return len(self.buffer)\n", "\n", " def append(self, experience: Experience) -> None:\n", " \"\"\"Add experience to the buffer.\n", "\n", " Args:\n", " experience: tuple (state, action, reward, done, new_state)\n", " \"\"\"\n", " self.buffer.append(experience)\n", "\n", " def sample(self, batch_size: int) -> Tuple:\n", " indices = np.random.choice(len(self.buffer), batch_size, replace=False)\n", " states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))\n", "\n", " return (\n", " np.array(states),\n", " np.array(actions),\n", " np.array(rewards, dtype=np.float32),\n", " np.array(dones, dtype=bool),\n", " np.array(next_states),\n", " )"]}, {"cell_type": "code", "execution_count": 6, "id": "5c108e7e", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:20.036044Z", "iopub.status.busy": "2022-04-28T12:56:20.033381Z", "iopub.status.idle": "2022-04-28T12:56:20.037860Z", "shell.execute_reply": "2022-04-28T12:56:20.038266Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.035997, "end_time": "2022-04-28T12:56:20.038397", "exception": false, "start_time": "2022-04-28T12:56:20.002400", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class RLDataset(IterableDataset):\n", " \"\"\"Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.\n", "\n", " Args:\n", " buffer: replay buffer\n", " sample_size: number of experiences to sample at a time\n", " \"\"\"\n", "\n", " def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:\n", " self.buffer = buffer\n", " self.sample_size = sample_size\n", "\n", " def __iter__(self) -> Iterator[Tuple]:\n", " states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)\n", " for i in range(len(dones)):\n", " yield states[i], actions[i], rewards[i], dones[i], new_states[i]"]}, {"cell_type": "markdown", "id": "20779fc6", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.029204, "end_time": "2022-04-28T12:56:20.098456", "exception": false, "start_time": "2022-04-28T12:56:20.069252", "status": "completed"}, "tags": []}, "source": ["### Agent"]}, {"cell_type": "code", "execution_count": 7, "id": "128a705a", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:20.166166Z", "iopub.status.busy": "2022-04-28T12:56:20.159927Z", "iopub.status.idle": "2022-04-28T12:56:20.168030Z", "shell.execute_reply": "2022-04-28T12:56:20.168464Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.040394, "end_time": "2022-04-28T12:56:20.168611", "exception": false, "start_time": "2022-04-28T12:56:20.128217", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class Agent:\n", " \"\"\"Base Agent class handeling the interaction with the environment.\"\"\"\n", "\n", " def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:\n", " \"\"\"\n", " Args:\n", " env: training environment\n", " replay_buffer: replay buffer storing experiences\n", " \"\"\"\n", " self.env = env\n", " self.replay_buffer = replay_buffer\n", " self.reset()\n", " self.state = self.env.reset()\n", "\n", " def reset(self) -> None:\n", " \"\"\"Resents the environment and updates the state.\"\"\"\n", " self.state = self.env.reset()\n", "\n", " def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:\n", " \"\"\"Using the given network, decide what action to carry out using an epsilon-greedy policy.\n", "\n", " Args:\n", " net: DQN network\n", " epsilon: value to determine likelihood of taking a random action\n", " device: current device\n", "\n", " Returns:\n", " action\n", " \"\"\"\n", " if np.random.random() < epsilon:\n", " action = self.env.action_space.sample()\n", " else:\n", " state = torch.tensor([self.state])\n", "\n", " if device not in [\"cpu\"]:\n", " state = state.cuda(device)\n", "\n", " q_values = net(state)\n", " _, action = torch.max(q_values, dim=1)\n", " action = int(action.item())\n", "\n", " return action\n", "\n", " @torch.no_grad()\n", " def play_step(\n", " self,\n", " net: nn.Module,\n", " epsilon: float = 0.0,\n", " device: str = \"cpu\",\n", " ) -> Tuple[float, bool]:\n", " \"\"\"Carries out a single interaction step between the agent and the environment.\n", "\n", " Args:\n", " net: DQN network\n", " epsilon: value to determine likelihood of taking a random action\n", " device: current device\n", "\n", " Returns:\n", " reward, done\n", " \"\"\"\n", "\n", " action = self.get_action(net, epsilon, device)\n", "\n", " # do step in the environment\n", " new_state, reward, done, _ = self.env.step(action)\n", "\n", " exp = Experience(self.state, action, reward, done, new_state)\n", "\n", " self.replay_buffer.append(exp)\n", "\n", " self.state = new_state\n", " if done:\n", " self.reset()\n", " return reward, done"]}, {"cell_type": "markdown", "id": "0526a1ef", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.028897, "end_time": "2022-04-28T12:56:20.226764", "exception": false, "start_time": "2022-04-28T12:56:20.197867", "status": "completed"}, "tags": []}, "source": ["### DQN Lightning Module"]}, {"cell_type": "code", "execution_count": 8, "id": "d77076c3", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:20.301345Z", "iopub.status.busy": "2022-04-28T12:56:20.292991Z", "iopub.status.idle": "2022-04-28T12:56:20.303346Z", "shell.execute_reply": "2022-04-28T12:56:20.303755Z"}, "papermill": {"duration": 0.047912, "end_time": "2022-04-28T12:56:20.303897", "exception": false, "start_time": "2022-04-28T12:56:20.255985", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class DQNLightning(LightningModule):\n", " \"\"\"Basic DQN Model.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " batch_size: int = 16,\n", " lr: float = 1e-2,\n", " env: str = \"CartPole-v0\",\n", " gamma: float = 0.99,\n", " sync_rate: int = 10,\n", " replay_size: int = 1000,\n", " warm_start_size: int = 1000,\n", " eps_last_frame: int = 1000,\n", " eps_start: float = 1.0,\n", " eps_end: float = 0.01,\n", " episode_length: int = 200,\n", " warm_start_steps: int = 1000,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " batch_size: size of the batches\")\n", " lr: learning rate\n", " env: gym environment tag\n", " gamma: discount factor\n", " sync_rate: how many frames do we update the target network\n", " replay_size: capacity of the replay buffer\n", " warm_start_size: how many samples do we use to fill our buffer at the start of training\n", " eps_last_frame: what frame should epsilon stop decaying\n", " eps_start: starting value of epsilon\n", " eps_end: final value of epsilon\n", " episode_length: max length of an episode\n", " warm_start_steps: max episode reward in the environment\n", " \"\"\"\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " self.env = gym.make(self.hparams.env)\n", " obs_size = self.env.observation_space.shape[0]\n", " n_actions = self.env.action_space.n\n", "\n", " self.net = DQN(obs_size, n_actions)\n", " self.target_net = DQN(obs_size, n_actions)\n", "\n", " self.buffer = ReplayBuffer(self.hparams.replay_size)\n", " self.agent = Agent(self.env, self.buffer)\n", " self.total_reward = 0\n", " self.episode_reward = 0\n", " self.populate(self.hparams.warm_start_steps)\n", "\n", " def populate(self, steps: int = 1000) -> None:\n", " \"\"\"Carries out several random steps through the environment to initially fill up the replay buffer with\n", " experiences.\n", "\n", " Args:\n", " steps: number of random steps to populate the buffer with\n", " \"\"\"\n", " for _ in range(steps):\n", " self.agent.play_step(self.net, epsilon=1.0)\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"Passes in a state x through the network and gets the q_values of each action as an output.\n", "\n", " Args:\n", " x: environment state\n", "\n", " Returns:\n", " q values\n", " \"\"\"\n", " output = self.net(x)\n", " return output\n", "\n", " def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:\n", " \"\"\"Calculates the mse loss using a mini batch from the replay buffer.\n", "\n", " Args:\n", " batch: current mini batch of replay data\n", "\n", " Returns:\n", " loss\n", " \"\"\"\n", " states, actions, rewards, dones, next_states = batch\n", "\n", " state_action_values = self.net(states).gather(1, actions.long().unsqueeze(-1)).squeeze(-1)\n", "\n", " with torch.no_grad():\n", " next_state_values = self.target_net(next_states).max(1)[0]\n", " next_state_values[dones] = 0.0\n", " next_state_values = next_state_values.detach()\n", "\n", " expected_state_action_values = next_state_values * self.hparams.gamma + rewards\n", "\n", " return nn.MSELoss()(state_action_values, expected_state_action_values)\n", "\n", " def get_epsilon(self, start: int, end: int, frames: int) -> float:\n", " if self.global_step > frames:\n", " return end\n", " return start - (self.global_step / frames) * (start - end)\n", "\n", " def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:\n", " \"\"\"Carries out a single step through the environment to update the replay buffer. Then calculates loss\n", " based on the minibatch recieved.\n", "\n", " Args:\n", " batch: current mini batch of replay data\n", " nb_batch: batch number\n", "\n", " Returns:\n", " Training loss and log metrics\n", " \"\"\"\n", " device = self.get_device(batch)\n", " epsilon = self.get_epsilon(self.hparams.eps_start, self.hparams.eps_end, self.hparams.eps_last_frame)\n", " self.log(\"epsilon\", epsilon)\n", "\n", " # step through environment with agent\n", " reward, done = self.agent.play_step(self.net, epsilon, device)\n", " self.episode_reward += reward\n", " self.log(\"episode reward\", self.episode_reward)\n", "\n", " # calculates training loss\n", " loss = self.dqn_mse_loss(batch)\n", "\n", " if done:\n", " self.total_reward = self.episode_reward\n", " self.episode_reward = 0\n", "\n", " # Soft update of target network\n", " if self.global_step % self.hparams.sync_rate == 0:\n", " self.target_net.load_state_dict(self.net.state_dict())\n", "\n", " self.log_dict(\n", " {\n", " \"reward\": reward,\n", " \"train_loss\": loss,\n", " }\n", " )\n", " self.log(\"total_reward\", self.total_reward, prog_bar=True)\n", " self.log(\"steps\", self.global_step, logger=False, prog_bar=True)\n", "\n", " return loss\n", "\n", " def configure_optimizers(self) -> List[Optimizer]:\n", " \"\"\"Initialize Adam optimizer.\"\"\"\n", " optimizer = Adam(self.net.parameters(), lr=self.hparams.lr)\n", " return optimizer\n", "\n", " def __dataloader(self) -> DataLoader:\n", " \"\"\"Initialize the Replay Buffer dataset used for retrieving experiences.\"\"\"\n", " dataset = RLDataset(self.buffer, self.hparams.episode_length)\n", " dataloader = DataLoader(\n", " dataset=dataset,\n", " batch_size=self.hparams.batch_size,\n", " )\n", " return dataloader\n", "\n", " def train_dataloader(self) -> DataLoader:\n", " \"\"\"Get train loader.\"\"\"\n", " return self.__dataloader()\n", "\n", " def get_device(self, batch) -> str:\n", " \"\"\"Retrieve device currently being used by minibatch.\"\"\"\n", " return batch[0].device.index if self.on_gpu else \"cpu\""]}, {"cell_type": "markdown", "id": "5d33f5a0", "metadata": {"papermill": {"duration": 0.029164, "end_time": "2022-04-28T12:56:20.362319", "exception": false, "start_time": "2022-04-28T12:56:20.333155", "status": "completed"}, "tags": []}, "source": ["### Trainer"]}, {"cell_type": "code", "execution_count": 9, "id": "e026a5a3", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:20.427732Z", "iopub.status.busy": "2022-04-28T12:56:20.425538Z", "iopub.status.idle": "2022-04-28T12:56:38.761133Z", "shell.execute_reply": "2022-04-28T12:56:38.760681Z"}, "papermill": {"duration": 18.36895, "end_time": "2022-04-28T12:56:38.761267", "exception": false, "start_time": "2022-04-28T12:56:20.392317", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True, used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "------------------------------------\n", "0 | net | DQN | 898 \n", "1 | target_net | DQN | 898 \n", "------------------------------------\n", "1.8 K Trainable params\n", "0 Non-trainable params\n", "1.8 K Total params\n", "0.007 Total estimated model params size (MB)\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:240: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "dbe8a00c1de94b83b3013847ea2f5940", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: 0it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('total_reward', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.\n", " warning_cache.warn(\n", "/home/AzDevOps_azpcontainer/.local/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:229: UserWarning: You called `self.log('steps', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.\n", " warning_cache.warn(\n"]}], "source": ["\n", "model = DQNLightning()\n", "\n", "trainer = Trainer(\n", " accelerator=\"auto\",\n", " devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs\n", " max_epochs=150,\n", " val_check_interval=50,\n", " logger=CSVLogger(save_dir=\"logs/\"),\n", ")\n", "\n", "trainer.fit(model)"]}, {"cell_type": "code", "execution_count": 10, "id": "ec06ad77", "metadata": {"execution": {"iopub.execute_input": "2022-04-28T12:56:38.838651Z", "iopub.status.busy": "2022-04-28T12:56:38.838157Z", "iopub.status.idle": "2022-04-28T12:56:39.245747Z", "shell.execute_reply": "2022-04-28T12:56:39.246149Z"}, "papermill": {"duration": 0.448227, "end_time": "2022-04-28T12:56:39.246313", "exception": false, "start_time": "2022-04-28T12:56:38.798086", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>epsilon</th>\n", " <th>episode reward</th>\n", " <th>reward</th>\n", " <th>train_loss</th>\n", " <th>total_reward</th>\n", " </tr>\n", " <tr>\n", " <th>epoch</th>\n", " <th></th>\n", " <th></th>\n", " <th></th>\n", " <th></th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>3</th>\n", " <td>0.95149</td>\n", " <td>5.0</td>\n", " <td>1.0</td>\n", " <td>0.189056</td>\n", " <td>22.0</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>0.90199</td>\n", " <td>15.0</td>\n", " <td>1.0</td>\n", " <td>1.432721</td>\n", " <td>12.0</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>0.85249</td>\n", " <td>18.0</td>\n", " <td>1.0</td>\n", " <td>30.838800</td>\n", " <td>14.0</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>0.80299</td>\n", " <td>68.0</td>\n", " <td>1.0</td>\n", " <td>3.394485</td>\n", " <td>14.0</td>\n", " </tr>\n", " <tr>\n", " <th>19</th>\n", " <td>0.75349</td>\n", " <td>21.0</td>\n", " <td>1.0</td>\n", " <td>18.886366</td>\n", " <td>15.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>"], "text/plain": [" epsilon episode reward reward train_loss total_reward\n", "epoch \n", "3 0.95149 5.0 1.0 0.189056 22.0\n", "7 0.90199 15.0 1.0 1.432721 12.0\n", "11 0.85249 18.0 1.0 30.838800 14.0\n", "15 0.80299 68.0 1.0 3.394485 14.0\n", "19 0.75349 21.0 1.0 18.886366 15.0"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"text/plain": ["<seaborn.axisgrid.FacetGrid at 0x7f02190640d0>"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}, {"data": {"image/png": "\n", "text/plain": ["<Figure size 472.75x360 with 1 Axes>"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["\n", "metrics = pd.read_csv(f\"{trainer.logger.log_dir}/metrics.csv\")\n", "del metrics[\"step\"]\n", "metrics.set_index(\"epoch\", inplace=True)\n", "display(metrics.dropna(axis=1, how=\"all\").head())\n", "sn.relplot(data=metrics, kind=\"line\")"]}, {"cell_type": "markdown", "id": "a5152630", "metadata": {"papermill": {"duration": 0.040482, "end_time": "2022-04-28T12:56:39.328199", "exception": false, "start_time": "2022-04-28T12:56:39.287717", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://www.pytorchlightning.ai/community)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "[{height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", " :header: How to train a Deep Q Network\n", " :card_description: Main takeaways: 1. RL has the same flow as previous models we have seen, with a few additions 2. Handle unsupervised learning by using an IterableDataset where the dataset...\n", " :tags: RL,GPU/TPU,Lightning-Examples"]}], "metadata": {"jupytext": {"cell_metadata_filter": "colab_type,id,colab,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12"}, "papermill": {"default_parameters": {}, "duration": 27.646782, "end_time": "2022-04-28T12:56:40.177952", "environment_variables": {}, "exception": null, "input_path": "lightning_examples/reinforce-learning-DQN/dqn.ipynb", "output_path": ".notebooks/lightning_examples/reinforce-learning-DQN.ipynb", "parameters": {}, "start_time": "2022-04-28T12:56:12.531170", "version": "2.3.4"}, "widgets": {"application/vnd.jupyter.widget-state+json": {"state": {"05064253be494384a9d635234217b5dd": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": ""}}, "1c02c26a2a414891b8fee0a16336d57f": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": "2", "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "1f07ee7908544bdb8bec422b9d950040": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_4621f6ebf4cf40289ae7e035b9adb3fd", "placeholder": "\u200b", "style": "IPY_MODEL_2d3298321de24088b41bb438812b2795", "value": "Epoch 149: "}}, "22c787b404204893a82471f1a84a5d74": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": "inline-flex", "flex": null, "flex_flow": "row wrap", "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": "100%"}}, "2d3298321de24088b41bb438812b2795": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": ""}}, "4621f6ebf4cf40289ae7e035b9adb3fd": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "566919922f744e4886abceb34314e954": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "813f369b769448a689615dd47cc401fd": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_1c02c26a2a414891b8fee0a16336d57f", "max": 1.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_566919922f744e4886abceb34314e954", "value": 1.0}}, "9d315e5cd5694a43b579f0ef98d98f84": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_b9552087e7f945cabf53728d7bec0c6b", "placeholder": "\u200b", "style": "IPY_MODEL_05064253be494384a9d635234217b5dd", "value": " 13/? [00:15<00:00, 1.16s/it, loss=10.2, v_num=4, total_reward=144.0, steps=1949.0]"}}, "b9552087e7f945cabf53728d7bec0c6b": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "dbe8a00c1de94b83b3013847ea2f5940": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_1f07ee7908544bdb8bec422b9d950040", "IPY_MODEL_813f369b769448a689615dd47cc401fd", "IPY_MODEL_9d315e5cd5694a43b579f0ef98d98f84"], "layout": "IPY_MODEL_22c787b404204893a82471f1a84a5d74"}}}, "version_major": 2, "version_minor": 0}}}, "nbformat": 4, "nbformat_minor": 5}