I am having trouble loading a model from checkpoint after training on the cloud using lightning. The error suggests there are extra keys in the state dictionary. All the problems arise from (edited) nn.TransformerEncoder
. It looks like only the (0) element in ModuleList
is loaded properly. Do I need to specify the model in some way before calling load_from_checkpoint? It is acting as if the number of layers in ModuleList
is set to 1.
If I manually create the model with parameters matching the checkpoint file, I can use model.load_state_dict to initialize the model without errors:
# Create a dummy model
model = phla_transformer.Transformer(
peptide_length=12,
allele_length=60,
dropout_rate=0.3,
transformer_heads=4,
transformer_layers=2
)
# Load the model from checkpoint
ckpt_path = glob.glob('logs/hits_95/lightning_logs/heads_4_layers_2/*.ckpt')[0]
checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
I expected load_from_checkpoint
to recover the model structure from the ckpt file, but perhaps I’m mistaken.
Error message and model definition are below:
RuntimeError: Error(s) in loading state_dict for Transformer:
Unexpected key(s) in state_dict: "transformer_encoder.transformer_encoder.layers.1.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.1.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.1.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.1.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.1.linear1.weight", "transformer_encoder.transformer_encoder.layers.1.linear1.bias", "transformer_encoder.transformer_encoder.layers.1.linear2.weight", "transformer_encoder.transformer_encoder.layers.1.linear2.bias", "transformer_encoder.transformer_encoder.layers.1.norm1.weight", "transformer_encoder.transformer_encoder.layers.1.norm1.bias", "transformer_encoder.transformer_encoder.layers.1.norm2.weight", "transformer_encoder.transformer_encoder.layers.1.norm2.bias".
Model definition from training is below.
Transformer(
(embedding): Embedding(22, 32)
(transformer_encoder): MaskedTransformerEncoder(
(encoder_layer): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
)
(linear1): Linear(in_features=32, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=32, bias=True)
(norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(transformer_encoder): TransformerEncoder(
(layers): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
)
(linear1): Linear(in_features=32, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=32, bias=True)
(norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(1): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
)
(linear1): Linear(in_features=32, out_features=2048, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=2048, out_features=32, bias=True)
(norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
)
)
(positional_encoder): PositionalEncoding(
(dropout): Dropout(p=0.1, inplace=False)
)
(flatten): Flatten(start_dim=1, end_dim=-1)
(fc1): Linear(in_features=2304, out_features=2304, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
(fc2): Linear(in_features=2304, out_features=1, bias=True)
(criterion): BCEWithLogitsLoss()
)