Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change compile for pipeline module torch.compile #6478

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4d9e98f
Change compile for pipeline module torch.compile
NirSonnenschein Aug 28, 2024
441a328
add additional fix and unit test
NirSonnenschein Dec 4, 2024
69978c4
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 5, 2024
0b81ff8
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 5, 2024
40fb3c4
make test process non-daemonic
NirSonnenschein Dec 9, 2024
effd52f
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 9, 2024
43ebd48
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 9, 2024
52e0a3d
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 10, 2024
52bbd46
Merge branch 'master' into torch_compile_micro_offset_fix
NirSonnenschein Dec 10, 2024
607e5ff
Merge branch 'master' into torch_compile_micro_offset_fix
NirSonnenschein Dec 11, 2024
02243a6
Merge branch 'master' into torch_compile_micro_offset_fix
NirSonnenschein Dec 12, 2024
22ad132
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 12, 2024
ff37644
Merge branch 'master' into torch_compile_micro_offset_fix
NirSonnenschein Dec 18, 2024
2dd8fc9
Merge branch 'master' into torch_compile_micro_offset_fix
NirSonnenschein Dec 19, 2024
9bb37bf
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 19, 2024
b40011b
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 20, 2024
8d8b6e6
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 20, 2024
47b7ebf
Merge branch 'master' into torch_compile_micro_offset_fix
loadams Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,11 @@ def get_additional_losses(self):
Return a dictionary of {"loss name": loss_value} or None if no additional losses.
"""
return None

def compile(self, *args, **kwargs):
for idx, layer in enumerate(self.forward_funcs):
if isinstance(layer, nn.Module):
layer.compile(*args, **kwargs)
else:
new_layer = torch.compile(layer, *args, **kwargs)
self.forward_funcs[idx] = new_layer
8 changes: 6 additions & 2 deletions tests/unit/pipe/test_pipe_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def batch_input():

class TestPipeModuleSequential(DistributedTest):
world_size = 2
# needs to be set for torch.compile: running torch.compile with daemonic process causes an error
non_daemonic_procs = True

@pytest.mark.parametrize("activation_checkpoints", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints):
@pytest.mark.parametrize("use_compile", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile):
base_model = copy.deepcopy(sequential_model)
base_input = batch_input.clone().detach()
base_output = base_model(base_input)
Expand All @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi

pipe_model = copy.deepcopy(sequential_model)
pipe_model = PipelineModule(layers=pipe_model, num_stages=2)

if (use_compile):
pipe_model.compile()
# Ensure all parameters are accounted for.
my_params = sum(p.numel() for p in pipe_model.parameters())
total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())
Expand Down
Loading