Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpointCallback is triggered by mistake after every validation stage when mannual optimization #20459

Open
silverbulletmdc opened this issue Nov 29, 2024 · 4 comments
Assignees
Labels
design Includes a design discussion working as intended Working as intended

Comments

@silverbulletmdc
Copy link

silverbulletmdc commented Nov 29, 2024

Bug description

I set the every_n_epochs param of ModelCheckpoint to 1 and val_check_interval of trainer to 200. The total iter of a batch is 1000. It should not save checkpoint files after the val_check. But it does.
image

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.4.0):
#- PyTorch Version (e.g., 2.4):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @tchaton @justusschock @awaelchli @Borda

@silverbulletmdc silverbulletmdc added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 29, 2024
@lantiga lantiga removed the needs triage Waiting to be triaged by maintainers label Dec 4, 2024
@lantiga lantiga self-assigned this Dec 4, 2024
@lantiga
Copy link
Collaborator

lantiga commented Dec 4, 2024

Thanks for reporting this. Is that the case with the latest master as well?

@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

Ok, I verified and can reproduce

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.demos import Transformer, WikiText2
from lightning.pytorch.callbacks import ModelCheckpoint


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.model = Transformer(
            vocab_size=self.vocab_size,
            nlayers=2,
            nhid=4096,
            ninp=1024,
            nhead=8,
        )

    def training_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("val_loss", loss, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)



def train():
    L.seed_everything(42)

    dataset = WikiText2()
    train_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)
    val_dataloader = DataLoader(dataset, num_workers=8, batch_size=1)

    model = LanguageModel(vocab_size=dataset.vocab_size)

    model_checkpoint = ModelCheckpoint(save_top_k=-1, every_n_epochs=2, save_last=False)

    trainer = L.Trainer(
        max_steps=100,
        precision="bf16-true",
        limit_train_batches=10,
        limit_val_batches=2,
        callbacks=model_checkpoint,
        val_check_interval=5
    )
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)


if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    train()

@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

BTW if you change

model_checkpoint = ModelCheckpoint(save_top_k=-1, every_n_epochs=2, save_last=False)

to

model_checkpoint = ModelCheckpoint(save_top_k=-1, every_n_epochs=2, save_on_train_epoch_end=True, save_last=False)

then you'll get checkpoints saved at the right interval.

However if you keep the interval as in the above snippet, you indeed get something like this:

'epoch=1-step=15.ckpt'
'epoch=1-step=20.ckpt'
'epoch=3-step=35.ckpt'
'epoch=3-step=40.ckpt'
'epoch=5-step=55.ckpt'
'epoch=5-step=60.ckpt'
'epoch=7-step=75.ckpt'
'epoch=7-step=80.ckpt'
'epoch=9-step=95.ckpt'
'epoch=9-step=100.ckpt'

i.e. you get a checkpoint at all validation steps, in the training epoch when you're supposed to save and not the other. Which is of course consistent with the code in ModelCheckpoint but it's indeed a bit surprising.

@lantiga
Copy link
Collaborator

lantiga commented Dec 11, 2024

So the resolution for this is use

ModelCheckpoint(..., save_on_train_epoch_end=True)

to avoid saving on validation_end.

In the future we could introduce a save_on_validation_end=True argument to make things more explicit. It would clarify in this case, wdyt?

@lantiga lantiga added working as intended Working as intended design Includes a design discussion and removed bug Something isn't working ver: 2.4.x labels Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion working as intended Working as intended
Projects
None yet
Development

No branches or pull requests

2 participants