Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FeaturesLineWidget #200

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added examples/features_line_widget.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/features_line_widget_over_time.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions examples/line_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import napari
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from napari_matplotlib.line import FeaturesLineWidget

labels = np.array([[0, 0, 1, 0],
[0, 2, 1, 0],
[2, 2, 2, 0],
[3, 3, 2, 0],
[0, 3, 0, 0]])
table = pd.DataFrame(data=np.array([np.array([1, 2, 3]), np.array([2, 5, 3]), np.array([0, 1, 0.5])]).T,
columns=['label', 'measurement1', 'measurement2'])

viewer = napari.Viewer()
viewer.add_labels(labels, features=table, name='labels')

plotter_widget = FeaturesLineWidget(viewer)
viewer.window.add_dock_widget(plotter_widget)

napari.run()
26 changes: 26 additions & 0 deletions examples/line_features_over_time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import napari
import numpy as np
import pandas as pd
from napari_matplotlib.line import FeaturesLineWidget

labels = np.array([[0, 0, 1, 0],
[0, 2, 1, 0],
[2, 2, 2, 0],
[3, 3, 2, 0],
[0, 3, 0, 0]])

table = pd.DataFrame(data=np.array([
np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]),
np.array([2, 5, 3, 3, 6, 4, 4, 7, 3]),
np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]),]).T,
columns=['label',
'mean_intensity',
'frame'])

viewer = napari.Viewer()
viewer.add_labels(labels, features=table, name='labels')

plotter_widget = FeaturesLineWidget(viewer)
viewer.window.add_dock_widget(plotter_widget)

napari.run()
3 changes: 3 additions & 0 deletions src/napari_matplotlib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def current_z(self) -> int:
Current z-step of the napari viewer.
"""
return self.viewer.dims.current_step[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.viewer.dims.current_step[0]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the line self.viewer.dims.current_step[-3] is a bug fix below for datasets that have >= 4 dimensions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can drop it for now since it does no affect this PR anymore. It had to do with other dimensions, yes

if self.viewer.dims.ndim < 3:
return slice(None)
return self.viewer.dims.current_step[-3]

def _setup_callbacks(self) -> None:
"""
Expand Down
248 changes: 248 additions & 0 deletions src/napari_matplotlib/line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from cycler import cycler
import napari
import numpy as np
import numpy.typing as npt
from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget

from .base import NapariMPLWidget
from .util import Interval

__all__ = ["LineBaseWidget", "MetadataLineWidget", "FeaturesLineWidget"]


class LineBaseWidget(NapariMPLWidget):
"""
Base class for widgets that do line plots of two datasets against each other.
"""

def __init__(self, napari_viewer: napari.viewer.Viewer, parent: Optional[QWidget] = None,
):
super().__init__(napari_viewer, parent=parent)
self.add_single_axes()

def clear(self) -> None:
"""
Clear the axes.
"""
self.axes.clear()

def draw(self) -> None:
"""
Plot lines for the currently selected layers.
"""
x, y, x_axis_name, y_axis_name = self._get_data()
self.axes.plot(x, y)
self.axes.set_xlabel(x_axis_name)
self.axes.set_ylabel(y_axis_name)

def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""Get the plot data.

This must be implemented on the subclass.

Returns
-------
data : np.ndarray
The list containing the line plot data.
x_axis_name : str
The label to display on the x axis
y_axis_name: str
The label to display on the y axis
"""
raise NotImplementedError


class FeaturesLineWidget(LineBaseWidget):
"""
Widget to do line plots of two features from a layer, grouped by object_id.
"""

n_layers_input = Interval(1, 1)
# Currently working with Labels layer
input_layer_types = (
napari.layers.Labels,
)

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
parent: Optional[QWidget] = None,
):
super().__init__(napari_viewer, parent=parent)

self.layout().addLayout(QVBoxLayout())

self._selectors: Dict[str, QComboBox] = {}
# Add split-by selector
self._selectors["object_id"] = QComboBox()
self._selectors["object_id"].currentTextChanged.connect(self._draw)
self.layout().addWidget(QLabel(f"object_id:"))
self.layout().addWidget(self._selectors["object_id"])

for dim in ["x", "y"]:
self._selectors[dim] = QComboBox()
# Re-draw when combo boxes are updated
self._selectors[dim].currentTextChanged.connect(self._draw)

self.layout().addWidget(QLabel(f"{dim}-axis:"))
self.layout().addWidget(self._selectors[dim])

self._update_layers(None)

@property
def x_axis_key(self) -> Union[str, None]:
"""
Key for the x-axis data.
"""
if self._selectors["x"].count() == 0:
return None
else:
return self._selectors["x"].currentText()

@x_axis_key.setter
def x_axis_key(self, key: str) -> None:
self._selectors["x"].setCurrentText(key)
self._draw()

@property
def y_axis_key(self) -> Union[str, None]:
"""
Key for the y-axis data.
"""
if self._selectors["y"].count() == 0:
return None
else:
return self._selectors["y"].currentText()

@y_axis_key.setter
def y_axis_key(self, key: str) -> None:
self._selectors["y"].setCurrentText(key)
self._draw()

@property
def object_id_axis_key(self) -> Union[str, None]:
"""
Key for the object_id factor.
"""
if self._selectors["object_id"].count() == 0:
return None
else:
return self._selectors["object_id"].currentText()

@object_id_axis_key.setter
def object_id_axis_key(self, key: str) -> None:
self._selectors["object_id"].setCurrentText(key)
self._draw()

def _get_valid_axis_keys(self) -> List[str]:
"""
Get the valid axis keys from the layer FeatureTable.

Returns
-------
axis_keys : List[str]
The valid axis keys in the FeatureTable. If the table is empty
or there isn't a table, returns an empty list.
"""
if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")):
return []
else:
return self.layers[0].features.keys()

def _check_valid_object_id_data_and_set_color_cycle(self):
# If no features, return False
# If no object_id_axis_key, return False
if self.layers[0].features is None \
or len(self.layers[0].features) == 0 \
or self.object_id_axis_key is None:
return False
feature_table = self.layers[0].features
# Return True if object_ids from table match labels from layer, otherwise False
object_ids_from_table = np.unique(feature_table[self.object_id_axis_key].values).astype(int)
labels_from_layer = np.unique(self.layers[0].data)[1:] # exclude zero
if np.array_equal(object_ids_from_table, labels_from_layer):
# Set color cycle
self._set_color_cycle(object_ids_from_table.tolist())
return True
return False

def _ready_to_plot(self) -> bool:
"""
Return True if selected layer has a feature table we can plot with,
the two columns to be plotted have been selected, and object
identifier (usually 'labels') in the table.
"""
if not hasattr(self.layers[0], "features"):
return False

feature_table = self.layers[0].features
valid_keys = self._get_valid_axis_keys()
valid_object_id_data = self._check_valid_object_id_data_and_set_color_cycle()

return (
feature_table is not None
and len(feature_table) > 0
and self.x_axis_key in valid_keys
and self.y_axis_key in valid_keys
and self.object_id_axis_key in valid_keys
and valid_object_id_data
)

def draw(self) -> None:
"""
Plot lines for two features from the currently selected layer, grouped by object_id.
"""
if self._ready_to_plot():
# draw calls _get_data and then plots the data
super().draw()

def _set_color_cycle(self, labels):
"""
Set the color cycle for the plot from the colors in the Labels layer.
"""
colors = [self.layers[0].get_color(label) for label in labels]
napari_labels_cycler = (cycler(color=colors))
self.axes.set_prop_cycle(napari_labels_cycler)

def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
"""
Get the plot data from the ``features`` attribute of the first
selected layer grouped by object_id.

Returns
-------
data : List[np.ndarray]
List contains X and Y columns from the FeatureTable. Returns
an empty array if nothing to plot.
x_axis_name : str
The title to display on the x axis. Returns
an empty string if nothing to plot.
y_axis_name: str
The title to display on the y axis. Returns
an empty string if nothing to plot.
"""
feature_table = self.layers[0].features

# Sort features by object_id and x_axis_key
feature_table = feature_table.sort_values(by=[self.object_id_axis_key, self.x_axis_key])
# Get data for each object_id (usually label)
grouped = feature_table.groupby(self.object_id_axis_key)
x = np.array([sub_df[self.x_axis_key].values for label, sub_df in grouped]).T.squeeze()
y = np.array([sub_df[self.y_axis_key].values for label, sub_df in grouped]).T.squeeze()

x_axis_name = str(self.x_axis_key)
y_axis_name = str(self.y_axis_key)

return x, y, x_axis_name, y_axis_name

def on_update_layers(self) -> None:
"""
Called when the layer selection changes by ``self.update_layers()``.
"""
# Clear combobox
for dim in ["object_id", "x", "y"]:
while self._selectors[dim].count() > 0:
self._selectors[dim].removeItem(0)
# Add keys for newly selected layer
self._selectors[dim].addItems(self._get_valid_axis_keys())