You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
self.optim.zero_grad()
with torch.no_grad():
loss.backward()
...
This usage likely prevents the calculation of gradients, as loss.backward() should not be inside a torch.no_grad() block. The correct approach would be:
self.optim.zero_grad()
loss.backward()
...
Here is the original code:
def fit_epoch(self):
"""Defined in :numref:`sec_linear_scratch`"""
self.model.train()
for batch in self.train_dataloader:
loss = self.model.training_step(self.prepare_batch(batch))
self.optim.zero_grad()
with torch.no_grad():
loss.backward()
if self.gradient_clip_val > 0: # To be discussed later
self.clip_gradients(self.gradient_clip_val, self.model)
self.optim.step()
self.train_batch_idx += 1
if self.val_dataloader is None:
return
self.model.eval()
for batch in self.val_dataloader:
with torch.no_grad():
self.model.validation_step(self.prepare_batch(batch))
self.val_batch_idx += 1
The text was updated successfully, but these errors were encountered:
def fit_epoch(self):
"""Defined in :numref:`sec_linear_scratch`"""
self.model.train()
for batch in self.train_dataloader:
loss = self.model.training_step(self.prepare_batch(batch))
self.optim.zero_grad()
with torch.no_grad():
loss.backward()
if self.gradient_clip_val > 0: # To be discussed later
self.clip_gradients(self.gradient_clip_val, self.model)
self.optim.step()
self.train_batch_idx += 1
if self.val_dataloader is None:
return
self.model.eval()
for batch in self.val_dataloader:
with torch.no_grad():
self.model.validation_step(self.prepare_batch(batch))
self.val_batch_idx += 1
Hello,
I noticed a potential issue in the fit_epoch method in https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py, where loss.backward() is called within a torch.no_grad() block:
This usage likely prevents the calculation of gradients, as loss.backward() should not be inside a torch.no_grad() block. The correct approach would be:
Here is the original code:
The text was updated successfully, but these errors were encountered: