Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Use of torch.no_grad() in fit_epoch Method in d2l/torch.py::Trainer::fit_epoch #2573

Open
caydenwei opened this issue Dec 19, 2023 · 3 comments

Comments

@caydenwei
Copy link

caydenwei commented Dec 19, 2023

Hello,

I noticed a potential issue in the fit_epoch method in https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py, where loss.backward() is called within a torch.no_grad() block:

self.optim.zero_grad()
with torch.no_grad():
    loss.backward()
    ...

This usage likely prevents the calculation of gradients, as loss.backward() should not be inside a torch.no_grad() block. The correct approach would be:

self.optim.zero_grad()
loss.backward()
...

Here is the original code:

    def fit_epoch(self):
        """Defined in :numref:`sec_linear_scratch`"""
        self.model.train()
        for batch in self.train_dataloader:
            loss = self.model.training_step(self.prepare_batch(batch))
            self.optim.zero_grad()
            with torch.no_grad():
                loss.backward()
                if self.gradient_clip_val > 0:  # To be discussed later
                    self.clip_gradients(self.gradient_clip_val, self.model)
                self.optim.step()
            self.train_batch_idx += 1
        if self.val_dataloader is None:
            return
        self.model.eval()
        for batch in self.val_dataloader:
            with torch.no_grad():
                self.model.validation_step(self.prepare_batch(batch))
            self.val_batch_idx += 1
@Wu-Zongyu
Copy link

I think it should be

self.optim.zero_grad() 
    loss.backward() 
    with torch.no_grad():
        self.optim.step()

@caydenwei
Copy link
Author

I think it should be

self.optim.zero_grad() 
    loss.backward() 
    with torch.no_grad():
        self.optim.step()

Apologies for not being clear earlier. I'm uncertain about the correctness of a specific part of the code found at https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py. Here is the original code:

    def fit_epoch(self):
        """Defined in :numref:`sec_linear_scratch`"""
        self.model.train()
        for batch in self.train_dataloader:
            loss = self.model.training_step(self.prepare_batch(batch))
            self.optim.zero_grad()
            with torch.no_grad():
                loss.backward()
                if self.gradient_clip_val > 0:  # To be discussed later
                    self.clip_gradients(self.gradient_clip_val, self.model)
                self.optim.step()
            self.train_batch_idx += 1
        if self.val_dataloader is None:
            return
        self.model.eval()
        for batch in self.val_dataloader:
            with torch.no_grad():
                self.model.validation_step(self.prepare_batch(batch))
            self.val_batch_idx += 1

@Brianwind
Copy link

Same question. I'm confused why the code still works in the examples (LeNet, etc).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants