Recurrent Neural Networks (RNN)
Recurrent Neural Networks (RNN)

Recurrent Neural Networks (RNN)

in
  1. Basic Principles of RNNs
    1. What Is an RNN?
    2. Structure of an RNN
    3. Gradient Vanishing and Exploding in RNNs
  2. Using RNNs in PyTorch
    1. 1. Basic RNN
    2. 2. Multi-Layer RNN
    3. 3. Using RNNCell
    4. 4. Multi-Layer RNNCell
  3. Practical Example: Text Classification
    1. Data Preparation
    2. Model, Loss, and Optimizer
    3. Training
    4. Example Output
  4. Notes
  5. Summary
  6. References

Recurrent Neural Networks (RNNs) are particularly effective in handling sequential data such as natural language processing, time-series prediction, and more. This post provides a detailed introduction to the principles behind RNNs and demonstrates how to use RNNs in PyTorch, including constructing multi-layer RNNs, using RNNCell, and building multi-layer RNNCell models. We will also analyze the internal mechanisms of RNNs using formulas and provide practical code examples to help you gain a comprehensive understanding.

Basic Principles of RNNs

What Is an RNN?

Traditional feed-forward neural networks perform well on tasks involving independent samples, but they struggle with sequential data because they cannot capture temporal dependencies. RNNs introduce “recurrent connections,” allowing the output from the previous time step to influence the current computation, making it possible to capture temporal patterns within sequences.

Structure of an RNN

The core of an RNN is a recurrent unit capable of transmitting information across time steps. For an input sequence of length : , the computation at each time step proceeds as follows:

  1. Hidden State Update:

    where:

    • is the hidden state at time
    • is the input at time
    • and are weight matrices from input to hidden state and from hidden state to hidden state, respectively
    • is the bias term
    • is the activation function (ReLU is also used in some cases)
  2. Output Calculation (Optional):

    where:

    • is the output at time
    • is the weight matrix from hidden state to output
    • is the output bias term

Gradient Vanishing and Exploding in RNNs

During training, RNNs often suffer from vanishing or exploding gradients, especially when processing long sequences. Since gradients must propagate through repeated weight multiplications across time steps, they may exponentially shrink or grow. LSTM and GRU architectures were introduced to mitigate these issues by incorporating gating mechanisms that stabilize gradient flow.

Using RNNs in PyTorch

PyTorch offers a powerful RNN module that simplifies building and training RNN models. Below are examples of using a standard RNN, multi-layer RNN, RNNCell, and multi-layer RNNCell.

1. Basic RNN

import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, hn = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

input_size = 10
hidden_size = 20
output_size = 1
model = SimpleRNN(input_size, hidden_size, output_size)
print(model)
SimpleRNN(
  (rnn): RNN(10, 20, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

2. Multi-Layer RNN

Multi-layer RNNs (deep RNNs) stack several recurrent layers to capture more complex temporal structures.

class MultiLayerRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MultiLayerRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, hn = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out

num_layers = 3
input_size = 10
hidden_size = 20
output_size = 1
model = MultiLayerRNN(input_size, hidden_size, output_size, num_layers)
print(model)
MultiLayerRNN(
  (rnn): RNN(10, 20, num_layers=3, batch_first=True)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

3. Using RNNCell

nn.RNNCell gives fine-grained control of per-time-step computation, useful when customizing the recurrence loop.

class RNNWithCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNWithCell, self).__init__()
        self.hidden_size = hidden_size
        self.rnn_cell = nn.RNNCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.hidden_size)
        for t in range(seq_len):
            h = self.rnn_cell(x[:, t, :], h)
        return self.fc(h)

model = RNNWithCell(input_size, hidden_size, output_size)
print(model)
RNNWithCell(
  (rnn_cell): RNNCell(10, 20)
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

4. Multi-Layer RNNCell

Stacking multiple RNNCell units manually simulates a deep RNN with full control over the recurrence of each layer.

class MultiLayerRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(MultiLayerRNNCell, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.cells = nn.ModuleList([
            nn.RNNCell(input_size if i == 0 else hidden_size, hidden_size)
            for i in range(num_layers)
        ])
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = [torch.zeros(batch_size, self.hidden_size) for _ in range(self.num_layers)]
        for t in range(seq_len):
            input_t = x[:, t, :]
            for layer in range(self.num_layers):
                h[layer] = self.cells[layer](input_t, h[layer])
                input_t = h[layer]
        return self.fc(h[-1])

num_layers = 3
model = MultiLayerRNNCell(input_size, hidden_size, output_size, num_layers)
print(model)
MultiLayerRNNCell(
  (cells): ModuleList(
    (0): RNNCell(10, 20)
    (1): RNNCell(20, 20)
    (2): RNNCell(20, 20)
  )
  (fc): Linear(in_features=20, out_features=1, bias=True)
)

Practical Example: Text Classification

Data Preparation

batch_size = 32
seq_len = 50
input_size = 100
hidden_size = 128
output_size = 2
num_layers = 2

inputs = torch.randn(batch_size, seq_len, input_size)
labels = torch.randint(0, output_size, (batch_size,))

Model, Loss, and Optimizer

model = MultiLayerRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Training

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Example Output

Epoch [1/10], Loss: 0.6801
Epoch [2/10], Loss: 0.5558
Epoch [3/10], Loss: 0.4489
...
Epoch [10/10], Loss: 0.0228

Notes

  1. Batch Dimension: Use batch_first=True so inputs follow (batch, seq, feature) format.
  2. Hidden State Initialization: Reset hidden states before each batch.
  3. Gradient Clipping: Use torch.nn.utils.clip_grad_norm_ to prevent gradient explosion.
  4. Variable-Length Sequences: Use pack_padded_sequence and pad_packed_sequence for efficient processing.

Summary

This article introduced the core principles of RNNs and their PyTorch implementation. By combining theoretical formulas with code examples, we demonstrated how to build single-layer and multi-layer RNNs, as well as flexible models based on RNNCell. Thanks to their simplicity and ability to capture temporal patterns, RNNs remain widely used in many sequence modeling tasks.

References