From b9b9fcdd3050958e8b3e6818e81829d1517cef16 Mon Sep 17 00:00:00 2001 From: Edward Evans Date: Fri, 27 Jan 2023 12:53:23 -0600 Subject: [PATCH 1/8] Add Linear and Enumerated axis classes to jc --- src/imagej/_java.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/imagej/_java.py b/src/imagej/_java.py index 1de17a29..0083b370 100644 --- a/src/imagej/_java.py +++ b/src/imagej/_java.py @@ -50,6 +50,14 @@ def MetadataWrapper(self): def LabelingIOService(self): return "io.scif.labeling.LabelingIOService" + @JavaClasses.java_import + def DefaultLinearAxis(self): + return "net.imagej.axis.DefaultLinearAxis" + + @JavaClasses.java_import + def EnumeratedAxis(self): + return "net.imagej.axis.EnumeratedAxis" + @JavaClasses.java_import def Dataset(self): return "net.imagej.Dataset" From d1745c868e97f10574aac2e1681a869bfaa0225d Mon Sep 17 00:00:00 2001 From: Edward Evans Date: Mon, 30 Jan 2023 14:16:10 -0600 Subject: [PATCH 2/8] Update _assign_axes to use "imagej" metadata This commit changes how the linear and enumerated axes are assigned. We now look for the "imagej" key in the xarray's global attributes. If the key is present we look for dim + "_axis_scale" to assign linear or enumerated axes. --- src/imagej/_java.py | 4 +++ src/imagej/dims.py | 74 +++++++++++++++++++++++++++------------------ 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/imagej/_java.py b/src/imagej/_java.py index 0083b370..6b4e8bf6 100644 --- a/src/imagej/_java.py +++ b/src/imagej/_java.py @@ -30,6 +30,10 @@ class MyJavaClasses(JavaClasses): significantly easier and more readable. """ + @JavaClasses.java_import + def Double(self): + return "java.lang.Double" + @JavaClasses.java_import def Throwable(self): return "java.lang.Throwable" diff --git a/src/imagej/dims.py b/src/imagej/dims.py index f009c3a0..1877cc07 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -2,7 +2,7 @@ Utility functions for querying and manipulating dimensional axis metadata. """ import logging -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import scyjava as sj @@ -177,49 +177,55 @@ def prioritize_rai_axes_order( return permute_order -def _assign_axes(xarr: xr.DataArray): +def _assign_axes( + xarr: xr.DataArray, +) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]: """ - Obtain xarray axes names, origin, and scale and convert into ImageJ Axis; - currently supports EnumeratedAxis - :param xarr: xarray that holds the units - :return: A list of ImageJ Axis with the specified origin and scale + Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both + DefaultLinearAxis and the newer EnumeratedAxis. + :param xarr: xarray that holds the data. + :return: A list of ImageJ Axis with the specified origin and scale. """ - Double = sj.jimport("java.lang.Double") - - axes = [""] * len(xarr.dims) - - # try to get EnumeratedAxis, if not then default to LinearAxis in the loop - try: - EnumeratedAxis = _get_enumerated_axis() - except (JException, TypeError): - EnumeratedAxis = None - + axes = [""] * xarr.ndim for dim in xarr.dims: - axis_str = _convert_dim(dim, direction="java") + axis_str = _convert_dim(dim, "java") ax_type = jc.Axes.get(axis_str) ax_num = _get_axis_num(xarr, dim) - scale = _get_scale(xarr.coords[dim]) + coords_arr = xarr.coords[dim].to_numpy() - if scale is None: + # check if coords/scale is numeric + if _is_numeric_scale(coords_arr): + doub_coords = [jc.Double(np.double(x)) for x in xarr.coords[dim]] + else: _logger.warning( f"The {ax_type.label} axis is non-numeric and is translated " "to a linear index." ) doub_coords = [ - Double(np.double(x)) for x in np.arange(len(xarr.coords[dim])) + jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim])) ] - else: - doub_coords = [Double(np.double(x)) for x in xarr.coords[dim]] - # EnumeratedAxis is a new axis made for xarray, so is only present in - # ImageJ versions that are released later than March 2020. - # This actually returns a LinearAxis if using an earlier version. - if EnumeratedAxis is not None: - java_axis = EnumeratedAxis(ax_type, sj.to_java(doub_coords)) + # assign axis scale type -- checks for imagej metadata + if "imagej" in xarr.attrs.keys(): + ij_dim = _convert_dim(dim, "java") + if ij_dim + "_axis_scale" in xarr.attrs["imagej"].keys(): + scale_type = xarr.attrs["imagej"][ij_dim + "_axis_scale"] + if scale_type == "linear": + jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) + if scale_type == "enumerated": + try: + EnumeratedAxis = _get_enumerated_axis() + except (JException, TypeError): + EnumeratedAxis = None + if EnumeratedAxis is not None: + jaxis = EnumeratedAxis(ax_type, sj.to_java(doub_coords)) + else: + jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) else: - java_axis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) + # default to DefaultLinearAxis always if no `scale_type` key in attr + jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) - axes[ax_num] = java_axis + axes[ax_num] = jaxis return axes @@ -295,6 +301,16 @@ def _get_scale(axis): return None +def _is_numeric_scale(coords_array: np.ndarray) -> bool: + """ + Checks if the coordinates array of the given axis is numeric. + + :param coords_array: A 1D NumPy array. + :return: bool + """ + return np.issubdtype(coords_array.dtype, np.number) + + def _get_enumerated_axis(): """Get EnumeratedAxis. From 2930c204265e148384ba464911905732d94e48d0 Mon Sep 17 00:00:00 2001 From: Edward Evans Date: Tue, 31 Jan 2023 12:36:46 -0600 Subject: [PATCH 3/8] Add metadata for all CalibratedAxis types Although the only calibrated axes used by nearly everyone are DefaultLinearAxis and EnumeratedAxis, I added metadata support for all CalibratedAxis types (e.g. PolynomialAxis etc...) just to be thorough. This metadata will be used for matching the Calibrated Axis type when going back to ImageJ/Java land. --- src/imagej/_java.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ src/imagej/dims.py | 33 ++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/src/imagej/_java.py b/src/imagej/_java.py index 6b4e8bf6..dd64798d 100644 --- a/src/imagej/_java.py +++ b/src/imagej/_java.py @@ -54,6 +54,10 @@ def MetadataWrapper(self): def LabelingIOService(self): return "io.scif.labeling.LabelingIOService" + @JavaClasses.java_import + def ChapmanRichardsAxis(self): + return "net.imagej.axis.ChapmanRichardsAxis" + @JavaClasses.java_import def DefaultLinearAxis(self): return "net.imagej.axis.DefaultLinearAxis" @@ -62,6 +66,50 @@ def DefaultLinearAxis(self): def EnumeratedAxis(self): return "net.imagej.axis.EnumeratedAxis" + @JavaClasses.java_import + def ExponentialAxis(self): + return "net.imagej.axis.ExponentialAxis" + + @JavaClasses.java_import + def ExponentialRecoveryAxis(self): + return "net.imagej.axis.ExponentialRecoveryAxis" + + @JavaClasses.java_import + def GammaVariateAxis(self): + return "net.imagej.axis.GammaVariateAxis" + + @JavaClasses.java_import + def GaussianAxis(self): + return "net.imagej.axis.GaussianAxis" + + @JavaClasses.java_import + def IdentityAxis(self): + return "net.imagej.axis.IdentityAxis" + + @JavaClasses.java_import + def InverseRodbardAxis(self): + return "net.imagej.axis.InverseRodbardAxis" + + @JavaClasses.java_import + def LogLinearAxis(self): + return "net.imagej.axis.LogLinearAxis" + + @JavaClasses.java_import + def PolynomialAxis(self): + return "net.imagej.axis.PolynomialAxis" + + @JavaClasses.java_import + def PowerAxis(self): + return "net.imagej.axis.PowerAxis" + + @JavaClasses.java_import + def RodbardAxis(self): + return "net.imagej.axis.RodbardAxis" + + @JavaClasses.java_import + def VariableAxis(self): + return "net.iamgej.axis.VariableAxis" + @JavaClasses.java_import def Dataset(self): return "net.imagej.Dataset" diff --git a/src/imagej/dims.py b/src/imagej/dims.py index 1877cc07..722a3743 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -205,11 +205,11 @@ def _assign_axes( jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim])) ] - # assign axis scale type -- checks for imagej metadata + # assign calibrated axis type -- checks for imagej metadata if "imagej" in xarr.attrs.keys(): ij_dim = _convert_dim(dim, "java") - if ij_dim + "_axis_scale" in xarr.attrs["imagej"].keys(): - scale_type = xarr.attrs["imagej"][ij_dim + "_axis_scale"] + if ij_dim + "_cal_axis_type" in xarr.attrs["imagej"].keys(): + scale_type = xarr.attrs["imagej"][ij_dim + "_cal_axis_type"] if scale_type == "linear": jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) if scale_type == "enumerated": @@ -483,3 +483,30 @@ def _to_ijdim(key: str) -> str: return ijdims[key] else: return key + + +def _cal_axis_type_to_str(key) -> str: + """ + Convert a CalibratedAxis type (e.g. net.imagej.axis.DefaultLinearAxis) to + a string. + """ + cal_axis_types = { + jc.ChapmanRichardsAxis: "ChapmanRichardsAxis", + jc.DefaultLinearAxis: "DefaultLinearAxis", + jc.EnumeratedAxis: "EnumeratedAxis", + jc.ExponentialAxis: "ExponentialAxis", + jc.ExponentialRecoveryAxis: "ExponentialRecoveryAxis", + jc.GammaVariateAxis: "GammaVariateAxis", + jc.GaussianAxis: "GaussianAxis", + jc.IdentityAxis: "IdentityAxis", + jc.InverseRodbardAxis: "InverseRodbardAxis", + jc.LogLinearAxis: "LogLinearAxis", + jc.PolynomialAxis: "PolynomialAxis", + jc.PowerAxis: "PowerAxis", + jc.RodbardAxis: "RodbardAxis", + } + + if key.__class__ in cal_axis_types: + return cal_axis_types[key.__class__] + else: + return "unknown" From 717229d96015733d9a5b71f170ab88a849604de4 Mon Sep 17 00:00:00 2001 From: Gabriel Selzer Date: Thu, 15 Jun 2023 16:20:19 -0500 Subject: [PATCH 4/8] Refactor code @elevans and I talked about this in person. I don't see much reason to remove EnumeratedAxis support, since it isn't hurting anything. Maybe we can just exert more pressure towards using DefaultLinearAxis, and we can use NumPy to check dimension linearity! --- src/imagej/_java.py | 48 ---------------- src/imagej/dims.py | 133 +++++++++++--------------------------------- 2 files changed, 33 insertions(+), 148 deletions(-) diff --git a/src/imagej/_java.py b/src/imagej/_java.py index dd64798d..6b4e8bf6 100644 --- a/src/imagej/_java.py +++ b/src/imagej/_java.py @@ -54,10 +54,6 @@ def MetadataWrapper(self): def LabelingIOService(self): return "io.scif.labeling.LabelingIOService" - @JavaClasses.java_import - def ChapmanRichardsAxis(self): - return "net.imagej.axis.ChapmanRichardsAxis" - @JavaClasses.java_import def DefaultLinearAxis(self): return "net.imagej.axis.DefaultLinearAxis" @@ -66,50 +62,6 @@ def DefaultLinearAxis(self): def EnumeratedAxis(self): return "net.imagej.axis.EnumeratedAxis" - @JavaClasses.java_import - def ExponentialAxis(self): - return "net.imagej.axis.ExponentialAxis" - - @JavaClasses.java_import - def ExponentialRecoveryAxis(self): - return "net.imagej.axis.ExponentialRecoveryAxis" - - @JavaClasses.java_import - def GammaVariateAxis(self): - return "net.imagej.axis.GammaVariateAxis" - - @JavaClasses.java_import - def GaussianAxis(self): - return "net.imagej.axis.GaussianAxis" - - @JavaClasses.java_import - def IdentityAxis(self): - return "net.imagej.axis.IdentityAxis" - - @JavaClasses.java_import - def InverseRodbardAxis(self): - return "net.imagej.axis.InverseRodbardAxis" - - @JavaClasses.java_import - def LogLinearAxis(self): - return "net.imagej.axis.LogLinearAxis" - - @JavaClasses.java_import - def PolynomialAxis(self): - return "net.imagej.axis.PolynomialAxis" - - @JavaClasses.java_import - def PowerAxis(self): - return "net.imagej.axis.PowerAxis" - - @JavaClasses.java_import - def RodbardAxis(self): - return "net.imagej.axis.RodbardAxis" - - @JavaClasses.java_import - def VariableAxis(self): - return "net.iamgej.axis.VariableAxis" - @JavaClasses.java_import def Dataset(self): return "net.imagej.Dataset" diff --git a/src/imagej/dims.py b/src/imagej/dims.py index 722a3743..f7137152 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -183,6 +183,14 @@ def _assign_axes( """ Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both DefaultLinearAxis and the newer EnumeratedAxis. + + Note that, in many cases, there are small discrepancies between the coordinates. + This can either be actually within the data, or it can be from floating point math + errors. In this case, we delegate to numpy.isclose to tell us whether our + coordinates are linear or not. If our coordinates are nonlinear, and the + EnumeratedAxis type is available, we will use it. Otherwise, this function + returns a DefaultLinearAxis. + :param xarr: xarray that holds the data. :return: A list of ImageJ Axis with the specified origin and scale. """ @@ -191,41 +199,37 @@ def _assign_axes( axis_str = _convert_dim(dim, "java") ax_type = jc.Axes.get(axis_str) ax_num = _get_axis_num(xarr, dim) - coords_arr = xarr.coords[dim].to_numpy() + coords_arr = xarr.coords[dim].to_numpy().astype(np.double) - # check if coords/scale is numeric - if _is_numeric_scale(coords_arr): - doub_coords = [jc.Double(np.double(x)) for x in xarr.coords[dim]] - else: + # coerce numeric scale + if not _is_numeric_scale(coords_arr): _logger.warning( f"The {ax_type.label} axis is non-numeric and is translated " "to a linear index." ) - doub_coords = [ - jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim])) - ] - - # assign calibrated axis type -- checks for imagej metadata - if "imagej" in xarr.attrs.keys(): - ij_dim = _convert_dim(dim, "java") - if ij_dim + "_cal_axis_type" in xarr.attrs["imagej"].keys(): - scale_type = xarr.attrs["imagej"][ij_dim + "_cal_axis_type"] - if scale_type == "linear": - jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) - if scale_type == "enumerated": - try: - EnumeratedAxis = _get_enumerated_axis() - except (JException, TypeError): - EnumeratedAxis = None - if EnumeratedAxis is not None: - jaxis = EnumeratedAxis(ax_type, sj.to_java(doub_coords)) - else: - jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) + coords_arr = [np.double(x) for x in np.arrange(len(xarr.coords[dim]))] + + # check scale linearity + diffs = np.diff(coords_arr) + linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0])) + + # For non-linear scales, use EnumeratedAxis + try: + EnumeratedAxis = sj.jimport("net.imagej.axis.EnumeratedAxis") + except (JException, TypeError): + EnumeratedAxis = None + # If we can use EnumeratedAxis for a nonlinear scale, then use it + if not linear and EnumeratedAxis: + j_coords = [jc.Double(x) for x in coords_arr] + axes[ax_num] = EnumeratedAxis(ax_type, sj.to_java(j_coords)) + # Otherwise, use DefaultLinearAxis else: - # default to DefaultLinearAxis always if no `scale_type` key in attr - jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords)) - - axes[ax_num] = jaxis + DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis") + scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1 + origin = coords_arr[0] if len(coords_arr) > 0 else 0 + axes[ax_num] = DefaultLinearAxis( + ax_type, jc.Double(scale), jc.Double(origin) + ) return axes @@ -280,27 +284,6 @@ def _get_axes_coords( return coords -def _get_scale(axis): - """ - Get the scale of an axis, assuming it is linear and so the scale is simply - second - first coordinate. - - :param axis: A 1D list like entry accessible with indexing, which contains the - axis coordinates - :return: The scale for this axis or None if it is a non-numeric scale. - """ - try: - # HACK: This axis length check is a work around for singleton dimensions. - # You can't calculate the slope of a singleton dimension. - # This section will be removed when axis-scale-logic is merged. - if len(axis) <= 1: - return 1 - else: - return axis.values[1] - axis.values[0] - except TypeError: - return None - - def _is_numeric_scale(coords_array: np.ndarray) -> bool: """ Checks if the coordinates array of the given axis is numeric. @@ -311,29 +294,6 @@ def _is_numeric_scale(coords_array: np.ndarray) -> bool: return np.issubdtype(coords_array.dtype, np.number) -def _get_enumerated_axis(): - """Get EnumeratedAxis. - - EnumeratedAxis is only in releases later than March 2020. If using - an older version of ImageJ without EnumeratedAxis, use - _get_linear_axis() instead. - """ - return sj.jimport("net.imagej.axis.EnumeratedAxis") - - -def _get_linear_axis(axis_type: "jc.AxisType", values): - """Get linear axis. - - This is used if no EnumeratedAxis is found. If EnumeratedAxis - is available, use _get_enumerated_axis() instead. - """ - DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis") - origin = values[0] - scale = values[1] - values[0] - axis = DefaultLinearAxis(axis_type, scale, origin) - return axis - - def _dataset_to_imgplus(rai: "jc.RandomAccessibleInterval") -> "jc.ImgPlus": """Get an ImgPlus from a Dataset. @@ -483,30 +443,3 @@ def _to_ijdim(key: str) -> str: return ijdims[key] else: return key - - -def _cal_axis_type_to_str(key) -> str: - """ - Convert a CalibratedAxis type (e.g. net.imagej.axis.DefaultLinearAxis) to - a string. - """ - cal_axis_types = { - jc.ChapmanRichardsAxis: "ChapmanRichardsAxis", - jc.DefaultLinearAxis: "DefaultLinearAxis", - jc.EnumeratedAxis: "EnumeratedAxis", - jc.ExponentialAxis: "ExponentialAxis", - jc.ExponentialRecoveryAxis: "ExponentialRecoveryAxis", - jc.GammaVariateAxis: "GammaVariateAxis", - jc.GaussianAxis: "GaussianAxis", - jc.IdentityAxis: "IdentityAxis", - jc.InverseRodbardAxis: "InverseRodbardAxis", - jc.LogLinearAxis: "LogLinearAxis", - jc.PolynomialAxis: "PolynomialAxis", - jc.PowerAxis: "PowerAxis", - jc.RodbardAxis: "RodbardAxis", - } - - if key.__class__ in cal_axis_types: - return cal_axis_types[key.__class__] - else: - return "unknown" From 8b281f98bff81ce7b1288b9dc0039c4fcdb0a1df Mon Sep 17 00:00:00 2001 From: Gabriel Selzer Date: Thu, 15 Jun 2023 16:28:46 -0500 Subject: [PATCH 5/8] Use try/except/finally Mwahahahahahaha --- src/imagej/dims.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/imagej/dims.py b/src/imagej/dims.py index f7137152..9e0d026b 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -215,19 +215,18 @@ def _assign_axes( # For non-linear scales, use EnumeratedAxis try: - EnumeratedAxis = sj.jimport("net.imagej.axis.EnumeratedAxis") + if not linear: + j_coords = [jc.Double(x) for x in coords_arr] + axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords)) + continue except (JException, TypeError): - EnumeratedAxis = None - # If we can use EnumeratedAxis for a nonlinear scale, then use it - if not linear and EnumeratedAxis: - j_coords = [jc.Double(x) for x in coords_arr] - axes[ax_num] = EnumeratedAxis(ax_type, sj.to_java(j_coords)) - # Otherwise, use DefaultLinearAxis - else: - DefaultLinearAxis = sj.jimport("net.imagej.axis.DefaultLinearAxis") + # We don't have EnumeratedAxis available - use DefaultLinearAxis + pass + # For linear scales, use DefaultLinearAxis + finally: scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1 origin = coords_arr[0] if len(coords_arr) > 0 else 0 - axes[ax_num] = DefaultLinearAxis( + axes[ax_num] = jc.DefaultLinearAxis( ax_type, jc.Double(scale), jc.Double(origin) ) From 1bda47b0f7104e0618d4e63b0e98e00e368de26f Mon Sep 17 00:00:00 2001 From: Edward Evans Date: Fri, 16 Jun 2023 10:50:39 -0500 Subject: [PATCH 6/8] Use if/else statement instead of Finally The Finally block was always triggering, overwriting the EnumeratedAxis with DefaultLinearAxis. --- src/imagej/dims.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/imagej/dims.py b/src/imagej/dims.py index 9e0d026b..04bc7c0c 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -213,22 +213,15 @@ def _assign_axes( diffs = np.diff(coords_arr) linear: bool = diffs.size and np.all(np.isclose(diffs, diffs[0])) - # For non-linear scales, use EnumeratedAxis - try: - if not linear: + if not linear: + try: j_coords = [jc.Double(x) for x in coords_arr] axes[ax_num] = jc.EnumeratedAxis(ax_type, sj.to_java(j_coords)) - continue - except (JException, TypeError): - # We don't have EnumeratedAxis available - use DefaultLinearAxis - pass - # For linear scales, use DefaultLinearAxis - finally: - scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1 - origin = coords_arr[0] if len(coords_arr) > 0 else 0 - axes[ax_num] = jc.DefaultLinearAxis( - ax_type, jc.Double(scale), jc.Double(origin) - ) + except (JException, TypeError): + # if EnumeratedAxis not available - use DefaultLinearAxis + axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type) + else: + axes[ax_num] = _get_default_linear_axis(coords_arr, ax_type) return axes @@ -283,6 +276,18 @@ def _get_axes_coords( return coords +def _get_default_linear_axis(coords_arr: np.ndarray, ax_type: "jc.AxisType"): + """ + Create a new DefaultLinearAxis with the given coordinate array and axis type. + + :param coords_arr: A 1D NumPy array. + :return: An instance of net.imagej.axis.DefaultLinearAxis. + """ + scale = coords_arr[1] - coords_arr[0] if len(coords_arr) > 1 else 1 + origin = coords_arr[0] if len(coords_arr) > 0 else 0 + return jc.DefaultLinearAxis(ax_type, jc.Double(scale), jc.Double(origin)) + + def _is_numeric_scale(coords_array: np.ndarray) -> bool: """ Checks if the coordinates array of the given axis is numeric. From de50af9404510093a7447989b69560cdcc770590 Mon Sep 17 00:00:00 2001 From: Edward Evans Date: Fri, 16 Jun 2023 11:04:23 -0500 Subject: [PATCH 7/8] Fix handling of non-numeric scales Non-numeric scales were never handled correctly. If the coords are non-numeric, then trying to convert them to a numpy array without checking if they're numbers will fail. We now check if the coords are numeric or not and then replace them with a linear scale if they are. --- src/imagej/dims.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/imagej/dims.py b/src/imagej/dims.py index 04bc7c0c..03cf29e4 100644 --- a/src/imagej/dims.py +++ b/src/imagej/dims.py @@ -199,15 +199,17 @@ def _assign_axes( axis_str = _convert_dim(dim, "java") ax_type = jc.Axes.get(axis_str) ax_num = _get_axis_num(xarr, dim) - coords_arr = xarr.coords[dim].to_numpy().astype(np.double) + coords_arr = xarr.coords[dim] # coerce numeric scale if not _is_numeric_scale(coords_arr): _logger.warning( - f"The {ax_type.label} axis is non-numeric and is translated " + f"The {ax_type.getLabel()} axis is non-numeric and is translated " "to a linear index." ) - coords_arr = [np.double(x) for x in np.arrange(len(xarr.coords[dim]))] + coords_arr = [np.double(x) for x in np.arange(len(xarr.coords[dim]))] + else: + coords_arr = coords_arr.to_numpy().astype(np.double) # check scale linearity diffs = np.diff(coords_arr) From 68a77ee40ff5651026221cd833a2601c5145dad7 Mon Sep 17 00:00:00 2001 From: Edward Evans Date: Fri, 16 Jun 2023 11:13:28 -0500 Subject: [PATCH 8/8] Add linear, non-linear and non-numeric scale tests --- tests/test_image_conversion.py | 99 ++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/test_image_conversion.py b/tests/test_image_conversion.py index 977ce47f..34461308 100644 --- a/tests/test_image_conversion.py +++ b/tests/test_image_conversion.py @@ -1,4 +1,5 @@ import random +import string import numpy as np import pytest @@ -7,6 +8,7 @@ import imagej.dims as dims import imagej.images as images +from imagej._java import jc # -- Image helpers -- @@ -94,6 +96,75 @@ def get_xarr(option="C"): return xarr +def get_non_linear_coord_xarr(option="C"): + name: str = "non_linear_coord_data_array" + linear_coord_arr = np.arange(5) + # generate a 1D log scale array + non_linear_coord_arr = np.logspace(0, np.log10(100), num=30) + if option == "C": + xarr = xr.DataArray( + np.random.rand(30, 30, 5), + dims=["row", "col", "ch"], + coords={ + "row": non_linear_coord_arr, + "col": non_linear_coord_arr, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + elif option == "F": + xarr = xr.DataArray( + np.ndarray([30, 30, 5], order="F"), + dims=["row", "col", "ch"], + coords={ + "row": non_linear_coord_arr, + "col": non_linear_coord_arr, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + else: + xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name) + + return xarr + + +def get_non_numeric_coord_xarr(option="C"): + name: str = "non_numeric_coord_data_array" + non_numeric_coord_list = [random.choice(string.ascii_letters) for _ in range(30)] + linear_coord_arr = np.arange(5) + if option == "C": + xarr = xr.DataArray( + np.random.rand(30, 30, 5), + dims=["row", "col", "ch"], + coords={ + "row": non_numeric_coord_list, + "col": non_numeric_coord_list, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + elif option == "F": + xarr = xr.DataArray( + np.ndarray([30, 30, 5], order="F"), + dims=["row", "col", "ch"], + coords={ + "row": non_numeric_coord_list, + "col": non_numeric_coord_list, + "ch": linear_coord_arr, + }, + attrs={"Hello": "World"}, + name=name, + ) + else: + xarr = xr.DataArray(np.random.rand(30, 30, 5), name=name) + + return xarr + + # -- Helpers -- @@ -359,6 +430,34 @@ def test_no_coords_or_dims_in_xarr(ij_fixture): assert_inverted_xarr_equal_to_xarr(dataset, ij_fixture, xarr) +def test_linear_coord_on_xarr_conversion(ij_fixture): + xarr = get_xarr() + dataset = ij_fixture.py.to_java(xarr) + axes = dataset.dim_axes + # all axes should be DefaultLinearAxis + for ax in axes: + assert isinstance(ax, jc.DefaultLinearAxis) + + +def test_non_linear_coord_on_xarr_conversion(ij_fixture): + xarr = get_non_linear_coord_xarr() + dataset = ij_fixture.py.to_java(xarr) + axes = dataset.dim_axes + # axes [0, 1] should be EnumeratedAxis with axis 2 as DefaultLinearAxis + for i in range(2): + assert isinstance(axes[i], jc.EnumeratedAxis) + assert isinstance(axes[-1], jc.DefaultLinearAxis) + + +def test_non_numeric_coord_on_xarr_conversion(ij_fixture): + xarr = get_non_numeric_coord_xarr() + dataset = ij_fixture.py.to_java(xarr) + axes = dataset.dim_axes + # all axes should be DefaultLinearAxis + for ax in axes: + assert isinstance(ax, jc.DefaultLinearAxis) + + dataset_conversion_parameters = [ ( get_img,