Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: forbid repeated deepspeed.initialize on training objects #6874

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

traincheck-team
Copy link

@traincheck-team traincheck-team commented Dec 16, 2024

Previously closed PR:
#6848

Partially Fixes: #6772 #6771 #6770 by forbidding repeated initialization.

What are changed:

  1. Marking 'model', 'optimizer' and 'lr_scheduler' in the arguments of deepspeed.initialize with the flag ds_is_inited = True.
  2. Marking 'engine', 'engine.optimizer' and 'engine.lr_scheduler' in the return values of deepspeed.initialize with the flag ds_is_inited = True.
  3. When calling deepspeed.initialize, raise an exception if detected ds_is_inited == True in the input model, optimizer or lr_scheduler

Expected Behavior:
Forbid repeated deepspeed.initialize invocations on model, optimizer and lr_scheduler objects.

@traincheck-team
Copy link
Author

@microsoft-github-policy-service agree

@traincheck-team
Copy link
Author

This fix still has interference with existing unit tests. Let me double check before we proceed.

…peedEngine propagates flag from the internal model
@traincheck-team traincheck-team force-pushed the fix-6848-forbid-repeated-init branch from dc81325 to d1e7777 Compare December 16, 2024 21:02
@traincheck-team
Copy link
Author

traincheck-team commented Dec 16, 2024

The unit tests in tests/unit/runtime/test_ds_initialize.py all passed.
The PR is ready for review @tjruwase.

I am not able to check other unit tests due to GPU memory constraint.

deepspeed/__init__.py Outdated Show resolved Hide resolved
deepspeed/__init__.py Outdated Show resolved Hide resolved
if _is_initialized(model):
raise ValueError(
"Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.")
if optimizer is not None and _is_initialized(optimizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that optimizer could be a Callable, not an object

optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,

raise ValueError(
"Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once."
)
if lr_scheduler is not None and _is_initialized(lr_scheduler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto for lr_scheduler

lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,

@@ -137,6 +181,10 @@ def initialize(args=None,
zero.partition_parameters.shutdown_init_context()

assert model is not None, "deepspeed.initialize requires a model"
# enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call
_assert_trainobjs_not_inited(model, optimizer, lr_scheduler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this call should be moved into `_mark_trainobjs_initialized()

@traincheck-team
Copy link
Author

traincheck-team commented Dec 19, 2024

Thanks for the review @tjruwase.

  • I added handling for callable types for optimizer and lr_scheduler. The handling is to only mark is_ds_inited for objects instead of callables, as the callables are not stateful and reuse should be allowed.

Regarding,

I think this call should be moved into _mark_trainobjs_initialized()

I think _assert_trainobjs_not_inited should still be separated from _mark_trainobjs_initialized as _mark_trainobjs_initialized is also called on the wrapped model and optimizers before exiting from deepspeed.initialize. The wrapped models may already have is_ds_inited being True since in DeepSpeedEngine all model flags will be passed through on the wrapper.

If we still want to keep _assert_trainobjs_not_inited inside _mark_trainobjs_initialized, we can do either of the three:

  1. add a flag to _mark_trainobjs_initialized to indicate whether the *not_inited` check should be performed
  2. add/check a flag using __dict__ rather than setattr/getattr
  3. check whether inited use type information as well, i.e. for types of DeepSpeedEngine we directly return inited == True instead of checking for flags.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] [Fix-Suggested] ZeRO Stage 3 Overwrites Module ID Attribute Causing Incorrect Expert Placement on GPUs
2 participants