Refactoring a Machine Learning Model
This blog post is a tutorial that will take you from a naive implementation of a multilayer perceptron (MLP) in PyTorch to an enlightened implementation that simultaneously leverages the power of PyTorch, Python’s built-ins, and some powerful third party Python packages.
This tutorial is going to assume the following imports for all code blocks:
import itertools as itt
import more_itertools
from torch import nn
from torch.nn import functional as F
itertools
is a builtin library for helping deal with lists, sets, and other iterables.more_itertools
is a third-party extension to itertools, highly regarded in the Python community.- You should already be familiar with PyTorch and
writing your own subclasses of
torch.nn.Module
by implementing your own__init__()
andforward()
functions.
This tutorial isn’t really about the theory nor application of machine learning models - it’s just about the best ways to implement them. I’m also going to commit the sin of omitting docstrings and a lot of type annotations, since most of the MLP should be pretty obvious.
Let’s start with a naive implementation, that reflects some old habits from C or Java programming:
import torch
from torch import nn
from torch.nn import functional as F
class MLP1(nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
layers = []
for i in range(len(dims) - 1):
in_features, out_features = dims[i], dims[i + 1]
layers.append(nn.Linear(in_features, out_features))
self.layers = nn.ModuleList(layers)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
rv = x
for layer in self.layers:
rv = layer(rv)
rv = F.relu(rv)
return rv
Incremental Improvements
MLP1 uses the dreaded range(len(...))
pattern, which can almost always be
replaced with direct iteration. However, in this case, it uses the index to get
the next element with it. Luckily, more_itertools
has a function
pairwise()
that does exactly this. MLP1 can then be refactored into:
import torch
from more_itertools import pairwise
from torch import nn
from torch.nn import functional as F
class MLP2(nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
layers = []
for in_features, out_features in pairwise(dims): # this line changed
layers.append(nn.Linear(in_features, out_features))
self.layers = nn.ModuleList(layers)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
rv = x
for layer in self.layers:
rv = layer(rv)
rv = F.relu(rv)
return rv
The application
of F.relu
in forward()
is suspect for a few reasons:
- Because it lives as a hard-coded call in
forward()
, there’s no way to make it into a hyper-parameter that can be chosen by a user - Because it’s the functional form
F.relu
and notnn.ReLU
, it can’t be stacked with other layers
MLP2 can be refactored to address both of those by using the modular
form nn.ReLU
in the layers after creating each nn.Linear
.
import torch
from more_itertools import pairwise
from torch import nn
class MLP3(nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
layers = []
for in_features, out_features in pairwise(dims):
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU()) # this line changed
self.layers = nn.ModuleList(layers)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
rv = x
for layer in self.layers:
rv = layer(rv)
return rv
Now that the forward()
function is just a successive application of layers, it
can be exchanged with a nn.Sequential
. MLP3 can be refactored to look like:
import torch
from more_itertools import pairwise
from torch import nn
class MLP4(nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
layers = []
for in_features, out_features in pairwise(dims):
layers.append(nn.Linear(in_features, out_features))
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers) # this line changed
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.layers(x)
The following two improvements will make the construction of the layers
list
that goes in nn.Sequential
much more elegant. First, we’ll refactor MLP4 to
use the extend
function of a list rather than append
:
import torch
from more_itertools import pairwise
from torch import nn
class MLP5(nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
layers = []
for in_features, out_features in pairwise(dims):
layers.extend((
nn.Linear(in_features, out_features),
nn.ReLU(),
))
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.layers(x)
An Aside on List Comprehensions
As we prepare to refactor MLP5, we’ll take a short aside to discuss list comprehensions in Python. Here are a few resources to get you started:
- Ned Batchelder - Loop like a native: while, for, iterators, generators
- Trey Hunner - Comprehensible Comprehensions
The minimum amount of information you need to know for this tutorial is that anytime we see code that looks like
old_list = ...
new_list = []
for x in old_list:
new_list.append(transform(x))
we know that we can transform it using a list comprehension like
old_list = ...
new_list = [
transform(x)
for x in old_list
]
There’s an analogous pattern for when we’re successively extending a list, like what we did when writing MLP5. If we see code that looks like
old_list = ...
new_list = []
for x in old_list:
new_list.extend(transform(x))
we can transform it into something more elegant
using itertools.chain.from_iterable()
like
from itertools import chain
old_list = ...
new_list = list(chain.from_iterable(
transform(x)
for x in old_list
))
While this may be a few extra lines (because it’s broken up for readability), it has the advantage that it’s only one logical line and can be used in more clever ways.
Bringing it All Together
We’ll apply this template to our code to get a one-liner for instantiating
our nn.Sequential
(though notice it’s again broken up onto multiple lines for
readability):
from itertools import chain
import torch
from more_itertools import pairwise
from torch import nn
class MLP6(nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
self.layers = nn.Sequential(*chain.from_iterable(
(
nn.Linear(in_features, out_features),
nn.ReLU(),
)
for in_features, out_features in pairwise(dims)
))
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
return self.layers(x)
Finally, since we’re now just creating a module that wraps the exact
functionality of nn.Sequential
, it’s possible to directly
subclass nn.Sequential
. We’ll refactor on MLP6 to get our final result:
from itertools import chain
from more_itertools import pairwise
from torch import nn
class MLP7(nn.Sequential):
def __init__(self, dims: list[int]):
super().__init__(*chain.from_iterable(
(
nn.Linear(in_features, out_features),
nn.ReLU(),
)
for in_features, out_features in pairwise(dims)
))
MLP7 is now a much more simple implementation that uses a few neat tricks to reduce error-prone logic. I hope you enjoy applying these patterns to your own models, and if you have any other ideas you’d like me to include here, please leave comment or get in touch!
While we were originally aiming at reducing complexity, this model still has the
issue that it contains a hard-coded reference to the ReLU non-linear activation
function, which could be easily generalized to support alternate non-linear
activation functions. In my next post,
I’ll demonstrate the thought process behind this and the ultimate solution
using the class-resolver
.