- Tutorials >
- Reinforcement Learning (DQN) Tutorial
Shortcuts
intermediate/reinforcement_q_learning
Run in Google Colab
Colab
Download Notebook
Notebook
View on GitHub
GitHub
Note
Click hereto download the full example code
- Author: Adam Paszke
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agenton the CartPole-v1 task from Gymnasium.
You might find it helpful to read the original Deep Q Learning (DQN) paper
Task
The agent has to decide between two actions - moving the cart left orright - so that the pole attached to it stays upright. You can find moreinformation about the environment and other more challenging environments atGymnasium’s website.
CartPole¶
As the agent observes the current state of the environment and choosesan action, the environment transitions to a new state, and alsoreturns a reward that indicates the consequences of the action. In thistask, rewards are +1 for every incremental timestep and the environmentterminates if the pole falls over too far or the cart moves more than 2.4units away from center. This means better performing scenarios will runfor longer duration, accumulating larger return.
The CartPole task is designed so that the inputs to the agent are 4 realvalues representing the environment state (position, velocity, etc.).We take these 4 inputs without any scaling and pass them through asmall fully-connected network with 2 outputs, one for each action.The network is trained to predict the expected value for each action,given the input state. The action with the highest expected value isthen chosen.
Packages
First, let’s import needed packages. Firstly, we needgymnasium for the environment,installed by using pip. This is a fork of the original OpenAIGym project and maintained by the same team since Gym v0.19.If you are running this in Google Colab, run:
%%bashpip3 install gymnasium[classic_control]
We’ll also use the following from PyTorch:
neural networks (
torch.nn
)optimization (
torch.optim
)automatic differentiation (
torch.autograd
)
import gymnasium as gymimport mathimport randomimport matplotlibimport matplotlib.pyplot as pltfrom collections import namedtuple, dequefrom itertools import countimport torchimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Fenv = gym.make("CartPole-v1")# set up matplotlibis_ipython = 'inline' in matplotlib.get_backend()if is_ipython: from IPython import displayplt.ion()# if GPU is to be useddevice = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
Replay Memory¶
We’ll be using experience replay memory for training our DQN. It storesthe transitions that the agent observes, allowing us to reuse this datalater. By sampling from it randomly, the transitions that build up abatch are decorrelated. It has been shown that this greatly stabilizesand improves the DQN training procedure.
For this, we’re going to need two classes:
Transition
- a named tuple representing a single transition inour environment. It essentially maps (state, action) pairsto their (next_state, reward) result, with the state being thescreen difference image as described later on.ReplayMemory
- a cyclic buffer of bounded size that holds thetransitions observed recently. It also implements a.sample()
method for selecting a random batch of transitions for training.
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))class ReplayMemory(object): def __init__(self, capacity): self.memory = deque([], maxlen=capacity) def push(self, *args): """Save a transition""" self.memory.append(Transition(*args)) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory)
Now, let’s define our model. But first, let’s quickly recap what a DQN is.
DQN algorithm¶
Our environment is deterministic, so all equations presented here arealso formulated deterministically for the sake of simplicity. In thereinforcement learning literature, they would also contain expectationsover stochastic transitions in the environment.
Our aim will be to train a policy that tries to maximize the discounted,cumulative reward\(R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t\), where\(R_{t_0}\) is also known as the return. The discount,\(\gamma\), should be a constant between \(0\) and \(1\)that ensures the sum converges. A lower \(\gamma\) makesrewards from the uncertain far future less important for our agentthan the ones in the near future that it can be fairly confidentabout. It also encourages agents to collect reward closer in timethan equivalent rewards that are temporally far away in the future.
The main idea behind Q-learning is that if we had a function\(Q^*: State \times Action \rightarrow \mathbb{R}\), that could tellus what our return would be, if we were to take an action in a givenstate, then we could easily construct a policy that maximizes ourrewards:
\[\pi^*(s) = \arg\!\max_a \ Q^*(s, a)\]
However, we don’t know everything about the world, so we don’t haveaccess to \(Q^*\). But, since neural networks are universal functionapproximators, we can simply create one and train it to resemble\(Q^*\).
For our training update rule, we’ll use a fact that every \(Q\)function for some policy obeys the Bellman equation:
\[Q^{\pi}(s, a) = r + \gamma Q^{\pi}(s', \pi(s'))\]
The difference between the two sides of the equality is known as thetemporal difference error, \(\delta\):
\[\delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))\]
To minimize this error, we will use the Huberloss. The Huber loss actslike the mean squared error when the error is small, but like the meanabsolute error when the error is large - this makes it more robust tooutliers when the estimates of \(Q\) are very noisy. We calculatethis over a batch of transitions, \(B\), sampled from the replaymemory:
\[\mathcal{L} = \frac{1}{|B|}\sum_{(s, a, s', r) \ \in \ B} \mathcal{L}(\delta)\]
\[\text{where} \quad \mathcal{L}(\delta) = \begin{cases} \frac{1}{2}{\delta^2} & \text{for } |\delta| \le 1, \\ |\delta| - \frac{1}{2} & \text{otherwise.}\end{cases}\]
Q-network¶
Our model will be a feed forward neural network that takes in thedifference between the current and previous screen patches. It has twooutputs, representing \(Q(s, \mathrm{left})\) and\(Q(s, \mathrm{right})\) (where \(s\) is the input to thenetwork). In effect, the network is trying to predict the expected return oftaking each action given the current input.
class DQN(nn.Module): def __init__(self, n_observations, n_actions): super(DQN, self).__init__() self.layer1 = nn.Linear(n_observations, 128) self.layer2 = nn.Linear(128, 128) self.layer3 = nn.Linear(128, n_actions) # Called with either one element to determine next action, or a batch # during optimization. Returns tensor([[left0exp,right0exp]...]). def forward(self, x): x = F.relu(self.layer1(x)) x = F.relu(self.layer2(x)) return self.layer3(x)
Training¶
Hyperparameters and utilities¶
This cell instantiates our model and its optimizer, and defines someutilities:
select_action
- will select an action according to an epsilongreedy policy. Simply put, we’ll sometimes use our model for choosingthe action, and sometimes we’ll just sample one uniformly. Theprobability of choosing a random action will start atEPS_START
and will decay exponentially towardsEPS_END
.EPS_DECAY
controls the rate of the decay.plot_durations
- a helper for plotting the duration of episodes,along with an average over the last 100 episodes (the measure used inthe official evaluations). The plot will be underneath the cellcontaining the main training loop, and will update after everyepisode.
# BATCH_SIZE is the number of transitions sampled from the replay buffer# GAMMA is the discount factor as mentioned in the previous section# EPS_START is the starting value of epsilon# EPS_END is the final value of epsilon# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay# TAU is the update rate of the target network# LR is the learning rate of the ``AdamW`` optimizerBATCH_SIZE = 128GAMMA = 0.99EPS_START = 0.9EPS_END = 0.05EPS_DECAY = 1000TAU = 0.005LR = 1e-4# Get number of actions from gym action spacen_actions = env.action_space.n# Get the number of state observationsstate, info = env.reset()n_observations = len(state)policy_net = DQN(n_observations, n_actions).to(device)target_net = DQN(n_observations, n_actions).to(device)target_net.load_state_dict(policy_net.state_dict())optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)memory = ReplayMemory(10000)steps_done = 0def select_action(state): global steps_done sample = random.random() eps_threshold = EPS_END + (EPS_START - EPS_END) * \ math.exp(-1. * steps_done / EPS_DECAY) steps_done += 1 if sample > eps_threshold: with torch.no_grad(): # t.max(1) will return the largest column value of each row. # second column on max result is index of where max element was # found, so we pick action with the larger expected reward. return policy_net(state).max(1).indices.view(1, 1) else: return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)episode_durations = []def plot_durations(show_result=False): plt.figure(1) durations_t = torch.tensor(episode_durations, dtype=torch.float) if show_result: plt.title('Result') else: plt.clf() plt.title('Training...') plt.xlabel('Episode') plt.ylabel('Duration') plt.plot(durations_t.numpy()) # Take 100 episode averages and plot them too if len(durations_t) >= 100: means = durations_t.unfold(0, 100, 1).mean(1).view(-1) means = torch.cat((torch.zeros(99), means)) plt.plot(means.numpy()) plt.pause(0.001) # pause a bit so that plots are updated if is_ipython: if not show_result: display.display(plt.gcf()) display.clear_output(wait=True) else: display.display(plt.gcf())
Training loop¶
Finally, the code for training our model.
Here, you can find an optimize_model
function that performs asingle step of the optimization. It first samples a batch, concatenatesall the tensors into a single one, computes \(Q(s_t, a_t)\) and\(V(s_{t+1}) = \max_a Q(s_{t+1}, a)\), and combines them into ourloss. By definition we set \(V(s) = 0\) if \(s\) is a terminalstate. We also use a target network to compute \(V(s_{t+1})\) foradded stability. The target network is updated at every step with asoft update controlled bythe hyperparameter TAU
, which was previously defined.
def optimize_model(): if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for # detailed explanation). This converts batch-array of Transitions # to Transition of batch-arrays. batch = Transition(*zip(*transitions)) # Compute a mask of non-final states and concatenate the batch elements # (a final state would've been the one after which simulation ended) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) # Compute Q(s_t, a) - the model computes Q(s_t), then we select the # columns of actions taken. These are the actions which would've been taken # for each batch state according to policy_net state_action_values = policy_net(state_batch).gather(1, action_batch) # Compute V(s_{t+1}) for all next states. # Expected values of actions for non_final_next_states are computed based # on the "older" target_net; selecting their best reward with max(1).values # This is merged based on the mask, such that we'll have either the expected # state value or 0 in case the state was final. next_state_values = torch.zeros(BATCH_SIZE, device=device) with torch.no_grad(): next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values # Compute the expected Q values expected_state_action_values = (next_state_values * GAMMA) + reward_batch # Compute Huber loss criterion = nn.SmoothL1Loss() loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1)) # Optimize the model optimizer.zero_grad() loss.backward() # In-place gradient clipping torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100) optimizer.step()
Below, you can find the main training loop. At the beginning we resetthe environment and obtain the initial state
Tensor. Then, we samplean action, execute it, observe the next state and the reward (always1), and optimize our model once. When the episode ends (our modelfails), we restart the loop.
Below, num_episodes is set to 600 if a GPU is available, otherwise 50episodes are scheduled so training does not take too long. However, 50episodes is insufficient for to observe good performance on CartPole.You should see the model constantly achieve 500 steps within 600 trainingepisodes. Training RL agents can be a noisy process, so restarting trainingcan produce better results if convergence is not observed.
if torch.cuda.is_available() or torch.backends.mps.is_available(): num_episodes = 600else: num_episodes = 50for i_episode in range(num_episodes): # Initialize the environment and get its state state, info = env.reset() state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) for t in count(): action = select_action(state) observation, reward, terminated, truncated, _ = env.step(action.item()) reward = torch.tensor([reward], device=device) done = terminated or truncated if terminated: next_state = None else: next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0) # Store the transition in memory memory.push(state, action, next_state, reward) # Move to the next state state = next_state # Perform one step of the optimization (on the policy network) optimize_model() # Soft update of the target network's weights # θ′ ← τ θ + (1 −τ )θ′ target_net_state_dict = target_net.state_dict() policy_net_state_dict = policy_net.state_dict() for key in policy_net_state_dict: target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU) target_net.load_state_dict(target_net_state_dict) if done: episode_durations.append(t + 1) plot_durations() breakprint('Complete')plot_durations(show_result=True)plt.ioff()plt.show()
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/gymnasium/utils/passive_env_checker.py:249: DeprecationWarning:`np.bool8` is a deprecated alias for `np.bool_`. (Deprecated NumPy 1.24)Complete
Here is the diagram that illustrates the overall resulting data flow.
Actions are chosen either randomly or based on a policy, getting the nextstep sample from the gym environment. We record the results in thereplay memory and also run optimization step on every iteration.Optimization picks a random batch from the replay memory to do training of thenew policy. The “older” target_net is also used in optimization to compute theexpected Q values. A soft update of its weights are performed at every step.
Total running time of the script: ( 5 minutes 44.153 seconds)
Download Python source code: reinforcement_q_learning.py
Download Jupyter notebook: reinforcement_q_learning.ipynb