SoFunction
Updated on 2024-11-20

pytorch implementation freezes some parameters to train others

1) Add the following sentence to the model

for p in ():
 p.requires_grad = False

For example, after loading the resnet pre-training model, and connecting a new module on top of the resenet, the part of the resenet module can be frozen and not updated for the time being, and only the parameters of the other parts can be updated, then you can add the above sentence below

class RESNET_MF():
 def __init__(self, model, pretrained):
  super(RESNET_MF, self).__init__()
   = model(pretrained)
  for p in ():
   p.requires_grad = False #Pre-trained models loaded in and all set to not update parameters, then layers added later
   = SpectralNorm(nn.Conv2d(2048, 512, 1))
   = SpectralNorm(nn.Conv2d(2048, 512, 1))
   = SpectralNorm(nn.Conv2d(2048, 2048, 1))
  ...

Also add in the optimizer:

filter(lambda p: p.requires_grad, ())
optimizer = (filter(lambda p: p.requires_grad, ()), lr=0.001, \
 betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

2) Parameters are stored in an ordered dictionary, then they can be frozen by looking up the id value corresponding to the name of the parameter

View the code for each layer:

model_dict = ('').state_dict()
dict_name = list(model_dict)
for i, p in enumerate(dict_name):
 print(i, p)

Print out this file and you can see it looks roughly like this:

0 gamma
1 resnet.
2 resnet.
3 resnet.
4 resnet.bn1.running_mean
5 resnet.bn1.running_var
6 resnet.layer1.0.
7 resnet.layer1.0.
8 resnet.layer1.0.
9 resnet.layer1.0.bn1.running_mean
....

Again add such code to the model:

for i,p in enumerate(()):
 if i < 165:
  p.requires_grad = False

Adding the above line to the optimizer enables parameter masking

Addendum: pytorch loads pre-trained models + breakpoint recovery + freeze training (pitfall-avoidance version)

1. Pre-trained model network structure = the network structure you want to load the model into

Then just apply it.

path="Path to your .pt file"
model = "Your network."
checkpoint = (path, map_location=device)
model.load_state_dict(checkpoint)

2. The structure of the pre-trained model network doesn't match the structure of your network.

When you apply the above formula directly, you will have a problem similar to the unexpected key problem

In this case, you need to analyze the network information specifically before deciding how to load it.

# model_dict is a dictionary that holds the network layer names and parameters.
model_dict = model.state_dict()
print(model_dict.keys()
# Print out here reticulation Name of layer
checkpoint = (path,map_location=device)
for k, v in ():
 print("keys:".k)
# Print out here Pre-trained model networks Name of layer, It's a dictionary. 【linchpin】Another way to display。

Then, comparing the similarities and differences in the structural parameters of the two networks.

If the names of the network layers are not basically the same, then the pre-trained model is basically unusable, so just change the model.

If the network parameters of the two networks have a lot of similarities, but are not exactly the same, then you can take the following approach.

(1) Exact match for partial web keyword ----

model.load_state_dict(checkpoint, strict=True)

The load_state_dict function adds the parameter strict=True, which simply ignores the absence of the dict, copies it if it is the same, and discards the assignment if it is not! He requires that the keywords of the pre-trained model must exactly and strictly match the keywords returned by the network's state_dict() function in order to be assigned.

strict is also not very smart for those cases where the network keywords can basically match. Otherwise, the network parameter will be empty even if it is loaded successfully.

(2) Most web keywords ---- are partial matches (not exactly the same, but similar), for example

Web Keywords: backbone.stage0.rbr_dense.

Pre-training models Keywords: stage0.rbr_dense.

You can see that the network keyword has one more prefix than the pre-trained model, but the rest is identical. In this case, you can read the stage0.rbr_dense. of the pre-trained model into the backbone.stage0.rbr_dense. of the network.

# For dictionaries, the in or not in operator is based on the key
model_dict = model.state_dict()
checkpoint = (path,map_location=device)
# k is a keyword for the pre-trained model and ss is a keyword for the network.
for k, v in ():
 flag = False
 for ss in model_dict.keys():
 if k in ss: # Match inside each element
 s = ss; flag = True; break
 else:
 continue
 if flag:
 checkpoint[k] = model_dict[s]

3. Breakpoint recovery

I feel the difference between this and the regular [model save load] method is mainly epoch recovery

# Model saving
state = {
 'epoch': epoch,
 'state_dict': model.state_dict(),
 'optimizer': optimizer.state_dict(),
  ... # There is other content that you wish to preserve, which can also be customized
 }
 (state, filepath)
# Load the model, resume training
 model.load_state_dict(state['state_dict'])
 optimizer.load_state_dict(state['optimizer'])
 start_epoch = checkpoint['epoch'] + 1

4. Freeze training

Freeze training in general is for [backbone], and is more often applied to [transfer learning].

For example, 0-49 Epoch: freeze backbone for training; 50-99: no freeze for training.

Init_Epoch = 0
Freeze_Epoch = 50
Unfreeze_Epoch =100
#------------------------------------#
# Freeze a certain portion of training
#------------------------------------# 
for param in ():
 param.requires_grad = False
for epoch in range(Init_Epoch,Freeze_Epoch): 
 # I`m Freeze-training !!
 pass
#------------------------------------#
# Post-thaw training
#------------------------------------#
for param in ():
 param.requires_grad = True
for epoch in range(Freeze_Epoch,Unfreeze_Epoch):
 # I`m unfreeze-training !!
 pass 

The above is a personal experience, I hope it can give you a reference, and I hope you can support me more. If there is any mistake or something that has not been fully considered, please do not hesitate to give me advice.