PyTorch-Training optimieren: Kompilieren, Profiling, Skalierung und Checkpointing

PyTorch zu beherrschen bedeutet heute nicht mehr nur, einzelne Funktionen zu kennen. Entscheidend ist ein wiederholbarer Engineering-Prozess, bei dem Trainingscode auch unter realistischen Produktionslasten schnell, skalierbar und wiederherstellbar bleibt. Werkzeuge wie torch.compile, torch.profiler, DDP/FSDP und Distributed Checkpointing sind sehr wirkungsvoll, um Training effizient zu halten. Ihren tatsächlichen Nutzen entfalten sie jedoch erst, wenn sie in der richtigen Reihenfolge eingesetzt und gründlich überprüft werden.

Dieser Beitrag beschreibt einen empfohlenen Ablauf: Baseline → Compile → Profile → Scale → Checkpoint. Er zeigt, welche Werte vor der Optimierung gemessen werden sollten, welche typischen Fehler beim Compiler und Profiler vermieden werden müssen, nach welchen Kriterien DDP oder FSDP ausgewählt werden können und wie fehlertolerantes Checkpointing für Multi-Node-Trainingsläufe umgesetzt wird.

Wichtige Erkenntnisse

  • Betrachten Sie PyTorch-Performance-Tuning als iterativen Engineering-Prozess und nicht als reine Liste aktivierbarer Funktionen. Der strukturierte Weg über Baseline → Compile → Profile → Scale → Checkpoint führt langfristig zu stabileren Verbesserungen als isolierte Einzeloptimierungen.
  • Ohne eine stabile Single-GPU-Baseline im Eager Mode mit bekannter Durchsatzrate und geprüfter Korrektheit lassen sich Performance-Vergleiche und Fehleranalysen kaum sinnvoll durchführen.
  • Optimieren Sie nicht, bevor eine korrekte Baseline steht. Eine verlässliche Single-GPU-Baseline im Eager Mode mit gemessenem Durchsatz und validierten Ergebnissen ist die Grundlage für späteres Benchmarking und Debugging.
  • Setzen Sie torch.compile bewusst ein, nicht automatisch. Beobachten Sie Graph Breaks, berücksichtigen Sie das Verhalten dynamischer Shapes, führen Sie Warm-up-Schritte vor Messungen aus und prüfen Sie, ob die stabile Laufzeit wirklich schneller ist als im Eager Mode.
  • Profiling sollte Entscheidungen ermöglichen, nicht bloß Annahmen bestätigen. Mit torch.profiler lassen sich CPU-Stalls, Kernel-Hotspots, erneutes Tracing durch Shape-Änderungen und Kommunikationsaufwand in verteilten Trainingsläufen sichtbar machen.
  • Checkpointing sollte von Beginn an auf Ausfälle ausgelegt sein. Distributed Checkpointing, optional mit asynchronem Speichern, sollte den vollständigen Trainingszustand erfassen, Resharding über unterschiedliche GPU-Anzahlen unterstützen und regelmäßig durch Restore-Tests validiert werden.

Baseline: Einen zuverlässigen Referenzpunkt schaffen

Starten Sie mit einem funktionierenden Trainingsbeispiel auf einer einzelnen GPU. Dieses dient als Referenz für Funktionalität und Performance. Definieren Sie Modell, Dataloader und Trainingsschleife und stellen Sie sicher, dass alles vollständig im Eager Mode läuft, also ohne Compilation und mit nur einem Prozess. Eine einfache Trainingsschleife kann beispielsweise so aussehen:

import torch
import torch.nn as nn
# Dummy model and data for illustration
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10))
data = torch.randn(32, 100)
targets = torch.randint(0, 10, (32,))
# Baseline forward + backward
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
outputs = model(data)
loss = nn.CrossEntropyLoss()(outputs, targets)
loss.backward()
optimizer.step()

Führen Sie mehrere Trainingsiterationen aus und messen Sie den Durchsatz, zum Beispiel Samples pro Sekunde. Eine korrekte Baseline bestätigt, dass das Modell wie erwartet trainiert, und liefert einen Vergleichswert für spätere Optimierungen. In dieser Phase stehen Korrektheit und Basisperformance im Vordergrund. Prüfen Sie, ob die GPU tatsächlich genutzt wird, etwa mit nvidia-smi oder PyTorch-Logs, und stellen Sie sicher, dass keine offensichtlichen Engpässe bestehen, etwa stockendes Data Loading oder unnötig CPU-lastige Operationen. Wechseln Sie erst zu Compilation oder weiteren Optimierungsschritten, wenn diese Baseline stabil ist.

Baseline-Checkliste

  • Funktionale Korrektheit: Das Modell trainiert auf einer einzelnen GPU korrekt und liefert erwartete Ergebnisse.
  • Grundlegende Performance erfasst: Zum Beispiel Zeit pro Batch oder GPU-Auslastung.
  • Keine offensichtlichen Engpässe: Die Datenpipeline hält die GPU beschäftigt und verursacht keine langen Leerlaufphasen.

Sobald eine belastbare Baseline vorhanden ist, können fortgeschrittene Funktionen Schritt für Schritt ergänzt werden. Der erste sinnvolle Schritt ist der in PyTorch 2.x eingeführte Compiler zur Beschleunigung des Trainings.

Compile: Training mit torch.compile beschleunigen

PyTorch 2 hat mit torch.compile eine Just-in-Time-Compilation eingeführt, mit der Modelle für eine optimierte Ausführung kompiliert werden können. Wird ein Modell oder eine Funktion mit torch.compile umschlossen, erzeugt PyTorch im Hintergrund automatisch optimierten Code.

Die Änderung ist sehr klein:

# Switch to compiled mode
model = torch.compile(model)  # uses default backend 'inductor'

Nach der Aktivierung nutzt ein Aufruf von model(data) automatisch einen optimierten Ausführungspfad. Die ersten Iterationen kompilieren das Modell just in time, spätere Iterationen verwenden die optimierten Kernel direkt weiter. PyTorch 2.9 speichert Compilation-Ergebnisse außerdem automatisch zwischen, um spätere Ausführungen zu verbessern, auch über Prozesse hinweg.

Um den größten Nutzen aus der Compilation zu ziehen, sollten Graph Breaks und dynamische Shapes verstanden werden.

Graph Breaks

Graph Breaks entstehen, wenn der Compiler einen Teil des Codes nicht in einem einzelnen Graphen erfassen kann, etwa wenn Python-Kontrollfluss von Laufzeitdaten abhängt. Standardmäßig fällt torch.compile für nicht unterstützte Abschnitte automatisch auf Eager Execution zurück und kompiliert den übrigen Teil weiter. Dadurch läuft der Code weiter, allerdings bleibt jeder Abschnitt mit Graph Break unoptimiert, weil er weiterhin in Python statt im fusionierten Graphen ausgeführt wird.

Für das Debugging während der Entwicklung ist torch.compile(fullgraph=True) hilfreich, da sofort ein Fehler ausgelöst wird, sobald ein Teil des Modells nicht kompiliert werden kann. Beispiel:

model = torch.compile(model, fullgraph=True)
try:
    model(data)
except Exception as e:
    print("Graph break:", e)
model = torch.compile(model, fullgraph=True)
try:
    model(data)
except Exception as e:
    print("Graph break:", e)

Dadurch wird beim ersten nicht unterstützten Vorgang eine Ausnahme ausgelöst, was den betroffenen Code leichter auffindbar macht. Die Fehlermeldung enthält üblicherweise Hinweise oder eine URL, die erklärt, warum der Break aufgetreten ist und wie er vermieden werden kann. Auch PyTorch-Logging ist nützlich. Wenn das Skript mit der Umgebungsvariable TORCH_LOGS="graph_breaks" gestartet wird, werden Ursachen und Positionen von Graph Breaks ausgegeben.

Nutzen Sie diese Informationen, um Python-seitige Operationen umzuschreiben oder zu entfernen, die Compilation verhindern. Dazu gehört beispielsweise, Python-list()-Operationen oder datenabhängige if/else-Logik durch Tensor-Operationen zu ersetzen oder solche Konstrukte vollständig zu entfernen.

Dynamische Shapes

Standardmäßig spezialisiert torch.compile den kompilierten Graphen auf die Shapes, die beim Ausführen beobachtet werden. Wenn ein Modell Eingaben unterschiedlicher Größe erhält, wird bei neuen Shapes erneut kompiliert, was zusätzlichen Aufwand verursacht. In PyTorch 2.9 wurde dafür ein dynamic-Flag eingeführt. Mit dynamic=True versucht der Compiler, einen Kernel zu erzeugen, der mehrere Shapes über symbolische Shapes verarbeiten kann. Beispiel:

model = torch.compile(model, dynamic=True)

Damit wird der Compiler angewiesen, einen allgemeineren Graphen zu tracen. Das reduziert Recompilations, wenn Eingabegrößen variieren, beispielsweise bei wechselnden Sequenzlängen. Im Gegensatz dazu erzwingt dynamic=False eine Spezialisierung auf exakte Shapes und kann schneller sein, wenn die Shapes tatsächlich konstant bleiben. In PyTorch 2.9 ist dynamic=None die Standardeinstellung. Dabei wird zunächst spezialisiert, und bei wiederholten Recompilations wird automatisch auf einen dynamischen Kernel gewechselt.

  • Lassen Sie dynamic auf dem Standardwert oder setzen Sie es auf False, wenn Eingabeshapes konstant sind oder sich nur selten ändern und maximale Spezialisierung gewünscht ist.
  • Setzen Sie dynamic=True, wenn Shapes häufig variieren und wiederholte Recompilations zum Problem werden, etwa bei NLP-Workloads mit unterschiedlich langen Sequenzen, insbesondere wenn das Recompile-Limit von 8 erreicht wird und danach auf Eager Execution zurückgefallen würde.

Kompiliertes Beispiel

Das folgende Beispiel kompiliert ein einfaches Modell und zeigt, wie Graph Breaks und dynamische Shapes behandelt werden können:

import torch
# Sample model with a potential graph break (data-dependent control flow)
class ToyModel(torch.nn.Module):
    def forward(self, x):
        # Example: data-dependent branching (not traceable by Dynamo)
        if x.sum() > 0:
            return x * 2
        else:
            return x
model = ToyModel()
# Attempt full-graph compilation to catch breaks
try:
    model = torch.compile(model, fullgraph=True)
    output = model(torch.randn(4, 4))
except Exception as e:
    print("Graph break detected:", e)
    # Rewrite model or accept partial graph compilation
    model = torch.compile(model, fullgraph=False)  # fallback to allow breaks

In diesem Beispiel verursacht die datenabhängige if-Anweisung einen Graph Break, wenn fullgraph=True verwendet wird. Der Code fängt die Ausnahme ab, gibt sie aus und kompiliert anschließend mit dem Standardverhalten neu, bei dem Graph Breaks erlaubt sind. In der Praxis sollte das Modell so angepasst werden, dass es mit fullgraph=True ohne Ausnahme kompiliert, sodass das gesamte Modell als statischer Graph mit maximaler Geschwindigkeit ausführbar ist.

Compile-Checkliste

  • Modell mit torch.compile() umschließen: Diese kleine Änderung kann häufig spürbare Beschleunigungen bringen. Verwenden Sie standardmäßig das Backend inductor, sofern kein klarer Grund für ein anderes Backend besteht.
  • Vor Messungen aufwärmen: Führen Sie mehrere Iterationen aus, bevor Sie Zeiten messen, da die ersten Durchläufe Compilation-Aufwand enthalten.
  • Graph Breaks prüfen: Testen Sie während der Entwicklung fullgraph=True und nutzen Sie TORCH_LOGS="graph_breaks", um nicht unterstützte Codepfade zu finden. Refaktorieren oder entfernen Sie diese Stellen.
  • Dynamische Shapes abstimmen: Wenn in Logs oder Ausgaben wiederholte Recompilations sichtbar werden, kann dynamic=True sinnvoll sein, um einen flexibleren Graphen zu erzeugen.
  • Beschleunigung messen: Vergleichen Sie den Durchsatz mit der Baseline. Die erste Iteration kann langsamer sein, aber die stabile Ausführung sollte schneller sein als im Eager Mode. Falls nicht, können Performance-Hinweise mit TORCH_LOGS="perf_hints" aktiviert werden.

Nachdem das Modell kompiliert wurde und eine bessere Geschwindigkeit bestätigt ist, sollte die Ausführung genauer untersucht werden, um verbleibende Engpässe zu finden und gezielt zu beheben.

Profile: Engpässe mit torch.profiler erkennen

Auch nach der Compilation können Performance-Probleme bestehen bleiben, etwa schlecht ausgelastete GPUs, I/O-Engpässe oder ineffiziente Kernel. PyTorch enthält einen integrierten Profiler, der Ausführungstraces erfasst und bei der Analyse solcher Probleme hilft. In PyTorch 2.9 kann torch.profiler CPU- und GPU-Aktivität aufzeichnen, Tensor-Shapes erfassen und mit Werkzeugen wie Chrome Trace Viewer oder TensorBoard zur Visualisierung genutzt werden. Zudem unterstützt der Profiler asynchrone Trace-Erfassung, sodass Teile des Trainings untersucht werden können, ohne den Programmablauf zu unterbrechen.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.profiler
# -----------------------------
# Device setup
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Model definition
# -----------------------------
model = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
).to(device)
# -----------------------------
# Optimizer and loss
# -----------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# -----------------------------
# Data iterator (synthetic)
# -----------------------------
def data_generator(batch_size=32):
    while True:
        inputs = torch.randn(batch_size, 100, device=device)
        targets = torch.randint(0, 10, (batch_size,), device=device)
        yield inputs, targets
data_iter = data_generator()
# -----------------------------
# Training step
# -----------------------------
def train_step(batch):
    model.train()
    inputs, targets = batch
    optimizer.zero_grad(set_to_none=True)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss
# -----------------------------
# Warm-up phase
# -----------------------------
for _ in range(5):
    train_step(next(data_iter))
# -----------------------------
# Profiling phase
# -----------------------------
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    record_shapes=True,
    profile_memory=True,
) as prof:
    for _ in range(3):
        train_step(next(data_iter))
        prof.step()
# -----------------------------
# Export trace
# -----------------------------
prof.export_chrome_trace("trace.json")
print("Profiler trace saved to trace.json")

Dieses Beispiel definiert eine minimale, aber vollständige PyTorch-Trainingsschleife und nutzt torch.profiler, um stabile CPU- und GPU-Performance-Daten zu erfassen. Es erstellt ein kleines neuronales Netzwerk, erzeugt synthetische Daten über einen Iterator und kapselt eine Trainingsiteration in der Funktion train_step().

Zunächst werden mehrere Warm-up-Schritte ausgeführt, damit einmalige Initialisierungs- und Cache-Effekte nicht in die Messungen einfließen. Anschließend werden einige Trainingsschritte profiliert, wobei Operator-Shapes und Speichernutzung aufgezeichnet werden. Zusätzlich exportiert das Skript eine Chrome-Trace-Datei namens trace.json, die im Browser geöffnet werden kann, um GPU-Auslastung, Kernel-Starts, CPU-GPU-Überlappung und weitere Performance-Engpässe zu untersuchen. Die Optionen record_shapes=True und profile_memory=True helfen dabei, Shape-bezogene Probleme und Speicherzuweisungen zu erkennen, was besonders bei Ineffizienzen oder Out-of-Memory-Fehlern nützlich ist.

Tipp: Öffnen Sie chrome://tracing in Google Chrome und laden Sie trace.json, um die Ausführung als Timeline zu betrachten.

Für eine automatische Engpassanalyse können torch.utils.bottleneck oder torch.profiler.schedule verwendet werden, um periodische Ausschnitte zu erfassen. Das folgende Beispiel profiliert einige Schritte pro Epoche mithilfe eines Zeitplans:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.profiler
# -----------------------------
# Device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Model
# -----------------------------
model = nn.Sequential(
    nn.Linear(100, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
).to(device)
# -----------------------------
# Optimizer and loss
# -----------------------------
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# -----------------------------
# Dataset and DataLoader
# -----------------------------
num_samples = 10_000
batch_size = 32
inputs = torch.randn(num_samples, 100)
targets = torch.randint(0, 10, (num_samples,))
dataset = TensorDataset(inputs, targets)
data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
)
# -----------------------------
# Training step
# -----------------------------
def train_step(batch):
    model.train()
    inputs, targets = batch
    inputs = inputs.to(device, non_blocking=True)
    targets = targets.to(device, non_blocking=True)
    optimizer.zero_grad(set_to_none=True)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    return loss
# -----------------------------
# Warm-up (outside profiler)
# -----------------------------
for i, batch in enumerate(data_loader):
    train_step(batch)
    if i >= 5:
        break
# -----------------------------
# Profiling with schedule
# -----------------------------
with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(
        wait=1,
        warmup=1,
        active=3,
        repeat=2,
    ),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./prof_log"),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for step, batch in enumerate(data_loader):
        loss = train_step(batch)
        prof.step()

        if step >= 50:
            break

Diese Konfiguration erfasst asynchron zwei Trace-Fenster mit jeweils drei aktiven Schritten, nachdem zunächst eine Warte- und Warm-up-Phase durchlaufen wurde. Der Callback on_trace_ready schreibt die Traces an einen Speicherort, der über den Profiling-Bereich von TensorBoard ausgewertet werden kann, ohne das Training anzuhalten. Die asynchrone Erfassung ermöglicht periodisches Profiling, etwa alle paar Schritte innerhalb eines größeren Trainingsfensters, und reduziert dadurch den Overhead, während das Laufzeitverhalten weiterhin sichtbar bleibt.

Nach dem Erfassen der Profiling-Daten sollten die Ergebnisse für konkrete Optimierungsentscheidungen genutzt werden:

  • CPU-gebundene Ausführung: Wenn die GPU häufig wartet, sollte geprüft werden, ob mehr Arbeit auf die GPU verlagert oder Data Loading mit Berechnungen überlappt werden kann, etwa durch Worker-Prozesse im DataLoader oder durch Preprocessing auf der GPU. Prüfen Sie außerdem, ob Python-Schleifen während des Trainings relevante Arbeit ausführen, und entfernen oder kompilieren Sie diese nach Möglichkeit.
  • GPU-Kernel als Engpass: Optimieren Sie betroffene Operationen beispielsweise durch fusionierte Kernel, niedrigere Präzision oder andere Kernel-nahe Verbesserungen. Die PyTorch-Performance-Empfehlungen bevorzugen in der Regel hochstufige Operationen wie torch.nn.functional, damit Bibliotheken diese effizient optimieren können. Wenn eine einzelne Operation deutlich langsamer ist als erwartet, sollte geprüft werden, ob dies fachlich begründet ist oder ob eine bessere Berechnungsvariante existiert.
  • Mehrere Shapes verursachen erneutes Tracing: Gruppieren Sie Eingaben gegebenenfalls nach Größe oder verwenden Sie wie beschrieben dynamic=True.
  • Externe Engpässe ausschließen: Bei verteiltem Training sollte geprüft werden, ob Kommunikationsaufwand die Laufzeit dominiert. Hinweise darauf sind NCCL-Kernel oder CPU-Zeit, die durch Warten auf Netzwerkkommunikation entsteht.

Profiling-Checkliste

  • Vor dem Profiling aufwärmen: Compilation- und Lazy-Initialization-Overhead sollten nicht mitgemessen werden.
  • CPU- und GPU-Aktivität erfassen: Mit activities=[CPU, CUDA] wird sichtbar, wie beide Bereiche zusammenspielen.
  • Shapes und Speicher aufzeichnen: Das hilft, Shape-Probleme und Speicherspitzen zu identifizieren.
  • Scheduling für lange Läufe verwenden: profiler.schedule kann regelmäßig kurze Zeitfenster erfassen, wodurch Overhead reduziert und Sichtbarkeit erhalten bleibt.
  • Auslastung analysieren: Stellen Sie sicher, dass GPUs vollständig genutzt werden. Falls nicht, ermitteln Sie, ob CPU oder I/O den Durchsatz begrenzen.
  • Teuerste Operationen finden: Der Profiler kann Operationen nach Self-Time und anderen Metriken sortieren. Konzentrieren Sie Optimierungen auf die Operationen mit dem größten Zeitanteil, etwa durch bessere Algorithmen oder größere Batch-Größen, wenn GPUs unterausgelastet sind.

Wenn die Anwendung auf einer einzelnen Maschine optimiert ist, besteht der nächste Schritt darin, über mehrere GPUs oder Nodes zu skalieren. Dafür kommen die verteilten Trainingsstrategien von PyTorch ins Spiel.

Scale: Verteiltes Training mit DDP oder FSDP

Wenn Workload oder Modell über eine einzelne GPU hinauswachsen, muss das Training über mehrere Geräte skaliert werden. PyTorch stellt dafür zwei zentrale Methoden bereit: Distributed Data Parallel und Fully Sharded Data Parallel. Beide Ansätze basieren auf Datenparallelität, bei der jeder Prozess einen anderen Teil der Daten verarbeitet. Sie unterscheiden sich jedoch stark darin, wie Modellparameter gespeichert und synchronisiert werden. Die folgenden Abschnitte erklären, wann welcher Ansatz geeignet ist und wie er für Single-Node- und Multi-Node-Training mit torchrun konfiguriert wird.

Distributed Data Parallel

DDP hält auf jeder GPU eine vollständige Kopie des Modells und synchronisiert Gradienten nach jedem Backward-Pass per All-Reduce. Dieser Ansatz eignet sich besonders für Modelle, die problemlos in den Speicher einer einzelnen GPU passen. Das Konzept ist vergleichsweise einfach: Zuerst wird eine Process Group initialisiert, anschließend wird das Modell mit torch.nn.parallel.DistributedDataParallel umschlossen.

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
# -----------------------------
# Setup distributed (safe)
# -----------------------------
def setup_distributed():
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    if world_size > 1:
        dist.init_process_group(backend="nccl")
        local_rank = int(os.environ["LOCAL_RANK"])
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
        distributed = True
    else:
        local_rank = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        distributed = False

    return distributed, local_rank, device

def cleanup_distributed(distributed):
    if distributed:
        dist.destroy_process_group()

# -----------------------------
# Main training
# -----------------------------
def main():
    distributed, local_rank, device = setup_distributed()

    # -----------------------------
    # Model
    # -----------------------------
    model = nn.Sequential(
        nn.Linear(100, 256),
        nn.ReLU(),
        nn.Linear(256, 10),
    ).to(device)
    if distributed:
        model = DDP(model, device_ids=[local_rank])
    # -----------------------------
    # Optimizer and loss
    # -----------------------------
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    # -----------------------------
    # Dataset
    # -----------------------------
    num_samples = 20_000
    batch_size = 32
    inputs = torch.randn(num_samples, 100)
    targets = torch.randint(0, 10, (num_samples,))
    dataset = TensorDataset(inputs, targets)
    if distributed:
        sampler = DistributedSampler(dataset, shuffle=True)
    else:
        sampler = None
    train_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=2,
        pin_memory=True,
    )

    # -----------------------------
    # Training loop
    # -----------------------------
    epochs = 3
    for epoch in range(epochs):
        if distributed:
            sampler.set_epoch(epoch)

        for step, (x, y) in enumerate(train_loader):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(x)
            loss = loss_fn(outputs, y)
            loss.backward()
            optimizer.step()
            if step % 50 == 0 and local_rank == 0:
                print(
                    f"[Epoch {epoch}] Step {step} | "
                    f"Loss {loss.item():.4f}"
                )
    cleanup_distributed(distributed)
# -----------------------------
# Entry point
# -----------------------------
if __name__ == "__main__":
    main()

Ein DDP-Job wird mit torchrun gestartet. Um beispielsweise auf einer einzelnen Maschine mit vier GPUs zu starten, wird folgender Befehl verwendet:

torchrun --nproc_per_node=4 train.py

Dieser Befehl startet vier Prozesse, jeweils einen pro GPU, und setzt die benötigten Umgebungsvariablen wie Rank und World Size automatisch. Für Multi-Node-Ausführung müssen zusätzlich Node- und Netzwerkdaten angegeben werden:

# On node 0:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=0 --master_addr="<IP of node0>" --master_port=12345 train.py
# On node 1:
torchrun --nnodes=2 --nproc_per_node=4 --node_rank=1 --master_addr="<IP of node0>" --master_port=12345 train.py

Fully Sharded Data Parallel (FSDP)

FSDP geht über DDP hinaus, indem Modellparameter und Optimizer-States über GPUs verteilt werden, statt sie vollständig auf jedem Gerät zu replizieren. In PyTorch 2.9 ist die Nutzung des Wrappers FullyShardedDataParallel vergleichsweise direkt möglich.

Wichtige Punkte

  • Umschließen Sie das Modell oder ausgewählte Submodule mit FSDP, bevor der Optimizer erzeugt wird.
  • Verwenden Sie torch.cuda.set_device(rank) und platzieren Sie das Modell wie bei DDP auf der Ziel-GPU.
  • Optional kann eine Auto-Wrap-Policy angegeben werden, wenn Sharding auf Modulebene gewünscht ist, etwa für einzelne Transformer-Blöcke. Ohne eine solche Policy behandelt FSDP(model) das gesamte Modell als eine Shard-Einheit, was bei tiefen Modellen häufig ineffizient ist.

Dieser Ansatz kann Submodule oberhalb eines gewählten Parameter-Schwellenwerts, beispielsweise 100.000 Parameter, rekursiv in eigene FSDP-Shards aufteilen. Gibt es zu wenige Shards und das gesamte Modell wird als ein großer Shard behandelt, bleiben die Speichereinsparungen zwischen Layern begrenzt. Zu viele Shards erhöhen hingegen den Kommunikationsaufwand. Die folgende Tabelle fasst die wichtigsten Unterschiede zwischen DDP und FSDP zusammen:

Aspekt DDP (Distributed Data Parallel) FSDP (Fully Sharded Data Parallel)
Startmethode Wird mit torchrun gestartet, ein Prozess pro GPU. Wird ebenfalls mit torchrun gestartet, ein Prozess pro GPU.
Modellreplikation Jeder Rank hält eine vollständige Kopie des Modells. Das Modell wird über Ranks aufgeteilt, sodass jeder Rank nur einen Teil der Parameter hält.
Speichernutzung pro GPU Hoch, da sie von der vollständigen Modellgröße abhängt. Deutlich niedriger, näherungsweise Modellgröße geteilt durch die Anzahl der GPUs.
Kommunikationsmuster Gradienten werden nach dem Backward-Pass per All-Reduce synchronisiert. Parameter werden vor dem Forward-Pass per All-Gather gesammelt, Gradienten nach dem Backward-Pass per Reduce-Scatter verteilt.
Anforderung an das Wrapping Alle Ranks umschließen das vollständige Modell mit DistributedDataParallel. Alle Ranks müssen dieselbe FSDP-Wrapping-Logik anwenden, damit jeder Prozess seine Shards korrekt kennt.
Bedienbarkeit Einfach einzurichten und nur mit moderaten Codeänderungen verbunden. Komplexer und abhängig von sorgfältigen Wrapping-Policies sowie korrektem Optimizer-Setup.
Skalierbarkeit Durch den Speicher jeder einzelnen GPU begrenzt und daher weniger geeignet für extrem große Modelle. Für sehr große Modelle ausgelegt, die nicht auf eine einzelne GPU passen.
Typischer Einsatzfall Modelle, die bequem in den GPU-Speicher passen und vor allem schneller trainiert werden sollen. Sehr große Modelle oder Szenarien, in denen Speichereffizienz entscheidend ist.

Als praktischer Einstieg ist DDP häufig die einfachere Wahl. Wenn Speicherdruck entsteht oder größere Modelle benötigt werden, sollte FSDP geprüft werden. In PyTorch 2.9 ist FSDP weiter gereift. Standardmäßig arbeitet es effektiv nach dem Prinzip „alles sharden“ im Stil von ZeRO Stage 3 und nutzt Voreinstellungen wie limit_all_gathers=True, um unerwartete Speicherspitzen zu reduzieren.

Skalierungs-Checkliste

  • Determinismus sicherstellen: Verwenden Sie auf allen Ranks denselben Zufalls-Seed, beispielsweise torch.manual_seed(seed + rank_offset), wenn Reproduzierbarkeit erforderlich ist.
  • DistributedSampler verwenden: Teilen Sie den Datensatz auf die Ranks auf und rufen Sie pro Epoche set_epoch() auf, damit die Daten in jeder Epoche unterschiedlich gemischt werden.
  • DDP: Gradienten werden automatisch über GPUs synchronisiert, nachdem backward() auf einem mit DDP umschlossenen Modell ausgeführt wurde. Denken Sie daran, Lernrate oder Batch-Größe anzupassen, wenn die Anzahl der GPUs steigt.
  • FSDP: Module müssen vor dem Erstellen des Optimizers umschlossen werden. Passen Sie auto_wrap_policy so an, dass bei tiefen Modellen kein einzelner übergroßer Shard entsteht. Prüfen Sie außerdem, ob die Shard-Größen tatsächlich in den GPU-Speicher passen.
  • Kommunikationsbackend: Bei Multi-Node-Läufen muss die Netzwerkkonfiguration stimmen, typischerweise mit NCCL. Umgebungsvariablen wie NCCL_P2P_LEVEL=NVL können bei vorhandenen NVLink-Verbindungen die Kommunikation verbessern.
  • Gradient Accumulation: Wenn sie mit FSDP eingesetzt wird, sollte model.no_sync() in Iterationen verwendet werden, in denen die Gradientensynchronisierung übersprungen werden soll, ähnlich wie bei DDP.

Checkpoint: Training mit Distributed Checkpoints zuverlässig wiederherstellen

Bei langen oder groß angelegten Trainingsläufen ist Checkpointing unverzichtbar. Es kann nötig sein, nach einem Absturz fortzufahren, aus einem Zwischenstand weiterzutrainieren oder Zwischenergebnisse zu untersuchen. PyTorch 2.9 bietet ausgereifte Unterstützung für Distributed Checkpointing über DCP, wodurch Speichern und Laden effizienter und zuverlässiger werden als bei traditionellen Ansätzen mit torch.save. Vor dem empfohlenen Ansatz lohnt sich der Vergleich beider Methoden.

Klassisches torch.save / torch.load

In klassischen Single-GPU- oder DDP-Trainingsskripten wird üblicherweise das state_dict() des Modells in eine Datei geschrieben, meist nur von Rank 0, da alle Prozesse in der Regel auf denselben Speicherort zugreifen können. Beispiel:

# On rank 0 only:
torch.save(model.state_dict(), "checkpoint.pt")

Und zum Wiederherstellen:

model.load_state_dict(torch.load("checkpoint.pt"))

Das funktioniert gut für kleine Modelle oder einfache Setups, bei denen ein Prozess den vollständigen Modellzustand im Speicher halten kann. In verteiltem Training wird dieser Ansatz jedoch problematisch, wenn das Modell über Prozesse hinweg geshardet ist.

Distributed Checkpoint (DCP)

Das Modul torch.distributed.checkpoint in PyTorch, das in früheren Versionen eingeführt und bis PyTorch 2.9 weiter ausgebaut wurde, adressiert diese Einschränkungen. Es parallelisiert das Speichern, indem jeder Rank seinen eigenen Teil des Modellzustands schreibt. Dadurch entstehen mehrere Dateien, die gemeinsam einen Checkpoint bilden. Außerdem unterstützt DCP Resharding beim Laden: Ein auf N Ranks gespeicherter Checkpoint kann später auf M Ranks geladen werden, wobei Sammeln und Neuverteilung automatisch erfolgen.

Ein typisches Nutzungsmuster besteht darin, Modell und Optimizer in einen Stateful-Container zu legen, der ein state_dict bereitstellt. Hilfsfunktionen wie get_state_dict und set_state_dict sind dabei FSDP-kompatibel:

import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

# Define a stateful container for model & optimizer
class AppState(Stateful):
    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer
    def state_dict(self):
        model_sd, opt_sd = get_state_dict(self.model, self.optimizer)
        return {"model": model_sd, "optim": opt_sd}
    def load_state_dict(self, state):
        set_state_dict(self.model, self.optimizer, state["model"], state.get("optim"))

In diesem Beispiel implementiert AppState das Stateful-Interface, sodass DCP weiß, wie Zustand gelesen und wiederhergestellt wird. Wenn das Modell geshardet ist, liefert get_state_dict automatisch den passenden geshardeten Zustand für jeden Rank.

Speichern mit DCP

Alle Ranks führen Folgendes aus:

app_state = AppState(model, optimizer)
dcp.save(app_state, checkpoint_id="mycheckpoint")

Dadurch werden Checkpoint-Verzeichnisse erzeugt, deren Dateien mit mycheckpoint beginnen. Jeder Rank schreibt seinen Shard parallel. Im Vergleich zum Speichern durch nur einen Rank kann das die Checkpoint-Zeit deutlich reduzieren.

Laden mit DCP

Um das Training fortzusetzen, werden Modell und Optimizer erneut erstellt, in AppState verpackt und anschließend geladen:

app_state = AppState(model, optimizer)
dcp.load(app_state, checkpoint_id="mycheckpoint");

Wenn die Anzahl der Ranks von der ursprünglichen Speicherkonfiguration abweicht, übernimmt DCP die Neuverteilung der Shards automatisch. Wird beispielsweise ein auf acht GPUs gespeicherter Checkpoint auf vier GPUs geladen, kann jeder Rank zwei zuvor gespeicherte Shards laden. Steigt die Anzahl der Ranks, verteilt DCP die Daten entsprechend neu. Dieses integrierte Resharding macht eine manuelle Checkpoint-Konvertierung überflüssig.

Asynchrones Checkpointing

Eine wichtige Funktion in PyTorch 2.9 ist das asynchrone Speichern von Checkpoints. Dadurch können Checkpoint-Schreibvorgänge parallel zum laufenden Training ausgeführt werden. Mit dcp.async_save gibt die API ein Future zurück und erledigt die I/O-lastige Arbeit in Hintergrund-Threads. Das Muster sieht folgendermaßen aus:

save_future = dcp.async_save(app_state, checkpoint_id="chkpt_epoch10")
# ... training can continue immediately ...
save_future.wait()  # later, wait for completion (or periodically check)

async_save kopiert die Modell- und Optimizer-Zustände zunächst lokal auf die CPU in gepinnte Speicherpuffer und schreibt diese anschließend asynchron auf den Speicher. Dadurch kann die GPU mit Berechnungen fortfahren, während der Checkpoint geschrieben wird.

Der Nachteil ist ein temporärer zusätzlicher Speicherbedarf. Asynchrones Speichern kann den Speicherbedarf für Modellzustände während des Speichervorgangs ungefähr verdoppeln, weil jeder Rank CPU-seitige Puffer in etwa der Größe seines Checkpoint-Shards anlegt. Bei sehr großen Modellen und vielen Ranks können CPU-RAM und gepinnter Speicher dadurch deutlich stärker belastet werden.

Die folgende Tabelle vergleicht die verfügbaren Checkpointing-Ansätze und zeigt, wann sie sinnvoll sind:

Checkpoint-Methode Beschreibung Wann verwenden?
torch.save/torch.load (Rank 0) Single-File-Checkpoint, der von einem Prozess geschrieben wird. Wenn das Modell geshardet ist, wird der Zustand gesammelt und meist von Rank 0 als eine Datei gespeichert. Der Ansatz ist einfach und funktioniert gut für nicht geshardete Modelle. Geeignet für kleine bis mittlere Modelle oder Single-GPU-Training. Auch für DDP kann dieser Ansatz funktionieren, solange die Modellgröße moderat bleibt, da jeder Rank eine vollständige Kopie besitzt. Für sehr große oder stark verteilte Modelle ist er weniger geeignet, da er langsam werden oder Out-of-Memory-Probleme verursachen kann.
DCP (synchron) Verteiltes Checkpointing mit parallelem Speichern. Jeder Rank schreibt seinen eigenen Shard des State Dictionary. Das Training pausiert, bis alle Ranks den Schreibvorgang abgeschlossen haben. Unterstützt FSDP, ShardedTensor und Laden mit unterschiedlichen World Sizes durch automatisches Resharding. Es entstehen mehrere Dateien. Geeignet für große Modelle in Multi-GPU-Setups mit DDP oder FSDP. Empfohlen, wenn ein Checkpoint über nur einen Rank zu langsam oder zu groß wäre. Synchrones Speichern ist sinnvoll, wenn das Training kurze Pausen toleriert, etwa zwischen Epochen, oder wenn das Speichersystem schnell genug ist.
DCP + Async (async_save) Verteiltes Checkpointing mit Hintergrund-Schreibvorgängen. Die Funktion kehrt sofort zurück und speichert asynchron, sodass das Training weiterlaufen kann, während Daten geschrieben werden. Zusätzlicher Speicher für Staging-Puffer ist erforderlich, und überlappende Speichervorgänge müssen sorgfältig kontrolliert werden. Geeignet für sehr große Modelle oder Trainingsjobs, bei denen Checkpoint-Zeit einen relevanten Anteil der Laufzeit ausmacht. Besonders nützlich in produktionsnahen Workflows, in denen Pausen zu teuer wären. Ausreichend CPU-RAM und gepinnter Speicher sollten vorhanden sein, und das Verhalten sollte in der jeweiligen Umgebung getestet werden, da Speicher-I/O das Training indirekt beeinflussen kann.

Checkpointing-Checkliste

  • Alle erforderlichen Zustände einbeziehen: Speichern Sie Modellgewichte, Optimizer-State und bei Bedarf Scheduler-State, RNG-State sowie weitere Trainingskontexte, damit das Training vollständig fortgesetzt werden kann.
  • Wiederherstellung testen: Laden Sie den Checkpoint direkt nach dem Speichern in einer sauberen Umgebung, idealerweise sowohl mit derselben GPU-Anzahl als auch mit einer anderen GPU-Anzahl in verteilten Setups. So fallen fehlende Zustände früh auf.
  • Speicherplatz verwalten: Checkpoints können sehr groß werden, besonders in verteilten Setups. Aufbewahrungsregeln, etwa nur die letzten k Checkpoints zu behalten, helfen dabei, Speicherüberläufe zu vermeiden.
  • Asynchronen Modus vorsichtig nutzen: async_save ist leistungsfähig, aber mehrere gleichzeitige Speichervorgänge sollten nur mit sehr sorgfältigem Speichermanagement zugelassen werden. Rufen Sie auf dem zurückgegebenen Future schließlich immer .wait() auf, damit Fehler sichtbar werden und der Erfolg bestätigt wird.
  • Konsistenz sicherstellen: Beim Wiederherstellen müssen alle Ranks denselben dcp.load-Aufruf ausführen. Das ist besonders bei FSDP wichtig, da das Laden koordinierte Synchronisierung erfordert. Außerdem sollte das Modell bereits auf den erwarteten Geräten liegen, wenn Offloading oder ähnliche Mechanismen verwendet werden.

Fazit

Dieser Workflow hilft dabei, PyTorch-Trainingscode zu entwickeln, der schnell arbeitet und zugleich gegen Ausfälle abgesichert ist:

  • Mit einer funktionierenden Baseline beginnen.
  • Für bessere Ausführungsgeschwindigkeit kompilieren.
  • Durch Profiling weitere Optimierungsmöglichkeiten erkennen.
  • Mit der passenden daten- oder modellparallelen Strategie skalieren.
  • Checkpointing korrekt umsetzen, damit Training nach Unterbrechungen wiederhergestellt werden kann.

Jede Phase besitzt ihren eigenen praktischen Ablauf. Nutzen Sie diese Schritte iterativ. Profilen Sie beispielsweise erneut, nachdem Skalierung das Performance-Profil verändert hat, oder prüfen Sie Compile-Einstellungen erneut, wenn sich die Umgebung ändert. PyTorch hat sich schnell weiterentwickelt und stellt heute Werkzeuge für jeden dieser Schritte bereit. Wer versteht, wann und wie sie eingesetzt werden, kann große Modelle effizient und skalierbar trainieren und gleichzeitig eine robuste Wiederherstellung sicherstellen.

FAQs

Warum ist es wichtig, vor der Optimierung eine Baseline zu erstellen?

Eine zuverlässige Single-GPU-Baseline im Eager Mode liefert einen vertrauenswürdigen Referenzpunkt für Korrektheit und Performance. Ohne sie können Optimierungstechniken zugrunde liegende Probleme verdecken. Dadurch wird es schwierig zu beurteilen, ob Performance-Veränderungen tatsächlich auf Verbesserungen zurückzuführen sind oder durch versteckte Fehler entstehen. Eine validierte Baseline ermöglicht es, die Auswirkungen jeder Optimierung genau zu messen und Regressionen sicher zu erkennen.

Wann sollte ich torch.compile verwenden, und worauf sollte ich achten?

torch.compile sollte am besten erst eingesetzt werden, nachdem eine stabile Performance-Baseline erstellt wurde. Es kann die Ausführungsgeschwindigkeit bei stabilem Training und Inferenz deutlich verbessern, erfordert jedoch sorgfältige Überwachung. Achte auf Graph Breaks, plane vor Benchmarks ausreichend Warm-up-Zeit ein und gehe bewusst mit dynamischen Shapes um, um unnötigen Recompilation-Overhead zu vermeiden.

Wie hilft torch.profiler über die Validierung von Annahmen hinaus?

torch.profiler bietet detaillierte Einblicke in die tatsächliche Anwendungsperformance und macht Engpässe sichtbar, die durch reine Code-Inspektion oft nicht erkennbar sind. Es kann Probleme wie CPU-Bottlenecks, ineffiziente GPU-Kernels, wiederholtes Graph-Retracing, Speicherineffizienzen und Kommunikations-Overhead in verteilten Workloads identifizieren. Dadurch wird eine datenbasierte Optimierung möglich, statt sich auf Annahmen oder Intuition zu verlassen.

Wie entscheide ich zwischen DDP und FSDP für verteiltes Training?

Wenn dein Modell problemlos in den verfügbaren Speicher jeder GPU passt, ist Distributed Data Parallel (DDP) aufgrund seiner Einfachheit und starken Performance meist die bevorzugte Option. Für sehr große Modelle, die sich den GPU-Speichergrenzen nähern oder diese überschreiten, bietet Fully Sharded Data Parallel (FSDP) eine bessere Skalierbarkeit, indem Modellparameter über mehrere Geräte verteilt werden. Dafür bringt FSDP jedoch zusätzliche Konfigurations- und Betriebskomplexität mit sich.

Warum wird Distributed Checkpointing bei großskaligem Training gegenüber torch.save bevorzugt?

Distributed Checkpointing ist für große Multi-GPU- und Multi-Node-Umgebungen ausgelegt. Im Gegensatz zu torch.save kann es Checkpoint-Operationen über mehrere Ranks parallelisieren, Parameter beim Laden neu sharden und asynchrone Checkpoints erstellen. Diese Funktionen verbessern die Checkpoint-Performance, reduzieren Trainingsunterbrechungen und ermöglichen eine zuverlässigere Wiederherstellung bei großskaligen Trainings-Workloads.

Quelle: digitalocean.com

Jetzt 200€ Guthaben sichern

Registrieren Sie sich jetzt in unserer ccloud³ und erhalten Sie 200€ Startguthaben für Ihr Projekt.

Das könnte Sie auch interessieren:

Moderne Hosting Services mit Cloud Server, Managed Server und skalierbarem Cloud Hosting für professionelle IT-Infrastrukturen

Coreflux MQTT Broker mit Managed Databases einrichten

Databases, Tutorial
Vijonavor 3 Stunden Coreflux MQTT Broker mit Managed Databases für IoT-Datenverarbeitung bereitstellen MQTT Broker verbinden IoT-Geräte und Anwendungen über ein Publish-Subscribe-Messaging-Modell und sind damit ein zentraler Bestandteil moderner IoT-Infrastrukturen. Coreflux erweitert…
Moderne Hosting Services mit Cloud Server, Managed Server und skalierbarem Cloud Hosting für professionelle IT-Infrastrukturen

NVIDIA DGX B300 erklärt: Blackwell GPU, Specs & KI Performance

AI/ML, Tutorial
VijonaHeute um 8:04 Uhr NVIDIA DGX B300: Architektur, Funktionen, Spezifikationen und ideale Einsatzbereiche Cloud-Anbieter und Infrastrukturplattformen arbeiten kontinuierlich daran, moderne Technologien in ihren Umgebungen bereitzustellen. Diese Entwicklung reicht von grundlegenden Cloud-Services…