PyTorch is awesome. Since its inception, it has established itself as one of the leading deep learning frameworks, next to TensorFlow. Its ease of use and dynamic define-by-run nature was especially popular among researchers, who were able to prototype and experiment faster than ever.
Since the beginnings, it has undergone explosive progress, becoming much more than a framework for fast prototyping. In this post, my aim is to introduce you to five tools that can help you improve your development and production workflow with PyTorch.
To give you a quick rundown, we will take a look at these.
- PyTorch Lightning
- TorchScript + JIT
To start off, let’s talk about hooks, which are one of the most useful built-in development tools in PyTorch. Have you ever littered your forward pass method with print statements and breakpoints to deal with those nasty tensor shape mismatches or mysterious NaN-s appearing in random layers?
Good news: you don’t have to do that. There is a simple and elegant solution. A hook is a function, which can be attached to certain layers. It receives the input of the layer before the forward pass (or backward pass, depending on where you attach it), allowing you to store, inspect or even modify it.
In the example below, you can see how to use hooks to simply store the output of every convolutional layer of a ResNet model.
If you would like to go into more detail, I have written a detailed guide about hooks.
If you have used Keras, you know that a great interface can make training models a breeze. Originally, this was not available for PyTorch. However, PyTorch Lightning was developed to fill the void. Although not an official part of PyTorch, it is currently developed by a very active community and has gained significant traction recently.
To demonstrate how it helps you eliminate the boilerplate code which is usually present in PyTorch, here is a quick example, where we train a ResNet classifier on MNIST.
In addition, the
Trainer class supports multi-GPU training, which can be useful in certain scenarios. There are more many examples in the official documentation.
As neural network architectures became more complex, their computational requirement has increased as well. This made certain models unfeasible in practice. You may want to run the neural network in a mobile application, which has strong hardware limitations. Because of this, significant efforts are being made to overcome such obstacles.
One of the most promising ones is the quantization of networks. In essence, quantization is simply using uint8 instead of float32 or float64. This makes the network smaller and computations faster. Even though there is a trade-off between accuracy and size/speed, the performance loss can be minimal if done right.
PyTorch supports three quantization workflows:
- Dynamic quantization, converting weights and inputs to uint8 during computation. This makes it faster, but weights and outputs are still stored as float. (So, no speedup by faster uint8 memory access.)
- Post-training static quantization. This converts the entire trained network, also improving the memory access speed. However, this may lead to loss in performance.
- Quantization aware training. If the post-training quantization results in a suboptimal performance loss, quantization can be applied during training.
If you are aiming for production, quantization is seriously worth exploring. (Keep in mind that it is currently an experimental feature and can change.)
There are more techniques to speedup/shrink neural networks besides quantization. Even a moderately sized convolutional network contains millions of parameters, making training and inference computationally costly. Since trained networks are inherently sparse, it is a natural idea to simply remove unnecessary neurons to decrease size and increase speed.
Removing weights might not seem to be a good idea, but it is a very effective method. Just think about how a convolutional layer is really a linear layer with a bunch of zero weights. In PyTorch, there are several pruning methods implemented in the
torch.nn.utils.prune module. To use them, simply apply the pruning function to the layer to prune:
prune.random_unstructured(nn.Conv2d(3, 16, 3), "weight", 0.5)
This adds a pruning forward pre-hook to the module, which is executed before each forward pass, masking the weights. As a result, computations in this layer will be faster, due to the sparsity of the weights.
TorchScript + JIT
As you know, the internals of PyTorch are actually implemented in C++, using CUDA, CUDNN and other high performance computing tools. This is what makes it really fast. What you use for training is just a Python wrapper on top of a C++ tensor library. This some disadvantages, for instance it adds an overhead to the computations. Python is really convenient for development, however in production, you don’t really need that convenience.
What you need is a way to run your models lightning fast. TorchScript and JIT provides just that. It translates your model into an intermediate representation, which can be used to load it in environments other than Python. In addition, this representation can be optimized further to achieve even faster performance.
To translate your model, you can use
Tracing requires an example input, which is passed to your model, recording the operations in the internal representation meanwhile. However, if your forward pass calculates control flow such as
if statements, the representation won’t be correct. If the tracing only touched only one part of the branch, the other branches won’t be present. In these cases, scripting should be used, which analyzes the source code of the model directly.
- PyTorch TorchScript tutorial
- Research to Production: PyTorch JIT/TorchScript Updates by Michael Suo
- From Research to Production, talk by Jeff Smith at QCon New York 2019
Have you used any of these in your work? Do you know any best practices or great tutorials? Let us know! 🙂
Until then, let’s level up our PyTorch skills and build something awesome!