Skip to content

Commit cc0a5e1

Browse files
authored
Merge pull request #374 from fooof-tools/upds
[ENH] - Upds
2 parents 89ad00c + 12564c2 commit cc0a5e1

File tree

8 files changed

+79
-66
lines changed

8 files changed

+79
-66
lines changed

specparam/plts/event.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def plot_event_model(event, **plot_kwargs):
7878
color=PARAM_COLORS['presence'], ax=next(axes))
7979
next(axes).axis('off')
8080

81-
# 03: goodness of fit
81+
# 03: metrics
8282
for ind, glabel in enumerate(event.results.metrics.labels):
8383
plot_param_over_time_yshade(\
8484
None, event.results.event_time_results[glabel],

specparam/plts/group.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def plot_group_model(group, **plot_kwargs):
5353

5454
# Goodness of fit plot
5555
ax1 = plt.subplot(gs[0, 1])
56-
plot_group_goodness(group, ax1, **scatter_kwargs, custom_styler=None)
56+
plot_group_metrics(group, ax1, **scatter_kwargs, custom_styler=None)
5757

5858
# Center frequencies plot
5959
ax2 = plt.subplot(gs[1, :])
@@ -79,17 +79,17 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs):
7979
if group.modes.aperiodic.name == 'knee':
8080
plot_scatter_2(group.results.get_params('aperiodic', 'exponent'), 'Exponent',
8181
group.results.get_params('aperiodic', 'knee'), 'Knee',
82-
'Aperiodic Fit', ax=ax)
82+
'Aperiodic Parameters', ax=ax)
8383
else:
8484
plot_scatter_1(group.results.get_params('aperiodic', 'exponent'), 'Exponent',
85-
'Aperiodic Fit', ax=ax)
85+
'Aperiodic Parameters', ax=ax)
8686

8787

8888
@savefig
8989
@style_plot
9090
@check_dependency(plt, 'matplotlib')
91-
def plot_group_goodness(group, ax=None, **plot_kwargs):
92-
"""Plot goodness of fit results, in a scatter plot.
91+
def plot_group_metrics(group, ax=None, **plot_kwargs):
92+
"""Plot metrics results, in a scatter plot.
9393
9494
Parameters
9595
----------
@@ -101,17 +101,26 @@ def plot_group_goodness(group, ax=None, **plot_kwargs):
101101
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
102102
"""
103103

104-
# Get indices of metrics to plot
105-
err_ind = find_first_ind(group.results.metrics.labels, 'error')
106-
err_label = group.results.metrics.labels[err_ind]
107-
gof_ind = find_first_ind(group.results.metrics.labels, 'gof')
108-
gof_label = group.results.metrics.labels[gof_ind]
104+
if len(group.results.metrics) == 0:
105+
ax.set(xticks=[], yticks=[])
109106

110-
plot_scatter_2(group.results.get_metrics(err_label),
111-
group.results.metrics.flabels[err_ind],
112-
group.results.get_metrics(gof_label),
113-
group.results.metrics.flabels[gof_ind],
114-
'Fit Quality', ax=ax)
107+
if len(group.results.metrics) == 1:
108+
plot_scatter_1(group.results.get_metrics(group.results.metrics.labels[0]),
109+
group.results.metrics.flabels[0],
110+
'Metrics', ax=ax)
111+
112+
elif len(group.results.metrics) >= 2:
113+
ind1 = 0
114+
ind2 = 1
115+
if 'error' in group.results.metrics.categories:
116+
ind1 = find_first_ind(group.results.metrics.labels, 'error')
117+
if 'gof' in group.results.metrics.categories:
118+
ind2 = find_first_ind(group.results.metrics.labels, 'gof')
119+
plot_scatter_2(group.results.get_metrics(group.results.metrics.labels[ind1]),
120+
group.results.metrics.flabels[ind1],
121+
group.results.get_metrics(group.results.metrics.labels[ind2]),
122+
group.results.metrics.flabels[ind2],
123+
'Metrics', ax=ax)
115124

116125

117126
@savefig
@@ -130,5 +139,5 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs):
130139
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
131140
"""
132141

133-
plot_hist(group.results.get_params('peak', 0)[:, 0], 'Center Frequency',
134-
'Peaks - Center Frequencies', x_lims=group.data.freq_range, ax=ax)
142+
plot_hist(group.results.get_params('peak', 'cf')[:, 0], 'Center Frequency',
143+
'Peak Parameters - Center Frequencies', x_lims=group.data.freq_range, ax=ax)

specparam/plts/time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def plot_time_model(time, **plot_kwargs):
6666
colors=[PARAM_COLORS[plabel] for plabel in time.modes.periodic.params.labels],
6767
title='Periodic Parameters - ' + blabel, ax=next(axes))
6868

69-
# 03: goodness of fit
69+
# 03: metrics
7070
err_ind = find_first_ind(time.results.metrics.labels, 'error')
7171
gof_ind = find_first_ind(time.results.metrics.labels, 'gof')
7272
plot_params_over_time(None, \

specparam/reports/save.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from specparam.io.utils import create_file_path
44
from specparam.modutils.dependencies import safe_import, check_dependency
55
from specparam.plts.templates import plot_text
6-
from specparam.plts.group import (plot_group_aperiodic, plot_group_goodness,
6+
from specparam.plts.group import (plot_group_aperiodic, plot_group_metrics,
77
plot_group_peak_frequencies)
88
from specparam.reports.strings import (gen_settings_str, gen_model_results_str,
99
gen_group_results_str, gen_time_results_str,
@@ -99,7 +99,7 @@ def save_group_report(group, file_name, file_path=None, add_settings=True):
9999

100100
# Goodness of fit plot
101101
ax2 = plt.subplot(grid[1, 1])
102-
plot_group_goodness(group, ax2, custom_styler=None)
102+
plot_group_metrics(group, ax2, custom_styler=None)
103103

104104
# Peak center frequencies plot
105105
ax3 = plt.subplot(grid[2, :])

specparam/reports/strings.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@ def gen_methods_report_str(concise=False):
323323
return output
324324

325325

326-
# TODO: UPDATE
327326
def gen_methods_text_str(model=None):
328327
"""Generate a string representation of a template methods report.
329328
@@ -334,32 +333,43 @@ def gen_methods_text_str(model=None):
334333
If None, the text is returned as a template, without values.
335334
"""
336335

337-
template = (
336+
if model:
337+
settings_names = list(model.algorithm.settings.values.keys())
338+
settings_values = list(model.algorithm.settings.values.values())
339+
else:
340+
settings_names = []
341+
settings_values = []
342+
343+
template = [
338344
"The periodic & aperiodic spectral parameterization algorithm (version {}) "
339345
"was used to parameterize neural power spectra. "
340346
"The model was fit with {} aperiodic mode and {} periodic mode. "
341347
"Settings for the algorithm were set as: "
342-
"peak width limits : {}; "
343-
"max number of peaks : {}; "
344-
"minimum peak height : {}; "
345-
"peak threshold : {}; ."
348+
]
349+
350+
if settings_names:
351+
settings_strs = [el + ' : {}, ' for el in settings_names]
352+
settings_strs[-1] = settings_strs[-1][:-2] + '. '
353+
template.extend(settings_strs)
354+
else:
355+
template.extend('XX. ')
356+
357+
template.extend([
346358
"Power spectra were parameterized across the frequency range "
347359
"{} to {} Hz."
348-
)
360+
])
349361

350-
if model:
351-
freq_range = model.data.freq_range if model.data.has_data else ('XX', 'XX')
362+
if model and model.data.has_data:
363+
freq_range = model.data.freq_range
352364
else:
353365
freq_range = ('XX', 'XX')
354366

355-
methods_str = template.format(MODULE_VERSION,
356-
model.modes.aperiodic.name if model else 'XX',
357-
model.modes.periodic.name if model else 'XX',
358-
model.algorithm.settings.peak_width_limits if model else 'XX',
359-
model.algorithm.settings.max_n_peaks if model else 'XX',
360-
model.algorithm.settings.min_peak_height if model else 'XX',
361-
model.algorithm.settings.peak_threshold if model else 'XX',
362-
*freq_range)
367+
methods_str = ''.join(template).format(\
368+
MODULE_VERSION,
369+
model.modes.aperiodic.name if model else 'XX',
370+
model.modes.periodic.name if model else 'XX',
371+
*settings_values,
372+
*freq_range)
363373

364374
return methods_str
365375

@@ -401,21 +411,18 @@ def gen_model_results_str(model, concise=False):
401411
_report_str_model(model),
402412
'',
403413

404-
# Aperiodic parameters
405414
'Aperiodic Parameters (\'{}\' mode)'.format(model.modes.aperiodic.name),
406415
'(' + ', '.join(model.modes.aperiodic.params.labels) + ')',
407416
', '.join(['{:2.4f}'] * \
408417
len(model.results.params.aperiodic.params)).format(*model.results.params.aperiodic.params),
409418
'',
410419

411-
# Peak parameters
412420
'Peak Parameters (\'{}\' mode) {} peaks found'.format(\
413421
model.modes.periodic.name, model.results.n_peaks),
414422
*[peak_str.format(*op) for op in model.results.params.periodic.params],
415423
'',
416424

417-
# Metrics
418-
'Model fit quality metrics:',
425+
'Model metrics:',
419426
*['{:>18s} is {:1.4f} {:8s}'.format('{:s} ({:s})'.format(*key.split('_')), res, ' ') \
420427
for key, res in model.results.metrics.results.items()],
421428
'',
@@ -460,28 +467,31 @@ def gen_group_results_str(group, concise=False):
460467
_report_str_model(group),
461468
'',
462469

463-
# Aperiodic parameters
464470
'Aperiodic Parameters (\'{}\' mode)'.format(group.modes.aperiodic.name),
465471
*[el for el in [\
466472
'{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \
467473
*compute_arr_desc(group.results.get_params('aperiodic', label))) \
468474
for label in group.modes.aperiodic.params.labels]],
469475
'',
470476

471-
# Peak Parameters
472477
'Peak Parameters (\'{}\' mode) {} total peaks found'.format(\
473478
group.modes.periodic.name, sum(group.results.n_peaks)),
474479
'',
480+
]
475481

476-
# Metrics
477-
'Model fit quality metrics:',
478-
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
479-
'{:s} ({:s})'.format(*label.split('_')),
480-
*compute_arr_desc(group.results.get_metrics(label))) \
481-
for label in group.results.metrics.labels],
482-
'',
482+
if len(group.results.metrics) > 0:
483+
str_lst.extend([
484+
'Model metrics:',
485+
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
486+
'{:s} ({:s})'.format(*label.split('_')),
487+
*compute_arr_desc(group.results.get_metrics(label))) \
488+
for label in group.results.metrics.labels],
489+
'',
490+
])
491+
492+
str_lst.extend([
483493
DIVIDER,
484-
]
494+
])
485495

486496
output = _format(str_lst, concise)
487497

@@ -525,15 +535,13 @@ def gen_time_results_str(time, concise=False):
525535
_report_str_model(time),
526536
'',
527537

528-
# Aperiodic parameters
529538
'Aperiodic Parameters (\'{}\' mode)'.format(time.modes.aperiodic.name),
530539
*[el for el in [\
531540
'{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \
532541
*compute_arr_desc(time.results.time_results[label])) \
533542
for label in time.modes.aperiodic.params.labels]],
534543
'',
535544

536-
# Peak Parameters
537545
'Peak Parameters (\'{}\' mode) - mean values across windows'.format(\
538546
time.modes.periodic.name),
539547
*[peak_str.format(*[band_label] + \
@@ -543,8 +551,7 @@ def gen_time_results_str(time, concise=False):
543551
for band_label in time.results.bands.labels],
544552
'',
545553

546-
# Metrics
547-
'Model fit quality metrics (values across windows):',
554+
'Model metrics (values across windows):',
548555
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
549556
'{:s} ({:s})'.format(*key.split('_')),
550557
*compute_arr_desc(time.results.time_results[key])) \
@@ -597,15 +604,13 @@ def gen_event_results_str(event, concise=False):
597604
_report_str_model(event),
598605
'',
599606

600-
# Aperiodic parameters
601607
'Aperiodic Parameters (\'{}\' mode)'.format(event.modes.aperiodic.name),
602608
*[el for el in [\
603609
'{:8s} - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'.format(label, \
604610
*compute_arr_desc(np.mean(event.results.event_time_results[label]))) \
605611
for label in event.modes.aperiodic.params.labels]],
606612
'',
607613

608-
# Peak Parameters
609614
'Peak Parameters (\'{}\' mode) - mean values across windows'.format(\
610615
event.modes.periodic.name),
611616
*[peak_str.format(*[band_label] + \
@@ -616,8 +621,7 @@ def gen_event_results_str(event, concise=False):
616621
for band_label in event.results.bands.labels],
617622
'',
618623

619-
# Metrics
620-
'Model fit quality metrics (values across events):',
624+
'Model metrics (values across events):',
621625
*['{:>18s} - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'.format(\
622626
'{:s} ({:s})'.format(*key.split('_')),
623627
*compute_arr_desc(np.mean(event.results.event_time_results[key], 1))) \

specparam/results/results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def add_results(self, results):
160160

161161

162162
def get_results(self):
163-
"""Return model fit parameters and goodness of fit metrics.
163+
"""Return model fit parameters and metrics.
164164
165165
Returns
166166
-------

specparam/tests/models/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,15 @@ def test_fit_knee():
110110
assert np.allclose(gauss, tfm.results.params.periodic.get_params('fit')[ii], [2.0, 0.5, 1.0])
111111

112112
def test_fit_default_metrics():
113-
"""Test goodness of fit & error metrics, post model fitting."""
113+
"""Test computing metrics, post model fitting."""
114114

115115
tfm = SpectralModel(verbose=False)
116116

117117
# Hack fake data with known properties: total error magnitude 2
118118
tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5])
119119
tfm.results.model.modeled_spectrum = np.array([1, 2, 5, 4, 5])
120120

121-
# Check default goodness of fit and error measures
121+
# Check default metrics
122122
tfm.results.metrics.compute_metrics(tfm.data, tfm.results)
123123
assert np.isclose(tfm.results.metrics.results['error_mae'], 0.4)
124124
assert np.isclose(tfm.results.metrics.results['gof_rsquared'], 0.75757575)

specparam/tests/plts/test_group.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def test_plot_group_aperiodic(tfg, skip_if_no_mpl):
3131
file_name='test_plot_group_aperiodic.png')
3232

3333
@plot_test
34-
def test_plot_group_goodness(tfg, skip_if_no_mpl):
34+
def test_plot_group_metrics(tfg, skip_if_no_mpl):
3535

36-
plot_group_goodness(tfg, file_path=TEST_PLOTS_PATH,
37-
file_name='test_plot_group_goodness.png')
36+
plot_group_metrics(tfg, file_path=TEST_PLOTS_PATH,
37+
file_name='test_plot_group_metrics.png')
3838

3939
@plot_test
4040
def test_plot_group_peak_frequencies(tfg, skip_if_no_mpl):

0 commit comments

Comments
 (0)