Skip to content

Commit

Permalink
Add in mem ckpt
Browse files Browse the repository at this point in the history
Add coordinator address processing
  • Loading branch information
hanzhi713 committed Nov 19, 2024
1 parent e080157 commit 56f51de
Show file tree
Hide file tree
Showing 8 changed files with 1,141 additions and 11 deletions.
30 changes: 29 additions & 1 deletion axlearn/common/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,8 @@ class Config(Module.Config):
every_n_steps_policy
)

# TODO(hanzhi-zhou): deprecate all checkpoint_paths related class methods in favor of
# checkpoint_steps.
@classmethod
def checkpoint_paths(cls, base_dir: str) -> list[str]:
"""Returns complete checkpoint paths under base dir.
Expand All @@ -758,6 +760,24 @@ def latest_checkpoint_path(cls, base_dir: str) -> str:
# Note: checkpoint_paths should already filter incomplete checkpoints.
return sorted(cls.checkpoint_paths(base_dir)).pop()

@classmethod
def checkpoint_steps(cls, base_dir: str) -> list[int]:
"""Returns complete checkpoint steps under base dir.
Args:
base_dir: Path to checkpoints dir.
Returns:
A list of committed checkpoint steps. Incomplete checkpoints are dropped.
"""
raise NotImplementedError(cls)

@classmethod
def latest_checkpoint_step(cls, base_dir: str) -> int:
"""Returns the most recent (highest step count) checkpoint step under base dir."""
# Note: checkpoint_steps should already filter incomplete checkpoints.
return max(cls.checkpoint_steps(base_dir))

def __init__(self, cfg: Module.Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
self._within_context = False
Expand Down Expand Up @@ -864,7 +884,11 @@ class Config(BaseCheckpointer.Config):
@classmethod
def checkpoint_paths(cls, base_dir: str) -> list[str]:
"""See `BaseCheckpointer.checkpointer_paths`."""

logging.log_first_n(
logging.WARNING,
msg="checkpoint_paths is deprecated. Use checkpoint_steps instead.",
n=1,
)
# The default checkpointer commits under "<base_dir>/<step_prefix>_<step>/index". Using a
# concurrent `exists` check for the index file can be several times faster than `glob` on
# gcs when there are many checkpoint files, even if using a "native" solution like
Expand All @@ -881,6 +905,10 @@ def checkpoint_paths(cls, base_dir: str) -> list[str]:
index_exists = pool.map(fs.exists, paths)
return [os.path.dirname(path) for path, committed in zip(paths, index_exists) if committed]

@classmethod
def checkpoint_steps(cls, base_dir: str) -> list[int]:
return [parse_step_from_dir(path) for path in cls.checkpoint_paths(base_dir)]

@classmethod
def cleanup_checkpoint(cls, ckpt_dir: str, *, sync: bool = True):
"""Removes ckpt_dir if it exists.
Expand Down
11 changes: 11 additions & 0 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax
import jax.lib
import orbax.checkpoint as ocp
import tensorflow as tf
from absl import logging
Expand Down Expand Up @@ -184,8 +185,18 @@ class Config(BaseCheckpointer.Config):
@classmethod
def checkpoint_paths(cls, base_dir: str) -> List[str]:
"""See `BaseCheckpointer.checkpointer_paths`."""
logging.log_first_n(
logging.WARNING,
msg="checkpoint_paths is deprecated. Use checkpoint_steps instead.",
n=1,
)
return [str(path) for path in ocp.utils.checkpoint_steps_paths(base_dir)]

@classmethod
def checkpoint_steps(cls, base_dir) -> list[int]:
"""See `BaseCheckpointer.checkpointer_steps`."""
return ocp.utils.checkpoint_steps(base_dir)

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)

Expand Down
Loading

0 comments on commit 56f51de

Please sign in to comment.