- Basic Principles of LSTM
- Using LSTM in PyTorch
- Practical Example: Text Classification
- Notes
- Summary
- References
The Long Short-Term Memory Network (LSTM) is an important variant of Recurrent Neural Networks (RNN), designed to address the issues of gradient vanishing and gradient exploding that traditional RNNs encounter when processing long sequences. By introducing gating mechanisms, LSTM effectively captures long-range dependencies, making it particularly powerful in tasks such as natural language processing and time-series forecasting.
Basic Principles of LSTM
What is LSTM?
Traditional RNNs tend to suffer from gradient vanishing or exploding when handling long sequences, making it difficult for the model to capture long-term dependencies. LSTM introduces three gates (input gate, forget gate, and output gate) along with a cell state to regulate the flow of information, thereby mitigating these issues. The design of LSTM enables it to retain and propagate important information over long sequences, significantly improving the model’s memory capability.
Structure of LSTM
The core of an LSTM is the LSTM cell, and each cell contains the following key components:
- Cell State: Runs through the entire sequence, serving as a primary pathway for information.
- Forget Gate: Determines how much past information to retain.
- Input Gate: Decides how much new information to add to the cell state.
- Output Gate: Determines how much information from the current cell state will be output.
Suppose we have an input sequence of length
At each time step
-
Forget Gate:
where:
is the forget gate output is the forget gate weight matrix is the previous hidden state is the input at time is the bias is the Sigmoid activation function
-
Input Gate:
where:
is the input gate output is the candidate cell state
-
Cell State Update:
where:
is the updated cell state is element-wise multiplication
-
Output Gate:
where:
is the output gate output is the hidden state at time
Advantages of LSTM
- Captures long-range dependencies: Thanks to the gating mechanism.
- Selective information retention: The input and forget gates allow the model to keep or discard information efficiently.
- Stable gradient propagation: LSTM alleviates gradient vanishing/exploding problems, making training more stable.
Using LSTM in PyTorch
PyTorch provides a powerful LSTM module, making it convenient to build and train LSTM models. Below are examples demonstrating how to use the LSTM module, build multi-layer LSTMs, use LSTMCell, and construct multi-layer LSTMCell.
1. Basic Usage of LSTM
import torch
import torch.nn as nn
# Define LSTM model
class SimpleLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleLSTM, self).__init__()
self.hidden_size = hidden_size
# Define the LSTM layer
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
# Output layer
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(1, x.size(0), self.hidden_size)
c0 = torch.zeros(1, x.size(0), self.hidden_size)
out, (hn, cn) = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
input_size = 10
hidden_size = 20
output_size = 1
model = SimpleLSTM(input_size, hidden_size, output_size)
print(model)
SimpleLSTM(
(lstm): LSTM(10, 20, batch_first=True)
(fc): Linear(in_features=20, out_features=1, bias=True)
)
2. Multi-layer LSTM
class MultiLayerLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super(MultiLayerLSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, (hn, cn) = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
num_layers = 3
model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
print(model)
MultiLayerLSTM(
(lstm): LSTM(10, 20, num_layers=3, batch_first=True)
(fc): Linear(in_features=20, out_features=1, bias=True)
)
3. Using LSTMCell
class LSTMWithCell(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(LSTMWithCell, self).__init__()
self.hidden_size = hidden_size
self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
batch_size, seq_len, _ = x.size()
h = torch.zeros(batch_size, self.hidden_size)
c = torch.zeros(batch_size, self.hidden_size)
for t in range(seq_len):
h, c = self.lstm_cell(x[:, t, :], (h, c))
out = self.fc(h)
return out
model = LSTMWithCell(input_size, hidden_size, output_size)
print(model)
LSTMWithCell(
(lstm_cell): LSTMCell(10, 20)
(fc): Linear(in_features=20, out_features=1, bias=True)
)
4. Multi-layer LSTMCell
class MultiLayerLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers):
super(MultiLayerLSTMCell, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
self.cells = nn.ModuleList([
nn.LSTMCell(input_size if i == 0 else hidden_size, hidden_size)
for i in range(num_layers)
])
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
batch_size, seq_len, _ = x.size()
h = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]
c = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]
for t in range(seq_len):
input_t = x[:, t, :]
for layer in range(self.num_layers):
h[layer], c[layer] = self.cells[layer](input_t, (h[layer], c[layer]))
input_t = h[layer]
out = self.fc(h[-1])
return out
num_layers = 3
model = MultiLayerLSTMCell(input_size, hidden_size, output_size, num_layers)
print(model)
MultiLayerLSTMCell(
(cells): ModuleList(
(0): LSTMCell(10, 20)
(1): LSTMCell(20, 20)
(2): LSTMCell(20, 20)
)
(fc): Linear(in_features=20, out_features=1, bias=True)
)
Practical Example: Text Classification
Data Preparation
batch_size = 32
seq_len = 50
input_size = 100
hidden_size = 128
output_size = 2
num_layers = 2
inputs = torch.randn(batch_size, seq_len, input_size)
labels = torch.randint(0, output_size, (batch_size,))
Define Model, Loss, Optimizer
model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
Training
num_epochs = 10
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
Epoch [1/100], Loss: 0.6932
Epoch [2/100], Loss: 0.6796
Epoch [3/100], Loss: 0.6660
Epoch [4/100], Loss: 0.6512
Epoch [5/100], Loss: 0.6342
Epoch [6/100], Loss: 0.6139
Epoch [7/100], Loss: 0.5893
Epoch [8/100], Loss: 0.5595
Epoch [9/100], Loss: 0.5239
Epoch [10/100], Loss: 0.4827
LSTM contains multiple gating mechanisms, so gradients must propagate through these gates during backpropagation. This increases the computation required per update, so convergence may be slower than a simple RNN. However, thanks to its gating structure, LSTM greatly mitigates gradient vanishing and exploding.
Notes
- Batch dimension: Setting
batch_first=Truemeans inputs/outputs follow the shape(batch, seq, feature). - Initialization: Hidden and cell states must be initialized for each batch, typically as zeros.
- Gradient clipping: Use
torch.nn.utils.clip_grad_norm_to prevent gradient explosion. - Variable-length sequences: Use
pack_padded_sequenceandpad_packed_sequence. - Bidirectional LSTM: Enable via
bidirectional=True.
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
Summary
This article introduced the basic principles of LSTM and demonstrated its implementation in PyTorch. Through theoretical formulas and code examples, we explored single-layer/multi-layer LSTMs, LSTMCell, and multi-layer LSTMCell. Although Transformers have achieved breakthrough results in many tasks, LSTMs remain widely used in practical applications due to their simplicity and effectiveness.
