Skip to content

Commit

Permalink
Fix num_items_in_batch not being an integer
Browse files Browse the repository at this point in the history
In method `Trainer#get_batch_samples`, the return values should be a
list of batch samples and an integer indicating the number of items that
exist in the batch. However, this was not actually a case and what was
returned instead of an integer, was a tensor with one element. In the
multi-GPU setup, this tensor is placed in a different device than the
loss tensor, causing the loss function to raise a `RuntimeError`.

The problem arises from
https://github.com/huggingface/transformers/blob/5d7739f15a6e50de416977fe2cc9cb516d67edda/src/transformers/trainer.py#L5139-L5144,
where the outer `sum` operates over a list of tensors which means that
the final result is also a tensor. To counter this issue, a new check
(after the accelerator gathering) has been added in order to convert a
potential tensor to an integer before returning the
`num_items_in_batch`.
  • Loading branch information
xspirus committed Dec 6, 2024
1 parent 98e8062 commit 2134505
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5145,4 +5145,8 @@ def get_batch_samples(self, epoch_iterator, num_batches):

if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()

if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.item()

return batch_samples, num_items_in_batch

0 comments on commit 2134505

Please sign in to comment.