An Easy Way To Change CNN Structure In Running Time With Pytorch
Introduction
In some applications for network compression like pruning, we need to change the model structure in running time, since we have to train a similar new model. However, it’s not easy to change some parts of complex structures like nn.Conv2d
in nn.Sequential
.
In this post, I’ll show you an easy way to change a part of network structures with pytorch in a simple way.
Why it’s hard to change the structures in running time?
First we have to talk from python.
For general python problem
We need to notice that:
- Python has names and objects.
- Assignment binds a name to an object.
- Passing an argument into a function also binds a name (the parameter name of the function) to an object.
- Some data types are mutable, but others aren’t.
Simple examples
- For immutable object, we cannot mute the object, and if we try to change the value of reference, the connection to original object will also be removed.
a = 3 b = a b = 5 print(a, b)
The result is
3 5
. - For mutable object like list, we can change the values, while the reference still point at the original object.
a = [0, 1, 2] b = a b[1] = 10 print(a, b)
The result is
[0, 10, 2] [0, 10, 2]
.
Argument passing examples
Like examples above.
- For immutable object, we cannot mute the object, and if we try to change the value of reference, the connection to outer original object will also be removed, that means the object in function is completely different with previous connected object.
def foo(b): b = 10 print("b in function is:", b) a = 3 b = a foo(b) print(a, b) ## output: ## b in function is: 10 ## 3 3
- For mutable object like list, we can change the values, while the reference still point at the original object in function.
def foo(b): b[0] = 10 print("b in function is:", b) a = [2, 5] b = a foo(b) print(a, b) ## output: ## b in function is: [10, 5] ## [10, 5] [10, 5]
We have to notice this since it’s a double-edged sword. If we want to change the elements of list in a function, that can be easy. However, in many situations we just want to pass the value of that list to a function, we need to copy that list using
copy()
.
from copy import copy
def foo(b):
b[0] = 10
print("b in function is:", b)
a = [2, 5]
b = a
foo(copy(b))
print(a, b)
## output:
## b in function is: [10, 5]
## [2, 5] [2, 5]
About pytorch
Here we introduce why it’s not so easy to change structures in running time for network in pytorch.
Considering resnet18
.
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
There are mainly two difficulties:
- Too many substructures to position target network
- Usually we can only get a reference to the original object in python. That means even if we can position the target part, we still cannot change them.
import torch
from torchvision import models
net = models.resnet18()
print(next(next((net.layer4.children())).children()))
reference = next(next((net.layer4.children())).children())
reference = torch.nn.Conv2d(111, 111, 3, 2, 1)
print( reference, next(next((net.layer4.children())).children()))
## output:
## Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
## Conv2d(111, 111, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
How to change the structures in running time?
It’s not so hard, since sequential
is actually has similar behavior with list
, that means if we make sure sequential
is been used to warp Conv2d
or something, we can replace the network in running time!
import torch
net = torch.nn.Sequential(torch.nn.Conv2d(3,3,1,1,1), torch.nn.ReLU())
print(net)
net[0] = torch.nn.Conv2d(2,1,1,1,1)
print(net)
Whose result is:
Sequential(
(0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
Sequential(
(0): Conv2d(2, 1, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1))
(1): ReLU()
)
That’s it! Also, we can just make a new network to replace the previous one, which is very troublesome.
However, it’s not perfect since if the substructure of Sequential
is not Sequantial
, we may not replace the substructure like above.
Summary
In this post, we’ve introduced how to change the network in running time with pytorch. The introduced method is suit for personal designed netword (or simple network).
References
Welcome to share or comment on this post: