Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: qwen2 rotaty embed inv_freq not in gpu #35417

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

NileZhou
Copy link

@NileZhou NileZhou commented Dec 26, 2024

What does this PR do?

fix an issue when I run InternVL2.5(which contains Qwen2)

Fixes # (issue)

When I run InternVL2.5(which contains Qwen2) on my 8*A100 machine, I got this error:

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████| 16/16 [03:29<00:00, 13.12s/it]
model device: cuda:0
pixel_values device: cuda:0
Setting pad_token_id to eos_token_id:151645 for open-end generation.
Traceback (most recent call last):
File "/njfs/train-nlp/zhouyi9/projects/ImageComment/InternVL/internvl_chat/inference_test.py", line 141, in
response = model.chat(tokenizer, pixel_values, question, generation_config)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2_5-38B/modeling_internvl_chat.py", line 290, in chat
generation_output = self.generate(
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/InternVL2_5-38B/modeling_internvl_chat.py", line 339, in generate
outputs = self.language_model.generate(
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/transformers/generation/utils.py", line 2252, in generate
result = self._sample(
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/transformers/generation/utils.py", line 3251, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1165, in forward
outputs = self.model(
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 871, in forward
position_embeddings = self.rotary_emb(hidden_states, position_ids)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data0/users/software/20240312_conda/miniconda/envs/zhouyi_internvl/lib/python3.9/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 163, in forward
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

The reason:

    @torch.no_grad()
    def forward(self, x, position_ids):
        if "dynamic" in self.rope_type:
            self._dynamic_frequency_update(position_ids, device=x.device)

        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
it's on the cpu, not on gpu

so I add:
inv_freq_expanded = inv_freq_expanded.to(position_ids.device)

solved this problem

Before submitting

  • Did you write any new necessary tests?
    yes

Who can review?

If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of who to tag.
Please tag fewer than 3 people.

Models:

-->

@NileZhou
Copy link
Author

Could you review this PR?
thanks!

PS: I can't resolve the problem that assign or set reviewers.

@ArthurZucker , @qubvel

@mumtozee
Copy link

Hi, I have faced the same problem while running InternVL2_5-38B-MPO splitted on 4 GPUs. However, InternVL2_5-4B-MPO has finished successfully in the same setup but on a single GPU. Both of these models have Qwen2 as their LLM backbone and run on the same code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants