Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

adapt pytorch lighting 2.0 AKA lightning #5606

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions nni/common/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,12 @@ def _trace(self, model, dummy_input):
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
try:
import pytorch_lightning as pl
import lightning as pl
except ImportError:
is_lightning_module = False
try:
import pytorch_lightning as pl
except ImportError:
is_lightning_module = False
else:
if isinstance(model, pl.LightningModule):
is_lightning_module = True
Expand Down
22 changes: 17 additions & 5 deletions nni/compression/utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
from torch.utils.hooks import RemovableHandle

try:
import pytorch_lightning as pl
import lightning as pl
except ImportError:
LIGHTNING_INSTALLED = False
try:
import pytorch_lightning as pl
except ImportError:
LIGHTNING_INSTALLED = False
else:
LIGHTNING_INSTALLED = True
else:
LIGHTNING_INSTALLED = True

Expand Down Expand Up @@ -388,10 +393,17 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
err_msg_p = (
'Only support traced {}, please use nni.trace({}) to initialize the trainer. '
'for pytorch_lightning version > 2.0, please using {}'
)
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer', 'lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
err_msg = err_msg_p.format(
'pytorch_lightning.LightningDataModule',
'pytorch_lightning.LightningDataModule',
'lightning.LightningDataModule',
)
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
self.trainer = trainer
self.data_module = data_module
Expand Down
7 changes: 5 additions & 2 deletions nni/compression/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def fix_mask_conflict(masks, model, dummy_input, traced=None):
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
try:
import pytorch_lightning as pl
import lightning as pl
except ImportError:
is_lightning_module = False
try:
import pytorch_lightning as pl
except ImportError:
is_lightning_module = False
else:
if isinstance(model, pl.LightningModule):
is_lightning_module = True
Expand Down
8 changes: 6 additions & 2 deletions nni/nas/evaluator/pytorch/cgo/trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy
try:
import lightning as pl
from lightning.strategies import SingleDeviceStrategy
except ImportError:
import pytorch_lightning as pl
from pytorch_lightning.strategies import SingleDeviceStrategy


class BypassStrategy(SingleDeviceStrategy):
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from pathlib import Path
from typing import Any, Dict, Union, Optional, List, Type

import pytorch_lightning as pl
try:
import lightning as pl
except ImportError:
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/oneshot/pytorch/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import logging

import pytorch_lightning as pl
try:
import lightning as pl
except ImportError:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
Expand Down
5 changes: 4 additions & 1 deletion nni/nas/oneshot/pytorch/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
import logging
from typing import Any, Callable, TYPE_CHECKING

import pytorch_lightning as pl
try:
import lightning as pl
except ImportError:
import pytorch_lightning as pl
import torch

from nni.mutable import Sample
Expand Down