Graph Neural Networks Explained: Concepts, Applications, and PyTorch Implementation

Graph Neural Networks (GNNs) have emerged as a highly capable deep learning approach for understanding complex data with rich interconnections. In contrast to traditional neural networks that process fixed-size inputs such as images or sequences, GNNs are built to capture relationships that describe how entities are linked, affect one another, and interact inside a graph. This makes them especially useful in fields where structure plays a central role, including social networks, molecular chemistry, recommendation systems, and fraud detection.

In this article, we will explain the foundational ideas behind GNNs, look at how they developed over time, and show the practical problems they are designed to address. You will also see how GNNs are built in practice through a hands-on example implemented with the PyTorch library.

Key Takeaways

  • Graph Neural Networks (GNNs) are highly effective for modeling relationships in interconnected data, making them well suited for tasks where structure and dependency patterns are important.
  • GNNs extend traditional deep learning methods to graph-based data, enabling models to learn from nodes, edges, and the connections between them.
  • Widely used GNN architectures, including Graph Convolutional Networks (GCNs), Graph Attention Networks (GATs), and GraphSAGE, each provide different strategies for aggregating and propagating information across graph nodes.
  • Real-world GNN use cases include social network analysis, molecular property prediction, recommendation systems, fraud detection, and knowledge graph applications.
  • Building GNNs is approachable with modern tools such as PyTorch Geometric, which streamlines graph data processing, message passing, and model development.
  • A practical implementation illustrates the full GNN workflow, including graph data loading, model definition, training, and evaluation.
  • GNNs are advancing quickly, with ongoing improvements in scalability, efficiency, and performance for larger and more complex graph structures.

What Are Graph Neural Networks?

Graph Neural Networks, commonly called GNNs, are a newer class of neural networks designed for data that is organized as graphs. A graph is made up of objects represented as nodes and the relationships between them represented as edges. GNNs are capable of handling both directed graphs, where edges have a defined direction, and undirected graphs, where the connections do not point in a specific direction. These graph structures can also differ greatly in size and form.

A GNN is composed of multiple layers, where each layer builds on the information produced by the one before it. The network takes a graph as input, including its nodes, edges, and related features. The output is a collection of node embeddings, one for each node in the graph. These embeddings capture the learned characteristics of the nodes. Rather than working only with vectors, matrices, or tensors like conventional neural networks, GNNs can directly operate on full graph structures. That flexibility makes them very useful for networked data such as social graphs, molecular models, and transportation networks. Although the underlying mathematics can be sophisticated, the general principle is straightforward: GNNs repeatedly pass information between connected nodes in order to learn meaningful representations.

How Do Graph Neural Networks Work?

Graph neural networks, or GNNs, focus on learning patterns and dependencies among nodes within a network. The central principle is that each node sends messages to its neighboring nodes, sharing information about itself.

  • Inside a GNN, each node exchanges information with the nodes it is connected to, allowing the model to gradually build an understanding of the entire graph structure.
  • Every node constructs a “message” based on its own features together with the features of its neighboring nodes.
  • At the same time, nearby nodes create and send their own messages in return.
  • Once a node receives these incoming messages, it updates its internal representation by combining all of the gathered information.
  • This repeated message-passing process enables information to move through the graph, helping nodes learn about regions of the graph beyond their direct neighbors.
  • By stacking several layers of this process, GNNs can capture deeper and more sophisticated relationships.
  • With every additional layer, the model forms richer and more informative feature representations of the graph.

Implementing a Graph Neural Network in PyTorch

Cora Dataset

The Cora dataset is a well-known benchmark frequently used in graph representation learning research. It contains a collection of scientific publications grouped into seven categories such as “Case Based,” “Genetic Algorithms,” “Neural Networks,” “Probabilistic Methods,” “Reinforcement Learning,” and “Rule Learning.” Cora has been used for many years and still serves as a standard resource for projects in this field. It allows researchers to test how well a model can understand both the textual content of documents and the citation relationships that connect them. A large number of influential graph neural network studies have relied on Cora to measure performance on these combined tasks. In this graph, publications are represented as nodes, while citations between them act as edges. Each document is described by a feature vector that captures its content. The goal is to create a model that uses the citation graph, the content features, and their relationships to predict which of the seven categories a publication belongs to.

Data Preprocessing

We install the PyTorch Geometric library with the following command:

pip install torch_geometric

After that, we can use PyTorch Geometric to load and preprocess the dataset.

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

dataset = Planetoid(root='data/Cora', name='Cora', transform=T.NormalizeFeatures())

The Planetoid class loads the Cora dataset and normalizes its feature vectors. We can access the processed data through the dataset object, which returns a Data object with the following attributes:

  • x: a node feature matrix with shape (num_nodes, num_features)
  • edge_index: an edge connectivity matrix with shape (2, num_edges)
  • y: a vector of node labels with shape (num_nodes)
  • train_mask, val_mask, test_mask: boolean masks that indicate which nodes belong to training, validation, and testing sets.

Model Architecture

When developing a graph neural network, selecting the right architecture is extremely important. Here, we will go through a simple implementation using PyTorch’s torch_geometric library. We will use a graph convolutional network as a practical starting point for graph learning tasks.

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        # Define the first graph convolutional layer
        self.conv1 = GCNConv(in_channels, hidden_channels)
        # Define the second graph convolutional layer
        self.conv2 = GCNConv(hidden_channels, out_channels)
        # Define the linear layer
        self.linear = torch.nn.Linear(out_channels, out_channels)

    def forward(self, x, edge_index):
        # Apply the first graph convolutional layer
        x = self.conv1(x, edge_index)
        # Apply the ReLU activation function
        x = F.relu(x)
        # Apply the second graph convolutional layer
        x = self.conv2(x, edge_index)
        # Apply the ReLU activation function
        x = F.relu(x)
        # Apply the linear layer
        x = self.linear(x)
        # Apply the log softmax activation function
        return F.log_softmax(x, dim=1)

In the code above, we imported torch and torch.nn.functional to access useful neural network tools and functions. Next, we created a GNN class that inherits from torch.nn.Module.

Inside the __init__ method, we defined two convolutional layers using the GCNConv module from PyTorch Geometric. This makes graph convolution straightforward to implement. We also added a basic linear layer.

In the forward pass, the input is processed through the two convolutional layers, with ReLU activation applied after each one. It then moves through the linear layer, and finally through a log softmax function to compress the outputs.

With only a few lines of code, we can create a compact but effective graph neural network.

Of course, this is just a basic example, but it clearly shows how PyTorch and PyTorch Geometric make it easy to prototype and refine graph neural network architectures. The GCNConv layers make it especially simple to integrate graph structure into the model.

Training

For training, we will use cross-entropy loss together with the Adam optimizer. We can divide the data into training, validation, and test sets by using the mask attributes available on the Data object.

# Set the device to CUDA if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define the GNN model with the specified input, hidden, and output dimensions, and move it to the device
model = GNN(dataset.num_features, 16, dataset.num_classes).to(device)
# Define the Adam optimizer with the specified learning rate and weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Define the training function
def train():
    # Set the model to training mode
    model.train()
    # Zero the gradients of the optimizer
    optimizer.zero_grad()
    # Perform a forward pass of the model on the training nodes
    out = model(dataset.x.to(device), dataset.edge_index.to(device))
    # Compute the negative log-likelihood loss on the training nodes
    loss = F.nll_loss(out[dataset.train_mask], dataset.y[dataset.train_mask])
    # Compute the gradients of the loss with respect to the model parameters
    loss.backward()
    # Update the model parameters using the optimizer
    optimizer.step()
    # Return the loss as a scalar value
    return loss.item()

# Define the testing function
@torch.no_grad()
def test():
    # Set the model to evaluation mode
    model.eval()
    # Perform a forward pass of the model on all nodes
    out = model(dataset.x.to(device), dataset.edge_index.to(device))
    # Compute the predicted labels by taking the argmax of the output scores
    pred = out.argmax(dim=1)
    # Compute the training, validation, and testing accuracies
    train_acc = pred[dataset.train_mask].eq(dataset.y[dataset.train_mask]).sum().item() / dataset.train_mask.sum().item()
    val_acc = pred[dataset.val_mask].eq(dataset.y[dataset.val_mask]).sum().item() / dataset.val_mask.sum().item()
    test_acc = pred[dataset.test_mask].eq(dataset.y[dataset.test_mask]).sum().item() / dataset.test_mask.sum().item()
    # Return the accuracies as a tuple
    return train_acc, val_acc, test_acc

# Train the model for 500 epochs
for epoch in range(1, 500):
    # Perform a single training iteration and get the loss
    loss = train()
    # Evaluate the model on the training, validation, and testing sets and get the accuracies
    train_acc, val_acc, test_acc = test()
    # Print the epoch number, loss, and accuracies
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

The train function performs one round of model training and returns the loss value. The test function measures the model’s performance on the training, validation, and test splits and returns the accuracy for each one. We trained the model for 500 epochs and printed both loss and accuracy values after every epoch.

Compute the Accuracy of the GNN Model

The code below defines a function that measures the model’s accuracy across the full dataset. The compute_accuracy() function switches the model into evaluation mode, performs a forward pass, and predicts a label for every node. It then compares those predicted labels to the ground truth labels, counts the correct predictions, and divides that number by the total number of nodes to obtain the overall accuracy.

@torch.no_grad()
def compute_accuracy():
    model.eval()
    out = model(dataset.x.to(device), dataset.edge_index.to(device))
    pred = out.argmax(dim=1)
    correct = pred.eq(dataset.y.to(device)).sum().item()
    total = dataset.y.shape[0]
    accuracy = correct / total
    return accuracy

accuracy = compute_accuracy()
print(f"Accuracy: {accuracy:.4f}")

In this example, the model reached an accuracy of 0.8006 on the Cora dataset. In other words, it correctly predicted the class label roughly 80% of the time. That is a strong result, though it is not flawless. Accuracy provides a fast, high-level snapshot of overall performance, but it does not fully explain where the model performs well and where it struggles.

To better understand how effective the model really is, it is useful to look at additional evaluation metrics such as precision, recall, F1 score, and the confusion matrix. These measures offer more detailed insight into how the model behaves in different situations, such as distinguishing positive and negative cases or handling imbalanced datasets. So while an accuracy of 80% is promising, it is not enough on its own to declare the model a complete success. Accuracy is a valuable starting point, but it does not reveal the full story behind the model’s behavior.

Evaluation

We can assess how well the GNN performs by using metrics such as accuracy, precision, recall, and F1 score. In addition, we can visualize the node embeddings learned by the model with t-SNE. This technique projects high-dimensional embeddings into two dimensions so that they can be displayed visually.

# Import the necessary libraries
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Set the model to evaluation mode
model.eval()

# Perform a forward pass of the model on the dataset
out = model(dataset.x.to(device), dataset.edge_index.to(device))

# Apply t-SNE to the output feature matrix to obtain a 2D embedding
emb = TSNE(n_components=2).fit_transform(out.cpu().detach().numpy())

# Create a figure with a specified size
plt.figure(figsize=(10,10))

# Create a scatter plot of the embeddings, color-coded by the true labels
plt.scatter(emb[:,0], emb[:,1], c=dataset.y, cmap='jet')

# Display the plot
plt.show()

Note: The reader can run the code above and view an interpretable scatter plot. The script uses t-SNE to display the learned node embeddings in a two-dimensional plot, which is an effective way to inspect high-dimensional data. Here is what happens in the visualization:

  • Each point in the chart corresponds to one node in the dataset. The x-axis and y-axis represent the two dimensions produced by t-SNE from the original embeddings. The color assigned to each point reflects the true label of that node.
  • Nodes with similar embeddings are expected to have similar labels, so they should appear grouped together in clusters. By contrast, nodes with very different embeddings are likely to have different labels, causing them to appear farther apart.
  • Overall, the plot gives a visual summary of the relationships between nodes based on the embeddings the model learned. Distinct groups begin to appear, suggesting that certain nodes share common characteristics. This provides a practical way to look inside the model and see how it organizes information.

Potential Challenges and Considerations

  • With 2708 nodes and 5429 edges, the Cora dataset is relatively small. This can reduce the efficiency of a GNN and may require more advanced strategies such as data augmentation or transfer learning.
  • The Cora dataset contains only one node type and one edge type, which makes it a homogeneous graph. This can limit how well the GNN generalizes to more complex networks that include multiple node and edge types.
  • Choosing suitable hyperparameters, including the number of hidden layers, the number of hidden units, and the learning rate, can strongly influence model performance and requires careful tuning.

FAQ

1. What is a Graph Neural Network (GNN)?

A GNN is a neural network designed for graph-structured data. It learns node-level, edge-level, or graph-level representations by combining information from neighboring nodes.

2. How is a GNN different from a traditional neural network?

Traditional neural networks work best with grid-based data such as images and sequences, while GNNs are built for irregular and connected structures like social networks or molecular graphs.

3. What types of problems are GNNs used for?

They are widely applied to node classification, link prediction, graph classification, recommendation systems, fraud detection, and molecular property prediction.

4. What is message passing in GNNs?

Message passing is the core process in which nodes share information with neighboring nodes, update their embeddings, and learn contextual relationships.

5. Do GNNs scale well to very large graphs?

Scaling can be difficult because of memory requirements and neighborhood growth. Methods such as sampling with GraphSAGE, mini-batching, and distributed training help address this issue.

6. What programming libraries are best for implementing GNNs?

Popular frameworks include PyTorch Geometric (PyG) and Deep Graph Library (DGL), both of which provide ready-made layers and utilities for building GNN models.

7. Are GNNs suitable for real-time applications?

Yes, depending on the complexity of the model. Lightweight architectures and sampling-based methods can support near real-time performance.

8. Can GNNs handle dynamic or evolving graphs?

Yes, dynamic GNN variants can process graphs that change over time, which is useful for traffic prediction, temporal recommendation systems, and anomaly detection.

9. What data preprocessing is required for GNNs?

You generally need to prepare adjacency information, node and edge features, and ensure that the graph is correctly formatted for the chosen library.

10. Are GNNs interpretable?

Interpretability is improving through tools such as attention mechanisms and GNNExplainer, which help identify the most influential nodes and edges.

Conclusion

In this article, we covered the main ideas behind Graph Neural Networks (GNNs) and examined how they can be used across a range of domains. GNNs are especially well suited for graph-structured data, allowing models to reason over complex relationships found in social networks, molecular graphs, transportation systems, and many other applications.

To show these ideas in practice, we worked with the widely used Cora dataset, a standard benchmark in graph learning. In this dataset, every publication is represented as a node, while the citation links between publications form the edges. Our goal was to use both the textual features of each paper and the citation relationships to predict its category.

We prepared the data with the PyTorch Geometric library by normalizing the feature vectors and dividing the dataset into training, validation, and test partitions. We then created a simple GNN with graph convolutional layers followed by a linear classifier and trained it using cross-entropy loss together with the Adam optimizer. Finally, we evaluated the model by measuring its accuracy.

Although there are many further techniques and enhancements worth exploring, this project provides a strong introduction to how GNNs work and demonstrates how effective they can be when learning from complex relational data.

Source: digitalocean.com

Create a Free Account

Register now and get access to our Cloud Services.

Posts you might be interested in: