5 Advanced PyTorch Tools to Level up Your Workflow

PyTorch is awesome. Since its inception, it has established itself as one of the leading deep learning frameworks, next to TensorFlow. Its ease of use and dynamic define-by-run nature was especially popular among researchers, who were able to prototype and experiment faster than ever.

Since the beginnings, it has undergone explosive progress, becoming much more than a framework for fast prototyping. In this post, my aim is to introduce you to five tools that can help you improve your development and production workflow with PyTorch.

To give you a quick rundown, we will take a look at these.

  • Hooks
  • PyTorch Lightning
  • Quantization
  • Pruning
  • TorchScript + JIT


To start off, let’s talk about hooks, which are one of the most useful built-in development tools in PyTorch. Have you ever littered your forward pass method with print statements and breakpoints to deal with those nasty tensor shape mismatches or mysterious NaN-s appearing in random layers?

Good news: you don’t have to do that. There is a simple and elegant solution. A hook is a function, which can be attached to certain layers. It receives the input of the layer before the forward pass (or backward pass, depending on where you attach it), allowing you to store, inspect or even modify it.

In the example below, you can see how to use hooks to simply store the output of every convolutional layer of a ResNet model.

import torch
from torchvision.models import resnet34
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
def clear(self):
self.outputs = []
model = resnet34(pretrained=True)
save_output = SaveOutput()
hook_handles = []
for layer in model.modules():
if isinstance(layer, torch.nn.modules.conv.Conv2d):
handle = layer.register_forward_hook(save_output)
view raw hooks.py hosted with ❤ by GitHub

If you would like to go into more detail, I have written a detailed guide about hooks.

PyTorch Lightning

If you have used Keras, you know that a great interface can make training models a breeze. Originally, this was not available for PyTorch. However, PyTorch Lightning was developed to fill the void. Although not an official part of PyTorch, it is currently developed by a very active community and has gained significant traction recently.

To demonstrate how it helps you eliminate the boilerplate code which is usually present in PyTorch, here is a quick example, where we train a ResNet classifier on MNIST.

import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.models import resnet18
transform = transforms.Compose(
[transforms.Resize(224), transforms.Grayscale(3), transforms.ToTensor()]
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
train_loader = DataLoader(dataset, batch_size=32)
class ResNetModel(LightningModule):
def __init__(self):
self.model = resnet18(pretrained=False, num_classes=10)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {"train_loss": loss}
return {"loss": loss, "log": tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
model = ResNetModel()
trainer = Trainer(num_nodes=1, max_epochs=50)
trainer.fit(model, train_loader)
view raw pytorch_lightning.py hosted with ❤ by GitHub

In addition, the Trainer class supports multi-GPU training, which can be useful in certain scenarios. There are more many examples in the official documentation.

There is an excellent introduction by the author William Falcon, which I seriously recommend if you are interested.


As neural network architectures became more complex, their computational requirement has increased as well. This made certain models unfeasible in practice. You may want to run the neural network in a mobile application, which has strong hardware limitations. Because of this, significant efforts are being made to overcome such obstacles.

One of the most promising ones is the quantization of networks. In essence, quantization is simply using uint8 instead of float32 or float64. This makes the network smaller and computations faster. Even though there is a trade-off between accuracy and size/speed, the performance loss can be minimal if done right.

PyTorch supports three quantization workflows:

  1. Dynamic quantization, converting weights and inputs to uint8 during computation. This makes it faster, but weights and outputs are still stored as float. (So, no speedup by faster uint8 memory access.)
  2. Post-training static quantization. This converts the entire trained network, also improving the memory access speed. However, this may lead to loss in performance.
  3. Quantization aware training. If the post-training quantization results in a suboptimal performance loss, quantization can be applied during training.

If you are aiming for production, quantization is seriously worth exploring. (Keep in mind that it is currently an experimental feature and can change.)

Further reading:


There are more techniques to speedup/shrink neural networks besides quantization. Even a moderately sized convolutional network contains millions of parameters, making training and inference computationally costly. Since trained networks are inherently sparse, it is a natural idea to simply remove unnecessary neurons to decrease size and increase speed.

Removing weights might not seem to be a good idea, but it is a very effective method. Just think about how a convolutional layer is really a linear layer with a bunch of zero weights. In PyTorch, there are several pruning methods implemented in the torch.nn.utils.prune module. To use them, simply apply the pruning function to the layer to prune:

prune.random_unstructured(nn.Conv2d(3, 16, 3), "weight", 0.5)

This adds a pruning forward pre-hook to the module, which is executed before each forward pass, masking the weights. As a result, computations in this layer will be faster, due to the sparsity of the weights.

Further reading:

TorchScript + JIT

As you know, the internals of PyTorch are actually implemented in C++, using CUDA, CUDNN and other high performance computing tools. This is what makes it really fast. What you use for training is just a Python wrapper on top of a C++ tensor library. This some disadvantages, for instance it adds an overhead to the computations. Python is really convenient for development, however in production, you don’t really need that convenience.

What you need is a way to run your models lightning fast. TorchScript and JIT provides just that. It translates your model into an intermediate representation, which can be used to load it in environments other than Python. In addition, this representation can be optimized further to achieve even faster performance.

To translate your model, you can use




Tracing requires an example input, which is passed to your model, recording the operations in the internal representation meanwhile. However, if your forward pass calculates control flow such as if statements, the representation won’t be correct. If the tracing only touched only one part of the branch, the other branches won’t be present. In these cases, scripting should be used, which analyzes the source code of the model directly.

Further resources:

Have you used any of these in your work? Do you know any best practices or great tutorials? Let us know! 🙂

Until then, let’s level up our PyTorch skills and build something awesome!

Share on facebook
Share on twitter
Share on linkedin

Related posts