Skip to content

Commit

Permalink
add additional fix and unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
NirSonnenschein committed Dec 5, 2024
1 parent 4d9e98f commit 441a328
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
7 changes: 5 additions & 2 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,5 +665,8 @@ def get_additional_losses(self):

def compile(self, *args, **kwargs):
for idx, layer in enumerate(self.forward_funcs):
new_layer = torch.compile(layer, *args, **kwargs)
self.forward_funcs[idx] = new_layer
if isinstance(layer, nn.Module):
layer.compile(*args, **kwargs)
else:
new_layer = torch.compile(layer, *args, **kwargs)
self.forward_funcs[idx] = new_layer
6 changes: 4 additions & 2 deletions tests/unit/pipe/test_pipe_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class TestPipeModuleSequential(DistributedTest):
world_size = 2

@pytest.mark.parametrize("activation_checkpoints", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints):
@pytest.mark.parametrize("use_compile", [False, True])
def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile):
base_model = copy.deepcopy(sequential_model)
base_input = batch_input.clone().detach()
base_output = base_model(base_input)
Expand All @@ -71,7 +72,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi

pipe_model = copy.deepcopy(sequential_model)
pipe_model = PipelineModule(layers=pipe_model, num_stages=2)

if (use_compile):
pipe_model.compile()
# Ensure all parameters are accounted for.
my_params = sum(p.numel() for p in pipe_model.parameters())
total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())
Expand Down

0 comments on commit 441a328

Please sign in to comment.