Applying Temporal Difference Methods to Machine Learning — Part 1

- 7 mins

In this part, I will be covering the concepts underlying this application.


When introducing these methods, Sutton makes the claim that temporal-difference methods have an advantage over typical supervised learning because it has the effect of spreading out the computation load and generate more accurate predictions.

He emphasizes that this is true in a particular setting where the predictions are for multi-step prediction problems, meaning what we are trying to predict is only revealed after a sequence of predictions. In a sense, problems where we can only know the true outcome after multiple steps have been observed and predicted.

The author even further argues these types of problems occur more often in the real-world rather than single-step prediction problems, where each time a prediction is made, the real outcome can be verified.

Here are some examples of multi-steps prediction problems

What I propose in this case study is to play with a classic machine learning problem that is the MNIST dataset (more info about MNIST here). I will detail further along how this problem can be modified to be considered a multi-step prediction problem.

Temporal-Difference learning

The main concept behind the temporal-difference learning methods is to allow feedback to be learned based on the differences of the predictions that are made at each step, as opposed to waiting for the real outcome at the end of a sequence of observations.

Let’s consider a sequence of observations $x_1, x_2, \dots, x_m$ that lead to the outcome $z$. Let’s further denote the prediction of $z$ at each time step $t \in {1, \dots, m}$ as $P_t$, and where $P_{m+1} = z$. In addition, $x_t$ could denote an observation vector of different attributes.

We are therefore trying to predict what will be the final outcome after each observation $x_t$ at time $t$. To do so, we will be using a set of weights denoted $w$, where $P_t$ can now be written as $P(x_t, w)$. Sutton analyses the case where $P(x_t, w)$ is a linear function of $x_t$ and $w$. I will be exploring the non-linear case further down the line when being applied to MNIST.

For now, let’s focus on a variant of this problem where the weights are only updated at the end of the sequence under the following update rule,

$w \leftarrow w + \sum \limits_{t=1}^m \Delta w_t$

Traditional supervised learning approach

Under the traditional supervised learning approach, all observations ${ x_1, x_2, \dots, x_m }$ are considered paired observations with the outcome $z$. Under this approach, and given our prediction function $P(x_t, w)$ a very popular gradient update rule for $w$ based on backpropagation of the error can be given by the following,

$\Delta w_t = \alpha (z - P_t) \nabla_w P_t$

Where $\nabla_w P_t$ is the gradient of the prediction at time $t$ with regard to the weights of our function and $\alpha$ our learning rate.

An important observation that the author emphasizes is the fact that all $\Delta w_t$ for all $t$ depend on the error at each time step $(z - P_t)$ which themselves depend on $z$ that is only known at the end of the sequence under the types of problems that we are exploring.

In practical terms, we would therefore make a prediction and determine the gradient at each time step, store them in memory until the end of the full sequence and once the outcome is known, compute the errors at each time step and do the update to our weights.

This is what TD learning will aim to circumvent, to allow iterative calculations to be made rather than stacking up the information in other to reduce the memory requirements.

TD approach

The main issue with the traditional machine learning approach described above is the update rule referring to the real outcome at each time step. Sutton suggests that rather to see the error as the outcome vs our current predictions, to consider the sum of all differences between our future predictions that will occur $(P_{t+1} - P_t)$. These differences in predictions at each time steps are called temporal differences, hence the name of the method!

Now let’s do a bit of math to figure out what would be the update rule based on this approach.

Arithmetically, we can rewrite $z - P_t$ as $\sum \limits_{k=t}^m(P_{k+1} - P_k)$, where $P_{m+1} = z$. We can then re-write the update rule from the first approach,

$w \leftarrow w + \sum \limits_{t=1}^m \alpha (z - P_t) \nabla_w P_t = w + \sum \limits_{t=1}^m \alpha \sum \limits_{k=t}^m (P_{k+1} - P_k) \nabla_w P_t$

By following Fubini’s theorem, we can switch the indices of the double summation to obtain

$w + \sum \limits_{t=1}^m \alpha \sum \limits_{k=t}^m (P_{k+1} - P_k) \nabla_w P_t = w + \sum \limits_{k=1}^m \alpha \sum \limits_{t=1}^k (P_{k+1} - P_k) \nabla_w P_t$.

By simply swapping the indices $k$ and $t$ for clearer understanding and moving around constants, we can finally obtain the update rule,

$w \leftarrow w + \sum \limits_{t=1}^m \alpha (P_{t+1} - P_t) \sum \limits_{k=1}^t \nabla_w P_k$

We can then see this update rule as a sum of $\Delta w_t$ for any $t$ as,

$\Delta w_t = \alpha (P_{t+1} - P_t) \sum \limits_{k=1}^t \nabla_w P_k$.

We can therefore notice the update rule for the TD approach doesn’t required the actual outcome $z$ (unless on the last prediction $m$). It therefore doesn’t require us to track all the predictions that were made during the sequence. To compensate though, we need to compute the sum of gradients over previous time steps, which can be done easily in terms of memory, as we can only store the current sum and add the current gradient when obtained.

In other words, when we compute the predictions at time step $t+1$, we can obtain the sum of previous updates easily. We determine the TD error $(P_{t+1} - P_t)$, simply add the gradient to the total gradient kept in memory w.r.t. the weights and increment our sum of updates. When we reach the end of the sequence, no additional computation than any previous step is required other than doing the actual updates to $w$.

This dramatically reduces the memory requirements compared to the traditional machine learning approach detailed above, especially in cases of long sequence.

Using MNIST as a multi-step prediction problem

Traditionally MNIST has been seen as a single-step prediction problem, i.e. we see the 28x28 pixel input as a whole, compute a prediction and compare it to the real number. In order to use it as a multi-step prediction problem, we can simply consider the image input as a sequence of 784 pixels! This way, after each pixel if observed, we can make a prediction w.r.t. the image and have the real outcome once the image has been fully covered, making it a multi-step prediction problem.

Indeed, we can therefore denote pixel $i$ as $p_i$ and going from left to right and from top row to bottom row, we can obtain a sequence $p_1, p_2, \dots, p_{784}$. We can further denote the outcome of the sequence as being the class of the image $z \in { 0, 1, 2, \dots, 9 }$.

Additionally, to make the computation of the prediction $P_t$ at time step $t$ a function of all previously observed pixels in the sequence rather than just the current pixel $p_t$, we can express our observation at time $t$ as $x_t = [p_1, p_2, \dots, p_{t-1}, p_t, 0, \dots, 0, 0]^T$ as being a vector of size 784 with all previously seen pixels up to time step $t$ and 0 afterwards.

Next steps

With the concepts underlying the case study exposed, in the following part I will cover the performance of the TD approach in comparison to the traditional machine learning approach for multi-step prediction problem.

To be fair, there are some very powerful methods that perform extremely well on MNIST these days. The goal here is not to compare the best of machine learning to the TD learning approach mentioned above. It is meant to be an exercise of applying the fundamental concept introduced by Sutton.

The machine learning approach used will have the same sequenced inputs as the TD method. Understandably, knowing that MNIST is mostly seen as a purely single-step prediction problem could be exploited by having as an input the full sequence of pixels.

I trust the reader to understand this nuance :).

rss facebook twitter github gitlab youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora quora