SoFunction
Updated on 2024-11-16

Solving the key mismatch encountered when Pytorch modifies a pre-trained model

I. Pytorch encounters key mismatch when modifying pre-trained models

Recently thought of modifying the pre-training model of the network, but found that after I loaded the pre-training model weights to the newly created model and saved it.

I have a key mismatch problem when I use the newly assigned network model

# Save after loading (no network modifications)
base_weights = (args.save_folder + )
ssd_net.vgg.load_state_dict(base_weights) 
(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# Replace the previous pre-trained model with the newly saved network
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if :
        ...
    else:
        base_weights = (args.save_folder + )
        # for ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights) 

The following error occurs:

Loading base network…
Traceback (most recent call last):
File “”, line 264, in
train()
File “”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “”, “”, … “”, “”.
Unexpected key(s) in state_dict: “vgg.”, “vgg.”, … “vgg.”, “vgg.”.

It means that the pre-trained model key parameter was "", "", but after loading and saving it changed to "vgg.", "vgg."

I think it's because of the = (base) line in the model definition file itself.

The problem now is because the model key parameter saved by your own definition has an extra prefix.

It can be modified with the following statement and loaded

from collections import OrderedDict   #Import this module
base_weights = (args.save_folder + )
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`, i.e., take only the last few digits of vgg.
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict) 

No more mistakes at this point.

Referred to this article. Modify it and apply it to your own model.

///article/

II. pytorch encountered problems loading pre-trained models: KeyError: 'bn1.num_batches_tracked'

Recently when using pytorch1.0 to load resnet pre-training model, encountered a problem, here to record.

KeyError: 'layer1.0.bn1.num_batches_tracked'

Actually, it's a matter of the version used. pytorch0.4.1 added the track_running_stats parameter to the BN layer.

The function of this parameter is as follows:

When training, it is used to count the number of min-batches forwarded during training, for each min-batch, track_running_stats+=1.

If momentum is not specified, then 1/num_batches_tracked is used as a factor to calculate the running mean and variance.

Actually, this parameter is not useful. But since the official pre-trained model is trained by pytorch version 0.3, it doesn't have this parameter.

So, just filter the keyword in the dictionary of pre-trained weights, 'num_batches_tracked'. Code example, below.

Problematic code:

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = (model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

Filter for 'num_batches_tracked':

   def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = (model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

The above is a personal experience, I hope it can give you a reference, and I hope you can support me more.