Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added substation segementation dataset #2352

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open

Conversation

rijuld
Copy link

@rijuld rijuld commented Oct 17, 2024

No description provided.

@github-actions github-actions bot added documentation Improvements or additions to documentation datasets Geospatial or benchmark datasets testing Continuous integration testing labels Oct 17, 2024
@adamjstewart adamjstewart added this to the 0.7.0 milestone Oct 17, 2024
@adamjstewart
Copy link
Collaborator

Hi @rijuld, thanks for the contribution! If you're new to creating PyTorch datasets, I highly recommend reading the following tutorials:

The only difference between datasets in torchvision and NonGeoDatasets in TorchGeo is that our __getitem__ returns a dictionary instead of a tuple. Other than that, they share all the same basic components.

Most of your issues seem to be due to the use of args. I think you just need to remove this and explicitly list all parameters in the function signature. This will also simplify your testing code. Take a look at other existing datasets, we have about 75 examples to choose from. If you find one that is similar to your dataset, it shouldn't actually require that many changes to get them working.

@rijuld
Copy link
Author

rijuld commented Oct 22, 2024

Hi @adamjstewart , thanks a ton for the feedback! I will go through this tutorial.

image = image[:4, :, :, :] if self.use_timepoints else image[0]
return torch.from_numpy(image)

def _apply_transforms(self, image: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Collaborator

@nilsleh nilsleh Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rijuld, thank you for contributing this dataset. As another pointer, torchgeo datasets usually have an accompanying datamodule that defines things like the train/val/test split, but also common data augmentations, like flips, color augmentations etc through the kornia package. So in essence, torchgeo datasets simply load a particular sample and the augmentations are applied on GPU over the batch.

For example in this dataset, the getitem method loads the image and mask, and then we have a corresponding datamodule where we define augmentations like resizing and others, which will automatically be applied with a lightning training setup. This helps streamlining the datasets and keep them "minimal" and also make use of existing augmentation implementations like Kornia.

Let me know if I can help with any further questions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nilsleh , thank you for the detailed explanation!

That makes perfect sense. I will try to make this minimal, implement this today and reach out if I have any further questions.

Thanks again!

Copy link
Author

@rijuld rijuld Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nilsleh,

Hope you're doing well! I wanted to clarify if it's essential to shift all data augmentations to the datamodule. If so, could you guide me on which specific parts of the dataset should be moved there?

I've already removed the geotransform and color transform and plan to add them to the datamodule in my next pull request. If there are other elements you’d suggest removing, I can address those too. Once these adjustments are made, would it be possible to merge this PR (pending review) without the datamodule updates?

Thank you very much for your help!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the late response. Adam prefers having all data normalization in the datamodule for consistency, But I also don't think it is terrible to do in the dataset. If you move it to the datamodule, you can use the kornina Normalize module, that you can add to the augmentation series. Then it will be applied to on_after_batch_transfer in the LightningDataModule.

@rijuld
Copy link
Author

rijuld commented Oct 30, 2024

@microsoft-github-policy-service agree

@rijuld rijuld requested a review from nilsleh October 30, 2024 17:55
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bad data loader, it just doesn't match a single other data loader in TorchGeo. I highly recommend looking at some of the 80+ existing data loaders and unit tests for those data loaders already builtin before adding a new one from scratch. Especially for unit testing, you can probably just copy-n-paste most of the existing test code for a similar dataset.

torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
torchgeo/datasets/substation_seg.py Outdated Show resolved Hide resolved
docs/api/datasets.rst Outdated Show resolved Hide resolved
docs/api/datasets/non_geo_datasets.csv Outdated Show resolved Hide resolved
tests/datasets/test_substation_seg.py Outdated Show resolved Hide resolved
@adamjstewart
Copy link
Collaborator

Can you resolve the merge conflict? The order of __all__ changed in the main branch.

@adamjstewart
Copy link
Collaborator

Still a lot of missing test coverage, will review in more detail after AGU.

@github-actions github-actions bot added the datamodules PyTorch Lightning datamodules label Dec 5, 2024
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the download/extract logic isn't tested yet, and the datamodule also has no tests. To test the datamodule, you can add a tests/conf/substation.yaml file and add 1 line to tests/trainers/test_segmentation.py.

import numpy as np

# Parameters
SIZE = 228 # Image dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually use much smaller fake images (32 x 32) to make the tests run faster

@@ -0,0 +1,157 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs the same copyright header

from torchgeo.datasets import SubstationDataset


class Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of this

self.timepoint_aggregation: str = 'median'


@pytest.fixture
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put all of these in a class like our other tests?



class SubstationDataset(NonGeoDataset):
"""Base class for Substation Dataset.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a base class, it's a class

from .utils import download_url, extract_archive


class SubstationDataset(NonGeoDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't usually put "Dataset" in the class name

* https://doi.org/10.48550/arXiv.2409.17363
"""

directory: str = 'Substation'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to specify the types here, mypy can automatically infer them


def __init__(
self,
data_dir: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually call this variable root

def __init__(
self,
data_dir: str,
in_channels: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually call this bands and allow the user to specify a list of bands, which may not be in order

data_dir: str,
in_channels: int,
use_timepoints: bool,
image_files: list[str],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is usually automatically detected by the dataset class based on the downloaded data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants