Skip to content

Commit

Permalink
init process id
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Nov 11, 2024
1 parent 36ad92a commit b4a00eb
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 18 deletions.
201 changes: 191 additions & 10 deletions axlearn/common/checkpointer_orbax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@
import copy
import dataclasses
import functools
import hashlib
import os
import time
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Process
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import jax
import jax.lib
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as oecp
import tensorflow as tf
from absl import logging
from jax._src.distributed import global_state
from jax._src.mesh import thread_resources
from jax.experimental.array_serialization import serialization

Expand Down Expand Up @@ -505,11 +509,181 @@ def _initialize_runtime_to_distributed_ids(timeout: int):
)


_PROCESS_ID_FILE_NAME: str = "process_id.txt"


def _get_previous_process_id(local_dir: str, *, unique_str: str) -> int:
"""Gets previous process id from local checkpoint directory. Returns -1 if file isn't found."""
path = os.path.join(local_dir, _get_unique_id(unique_str), _PROCESS_ID_FILE_NAME)
if not fs.exists(path):
return -1

with fs.open(path) as f:
proc_id = int(f.read())
return proc_id


def _dump_process_id(local_dir: str, *, unique_str: str, process_index: int):
"""Dumps process id to local checkpoint directory."""
local_dir = os.path.join(local_dir, _get_unique_id(unique_str))
fs.makedirs(local_dir)
process_id_file = os.path.join(local_dir, _PROCESS_ID_FILE_NAME)
with fs.open(process_id_file, "w") as f:
f.write(str(process_index))


def _get_unique_id(unique_str: str) -> str:
return hashlib.sha256(unique_str.encode(), usedforsecurity=False).hexdigest()


def _init_consistent_proc_ids(
*,
distributed_coordinator: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None,
initialization_timeout: Optional[int] = None,
trainer_dir: str,
local_ckpt_dir: str,
):
"""Reads local process id file and assigns globally consistent process ids through rank 0.
During failover, healthy nodes will read their locally stored process id file, but failed nodes
will lost their process ids. To assign ids that are free in the global id range (i.e. 0 to
num_processes - 1), we let each node report its process id (-1 if missing) to rank 0, and rank
0 will figure out suitable IDs to assign to each failed node. We reuse Jax's distributed client
to avoid writing our own coordinator.
"""
jax.distributed.initialize(
coordinator_address=distributed_coordinator,
num_processes=num_processes,
process_id=process_id,
initialization_timeout=initialization_timeout,
)
timeout_in_ms = 300 * 1000
client: jax.lib.xla_extension.DistributedRuntimeClient = global_state.client
prev_process_id = _get_previous_process_id(local_ckpt_dir, unique_str=trainer_dir)
prefix = "axlearn/id_reassign"
# Local key just needs to be unique for each process.
local_set_key = f"{prefix}/{jax.process_index()}"
# For TPU backend, only GKE is supported for now.
if jax.default_backend() == "tpu":
# For TPUs, we have the additional requirement that process ids in slice id X must be in
# range [X * num_processes_per_slice, (X + 1) * num_processes_per_slice). Therefore, we
# first identify the healthy slices' ids and then figure out the slice ids to assign to
# failed slices. Each process in the failed slice will then get id `new_slice_id *
# num_proc_per_slice + worker_id`.
client.key_value_set(
local_set_key,
f"{os.environ['MEGASCALE_SLICE_ID']}|{prev_process_id}|{os.environ['TPU_WORKER_ID']}",
)
client.wait_at_barrier("axlearn/id-reassign-gather-id", timeout_in_ms=timeout_in_ms)
if jax.process_index() == 0:
ids = client.key_value_dir_get(prefix)
parsed_ids: list[tuple[int, int, int]] = []
for _, v in ids:
data = v.split("|")
assert len(data) == 3
parsed_ids.append(tuple(int(x) for x in data))

num_proc_per_slice = len(str(os.environ.get("TPU_WORKER_HOSTNAMES", None)).split(","))
failed_slices_new_ids = {}
for slice_id, prev_proc_id, _ in parsed_ids:
if prev_proc_id == -1:
failed_slices_new_ids[slice_id] = -1

already_assigned_slice_ids = set()
for slice_id, prev_proc_id, _ in parsed_ids:
if slice_id not in failed_slices_new_ids:
already_assigned_slice_ids.add(prev_proc_id // num_proc_per_slice)

to_be_assigned_slice_ids = (
set(range(int(os.environ["MEGASCALE_NUM_SLICES"]))) - already_assigned_slice_ids
)
assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids)
for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids):
failed_slices_new_ids[k] = new_id

for (k, _), (slice_id, prev_proc_id, worker_id) in zip(ids, parsed_ids):
if (new_slice_id := failed_slices_new_ids.get(slice_id)) is not None:
client.key_value_set(
k + "/get", str(new_slice_id * num_proc_per_slice + worker_id)
)
else:
client.key_value_set(k + "/get", str(prev_proc_id))
elif jax.default_backend() == "gpu":
# For GPU backend, failed nodes are assigned with ids that are missing in the global id
# range with arbitrary order.
client.key_value_set(local_set_key, str(prev_process_id))
client.wait_at_barrier("axlearn/id-reassign-gather-id", timeout_in_ms=timeout_in_ms)
if jax.process_index() == 0:
ids = client.key_value_dir_get(prefix)
to_be_assigned_proc_ids = list(
set(range(num_processes)) - set(int(value) for _, value in ids if int(value) != -1)
)
counter = 0
for k, value in ids:
if int(value) == -1:
client.key_value_set(k + "/get", str(to_be_assigned_proc_ids[counter]))
counter += 1
else:
client.key_value_set(k + "/get", value)
assert counter == len(to_be_assigned_proc_ids)
else:
raise RuntimeError(f"Unsupported backend {jax.default_backend()}")

_dump_process_id(
local_ckpt_dir,
unique_str=trainer_dir,
process_index=int(
client.blocking_key_value_get(local_set_key + "/get", timeout_in_ms=timeout_in_ms)
),
)
# Block to avoid coordinator exiting too early.
client.wait_at_barrier("axlearn/id-reassign-finalize", timeout_in_ms=timeout_in_ms)
jax.distributed.shutdown()


def get_consistent_proc_id(
*,
distributed_coordinator: Optional[str] = None,
num_processes: Optional[int] = None,
process_id: Optional[int] = None,
initialization_timeout: Optional[int] = None,
trainer_dir: str,
local_ckpt_dir: str,
) -> int:
"""Returns process id so that process id <-> node mapping stays the same for health nodes.
This is required to preserve shard order for in-memory checkpoint recovery. For GPU training,
all healthy nodes will have their process id unchanged. For TPU, all nodes in the healthy
slices will have their process id unchanged. See docstring of `_init_consistent_proc_ids` for
implementation details.
"""
proc = Process(
target=_init_consistent_proc_ids,
kwargs=dict(
distributed_coordinator=distributed_coordinator,
num_processes=num_processes,
process_id=process_id,
initialization_timeout=initialization_timeout,
trainer_dir=trainer_dir,
local_ckpt_dir=local_ckpt_dir,
),
)
proc.start()
proc.join()
assert proc.exitcode == 0

proc_id = _get_previous_process_id(local_ckpt_dir, unique_str=trainer_dir)
assert proc_id != -1
return proc_id


class OrbaxEmergencyCheckpointer(BaseCheckpointer):
"""Checkpointer implementation that uses Orbax emergency checkpoint.
This checkpointer is intended for multi-slice training that uses data-parallelism across
slices. Orbax emergency checkpoint works by exploiting the following properties
slices. Orbax emergency checkpoint works by exploiting the following properties:
1. Tensors are replicated across data-parallel replicas.
2. When a slice fails in a multi-slice training and failover is started, only nodes
corresponding to the non-healthy slice may be restarted. Healthy nodes from healthy slices
Expand Down Expand Up @@ -560,9 +734,12 @@ class Config(BaseCheckpointer.Config):
keep_every_n_steps: If > 0, keeps at least one persistent checkpoint every N steps.
local_keep_last_n: Keep this many past ckpts in local storage (e.g. node memory).
This should almost always set to 1 to avoid OOM.
local_dir: Ckpt path for local storage. The content in this path must persist across
pod restarts unless the restart is caused by node failure. `local_dir` must be the
same for all processes or processes may hang.
local_dir: Ckpt base path for local storage. The content in this path must persist
across pod restarts unless the restart is caused by node failure. `local_dir` must
be the same for all processes or processes may hang.
unqiue_str: A string that's unique for the current run. Typically, this is set to
trainer_dir. Local checkpoint will be stored in local_dir/sha256(unique_str).
During init, all other folders in local_dir will be removed.
save_policy: Save policy for persistent checkpoints.
local_save_policy: Save policy for local checkpoints. This should be more frequent than
`save_policy`. Note that data iterator will be saved with either `save_policy` or
Expand All @@ -580,6 +757,7 @@ class Config(BaseCheckpointer.Config):
every_n_steps_policy
).set(n=10)
local_dir: str = "/host-tmp/checkpoints"
unique_str: Required[str] = REQUIRED
non_tensor_async_timeout_secs: int = 300
async_timeout_secs: int = 3600
replica_axis_index: Required[int] = REQUIRED
Expand Down Expand Up @@ -624,12 +802,15 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
if jax.process_index() == 0:
fs.makedirs(os.path.join(cfg.dir, self._NON_TENSORS_PREFIX))
fs.makedirs(os.path.join(cfg.dir, self._TENSORS_PREFIX))
fs.makedirs(cfg.local_dir)
ocp.multihost.sync_global_processes(
"axlearn-persistent-dir-create", timeout=cfg.non_tensor_async_timeout_secs
)
# Cleanup local checkpoints from different runs.
unique_id = _get_unique_id(cfg.unique_str)
for fd in fs.listdir(cfg.local_dir):
if not fd.startswith(".") and fd != unique_id:
fs.rmtree(os.path.join(cfg.local_dir, fd))
self._local_dir = os.path.join(cfg.local_dir, unique_id)
fs.makedirs(self._local_dir)
# Orbax emergency ckpt requires this function to be called prior to checkpointer
# operations.
# operations. This function also serves as a barrier.
_initialize_runtime_to_distributed_ids(cfg.non_tensor_async_timeout_secs)
ckpt_cfg: Checkpointer.Config = Checkpointer.default_config()
# TODO(hanzhi-zhou): this `keep_last_n` may not be what users expect since non-tensor
Expand Down Expand Up @@ -695,7 +876,7 @@ def _orbax_save_fn(
# For meaning of these options, refer to
# https://github.com/google/orbax/blob/95be2c021bc8cbf4badd83a053ff57b7a9f9b314/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L277
self._tensor_manager = oecp.CheckpointManager(
cfg.local_dir,
self._local_dir,
persistent_directory=os.path.join(cfg.dir, self._TENSORS_PREFIX),
global_mesh=thread_resources.env.physical_mesh,
abstract_state=self._get_abstract_state(state_with_tensors),
Expand Down
Loading

0 comments on commit b4a00eb

Please sign in to comment.