banner
RNN architecture
Learning to remember
#️⃣   ⌛  ~1 h 🤓  Intermediate
12.06.2023
upd:
#55

views-badgeviews-badge
banner
RNN architecture
Learning to remember
⌛  ~1 h
#55


🎓 74/167

This post is a part of the Fundamental NN architectures educational series from my free course. Please keep in mind that the correct sequence of posts is outlined on the course page, while it can be arbitrary in Research.

I'm also happy to announce that I've started working on standalone paid courses, so you could support my work and get cheap educational material. These courses will be of completely different quality, with more theoretical depth and niche focus, and will feature challenging projects, quizzes, exercises, video lectures and supplementary stuff. Stay tuned!


In this article, i will focus on the core concepts, theoretical underpinnings, and practical implementations of recurrent neural networks (rnns). this piece is intended for individuals with substantial machine learning experience who are looking to deeply understand not only the fundamental structures of rnns, but also the reasoning behind their design, their various important modifications, and the ways in which they can be used effectively in real-world applications. i will follow the general outline below, aiming to combine a thorough theoretical discussion with plenty of practical context and references to research:

  1. introduction
  2. fundamentals of rnn
  3. lstm: long short-term memory networks
  4. gru: gated recurrent units
  5. bidirectional rnn
  6. embedding layer
  7. training and optimization
  8. advanced extensions
  9. case studies and applications

because recurrent neural networks underpin so many sequence-based learning tasks — from language modeling and machine translation to speech recognition and time series analysis — it is especially critical to explore them in detail. i'll begin by defining rnns and setting their historical context, then move into the intricacies of how they process sequence data, including a discussion of backpropagation through time (bptt). i'll dedicate several chapters to the modern variants of rnn, notably lstm and gru, which were introduced to address challenges like vanishing and exploding gradients. these challenges, in turn, limited the depth and sequence length that basic rnns could manage.

the article also addresses advanced extensions such as attention mechanisms (briefly, because they become especially central in transformer-based models), and combinations of rnns with convolutional layers for multi-modal sequence data. i will close by discussing real-world use cases, relevant code snippets, and references to state-of-the-art research papers.

2. fundamentals of rnn

2.1 definition and historical context of recurrent neural networks

recurrent neural networks are a class of artificial neural networks designed to capture temporal or sequential patterns in data. unlike feed-forward networks, where signals only travel in one direction (from input to hidden layers to output layers), rnns contain recurrent (or feedback) connections that can pass information from one time step to the next. this allows rnns to maintain a form of internal state that can, in principle, store information about previous inputs and thereby handle variable-length sequences.

the earliest formulations of rnns date back to the 1980s and 1990s. two notable early architectures include:

  • elman networks: introduced by jeff elman, these networks have a "context" layer that stores a copy of the hidden layer's previous state. the hidden state at time tt therefore depends on both the input at time xtx_t and the hidden state at t1t-1.
  • jordan networks: introduced by michael jordan, these have recurrent connections from the output layer to a set of context units, which feed into the hidden layer.

in modern usage, these older forms of recurrent networks are collectively referred to as "simple rnns" or "vanilla rnns." while conceptually straightforward, they exhibit significant difficulties in practice when processing long sequences due to exploding and vanishing gradients (problems discussed more thoroughly in chapter 2.5). hence, later researchers, notably hochreiter and schmidhuber (1997), introduced the lstm architecture. cho and gang (2014) proposed the gated recurrent unit (gru) variant, which similarly addresses the limitations of simple rnns.

2.2 importance of sequence modeling in machine learning

sequence modeling is central to many tasks:

  • language modeling: words or tokens in a text are processed sequentially.
  • speech recognition: audio waveforms are inherently sequential in time.
  • time series analysis: from financial forecasting to industrial sensor monitoring, the data are temporal.
  • music generation: notes unfold over time, each depending on the context of preceding notes.
  • video processing: frames in video are sequential with temporal dependencies, though many modern approaches rely on 3d convolutions or transformers as well.

the unique advantage of rnns is their parameter sharing across different time steps. in a feed-forward network, each input dimension might be associated with a separate set of weights, but in rnns, the same recurrent weight matrices are reused at each time step, capturing patterns that shift in time.

2.3 core rnn architecture

a simple rnn cell transforms an input xtx_t and a hidden state ht1h_{t-1} from the previous time step into the next hidden state hth_t using a function such as

ht=σ(Whhht1+Wxhxt+bh), h_t = \sigma(W_{hh} \, h_{t-1} + W_{xh} \, x_t + b_h),

where:

  • WhhW_{hh} is the recurrent weight matrix that connects the hidden state from t1t-1 to the hidden state at time tt.
  • WxhW_{xh} is the input-to-hidden weight matrix.
  • bhb_h is the bias vector.
  • σ()\sigma(\cdot) is typically a nonlinear activation function like (\tanh) or (\mathrm{ReLU}).

this is the minimal building block. in a classification problem, the output at time step tt might be computed as:

yt=softmax(Whyht+by), y_t = \mathrm{softmax}(W_{hy} \, h_t + b_y),

where WhyW_{hy} is the hidden-to-output weight matrix and byb_y is the bias for the output layer.

2.4 forward propagation through time

in typical feed-forward networks, we apply transformations layer by layer. in rnns, we apply the same cell transformations across time steps. conceptually, we can "unfold" the rnn for TT steps. for a sequence x1,x2,...,xTx_1, x_2, ..., x_T, the hidden state at each step is computed in a chain-like manner.

unrolled representation
to visualize the recurrence properly, we often depict the network "unrolled" over time:

x_1 ---> [ RNN cell ] ---> h_1 ---> [ RNN cell ] ---> h_2 ---> ... ---> [ RNN cell ] ---> h_T

the important thing is that the same set of parameters (WhhW_{hh}, WxhW_{xh}, etc.) is used at every time step. this drastically reduces the total number of parameters needed to process sequences of arbitrary length.

2.5 backpropagation through time (bptt)

training rnns typically involves backpropagation through time (bptt), an extension of the standard backpropagation algorithm that accounts for the unrolled structure. to compute the gradients of a loss function with respect to the model parameters, we sum the contributions of each time step's partial gradients. in practice, we usually perform truncated bptt, limiting how far we go back in time before stopping gradient flow.

vanishing and exploding gradients
two significant training stability problems often occur during bptt:

  1. vanishing gradients: if the eigenvalues of the recurrent weight matrix WhhW_{hh} (or effective jacobian) are less than 1 in magnitude, repeated multiplication over many time steps can shrink the gradients exponentially. that causes them to approach zero, thus "vanishing."
  2. exploding gradients: if the eigenvalues are greater than 1, repeated multiplication can cause gradients to grow exponentially large, "exploding" in magnitude and making training unstable.

mitigation strategies

  • gradient clipping: bounding the gradient norm to a fixed value to avoid exploding gradients.
  • careful initialization: using orthogonal or identity initializations for WhhW_{hh}.
  • gated architectures (discussed in chapters 3 and 4) that let the model learn to maintain stable long-range dependencies.

2.6 typical use cases for simple rnns

despite the issues of vanishing/exploding gradients, simple rnns can still be used effectively in tasks where sequences are not excessively long or complicated:

  • basic sentiment classification where input sequences are typically short.
  • certain straightforward time-series tasks with limited or short memory requirements.
  • small prototype experiments or teaching demonstrations in academic contexts.

these simpler rnn variants remain conceptually valuable for building intuition about sequential processing.

3. lstm: long short-term memory networks

3.1 motivation

as mentioned, a vanilla rnn struggles with learning dependencies across many time steps. once sequences become moderately long (e.g., 50–100 steps or more), standard rnns have difficulty propagating useful gradients back to the early part of the sequence.

in response, hochreiter & schmidhuber (1997) introduced the long short-term memory (lstm) architecture. by introducing carefully designed gating mechanisms and an internal "cell state," the lstm allows the model to selectively remember and forget information over potentially large time intervals. this effectively reduces the vanishing gradient problem and enables stable training over longer sequences.

3.2 architecture of the lstm cell

the architecture of an lstm cell includes:

  • cell state (CtC_t): acts as an internal "conveyor belt" that can carry information across many time steps unchanged, subject only to minor linear interactions. this is the key to preserving long-range dependencies.
  • hidden state (hth_t): the traditional hidden state that is output to the next layer or next time step.
  • three gates:
    1. forget gate (ftf_t): decides which information to keep and which to discard from the cell state.
    2. input gate (iti_t): determines how much new information enters the cell state from the current input.
    3. output gate (oto_t): controls how much of the cell state flows into the hidden state at time tt.

lstm equations
the typical implementation of an lstm includes the following steps at time tt:

  1. forget gate: ft=σ(Wxfxt+Whfht1+bf)f_t = \sigma(W_{xf} x_t + W_{hf} h_{t-1} + b_f)
    determines which parts of the previous cell state Ct1C_{t-1} to forget.
  2. input gate: it=σ(Wxixt+Whiht1+bi)i_t = \sigma(W_{xi} x_t + W_{hi} h_{t-1} + b_i)
    decides how much of the candidate cell update C~t\tilde{C}_t gets added to Ct1C_{t-1}.
  3. candidate cell state: C~t=tanh(WxCxt+WhCht1+bC)\tilde{C}_t = \tanh(W_{xC} x_t + W_{hC} h_{t-1} + b_C)
    is a typical tanh layer that provides new candidate values that could be added to the cell state.
  4. update cell state: Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t
    merges the old cell state and the newly scaled candidate, controlling what to remember and what new information to add.
  5. output gate: ot=σ(Wxoxt+Whoht1+bo)o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
    is a gating factor for the final hidden state.
  6. hidden state: ht=ottanh(Ct)h_t = o_t \odot \tanh(C_t)
    chooses how much of the updated cell state to reveal as hidden output.

above, σ()\sigma(\cdot) denotes the sigmoid function and tanh()\tanh(\cdot) is the hyperbolic tangent function, each gate is pointwise multiplied (\odot), and WxW_{x\cdot} / WhW_{h\cdot} are parameter matrices.

3.3 peephole connections and other lstm variants

the "peephole" variant proposed by gers and schmidhuber (2000) adds direct connections from the cell state to each gate, letting gates access the exact cell state. other variations unify the forget and input gates. overall, the essential principle remains: an lstm cell has a carefully designed architecture for controlling how information is added to, retained in, and extracted from the cell state.

3.4 practical advantages of lstm

  • effective for long sequences: they can capture dependencies across hundreds of time steps.
  • stable gradients: gating mechanisms largely mitigate the vanishing gradient problem.
  • widespread empirical success: used extensively in machine translation (before transformers became the norm), speech recognition, text classification, etc.

3.5 limitations of lstm

  • computational cost: the gating mechanisms add more parameters.
  • long inference times: longer sequences must be processed step by step.
  • memory usage: storing states for all steps is more expensive.

nonetheless, lstms remain a proven architecture in many production systems and remain relevant, especially for tasks like smaller-scale language modeling or specialized rnn-based pipelines.

4. gru: gated recurrent units

4.1 motivation and background

the gated recurrent unit (gru), introduced by cho and gang (2014), is a simplification of the lstm architecture. it merges the forget and input gates into a single gate and combines the cell state and hidden state, yielding fewer parameters and sometimes equally strong performance. this was especially relevant in resource-constrained scenarios or when quick iteration is needed.

4.2 architecture of the gru cell

in a gru, the gating system is conceptually simpler:

  • update gate (ztz_t): decides how much of the previous hidden state to keep around.
  • reset gate (rtr_t): decides how to combine the new input with the previous hidden state.

the hidden state hth_t serves a role similar to that of the lstm's cell state CtC_t combined with the hidden state.

the main equations are:

  1. update gate: zt=σ(Wxzxt+Whzht1+bz)z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
  2. reset gate: rt=σ(Wxrxt+Whrht1+br)r_t = \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r)
  3. candidate hidden state: h~t=tanh(Wxhxt+rt(Whhht1)+bh)\tilde{h}_t = \tanh(W_{xh} x_t + r_t \odot (W_{hh} h_{t-1}) + b_h)
  4. final hidden state: ht=ztht1+(1zt)h~th_t = z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t

intuitively, if ztz_t is close to 1, the model largely retains the previous hidden state and ignores the candidate. if ztz_t is close to 0, the model overwrites the previous hidden state with the new candidate. the reset gate rtr_t determines how to blend old information into the candidate.

4.3 comparison between gru and lstm

  • parameters: grus have fewer parameters and are simpler to implement.
  • memory usage: slightly lower for grus.
  • expressive power: lstms can represent more intricate gating behaviors due to separate input and forget gates.
  • empirical performance: can be similar, though details depend on the task and dataset specifics.

4.4 use-case recommendations

  • if memory or compute is constrained, a gru might be preferable.
  • if you suspect your data requires a strong capacity for long-term memory, you might choose an lstm, though grus can also excel.
  • many speech- and text-based tasks have historically found grus to be a sweet spot between performance and overhead.

4.5 variable-length input and partial sequences

both lstm and gru can handle variable-length sequences natively by simply unrolling the recurrence as far as needed. for partial sequences or streaming data, one can maintain an internal hidden state and update the model step by step.

5. bidirectional rnn

5.1 concept of forward and backward passes

a bidirectional recurrent neural network processes the sequence in both directions: from t=1t=1 to t=Tt=T and from t=Tt=T down to t=1t=1. the hidden states for the forward pass ht \overrightarrow{h}_t and backward pass ht\overleftarrow{h}_t are concatenated at each time step to form a combined representation:

ht=[ht;ht]. h_t = [\overrightarrow{h}_t; \overleftarrow{h}_t].

the result is often richer, as each output state is informed not just by the current input and historical context, but also by "future" context.

5.2 integrating bidirectional layers with lstm and gru

bidirectional rnns are typically used in tasks such as speech recognition, machine translation (particularly when used as an encoder for a subsequent decoder), and question-answering, especially in tasks where the entire sequence is available. to build a bidirectional lstm or gru, one can simply place a forward lstm/gru and a backward lstm/gru side by side, then concatenate or add their outputs.

practically, major frameworks (pytorch, tensorflow, keras) provide built-in layers like bidirectional(lstm(...)) that handle the forward-backward unrolling behind the scenes.

5.3 performance improvements and trade-offs

  • pros: capturing context from both directions in a single pass can yield improved accuracy on tasks like part-of-speech tagging or sentiment classification where future context helps interpret the meaning of earlier tokens.
  • cons: at inference time, you must wait to see the entire sequence to process it in reverse. it is thus not suitable for real-time streaming tasks. also, memory usage is higher.

5.4 real-world use cases

  • speech recognition: audio frames are often processed with a bidirectional rnn for transcription.
  • machine translation: used on the encoder side if the entire source sentence is known in advance.
  • question answering: capturing future context is beneficial for analyzing a question's structure and relevant context in the input text.

6. embedding layer

6.1 role of embeddings in sequence modeling

when dealing with natural language, tokens (words, subwords, characters) are discrete. to feed them into an rnn, we must convert them to numeric vectors that capture semantic and syntactic properties. this is done via an embedding layer.

embedding layers map each token (identified by an index) to a trainable dense vector representation. for example:

embedding(w)Rd \mathrm{embedding}(w) \in \mathbb{R}^d ,

where dd is the embedding dimension.

these embeddings are learned jointly with the rest of the network.

6.2 learning embedding representations

the embedding layer parameters are typically initialized randomly (or with pretrained embeddings from something like word2vec or glove) and then updated by backpropagation.

example: if your vocabulary has 10,000 words and embedding size is 300, the embedding layer is effectively a 10,000 x 300 matrix, where each row is the vector for a particular word index.

6.3 incorporating embeddings in rnn architectures

once a token is converted to its embedding ete_t, that embedding is used as the input xtx_t for the rnn at time tt.

for example, a stacked rnn might look like:

raw_token -> embedding -> (RNN cell) -> ...

this approach is standard in nlp tasks.

6.4 pretrained vs. trainable embeddings

  • pretrained (e.g., glove, word2vec, fasttext): can help the model converge faster if domain vocabulary matches the pretrained corpus.
  • trainable: for domain-specific tasks, learning embeddings from scratch is often beneficial. in practice, some mix is used, or embeddings are pretrained and then fine-tuned.

6.5 handling out-of-vocabulary and rare words

a major issue in nlp is how to handle tokens not seen in training or extremely rare words. solutions include:

  • subword embeddings: splitting tokens into morphological units.
  • byte-pair encoding (bpe): widely used in modern nlp.
  • character-based embedding: letting the rnn handle characters, though that typically requires deeper or more advanced networks for performance.

7. training and optimization

7.1 data preprocessing for rnn models

preparing sequence data often requires:

  1. tokenization and normalization (text).
  2. sequence padding or truncation to achieve uniform batch shapes.
  3. batching strategy: deciding how to batch sequences of different lengths. common approaches:
    • bucketed batching: group sequences by similar length.
    • masking: typically used so that the network ignores padded tokens.

7.2 hyperparameter tuning

common hyperparameters to tune:

  • learning rate: rnns can be sensitive to the chosen optimizer and learning rate schedule.
  • hidden state dimension: the size of hth_t strongly impacts capacity.
  • number of layers: deeper rnns are more expressive but can be harder to train.
  • dropout rate: especially crucial in recurrent connections to prevent overfitting.

7.3 regularization techniques

  • recurrent dropout: random dropout of hidden connections within the rnn cell.
  • input/output dropout: dropping inputs or outputs to/from the rnn.
  • weight decay: standard L2L_2 regularization.
  • early stopping: stopping training when validation loss stagnates or worsens.

these methods help reduce overfitting, particularly in tasks involving large networks but limited data.

7.4 monitoring and dealing with overfitting

validation metrics: for language modeling, perplexity or cross-entropy. for classification, accuracy or f1.

overfitting indicators: training loss decreases but validation loss stops decreasing or starts increasing.

techniques to mitigate:

  • reduce model size.
  • apply dropout more aggressively.
  • gather more data or use data augmentation if feasible (in certain domains like text, data augmentation is non-trivial but possible with synonyms replacement or back-translation).

7.5 practical tools and libraries

  • pytorch: torch.nn.rnn, torch.nn.lstm, torch.nn.gru, or torch.nn.rnncell if going low-level.
  • tensorflow/keras: keras.layers.simpleRNN, keras.layers.lstm, keras.layers.gru, or their bidirectional wrappers.
  • mxnet: mxnet.gluon.rnn.

each library typically provides easy-to-use modules for building advanced rnn architectures, including multi-layer stacked rnns, residual connections, and more.

8. advanced extensions

8.1 attention mechanisms

while not strictly part of a standard rnn cell, attention revolutionized sequence-to-sequence tasks by allowing the model to "focus" on certain parts of the input sequence when predicting each token of the output.

key idea: rather than compressing the entire source sequence into a single vector (like a naive encoder-decoder rnn might do), an attention mechanism provides context vectors that vary at each output time step. this overcame the bottleneck of a single representation and improved performance in tasks such as machine translation.

soft attention: introduced by bahdanau and gang (2014) and luong and gang (2015).

transformers: introduced by vaswani and gang (2017), removed the rnn altogether in their architecture, replacing recurrence with purely attention-based blocks.

8.2 transformers vs. rnn-based architectures

  • parallelization: rnns process sequences step by step, limiting parallelization. transformers handle all positions simultaneously via self-attention.
  • long-range dependencies: attention-based models handle them more gracefully, whereas rnns still might degrade for very long sequences.
  • resource usage: transformers can require large memory for extremely long sequences.

that said, rnns remain valuable especially for smaller or streaming tasks.

8.3 combining rnns with convolutional networks

in tasks like audio or video processing, one can combine a convolutional front end that processes local structure in time (or space-time for video) with a subsequent rnn that captures longer-range dependencies. for instance, a convolutional layer can embed short audio windows into higher-level features, then the rnn processes these features sequentially.

8.4 memory-augmented networks

beyond lstms, researchers have proposed advanced memory constructs, such as neural turing machines and differentiable neural computers (graves and gang, 2014, 2016), which explicitly handle read and write operations to an external memory. these approaches aim to let the network store information for indefinite lengths of time.

8.5 reinforcement learning with rnns

when an agent interacts with an environment, partial observability can require maintaining hidden states over time. rnns such as lstms or grus can track historical observations for decision making. deepmind's use of lstms in deep reinforcement learning for tasks like atari is a well-known example (mnih and gang, 2015).

8.6 multi-task learning

rnns are sometimes used in multi-task settings, for example, a single rnn might be used to do language modeling and sequence tagging simultaneously, sharing hidden layers. the gating and memory aspects can help the model not "forget" essential features across tasks.

8.7 hierarchical rnns

a hierarchical approach might group tokens into sentences, paragraphs, or documents, with an rnn at each level. for instance, an rnn processes words in a sentence to produce a sentence embedding, then another rnn processes these sentence embeddings at the paragraph level, etc.

9. case studies and applications

9.1 natural language processing

9.1.1 machine translation
classic sequence-to-sequence rnns used an encoder rnn to encode a source sentence into a context vector, then a decoder rnn to generate target words step by step.

  • coverage: attention-based rnns improved translation quality by letting the decoder attend to different parts of the source.

9.1.2 sentiment analysis
rnns ingest word embeddings of a sentence in order, culminating in a final hidden state that can be used for classification (positive/negative).

9.1.3 text summarization
similar to machine translation, except the "target" is a compressed version of the input.

9.2 speech recognition and generation

9.2.1 end-to-end automatic speech recognition (asr)
frameworks like deep speech (hannun and gang, 2014) used multi-layer rnns (often bidirectional lstms) on spectrogram frames to map them directly to phoneme or character sequences.

9.2.2 text-to-speech (tts)
networks like tacotron combined a recurrent seq2seq approach with attention to generate spectrogram frames from text, followed by a vocoder to synthesize waveforms.

9.3 time series forecasting and anomaly detection

rnns can be used to forecast future values in a univariate or multivariate time series. the ability to keep track of hidden states over time helps the model identify underlying patterns or anomalies.

9.4 real-world implementations

  • personalized recommendations: sequence-based recommendation systems for user event histories.
  • conversational ai: older chatbots used hierarchical rnns. new systems more often adopt transformers.
  • healthcare: event sequences in medical records can be used for diagnoses or risk prediction.

9.5 challenges and future directions

  1. interpretability: gating helps, but rnns are still often viewed as black boxes.
  2. scalability: step-by-step nature can be slow for large sequences.
  3. hybrid approaches: mixture of rnns, attention, or memory modules.
  4. transformer dominance: many modern tasks have shifted to transformers, but rnns remain valuable for specialized or resource-constrained cases.

below, i will provide expanded details, including code snippets for implementing various rnn structures using python (keras and pytorch examples). these examples illustrate the typical usage patterns for rnns, lstms, and grus in practice. i will also reference additional research as relevant.


expanded details and implementation code

a. code example: simple rnn for sentiment analysis (keras)

the snippet below shows how one might implement a simple rnn for a text classification task, such as imdb sentiment classification:


import numpy as np
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense, Embedding
from tensorflow.keras.optimizers import Adam

# parameters
max_features = 5000   # size of vocabulary
maxlen = 100          # cut texts after this number of words
batch_size = 32
embedding_dim = 128
rnn_units = 64
epochs = 5

# load the data
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

# pad sequences
x_train = pad_sequences(x_train, maxlen=maxlen)
x_test  = pad_sequences(x_test,  maxlen=maxlen)

# build the model
model = Sequential()
model.add(Embedding(input_dim=max_features, output_dim=embedding_dim, input_length=maxlen))
model.add(SimpleRNN(units=rnn_units, activation='tanh'))
model.add(Dense(1, activation='sigmoid'))

# compile
model.compile(loss='binary_crossentropy',
              optimizer=Adam(learning_rate=0.001),
              metrics=['accuracy'])

# train
model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(x_test, y_test),
          verbose=1)

# evaluate
test_loss, test_acc = model.evaluate(x_test, y_test, batch_size=batch_size)
print("test accuracy:", test_acc)

this is a simplistic example using a standard "vanilla" rnn layer (SimpleRNN). for short sequences, it can work acceptably. for longer sequences, the model might struggle with capturing context from the beginning of the text.


b. code example: lstm for time series forecasting (pytorch)

imagine we have a univariate time series, and we want to forecast the next value based on the last 20 time steps. we can do this with an lstm in pytorch:


import torch
import torch.nn as nn
import torch.optim as optim

class LSTMForecast(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1):
        super(LSTMForecast, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # x shape: (batch_size, seq_length, input_size)
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        
        out, (hn, cn) = self.lstm(x, (h0, c0))
        # out shape: (batch_size, seq_length, hidden_size)
        # we want only the last time step
        out = out[:, -1, :]  # (batch_size, hidden_size)
        out = self.fc(out)
        return out

# synthetic data
import numpy as np
seq_length = 20
num_samples = 1000

# random walk or some synthetic time series
time_series = np.zeros(num_samples + seq_length)
for i in range(1, num_samples + seq_length):
    time_series[i] = time_series[i-1] + np.random.normal()

x_data = []
y_data = []
for i in range(num_samples):
    x_data.append(time_series[i:i+seq_length])
    y_data.append(time_series[i+seq_length])

x_data = np.array(x_data, dtype=np.float32).reshape(num_samples, seq_length, 1)
y_data = np.array(y_data, dtype=np.float32).reshape(num_samples, 1)

train_size = int(num_samples * 0.8)
x_train = x_data[:train_size]
y_train = y_data[:train_size]
x_test  = x_data[train_size:]
y_test  = y_data[train_size:]

x_train_torch = torch.from_numpy(x_train)
y_train_torch = torch.from_numpy(y_train)
x_test_torch  = torch.from_numpy(x_test)
y_test_torch  = torch.from_numpy(y_test)

model = LSTMForecast()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 30
batch_size = 32
for epoch in range(epochs):
    # simple mini-batch iteration
    permutation = torch.randperm(train_size)
    for i in range(0, train_size, batch_size):
        indices = permutation[i:i+batch_size]
        batch_x = x_train_torch[indices]
        batch_y = y_train_torch[indices]
        
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        
        # gradient clipping 
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        
        optimizer.step()
    
    # check training loss
    with torch.no_grad():
        train_preds = model(x_train_torch)
        train_loss = criterion(train_preds, y_train_torch).item()
        test_preds = model(x_test_torch)
        test_loss = criterion(test_preds, y_test_torch).item()
    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")

print("training complete!")

here, we see the typical approach for time series:

  1. we define a recurrent architecture (nn.lstm) to take an input of shape (batch_size, seq_length, input_size).
  2. we do a forward pass, capturing the last time step's hidden representation.
  3. we add a final fully connected layer for regression output.
  4. we train using an mse loss, with gradient clipping.

c. code example: bidirectional gru for nlp sequence labeling (pytorch)

for tasks like named entity recognition (ner) or part-of-speech (pos) tagging, a sequence label must be emitted for each token. a bidirectional gru can be used:


import torch
import torch.nn as nn
import torch.optim as optim

class BiGRU(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_labels):
        super(BiGRU, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.bigru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(hidden_dim*2, num_labels)
    
    def forward(self, x):
        # x shape: (batch_size, seq_length)
        emb = self.embed(x)  # (batch_size, seq_length, embed_dim)
        out, h_n = self.bigru(emb)  # out shape: (batch_size, seq_length, hidden_dim * 2)
        logits = self.classifier(out)  # (batch_size, seq_length, num_labels)
        return logits

# example usage
vocab_size = 5000
embed_dim = 128
hidden_dim = 64
num_labels = 10  # e.g. # of pos tags or entity classes

model = BiGRU(vocab_size, embed_dim, hidden_dim, num_labels)
x_sample = torch.randint(0, vocab_size, (8, 12))  # batch of 8, seq_length=12
output = model(x_sample)
print("output shape:", output.shape)  # expected: [8, 12, 10]

in a real training loop for sequence labeling, you would compute cross-entropy at each token position. the final dimensionality is (batch_size, seq_length, num_labels), so you can measure the classification loss for each token.


references to advanced research

  • hopfield networks (hopfield, 1982) introduced the idea of stable attractor states and memory, which preceded modern rnn memory-like concepts.
  • elman networks (elman, 1990) introduced the notion of simple context units.
  • hochreiter & schmidhuber (1997): original paper on lstms.
  • gers and gang (2000): introduced peephole connections.
  • cho and gang (2014): introduced the gru, widely used in neural machine translation.
  • bahdanau and gang (2014): introduced the attention mechanism for neural machine translation.
  • vaswani and gang (2017): introduced the transformer model, which replaced recurrence with multi-head self-attention.

lengthier theoretical insights

rnn and dynamical systems perspective

rnns can be framed as discrete-time dynamical systems, where the hidden state update

ht=f(ht1,xt;θ)h_t = f(h_{t-1}, x_t; \theta)

resembles a step in a dynamical system with parameters (\theta). from a systems theory standpoint, one can analyze stability by examining the eigenvalues of the jacobian of f()f(\cdot). large eigenvalues cause expansions in the state space (exploding gradients), while small eigenvalues cause contractions (vanishing gradients).

truncated bptt

in practice, if sequences are extremely long, it is common to train rnns by unrolling for only a fixed window size of, say, 20 or 30 steps (the "truncation" length). the hidden state is then detached from the computational graph before continuing. this prevents computational blow-up but also means that the model might not fully learn extremely long dependencies.

second-order methods

some research has investigated second-order optimization methods (using curvature information) to address the difficulties of training rnns. these are rarely used in mainstream libraries due to computational overhead, but occasionally appear in large-scale specialized systems.

interpretability and gating analyses

some interpretability research tries to see how gating in lstms or grus behaves. for instance, does the forget gate open or close for certain input patterns? sometimes gating patterns can be correlated with semantic boundaries in text.


conclusion

i have walked through the fundamentals and modern forms of recurrent neural networks, covering vanilla rnns, lstms, grus, and bidirectional rnns, as well as essential gating ideas, the gating equations, training procedures, advanced uses, and real-world applications. while rnns are no longer the top approach for many language tasks — especially after the emergence of attention-only transformers — they remain a foundational technique for sequence modeling. understanding rnns is important not only for historical perspective but also for tackling certain specialized tasks (small resource settings, streaming tasks, certain time-series forecasting problems, or scenarios requiring explicit stateful processing).

above all, rnns underscore the importance of memory in machine learning: how to preserve and propagate relevant context from previous inputs to influence future predictions. lstms and grus introduced gating to make memory management more robust over time, mitigating the vanishing gradient problem. these gating ideas strongly influenced subsequent architectures, including many of the memory-augmented networks and the gating logic in some transformer variants.

rnns also remain relevant in fields such as music generation, real-time inference on low-power devices, or any domain in which a step-by-step approach is natural. while i've provided thorough background, theoretical commentary, and code snippets, there is of course much more that can be explored, including specialized initialization methods, advanced regularization approaches, and synergy with convolutional or attention-based layers.

i hope this comprehensive overview helps to build a deeper understanding of rnn architecture and fosters readiness to implement, debug, optimize, and apply rnns in a variety of projects that involve sequential or time-dependent data.


additional references

  • graves a., liwicki m., fernández s., bertolami r., bunke h., schmidhuber j. "a novel connectionist system for unconstrained handwriting recognition." ieee transactions on pattern analysis and machine intelligence (2008).
  • sutskever i., vinyals o., le q.v. "sequence to sequence learning with neural networks." neurips (2014).
  • lipton z.c., kale d.c., elkan c., wetzell r. "learning to diagnose with lstm recurrent neural networks." iclr (2016).
  • vaswani a. and gang "attention is all you need." neurips (2017).

mysterious_frog

An image was requested, but the frog was found.

Alt: "rnn illustration"

Caption: "a conceptual diagram of an rnn unrolled over time"

Error type: missing path

kofi_logopaypal_logopatreon_logobtc-logobnb-logoeth-logo
kofi_logopaypal_logopatreon_logobtc-logobnb-logoeth-logo