-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FeaturesLineWidget #200
Open
zoccoler
wants to merge
10
commits into
matplotlib:main
Choose a base branch
from
zoccoler:line_and_featuresline_widgets
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
FeaturesLineWidget #200
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
35e5808
Create line.py
zoccoler bfb655a
Update line.py
zoccoler e2eb48f
remove minimum size
zoccoler 22f7ce6
Merge branch 'main' into line_plot_widget
zoccoler 86ac6d2
WIP: re-factor line widgets
zoccoler 5891f7f
remove unused packages
zoccoler c2a0586
add examples
zoccoler 4607213
revert changes in base
zoccoler 7fff195
remove line widget
zoccoler ece6ed1
replace label by object_id
zoccoler File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import napari | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
from napari_matplotlib.line import FeaturesLineWidget | ||
|
||
labels = np.array([[0, 0, 1, 0], | ||
[0, 2, 1, 0], | ||
[2, 2, 2, 0], | ||
[3, 3, 2, 0], | ||
[0, 3, 0, 0]]) | ||
table = pd.DataFrame(data=np.array([np.array([1, 2, 3]), np.array([2, 5, 3]), np.array([0, 1, 0.5])]).T, | ||
columns=['label', 'measurement1', 'measurement2']) | ||
|
||
viewer = napari.Viewer() | ||
viewer.add_labels(labels, features=table, name='labels') | ||
|
||
plotter_widget = FeaturesLineWidget(viewer) | ||
viewer.window.add_dock_widget(plotter_widget) | ||
|
||
napari.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import napari | ||
import numpy as np | ||
import pandas as pd | ||
from napari_matplotlib.line import FeaturesLineWidget | ||
|
||
labels = np.array([[0, 0, 1, 0], | ||
[0, 2, 1, 0], | ||
[2, 2, 2, 0], | ||
[3, 3, 2, 0], | ||
[0, 3, 0, 0]]) | ||
|
||
table = pd.DataFrame(data=np.array([ | ||
np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]), | ||
np.array([2, 5, 3, 3, 6, 4, 4, 7, 3]), | ||
np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]),]).T, | ||
columns=['label', | ||
'mean_intensity', | ||
'frame']) | ||
|
||
viewer = napari.Viewer() | ||
viewer.add_labels(labels, features=table, name='labels') | ||
|
||
plotter_widget = FeaturesLineWidget(viewer) | ||
viewer.window.add_dock_widget(plotter_widget) | ||
|
||
napari.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
from cycler import cycler | ||
import napari | ||
import numpy as np | ||
import numpy.typing as npt | ||
from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget | ||
|
||
from .base import NapariMPLWidget | ||
from .util import Interval | ||
|
||
__all__ = ["LineBaseWidget", "MetadataLineWidget", "FeaturesLineWidget"] | ||
|
||
|
||
class LineBaseWidget(NapariMPLWidget): | ||
""" | ||
Base class for widgets that do line plots of two datasets against each other. | ||
""" | ||
|
||
def __init__(self, napari_viewer: napari.viewer.Viewer, parent: Optional[QWidget] = None, | ||
): | ||
super().__init__(napari_viewer, parent=parent) | ||
self.add_single_axes() | ||
|
||
def clear(self) -> None: | ||
""" | ||
Clear the axes. | ||
""" | ||
self.axes.clear() | ||
|
||
def draw(self) -> None: | ||
""" | ||
Plot lines for the currently selected layers. | ||
""" | ||
x, y, x_axis_name, y_axis_name = self._get_data() | ||
self.axes.plot(x, y) | ||
self.axes.set_xlabel(x_axis_name) | ||
self.axes.set_ylabel(y_axis_name) | ||
|
||
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: | ||
"""Get the plot data. | ||
|
||
This must be implemented on the subclass. | ||
|
||
Returns | ||
------- | ||
data : np.ndarray | ||
The list containing the line plot data. | ||
x_axis_name : str | ||
The label to display on the x axis | ||
y_axis_name: str | ||
The label to display on the y axis | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class FeaturesLineWidget(LineBaseWidget): | ||
""" | ||
Widget to do line plots of two features from a layer, grouped by object_id. | ||
""" | ||
|
||
n_layers_input = Interval(1, 1) | ||
# Currently working with Labels layer | ||
input_layer_types = ( | ||
napari.layers.Labels, | ||
) | ||
|
||
def __init__( | ||
self, | ||
napari_viewer: napari.viewer.Viewer, | ||
parent: Optional[QWidget] = None, | ||
): | ||
super().__init__(napari_viewer, parent=parent) | ||
|
||
self.layout().addLayout(QVBoxLayout()) | ||
|
||
self._selectors: Dict[str, QComboBox] = {} | ||
# Add split-by selector | ||
self._selectors["object_id"] = QComboBox() | ||
self._selectors["object_id"].currentTextChanged.connect(self._draw) | ||
self.layout().addWidget(QLabel(f"object_id:")) | ||
self.layout().addWidget(self._selectors["object_id"]) | ||
|
||
for dim in ["x", "y"]: | ||
self._selectors[dim] = QComboBox() | ||
# Re-draw when combo boxes are updated | ||
self._selectors[dim].currentTextChanged.connect(self._draw) | ||
|
||
self.layout().addWidget(QLabel(f"{dim}-axis:")) | ||
self.layout().addWidget(self._selectors[dim]) | ||
|
||
self._update_layers(None) | ||
|
||
@property | ||
def x_axis_key(self) -> Union[str, None]: | ||
""" | ||
Key for the x-axis data. | ||
""" | ||
if self._selectors["x"].count() == 0: | ||
return None | ||
else: | ||
return self._selectors["x"].currentText() | ||
|
||
@x_axis_key.setter | ||
def x_axis_key(self, key: str) -> None: | ||
self._selectors["x"].setCurrentText(key) | ||
self._draw() | ||
|
||
@property | ||
def y_axis_key(self) -> Union[str, None]: | ||
""" | ||
Key for the y-axis data. | ||
""" | ||
if self._selectors["y"].count() == 0: | ||
return None | ||
else: | ||
return self._selectors["y"].currentText() | ||
|
||
@y_axis_key.setter | ||
def y_axis_key(self, key: str) -> None: | ||
self._selectors["y"].setCurrentText(key) | ||
self._draw() | ||
|
||
@property | ||
def object_id_axis_key(self) -> Union[str, None]: | ||
""" | ||
Key for the object_id factor. | ||
""" | ||
if self._selectors["object_id"].count() == 0: | ||
return None | ||
else: | ||
return self._selectors["object_id"].currentText() | ||
|
||
@object_id_axis_key.setter | ||
def object_id_axis_key(self, key: str) -> None: | ||
self._selectors["object_id"].setCurrentText(key) | ||
self._draw() | ||
|
||
def _get_valid_axis_keys(self) -> List[str]: | ||
""" | ||
Get the valid axis keys from the layer FeatureTable. | ||
|
||
Returns | ||
------- | ||
axis_keys : List[str] | ||
The valid axis keys in the FeatureTable. If the table is empty | ||
or there isn't a table, returns an empty list. | ||
""" | ||
if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): | ||
return [] | ||
else: | ||
return self.layers[0].features.keys() | ||
|
||
def _check_valid_object_id_data_and_set_color_cycle(self): | ||
# If no features, return False | ||
# If no object_id_axis_key, return False | ||
if self.layers[0].features is None \ | ||
or len(self.layers[0].features) == 0 \ | ||
or self.object_id_axis_key is None: | ||
return False | ||
feature_table = self.layers[0].features | ||
# Return True if object_ids from table match labels from layer, otherwise False | ||
object_ids_from_table = np.unique(feature_table[self.object_id_axis_key].values).astype(int) | ||
labels_from_layer = np.unique(self.layers[0].data)[1:] # exclude zero | ||
if np.array_equal(object_ids_from_table, labels_from_layer): | ||
# Set color cycle | ||
self._set_color_cycle(object_ids_from_table.tolist()) | ||
return True | ||
return False | ||
|
||
def _ready_to_plot(self) -> bool: | ||
""" | ||
Return True if selected layer has a feature table we can plot with, | ||
the two columns to be plotted have been selected, and object | ||
identifier (usually 'labels') in the table. | ||
""" | ||
if not hasattr(self.layers[0], "features"): | ||
return False | ||
|
||
feature_table = self.layers[0].features | ||
valid_keys = self._get_valid_axis_keys() | ||
valid_object_id_data = self._check_valid_object_id_data_and_set_color_cycle() | ||
|
||
return ( | ||
feature_table is not None | ||
and len(feature_table) > 0 | ||
and self.x_axis_key in valid_keys | ||
and self.y_axis_key in valid_keys | ||
and self.object_id_axis_key in valid_keys | ||
and valid_object_id_data | ||
) | ||
|
||
def draw(self) -> None: | ||
""" | ||
Plot lines for two features from the currently selected layer, grouped by object_id. | ||
""" | ||
if self._ready_to_plot(): | ||
# draw calls _get_data and then plots the data | ||
super().draw() | ||
|
||
def _set_color_cycle(self, labels): | ||
""" | ||
Set the color cycle for the plot from the colors in the Labels layer. | ||
""" | ||
colors = [self.layers[0].get_color(label) for label in labels] | ||
napari_labels_cycler = (cycler(color=colors)) | ||
self.axes.set_prop_cycle(napari_labels_cycler) | ||
|
||
def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: | ||
""" | ||
Get the plot data from the ``features`` attribute of the first | ||
selected layer grouped by object_id. | ||
|
||
Returns | ||
------- | ||
data : List[np.ndarray] | ||
List contains X and Y columns from the FeatureTable. Returns | ||
an empty array if nothing to plot. | ||
x_axis_name : str | ||
The title to display on the x axis. Returns | ||
an empty string if nothing to plot. | ||
y_axis_name: str | ||
The title to display on the y axis. Returns | ||
an empty string if nothing to plot. | ||
""" | ||
feature_table = self.layers[0].features | ||
|
||
# Sort features by object_id and x_axis_key | ||
feature_table = feature_table.sort_values(by=[self.object_id_axis_key, self.x_axis_key]) | ||
# Get data for each object_id (usually label) | ||
grouped = feature_table.groupby(self.object_id_axis_key) | ||
x = np.array([sub_df[self.x_axis_key].values for label, sub_df in grouped]).T.squeeze() | ||
y = np.array([sub_df[self.y_axis_key].values for label, sub_df in grouped]).T.squeeze() | ||
|
||
x_axis_name = str(self.x_axis_key) | ||
y_axis_name = str(self.y_axis_key) | ||
|
||
return x, y, x_axis_name, y_axis_name | ||
|
||
def on_update_layers(self) -> None: | ||
""" | ||
Called when the layer selection changes by ``self.update_layers()``. | ||
""" | ||
# Clear combobox | ||
for dim in ["object_id", "x", "y"]: | ||
while self._selectors[dim].count() > 0: | ||
self._selectors[dim].removeItem(0) | ||
# Add keys for newly selected layer | ||
self._selectors[dim].addItems(self._get_valid_axis_keys()) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the line
self.viewer.dims.current_step[-3]
is a bug fix below for datasets that have >= 4 dimensions?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can drop it for now since it does no affect this PR anymore. It had to do with other dimensions, yes