Skip to content

Commit

Permalink
Learner: add new unittests using Model. (#900)
Browse files Browse the repository at this point in the history
These tests are similar how trainer.py uses Learner, which demonstrates how API
users should use the API while ensuring that it works correctly.
  • Loading branch information
ds-hwang authored Dec 19, 2024
1 parent a15a3bc commit 6a7d2f0
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 7 deletions.
9 changes: 3 additions & 6 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,12 @@ def should_apply(tree: Nested[Any]) -> Nested[bool]:
sub_learner_updates = sub_learner_updates.mask(
# pylint: disable-next=cell-var-from-loop
lambda _: should_apply(updates.opt_params),
fields=(
"opt_params",
"delta_updates",
),
fields=("opt_params", "delta_updates"),
)
sub_learner_updated_model_params = getattr(self, name).update(sub_learner_updates)
updated_model_params = jax.tree.map(
lambda apply, new_v, old_v: new_v if apply else old_v,
should_apply(updates.param_values()),
should_apply(updated_model_params),
sub_learner_updated_model_params,
updated_model_params,
)
Expand Down Expand Up @@ -712,7 +709,7 @@ def _value_and_grad(

split_params = split_params_fn(opt_params)
model_params_grad, model_params_nograd = jax.tree.map(lambda p: p.value, split_params)
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
(unused_loss, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
model_params_grad, inputs=(model_params_nograd, inputs)
)
return Updates(
Expand Down
210 changes: 209 additions & 1 deletion axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import axlearn.common.update_transformation_test
from axlearn.common import schedule
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.base_model import BaseModel
from axlearn.common.config import REQUIRED, Required, config_class, config_for_function
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.layers import Linear
from axlearn.common.learner import (
CompositeLearner,
Learner,
Expand All @@ -28,7 +30,7 @@
should_update_with_optimizers,
)
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import OutputCollection
from axlearn.common.module import OutputCollection, child_context
from axlearn.common.module import functional as F
from axlearn.common.module import new_output_collection
from axlearn.common.optimizer_base import OptParam, OptStateSpec
Expand All @@ -50,6 +52,7 @@
)
from axlearn.common.utils import (
Nested,
NestedTensor,
PartitionSpec,
Tensor,
VDict,
Expand All @@ -59,7 +62,113 @@
)


class TestModel(BaseModel):
"""A simple model for test."""

@config_class
class Config(BaseModel.Config):
dim: int = 4

def __init__(self, cfg, *, parent):
super().__init__(cfg, parent=parent)
enc_cfg = Linear.default_config().set(
input_dim=cfg.dim,
output_dim=cfg.dim,
)
self._add_child("encoder", enc_cfg)

dec_cfg = Linear.default_config().set(
input_dim=cfg.dim,
output_dim=1,
)
self._add_child("decoder", dec_cfg)

def forward(self, input_batch: NestedTensor) -> tuple[Tensor, NestedTensor]:
x = self.encoder(input_batch["x"])
y = self.decoder(x)
loss = jnp.mean(y**2)
aux = dict(discriminator_loss=jnp.mean(jnp.abs(y)))
return loss, aux


class LearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
learning_rate = config_for_function(schedule.stepwise).set(
sub=[0.1, 0.01, 0.001],
start_step=[100, 200],
)
optimizer_cfg = config_for_function(adam_optimizer).set(
learning_rate=learning_rate, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
cfg = Learner.default_config().set(name="test", optimizer=optimizer_cfg)
cfg.ema.decay = ema_decay
learner: Learner = cfg.instantiate(parent=None)

# Init a model.
input_dim = 4
model_cfg = TestModel.default_config().set(name="test", dim=input_dim)
model = model_cfg.instantiate(parent=None)
prng_key = jax.random.PRNGKey(123)
init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5)
params = model.initialize_parameters_recursively(init_key)

# Create model and learner states.
model_param_specs = model.create_parameter_specs_recursively()
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
factorization_spec=spec.factorization if spec else None,
weight_decay_scale=spec.weight_decay_scale if spec else 1.0,
),
params,
model_param_specs,
)
learner_state = learner.init(model_params=opt_params)

# Forward and backward.
def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs:
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

batch = 2
input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim)))
fwd_bwd_outputs, learner_output_collection = F(
learner,
method="forward_and_backward",
state=learner_state,
is_training=True,
prng_key=learner_key,
inputs=dict(
fn=_forward,
opt_params=opt_params,
inputs=dict(
input_batch=input_batch,
forward_key=fwd_key,
),
),
)
forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs
updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and Adam mu states are same.
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
)

def test_prune_empty_state(self):
state = {
"state": {
Expand Down Expand Up @@ -816,6 +925,105 @@ def test__value_and_grad(self):


class CompositeLearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
encoder_lr = 0.1
opt1_cfg = config_for_function(sgd_optimizer).set(
learning_rate=encoder_lr, decouple_weight_decay=True, weight_decay=1.0
)
opt2_cfg = config_for_function(adam_optimizer).set(
learning_rate=0.0, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")]

cfg = CompositeLearner.default_config().set(
name="test",
rules=learner_rules,
learners={
"encoder": Learner.default_config().set(
optimizer=opt1_cfg, enable_per_variable_summaries=True
),
"decoder": Learner.default_config().set(
optimizer=opt2_cfg, enable_per_variable_summaries=False
),
},
)
cfg.ema.decay = ema_decay
learner: CompositeLearner = cfg.instantiate(parent=None)

# Init a model.
input_dim = 4
model_cfg = TestModel.default_config().set(name="test", dim=input_dim)
model = model_cfg.instantiate(parent=None)
prng_key = jax.random.PRNGKey(123)
init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5)
params = model.initialize_parameters_recursively(init_key)

# Create model and learner states.
model_param_specs = model.create_parameter_specs_recursively()
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
factorization_spec=spec.factorization if spec else None,
weight_decay_scale=spec.weight_decay_scale if spec else 1.0,
),
params,
model_param_specs,
)
learner_state = learner.init(model_params=opt_params)

# Forward and backward.
def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs:
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

batch = 2
input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim)))
fwd_bwd_outputs, learner_output_collection = F(
learner,
method="forward_and_backward",
state=learner_state,
is_training=True,
prng_key=learner_key,
inputs=dict(
fn=_forward,
opt_params=opt_params,
inputs=dict(
input_batch=input_batch,
forward_key=fwd_key,
),
),
)
forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs
updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and optimizer states are same.
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
),
)
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
),
)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
# pylint: disable-next=too-many-statements
def test_learner(self, ema_decay: Optional[float], method: str):
Expand Down

0 comments on commit 6a7d2f0

Please sign in to comment.