Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix
num_items_in_batch
not being an integer
In method `Trainer#get_batch_samples`, the return values should be a list of batch samples and an integer indicating the number of items that exist in the batch. However, this was not actually a case and what was returned instead of an integer, was a tensor with one element. In the multi-GPU setup, this tensor is placed in a different device than the loss tensor, causing the loss function to raise a `RuntimeError`. The problem arises from https://github.com/huggingface/transformers/blob/5d7739f15a6e50de416977fe2cc9cb516d67edda/src/transformers/trainer.py#L5139-L5144, where the outer `sum` operates over a list of tensors which means that the final result is also a tensor. To counter this issue, a new check (after the accelerator gathering) has been added in order to convert a potential tensor to an integer before returning the `num_items_in_batch`.
- Loading branch information