You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I followed this notebook to train a new model and saved the checkpoint as Pytorch Lightning checkpoint. After that, I want to load the checkpoint but it can't be loaded.
The code:
classPolypModel(pl.LightningModule):
def__init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
super().__init__()
self.model=smp.create_model(
arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
)
# preprocessing parameteres for imageparams=smp.encoders.get_preprocessing_params(encoder_name)
self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))
# for image segmentation dice loss could be the best first choiceself.loss_fn=smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
defforward(self, image):
# normalize image hereimage= (image-self.mean) /self.stdmask=self.model(image)
returnmaskdefshared_step(self, batch, stage):
image=batch["image"]
# Shape of the image should be (batch_size, num_channels, height, width)# if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]assertimage.ndim==4# Check that image dimensions are divisible by 32, # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80# and we will get an error trying to concat these featuresh, w=image.shape[2:]
asserth%32==0andw%32==0mask=batch["mask"]
# Shape of the mask should be [batch_size, num_classes, height, width]# for binary segmentation num_classes = 1assertmask.ndim==4# Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentationassertmask.max() <=1.0andmask.min() >=0logits_mask=self.forward(image)
# Predicted mask contains logits, and loss_fn param `from_logits` is set to Trueloss=self.loss_fn(logits_mask, mask)
# Lets compute metrics for some threshold# first convert mask values to probabilities, then # apply thresholdingprob_mask=logits_mask.sigmoid()
pred_mask= (prob_mask>0.5).float()
# We will compute IoU metric by two ways# 1. dataset-wise# 2. image-wise# but for now we just compute true positive, false positive, false negative and# true negative 'pixels' for each image and class# these values will be aggregated in the end of an epochtp, fp, fn, tn=smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")
return {
"loss": loss,
"tp": tp,
"fp": fp,
"fn": fn,
"tn": tn,
}
defshared_epoch_end(self, outputs, stage):
# aggregate step meticstp=torch.cat([x["tp"] forxinoutputs])
fp=torch.cat([x["fp"] forxinoutputs])
fn=torch.cat([x["fn"] forxinoutputs])
tn=torch.cat([x["tn"] forxinoutputs])
# per image IoU means that we first calculate IoU score for each image # and then compute mean over these scoresper_image_iou=smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
# dataset IoU means that we aggregate intersection and union over whole dataset# and then compute IoU score. The difference between dataset_iou and per_image_iou scores# in this particular case will not be much, however for dataset # with "empty" images (images without target class) a large gap could be observed. # Empty images influence a lot on per_image_iou and much less on dataset_iou.dataset_iou=smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
metrics= {
f"{stage}_per_image_iou": per_image_iou,
f"{stage}_dataset_iou": dataset_iou,
}
self.log_dict(metrics, prog_bar=True)
deftraining_step(self, batch, batch_idx):
returnself.shared_step(batch, "train")
deftraining_epoch_end(self, outputs):
returnself.shared_epoch_end(outputs, "train")
defvalidation_step(self, batch, batch_idx):
returnself.shared_step(batch, "valid")
defvalidation_epoch_end(self, outputs):
returnself.shared_epoch_end(outputs, "valid")
deftest_step(self, batch, batch_idx):
returnself.shared_step(batch, "test")
deftest_epoch_end(self, outputs):
returnself.shared_epoch_end(outputs, "test")
defconfigure_optimizers(self):
returntorch.optim.Adam(self.parameters(), lr=0.0001)
model=PolypModel("DeepLabV3Plus", "efficientnet-b0", in_channels=3, out_classes=1)
checkpoint=torch.load('/datadrive/thaonp47/WeakPolyp/deeplabv3plus-eff/epoch=04-valid_per_image_iou=0.64.ckpt')
model.load_state_dict(checkpoint)
It would be easier to understand if you look at the loaded torch state dict keys and mode's state dict keys, thus you can figure out the problem with loading.
请问您最后是如何解决的?如果解决了 请告知一下谢谢,请您提供一下如何保存为pt或者ckpt 并且转换为onnx的方法。(How did you solve it in the end? If it is solved, please let us know thank you. Please provide the method of how to save to pt or ckpt and convert to onnx.)
Hi, I followed this notebook to train a new model and saved the checkpoint as Pytorch Lightning checkpoint. After that, I want to load the checkpoint but it can't be loaded.
The code:
And the errors:
Traceback (most recent call last): File "test.py", line 205, in <module> model.load_state_dict(checkpoint) File "/home/azureuser/anaconda3/envs/thaonp47/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for PolypModel: Missing key(s) in state_dict: "std", "mean", "model.encoder._conv_stem.weight", "model.encoder._bn0.weight", "model.encoder._bn0.bias", "model.encoder._bn0.running_mean", "model.encoder._bn0.running_var", "model.encoder._blocks.0._depthwise_conv.weight", "model.encoder._blocks.0._bn1.weight", "model.encoder._blocks.0._bn1.bias", "model.encoder._blocks.0._bn1.running_mean", "model.encoder._blocks.0._bn1.running_var", "model.encoder._blocks.0._se_reduce.weight", "model.encoder._blocks.0._se_reduce.bias", "model.encoder._blocks.0._se_expand.weight", "model.encoder._blocks.0._se_expand.bias", "model.encoder._blocks.0._project_conv.weight", "model.encoder._blocks.0._bn2.weight", "model.encoder._blocks.0._bn2.bias", "model.encoder._blocks.0._bn2.running_mean", "model.encoder._blocks.0._bn2.running_var", "model.encoder._blocks.1._expand_conv.weight", "model.encoder._blocks.1._bn0.weight", "model.encoder._blocks.1._bn0.bias", "model.encoder._blocks.1._bn0.running_mean", "model.encoder._blocks.1._bn0.running_var", "model.encoder._blocks.1._depthwise_conv.weight", "model.encoder._blocks.1._bn1.weight", "model.encoder._blocks.1._bn1.bias", "model.encoder._blocks.1._bn1.running_mean", "model.encoder._blocks.1._bn1.running_var", "model.encoder._blocks.1._se_reduce.weight", "model.encoder._blocks.1._se_reduce.bias", "model.encoder._blocks.1._se_expand.weight", "model.encoder._blocks.1._se_expand.bias", "model.encoder._blocks.1._project_conv.weight", "model.encoder._blocks.1._bn2.weight", "model.encoder._blocks.1._bn2.bias", "model.encoder._blocks.1._bn2.running_mean", "model.encoder._blocks.1._bn2.running_var", "model.encoder._blocks.2._expand_conv.weight", "model.encoder._blocks.2._bn0.weight", "model.encoder._blocks.2._bn0.bias", "model.encoder._blocks.2._bn0.running_mean", "model.encoder._blocks.2._bn0.running_var", "model.encoder._blocks.2._depthwise_conv.weight", "model.encoder._blocks.2._bn1.weight", "model.encoder._blocks.2._bn1.bias", "model.encoder._blocks.2._bn1.running_mean", "model.encoder._blocks.2._bn1.running_var", "model.encoder._blocks.2._se_reduce.weight", "model.encoder._blocks.2._se_reduce.bias", "model.encoder._blocks.2._se_expand.weight", "model.encoder._blocks.2._se_expand.bias", "model.encoder._blocks.2._project_conv.weight", "model.encoder._blocks.2._bn2.weight", "model.encoder._blocks.2._bn2.bias", "model.encoder._blocks.2._bn2.running_mean", "model.encoder._blocks.2._bn2.running_var", "model.encoder._blocks.3._expand_conv.weight", "model.encoder._blocks.3._bn0.weight", "model.encoder._blocks.3._bn0.bias", "model.encoder._blocks.3._bn0.running_mean", "model.encoder._blocks.3._bn0.running_var", "model.encoder._blocks.3._depthwise_conv.weight", "model.encoder._blocks.3._bn1.weight", "model.encoder._blocks.3._bn1.bias", "model.encoder._blocks.3._bn1.running_mean", "model.encoder._blocks.3._bn1.running_var", "model.encoder._blocks.3._se_reduce.weight", "model.encoder._blocks.3._se_reduce.bias", "model.encoder._blocks.3._se_expand.weight", "model.encoder._blocks.3._se_expand.bias", "model.encoder._blocks.3._project_conv.weight", "model.encoder._blocks.3._bn2.weight", "model.encoder._blocks.3._bn2.bias", "model.encoder._blocks.3._bn2.running_mean", "model.encoder._blocks.3._bn2.running_var", "model.encoder._blocks.4._expand_conv.weight", "model.encoder._blocks.4._bn0.weight", "model.encoder._blocks.4._bn0.bias", "model.encoder._blocks.4._bn0.running_mean", "model.encoder._blocks.4._bn0.running_var", "model.encoder._blocks.4._depthwise_conv.weight", "model.encoder._blocks.4._bn1.weight", "model.encoder._blocks.4._bn1.bias", "model.encoder._blocks.4._bn1.running_mean", "model.encoder._blocks.4._bn1.running_var", "model.encoder._blocks.4._se_reduce.weight", "model.encoder._blocks.4._se_reduce.bias", "model.encoder._blocks.4._se_expand.weight", "model.encoder._blocks.4._se_expand.bias", "model.encoder._blocks.4._project_conv.weight", "model.encoder._blocks.4._bn2.weight", "model.encoder._blocks.4._bn2.bias", "model.encoder._blocks.4._bn2.running_mean", "model.encoder._blocks.4._bn2.running_var", "model.encoder._blocks.5._expand_conv.weight", "model.encoder._blocks.5._bn0.weight", "model.encoder._blocks.5._bn0.bias", "model.encoder._blocks.5._bn0.running_mean", "model.encoder._blocks.5._bn0.running_var", "model.encoder._blocks.5._depthwise_conv.weight", "model.encoder._blocks.5._bn1.weight", "model.encoder._blocks.5._bn1.bias", "model.encoder._blocks.5._bn1.running_mean", "model.encoder._blocks.5._bn1.running_var", "model.encoder._blocks.5._se_reduce.weight", "model.encoder._blocks.5._se_reduce.bias", "model.encoder._blocks.5._se_expand.weight", "model.encoder._blocks.5._se_expand.bias", "model.encoder._blocks.5._project_conv.weight", "model.encoder._blocks.5._bn2.weight", "model.encoder._blocks.5._bn2.bias", "model.encoder._blocks.5._bn2.running_mean", "model.encoder._blocks.5._bn2.running_var", "model.encoder._blocks.6._expand_conv.weight", "model.encoder._blocks.6._bn0.weight", "model.encoder._blocks.6._bn0.bias", "model.encoder._blocks.6._bn0.running_mean", "model.encoder._blocks.6._bn0.running_var", "model.encoder._blocks.6._depthwise_conv.weight", "model.encoder._blocks.6._bn1.weight", "model.encoder._blocks.6._bn1.bias", "model.encoder._blocks.6._bn1.running_mean", "model.encoder._blocks.6._bn1.running_var", "model.encoder._blocks.6._se_reduce.weight", "model.encoder._blocks.6._se_reduce.bias", "model.encoder._blocks.6._se_expand.weight", "model.encoder._blocks.6._se_expand.bias", "model.encoder._blocks.6._project_conv.weight", "model.encoder._blocks.6._bn2.weight", "model.encoder._blocks.6._bn2.bias", "model.encoder._blocks.6._bn2.running_mean", "model.encoder._blocks.6._bn2.running_var", "model.encoder._blocks.7._expand_conv.weight", "model.encoder._blocks.7._bn0.weight", "model.encoder._blocks.7._bn0.bias", "model.encoder._blocks.7._bn0.running_mean", "model.encoder._blocks.7._bn0.running_var", "model.encoder._blocks.7._depthwise_conv.weight", "model.encoder._blocks.7._bn1.weight", "model.encoder._blocks.7._bn1.bias", "model.encoder._blocks.7._bn1.running_mean", "model.encoder._blocks.7._bn1.running_var", "model.encoder._blocks.7._se_reduce.weight", "model.encoder._blocks.7._se_reduce.bias", "model.encoder._blocks.7._se_expand.weight", "model.encoder._blocks.7._se_expand.bias", "model.encoder._blocks.7._project_conv.weight", "model.encoder._blocks.7._bn2.weight", "model.encoder._blocks.7._bn2.bias", "model.encoder._blocks.7._bn2.running_mean", "model.encoder._blocks.7._bn2.running_var", "model.encoder._blocks.8._expand_conv.weight", "model.encoder._blocks.8._bn0.weight", "model.encoder._blocks.8._bn0.bias", "model.encoder._blocks.8._bn0.running_mean", "model.encoder._blocks.8._bn0.running_var", "model.encoder._blocks.8._depthwise_conv.weight", "model.encoder._blocks.8._bn1.weight", "model.encoder._blocks.8._bn1.bias", "model.encoder._blocks.8._bn1.running_mean", "model.encoder._blocks.8._bn1.running_var", "model.encoder._blocks.8._se_reduce.weight", "model.encoder._blocks.8._se_reduce.bias", "model.encoder._blocks.8._se_expand.weight", "model.encoder._blocks.8._se_expand.bias", "model.encoder._blocks.8._project_conv.weight", "model.encoder._blocks.8._bn2.weight", "model.encoder._blocks.8._bn2.bias", "model.encoder._blocks.8._bn2.running_mean", "model.encoder._blocks.8._bn2.running_var", "model.encoder._blocks.9._expand_conv.weight", "model.encoder._blocks.9._bn0.weight", "model.encoder._blocks.9._bn0.bias", "model.encoder._blocks.9._bn0.running_mean", "model.encoder._blocks.9._bn0.running_var", "model.encoder._blocks.9._depthwise_conv.weight", "model.encoder._blocks.9._bn1.weight", "model.encoder._blocks.9._bn1.bias", "model.encoder._blocks.9._bn1.running_mean", "model.encoder._blocks.9._bn1.running_var", "model.encoder._blocks.9._se_reduce.weight", "model.encoder._blocks.9._se_reduce.bias", "model.encoder._blocks.9._se_expand.weight", "model.encoder._blocks.9._se_expand.bias", "model.encoder._blocks.9._project_conv.weight", "model.encoder._blocks.9._bn2.weight", "model.encoder._blocks.9._bn2.bias", "model.encoder._blocks.9._bn2.running_mean", "model.encoder._blocks.9._bn2.running_var", "model.encoder._blocks.10._expand_conv.weight", "model.encoder._blocks.10._bn0.weight", "model.encoder._blocks.10._bn0.bias", "model.encoder._blocks.10._bn0.running_mean", "model.encoder._blocks.10._bn0.running_var", "model.encoder._blocks.10._depthwise_conv.weight", "model.encoder._blocks.10._bn1.weight", "model.encoder._blocks.10._bn1.bias", "model.encoder._blocks.10._bn1.running_mean", "model.encoder._blocks.10._bn1.running_var", "model.encoder._blocks.10._se_reduce.weight", "model.encoder._blocks.10._se_reduce.bias", "model.encoder._blocks.10._se_expand.weight", "model.encoder._blocks.10._se_expand.bias", "model.encoder._blocks.10._project_conv.weight", "model.encoder._blocks.10._bn2.weight", "model.encoder._blocks.10._bn2.bias", "model.encoder._blocks.10._bn2.running_mean", "model.encoder._blocks.10._bn2.running_var", "model.encoder._blocks.11._expand_conv.weight", "model.encoder._blocks.11._bn0.weight", "model.encoder._blocks.11._bn0.bias", "model.encoder._blocks.11._bn0.running_mean", "model.encoder._blocks.11._bn0.running_var", "model.encoder._blocks.11._depthwise_conv.weight", "model.encoder._blocks.11._bn1.weight", "model.encoder._blocks.11._bn1.bias", "model.encoder._blocks.11._bn1.running_mean", "model.encoder._blocks.11._bn1.running_var", "model.encoder._blocks.11._se_reduce.weight", "model.encoder._blocks.11._se_reduce.bias", "model.encoder._blocks.11._se_expand.weight", "model.encoder._blocks.11._se_expand.bias", "model.encoder._blocks.11._project_conv.weight", "model.encoder._blocks.11._bn2.weight", "model.encoder._blocks.11._bn2.bias", "model.encoder._blocks.11._bn2.running_mean", "model.encoder._blocks.11._bn2.running_var", "model.encoder._blocks.12._expand_conv.weight", "model.encoder._blocks.12._bn0.weight", "model.encoder._blocks.12._bn0.bias", "model.encoder._blocks.12._bn0.running_mean", "model.encoder._blocks.12._bn0.running_var", "model.encoder._blocks.12._depthwise_conv.weight", "model.encoder._blocks.12._bn1.weight", "model.encoder._blocks.12._bn1.bias", "model.encoder._blocks.12._bn1.running_mean", "model.encoder._blocks.12._bn1.running_var", "model.encoder._blocks.12._se_reduce.weight", "model.encoder._blocks.12._se_reduce.bias", "model.encoder._blocks.12._se_expand.weight", "model.encoder._blocks.12._se_expand.bias", "model.encoder._blocks.12._project_conv.weight", "model.encoder._blocks.12._bn2.weight", "model.encoder._blocks.12._bn2.bias", "model.encoder._blocks.12._bn2.running_mean", "model.encoder._blocks.12._bn2.running_var", "model.encoder._blocks.13._expand_conv.weight", "model.encoder._blocks.13._bn0.weight", "model.encoder._blocks.13._bn0.bias", "model.encoder._blocks.13._bn0.running_mean", "model.encoder._blocks.13._bn0.running_var", "model.encoder._blocks.13._depthwise_conv.weight", "model.encoder._blocks.13._bn1.weight", "model.encoder._blocks.13._bn1.bias", "model.encoder._blocks.13._bn1.running_mean", "model.encoder._blocks.13._bn1.running_var", "model.encoder._blocks.13._se_reduce.weight", "model.encoder._blocks.13._se_reduce.bias", "model.encoder._blocks.13._se_expand.weight", "model.encoder._blocks.13._se_expand.bias", "model.encoder._blocks.13._project_conv.weight", "model.encoder._blocks.13._bn2.weight", "model.encoder._blocks.13._bn2.bias", "model.encoder._blocks.13._bn2.running_mean", "model.encoder._blocks.13._bn2.running_var", "model.encoder._blocks.14._expand_conv.weight", "model.encoder._blocks.14._bn0.weight", "model.encoder._blocks.14._bn0.bias", "model.encoder._blocks.14._bn0.running_mean", "model.encoder._blocks.14._bn0.running_var", "model.encoder._blocks.14._depthwise_conv.weight", "model.encoder._blocks.14._bn1.weight", "model.encoder._blocks.14._bn1.bias", "model.encoder._blocks.14._bn1.running_mean", "model.encoder._blocks.14._bn1.running_var", "model.encoder._blocks.14._se_reduce.weight", "model.encoder._blocks.14._se_reduce.bias", "model.encoder._blocks.14._se_expand.weight", "model.encoder._blocks.14._se_expand.bias", "model.encoder._blocks.14._project_conv.weight", "model.encoder._blocks.14._bn2.weight", "model.encoder._blocks.14._bn2.bias", "model.encoder._blocks.14._bn2.running_mean", "model.encoder._blocks.14._bn2.running_var", "model.encoder._blocks.15._expand_conv.weight", "model.encoder._blocks.15._bn0.weight", "model.encoder._blocks.15._bn0.bias", "model.encoder._blocks.15._bn0.running_mean", "model.encoder._blocks.15._bn0.running_var", "model.encoder._blocks.15._depthwise_conv.weight", "model.encoder._blocks.15._bn1.weight", "model.encoder._blocks.15._bn1.bias", "model.encoder._blocks.15._bn1.running_mean", "model.encoder._blocks.15._bn1.running_var", "model.encoder._blocks.15._se_reduce.weight", "model.encoder._blocks.15._se_reduce.bias", "model.encoder._blocks.15._se_expand.weight", "model.encoder._blocks.15._se_expand.bias", "model.encoder._blocks.15._project_conv.weight", "model.encoder._blocks.15._bn2.weight", "model.encoder._blocks.15._bn2.bias", "model.encoder._blocks.15._bn2.running_mean", "model.encoder._blocks.15._bn2.running_var", "model.encoder._conv_head.weight", "model.encoder._bn1.weight", "model.encoder._bn1.bias", "model.encoder._bn1.running_mean", "model.encoder._bn1.running_var", "model.decoder.aspp.0.convs.0.0.weight", "model.decoder.aspp.0.convs.0.1.weight", "model.decoder.aspp.0.convs.0.1.bias", "model.decoder.aspp.0.convs.0.1.running_mean", "model.decoder.aspp.0.convs.0.1.running_var", "model.decoder.aspp.0.convs.1.0.0.weight", "model.decoder.aspp.0.convs.1.0.1.weight", "model.decoder.aspp.0.convs.1.1.weight", "model.decoder.aspp.0.convs.1.1.bias", "model.decoder.aspp.0.convs.1.1.running_mean", "model.decoder.aspp.0.convs.1.1.running_var", "model.decoder.aspp.0.convs.2.0.0.weight", "model.decoder.aspp.0.convs.2.0.1.weight", "model.decoder.aspp.0.convs.2.1.weight", "model.decoder.aspp.0.convs.2.1.bias", "model.decoder.aspp.0.convs.2.1.running_mean", "model.decoder.aspp.0.convs.2.1.running_var", "model.decoder.aspp.0.convs.3.0.0.weight", "model.decoder.aspp.0.convs.3.0.1.weight", "model.decoder.aspp.0.convs.3.1.weight", "model.decoder.aspp.0.convs.3.1.bias", "model.decoder.aspp.0.convs.3.1.running_mean", "model.decoder.aspp.0.convs.3.1.running_var", "model.decoder.aspp.0.convs.4.1.weight", "model.decoder.aspp.0.convs.4.2.weight", "model.decoder.aspp.0.convs.4.2.bias", "model.decoder.aspp.0.convs.4.2.running_mean", "model.decoder.aspp.0.convs.4.2.running_var", "model.decoder.aspp.0.project.0.weight", "model.decoder.aspp.0.project.1.weight", "model.decoder.aspp.0.project.1.bias", "model.decoder.aspp.0.project.1.running_mean", "model.decoder.aspp.0.project.1.running_var", "model.decoder.aspp.1.0.weight", "model.decoder.aspp.1.1.weight", "model.decoder.aspp.2.weight", "model.decoder.aspp.2.bias", "model.decoder.aspp.2.running_mean", "model.decoder.aspp.2.running_var", "model.decoder.block1.0.weight", "model.decoder.block1.1.weight", "model.decoder.block1.1.bias", "model.decoder.block1.1.running_mean", "model.decoder.block1.1.running_var", "model.decoder.block2.0.0.weight", "model.decoder.block2.0.1.weight", "model.decoder.block2.1.weight", "model.decoder.block2.1.bias", "model.decoder.block2.1.running_mean", "model.decoder.block2.1.running_var", "model.segmentation_head.0.weight", "model.segmentation_head.0.bias". Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "callbacks", "optimizer_states", "lr_schedulers".
The text was updated successfully, but these errors were encountered: