New update is available. Click here to update.

Solving the Vanishing Gradient Problem with LSTM

soham Medewar
Last Updated: Sep 19, 2022

Introduction

When dealing with sequential data, one must prefer recurrent neural networks over any other neural networks because RNN maintains internal memory so that it can store the data of previous input data. But there is a drawback to using an RNN because it causes a vanishing gradient problem in simple RNN. For this reason, an updated RNN known as LSTM is widely used.

In the following article, we will see how RNN causes vanishing gradient problems and also will see how this problem can be solved using LSTM.

Vanishing Gradient Problem

RNNs are plagued by the problem of vanishing gradients, which makes learning large data sequences difficult. The gradients contain information utilized in the RNN parameter update, and as the gradient shrinks, the parameter updates become minor, implying that no meaningful learning occurs.

Now let us see the proof that RNN causes a vanishing gradient problem.

RNN and Vanishing Gradient Problem

Let us have a look at the basic architecture of the recurrent neural network. The image below is an RNN.

RNN example

The neural network has an input sequence of [x(1), x(2),…, x(k)]; at a time step t, we provide an input of x(t). Past information and learned knowledge are encoded in the neural network as vectors [c(1), c(2),…, c(k-1)], at time step t, the neural network has a state vector of c(t-1).   The state vector c(t-1) and  the input vector x(t) these two vectors are attached to make a complete input vector at time step t, i.e.,  [c(t-1), x(t)].

The two-weight matrices: Wrec and Win, of the neural network, are connecting to two parts of the input vector c(t-1) and x(t), to the hidden layer. We ignore the bias vectors in our calculations and write W = [Wrec, Win] instead.

In the hidden layer, the sigmoid function is utilized as the activation function. 

At the last time step, the network produces a single vector (RNNs can output a vector at each time step, but we'll use this simplified model).

Backpropagation Through Time in RNNs

We compute the prediction error E(k) and utilize the Back Propagation Through Time approach to computing the gradient after the RNN outputs the prediction vector h(k).

Gradient of error in RNN

We use the gradient to update the weights of the models in the following way.

Wieght Update equation

Now we apply the gradient descent algorithm to the neural network and continue the learning process.

If the learning process has total T time steps, then the gradient of the error on the kth time step will be:

Applying Gradient Descent Algorithm

As we can write W = [Wrec, Win], therefore c(t) can be written as:

Simplification

Computing the derivative of c(t):

Derivate of c(t)

Now insert the derivative of c(t) in the gradient of the error in the kth step’s equation.

Simplification of Kth step equation

When k is big, the last statement tends to vanish because the derivative of the tanh activation function is lower than 1.

The exploding gradient problem can also occur when the weights Wrec are strong enough to override the smaller tanh derivative, causing the product of derivatives to explode.

We have:

equation

At some time step k:

equation

Due to this, ​​our complete gradient will vanish.

Gradient

And the neural network will be updated in the following way:

Newural Network update

So the weights (W) will not have any significant change, and hence learning of neural networks will have no progress.

LSTM Solving Vanishing Gradient Problem

At time step t the LSTM has an input vector of [h(t-1), x(t)]. The cell state of the LSTM unit is defined by c(t). The output vectors that are passed through the LSTM network from time step t to t+1 are denoted by h(t).

LSTM network cells

The three gates of the LSTM unit cell that update and control the cell state of the neural network are the forgot gate, the input gate, and the output gate.

(To understand the structure of the LSTM unit, visit this blog)

The forget gate determines which information in the cell state should be forgotten when fresh data enters the network. The output of the forgot gate is given by:

equation

Given the new input information, the input gate determines what new information is encoded into the cell state. The output of the input gate is given by:

equation

The output of the input gate includes the product of the outputs of two fully connected layers.

equation

equation

The output gate, which is controlled by the output vector h, determines what information encoded in the cell state is delivered to the network as input in the next time step t.

The activation of the output gate is given by:

equation

The output vector of the cell is given by:

equation

Therefore the form of the cell state will be:

equation

Backpropagation Through Time in LSTM

On the kth time step, our LSTM network, generates a prediction vector h(k), like the RNN model. Long-term dependencies and relationships in sequential data are captured by the knowledge contained in state vectors c(t). 

The data sequences might be hundreds or thousands of time steps long, making learning with a standard RNN exceedingly challenging. 

The gradient is computed across T time steps, and it is used to update the network parameters.

Backpropogating through time for gradient computation

The error term gradient is equal to the sum of T gradients, much like in RNNs.

equation

All of these T subgradients must vanish for the total error gradient to vanish. If we consider (3) as a series of functions, the sequence of partial sums tends to zero, and hence the series converges to zero.

equation

the series of partial sums

equation

where

equation

tends to zero.

If we want equation number 3 not to vanish, we just need to increase the likelihood that at least one of the subgradients must not vanish. In other words, make the series of subgradients not to vanish(converge to zero) in equation (3).

The Error Gradients in the LSTM Network

The gradient of the error for some time step k has the form:

equation

The following product term leads to vanishing gradient problems.

equation

The state vector c(t) in the LSTM has the following form:

equation

In short, we can write it in the following format.

equation

It's important to remember that the state vector c(t) is a function of the following elements, which should be considered while computing the derivative during backpropagation:

equation

After computing the derivative of the equation number 5, we get:

equation

Now we will compute the four derivative terms and write:

equation

Denoting the four elements having the derivative of the cell state by:

equation

Sum up all the gradients

equation

Now plug equation no. 6 into equation no. 4, we will obtain the LSTM state gradient:

equation

Preventing the Error Gradient from Vanishing

The gradient holds the activation vector of the forget gate, which helps the network to better regulate the gradient values at each time step by updating the forget gate's parameters. The activation of the forget gate allows the LSTM to select whether or not particular information should be remembered at each time step and update the model's parameters accordingly.

Let us say that at any time step k < T, we have that:

equation

Then, at time step k+1, we may find an appropriate parameter update of the forget gate such that the gradient does not disappear.

equation

The existence of the forget gate's activation vector in the gradient term, and additive structure allows the LSTM to identify such a parameter update at each time step, yielding:

equation

Now the gradient doesn’t vanish.

Frequently Asked Questions

What is LSTM used for?

As there might be lags of uncertain duration between critical occurrences in a time series, LSTM networks are well-suited to categorizing, processing, and making predictions based on that data. The vanishing gradient problem that can occur when training traditional RNNs was addressed with the development of LSTMs.

What is the advantage of LSTM?

LSTMs provide us with many parameters such as input and output biases and learning rates. Hence, no need for fine adjustments. In LSTM, the complexity of updating weights is reduced to O(1).

Why is LSTM good for text classification?

LSTM models can capture long-term dependencies between word sequences; hence are better used for text classification.

How does LSTM learn?

They have internal mechanisms called gates that can regulate the flow of information. These gates can learn which data in a sequence is important to keep or throw away.

Conclusion

In this article, we have discussed the following topics:


Recommended Readings:


Want to learn more about Machine Learning? Here is an excellent course that can guide you in learning. 

Do check out The Interview guide for Product Based Companies as well as some of the Popular Interview Problems from Top companies like Amazon, Adobe, Google, etc. on CodeStudio.

Also check out some of the amazing Guided Paths on topics such as Data Structure and Algorithms, Competitive Programming, Basics of C, Basics of Java, etc. along with some Contests and Interview Experiences only on CodeStudio

Happy Coding!

Was this article helpful ?
0 upvotes