Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix num_items_in_batch not being an integer #35115

Merged

Conversation

xspirus
Copy link
Contributor

@xspirus xspirus commented Dec 6, 2024

In method Trainer#get_batch_samples, the return values should be a list of batch samples and an integer indicating the number of items that exist in the batch. However, this was not actually a case and what was returned instead of an integer, was a tensor with one element. In the multi-GPU setup, this tensor is placed in a different device than the loss tensor, causing the loss function to raise a RuntimeError.

The problem arises from

if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
except (TypeError, AttributeError):
pass
, where the outer sum operates over a list of tensors which means that the final result is also a tensor. To counter this issue, a new check (after the accelerator gathering) has been added in order to convert a potential tensor to an integer before returning the num_items_in_batch.

What does this PR do?

Fixes #35086

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @muellerzr @SunMarc

In method `Trainer#get_batch_samples`, the return values should be a
list of batch samples and an integer indicating the number of items that
exist in the batch. However, this was not actually a case and what was
returned instead of an integer, was a tensor with one element. In the
multi-GPU setup, this tensor is placed in a different device than the
loss tensor, causing the loss function to raise a `RuntimeError`.

The problem arises from
https://github.com/huggingface/transformers/blob/5d7739f15a6e50de416977fe2cc9cb516d67edda/src/transformers/trainer.py#L5139-L5144,
where the outer `sum` operates over a list of tensors which means that
the final result is also a tensor. To counter this issue, a new check
(after the accelerator gathering) has been added in order to convert a
potential tensor to an integer before returning the
`num_items_in_batch`.
@xspirus xspirus force-pushed the fix-num-items-in-batch-tensor branch from 9739107 to 2134505 Compare December 6, 2024 08:45
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Thanks for the details and opening this PR @xspirus !

@SunMarc SunMarc requested a review from ArthurZucker December 6, 2024 10:49
@SunMarc
Copy link
Member

SunMarc commented Dec 6, 2024

If you are planning to do a patch, it would be nice to include this @ArthurZucker

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, will include in the patch then!

@ArthurZucker ArthurZucker merged commit dada0fd into huggingface:main Dec 10, 2024
24 checks passed
ArthurZucker pushed a commit that referenced this pull request Dec 10, 2024
In method `Trainer#get_batch_samples`, the return values should be a
list of batch samples and an integer indicating the number of items that
exist in the batch. However, this was not actually a case and what was
returned instead of an integer, was a tensor with one element. In the
multi-GPU setup, this tensor is placed in a different device than the
loss tensor, causing the loss function to raise a `RuntimeError`.

The problem arises from
https://github.com/huggingface/transformers/blob/5d7739f15a6e50de416977fe2cc9cb516d67edda/src/transformers/trainer.py#L5139-L5144,
where the outer `sum` operates over a list of tensors which means that
the final result is also a tensor. To counter this issue, a new check
(after the accelerator gathering) has been added in order to convert a
potential tensor to an integer before returning the
`num_items_in_batch`.
@xspirus xspirus deleted the fix-num-items-in-batch-tensor branch December 10, 2024 08:42
@chiragjn
Copy link

I am facing a related issue where the gather itself crashes when num_items_in_batch ends up None. There is some mismatch in computing total steps, my dataloader is getting exhausted beforehand.
I'll check why, but maybe this needs to be handled as well

[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/axolotl/train.py", line 192, in train
[rank1]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "<string>", line 302, in _fixed_inner_training_loop
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/transformers/trainer.py", line 5149, in get_batch_samples
[rank1]:     num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
[rank1]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/accelerator.py", line 2458, in gather
[rank1]:     return gather(tensor)
[rank1]:            ^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 376, in wrapper
[rank1]:     return function(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 437, in gather
[rank1]:     return _gpu_gather(tensor)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 356, in _gpu_gather
[rank1]:     return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 129, in recursively_apply
[rank1]:     raise TypeError(
[rank1]: TypeError: Unsupported types (<class 'NoneType'>) passed to `_gpu_gather_one`. Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.

@xspirus
Copy link
Contributor Author

xspirus commented Dec 17, 2024

This is probably because of https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L5151-L5154 where an exception has occurred. A check for None would probably fix the issue.

@chiragjn
Copy link

I confirmed in my case batch_samples len is 0 which means next(epoch_iterator) is raising a StopIteration which is a bit worrying and points to miscalculation somewhere in total steps

While we should handle None I would also like to get to the root cause and see if everything makes sense

@ArthurZucker
Copy link
Collaborator

cc @muellerzr there should be a simple fix! 🤗

@chiragjn
Copy link

Opened a new issue #35387 with reproduction details

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CausalLM loss function throws runtime error in multi-gpu setup
5 participants