Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load saved model after training #898

Open
meowmoewrainbow opened this issue Jul 28, 2024 · 4 comments
Open

Can't load saved model after training #898

meowmoewrainbow opened this issue Jul 28, 2024 · 4 comments

Comments

@meowmoewrainbow
Copy link

meowmoewrainbow commented Jul 28, 2024

Hi, I followed this notebook to train a new model and saved the checkpoint as Pytorch Lightning checkpoint. After that, I want to load the checkpoint but it can't be loaded.

The code:

class PolypModel(pl.LightningModule):

    def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch, encoder_name=encoder_name, in_channels=in_channels, classes=out_classes, **kwargs
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

    def forward(self, image):
        # normalize image here
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch, stage):
        
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32, 
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of 
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have 
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)
        
        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then 
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
        }
        
        self.log_dict(metrics, prog_bar=True)

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")            

    def training_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")

    def validation_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "valid")

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, "test")  

    def test_epoch_end(self, outputs):
        return self.shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)
    
model = PolypModel("DeepLabV3Plus", "efficientnet-b0", in_channels=3, out_classes=1)


checkpoint = torch.load('/datadrive/thaonp47/WeakPolyp/deeplabv3plus-eff/epoch=04-valid_per_image_iou=0.64.ckpt')
model.load_state_dict(checkpoint)

And the errors:
Traceback (most recent call last): File "test.py", line 205, in <module> model.load_state_dict(checkpoint) File "/home/azureuser/anaconda3/envs/thaonp47/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for PolypModel: Missing key(s) in state_dict: "std", "mean", "model.encoder._conv_stem.weight", "model.encoder._bn0.weight", "model.encoder._bn0.bias", "model.encoder._bn0.running_mean", "model.encoder._bn0.running_var", "model.encoder._blocks.0._depthwise_conv.weight", "model.encoder._blocks.0._bn1.weight", "model.encoder._blocks.0._bn1.bias", "model.encoder._blocks.0._bn1.running_mean", "model.encoder._blocks.0._bn1.running_var", "model.encoder._blocks.0._se_reduce.weight", "model.encoder._blocks.0._se_reduce.bias", "model.encoder._blocks.0._se_expand.weight", "model.encoder._blocks.0._se_expand.bias", "model.encoder._blocks.0._project_conv.weight", "model.encoder._blocks.0._bn2.weight", "model.encoder._blocks.0._bn2.bias", "model.encoder._blocks.0._bn2.running_mean", "model.encoder._blocks.0._bn2.running_var", "model.encoder._blocks.1._expand_conv.weight", "model.encoder._blocks.1._bn0.weight", "model.encoder._blocks.1._bn0.bias", "model.encoder._blocks.1._bn0.running_mean", "model.encoder._blocks.1._bn0.running_var", "model.encoder._blocks.1._depthwise_conv.weight", "model.encoder._blocks.1._bn1.weight", "model.encoder._blocks.1._bn1.bias", "model.encoder._blocks.1._bn1.running_mean", "model.encoder._blocks.1._bn1.running_var", "model.encoder._blocks.1._se_reduce.weight", "model.encoder._blocks.1._se_reduce.bias", "model.encoder._blocks.1._se_expand.weight", "model.encoder._blocks.1._se_expand.bias", "model.encoder._blocks.1._project_conv.weight", "model.encoder._blocks.1._bn2.weight", "model.encoder._blocks.1._bn2.bias", "model.encoder._blocks.1._bn2.running_mean", "model.encoder._blocks.1._bn2.running_var", "model.encoder._blocks.2._expand_conv.weight", "model.encoder._blocks.2._bn0.weight", "model.encoder._blocks.2._bn0.bias", "model.encoder._blocks.2._bn0.running_mean", "model.encoder._blocks.2._bn0.running_var", "model.encoder._blocks.2._depthwise_conv.weight", "model.encoder._blocks.2._bn1.weight", "model.encoder._blocks.2._bn1.bias", "model.encoder._blocks.2._bn1.running_mean", "model.encoder._blocks.2._bn1.running_var", "model.encoder._blocks.2._se_reduce.weight", "model.encoder._blocks.2._se_reduce.bias", "model.encoder._blocks.2._se_expand.weight", "model.encoder._blocks.2._se_expand.bias", "model.encoder._blocks.2._project_conv.weight", "model.encoder._blocks.2._bn2.weight", "model.encoder._blocks.2._bn2.bias", "model.encoder._blocks.2._bn2.running_mean", "model.encoder._blocks.2._bn2.running_var", "model.encoder._blocks.3._expand_conv.weight", "model.encoder._blocks.3._bn0.weight", "model.encoder._blocks.3._bn0.bias", "model.encoder._blocks.3._bn0.running_mean", "model.encoder._blocks.3._bn0.running_var", "model.encoder._blocks.3._depthwise_conv.weight", "model.encoder._blocks.3._bn1.weight", "model.encoder._blocks.3._bn1.bias", "model.encoder._blocks.3._bn1.running_mean", "model.encoder._blocks.3._bn1.running_var", "model.encoder._blocks.3._se_reduce.weight", "model.encoder._blocks.3._se_reduce.bias", "model.encoder._blocks.3._se_expand.weight", "model.encoder._blocks.3._se_expand.bias", "model.encoder._blocks.3._project_conv.weight", "model.encoder._blocks.3._bn2.weight", "model.encoder._blocks.3._bn2.bias", "model.encoder._blocks.3._bn2.running_mean", "model.encoder._blocks.3._bn2.running_var", "model.encoder._blocks.4._expand_conv.weight", "model.encoder._blocks.4._bn0.weight", "model.encoder._blocks.4._bn0.bias", "model.encoder._blocks.4._bn0.running_mean", "model.encoder._blocks.4._bn0.running_var", "model.encoder._blocks.4._depthwise_conv.weight", "model.encoder._blocks.4._bn1.weight", "model.encoder._blocks.4._bn1.bias", "model.encoder._blocks.4._bn1.running_mean", "model.encoder._blocks.4._bn1.running_var", "model.encoder._blocks.4._se_reduce.weight", "model.encoder._blocks.4._se_reduce.bias", "model.encoder._blocks.4._se_expand.weight", "model.encoder._blocks.4._se_expand.bias", "model.encoder._blocks.4._project_conv.weight", "model.encoder._blocks.4._bn2.weight", "model.encoder._blocks.4._bn2.bias", "model.encoder._blocks.4._bn2.running_mean", "model.encoder._blocks.4._bn2.running_var", "model.encoder._blocks.5._expand_conv.weight", "model.encoder._blocks.5._bn0.weight", "model.encoder._blocks.5._bn0.bias", "model.encoder._blocks.5._bn0.running_mean", "model.encoder._blocks.5._bn0.running_var", "model.encoder._blocks.5._depthwise_conv.weight", "model.encoder._blocks.5._bn1.weight", "model.encoder._blocks.5._bn1.bias", "model.encoder._blocks.5._bn1.running_mean", "model.encoder._blocks.5._bn1.running_var", "model.encoder._blocks.5._se_reduce.weight", "model.encoder._blocks.5._se_reduce.bias", "model.encoder._blocks.5._se_expand.weight", "model.encoder._blocks.5._se_expand.bias", "model.encoder._blocks.5._project_conv.weight", "model.encoder._blocks.5._bn2.weight", "model.encoder._blocks.5._bn2.bias", "model.encoder._blocks.5._bn2.running_mean", "model.encoder._blocks.5._bn2.running_var", "model.encoder._blocks.6._expand_conv.weight", "model.encoder._blocks.6._bn0.weight", "model.encoder._blocks.6._bn0.bias", "model.encoder._blocks.6._bn0.running_mean", "model.encoder._blocks.6._bn0.running_var", "model.encoder._blocks.6._depthwise_conv.weight", "model.encoder._blocks.6._bn1.weight", "model.encoder._blocks.6._bn1.bias", "model.encoder._blocks.6._bn1.running_mean", "model.encoder._blocks.6._bn1.running_var", "model.encoder._blocks.6._se_reduce.weight", "model.encoder._blocks.6._se_reduce.bias", "model.encoder._blocks.6._se_expand.weight", "model.encoder._blocks.6._se_expand.bias", "model.encoder._blocks.6._project_conv.weight", "model.encoder._blocks.6._bn2.weight", "model.encoder._blocks.6._bn2.bias", "model.encoder._blocks.6._bn2.running_mean", "model.encoder._blocks.6._bn2.running_var", "model.encoder._blocks.7._expand_conv.weight", "model.encoder._blocks.7._bn0.weight", "model.encoder._blocks.7._bn0.bias", "model.encoder._blocks.7._bn0.running_mean", "model.encoder._blocks.7._bn0.running_var", "model.encoder._blocks.7._depthwise_conv.weight", "model.encoder._blocks.7._bn1.weight", "model.encoder._blocks.7._bn1.bias", "model.encoder._blocks.7._bn1.running_mean", "model.encoder._blocks.7._bn1.running_var", "model.encoder._blocks.7._se_reduce.weight", "model.encoder._blocks.7._se_reduce.bias", "model.encoder._blocks.7._se_expand.weight", "model.encoder._blocks.7._se_expand.bias", "model.encoder._blocks.7._project_conv.weight", "model.encoder._blocks.7._bn2.weight", "model.encoder._blocks.7._bn2.bias", "model.encoder._blocks.7._bn2.running_mean", "model.encoder._blocks.7._bn2.running_var", "model.encoder._blocks.8._expand_conv.weight", "model.encoder._blocks.8._bn0.weight", "model.encoder._blocks.8._bn0.bias", "model.encoder._blocks.8._bn0.running_mean", "model.encoder._blocks.8._bn0.running_var", "model.encoder._blocks.8._depthwise_conv.weight", "model.encoder._blocks.8._bn1.weight", "model.encoder._blocks.8._bn1.bias", "model.encoder._blocks.8._bn1.running_mean", "model.encoder._blocks.8._bn1.running_var", "model.encoder._blocks.8._se_reduce.weight", "model.encoder._blocks.8._se_reduce.bias", "model.encoder._blocks.8._se_expand.weight", "model.encoder._blocks.8._se_expand.bias", "model.encoder._blocks.8._project_conv.weight", "model.encoder._blocks.8._bn2.weight", "model.encoder._blocks.8._bn2.bias", "model.encoder._blocks.8._bn2.running_mean", "model.encoder._blocks.8._bn2.running_var", "model.encoder._blocks.9._expand_conv.weight", "model.encoder._blocks.9._bn0.weight", "model.encoder._blocks.9._bn0.bias", "model.encoder._blocks.9._bn0.running_mean", "model.encoder._blocks.9._bn0.running_var", "model.encoder._blocks.9._depthwise_conv.weight", "model.encoder._blocks.9._bn1.weight", "model.encoder._blocks.9._bn1.bias", "model.encoder._blocks.9._bn1.running_mean", "model.encoder._blocks.9._bn1.running_var", "model.encoder._blocks.9._se_reduce.weight", "model.encoder._blocks.9._se_reduce.bias", "model.encoder._blocks.9._se_expand.weight", "model.encoder._blocks.9._se_expand.bias", "model.encoder._blocks.9._project_conv.weight", "model.encoder._blocks.9._bn2.weight", "model.encoder._blocks.9._bn2.bias", "model.encoder._blocks.9._bn2.running_mean", "model.encoder._blocks.9._bn2.running_var", "model.encoder._blocks.10._expand_conv.weight", "model.encoder._blocks.10._bn0.weight", "model.encoder._blocks.10._bn0.bias", "model.encoder._blocks.10._bn0.running_mean", "model.encoder._blocks.10._bn0.running_var", "model.encoder._blocks.10._depthwise_conv.weight", "model.encoder._blocks.10._bn1.weight", "model.encoder._blocks.10._bn1.bias", "model.encoder._blocks.10._bn1.running_mean", "model.encoder._blocks.10._bn1.running_var", "model.encoder._blocks.10._se_reduce.weight", "model.encoder._blocks.10._se_reduce.bias", "model.encoder._blocks.10._se_expand.weight", "model.encoder._blocks.10._se_expand.bias", "model.encoder._blocks.10._project_conv.weight", "model.encoder._blocks.10._bn2.weight", "model.encoder._blocks.10._bn2.bias", "model.encoder._blocks.10._bn2.running_mean", "model.encoder._blocks.10._bn2.running_var", "model.encoder._blocks.11._expand_conv.weight", "model.encoder._blocks.11._bn0.weight", "model.encoder._blocks.11._bn0.bias", "model.encoder._blocks.11._bn0.running_mean", "model.encoder._blocks.11._bn0.running_var", "model.encoder._blocks.11._depthwise_conv.weight", "model.encoder._blocks.11._bn1.weight", "model.encoder._blocks.11._bn1.bias", "model.encoder._blocks.11._bn1.running_mean", "model.encoder._blocks.11._bn1.running_var", "model.encoder._blocks.11._se_reduce.weight", "model.encoder._blocks.11._se_reduce.bias", "model.encoder._blocks.11._se_expand.weight", "model.encoder._blocks.11._se_expand.bias", "model.encoder._blocks.11._project_conv.weight", "model.encoder._blocks.11._bn2.weight", "model.encoder._blocks.11._bn2.bias", "model.encoder._blocks.11._bn2.running_mean", "model.encoder._blocks.11._bn2.running_var", "model.encoder._blocks.12._expand_conv.weight", "model.encoder._blocks.12._bn0.weight", "model.encoder._blocks.12._bn0.bias", "model.encoder._blocks.12._bn0.running_mean", "model.encoder._blocks.12._bn0.running_var", "model.encoder._blocks.12._depthwise_conv.weight", "model.encoder._blocks.12._bn1.weight", "model.encoder._blocks.12._bn1.bias", "model.encoder._blocks.12._bn1.running_mean", "model.encoder._blocks.12._bn1.running_var", "model.encoder._blocks.12._se_reduce.weight", "model.encoder._blocks.12._se_reduce.bias", "model.encoder._blocks.12._se_expand.weight", "model.encoder._blocks.12._se_expand.bias", "model.encoder._blocks.12._project_conv.weight", "model.encoder._blocks.12._bn2.weight", "model.encoder._blocks.12._bn2.bias", "model.encoder._blocks.12._bn2.running_mean", "model.encoder._blocks.12._bn2.running_var", "model.encoder._blocks.13._expand_conv.weight", "model.encoder._blocks.13._bn0.weight", "model.encoder._blocks.13._bn0.bias", "model.encoder._blocks.13._bn0.running_mean", "model.encoder._blocks.13._bn0.running_var", "model.encoder._blocks.13._depthwise_conv.weight", "model.encoder._blocks.13._bn1.weight", "model.encoder._blocks.13._bn1.bias", "model.encoder._blocks.13._bn1.running_mean", "model.encoder._blocks.13._bn1.running_var", "model.encoder._blocks.13._se_reduce.weight", "model.encoder._blocks.13._se_reduce.bias", "model.encoder._blocks.13._se_expand.weight", "model.encoder._blocks.13._se_expand.bias", "model.encoder._blocks.13._project_conv.weight", "model.encoder._blocks.13._bn2.weight", "model.encoder._blocks.13._bn2.bias", "model.encoder._blocks.13._bn2.running_mean", "model.encoder._blocks.13._bn2.running_var", "model.encoder._blocks.14._expand_conv.weight", "model.encoder._blocks.14._bn0.weight", "model.encoder._blocks.14._bn0.bias", "model.encoder._blocks.14._bn0.running_mean", "model.encoder._blocks.14._bn0.running_var", "model.encoder._blocks.14._depthwise_conv.weight", "model.encoder._blocks.14._bn1.weight", "model.encoder._blocks.14._bn1.bias", "model.encoder._blocks.14._bn1.running_mean", "model.encoder._blocks.14._bn1.running_var", "model.encoder._blocks.14._se_reduce.weight", "model.encoder._blocks.14._se_reduce.bias", "model.encoder._blocks.14._se_expand.weight", "model.encoder._blocks.14._se_expand.bias", "model.encoder._blocks.14._project_conv.weight", "model.encoder._blocks.14._bn2.weight", "model.encoder._blocks.14._bn2.bias", "model.encoder._blocks.14._bn2.running_mean", "model.encoder._blocks.14._bn2.running_var", "model.encoder._blocks.15._expand_conv.weight", "model.encoder._blocks.15._bn0.weight", "model.encoder._blocks.15._bn0.bias", "model.encoder._blocks.15._bn0.running_mean", "model.encoder._blocks.15._bn0.running_var", "model.encoder._blocks.15._depthwise_conv.weight", "model.encoder._blocks.15._bn1.weight", "model.encoder._blocks.15._bn1.bias", "model.encoder._blocks.15._bn1.running_mean", "model.encoder._blocks.15._bn1.running_var", "model.encoder._blocks.15._se_reduce.weight", "model.encoder._blocks.15._se_reduce.bias", "model.encoder._blocks.15._se_expand.weight", "model.encoder._blocks.15._se_expand.bias", "model.encoder._blocks.15._project_conv.weight", "model.encoder._blocks.15._bn2.weight", "model.encoder._blocks.15._bn2.bias", "model.encoder._blocks.15._bn2.running_mean", "model.encoder._blocks.15._bn2.running_var", "model.encoder._conv_head.weight", "model.encoder._bn1.weight", "model.encoder._bn1.bias", "model.encoder._bn1.running_mean", "model.encoder._bn1.running_var", "model.decoder.aspp.0.convs.0.0.weight", "model.decoder.aspp.0.convs.0.1.weight", "model.decoder.aspp.0.convs.0.1.bias", "model.decoder.aspp.0.convs.0.1.running_mean", "model.decoder.aspp.0.convs.0.1.running_var", "model.decoder.aspp.0.convs.1.0.0.weight", "model.decoder.aspp.0.convs.1.0.1.weight", "model.decoder.aspp.0.convs.1.1.weight", "model.decoder.aspp.0.convs.1.1.bias", "model.decoder.aspp.0.convs.1.1.running_mean", "model.decoder.aspp.0.convs.1.1.running_var", "model.decoder.aspp.0.convs.2.0.0.weight", "model.decoder.aspp.0.convs.2.0.1.weight", "model.decoder.aspp.0.convs.2.1.weight", "model.decoder.aspp.0.convs.2.1.bias", "model.decoder.aspp.0.convs.2.1.running_mean", "model.decoder.aspp.0.convs.2.1.running_var", "model.decoder.aspp.0.convs.3.0.0.weight", "model.decoder.aspp.0.convs.3.0.1.weight", "model.decoder.aspp.0.convs.3.1.weight", "model.decoder.aspp.0.convs.3.1.bias", "model.decoder.aspp.0.convs.3.1.running_mean", "model.decoder.aspp.0.convs.3.1.running_var", "model.decoder.aspp.0.convs.4.1.weight", "model.decoder.aspp.0.convs.4.2.weight", "model.decoder.aspp.0.convs.4.2.bias", "model.decoder.aspp.0.convs.4.2.running_mean", "model.decoder.aspp.0.convs.4.2.running_var", "model.decoder.aspp.0.project.0.weight", "model.decoder.aspp.0.project.1.weight", "model.decoder.aspp.0.project.1.bias", "model.decoder.aspp.0.project.1.running_mean", "model.decoder.aspp.0.project.1.running_var", "model.decoder.aspp.1.0.weight", "model.decoder.aspp.1.1.weight", "model.decoder.aspp.2.weight", "model.decoder.aspp.2.bias", "model.decoder.aspp.2.running_mean", "model.decoder.aspp.2.running_var", "model.decoder.block1.0.weight", "model.decoder.block1.1.weight", "model.decoder.block1.1.bias", "model.decoder.block1.1.running_mean", "model.decoder.block1.1.running_var", "model.decoder.block2.0.0.weight", "model.decoder.block2.0.1.weight", "model.decoder.block2.1.weight", "model.decoder.block2.1.bias", "model.decoder.block2.1.running_mean", "model.decoder.block2.1.running_var", "model.segmentation_head.0.weight", "model.segmentation_head.0.bias". Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "callbacks", "optimizer_states", "lr_schedulers".

@qubvel
Copy link
Collaborator

qubvel commented Jul 28, 2024

Hi, try to choose state dict from the checkpoint to load:
model.load_state_dict(checkpoint["state_dict"])

@YUHSINCHENG1230
Copy link

I also try the code model.load_state_dict(checkpoint["state_dict"]) but still have same problem
how can i do next?

@qubvel
Copy link
Collaborator

qubvel commented Sep 2, 2024

It would be easier to understand if you look at the loaded torch state dict keys and mode's state dict keys, thus you can figure out the problem with loading.

@CvBokchoy
Copy link

请问您最后是如何解决的?如果解决了 请告知一下谢谢,请您提供一下如何保存为pt或者ckpt 并且转换为onnx的方法。(How did you solve it in the end? If it is solved, please let us know thank you. Please provide the method of how to save to pt or ckpt and convert to onnx.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants