Long Short-Term Memory Network (LSTM)
Long Short-Term Memory Network (LSTM)

Long Short-Term Memory Network (LSTM)

in
  1. Basic Principles of LSTM
    1. What is LSTM?
    2. Structure of LSTM
    3. Advantages of LSTM
  2. Using LSTM in PyTorch
    1. 1. Basic Usage of LSTM
    2. 2. Multi-layer LSTM
    3. 3. Using LSTMCell
    4. 4. Multi-layer LSTMCell
  3. Practical Example: Text Classification
    1. Data Preparation
    2. Define Model, Loss, Optimizer
    3. Training
  4. Notes
  5. Summary
  6. References

The Long Short-Term Memory Network (LSTM) is an important variant of Recurrent Neural Networks (RNN), designed to address the issues of gradient vanishing and gradient exploding that traditional RNNs encounter when processing long sequences. By introducing gating mechanisms, LSTM effectively captures long-range dependencies, making it particularly powerful in tasks such as natural language processing and time-series forecasting.

Basic Principles of LSTM

What is LSTM?

Traditional RNNs tend to suffer from gradient vanishing or exploding when handling long sequences, making it difficult for the model to capture long-term dependencies. LSTM introduces three gates (input gate, forget gate, and output gate) along with a cell state to regulate the flow of information, thereby mitigating these issues. The design of LSTM enables it to retain and propagate important information over long sequences, significantly improving the model’s memory capability.

Structure of LSTM

The core of an LSTM is the LSTM cell, and each cell contains the following key components:

  1. Cell State: Runs through the entire sequence, serving as a primary pathway for information.
  2. Forget Gate: Determines how much past information to retain.
  3. Input Gate: Decides how much new information to add to the cell state.
  4. Output Gate: Determines how much information from the current cell state will be output.

Suppose we have an input sequence of length :
.
At each time step , LSTM performs the following computations:

  1. Forget Gate:

    where:

    • is the forget gate output
    • is the forget gate weight matrix
    • is the previous hidden state
    • is the input at time
    • is the bias
    • is the Sigmoid activation function
  2. Input Gate:

    where:

    • is the input gate output
    • is the candidate cell state
  3. Cell State Update:

    where:

    • is the updated cell state
    • is element-wise multiplication
  4. Output Gate:

    where:

    • is the output gate output
    • is the hidden state at time

Advantages of LSTM

  • Captures long-range dependencies: Thanks to the gating mechanism.
  • Selective information retention: The input and forget gates allow the model to keep or discard information efficiently.
  • Stable gradient propagation: LSTM alleviates gradient vanishing/exploding problems, making training more stable.

Using LSTM in PyTorch

PyTorch provides a powerful LSTM module, making it convenient to build and train LSTM models. Below are examples demonstrating how to use the LSTM module, build multi-layer LSTMs, use LSTMCell, and construct multi-layer LSTMCell.

1. Basic Usage of LSTM

import torch
import torch.nn as nn

# Define LSTM model
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        # Define the LSTM layer
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        c0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, (hn, cn) = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

input_size = 10
hidden_size = 20
output_size = 1
model = SimpleLSTM(input_size, hidden_size, output_size)
print(model)
SimpleLSTM(
  (lstm): LSTM(10, 20, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

2. Multi-layer LSTM

class MultiLayerLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MultiLayerLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, (hn, cn) = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

num_layers = 3
model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
print(model)
MultiLayerLSTM(
  (lstm): LSTM(10, 20, num_layers=3, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

3. Using LSTMCell

class LSTMWithCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMWithCell, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = nn.LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.hidden_size)
        c = torch.zeros(batch_size, self.hidden_size)
        for t in range(seq_len):
            h, c = self.lstm_cell(x[:, t, :], (h, c))
        out = self.fc(h)
        return out

model = LSTMWithCell(input_size, hidden_size, output_size)
print(model)
LSTMWithCell(
  (lstm_cell): LSTMCell(10, 20)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

4. Multi-layer LSTMCell

class MultiLayerLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MultiLayerLSTMCell, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.cells = nn.ModuleList([
            nn.LSTMCell(input_size if i == 0 else hidden_size, hidden_size) 
            for i in range(num_layers)
        ])
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]
        c = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]
        for t in range(seq_len):
            input_t = x[:, t, :]
            for layer in range(self.num_layers):
                h[layer], c[layer] = self.cells[layer](input_t, (h[layer], c[layer]))
                input_t = h[layer]
        out = self.fc(h[-1])
        return out

num_layers = 3
model = MultiLayerLSTMCell(input_size, hidden_size, output_size, num_layers)
print(model)
MultiLayerLSTMCell(
  (cells): ModuleList(
    (0): LSTMCell(10, 20)
    (1): LSTMCell(20, 20)
    (2): LSTMCell(20, 20)
  )
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

Practical Example: Text Classification

Data Preparation

batch_size = 32
seq_len = 50
input_size = 100  
hidden_size = 128
output_size = 2  
num_layers = 2

inputs = torch.randn(batch_size, seq_len, input_size)
labels = torch.randint(0, output_size, (batch_size,))

Define Model, Loss, Optimizer

model = MultiLayerLSTM(input_size, hidden_size, output_size, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Training

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    
    loss.backward()
    optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
Epoch [1/100], Loss: 0.6932
Epoch [2/100], Loss: 0.6796
Epoch [3/100], Loss: 0.6660
Epoch [4/100], Loss: 0.6512
Epoch [5/100], Loss: 0.6342
Epoch [6/100], Loss: 0.6139
Epoch [7/100], Loss: 0.5893
Epoch [8/100], Loss: 0.5595
Epoch [9/100], Loss: 0.5239
Epoch [10/100], Loss: 0.4827

LSTM contains multiple gating mechanisms, so gradients must propagate through these gates during backpropagation. This increases the computation required per update, so convergence may be slower than a simple RNN. However, thanks to its gating structure, LSTM greatly mitigates gradient vanishing and exploding.

Notes

  1. Batch dimension: Setting batch_first=True means inputs/outputs follow the shape (batch, seq, feature).
  2. Initialization: Hidden and cell states must be initialized for each batch, typically as zeros.
  3. Gradient clipping: Use torch.nn.utils.clip_grad_norm_ to prevent gradient explosion.
  4. Variable-length sequences: Use pack_padded_sequence and pad_packed_sequence.
  5. Bidirectional LSTM: Enable via bidirectional=True.
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)

Summary

This article introduced the basic principles of LSTM and demonstrated its implementation in PyTorch. Through theoretical formulas and code examples, we explored single-layer/multi-layer LSTMs, LSTMCell, and multi-layer LSTMCell. Although Transformers have achieved breakthrough results in many tasks, LSTMs remain widely used in practical applications due to their simplicity and effectiveness.

References