Solving the Vanishing Gradient Problem with LSTM
When dealing with sequential data, one must prefer recurrent neural networks over any other neural networks because RNN maintains internal memory so that it can store the data of previous input data. But there is a drawback to using an RNN because it causes a vanishing gradient problem in simple RNN. For this reason, an updated RNN known as LSTM is widely used.
In the following article, we will see how RNN causes vanishing gradient problems and also will see how this problem can be solved using LSTM.
Vanishing Gradient Problem
RNNs are plagued by the problem of vanishing gradients, which makes learning large data sequences difficult. The gradients contain information utilized in the RNN parameter update, and as the gradient shrinks, the parameter updates become minor, implying that no meaningful learning occurs.
Now let us see the proof that RNN causes a vanishing gradient problem.
RNN and Vanishing Gradient Problem
Let us have a look at the basic architecture of the recurrent neural network. The image below is an RNN.
The neural network has an input sequence of [x(1), x(2),…, x(k)]; at a time step t, we provide an input of x(t). Past information and learned knowledge are encoded in the neural network as vectors [c(1), c(2),…, c(k-1)], at time step t, the neural network has a state vector of c(t-1). The state vector c(t-1) and the input vector x(t) these two vectors are attached to make a complete input vector at time step t, i.e., [c(t-1), x(t)].
The two-weight matrices: Wrec and Win, of the neural network, are connecting to two parts of the input vector c(t-1) and x(t), to the hidden layer. We ignore the bias vectors in our calculations and write W = [Wrec, Win] instead.
In the hidden layer, the sigmoid function is utilized as the activation function.
At the last time step, the network produces a single vector (RNNs can output a vector at each time step, but we'll use this simplified model).
Backpropagation Through Time in RNNs
We compute the prediction error E(k) and utilize the Back Propagation Through Time approach to computing the gradient after the RNN outputs the prediction vector h(k).
We use the gradient to update the weights of the models in the following way.
Now we apply the gradient descent algorithm to the neural network and continue the learning process.
If the learning process has total T time steps, then the gradient of the error on the kth time step will be:
As we can write W = [Wrec, Win], therefore c(t) can be written as:
Computing the derivative of c(t):
Now insert the derivative of c(t) in the gradient of the error in the kth step’s equation.
When k is big, the last statement tends to vanish because the derivative of the tanh activation function is lower than 1.
The exploding gradient problem can also occur when the weights Wrec are strong enough to override the smaller tanh derivative, causing the product of derivatives to explode.
At some time step k:
Due to this, our complete gradient will vanish.
And the neural network will be updated in the following way:
So the weights (W) will not have any significant change, and hence learning of neural networks will have no progress.
LSTM Solving Vanishing Gradient Problem
At time step t the LSTM has an input vector of [h(t-1), x(t)]. The cell state of the LSTM unit is defined by c(t). The output vectors that are passed through the LSTM network from time step t to t+1 are denoted by h(t).
The three gates of the LSTM unit cell that update and control the cell state of the neural network are the forgot gate, the input gate, and the output gate.
(To understand the structure of the LSTM unit, visit this blog)
The forget gate determines which information in the cell state should be forgotten when fresh data enters the network. The output of the forgot gate is given by:
Given the new input information, the input gate determines what new information is encoded into the cell state. The output of the input gate is given by:
The output of the input gate includes the product of the outputs of two fully connected layers.
The output gate, which is controlled by the output vector h, determines what information encoded in the cell state is delivered to the network as input in the next time step t.
The activation of the output gate is given by:
The output vector of the cell is given by:
Therefore the form of the cell state will be:
Backpropagation Through Time in LSTM
On the kth time step, our LSTM network, generates a prediction vector h(k), like the RNN model. Long-term dependencies and relationships in sequential data are captured by the knowledge contained in state vectors c(t).
The data sequences might be hundreds or thousands of time steps long, making learning with a standard RNN exceedingly challenging.
The gradient is computed across T time steps, and it is used to update the network parameters.
The error term gradient is equal to the sum of T gradients, much like in RNNs.
All of these T subgradients must vanish for the total error gradient to vanish. If we consider (3) as a series of functions, the sequence of partial sums tends to zero, and hence the series converges to zero.
the series of partial sums
tends to zero.
If we want equation number 3 not to vanish, we just need to increase the likelihood that at least one of the subgradients must not vanish. In other words, make the series of subgradients not to vanish(converge to zero) in equation (3).
The Error Gradients in the LSTM Network
The gradient of the error for some time step k has the form:
The following product term leads to vanishing gradient problems.
The state vector c(t) in the LSTM has the following form:
In short, we can write it in the following format.
It's important to remember that the state vector c(t) is a function of the following elements, which should be considered while computing the derivative during backpropagation:
After computing the derivative of the equation number 5, we get:
Now we will compute the four derivative terms and write:
Denoting the four elements having the derivative of the cell state by:
Sum up all the gradients
Now plug equation no. 6 into equation no. 4, we will obtain the LSTM state gradient:
Preventing the Error Gradient from Vanishing
The gradient holds the activation vector of the forget gate, which helps the network to better regulate the gradient values at each time step by updating the forget gate's parameters. The activation of the forget gate allows the LSTM to select whether or not particular information should be remembered at each time step and update the model's parameters accordingly.
Let us say that at any time step k < T, we have that:
Then, at time step k+1, we may find an appropriate parameter update of the forget gate such that the gradient does not disappear.
The existence of the forget gate's activation vector in the gradient term, and additive structure allows the LSTM to identify such a parameter update at each time step, yielding:
Now the gradient doesn’t vanish.
Frequently Asked Questions
What is LSTM used for?
As there might be lags of uncertain duration between critical occurrences in a time series, LSTM networks are well-suited to categorizing, processing, and making predictions based on that data. The vanishing gradient problem that can occur when training traditional RNNs was addressed with the development of LSTMs.
What is the advantage of LSTM?
LSTMs provide us with many parameters such as input and output biases and learning rates. Hence, no need for fine adjustments. In LSTM, the complexity of updating weights is reduced to O(1).
Why is LSTM good for text classification?
LSTM models can capture long-term dependencies between word sequences; hence are better used for text classification.
How does LSTM learn?
They have internal mechanisms called gates that can regulate the flow of information. These gates can learn which data in a sequence is important to keep or throw away.
In this article, we have discussed the following topics:
- Vanishing gradient problem
- Vanishing gradient problem in RNN
- Solving vanishing gradient problem using LSTM
Also check out some of the amazing Guided Paths on topics such as Data Structure and Algorithms, Competitive Programming, Basics of C, Basics of Java, etc. along with some Contests and Interview Experiences only on CodeStudio.