Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Nov 15, 2024
1 parent d475ff8 commit cf43485
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
14 changes: 11 additions & 3 deletions axlearn/common/checkpointer_orbax_emergency.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright © 2024 Apple Inc.

"""Checkpointing utilities using orbax.
"""Implements Orbax emergency checkpointing and provide utilities for correct store.
See also checkpointer.py for other checkpointing utilities and checkpointer_test.py for tests.
See the docstring of `OrbaxEmergencyCheckpointer` for more details.
"""

import copy
Expand Down Expand Up @@ -414,6 +414,10 @@ def get_consistent_proc_info(
class OrbaxEmergencyCheckpointer(BaseCheckpointer):
"""Checkpointer implementation that uses Orbax emergency checkpoint.
TLDR: To use the checkpointer, besides configuring it properly, it also requires
`get_consistent_proc_info` to be called and pass `inv_proc_id` and `address` as
`process_id` and `coordinator_address` to `jax.distributed.initialize`.
This checkpointer is intended for multi-slice training that uses data-parallelism across
slices. Orbax emergency checkpoint works by exploiting the following properties:
1. Tensors are replicated across data-parallel replicas.
Expand All @@ -428,7 +432,11 @@ class OrbaxEmergencyCheckpointer(BaseCheckpointer):
When a failure occurs, Orbax checkpointer will find the latest step from all local and
persistent checkpoints. If the checkpoint is local, the slice on which that checkpoint is
stored will read the checkpoint and broadcast the read values to other slices.
stored will read the checkpoint and broadcast the read values to other slices. Since local
checkpoints are scattered across different hosts, the process id, which determines the shard id
of locally stored shards, must stay the same for nodes in the healthy replicas to guarantee a
correct restore. We provide an utility function `get_consistent_proc_info` that returns the
process id and global coordinator address. They must be passed to `jax.distributed.initialize`.
However, the above procedure doesn't apply to some non-tensor states such as data iterators.
Data iterators are unique across jax processes, and thus cannot be stored on nodes. Orbax
Expand Down
5 changes: 1 addition & 4 deletions axlearn/common/checkpointer_orbax_emergency_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# Copyright © 2024 Apple Inc.

"""Tests orbax checkpointer.
See also checkpointer_test.py for common checkpointing tests.
"""
"""Tests orbax emergency checkpointer."""

# pylint: disable=protected-access

Expand Down

0 comments on commit cf43485

Please sign in to comment.