Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch PruningCallback not pruning #294

Open
deepakpokkalla opened this issue Dec 9, 2024 · 1 comment
Open

PyTorch PruningCallback not pruning #294

deepakpokkalla opened this issue Dec 9, 2024 · 1 comment
Labels
bug Issue/PR about behavior that is broken. Not for typos/CI but for example itself. stale Exempt from stale bot labeling.

Comments

@deepakpokkalla
Copy link

deepakpokkalla commented Dec 9, 2024

Expected behavior

I have implemented a simple example where I want to do hp tuning using optuna and each trial is spawn over two gpus (as my original data is huge). I am not looking for parallelizing trial runs themselves rather doing data parellism across 2 gpus within each trial. I expected that the trial.report() would work from custom PruningCallback I implemented, hwoever it's not pruning. I am trying the replicate the result (same trials should get pruned) when running on a single GPU with everythign else the same.

Environment

  • Optuna version: 2.10.1
  • Python version: 3.10.15
  • OS: Linux
  • torch: 2.2.2+cu121

Error messages, stack traces, or logs

When I run mlp_ddp.py script below none of the trials get pruned as a result of PruningCallback not working properly with optuna

Steps to reproduce

  1. The mlp_ddp.py code contains implementation with ddp within each trial (expect that same trials as "mlp.py" script to be pruned, but don't see any trials getting pruned)
  2. mlp.py code below contains implementation for a single gpu (reference solution for which trials should be pruned)

Reproducible examples (optional)

# my mlp_ddp.py code

import os 

import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler

optuna.logging.set_verbosity(optuna.logging.DEBUG)

class MLP(torch.nn.Module):
    def __init__(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
        super().__init__()

        layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
        for _ in range(n_layers):
            layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
        layers.append(nn.Linear(hidden_dim,out_dim))

        self.model = nn.Sequential(*layers)

    def forward(self,x):
        return self.model(x)

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class PruningCallback:
    def __init__(self,trial,monitor="accuracy"):
        self.trial = trial
        self.monitor = monitor
    
    def on_epoch_end(self,epoch,metrics):
        value = metrics.get(self.monitor)
        if value is None:
            return
        
        self.trial.report(value,step=epoch)
        if self.trial.should_prune():
            raise optuna.TrialPruned()

def objective(rank, world_size, params, callback, return_dict):

    setup(rank, world_size)
    torch.manual_seed(42)
    device = torch.device(f"cuda:{rank}")

    in_dim = 10
    out_dim = 3
    num_train_samples=500
    num_val_samples=100
    num_epochs = 10
    batch_size = 64

    train_data = torch.rand(num_train_samples,in_dim).to(device)
    val_data = torch.rand(num_val_samples,in_dim).to(device)
    train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
    val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)

    model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    loss_function = torch.nn.CrossEntropyLoss()

    out_dir = "./multirun/mlp-optuna-test"
    os.makedirs(out_dir, exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        train_outputs = model(train_data)
        loss = loss_function(train_outputs,train_targets)
        loss.backward()
        optimizer.step()
    
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_data)
            val_predictions = torch.argmax(val_outputs,dim=1)
            val_correct = (val_predictions == val_targets)
            acc = int(val_correct.sum())/len(val_targets)

        acc_tensor = torch.tensor(acc,device=device)
        dist.all_reduce(acc_tensor)
        acc_avg = acc_tensor.item()/world_size
        # print(rank,acc,acc_tensor,acc_avg)

        if rank==0:
            callback.on_epoch_end(epoch,{"accuracy":acc_avg})

    if rank==0:
        return_dict["result"] = acc_avg
    
    cleanup()

def ddp_objective(trial):
    params = {
        "n_layers": trial.suggest_int("n_layers", 1, 5),
        "hidden_dim": trial.suggest_int("hidden_dim", 32, 64),
    }

    world_size = 2  # Number of GPUs
    manager = mp.Manager()
    return_dict = manager.dict()

    callback = PruningCallback(trial,monitor="accuracy")
    mp.spawn(
        objective,
        args=(world_size, params, callback, return_dict),
        nprocs=world_size,
        join=True,
    )

    return return_dict["result"]


if __name__ == "__main__":

    sampler = TPESampler(seed=42)
    pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
    study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
    print(f"Pruner:{study.pruner}")
    print(f"Sampler:{study.sampler}")

    study.optimize(ddp_objective,n_trials=20)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics:")
    print(f"  Number of finished trials: {len(study.trials)}")
    print(f"  Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
    print(f"  Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")

    print("  Number of pruned trials ---: ", len(pruned_trials))
    print("  Number of complete trials ---: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print(f"  Value: {trial.value}")
    print(f"  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")
# my mlp_ddp.py code

import os 

import optuna
from optuna.trial import TrialState
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner, NopPruner

import torch
import torch.nn as nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, DistributedSampler

optuna.logging.set_verbosity(optuna.logging.DEBUG)

class MLP(torch.nn.Module):
    def __init__(self,n_layers,hidden_dim,in_dim=10,out_dim=3):
        super().__init__()

        layers = [nn.Linear(in_dim,hidden_dim), nn.ReLU()]
        for _ in range(n_layers):
            layers += [nn.Linear(hidden_dim,hidden_dim), nn.ReLU()]
        layers.append(nn.Linear(hidden_dim,out_dim))

        self.model = nn.Sequential(*layers)

    def forward(self,x):
        return self.model(x)

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class PruningCallback:
    def __init__(self,trial,monitor="accuracy"):
        self.trial = trial
        self.monitor = monitor
    
    def on_epoch_end(self,epoch,metrics):
        value = metrics.get(self.monitor)
        if value is None:
            return
        
        self.trial.report(value,step=epoch)
        if self.trial.should_prune():
            raise optuna.TrialPruned()

def objective(rank, world_size, params, callback, return_dict):

    setup(rank, world_size)
    torch.manual_seed(42)
    device = torch.device(f"cuda:{rank}")

    in_dim = 10
    out_dim = 3
    num_train_samples=500
    num_val_samples=100
    num_epochs = 10
    batch_size = 64

    train_data = torch.rand(num_train_samples,in_dim).to(device)
    val_data = torch.rand(num_val_samples,in_dim).to(device)
    train_targets = torch.randint(0,out_dim,(num_train_samples,)).to(device)
    val_targets = torch.randint(0,out_dim,(num_val_samples,)).to(device)

    model = MLP(params["n_layers"],params["hidden_dim"],in_dim,out_dim).to(device)
    model = DDP(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    loss_function = torch.nn.CrossEntropyLoss()

    out_dir = "./multirun/mlp-optuna-test"
    os.makedirs(out_dir, exist_ok=True)

    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        train_outputs = model(train_data)
        loss = loss_function(train_outputs,train_targets)
        loss.backward()
        optimizer.step()
    
        model.eval()
        with torch.no_grad():
            val_outputs = model(val_data)
            val_predictions = torch.argmax(val_outputs,dim=1)
            val_correct = (val_predictions == val_targets)
            acc = int(val_correct.sum())/len(val_targets)

        acc_tensor = torch.tensor(acc,device=device)
        dist.all_reduce(acc_tensor)
        acc_avg = acc_tensor.item()/world_size
        # print(rank,acc,acc_tensor,acc_avg)

        if rank==0:
            callback.on_epoch_end(epoch,{"accuracy":acc_avg})

    if rank==0:
        return_dict["result"] = acc_avg
    
    cleanup()

def ddp_objective(trial):
    params = {
        "n_layers": trial.suggest_int("n_layers", 1, 5),
        "hidden_dim": trial.suggest_int("hidden_dim", 32, 64),
    }

    world_size = 2  # Number of GPUs
    manager = mp.Manager()
    return_dict = manager.dict()

    callback = PruningCallback(trial,monitor="accuracy")
    mp.spawn(
        objective,
        args=(world_size, params, callback, return_dict),
        nprocs=world_size,
        join=True,
    )

    return return_dict["result"]


if __name__ == "__main__":

    sampler = TPESampler(seed=42)
    pruner = MedianPruner(n_startup_trials=3,n_warmup_steps=1)
    study = optuna.create_study(direction="maximize", sampler=sampler, pruner=pruner)
    print(f"Pruner:{study.pruner}")
    print(f"Sampler:{study.sampler}")

    study.optimize(ddp_objective,n_trials=20)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics:")
    print(f"  Number of finished trials: {len(study.trials)}")
    print(f"  Number of pruned trials: {len([t for t in study.trials if t.state == TrialState.PRUNED])}")
    print(f"  Number of complete trials: {len([t for t in study.trials if t.state == TrialState.COMPLETE])}")

    print("  Number of pruned trials ---: ", len(pruned_trials))
    print("  Number of complete trials ---: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print(f"  Value: {trial.value}")
    print(f"  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

Additional context (optional)

@deepakpokkalla deepakpokkalla added the bug Issue/PR about behavior that is broken. Not for typos/CI but for example itself. label Dec 9, 2024
Copy link

This issue has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Issue/PR about behavior that is broken. Not for typos/CI but for example itself. stale Exempt from stale bot labeling.
Projects
None yet
Development

No branches or pull requests

1 participant