{"cells": [{"cell_type": "markdown", "id": "ef8a43a5", "metadata": {"papermill": {"duration": 0.039794, "end_time": "2021-08-31T12:37:17.276676", "exception": false, "start_time": "2021-08-31T12:37:17.236882", "status": "completed"}, "tags": []}, "source": ["\n", "# How to train a Deep Q Network\n", "\n", "* **Author:** PL team\n", "* **License:** CC BY-SA\n", "* **Generated:** 2021-08-31T13:56:11.349578\n", "\n", "Main takeaways:\n", "\n", "1. RL has the same flow as previous models we have seen, with a few additions\n", "2. Handle unsupervised learning by using an IterableDataset where the dataset itself is constantly updated during training\n", "3. Each training step carries has the agent taking an action in the environment and storing the experience in the IterableDataset\n", "\n", "\n", "---\n", "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/reinforce-learning-DQN.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", "| Check out [the documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", "| Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)"]}, {"cell_type": "markdown", "id": "a4f9cec5", "metadata": {"papermill": {"duration": 0.007549, "end_time": "2021-08-31T12:37:17.292223", "exception": false, "start_time": "2021-08-31T12:37:17.284674", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "f96d7041", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2021-08-31T12:37:17.314301Z", "iopub.status.busy": "2021-08-31T12:37:17.313817Z", "iopub.status.idle": "2021-08-31T12:37:19.681259Z", "shell.execute_reply": "2021-08-31T12:37:19.680722Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 2.381402, "end_time": "2021-08-31T12:37:19.681392", "exception": false, "start_time": "2021-08-31T12:37:17.299990", "status": "completed"}, "tags": []}, "outputs": [], "source": ["! pip install --quiet \"torchmetrics>=0.3\" \"torch>=1.6, <1.9\" \"pytorch-lightning>=1.3\" \"gym\""]}, {"cell_type": "code", "execution_count": 2, "id": "8447fac2", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:19.702696Z", "iopub.status.busy": "2021-08-31T12:37:19.702216Z", "iopub.status.idle": "2021-08-31T12:37:22.180422Z", "shell.execute_reply": "2021-08-31T12:37:22.179976Z"}, "papermill": {"duration": 2.4905, "end_time": "2021-08-31T12:37:22.180549", "exception": false, "start_time": "2021-08-31T12:37:19.690049", "status": "completed"}, "tags": []}, "outputs": [], "source": ["import os\n", "from collections import OrderedDict, deque, namedtuple\n", "from typing import List, Tuple\n", "\n", "import gym\n", "import numpy as np\n", "import torch\n", "from pytorch_lightning import LightningModule, Trainer\n", "from pytorch_lightning.utilities import DistributedType\n", "from torch import Tensor, nn\n", "from torch.optim import Adam, Optimizer\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data.dataset import IterableDataset\n", "\n", "PATH_DATASETS = os.environ.get(\"PATH_DATASETS\", \".\")\n", "AVAIL_GPUS = min(1, torch.cuda.device_count())"]}, {"cell_type": "code", "execution_count": 3, "id": "eb2c666d", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.201758Z", "iopub.status.busy": "2021-08-31T12:37:22.201258Z", "iopub.status.idle": "2021-08-31T12:37:22.203391Z", "shell.execute_reply": "2021-08-31T12:37:22.202923Z"}, "papermill": {"duration": 0.014583, "end_time": "2021-08-31T12:37:22.203491", "exception": false, "start_time": "2021-08-31T12:37:22.188908", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class DQN(nn.Module):\n", " \"\"\"Simple MLP network.\"\"\"\n", "\n", " def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):\n", " \"\"\"\n", " Args:\n", " obs_size: observation/state size of the environment\n", " n_actions: number of discrete actions available in the environment\n", " hidden_size: size of hidden layers\n", " \"\"\"\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(obs_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Linear(hidden_size, n_actions),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x.float())"]}, {"cell_type": "markdown", "id": "395cefea", "metadata": {"papermill": {"duration": 0.007826, "end_time": "2021-08-31T12:37:22.219352", "exception": false, "start_time": "2021-08-31T12:37:22.211526", "status": "completed"}, "tags": []}, "source": ["### Memory"]}, {"cell_type": "code", "execution_count": 4, "id": "8491d2e4", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.239138Z", "iopub.status.busy": "2021-08-31T12:37:22.238662Z", "iopub.status.idle": "2021-08-31T12:37:22.240820Z", "shell.execute_reply": "2021-08-31T12:37:22.240330Z"}, "papermill": {"duration": 0.013722, "end_time": "2021-08-31T12:37:22.240922", "exception": false, "start_time": "2021-08-31T12:37:22.227200", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "# Named tuple for storing experience steps gathered in training\n", "Experience = namedtuple(\n", " \"Experience\",\n", " field_names=[\"state\", \"action\", \"reward\", \"done\", \"new_state\"],\n", ")"]}, {"cell_type": "code", "execution_count": 5, "id": "dba6653a", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.263599Z", "iopub.status.busy": "2021-08-31T12:37:22.263108Z", "iopub.status.idle": "2021-08-31T12:37:22.265253Z", "shell.execute_reply": "2021-08-31T12:37:22.264753Z"}, "papermill": {"duration": 0.016357, "end_time": "2021-08-31T12:37:22.265353", "exception": false, "start_time": "2021-08-31T12:37:22.248996", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class ReplayBuffer:\n", " \"\"\"Replay Buffer for storing past experiences allowing the agent to learn from them.\n", "\n", " Args:\n", " capacity: size of the buffer\n", " \"\"\"\n", "\n", " def __init__(self, capacity: int) -> None:\n", " self.buffer = deque(maxlen=capacity)\n", "\n", " def __len__(self) -> None:\n", " return len(self.buffer)\n", "\n", " def append(self, experience: Experience) -> None:\n", " \"\"\"Add experience to the buffer.\n", "\n", " Args:\n", " experience: tuple (state, action, reward, done, new_state)\n", " \"\"\"\n", " self.buffer.append(experience)\n", "\n", " def sample(self, batch_size: int) -> Tuple:\n", " indices = np.random.choice(len(self.buffer), batch_size, replace=False)\n", " states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))\n", "\n", " return (\n", " np.array(states),\n", " np.array(actions),\n", " np.array(rewards, dtype=np.float32),\n", " np.array(dones, dtype=np.bool),\n", " np.array(next_states),\n", " )"]}, {"cell_type": "code", "execution_count": 6, "id": "d4d0109b", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.285666Z", "iopub.status.busy": "2021-08-31T12:37:22.285190Z", "iopub.status.idle": "2021-08-31T12:37:22.286884Z", "shell.execute_reply": "2021-08-31T12:37:22.287257Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.014122, "end_time": "2021-08-31T12:37:22.287372", "exception": false, "start_time": "2021-08-31T12:37:22.273250", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class RLDataset(IterableDataset):\n", " \"\"\"Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training.\n", "\n", " Args:\n", " buffer: replay buffer\n", " sample_size: number of experiences to sample at a time\n", " \"\"\"\n", "\n", " def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:\n", " self.buffer = buffer\n", " self.sample_size = sample_size\n", "\n", " def __iter__(self) -> Tuple:\n", " states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)\n", " for i in range(len(dones)):\n", " yield states[i], actions[i], rewards[i], dones[i], new_states[i]"]}, {"cell_type": "markdown", "id": "a0b14f16", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007984, "end_time": "2021-08-31T12:37:22.303700", "exception": false, "start_time": "2021-08-31T12:37:22.295716", "status": "completed"}, "tags": []}, "source": ["### Agent"]}, {"cell_type": "code", "execution_count": 7, "id": "fe9a42c1", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.328705Z", "iopub.status.busy": "2021-08-31T12:37:22.328223Z", "iopub.status.idle": "2021-08-31T12:37:22.330404Z", "shell.execute_reply": "2021-08-31T12:37:22.329938Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.018735, "end_time": "2021-08-31T12:37:22.330502", "exception": false, "start_time": "2021-08-31T12:37:22.311767", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class Agent:\n", " \"\"\"Base Agent class handeling the interaction with the environment.\"\"\"\n", "\n", " def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:\n", " \"\"\"\n", " Args:\n", " env: training environment\n", " replay_buffer: replay buffer storing experiences\n", " \"\"\"\n", " self.env = env\n", " self.replay_buffer = replay_buffer\n", " self.reset()\n", " self.state = self.env.reset()\n", "\n", " def reset(self) -> None:\n", " \"\"\"Resents the environment and updates the state.\"\"\"\n", " self.state = self.env.reset()\n", "\n", " def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:\n", " \"\"\"Using the given network, decide what action to carry out using an epsilon-greedy policy.\n", "\n", " Args:\n", " net: DQN network\n", " epsilon: value to determine likelihood of taking a random action\n", " device: current device\n", "\n", " Returns:\n", " action\n", " \"\"\"\n", " if np.random.random() < epsilon:\n", " action = self.env.action_space.sample()\n", " else:\n", " state = torch.tensor([self.state])\n", "\n", " if device not in [\"cpu\"]:\n", " state = state.cuda(device)\n", "\n", " q_values = net(state)\n", " _, action = torch.max(q_values, dim=1)\n", " action = int(action.item())\n", "\n", " return action\n", "\n", " @torch.no_grad()\n", " def play_step(\n", " self,\n", " net: nn.Module,\n", " epsilon: float = 0.0,\n", " device: str = \"cpu\",\n", " ) -> Tuple[float, bool]:\n", " \"\"\"Carries out a single interaction step between the agent and the environment.\n", "\n", " Args:\n", " net: DQN network\n", " epsilon: value to determine likelihood of taking a random action\n", " device: current device\n", "\n", " Returns:\n", " reward, done\n", " \"\"\"\n", "\n", " action = self.get_action(net, epsilon, device)\n", "\n", " # do step in the environment\n", " new_state, reward, done, _ = self.env.step(action)\n", "\n", " exp = Experience(self.state, action, reward, done, new_state)\n", "\n", " self.replay_buffer.append(exp)\n", "\n", " self.state = new_state\n", " if done:\n", " self.reset()\n", " return reward, done"]}, {"cell_type": "markdown", "id": "87c06ab2", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.008003, "end_time": "2021-08-31T12:37:22.346701", "exception": false, "start_time": "2021-08-31T12:37:22.338698", "status": "completed"}, "tags": []}, "source": ["### DQN Lightning Module"]}, {"cell_type": "code", "execution_count": 8, "id": "2674d9e4", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.378167Z", "iopub.status.busy": "2021-08-31T12:37:22.369838Z", "iopub.status.idle": "2021-08-31T12:37:22.379784Z", "shell.execute_reply": "2021-08-31T12:37:22.380160Z"}, "papermill": {"duration": 0.025499, "end_time": "2021-08-31T12:37:22.380273", "exception": false, "start_time": "2021-08-31T12:37:22.354774", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class DQNLightning(LightningModule):\n", " \"\"\"Basic DQN Model.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " batch_size: int = 16,\n", " lr: float = 1e-2,\n", " env: str = \"CartPole-v0\",\n", " gamma: float = 0.99,\n", " sync_rate: int = 10,\n", " replay_size: int = 1000,\n", " warm_start_size: int = 1000,\n", " eps_last_frame: int = 1000,\n", " eps_start: float = 1.0,\n", " eps_end: float = 0.01,\n", " episode_length: int = 200,\n", " warm_start_steps: int = 1000,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " batch_size: size of the batches\")\n", " lr: learning rate\n", " env: gym environment tag\n", " gamma: discount factor\n", " sync_rate: how many frames do we update the target network\n", " replay_size: capacity of the replay buffer\n", " warm_start_size: how many samples do we use to fill our buffer at the start of training\n", " eps_last_frame: what frame should epsilon stop decaying\n", " eps_start: starting value of epsilon\n", " eps_end: final value of epsilon\n", " episode_length: max length of an episode\n", " warm_start_steps: max episode reward in the environment\n", " \"\"\"\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " self.env = gym.make(self.hparams.env)\n", " obs_size = self.env.observation_space.shape[0]\n", " n_actions = self.env.action_space.n\n", "\n", " self.net = DQN(obs_size, n_actions)\n", " self.target_net = DQN(obs_size, n_actions)\n", "\n", " self.buffer = ReplayBuffer(self.hparams.replay_size)\n", " self.agent = Agent(self.env, self.buffer)\n", " self.total_reward = 0\n", " self.episode_reward = 0\n", " self.populate(self.hparams.warm_start_steps)\n", "\n", " def populate(self, steps: int = 1000) -> None:\n", " \"\"\"Carries out several random steps through the environment to initially fill up the replay buffer with\n", " experiences.\n", "\n", " Args:\n", " steps: number of random steps to populate the buffer with\n", " \"\"\"\n", " for i in range(steps):\n", " self.agent.play_step(self.net, epsilon=1.0)\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"Passes in a state x through the network and gets the q_values of each action as an output.\n", "\n", " Args:\n", " x: environment state\n", "\n", " Returns:\n", " q values\n", " \"\"\"\n", " output = self.net(x)\n", " return output\n", "\n", " def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:\n", " \"\"\"Calculates the mse loss using a mini batch from the replay buffer.\n", "\n", " Args:\n", " batch: current mini batch of replay data\n", "\n", " Returns:\n", " loss\n", " \"\"\"\n", " states, actions, rewards, dones, next_states = batch\n", "\n", " state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)\n", "\n", " with torch.no_grad():\n", " next_state_values = self.target_net(next_states).max(1)[0]\n", " next_state_values[dones] = 0.0\n", " next_state_values = next_state_values.detach()\n", "\n", " expected_state_action_values = next_state_values * self.hparams.gamma + rewards\n", "\n", " return nn.MSELoss()(state_action_values, expected_state_action_values)\n", "\n", " def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:\n", " \"\"\"Carries out a single step through the environment to update the replay buffer. Then calculates loss\n", " based on the minibatch recieved.\n", "\n", " Args:\n", " batch: current mini batch of replay data\n", " nb_batch: batch number\n", "\n", " Returns:\n", " Training loss and log metrics\n", " \"\"\"\n", " device = self.get_device(batch)\n", " epsilon = max(\n", " self.hparams.eps_end,\n", " self.hparams.eps_start - self.global_step + 1 / self.hparams.eps_last_frame,\n", " )\n", "\n", " # step through environment with agent\n", " reward, done = self.agent.play_step(self.net, epsilon, device)\n", " self.episode_reward += reward\n", "\n", " # calculates training loss\n", " loss = self.dqn_mse_loss(batch)\n", "\n", " if self.trainer._distrib_type in {DistributedType.DP, DistributedType.DDP2}:\n", " loss = loss.unsqueeze(0)\n", "\n", " if done:\n", " self.total_reward = self.episode_reward\n", " self.episode_reward = 0\n", "\n", " # Soft update of target network\n", " if self.global_step % self.hparams.sync_rate == 0:\n", " self.target_net.load_state_dict(self.net.state_dict())\n", "\n", " log = {\n", " \"total_reward\": torch.tensor(self.total_reward).to(device),\n", " \"reward\": torch.tensor(reward).to(device),\n", " \"train_loss\": loss,\n", " }\n", " status = {\n", " \"steps\": torch.tensor(self.global_step).to(device),\n", " \"total_reward\": torch.tensor(self.total_reward).to(device),\n", " }\n", "\n", " return OrderedDict({\"loss\": loss, \"log\": log, \"progress_bar\": status})\n", "\n", " def configure_optimizers(self) -> List[Optimizer]:\n", " \"\"\"Initialize Adam optimizer.\"\"\"\n", " optimizer = Adam(self.net.parameters(), lr=self.hparams.lr)\n", " return [optimizer]\n", "\n", " def __dataloader(self) -> DataLoader:\n", " \"\"\"Initialize the Replay Buffer dataset used for retrieving experiences.\"\"\"\n", " dataset = RLDataset(self.buffer, self.hparams.episode_length)\n", " dataloader = DataLoader(\n", " dataset=dataset,\n", " batch_size=self.hparams.batch_size,\n", " )\n", " return dataloader\n", "\n", " def train_dataloader(self) -> DataLoader:\n", " \"\"\"Get train loader.\"\"\"\n", " return self.__dataloader()\n", "\n", " def get_device(self, batch) -> str:\n", " \"\"\"Retrieve device currently being used by minibatch.\"\"\"\n", " return batch[0].device.index if self.on_gpu else \"cpu\""]}, {"cell_type": "markdown", "id": "34468e32", "metadata": {"papermill": {"duration": 0.008184, "end_time": "2021-08-31T12:37:22.396696", "exception": false, "start_time": "2021-08-31T12:37:22.388512", "status": "completed"}, "tags": []}, "source": ["### Trainer"]}, {"cell_type": "code", "execution_count": 9, "id": "d99173d5", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:22.416827Z", "iopub.status.busy": "2021-08-31T12:37:22.416355Z", "iopub.status.idle": "2021-08-31T12:37:37.151109Z", "shell.execute_reply": "2021-08-31T12:37:37.151493Z"}, "papermill": {"duration": 14.746762, "end_time": "2021-08-31T12:37:37.151642", "exception": false, "start_time": "2021-08-31T12:37:22.404880", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True, used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["IPU available: False, using: 0 IPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params\n", "------------------------------------\n", "0 | net | DQN | 898 \n", "1 | target_net | DQN | 898 \n", "------------------------------------\n", "1.8 K Trainable params\n", "0 Non-trainable params\n", "1.8 K Total params\n", "0.007 Total estimated model params size (MB)\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:105: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "f962e76bc6bd4477847f5dedbf4d5f6d", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: -1it [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/tmp/ipykernel_13751/3638216480.py:30: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " np.array(dones, dtype=np.bool),\n", "/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:397: LightningDeprecationWarning: One of the returned values {'progress_bar', 'log'} has a `grad_fn`. We will detach it automatically but this behaviour will change in v1.6. Please detach it manually: `return {'loss': ..., 'something': something.detach()}`\n", " warning_cache.deprecation(\n"]}], "source": ["\n", "model = DQNLightning()\n", "\n", "trainer = Trainer(\n", " gpus=AVAIL_GPUS,\n", " max_epochs=200,\n", " val_check_interval=100,\n", ")\n", "\n", "trainer.fit(model)"]}, {"cell_type": "code", "execution_count": 10, "id": "4d7a7f5b", "metadata": {"execution": {"iopub.execute_input": "2021-08-31T12:37:37.176779Z", "iopub.status.busy": "2021-08-31T12:37:37.176322Z", "iopub.status.idle": "2021-08-31T12:37:38.234163Z", "shell.execute_reply": "2021-08-31T12:37:38.234540Z"}, "papermill": {"duration": 1.071924, "end_time": "2021-08-31T12:37:38.234673", "exception": false, "start_time": "2021-08-31T12:37:37.162749", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["\n", " \n", " \n", " "], "text/plain": [""]}, "metadata": {}, "output_type": "display_data"}], "source": ["# Start tensorboard.\n", "%load_ext tensorboard\n", "%tensorboard --logdir lightning_logs/"]}, {"cell_type": "markdown", "id": "f022a7f9", "metadata": {"papermill": {"duration": 0.010581, "end_time": "2021-08-31T12:37:38.256184", "exception": false, "start_time": "2021-08-31T12:37:38.245603", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "{height=\"60px\" width=\"240px\"}"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", " :header: How to train a Deep Q Network\n", " :card_description: Main takeaways: 1. RL has the same flow as previous models we have seen, with a few additions 2. Handle unsupervised learning by using an IterableDataset where the dataset...\n", " :tags: RL,GPU/TPU,Lightning-Examples"]}], "metadata": {"jupytext": {"cell_metadata_filter": "colab_type,colab,id,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6"}, "papermill": {"default_parameters": {}, "duration": 23.186754, "end_time": "2021-08-31T12:37:39.374766", "environment_variables": {}, "exception": null, "input_path": "lightning_examples/reinforce-learning-DQN/dqn.ipynb", "output_path": ".notebooks/lightning_examples/reinforce-learning-DQN.ipynb", "parameters": {}, "start_time": "2021-08-31T12:37:16.188012", "version": "2.3.3"}, "widgets": {"application/vnd.jupyter.widget-state+json": {"state": {"4c55a94d1d6e4fd7a0df1813d82588a3": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "6d83ec49723d45a6931b6ba9dbd64ec4": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": "2", "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "943f4388dd7b46e68aba0643420fc6ce": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": ""}}, "aade554e920a484b943345379a01cf83": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": "inline-flex", "flex": null, "flex_flow": "row wrap", "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": "100%"}}, "d0caef34ba7f48a3a5a32083a7622c04": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_e8087a4edab0420cbecaba04e73c4cfb", "placeholder": "\u200b", "style": "IPY_MODEL_f3ddb52ac2e24d3db75c8fb658d94a41", "value": " 13/? [00:00<00:00, 238.76it/s, loss=7.33, v_num=6]"}}, "d514fe7f72904e5fb9f4727b9a0d1e5b": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_6d83ec49723d45a6931b6ba9dbd64ec4", "max": 1.0, "min": 0.0, "orientation": "horizontal", "style": "IPY_MODEL_eb85619e8c9f4ae7bb3ad91dd9896fb3", "value": 1.0}}, "e8087a4edab0420cbecaba04e73c4cfb": {"model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": {"_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null}}, "eb85619e8c9f4ae7bb3ad91dd9896fb3": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": ""}}, "f3ddb52ac2e24d3db75c8fb658d94a41": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": {"_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": ""}}, "f962e76bc6bd4477847f5dedbf4d5f6d": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": ["IPY_MODEL_ffed8dbcdc9d4bd8ab6c5c4411ccf6af", "IPY_MODEL_d514fe7f72904e5fb9f4727b9a0d1e5b", "IPY_MODEL_d0caef34ba7f48a3a5a32083a7622c04"], "layout": "IPY_MODEL_aade554e920a484b943345379a01cf83"}}, "ffed8dbcdc9d4bd8ab6c5c4411ccf6af": {"model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": {"_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_4c55a94d1d6e4fd7a0df1813d82588a3", "placeholder": "\u200b", "style": "IPY_MODEL_943f4388dd7b46e68aba0643420fc6ce", "value": "Epoch 199: "}}}, "version_major": 2, "version_minor": 0}}}, "nbformat": 4, "nbformat_minor": 5}