Optimizing PyTorch Training Workflows: Compile, Profile, Scale, and Checkpoint

Mastering PyTorch is no longer just about knowing the available features. It requires running a repeatable engineering workflow in which training code remains fast, scalable, and recoverable under realistic production workloads. Tools such as torch.compile, torch.profiler, DDP/FSDP, and Distributed Checkpointing are highly effective for keeping training efficient. Still, they only deliver value when they are introduced in the right order and validated carefully.

This article presents a recommended workflow: baseline → compile → profile → scale → checkpoint. It explains what should be measured before optimization begins, highlights common mistakes to avoid when using the compiler and profiler, outlines decision criteria for choosing between DDP and FSDP, and shows how to implement fault-tolerant checkpointing for multi-node training jobs.

Key Takeaways

  • Think of PyTorch performance tuning as an iterative engineering workflow rather than a list of features to switch on. Moving through baseline → compile → profile → scale → checkpoint produces more reliable long-term improvements than activating optimizations one by one without structure.
  • Without a stable single-GPU eager-mode baseline with known throughput and verified correctness, it is not possible to make meaningful performance comparisons or debug issues effectively.
  • Do not optimize before you establish a correct baseline. A dependable single-GPU eager baseline with measured throughput and validated correctness is the foundation for both debugging and later benchmarking.
  • Use torch.compile intentionally rather than automatically. Watch for graph breaks, handle shape behavior carefully, warm up before benchmarking, and confirm that steady-state execution actually outperforms eager mode.
  • Use profiling to guide decisions instead of confirming assumptions. torch.profiler should be used to identify CPU stalls, kernel hotspots, shape retracing, and communication overhead in distributed runs.
  • Plan checkpointing around failure from the start. Distributed Checkpointing, optionally combined with asynchronous saves, should capture the entire training state, support resharding across different GPU counts, and be validated regularly with restore tests to ensure true fault tolerance.

Baseline: Create a Reliable Reference Point

Begin with a working single-GPU training example. This serves as the reference point for both functionality and performance. Define the model, dataloader, and training loop, and verify that everything runs end-to-end in eager mode, with no compilation and only one process. A simple example training loop could look like this:

import torch
import torch.nn as nn
# Dummy model and data for illustration
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10))
data = torch.randn(32, 100)
targets = torch.randint(0, 10, (32,))
# Baseline forward + backward
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
outputs = model(data)
loss = nn.CrossEntropyLoss()(outputs, targets)
loss.backward()
optimizer.step()

Run several training iterations and measure throughput, such as samples per second. A correct baseline confirms that the model trains as expected and gives you a reference for future optimizations. At this stage, correctness and baseline performance matter most. Make sure the GPU is actually being used by checking tools such as nvidia-smi or PyTorch logs, and verify that there are no obvious bottlenecks such as data loading stalls or unnecessary CPU-heavy work. Do not proceed to compilation or other advanced optimization steps until the baseline is stable.

Baseline Checklist

  • Functional correctness: The model trains correctly and produces expected results on a single GPU.
  • Basic performance recorded: For example, time per batch or GPU utilization.
  • No obvious bottlenecks: The data pipeline keeps the GPU occupied without long idle periods.

Once you have a dependable baseline in place, you can begin introducing advanced features step by step, starting with the compiler introduced in PyTorch 2.x to speed up training.

Compile: Speed Up Training with torch.compile

PyTorch 2 introduced just-in-time compilation through torch.compile, allowing users to improve model performance by compiling models for optimized execution. By wrapping a model or function with torch.compile, PyTorch generates optimized code automatically behind the scenes.

The change is minimal:

# Switch to compiled mode
model = torch.compile(model)  # uses default backend 'inductor'

Once this is enabled, calling model(data) uses an optimized execution path. The first few iterations compile the model just in time, while later iterations reuse the optimized kernels directly. PyTorch 2.9 also caches compilation results automatically to improve future runs, including across processes.

To get the most out of compilation, however, it is important to understand graph breaks and dynamic shapes.

Graph Breaks

Graph breaks occur whenever the compiler cannot capture part of the code into a single graph, such as when Python control flow depends on runtime data. By default, torch.compile falls back automatically to eager execution for unsupported sections while compiling the rest. This keeps the code running, but every graph break leaves part of the workload unoptimized because that portion continues to execute in Python rather than in the fused graph.

For development-time debugging, torch.compile(fullgraph=True) is useful because it raises an error as soon as any part of the model cannot be compiled. For example:

model = torch.compile(model, fullgraph=True)
try:
    model(data)
except Exception as e:
    print("Graph break:", e)
model = torch.compile(model, fullgraph=True)
try:
    model(data)
except Exception as e:
    print("Graph break:", e)

This causes the first unsupported operation to raise an exception and helps identify the code that needs refactoring. The error message usually contains hints or a URL explaining why the break happened and how to avoid it. PyTorch logging can also help. Running the script with the environment variable TORCH_LOGS="graph_breaks" prints the reasons and locations of graph breaks.

Use that information to rewrite or remove Python-side operations that interfere with compilation. Examples include replacing Python list() usage or data-dependent if/else logic with tensor-based operations, or eliminating those constructs entirely.

Dynamic Shapes

By default, torch.compile specializes the compiled graph for the shapes it encounters. If the model receives inputs with different sizes, it recompiles whenever it sees a new shape, which introduces overhead. In PyTorch 2.9, a dynamic flag was added. Setting dynamic=True attempts to generate a single kernel that can handle multiple shapes through symbolic shapes. For example:

model = torch.compile(model, dynamic=True)

This tells the compiler to trace a generalized graph, reducing recompilations when input sizes vary, such as changing sequence lengths. By contrast, dynamic=False forces exact-shape specialization and can produce faster execution when shapes are truly fixed. In PyTorch 2.9, dynamic=None is the default. It begins by specializing and then automatically switches to a dynamic kernel if repeated recompilations are detected.

  • Leave dynamic at the default or set it to False when input shapes are constant or rarely change and maximum specialization is preferred.
  • Set dynamic=True when shapes vary often and repeated recompilations become a problem, such as in NLP workloads with variable sequence lengths, especially if the recompile limit of 8 is being reached before the system falls back to eager execution.

Compiled Example

The following example compiles a simple model and shows how graph breaks and dynamic shapes can be handled:

import torch
# Sample model with a potential graph break (data-dependent control flow)
class ToyModel(torch.nn.Module):
    def forward(self, x):
        # Example: data-dependent branching (not traceable by Dynamo)
        if x.sum() > 0:
            return x * 2
        else:
            return x
model = ToyModel()
# Attempt full-graph compilation to catch breaks
try:
    model = torch.compile(model, fullgraph=True)
    output = model(torch.randn(4, 4))
except Exception as e:
    print("Graph break detected:", e)
    # Rewrite model or accept partial graph compilation
    model = torch.compile(model, fullgraph=False)  # fallback to allow breaks

In this example, the data-dependent if statement causes a graph break when fullgraph=True is used. The code catches the exception, prints it, and recompiles with the default behavior that permits graph breaks. In real-world usage, the goal would be to modify the model so that it compiles successfully with fullgraph=True and no exceptions, meaning the entire model can be compiled as a single static graph for maximum performance.

Compile Checklist

  • Wrap the model with torch.compile(): This small change can often produce meaningful speedups. Use the default inductor backend unless there is a clear reason to choose another.
  • Warm up first: Run several iterations before measuring because the initial runs include compilation overhead.
  • Inspect graph breaks: During development, try fullgraph=True and use TORCH_LOGS="graph_breaks" to locate unsupported code paths. Refactor or remove them.
  • Adjust dynamic shape handling: If repeated recompilations appear in logs or output, consider dynamic=True to generate a more flexible graph.
  • Measure the speedup: Compare throughput against the baseline. The first iteration may be slower, but steady-state execution should be faster than eager mode. If not, enable performance hints with TORCH_LOGS="perf_hints".

After compiling the model and confirming improved speed, the next step is to inspect its execution closely so that remaining bottlenecks can be found and addressed.

Profile: Find Bottlenecks with torch.profiler

Even after compilation, performance limitations may still remain, including underused GPUs, I/O bottlenecks, or inefficient kernels. PyTorch includes a built-in profiler that captures execution traces and helps identify those issues. In PyTorch 2.9, torch.profiler can trace CPU and GPU activity, record tensor shapes, and integrate with tools such as Chrome Trace Viewer and TensorBoard for visualization. It can also capture traces asynchronously, allowing parts of training to be profiled without interrupting program flow.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.profiler
# -----------------------------
# Device setup
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Model definition
# -----------------------------
model = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
).to(device)
# -----------------------------
# Optimizer and loss
# -----------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# -----------------------------
# Data iterator (synthetic)
# -----------------------------
def data_generator(batch_size=32):
    while True:
        inputs = torch.randn(batch_size, 100, device=device)
        targets = torch.randint(0, 10, (batch_size,), device=device)
        yield inputs, targets
data_iter = data_generator()
# -----------------------------
# Training step
# -----------------------------
def train_step(batch):
    model.train()
    inputs, targets = batch
    optimizer.zero_grad(set_to_none=True)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss
# -----------------------------
# Warm-up phase
# -----------------------------
for _ in range(5):
    train_step(next(data_iter))
# -----------------------------
# Profiling phase
# -----------------------------
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    profile_memory=True,
) as prof:
    for _ in range(3):
        train_step(next(data_iter))
        prof.step()
# -----------------------------
# Export trace
# -----------------------------
prof.export_chrome_trace("trace.json")
print("Profiler trace saved to trace.json")

This example defines a minimal but complete PyTorch training loop and uses torch.profiler to collect steady-state CPU and GPU performance data. It creates a small neural network, generates synthetic data through an iterator, and places one training iteration inside a train_step() function.

Several warm-up steps are executed first so that one-time initialization and caching effects are not included in the measurements. Then a few training steps are profiled while recording operator shapes and memory usage. The script also exports a Chrome trace file named trace.json, which can be opened in the browser to inspect GPU utilization, kernel launches, CPU-GPU overlap, and other performance bottlenecks. The settings record_shapes=True and profile_memory=True help expose shape-related issues and memory allocation patterns, which is useful when diagnosing inefficiencies or out-of-memory conditions.

Tip: Open chrome://tracing in Google Chrome and load trace.json to view the execution timeline.

For automatic bottleneck detection, torch.utils.bottleneck or torch.profiler.schedule can be used to capture periodic snapshots. For example, the following setup profiles a few steps every epoch using a schedule:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.profiler
# -----------------------------
# Device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Model
# -----------------------------
model = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
).to(device)
# -----------------------------
# Optimizer and loss
# -----------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# -----------------------------
# Dataset and DataLoader
# -----------------------------
num_samples = 10_000
batch_size = 32
inputs = torch.randn(num_samples, 100)
targets = torch.randint(0, 10, (num_samples,))
dataset = TensorDataset(inputs, targets)
data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
)
# -----------------------------
# Training step
# -----------------------------
def train_step(batch):
    model.train()
    inputs, targets = batch
    inputs = inputs.to(device, non_blocking=True)
    targets = targets.to(device, non_blocking=True)
    optimizer.zero_grad(set_to_none=True)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss
# -----------------------------
# Warm-up (outside profiler)
# -----------------------------
for i, batch in enumerate(data_loader):
    train_step(batch)
    if i >= 5:
        break
# -----------------------------
# Profiling with schedule
# -----------------------------
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=3,
        repeat=2,
    ),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./prof_log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, batch in enumerate(data_loader):
        loss = train_step(batch)
        prof.step()

        if step >= 50:
            break

This setup asynchronously captures two tracing windows, each lasting three steps, after an initial wait and warm-up phase. The on_trace_ready callback writes the traces to a location that can be inspected in TensorBoard’s profiling tab without halting training. Asynchronous capture allows intermittent profiling, such as every few steps within a larger training window, reducing overhead while still providing visibility into system behavior.

After collecting profiling data, use the results to guide concrete action:

  • CPU-bound execution: If the GPU shows many gaps while waiting, consider moving more work onto the GPU or overlapping data loading with computation through asynchronous loading, such as configuring worker processes in the DataLoader or preprocessing on the GPU. Also inspect whether Python loops are performing work during training, and remove or compile them where possible.
  • GPU kernels are the bottleneck: Optimize those operations with fused kernels, lower precision, or other kernel-level improvements. PyTorch’s performance guidance generally recommends relying on high-level operations such as torch.nn.functional so the underlying libraries can optimize them. If one operation is disproportionately slow, determine whether that is expected or whether a better alternative exists.
  • Multiple shapes are triggering retracing: Consider grouping inputs by size or enabling dynamic=True as described earlier.
  • No external bottlenecks should remain: In distributed training, verify that communication overhead is not dominating runtime. Signs include NCCL kernels or CPU time spent waiting on the network.

Profiling Checklist

  • Warm up before profiling: Avoid measuring compilation or lazy initialization overhead.
  • Capture both CPU and GPU activity: Use activities=[CPU, CUDA] to understand how they interact.
  • Record shapes and memory: This helps identify shape-driven issues and memory spikes.
  • Use scheduling for long runs: profiler.schedule can capture short windows periodically, reducing overhead while preserving visibility.
  • Analyze utilization: Confirm that GPUs are fully used. If not, determine whether CPU or I/O is limiting throughput.
  • Find the most expensive operations: The profiler can rank operations by self-time and other metrics. Focus optimization efforts on the operations consuming the most time, whether by improving algorithms or adjusting batch sizes when GPUs are underutilized.

Once the application is optimized on a single machine, the next step is to scale across multiple GPUs or nodes, which introduces PyTorch’s distributed training strategies.

Scale: Distributed Training with DDP or FSDP

When the workload or model outgrows a single GPU, training must be scaled across multiple devices. PyTorch provides two primary methods for multi-GPU training: Distributed Data Parallel and Fully Sharded Data Parallel. Both approaches rely on data parallelism, where each process handles a different subset of the data, but they differ significantly in how model parameters are stored and synchronized in memory. The following sections explain when each approach is appropriate and how to configure them for both single-node and multi-node training with torchrun.

Distributed Data Parallel

DDP keeps a full copy of the model on every GPU and synchronizes gradients by all-reducing them after each backward pass. It is a good fit for models that comfortably fit into the memory of a single GPU. The overall concept is straightforward: initialize a process group, then wrap the model with torch.nn.parallel.DistributedDataParallel.

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
# -----------------------------
# Setup distributed (safe)
# -----------------------------
def setup_distributed():
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    if world_size > 1:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        distributed = True
    else:
        local_rank = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        distributed = False

    return distributed, local_rank, device

def cleanup_distributed(distributed):
    if distributed:
        dist.destroy_process_group()

# -----------------------------
# Main training
# -----------------------------
def main():
    distributed, local_rank, device = setup_distributed()

    # -----------------------------
    # Model
    # -----------------------------
    model = nn.Sequential(
        nn.Linear(100, 256),
        nn.ReLU(),
        nn.Linear(256, 10),
    ).to(device)
    if distributed:
        model = DDP(model, device_ids=[local_rank])
    # -----------------------------
    # Optimizer and loss
    # -----------------------------
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    # -----------------------------
    # Dataset
    # -----------------------------
    num_samples = 20_000
    batch_size = 32
    inputs = torch.randn(num_samples, 100)
    targets = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    if distributed:
        sampler = DistributedSampler(dataset, shuffle=True)
    else:
        sampler = None
    train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=2,
        pin_memory=True,
    )

    # -----------------------------
    # Training loop
    # -----------------------------
    epochs = 3
    for epoch in range(epochs):
        if distributed:
            sampler.set_epoch(epoch)

        for step, (x, y) in enumerate(train_loader):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            if step % 50 == 0 and local_rank == 0:
                print(
                    f"[Epoch {epoch}] Step {step} | "
                    f"Loss {loss.item():.4f}"
                )
    cleanup_distributed(distributed)
# -----------------------------
# Entry point
# -----------------------------
if __name__ == "__main__":
    main()

A DDP job is launched with torchrun. For example, to start a job on a single machine with four GPUs:

torchrun --nproc_per_node=4 train.py

This command starts four processes, one per GPU, and configures the needed environment variables such as rank and world size automatically. For multi-node execution, additional node and network information must be provided:

# On node 0:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr="<IP of node0>" --master_port=12345 train.py
# On node 1:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=1 --master_addr="<IP of node0>" --master_port=12345 train.py

Fully Sharded Data Parallel (FSDP)

FSDP goes further than DDP by sharding model parameters and optimizer states across GPUs instead of fully replicating them on every device. In PyTorch 2.9, using the FullyShardedDataParallel wrapper is relatively straightforward.

Key Points

  • Wrap the model, or selected submodules, with FSDP before creating the optimizer.
  • Use torch.cuda.set_device(rank) and place the model on the target GPU just as with DDP.
  • An optional auto-wrap policy can be supplied when sharding at module granularity is desired, such as wrapping each transformer block separately. Without such a policy, FSDP(model) treats the entire model as one shard unit, which is often inefficient for deep models.

This approach can recursively shard submodules above a chosen parameter threshold, such as 100k parameters per unit. If there are too few shards and the whole model is treated as one large shard, memory savings between layers are limited. If there are too many shards, communication overhead increases. The table below summarizes the major differences between DDP and FSDP:

Aspect DDP (Distributed Data Parallel) FSDP (Fully Sharded Data Parallel)
Launch method Started with torchrun, one process per GPU. Also started with torchrun, one process per GPU.
Model replication Each rank stores a full copy of the model. The model is split across ranks, so each rank stores only a portion of the parameters.
Memory usage per GPU High, because it scales with full model size. Much lower, roughly equal to model size divided by the number of GPUs.
Communication pattern Gradients are all-reduced after backward. Parameters are all-gathered before forward, and gradients are reduce-scattered after backward.
Model wrapping requirement All ranks wrap the full model in DistributedDataParallel. All ranks must apply identical FSDP wrapping logic so each process knows which shard it owns.
Ease of use Simple to configure and requires only modest code changes. More complex and demands careful wrapping policies and optimizer setup.
Scalability Limited by the memory available on each GPU, so it is not well suited to extremely large models. Designed for very large models that do not fit on a single GPU.
Typical use case Models that fit comfortably in GPU memory and mainly need faster training. Very large models or workloads where memory efficiency is critical.

As a practical starting point, DDP is often the easier choice. If memory pressure becomes a problem or larger models are required, FSDP becomes worth considering. In PyTorch 2.9, FSDP has matured further. By default, it now effectively “shards everything” in a ZeRO Stage 3 style and uses defaults such as limit_all_gathers=True to reduce unexpected memory spikes.

Scaling Checklist

  • Ensure determinism: Use the same random seed on all ranks, such as torch.manual_seed(seed + rank_offset) if reproducibility is required.
  • Use DistributedSampler: Split the dataset across ranks and call set_epoch() each epoch so shuffling differs from one epoch to the next.
  • DDP: Gradients are synchronized automatically across GPUs after backward() is called on models wrapped with DDP. Remember to scale the learning rate or batch size as the number of GPUs increases.
  • FSDP: Wrap modules before building the optimizer. Adjust auto_wrap_policy to avoid ending up with one oversized shard in very deep models. Confirm that shard sizes genuinely fit within GPU memory.
  • Communication backend: In multi-node runs, verify that networking is configured correctly, typically with NCCL. Environment variables such as NCCL_P2P_LEVEL=NVL may improve communication when NVLink is available.
  • Gradient accumulation: If it is used with FSDP, call model.no_sync() on iterations where gradient synchronization should be skipped, similar to DDP behavior.

Checkpoint: Make Training Recovery Reliable with Distributed Checkpoints

In long-running or large-scale training jobs, checkpointing is essential. You may need to resume after a crash, continue training from an intermediate point, or inspect partial results. PyTorch 2.9 includes mature support for distributed checkpointing through DCP, which makes saving and loading state more efficient and reliable than traditional approaches based on torch.save. Before moving to the recommended approach, it is useful to compare both methods.

Traditional torch.save / torch.load

In a traditional single-GPU or DDP training script, the usual approach is to save the model’s state_dict() to a file, typically from rank 0 only, since all processes can usually access the same output location. For example:

# On rank 0 only:
torch.save(model.state_dict(), "checkpoint.pt")

And to restore it:

model.load_state_dict(torch.load("checkpoint.pt"))

This works well for small models or straightforward setups where one process can hold the full model state in memory. However, it becomes problematic in distributed training when the model is sharded across processes.

Distributed Checkpoint (DCP)

The torch.distributed.checkpoint module in PyTorch, introduced in earlier versions and further matured by 2.9, addresses these limitations. It parallelizes saving so that each rank writes its own portion of the model state, creating multiple files that together form one checkpoint. It also supports resharding during load, meaning a checkpoint saved on N ranks can later be loaded on M ranks with the required gathering and redistribution performed automatically.

A common usage pattern is to place the model and optimizer inside a stateful container that exposes a state_dict. Utility functions such as get_state_dict and set_state_dict make this FSDP-aware:

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

# Define a stateful container for model & optimizer
class AppState(Stateful):
    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer
    def state_dict(self):
        model_sd, opt_sd = get_state_dict(self.model, self.optimizer)
        return {"model": model_sd, "optim": opt_sd}
    def load_state_dict(self, state):
        set_state_dict(self.model, self.optimizer, state["model"], state.get("optim"))

In this example, AppState implements the Stateful interface so DCP knows how to retrieve and restore state. If the model is sharded, get_state_dict automatically returns the correct sharded state for each rank.

Saving with DCP

All ranks execute the following:

app_state = AppState(model, optimizer)
dcp.save(app_state, checkpoint_id="mycheckpoint")

This creates checkpoint directories containing files prefixed with mycheckpoint, with each rank writing its shard in parallel. Compared with one-rank saving, this can reduce checkpoint time significantly.

Loading with DCP

To resume training, create the model and optimizer again, wrap them in AppState, and then load:

app_state = AppState(model, optimizer)
dcp.load(app_state, checkpoint_id="mycheckpoint");

If the number of ranks differs from the original save configuration, DCP handles the shard redistribution automatically. For example, loading a checkpoint saved on eight GPUs onto four GPUs means each rank may load two previously saved shards. If the number of ranks increases, DCP redistributes them accordingly. This built-in resharding removes the need for manual conversion of checkpoints.

Asynchronous Checkpointing

One of the notable features in PyTorch 2.9 is asynchronous checkpoint saving, which allows checkpoint writes to overlap with ongoing training. With dcp.async_save, the API returns a future and performs the heavy I/O in background threads. The pattern looks like this:

save_future = dcp.async_save(app_state, checkpoint_id="chkpt_epoch10")
# ... training can continue immediately ...
save_future.wait()  # later, wait for completion (or periodically check)

async_save first stages the data locally on the CPU by copying the model and optimizer state into pinned memory buffers, then writes those buffers asynchronously to storage. This allows GPU compute to continue while the checkpoint is being flushed.

The tradeoff is temporary memory overhead. Asynchronous saving can roughly double the memory needed to store model state during the save process, because each rank allocates CPU-side buffers approximately equal to its checkpoint shard. For very large models and many ranks, CPU RAM and pinned memory usage can rise substantially.

The following table compares the available checkpointing approaches and when to use them:

Checkpoint Method Description When to Use
torch.save/torch.load (rank 0) Single-file checkpoint written by one process. If the model is sharded, state is gathered and saved as one file, usually by rank 0. It is simple and works well for non-sharded models. Use for small or medium models, or single-GPU training. It can also be acceptable in DDP when model size remains moderate because every rank has a full copy. It is not a good fit for very large or heavily distributed models because it can become slow or cause out-of-memory failures.
DCP (synchronous) Distributed checkpointing in parallel. Every rank writes its own shard of the state dictionary. Training pauses until all ranks finish writing. It supports FSDP, ShardedTensor, and loading across different world sizes through automatic resharding. Multiple files are produced. Use for large models on multi-GPU setups with either DDP or FSDP. It is recommended when one-rank checkpointing would be too slow or too large. Use synchronous mode when training can tolerate brief pauses, such as between epochs, or when storage performance is fast enough that the pause is minor.
DCP + Async (async_save) Distributed checkpointing with background writes. It returns immediately and performs saving asynchronously, allowing training to continue while data is written. Extra memory is required for staging buffers, and overlapping saves must be managed carefully. Use for very large models or training jobs where checkpoint time is a significant fraction of runtime. It is especially useful in production workflows where pauses are too costly. Ensure sufficient CPU RAM and pinned memory are available, and validate behavior in the target environment because storage performance can still affect the training job indirectly.

Checkpointing Checklist

  • Include all required state: Save model weights, optimizer state, and if necessary scheduler state, RNG state, and other training context so that training can resume fully.
  • Test restoration: Immediately after saving, restore the checkpoint in a clean environment, ideally both on the same number of GPUs and on a different GPU count for distributed setups. This helps reveal missing state early.
  • Manage storage carefully: Checkpoints can become very large, especially when distributed. Retention policies such as keeping only the most recent k checkpoints help prevent storage exhaustion.
  • Use async mode carefully: async_save is powerful, but multiple saves should not be left in flight simultaneously unless memory usage is very carefully managed. Always call .wait() on the returned future eventually so errors are surfaced and success is confirmed.
  • Maintain consistency: When restoring from a checkpoint, all ranks must call the same dcp.load. This is especially important with FSDP, where loading involves coordinated synchronization. Also ensure the model is already on the expected devices when using offloading or similar mechanisms.

Conclusion

Following this workflow helps produce PyTorch training code that is both high-performing and resilient to failure:

  • Begin with a working baseline.
  • Compile for better execution speed.
  • Profile to uncover further optimization opportunities.
  • Scale with the right data-parallel or model-parallel strategy.
  • Checkpoint correctly so training can recover from interruptions.

Each stage has its own practical playbook. Use them iteratively. For example, profile again after scaling changes the performance profile, or revisit compile settings after the environment changes. PyTorch has evolved quickly and now provides tools to support every step of this process. Knowing when and how to apply each one makes it possible to train large models efficiently at scale while preserving robust recovery behavior.

FAQs

Why is establishing a baseline important before optimization?

A reliable single-GPU eager-mode baseline provides a trusted reference point for both correctness and performance. Without it, optimization techniques can obscure underlying issues, making it difficult to determine whether performance changes are the result of actual improvements or hidden bugs. A validated baseline allows you to accurately measure the impact of each optimization and identify regressions with confidence.

When should I use torch.compile, and what should I consider?

torch.compile is best introduced after you have established a stable performance baseline. It can significantly improve execution speed during steady-state training and inference, but it requires careful monitoring. Pay attention to graph breaks, allow for adequate warm-up before benchmarking, and manage dynamic shapes thoughtfully to avoid unnecessary recompilation overhead.

How does torch.profiler help beyond validating assumptions?

torch.profiler provides detailed visibility into actual application performance, revealing bottlenecks that may not be obvious from code inspection alone. It can identify issues such as CPU bottlenecks, inefficient GPU kernels, repeated graph retracing, memory inefficiencies, and communication overhead in distributed workloads. This enables data-driven optimization rather than relying on assumptions or intuition.

How do I choose between DDP and FSDP for distributed training?

If your model fits comfortably within the memory available on each GPU, Distributed Data Parallel (DDP) is usually the preferred option due to its simplicity and strong performance. For very large models that approach or exceed GPU memory limits, Fully Sharded Data Parallel (FSDP) offers greater scalability by distributing model parameters across devices, although it introduces additional configuration and operational complexity.

Why is Distributed Checkpointing preferred over torch.save for large-scale training?

Distributed Checkpointing is designed for large multi-GPU and multi-node environments. Unlike torch.save, it can parallelize checkpoint operations across multiple ranks, support parameter resharding during loading, and enable asynchronous checkpoint creation. These capabilities improve checkpoint performance, reduce training interruptions, and provide more reliable recovery for large-scale training workloads.

Source: digitalocean.com

Create a Free Account

Register now and get access to our Cloud Services.

Posts you might be interested in: