Skip to content

Commit 4672059

Browse files
committed
Refactor dims._assign_axes()
This commit refactors the dims.assign_axes() method to (1) drop the use of EnumeratedAxis (until additional metadata mechanisms are available) and (2) resolves a bug with handling singleton dimensions. (1) Some plugins, like BoneJ, look specifically for DefaultLinearAxis when performing some calculations. While EnumeratedAxis works technically, it is not supported (likely) by most plugins. Additionally, ImageJ2 seems to always return DefaultLinearAxis. Until we can match an EnmeratedAxis, it does not make a lot of sense to default to using it. Especially if there are issues downstream like BoneJ. (2) When calculating the scale for a singleton dimension, the coordinate array only has one value. Thus attempting to calculate the slope from the first two entries of the coord array is not possible. Instead, we assign the scale to 1 for all singleton dimensions.
1 parent 2930c20 commit 4672059

File tree

1 file changed

+19
-49
lines changed

1 file changed

+19
-49
lines changed

src/imagej/dims.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
Utility functions for querying and manipulating dimensional axis metadata.
33
"""
44
import logging
5-
from typing import List, Tuple, Union
5+
from typing import List, Sequence, Tuple
66

77
import numpy as np
88
import scyjava as sj
99
import xarray as xr
10-
from jpype import JException, JObject
10+
from jpype import JObject
1111

1212
from imagej._java import jc
1313
from imagej.images import is_arraylike as _is_arraylike
@@ -179,7 +179,7 @@ def prioritize_rai_axes_order(
179179

180180
def _assign_axes(
181181
xarr: xr.DataArray,
182-
) -> List[Union["jc.DefaultLinearAxis", "jc.EnumeratedAxis"]]:
182+
) -> Sequence["jc.DefaultLinearAxis"]:
183183
"""
184184
Obtain xarray axes names, origin, scale and convert into ImageJ Axis. Supports both
185185
DefaultLinearAxis and the newer EnumeratedAxis.
@@ -205,25 +205,8 @@ def _assign_axes(
205205
jc.Double(np.double(x)) for x in np.arrange(len(xarr.coords[dim]))
206206
]
207207

208-
# assign calibrated axis type -- checks for imagej metadata
209-
if "imagej" in xarr.attrs.keys():
210-
ij_dim = _convert_dim(dim, "java")
211-
if ij_dim + "_cal_axis_type" in xarr.attrs["imagej"].keys():
212-
scale_type = xarr.attrs["imagej"][ij_dim + "_cal_axis_type"]
213-
if scale_type == "linear":
214-
jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
215-
if scale_type == "enumerated":
216-
try:
217-
EnumeratedAxis = _get_enumerated_axis()
218-
except (JException, TypeError):
219-
EnumeratedAxis = None
220-
if EnumeratedAxis is not None:
221-
jaxis = EnumeratedAxis(ax_type, sj.to_java(doub_coords))
222-
else:
223-
jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
224-
else:
225-
# default to DefaultLinearAxis always if no `scale_type` key in attr
226-
jaxis = _get_linear_axis(ax_type, sj.to_java(doub_coords))
208+
# create linear axes for data
209+
jaxis = _get_default_linear_axis(ax_type, doub_coords)
227210

228211
axes[ax_num] = jaxis
229212

@@ -321,6 +304,20 @@ def _get_enumerated_axis():
321304
return sj.jimport("net.imagej.axis.EnumeratedAxis")
322305

323306

307+
def _get_default_linear_axis(axis_type: "jc.AxisType", values):
308+
"""
309+
Get an instance of a DefaultLinearAxis.
310+
"""
311+
origin = values[0]
312+
# calculate the slope using the values/coord array
313+
if len(values) <= 1:
314+
scale = 1
315+
else:
316+
scale = values[1] - values[0]
317+
318+
return jc.DefaultLinearAxis(axis_type, scale, origin)
319+
320+
324321
def _get_linear_axis(axis_type: "jc.AxisType", values):
325322
"""Get linear axis.
326323
@@ -483,30 +480,3 @@ def _to_ijdim(key: str) -> str:
483480
return ijdims[key]
484481
else:
485482
return key
486-
487-
488-
def _cal_axis_type_to_str(key) -> str:
489-
"""
490-
Convert a CalibratedAxis type (e.g. net.imagej.axis.DefaultLinearAxis) to
491-
a string.
492-
"""
493-
cal_axis_types = {
494-
jc.ChapmanRichardsAxis: "ChapmanRichardsAxis",
495-
jc.DefaultLinearAxis: "DefaultLinearAxis",
496-
jc.EnumeratedAxis: "EnumeratedAxis",
497-
jc.ExponentialAxis: "ExponentialAxis",
498-
jc.ExponentialRecoveryAxis: "ExponentialRecoveryAxis",
499-
jc.GammaVariateAxis: "GammaVariateAxis",
500-
jc.GaussianAxis: "GaussianAxis",
501-
jc.IdentityAxis: "IdentityAxis",
502-
jc.InverseRodbardAxis: "InverseRodbardAxis",
503-
jc.LogLinearAxis: "LogLinearAxis",
504-
jc.PolynomialAxis: "PolynomialAxis",
505-
jc.PowerAxis: "PowerAxis",
506-
jc.RodbardAxis: "RodbardAxis",
507-
}
508-
509-
if key.__class__ in cal_axis_types:
510-
return cal_axis_types[key.__class__]
511-
else:
512-
return "unknown"

0 commit comments

Comments
 (0)