• Docs >
  • How to train a Deep Q Network

How to train a Deep Q Network

  • Author: PL team

  • License: CC BY-SA

  • Generated: 2021-07-26T23:14:45.695289

Main takeaways:

  1. RL has the same flow as previous models we have seen, with a few additions

  2. Handle unsupervised learning by using an IterableDataset where the dataset itself is constantly updated during training

  3. Each training step carries has the agent taking an action in the environment and storing the experience in the IterableDataset

Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack


This notebook requires some packages besides pytorch-lightning.

! pip install --quiet "gym" "pytorch-lightning>=1.3" "torch>=1.6, <1.9" "torchmetrics>=0.3"
WARNING: Value for scheme.platlib does not match. Please report this to <https://github.com/pypa/pip/issues/10151>
distutils: /usr/local/lib/python3.9/dist-packages
sysconfig: /usr/lib/python3.9/site-packages
WARNING: Value for scheme.purelib does not match. Please report this to <https://github.com/pypa/pip/issues/10151>
distutils: /usr/local/lib/python3.9/dist-packages
sysconfig: /usr/lib/python3.9/site-packages
WARNING: Value for scheme.headers does not match. Please report this to <https://github.com/pypa/pip/issues/10151>
distutils: /usr/local/include/python3.9/UNKNOWN
sysconfig: /usr/include/python3.9/UNKNOWN
WARNING: Value for scheme.scripts does not match. Please report this to <https://github.com/pypa/pip/issues/10151>
distutils: /usr/local/bin
sysconfig: /usr/bin
WARNING: Value for scheme.data does not match. Please report this to <https://github.com/pypa/pip/issues/10151>
distutils: /usr/local
sysconfig: /usr
WARNING: Additional context:
user = False
home = None
root = None
prefix = None
import os
from collections import deque, namedtuple, OrderedDict
from typing import List, Tuple

import gym
import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn, Tensor
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

PATH_DATASETS = os.environ.get('PATH_DATASETS', '.')
AVAIL_GPUS = min(1, torch.cuda.device_count())
class DQN(nn.Module):
    Simple MLP network

    def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128):
            obs_size: observation/state size of the environment
            n_actions: number of discrete actions available in the environment
            hidden_size: size of hidden layers
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.Linear(hidden_size, n_actions),

    def forward(self, x):
        return self.net(x.float())



# Named tuple for storing experience steps gathered in training Experience = namedtuple( 'Experience', field_names=['state', 'action', 'reward', 'done', 'new_state'], )
class ReplayBuffer:
    Replay Buffer for storing past experiences allowing the agent to learn from them

        capacity: size of the buffer

    def __init__(self, capacity: int) -> None:
        self.buffer = deque(maxlen=capacity)

    def __len__(self) -> None:
        return len(self.buffer)

    def append(self, experience: Experience) -> None:
        Add experience to the buffer

            experience: tuple (state, action, reward, done, new_state)

    def sample(self, batch_size: int) -> Tuple:
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))

        return (
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=np.bool),
class RLDataset(IterableDataset):
    Iterable Dataset containing the ExperienceBuffer
    which will be updated with new experiences during training

        buffer: replay buffer
        sample_size: number of experiences to sample at a time

    def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None:
        self.buffer = buffer
        self.sample_size = sample_size

    def __iter__(self) -> Tuple:
        states, actions, rewards, dones, new_states = self.buffer.sample(self.sample_size)
        for i in range(len(dones)):
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]


class Agent:
    Base Agent class handeling the interaction with the environment

    def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None:
            env: training environment
            replay_buffer: replay buffer storing experiences
        self.env = env
        self.replay_buffer = replay_buffer
        self.state = self.env.reset()

    def reset(self) -> None:
        """ Resents the environment and updates the state"""
        self.state = self.env.reset()

    def get_action(self, net: nn.Module, epsilon: float, device: str) -> int:
        """Using the given network, decide what action to carry out
        using an epsilon-greedy policy

            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device

        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
            state = torch.tensor([self.state])

            if device not in ['cpu']:
                state = state.cuda(device)

            q_values = net(state)
            _, action = torch.max(q_values, dim=1)
            action = int(action.item())

        return action

    def play_step(
        net: nn.Module,
        epsilon: float = 0.0,
        device: str = 'cpu',
    ) -> Tuple[float, bool]:
        """Carries out a single interaction step between the agent and the environment

            net: DQN network
            epsilon: value to determine likelihood of taking a random action
            device: current device

            reward, done

        action = self.get_action(net, epsilon, device)

        # do step in the environment
        new_state, reward, done, _ = self.env.step(action)

        exp = Experience(self.state, action, reward, done, new_state)


        self.state = new_state
        if done:
        return reward, done

DQN Lightning Module

class DQNLightning(LightningModule):
    """ Basic DQN Model """

    def __init__(
        batch_size: int = 16,
        lr: float = 1e-2,
        env: str = "CartPole-v0",
        gamma: float = 0.99,
        sync_rate: int = 10,
        replay_size: int = 1000,
        warm_start_size: int = 1000,
        eps_last_frame: int = 1000,
        eps_start: float = 1.0,
        eps_end: float = 0.01,
        episode_length: int = 200,
        warm_start_steps: int = 1000,
    ) -> None:
            batch_size: size of the batches")
            lr: learning rate
            env: gym environment tag
            gamma: discount factor
            sync_rate: how many frames do we update the target network
            replay_size: capacity of the replay buffer
            warm_start_size: how many samples do we use to fill our buffer at the start of training
            eps_last_frame: what frame should epsilon stop decaying
            eps_start: starting value of epsilon
            eps_end: final value of epsilon
            episode_length: max length of an episode
            warm_start_steps: max episode reward in the environment

        self.env = gym.make(self.hparams.env)
        obs_size = self.env.observation_space.shape[0]
        n_actions = self.env.action_space.n

        self.net = DQN(obs_size, n_actions)
        self.target_net = DQN(obs_size, n_actions)

        self.buffer = ReplayBuffer(self.hparams.replay_size)
        self.agent = Agent(self.env, self.buffer)
        self.total_reward = 0
        self.episode_reward = 0

    def populate(self, steps: int = 1000) -> None:
        Carries out several random steps through the environment to initially fill
        up the replay buffer with experiences

            steps: number of random steps to populate the buffer with
        for i in range(steps):
            self.agent.play_step(self.net, epsilon=1.0)

    def forward(self, x: Tensor) -> Tensor:
        Passes in a state x through the network and gets the q_values of each action as an output

            x: environment state

            q values
        output = self.net(x)
        return output

    def dqn_mse_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
        Calculates the mse loss using a mini batch from the replay buffer

            batch: current mini batch of replay data

        states, actions, rewards, dones, next_states = batch

        state_action_values = self.net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

        with torch.no_grad():
            next_state_values = self.target_net(next_states).max(1)[0]
            next_state_values[dones] = 0.0
            next_state_values = next_state_values.detach()

        expected_state_action_values = next_state_values * self.hparams.gamma + rewards

        return nn.MSELoss()(state_action_values, expected_state_action_values)

    def training_step(self, batch: Tuple[Tensor, Tensor], nb_batch) -> OrderedDict:
        Carries out a single step through the environment to update the replay buffer.
        Then calculates loss based on the minibatch recieved

            batch: current mini batch of replay data
            nb_batch: batch number

            Training loss and log metrics
        device = self.get_device(batch)
        epsilon = max(
            self.hparams.eps_start - self.global_step + 1 / self.hparams.eps_last_frame,

        # step through environment with agent
        reward, done = self.agent.play_step(self.net, epsilon, device)
        self.episode_reward += reward

        # calculates training loss
        loss = self.dqn_mse_loss(batch)

        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss = loss.unsqueeze(0)

        if done:
            self.total_reward = self.episode_reward
            self.episode_reward = 0

        # Soft update of target network
        if self.global_step % self.hparams.sync_rate == 0:

        log = {
            'total_reward': torch.tensor(self.total_reward).to(device),
            'reward': torch.tensor(reward).to(device),
            'train_loss': loss
        status = {
            'steps': torch.tensor(self.global_step).to(device),
            'total_reward': torch.tensor(self.total_reward).to(device)

        return OrderedDict({'loss': loss, 'log': log, 'progress_bar': status})

    def configure_optimizers(self) -> List[Optimizer]:
        """ Initialize Adam optimizer"""
        optimizer = Adam(self.net.parameters(), lr=self.hparams.lr)
        return [optimizer]

    def __dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences"""
        dataset = RLDataset(self.buffer, self.hparams.episode_length)
        dataloader = DataLoader(
        return dataloader

    def train_dataloader(self) -> DataLoader:
        """Get train loader"""
        return self.__dataloader()

    def get_device(self, batch) -> str:
        """Retrieve device currently being used by minibatch"""
        return batch[0].device.index if self.on_gpu else 'cpu'



model = DQNLightning() trainer = Trainer( gpus=AVAIL_GPUS, max_epochs=200, val_check_interval=100, ) trainer.fit(model)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores

  | Name       | Type | Params
0 | net        | DQN  | 898
1 | target_net | DQN  | 898
1.8 K     Trainable params
0         Non-trainable params
1.8 K     Total params
0.007     Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:102: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
/tmp/ipykernel_728/3502201211.py:32: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.array(dones, dtype=np.bool),
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/deprecated_api.py:70: LightningDeprecationWarning: Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.
  rank_zero_deprecation("Internal: `use_dp` is deprecated in v1.2 and will be removed in v1.4.")
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/deprecated_api.py:92: LightningDeprecationWarning: Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.
  rank_zero_deprecation("Internal: `use_ddp2` is deprecated in v1.2 and will be removed in v1.4.")
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning