COMP5329 Deep Learning Week 8 – Recurrent Neural Networks
1. Backpropagation through time (BPTT) in RNNs
After the RNN outputs the prediction vector h(t), we compute the prediction error E(t) and
use the Back Propagation Through time algorithm to compute the gradient
(1)
∂E
∂W
=
T∑
t=1
∂Et
∂W
.
The gradient is used to update the model parameters by:
(2) W ← W − α
∂E
∂W
.
And we continue the learning process using the Gradient Descent algorithm.
Say we have learning task that includes T time steps, the gradient of the error on the k time
step is given by:
(3)
∂Ek
∂W
=
∂E
∂hk
∂hk
∂ck
· · ·
∂c2
∂c1
∂c1
∂W
=
∂Ek
∂hk
∂hk
∂ck
(
k∏
t=2
∂ct
∂ct−1
)
∂c1
∂W
Notice that since W = [Whh,Whx], c(t) can be written as:
(4) c(t) = tanh(Whhct−1 +Whxxt).
Compute the derivative of c(t) and get:
(5)
∂ct
∂ct−1
= tanh
′
(Whhct−1 +Whxxt)
∂
∂ct−1
[Whhct−1 +Whxxt] = tanh
′
(Whhct−1 +Whxxt)Whh
Plug Eq. (5) into Eq. (3) and get our backpropagated gradient
(6)
∂Ek
∂W
=
∂Ek
∂hk
∂hk
∂ck
(
k∏
t=2
tanh
′
(Whhct−1 +Whxxt)Whh
)
∂c1
∂W
The last expression tends to vanish when k is large, this is due to the derivative of the tanh
activation function which is smaller than 1.
The product of derivatives can also explode if the weights Whh are large enough to overpower
the smaller tanh derivative, this is known as the exploding gradient problem.
2. Long Short-Term Memory
An LSTM network has an input vector [h(t − 1), x(t)] at time step t. The network cell state
is denoted by c(t). The output vectors passed through the network between consecutive time
steps t, t+ 1 are denoted by h(t).
An LSTM network has three gates that update and control the cell states, these are the forget
gate, input gate and output gate. The gates use hyperbolic tangent and sigmoid activation
functions. The forget gate controls what information in the cell state to forget, given new infor-
mation than entered the network.
The forget gate’s output is given by:
(7) ft = σ(Wfhht−1 +Wfxxt).
The input gate controls what new information will be encoded into the cell state, given the new
input information.
1
The input gate’s output has the form:
(8) tanh(Whhht−1 +Whxxt)⊗ σ(Wihht−1 +Wixxt),
where we can define
(9) c̃t = tanh(Whhht−1 +Whxxt)
and
(10) it = σ(Wihht−1 +Wixxt).
The output gate controls what information encoded in the cell state is sent to the network as
input in the following time step, this is done via the output vector h(t). The output gate’s
activations are given by:
(11) ot = σ(Wohht−1 +Woxxt)
and the cell’s output vector is given by:
(12) ht = ot ⊗ tanh(ct)
The long term dependencies and relations are encoded in the cell state vectors and it’s the cell
state derivative that can prevent the LSTM gradients from vanishing. The LSTM cell state has
the form:
(13) ct = ct−1 ⊗ ft + c̃t ⊗ it.
2.1. Backpropagation through time in LSTMs. As in the RNN model, our LSTM network
outputs a prediction vector h(k) on the k-th time step. The knowledge encoded in the state
vectors c(t) captures long-term dependencies and relations in the sequential data.
The length of the data sequences can be hundreds and even thousands of time steps, making
it extremely difficult to learn using a basic RNN. We compute the gradient used to update the
network parameters, the computation is done over T time steps.
As in RNNs, the error term gradient is given by the following sum of T gradients:
(14)
∂E
∂W
=
T∑
t=1
∂Et
∂W
.
The gradient of the error for some time step k has the form:
(15)
∂Ek
∂W
=
∂E
∂hk
∂hk
∂ck
· · ·
∂c2
∂c1
∂c1
∂W
=
∂Ek
∂hk
∂hk
∂ck
(
k∏
t=2
∂ct
∂ct−1
)
∂c1
∂W
As we have seen, the following product causes the gradients to vanish:
(16)
k∏
t=2
∂ct
∂ct−1
In an LTSM, the state vector c(t), has the form:
(17) ct = ct−1 ⊗ ft + c̃t ⊗ it.
Notice that the state vector c(t) is a function of the following elements, which should be taken
into account when computing the derivative during backpropagation:
(18) ct−1, ft, c̃t, it
2
Compute the derivative of Eq. (17) and get:
∂ct
∂ct−1
=
∂
∂ct−1
[ct−1 ⊗ ft + c̃t ⊗ it] =
∂
∂ct−1
[ct−1 ⊗ ft] +
∂
∂ct−1
[c̃t ⊗ it](19)
=
∂ft
∂ct−1
ct−1 + ft +
∂it
∂ct−1
c̃t +
∂c̃t
∂ct−1
it,(20)
where the four derivative terms are respectively denoted as At, Bt, Ct and Dt. The LSTM states
gradient can be written as
(21)
∂Ek
∂W
=
∂Ek
∂hk
∂hk
∂ck
(
k∏
t=2
[At +Bt + Ct +Dt]
)
∂c1
∂W
Notice that the gradient contains the forget gate’s vector of activations, which allows the net-
work to better control the gradients values, at each time step, using suitable parameter updates
of the forget gate. The presence of the forget gate’s activations allows the LSTM to decide,
at each time step, that certain information should not be forgotten and to update the model’s
parameters accordingly.
Say that for some time step k < T , we have that: ∑k i=1 ∂Et ∂W → 0. Then for the gradient not to vanish, we can find a suitable parameter update of the forget gate at time step k + 1 such that ∂Ek+1 ∂W 6→ 0. It is the presence of the forget gate’s vector of activations in the gradient term along with ad- ditive structure which allows the LSTM to find such a parameter update at any time step, and this yields: ∑k+1 i=1 ∂Et ∂W 6→ 0, and the gradient doesn’t vanish. Another important property to notice is that the cell state gradient is an additive function made up from four elements denoted A(t), B(t), C(t), D(t). This additive property enables better balancing of gradient values during backpropagation. The LSTM updates and balances the val- ues of the four components making it more likely the additive expression does not vanish. This additive property is different from the RNN case where the gradient contained a single element inside the product. In LSTMs, however, the presence of the forget gate, along with the additive property of the cell state gradients, enables the network to update the parameter in such a way that the different sub gradients do not necessarily agree and behave in a similar manner, making it less likely that all of the T gradients will vanish, or in other words, the series of functions does not converge to zero, and our gradients do not vanish. Summing up, we have seen that RNNs suffer from vanishing gradients and caused by long series of multiplications of small values, diminishing the gradients and causing the learning process to become degenerate. In a analogues way, RNNs suffer from exploding gradients affected from large gradient values and hampering the learning process. LSTMs solve the problem using a unique additive gradient structure that includes direct access to the forget gate’s activations, enabling the network to encourage desired behaviour from the error gradient using frequent gates update on every time step of the learning process. 3 1. Backpropagation through time (BPTT) in RNNs 2. Long Short-Term Memory 2.1. Backpropagation through time in LSTMs