Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train v2] Populate more deprecation warnings #49455

Open
wants to merge 45 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0f7f22e
remove ray.train.report usage within tune code
justinvyu Dec 17, 2024
3f44c64
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu Dec 17, 2024
a6dcedb
add ray.tune accessors for Train v1 APIs
justinvyu Dec 17, 2024
ea480bd
add tune context
justinvyu Dec 17, 2024
2adb15b
add deprecation messages for train context methods in tune context
justinvyu Dec 17, 2024
b61fbab
add deprecation message if ray.train.get_context is used in ray tune …
justinvyu Dec 17, 2024
58bcfc6
remove ray.air._internal._get_session for sanity
justinvyu Dec 17, 2024
59da5b1
warn if ray.train.report or ray.train.get_checkpoint are used
justinvyu Dec 17, 2024
554f2ee
lint
justinvyu Dec 17, 2024
a6971e1
fix lint
justinvyu Dec 17, 2024
4b37af4
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu Dec 18, 2024
e919c6a
deprecation warnings gated for now
justinvyu Dec 19, 2024
c2cb318
add (gated) warnings on train context
justinvyu Dec 19, 2024
0577e56
add pass through Checkpoint, config classes for warnings
justinvyu Dec 19, 2024
47c799f
fix lint
justinvyu Dec 19, 2024
2fd2d78
mark places to add deprecation warnings
justinvyu Dec 19, 2024
6caef4a
move stuff to train utils file
justinvyu Dec 19, 2024
febc5df
log deprecation warnings for Checkpoint / incorrect configs
justinvyu Dec 19, 2024
6a5d244
fix lint
justinvyu Dec 19, 2024
1e980bc
Add failure config / checkpoint config deprecation logs
justinvyu Dec 19, 2024
7256ef2
propagate the deprecation message gate env var
justinvyu Dec 19, 2024
f86a752
fix lint
justinvyu Dec 19, 2024
9728a3d
add tests checking for warnings
justinvyu Dec 19, 2024
4ecfbf2
add to BUILD
justinvyu Dec 19, 2024
761f327
fix lint
justinvyu Dec 19, 2024
8a2a50b
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu Dec 20, 2024
cc7354b
add tune trainable fn apis to docs
justinvyu Dec 20, 2024
cc8c62c
add tune configs
justinvyu Dec 20, 2024
a4a5c04
remove need for api stability tag on internal util
justinvyu Dec 20, 2024
790a78c
maybe fix the checkpoint double reference lint check??
justinvyu Dec 20, 2024
e4a0e6d
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu Dec 23, 2024
972faef
maybe fix doc
justinvyu Dec 23, 2024
cfb927d
Merge branch 'master' of https://github.com/ray-project/ray into tune…
justinvyu Dec 26, 2024
9f46913
add get_metadata deprecation message
justinvyu Dec 27, 2024
66da94c
_in_tune_session helper
justinvyu Dec 27, 2024
b4bb212
revert backend executor changes
justinvyu Dec 27, 2024
c4fa138
improve deprecation messages for restore
justinvyu Dec 27, 2024
d7b87f2
add some limited compatibility for metrics_dataframe
justinvyu Dec 27, 2024
2c01955
remove resume_from_checkpoint / metadata usage
justinvyu Dec 27, 2024
f14cfd5
remaining resume_from_checkpoint
justinvyu Dec 27, 2024
3137ff7
minor comment removal
justinvyu Dec 27, 2024
e1162a6
raise warnings in deprecated v2 train context methods
justinvyu Dec 27, 2024
ad707e8
some more NotImplementedError -> DeprecationWarning updates
justinvyu Dec 27, 2024
34c54ab
deprecate the torch amp stuff
justinvyu Dec 27, 2024
6b85968
fix lint
justinvyu Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/source/tune/api/execution.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ Tuner Configuration
:toctree: doc/

TuneConfig
RunConfig
CheckpointConfig
FailureConfig

.. seealso::

The `Tuner` constructor also takes in a :class:`RunConfig <ray.train.RunConfig>`.

Restoring a Tuner
~~~~~~~~~~~~~~~~~
Expand Down
19 changes: 19 additions & 0 deletions doc/source/tune/api/trainable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,25 @@ Function API
For reporting results and checkpoints with the function API,
see the :ref:`Ray Train utilities <train-loop-api>` documentation.

**Classes**

.. autosummary::
:nosignatures:
:toctree: doc/

~tune.Checkpoint
~tune.TuneContext

**Functions**

.. autosummary::
:nosignatures:
:toctree: doc/

~tune.get_checkpoint
~tune.get_context
~tune.report

.. _tune-trainable-docstring:

Trainable (Class API)
Expand Down
10 changes: 0 additions & 10 deletions python/ray/air/_internal/session.py

This file was deleted.

6 changes: 3 additions & 3 deletions python/ray/air/integrations/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import ray
from ray import logger
from ray._private.storage import _load_class
from ray.air import session
from ray.air._internal import usage as air_usage
from ray.air.constants import TRAINING_ITERATION
from ray.air.util.node import _force_on_current_node
from ray.train._internal.session import get_session
from ray.train._internal.syncer import DEFAULT_SYNC_TIMEOUT
from ray.tune.experiment import Trial
from ray.tune.logger import LoggerCallback
Expand Down Expand Up @@ -121,8 +121,8 @@ def training_loop(config):

try:
# Do a try-catch here if we are not in a train session
_session = session._get_session(warn=False)
if _session and rank_zero_only and session.get_world_rank() != 0:
session = get_session()
if session and rank_zero_only and session.get_world_rank() != 0:
return RunDisabled()

default_trial_id = session.get_trial_id()
Expand Down
1 change: 0 additions & 1 deletion python/ray/air/session.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from ray.air._internal.session import _get_session # noqa: F401
from ray.train._internal.session import * # noqa: F401,F403
63 changes: 45 additions & 18 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type

import ray
from ray.air._internal.session import _get_session
from ray.air._internal.util import RunnerThread, StartTraceback
from ray.air.constants import (
_ERROR_FETCH_TIMEOUT,
Expand All @@ -33,8 +32,10 @@
WORKER_HOSTNAME,
WORKER_NODE_IP,
WORKER_PID,
_v2_migration_warnings_enabled,
)
from ray.train.error import SessionMisuseError
from ray.train.utils import _log_deprecation_warning
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.util.debug import log_once
from ray.util.placement_group import _valid_resource_shape
Expand Down Expand Up @@ -646,7 +647,7 @@ def inner(fn: Callable):

@functools.wraps(fn)
def wrapper(*args, **kwargs):
session = _get_session()
session = get_session()
if not session:
if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
warnings.warn(
Expand Down Expand Up @@ -745,14 +746,27 @@ def train_func(config):
metrics: The metrics you want to report.
checkpoint: The optional checkpoint you want to report.
"""
# If we are running in a Tune function, switch to `ray.tune.report`.
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

_get_session().report(metrics, checkpoint=checkpoint)
if _in_tune_session():
import ray.tune

if _v2_migration_warnings_enabled():
_log_deprecation_warning(
"`ray.train.report` should be switched to "
"`ray.tune.report` when running in a function "
"passed to Ray Tune. This will be an error in the future."
)
return ray.tune.report(metrics, checkpoint=checkpoint)

get_session().report(metrics, checkpoint=checkpoint)


@PublicAPI(stability="stable")
@_warn_session_misuse()
def get_checkpoint() -> Optional[Checkpoint]:
"""Access the session's last checkpoint to resume from if applicable.
"""Access the latest reported checkpoint to resume from if one exists.

Example:

Expand Down Expand Up @@ -792,50 +806,63 @@ def train_func(config):
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""
# If we are running in a Tune function, switch to `ray.tune.get_checkpoint`.
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
import ray.tune

if _v2_migration_warnings_enabled():
_log_deprecation_warning(
"`ray.train.get_checkpoint` should be switched to "
"`ray.tune.get_checkpoint` when running in a function "
"passed to Ray Tune. This will be an error in the future."
)
return ray.tune.get_checkpoint()

return _get_session().loaded_checkpoint
return get_session().loaded_checkpoint


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_metadata() -> Dict[str, Any]:
"""User metadata dict passed to the Trainer constructor."""
return _get_session().metadata
return get_session().metadata


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_experiment_name() -> str:
"""Experiment name for the corresponding trial."""
return _get_session().experiment_name
return get_session().experiment_name


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_name() -> str:
"""Trial name for the corresponding trial."""
return _get_session().trial_name
return get_session().trial_name


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_id() -> str:
"""Trial id for the corresponding trial."""
return _get_session().trial_id
return get_session().trial_id


@PublicAPI(stability="alpha")
@_warn_session_misuse()
def get_run_id() -> str:
"""Unique Train Run id for the corresponding trial."""
return _get_session().run_id
return get_session().run_id


@PublicAPI(stability="beta")
@_warn_session_misuse()
def get_trial_resources() -> "PlacementGroupFactory":
"""Trial resources for the corresponding trial."""
return _get_session().trial_resources
return get_session().trial_resources


@PublicAPI(stability="beta")
Expand All @@ -860,7 +887,7 @@ def train_func(config):

/Users/root/ray_results/train_func_2023-07-19_15-01-37/train_func_d620c_00000_0_2023-07-19_15-01-40
"""
return _get_session().trial_dir
return get_session().trial_dir


@PublicAPI(stability="beta")
Expand Down Expand Up @@ -893,7 +920,7 @@ def train_loop_per_worker(config):

...
"""
session = _get_session()
session = get_session()
if not hasattr(session, "world_size"):
raise RuntimeError(
"`get_world_size` can only be called for TrainSession! "
Expand Down Expand Up @@ -932,7 +959,7 @@ def train_loop_per_worker(config):

...
"""
session = _get_session()
session = get_session()
if not hasattr(session, "world_rank"):
raise RuntimeError(
"`get_world_rank` can only be called for TrainSession! "
Expand Down Expand Up @@ -974,7 +1001,7 @@ def train_loop_per_worker(config):

...
"""
session = _get_session()
session = get_session()
if not hasattr(session, "local_rank"):
raise RuntimeError(
"`get_local_rank` can only be called for TrainSession! "
Expand Down Expand Up @@ -1013,7 +1040,7 @@ def train_loop_per_worker():

...
"""
session = _get_session()
session = get_session()
if not hasattr(session, "local_world_size"):
raise RuntimeError(
"`get_local_world_size` can only be called for TrainSession! "
Expand Down Expand Up @@ -1052,7 +1079,7 @@ def train_loop_per_worker():

...
"""
session = _get_session()
session = get_session()
if not hasattr(session, "node_rank"):
raise RuntimeError(
"`get_node_rank` can only be called for TrainSession! "
Expand Down Expand Up @@ -1109,7 +1136,7 @@ def train_loop_per_worker(config):
The ``DataIterator`` shard to use for this worker.
If no dataset is passed into Trainer, then return None.
"""
session = _get_session()
session = get_session()
if not hasattr(session, "get_dataset_shard"):
raise RuntimeError(
"`get_dataset_shard` can only be called for TrainSession! "
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ray.air.config import RunConfig, ScalingConfig
from ray.air.result import Result
from ray.train import Checkpoint
from ray.train._internal.session import _get_session
from ray.train._internal.session import get_session
from ray.train._internal.storage import (
StorageContext,
_exists_at_fs_path,
Expand Down Expand Up @@ -83,7 +83,7 @@ def _train_coordinator_fn(
"""
assert metadata is not None, metadata
# Propagate user metadata from the Trainer constructor.
_get_session().metadata = metadata
get_session().metadata = metadata

# config already contains merged values.
# Instantiate new Trainer in Trainable.
Expand Down
9 changes: 9 additions & 0 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import ray
from ray._private.ray_constants import env_bool
from ray.air.constants import ( # noqa: F401
COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
EVALUATION_DATASET_KEY,
Expand Down Expand Up @@ -90,6 +91,14 @@ def _get_ray_train_session_dir() -> str:
# Defaults to 0
RAY_TRAIN_ENABLE_STATE_TRACKING = "RAY_TRAIN_ENABLE_STATE_TRACKING"

# Set this to 1 to enable deprecation warnings for V2 migration.
ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR = "RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS"


def _v2_migration_warnings_enabled() -> bool:
return env_bool(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, False)


# NOTE: When adding a new environment variable, please track it in this list.
TRAIN_ENV_VARS = {
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
Expand Down
Loading