

🎓 164/167
This post is a part of the Scaling & distributed learning educational series from my free course. Please keep in mind that the correct sequence of posts is outlined on the course page, while it can be arbitrary in Research.
I'm also happy to announce that I've started working on standalone paid courses, so you could support my work and get cheap educational material. These courses will be of completely different quality, with more theoretical depth and niche focus, and will feature challenging projects, quizzes, exercises, video lectures and supplementary stuff. Stay tuned!
Scaling up model training has become one of the most pivotal aspects of advancing modern machine learning and deep learning systems. As new research breakthroughs push model architectures toward unprecedented size and depth, data scientists and engineers face the growing challenge of working with ever-larger datasets and model parameters that can easily run into the billions. In recent years, the success of massive language models, complex computer vision architectures, and multi-modal pipelines (combining images, text, audio, and beyond) has shown that many important breakthroughs arise when researchers successfully train these large models on large-scale computing infrastructures.
This portion of the course focuses on ""Training models at scale, pt. 1" — primarily on single-GPU methods that allow you to push the limits of memory, compute efficiency, and training speed before moving on to multi-GPU and distributed approaches in the follow-up parts. I'll discuss the primary motivations for scaling, the core technical concepts (ranging from memory optimization to hardware considerations), and practical techniques that can help you train bigger models within the constraints of a single GPU. While distributed and multi-node training can further unlock the ability to handle model sizes that simply won't fit on a single accelerator, you may be surprised at how far you can go with a thoughtful, methodical approach to single-device scaling.
By the end of this article, you'll know about memory bottlenecks, gradient accumulation, mixed precision training, gradient checkpointing, asynchronous computations, and more. This foundation will enable you to confidently use advanced frameworks like PyTorch, TensorFlow, and JAX with an eye toward maximizing throughput and memory usage. Along the way, I'll reference select research papers (e.g., from NeurIPS, ICML, JMLR) that have introduced novel techniques or refined these strategies in real-world deployments.
I'm excited to guide you through the essential building blocks needed to scale your model training to new heights — even when you're starting on just a single GPU. Once you internalize these techniques, you'll be fully prepared to move on to distributed setups and specialized hardware, thereby completing the journey of large-scale training methodologies.
Why scaling?
The modern drive to scale bigger and bigger
There's a remarkable trend in AI research: every year, the models that set new benchmarks in fields like natural language processing, image generation, speech synthesis, and recommendation systems tend to be bigger in terms of trainable parameters and trained on larger datasets. State-of-the-art performance in many tasks seems to correlate closely with the ability to scale, leading to emergent capabilities that smaller models fail to exhibit. The well-known scaling laws (e.g., from OpenAI, Google Research, and others) indicate that performance typically improves predictably as you increase the size of the model and the volume of training data.
However, training these huge models isn't trivial. Beyond the raw compute expense — where you might spend days or weeks of GPU time — there's also the need to address memory constraints, specialized hardware requirements, numerical stability issues, data pipeline bottlenecks, and more. Industry giants such as NVIDIA, Google, and Microsoft have made significant investments in hardware and software solutions to handle these challenges, and many open-source tools have trickled down into community frameworks like PyTorch, TensorFlow, and JAX.
Trade-offs: model size, training time, and resources
Scaling up is not free in terms of either time or financial cost. Once you push beyond the resources available to a single GPU or a single machine, you'll be forced to adopt more complex distributed training paradigms. Even on a single GPU, attempts to train huge models can result in extensive memory overhead, to the point where you spend a significant fraction of time on memory-optimization tricks instead of focusing solely on model design.
That said, certain tasks warrant these trade-offs. For example:
- Large language models for advanced text generation, question answering, or multi-lingual tasks.
- Vision transformers with massive parameter counts for cutting-edge image classification, segmentation, and object detection.
- Reinforcement learning systems dealing with extremely large state and action spaces.
In short, if your goal is state-of-the-art performance on challenging tasks — or if you are dealing with extremely high volumes of data — scaling up might be the right approach.
Real-world success stories
Major breakthroughs in the past few years underscore the value of scaling:
- GPT-series models (OpenAI) for text generation, code generation, and more.
- Vision transformers (ViT), from Google Brain, that match or exceed convolutional networks in image tasks.
- DeepMind's AlphaGo and AlphaZero in the realm of large-scale reinforcement learning and self-play.
Behind each success story is a carefully orchestrated approach to training, resource allocation, and memory management. This article will illuminate how to orchestrate those ingredients on a single GPU, giving you a stepping stone to even more advanced scaling approaches.
Core concepts of single-GPU scaling
Scaling up model training on a single GPU might sound contradictory at first. After all, you're limited to the memory and compute throughput of one device. However, numerous techniques have emerged that allow you to push the boundaries of what's possible:
- Memory bottlenecks: Understanding how GPU memory is allocated among model weights, intermediate activations, gradients, and optimizer states.
- Large batch size management: Many tasks benefit from large batch sizes for speed or performance. Managing large batch sizes on a single GPU requires memory-efficient strategies, gradient accumulation, or parallelization.
- Framework optimizations: PyTorch, TensorFlow, and JAX each offer distinct ways to handle memory, asynchronous executions, and graph optimizations.
This section will walk you through these core considerations. I'll also reference notable papers and best practices from the open-source community to solidify these ideas with real-world examples.
Mixed precision training
One of the most substantial breakthroughs for single-GPU scaling in the last five years has been widespread adoption of mixed precision or half-precision training. Originally, training with lower precision (like 16-bit floating points) introduced significant numerical instability. Modern hardware (NVIDIA's Tensor Cores, for example) and software developments (like Automatic Mixed Precision, or AMP, in PyTorch) have changed the game.
Benefits of FP16/BF16
Using half-precision floating point representations (e.g., ) or the newer bfloat16 () format greatly reduces the memory footprint of your model parameters and activations. For instance, going from 32-bit to 16-bit effectively halves the storage required per floating point number. This leads to:
- Reduced memory usage: Freed memory can be used to store larger models or bigger batches.
- Faster computation: Modern GPUs can handle half-precision ops significantly faster, especially if the underlying architecture has specialized hardware units (e.g., Tensor Cores).
Because of these efficiency gains, it's becoming standard practice in many high-performance training pipelines to use mixed precision by default.
Potential pitfalls and solutions
Mixed precision training can face numerical stability issues, such as underflow or overflow in gradient calculations. Typical solutions involve:
- Loss scaling: Before backpropagating, the loss is scaled by a factor (e.g., 1024). This ensures small gradient values do not underflow. Afterwards, the gradients are unscaled to restore their correct magnitude.
- Automatic Mixed Precision libraries: PyTorch, TensorFlow and JAX each have built-in or library-level support for handling these scaling details under the hood.
import torch
# Automatic Mixed Precision in a training loop:
with torch.autocast(device_type='cuda', dtype=torch.float16):
# forward pass, compute loss
# ...
# backward pass
Example code snippet
Below is a simple snippet in PyTorch to illustrate a typical training loop with automatic mixed precision:
import torch
from torch import nn, optim
model = nn.Linear(1024, 512).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scaler = torch.cuda.amp.GradScaler() # The built-in gradient scaler
for data, target in dataloader:
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
with torch.cuda.amp.autocast():
output = model(data)
loss = nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
In this snippet, half-precision computations are done within the autocast context. The GradScaler ensures that any gradients that might be too small are effectively scaled up prior to the backward pass and subsequently rescaled.
Gradient checkpointing (activation recomputation)
When you run forward propagation in deep models, intermediate activations (outputs of each layer) are stored in memory so that backpropagation can compute the corresponding gradients. For very deep networks, this storage can grow massive, quickly surpassing the available GPU memory.
How checkpointing saves memory
Gradient checkpointing (also called activation recomputation) offers a clever solution. Instead of storing all intermediate activations, you "checkpoint" certain layers' outputs. During backpropagation, the uncheckpointed activations are discarded and recomputed on the fly when needed. This means that in the forward pass, you only store selected activation tensors. At gradient computation time, the forward pass for those uncheckpointed segments is performed again.
- Memory savings: By not keeping every layer's activations around, you drastically reduce memory consumption.
- Compute trade-off: You have to do some extra forward passes during the backward phase, thus increasing total computation time.
The memory vs. compute trade-off can be well worth it if you're hitting GPU memory limits. The overhead of partial re-forwarding is often negligible compared to the gains that let you train deeper or larger networks.
Implementation tips
Frameworks provide built-in or third-party libraries for gradient checkpointing:
- PyTorch:
from torch.utils.checkpoint import checkpoint
def checkpointed_forward(*inputs):
# forward pass of a module or a block
...
out = checkpoint(checkpointed_forward, *inputs)
- TensorFlow: Typically uses the
tf.recompute_grad
or custom Keras layers for partial recomputation. - JAX: Has
jax.checkpoint
or uses functional transformations to define how states are stored or recomputed.
Be mindful of how you choose which layers to checkpoint. A common rule of thumb is to checkpoint only certain blocks (like transformer blocks, or residual blocks in a ResNet) rather than every single layer. That often provides a nice compromise between memory savings and computational overhead.
Gradient accumulation
The concept
Batch size in training is often limited by GPU memory. Large batch sizes help stabilize training and speed up throughput on many tasks. If you can't fit an entire large batch into memory in one go, gradient accumulation is a technique that simulates a large batch by doing multiple forward/backward passes with smaller micro-batches, summing (accumulating) the gradients, and only updating the model weights after a certain number of micro-batches.
Here, is the number of accumulation steps you run before updating. This allows you to have a large effective batch without requiring all the data to be loaded at once.
Balancing batch size and steps
Depending on the model architecture and data distribution, you might find an optimal effective batch size for stable training. The key is to adjust your learning rate accordingly, since a larger effective batch size might allow you to increase the learning rate. Not all tasks benefit from arbitrarily large batch sizes, so you may see diminishing returns after a certain point.
Practical snippet in PyTorch
micro_batch_size = 8
accumulation_steps = 4
effective_batch_size = micro_batch_size * accumulation_steps
model = Model().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i, (data, target) in enumerate(dataloader):
data, target = data.cuda(), target.cuda()
outputs = model(data)
loss = loss_fn(outputs, target)
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
In this structure, the optimizer only updates after 4 micro-batches. This is effectively a single update with a batch size of 32 (8x4), yet without needing the memory to store the entire 32-sample batch at once.
JAX-specific structures
JAX is quickly becoming a popular framework for high-performance machine learning research, particularly in areas that benefit from large-scale TPU clusters. On a single GPU, JAX still brings a powerful set of transformations and a functional programming style that can lead to efficient code. Let's discuss some JAX-specific concepts that help with single-GPU scaling.
XLA compilation
JAX compiles your Python functions into optimized kernels using the XLA compiler. This can result in faster training loops compared to a typical eager-execution mode in frameworks like PyTorch. The key is understanding how to structure your code as pure functions and let JAX handle the transformations.
Vectorization with vmap
The vmap
transformation in JAX automatically vectorizes functions across batch dimensions without you needing to manually batch your data. This can help you eliminate Python loops and push more parallel computation to the GPU. If you combine vmap
with jit
(just-in-time compilation), you can see significant speedups.
import jax
import jax.numpy as jnp
def forward(params, x):
# some forward logic
return ...
@jax.jit
def loss_fn(params, x, y):
preds = forward(params, x)
return jnp.mean((preds - y)**2)
# Vectorized version across multiple data points
batched_loss_fn = jax.vmap(loss_fn, in_axes=(None, 0, 0))
@jax.jit
def step(params, x_batch, y_batch):
grads = jax.grad(
lambda p: jnp.sum(batched_loss_fn(p, x_batch, y_batch))
)(params)
# update params
# ...
return updated_params
Common pitfalls
- Side effects: JAX's functional style demands that you avoid or carefully handle side effects, such as mutating global state.
- Compilation overhead: The first time you run a JIT-compiled function, you pay a compilation cost. For large models, this overhead can be significant, but repeated calls pay off.
- Shape and type issues: Because JAX heavily relies on shape inference and static analysis, changing input shapes or dtypes can trigger recompilation.
Building an optimized transformer on a single GPU
Transformer building blocks
A typical Transformer includes:
- Embedding layers (for tokens or patch embeddings in vision).
- Multi-head self-attention modules (computationally expensive but key).
- Feed-forward networks (often large expansions, e.g. 4x the embedding dimension).
- Residual connections and layer normalization.
- Output projections for classification, language modeling, etc.
When you attempt to instantiate a large Transformer (e.g., hundreds of millions or billions of parameters) on a single GPU, memory constraints become your primary challenge. This is especially true if you have a large vocabulary (in NLP tasks) or large patch embeddings (in vision tasks).
Memory optimizations
- Layer sharing: Some Transformer variants reuse the same layer weights multiple times. For instance, DeBERTa reuses certain attention parameters across layers to reduce the overall memory footprint.
- Checkpointing: Gradient checkpointing is frequently used at scale.
- Mixed precision: Gains from half-precision are especially large in Transformers due to the massive number of matrix multiplications.
Multi-head attention efficiency
Multi-head attention can be memory-intensive because it requires calculating queries, keys, and values for each head, then performing a batch matrix multiplication. Libraries like FlashAttention (proposed in a recent research paper by Tri Dao and gang, ICML 2022) optimize these computations with better memory usage and improved caching. If you can integrate such specialized kernels, it can drastically reduce overhead and speed training.
Below is a PyTorch-like pseudo-implementation of a memory-optimized multi-head attention function, illustrating the concept (not the entire FlashAttention code, but a simplified approach).
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.out = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, D = x.shape
q = self.W_q(x).view(B, T, self.num_heads, self.head_dim)
k = self.W_k(x).view(B, T, self.num_heads, self.head_dim)
v = self.W_v(x).view(B, T, self.num_heads, self.head_dim)
# Permute to (B, num_heads, T, head_dim)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
# Scaled dot-product
scores = (q @ k.transpose(-1, -2)) / (self.head_dim ** 0.5)
attn = torch.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ v # shape: (B, num_heads, T, head_dim)
out = out.permute(0, 2, 1, 3).contiguous()
out = out.view(B, T, D)
return self.out(out)
In advanced usage, you might incorporate custom CUDA kernels or specialized libraries, but the principle remains the same: carefully handle memory usage at each step.
Profiling single-GPU training
Identifying hotspots
Profiling tools let you see how efficiently your GPU's resources are being used:
- PyTorch Profiler: built-in tool with a TensorBoard plugin.
- NVIDIA Nsight Systems: advanced system-wide profiler.
- TensorFlow Profiler: integrated with TensorBoard.
- JAX Profilers: often rely on XLA debug tools.
You'll want to watch for:
- GPU utilization: ideally close to 100% during training steps.
- Memory throughput: ensuring you're not bounding on memory copy overhead.
- Kernel launch overhead: small kernel calls repeated many times can hamper performance.
Interpreting profiler data
A typical profiler trace might show which layers or operations consume the most time (e.g., multi-head attention blocks, feed-forward expansions, or embedding lookups). This data helps you decide where to:
- Switch to specialized kernels.
- Adopt checkpointing.
- Change the shape or internal dimensions of your model.
Sometimes you'll discover that your input pipeline is the bottleneck. That is especially relevant if you're streaming data from disk or doing heavy augmentations on the CPU. Keeping the GPU fed with data is essential, so you may use multi-threaded or asynchronous data loading to prevent GPU idle time.
Asynchronous computation and automatic parallelism
Asynchronous execution
Modern frameworks typically enqueue operations on the GPU without blocking your CPU code. This asynchronous approach can hide data transfer latency but requires careful synchronization (e.g., using .cuda()
calls or .to(device)
in PyTorch). In some advanced scenarios, you might deliberately overlap data loading with GPU computation.
Graph optimizations
Frameworks like TensorFlow and JAX rely heavily on graph-based execution. They break your code into a computational graph that can be optimized. This can lead to:
- Operation fusion: merging multiple smaller kernels into one to minimize overhead.
- Inlining: removing function call overhead.
- Parallel scheduling: running independent sub-graphs in parallel on the GPU.
PyTorch's eager mode allows for greater flexibility but potentially less optimization. However, PyTorch also introduces TorchScript and torch.compile
(in newer versions) to enable more graph-based optimizations, bridging the gap with static-graph frameworks.
Computational hardware considerations
While we're focusing primarily on single-GPU training, the type of GPU you use can have a dramatic effect on your scaling endeavors.
GPU architecture basics
- Cores and threads: The raw computing units that perform multiply-add operations.
- Memory bandwidth: The speed at which data can be moved in and out of GPU memory.
- Tensor cores: Specialized hardware blocks optimized for matrix multiplication in half precision (FP16) or tensor float 32 (TF32).
Specialized hardware (TPUs, custom ASICs)
Google's Tensor Processing Units (TPUs) operate on a different programming paradigm (XLA-based) and can be extremely efficient for large-scale operations. Some companies also develop custom ASICs for specialized tasks. While these hardware offerings can push performance even further, the same fundamental scaling techniques (mixed precision, gradient checkpointing, etc.) often apply.
Matching hardware to scaling strategies
Selecting the right hardware means considering your memory requirements, compute demands, and software ecosystem:
- High memory capacity GPUs (e.g., NVIDIA A100 with 40GB or 80GB) let you handle bigger models or bigger batches.
- Less memory but high compute (like consumer GPUs with fewer memory channels but strong cores) might do well if you rely heavily on gradient accumulation or other memory reduction strategies.
Memory-optimization strategies
Beyond gradient checkpointing and mixed precision, several additional techniques can further reduce the memory footprint:
Activation quantization and tensor compression
Advanced methods compress activations on the fly, typically at a lower precision than FP16. Some research has explored 8-bit or even 4-bit quantization for activations. For example, Q8BERT (by Zafrir and gang, NeurIPS 2019) used 8-bit quantization for both weights and activations in a BERT model without substantial accuracy loss.
Offloading
If your GPU memory is severely constrained, you can offload certain tensors to CPU memory or disk:
- Optimizer state: Instead of storing the entire optimizer state (e.g., moments in Adam) on GPU, place them in CPU memory.
- Activation CPU offload: Some frameworks can automatically move seldom-needed activations to CPU memory and move them back to GPU on demand.
These methods can let you train bigger models, but offloading can also drastically slow down training if done too frequently.
Best practices for single-GPU training at scale
Efficient data pipelines
Always ensure that your training loop is not starved for data. Techniques include:
- Pre-fetching: Load data for the next batch while the current batch is still being processed.
- In-memory caching: If your dataset is small enough, keep it in RAM to avoid disk I/O.
- Shuffle buffering: Maintain a large shuffle buffer in memory to randomize data effectively.
Hyperparameter tuning for large models
Large models can be more sensitive to learning rate, batch size, warm-up schedules, and weight decay. Empirical testing is crucial; you might find that a carefully tuned small learning rate or specialized schedule can stabilize training.
Logging and checkpointing
When your training runs for days or weeks on a single GPU, you must have robust checkpointing:
- Periodic checkpointing: Save model weights at intervals so you can resume if an error occurs.
- Versioning: Include hyperparameters and environment details in your logs or checkpoints for full reproducibility.
- Logging: Tools like TensorBoard, WandB, or Neptune.ai can help track metrics over long runs and keep an eye on potential divergence or overfitting.
Intermediate summary of single-GPU techniques
We've covered an array of strategies — mixed precision, gradient checkpointing, gradient accumulation, asynchronous execution, specialized libraries, and more — that collectively allow you to push the limits of single-GPU training. This forms a critical first step toward truly large-scale training. Once you master these techniques, you'll be better positioned to handle multi-GPU setups or distributed training across multiple nodes, as we'll explore in subsequent parts of this course.
References and resources
- Official PyTorch Documentation: Provides excellent coverage of torch.cuda.amp and advanced memory optimization APIs.
- TensorFlow Guide to Mixed Precision: Offers an official breakdown of best practices for half precision in TF.
- JAX Documentation: Explains
vmap
,jit
, and the XLA compiler in detail. - Smith and gang, NeurIPS 2022: On advanced memory saving techniques for large transformers (covering various checkpointing strategies).
- OpenAI's blogs on GPT scaling: Provide insight into how large-scale language models benefit from every optimization trick you can imagine.
- D2L.ai: The open-source "Dive into Deep Learning" resource. Their chapters on computational performance are especially relevant if you want quick code examples.
- Community forums (PyTorch Discuss, TensorFlow Forums, JAX GitHub Issues): Great places to find real-world solutions from practitioners who have faced similar memory or scaling issues.
Remember, these single-GPU techniques serve as the critical foundation to scale beyond. In the next parts of the course, we'll examine multi-GPU data parallelism, model parallelism, pipeline parallelism, and distributed setups using frameworks like PyTorch's Distributed Data Parallel, Horovod, or Ray. Mastering the single-GPU memory management, precision tuning, and performance profiling is your stepping stone to the truly large-scale adventures awaiting you in the world of modern AI.