1) Add the following sentence to the model
for p in (): p.requires_grad = False
For example, after loading the resnet pre-training model, and connecting a new module on top of the resenet, the part of the resenet module can be frozen and not updated for the time being, and only the parameters of the other parts can be updated, then you can add the above sentence below
class RESNET_MF(): def __init__(self, model, pretrained): super(RESNET_MF, self).__init__() = model(pretrained) for p in (): p.requires_grad = False #Pre-trained models loaded in and all set to not update parameters, then layers added later = SpectralNorm(nn.Conv2d(2048, 512, 1)) = SpectralNorm(nn.Conv2d(2048, 512, 1)) = SpectralNorm(nn.Conv2d(2048, 2048, 1)) ...
Also add in the optimizer:
filter(lambda p: p.requires_grad, ())
optimizer = (filter(lambda p: p.requires_grad, ()), lr=0.001, \ betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
2) Parameters are stored in an ordered dictionary, then they can be frozen by looking up the id value corresponding to the name of the parameter
View the code for each layer:
model_dict = ('').state_dict() dict_name = list(model_dict) for i, p in enumerate(dict_name): print(i, p)
Print out this file and you can see it looks roughly like this:
0 gamma 1 resnet. 2 resnet. 3 resnet. 4 resnet.bn1.running_mean 5 resnet.bn1.running_var 6 resnet.layer1.0. 7 resnet.layer1.0. 8 resnet.layer1.0. 9 resnet.layer1.0.bn1.running_mean ....
Again add such code to the model:
for i,p in enumerate(()): if i < 165: p.requires_grad = False
Adding the above line to the optimizer enables parameter masking
Addendum: pytorch loads pre-trained models + breakpoint recovery + freeze training (pitfall-avoidance version)
1. Pre-trained model network structure = the network structure you want to load the model into
Then just apply it.
path="Path to your .pt file" model = "Your network." checkpoint = (path, map_location=device) model.load_state_dict(checkpoint)
2. The structure of the pre-trained model network doesn't match the structure of your network.
When you apply the above formula directly, you will have a problem similar to the unexpected key problem
In this case, you need to analyze the network information specifically before deciding how to load it.
# model_dict is a dictionary that holds the network layer names and parameters. model_dict = model.state_dict() print(model_dict.keys() # Print out here reticulation Name of layer
checkpoint = (path,map_location=device) for k, v in (): print("keys:".k) # Print out here Pre-trained model networks Name of layer, It's a dictionary. 【linchpin】Another way to display。
Then, comparing the similarities and differences in the structural parameters of the two networks.
If the names of the network layers are not basically the same, then the pre-trained model is basically unusable, so just change the model.
If the network parameters of the two networks have a lot of similarities, but are not exactly the same, then you can take the following approach.
(1) Exact match for partial web keyword ----
model.load_state_dict(checkpoint, strict=True)
The load_state_dict function adds the parameter strict=True, which simply ignores the absence of the dict, copies it if it is the same, and discards the assignment if it is not! He requires that the keywords of the pre-trained model must exactly and strictly match the keywords returned by the network's state_dict() function in order to be assigned.
strict is also not very smart for those cases where the network keywords can basically match. Otherwise, the network parameter will be empty even if it is loaded successfully.
(2) Most web keywords ---- are partial matches (not exactly the same, but similar), for example
Web Keywords: backbone.stage0.rbr_dense.
Pre-training models Keywords: stage0.rbr_dense.
You can see that the network keyword has one more prefix than the pre-trained model, but the rest is identical. In this case, you can read the stage0.rbr_dense. of the pre-trained model into the backbone.stage0.rbr_dense. of the network.
# For dictionaries, the in or not in operator is based on the key model_dict = model.state_dict() checkpoint = (path,map_location=device) # k is a keyword for the pre-trained model and ss is a keyword for the network. for k, v in (): flag = False for ss in model_dict.keys(): if k in ss: # Match inside each element s = ss; flag = True; break else: continue if flag: checkpoint[k] = model_dict[s]
3. Breakpoint recovery
I feel the difference between this and the regular [model save load] method is mainly epoch recovery
# Model saving state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), ... # There is other content that you wish to preserve, which can also be customized } (state, filepath) # Load the model, resume training model.load_state_dict(state['state_dict']) optimizer.load_state_dict(state['optimizer']) start_epoch = checkpoint['epoch'] + 1
4. Freeze training
Freeze training in general is for [backbone], and is more often applied to [transfer learning].
For example, 0-49 Epoch: freeze backbone for training; 50-99: no freeze for training.
Init_Epoch = 0 Freeze_Epoch = 50 Unfreeze_Epoch =100 #------------------------------------# # Freeze a certain portion of training #------------------------------------# for param in (): param.requires_grad = False for epoch in range(Init_Epoch,Freeze_Epoch): # I`m Freeze-training !! pass #------------------------------------# # Post-thaw training #------------------------------------# for param in (): param.requires_grad = True for epoch in range(Freeze_Epoch,Unfreeze_Epoch): # I`m unfreeze-training !! pass
The above is a personal experience, I hope it can give you a reference, and I hope you can support me more. If there is any mistake or something that has not been fully considered, please do not hesitate to give me advice.