Vision Transformers (ViTs): How Transformers Are Reshaping Computer Vision
In recent years, transformers have dramatically changed the field of natural language processing in machine learning. Architectures such as GPT and BERT have established new standards for understanding and generating human language. That same idea is now being transferred to computer vision. One of the most notable advances in this area is the rise of vision transformers, also known as ViTs. As described in the paper An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale, ViTs and other transformer-driven approaches aim to take the place of convolutional neural networks (CNNs). Vision Transformers introduce a new way to tackle computer vision challenges. Rather than depending on conventional convolutional neural networks, which have supported image-based applications for many years, ViTs apply transformer architecture to image processing. They interpret image patches similarly to how words are treated in a sentence, enabling the model to learn connections between those patches in the same way it learns context within a block of text.
In contrast to CNNs, ViTs split input images into smaller patches, convert them into vectors, and lower their dimensionality through matrix multiplication. A transformer encoder then handles these vectors as token embeddings. In this article, we will examine vision transformers and highlight their main distinctions from convolutional neural networks. One reason they stand out is their strength in capturing global image patterns, something CNNs often find more difficult.
Prerequisites
- Basics of Neural Networks: A general understanding of how neural networks handle data.
- Convolutional Neural Networks (CNNs): Familiarity with CNNs and their importance in computer vision.
- Transformer Architecture: Knowledge of transformers, especially their application in NLP.
- Image Processing: Understanding of core ideas such as image representation, channels, and pixel arrays.
- Attention Mechanism: Familiarity with self-attention and how it models relationships across inputs.
What Are Vision Transformers?
Vision transformers apply attention and transformer mechanisms to image processing, much like transformers do in natural language processing. The difference is that, instead of tokens, an image is broken into patches and supplied as a sequence of linear embeddings. These patches are handled in the same way tokens or words are processed in NLP.
Rather than analyzing the full image at once, a ViT divides the image into smaller parts, similar to pieces of a puzzle. Every piece is converted into a numeric representation, or vector, that captures its features. The model then studies all of these pieces together and determines how they are connected by using a transformer-based mechanism.
Unlike CNNs, ViTs do not rely on applying dedicated filters or kernels across an image to identify particular features such as edges. In CNNs, this convolution process resembles scanning an image line by line. These filters move across the entire image and emphasize important features. The network stacks multiple layers of such filters so that increasingly complex patterns can be recognized over time.
With CNNs, pooling layers shrink the size of feature maps. These layers evaluate the extracted features and help generate predictions for tasks like image recognition and object detection. However, CNNs operate with a fixed receptive field, which limits their ability to capture long-range dependencies.
How CNNs View Images
ViTs, although they often contain more parameters, rely on self-attention for richer feature representation and reduce the dependence on extremely deep architectures. CNNs usually need much deeper networks to reach a similar representational strength, which raises computational cost.
Another limitation of CNNs is that they do not naturally capture image-wide patterns, because their filters focus on local regions. To interpret the whole image or understand distant relationships, CNNs must stack many layers and apply pooling to gradually enlarge their field of view. This gradual aggregation can cause loss of global information along the way.
ViTs take a different route by splitting the image into patches that are treated as separate input tokens. Through self-attention, ViTs compare all patches at the same time and learn how they relate to one another. This enables them to identify dependencies and patterns across the entire image without needing to build them step by step through many layers.
What Is Inductive Bias?
Before moving ahead, it is helpful to understand inductive bias. Inductive bias describes the assumptions a model makes about the structure of data. During training, these assumptions help the model generalize better and reduce bias. In CNNs, inductive biases include the following:
- Locality: Features in images, such as edges or textures, tend to exist within small localized regions.
- Two-dimensional neighborhood structure: Pixels that are close to one another are more likely to be related, so filters operate on neighboring spatial regions.
- Translation equivariance: A feature detected in one area of an image, such as an edge, keeps the same meaning when it appears elsewhere.
These biases make CNNs very effective for image tasks because they are naturally built to exploit the spatial and structural characteristics of images.
Vision Transformers, in comparison, contain much less image-specific inductive bias than CNNs. In ViTs:
- Global processing: Self-attention layers operate across the entire image, allowing the model to capture broad dependencies and relationships without being limited to local areas.
- Minimal 2D structure: The two-dimensional structure of the image is used only at the start, when the image is cut into patches, and during fine-tuning, when positional embeddings are adjusted for different resolutions. Unlike CNNs, ViTs do not assume that nearby pixels must be related.
- Learned spatial relations: Positional embeddings in ViTs do not begin with fixed two-dimensional spatial relationships. Instead, the model learns these relationships directly from the training data.
How Vision Transformers Work
Vision Transformers use the standard Transformer architecture originally created for one-dimensional text sequences. To make this architecture suitable for two-dimensional images, the images are first divided into smaller fixed-size patches, such as P × P pixels, and these patches are flattened into vectors. If an image has dimensions H × W with C channels, then the number of patches becomes N = H × W / P × P, which defines the effective sequence length for the Transformer. These flattened patches are then linearly projected into a fixed-dimensional space D, producing what are known as patch embeddings.
A special learnable token, similar to the [CLS] token in BERT, is added to the front of the sequence of patch embeddings. This token learns a global representation of the image that is later used for classification. Positional embeddings are also added so that the model can retain information about where each patch belongs, which helps preserve the spatial structure of the image.
The resulting sequence of embeddings is then passed through the Transformer encoder. Inside the encoder, two main operations alternate: Multi-Headed Self-Attention (MSA) and a feedforward neural network, often called an MLP block. Each layer applies Layer Normalization (LN) before these operations and uses residual connections afterward to improve training stability. The output of the Transformer encoder, especially the final state of the [CLS] token, serves as the representation of the image.
For classification tasks, a simple head is attached to the final [CLS] token. During pretraining, this head is usually a small multi-layer perceptron (MLP), while during fine-tuning it is often reduced to a single linear layer. This structure allows ViTs to model global relationships among patches effectively and take full advantage of self-attention for image understanding.
In a hybrid Vision Transformer model, the process starts differently. Instead of cutting raw images directly into patches, the input sequence is produced from feature maps created by a CNN. The CNN first processes the image and extracts meaningful spatial information, and those resulting features are then turned into patches. These patches are flattened and projected into a fixed-dimensional space using the same trainable linear projection used in a standard Vision Transformer. One special version of this method uses patches of size 1×1, where every patch represents a single spatial position in the CNN feature map.
In that setup, the spatial dimensions of the feature map are flattened, and the sequence is projected into the Transformer’s input dimension. As in the regular ViT, a classification token and positional embeddings are added to preserve positional information and support global image understanding. This hybrid design combines the local feature extraction strengths of CNNs with the global modeling power of Transformers.
Code Demo
Below is a code example showing how to use vision transformers with images.
# Install the necessary libraries
pip install -q transformers
from transformers import ViTForImageClassification
from PIL import Image
from transformers import ViTImageProcessor
import requests
import torch
# Load the model and move it to ‘GPU’
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.to(device)
# Load the image to perform predictions
url = 'link to your image'
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values
# print(pixel_values.shape)
The ViT model analyzes the image. It includes a BERT-style encoder along with a linear classification head placed on top of the final hidden state of the [CLS] token.
with torch.no_grad():
outputs = model(pixel_values)
logits = outputs.logits
# logits.shape
prediction = logits.argmax(-1)
print("Predicted class:", model.config.id2label[prediction.item()])
Here is a basic implementation of a Vision Transformer (ViT) using PyTorch. The code contains the essential building blocks, including patch embedding, positional encoding, and the Transformer encoder. It can be applied to simple classification tasks.
import torch
import torch.nn as nn
import torch.nn.functional as F
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1):
super(VisionTransformer, self).__init__()
# Image and patch dimensions
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
self.num_patches = (img_size // patch_size) ** 2
self.patch_dim = (3 * patch_size ** 2) # Assuming 3 channels (RGB)
# Layers
self.patch_embeddings = nn.Linear(self.patch_dim, dim)
self.position_embeddings = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(dropout)
# Transformer Encoder
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim, dropout=dropout),
num_layers=depth
)
# MLP Head for classification
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
# Flatten patches and embed
batch_size, channels, height, width = x.shape
patch_size = height // int(self.num_patches ** 0.5)
x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
x = x.contiguous().view(batch_size, 3, patch_size, patch_size, -1)
x = x.permute(0, 4, 1, 2, 3).flatten(2).permute(0, 2, 1)
x = self.patch_embeddings(x)
# Add positional embeddings
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.position_embeddings
x = self.dropout(x)
# Transformer Encoder
x = self.transformer(x)
# Classification Head
x = x[:, 0] # CLS token
return self.mlp_head(x)
# Example usage
if __name__ == "__main__":
model = VisionTransformer(img_size=224, patch_size=16, num_classes=10, dim=768, depth=12, heads=12, mlp_dim=3072)
print(model)
dummy_img = torch.randn(8, 3, 224, 224) # Batch of 8 images, 3 channels, 224x224 size
preds = model(dummy_img)
print(preds.shape) # Output: [8, 10] (Batch size, Number of classes)
Key Components
- Patch Embedding: Images are split into smaller patches, flattened, and linearly mapped into embeddings.
- Positional Encoding: Positional information is added to patch embeddings because Transformers do not inherently understand position.
- Transformer Encoder: Self-attention and feed-forward layers are applied to learn the relationships among patches.
- Classification Head: The class probabilities are produced using the CLS token.
You can train this model on any image dataset by using an optimizer such as Adam and a loss function like cross-entropy. To achieve stronger results, pretraining on a large dataset before fine-tuning is often beneficial.
Popular Follow-up Work
DeiT (Data-efficient Image Transformers) by Facebook AI: These are vision transformers trained efficiently through knowledge distillation. DeiT provides four versions: deit-tiny, deit-small, and two deit-base variants. Use DeiTImageProcessor to prepare images.
BEiT (BERT Pre-Training of Image Transformers) by Microsoft Research: Inspired by BERT, BEiT uses self-supervised masked image modeling and performs better than supervised ViTs. It depends on VQ-VAE during training.
DINO (Self-supervised Vision Transformer Training) by Facebook AI: ViTs trained with DINO can segment objects without explicit supervision. Checkpoints are available online.
MAE (Masked Autoencoders) by Facebook pretrain ViTs by reconstructing masked patches (75%). After fine-tuning, this straightforward method outperforms supervised pretraining.
Conclusion
To conclude, ViTs provide a strong alternative to CNNs by applying transformers to image recognition, reducing inductive bias, and interpreting images as sequences of patches. This straightforward yet highly scalable method has achieved state-of-the-art results on many image classification benchmarks, especially when combined with large-scale pretraining. Still, several challenges remain, including extending ViTs to tasks such as object detection and segmentation, advancing self-supervised pretraining techniques further, and investigating how scaling ViTs even more might lead to stronger performance.


