Slice-timing correction¶
In this experiment, we will demonstrate slice-timing correction using SciPy's make_interp_spline along with Xarray images.
from pathlib import Path
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
plt.rc('image', cmap='gray')
from scipy.interpolate import make_interp_spline
import xarray as xr
%matplotlib inline
import xibabel as xib
We will use a test functional image from OpenNeuro ds000009:
import xibabel.testing
func_path = xib.testing.JC_EG_FUNC
_ = xib.testing.fetcher.get_file(func_path)
bold = xib.load(func_path, format='bids')
bold
<xarray.DataArray 'sub-07_task-balloonanalogrisktask_bold' (i: 64, j: 64, k: 34, time: 246)> Size: 274MB dask.array<array, shape=(64, 64, 34, 246), dtype=float64, chunksize=(64, 64, 34, 246), chunktype=numpy.ndarray> Coordinates: * i (i) int64 512B 0 1 2 3 4 5 6 7 8 9 ... 55 56 57 58 59 60 61 62 63 * j (j) int64 512B 0 1 2 3 4 5 6 7 8 9 ... 55 56 57 58 59 60 61 62 63 * k (k) int64 272B 0 1 2 3 4 5 6 7 8 9 ... 25 26 27 28 29 30 31 32 33 * time (time) float64 2kB 0.0 2.0 4.0 6.0 8.0 ... 484.0 486.0 488.0 490.0 Attributes: (12/13) xib-FrequencyEncodingDirection: i PhaseEncodingDirection: j SliceEncodingDirection: k RepetitionTime: 2.0 xib-affines: {'scanner': [[-3.0, 0.0, -0.0, 90.091979... EffectiveEchoSpacing: 0.000395 ... ... ImagingFrequency: 123.249608 SeriesDate: 18681110 SeriesNumber: 9 SeriesTime: 171509.015000 StudyID: 1 StudyTime: 160942.703000
This image has an interleaved, ascending slice order:
slice_times = bold.attrs['SliceTiming']
# Adapted from https://textbook.nipraxis.org/slice_timing.html
slice_idcs = np.arange(len(slice_times))
slice_order = np.argsort(slice_times)
acq_by_pos = np.argsort(slice_order)
n_x = len(acq_by_pos) * 1.5 # Determines width of picture.
picture = np.repeat(acq_by_pos[:, None], n_x, axis=1)
cm = matplotlib.colors.LinearSegmentedColormap.from_list(
'light_gray', [[0.4] * 3, [1] * 3])
plt.imshow(picture, cmap=cm, origin='lower')
plt.box(on=False)
plt.xticks([])
plt.yticks(slice_idcs)
plt.tick_params(axis='y', which='both', left=False)
plt.ylabel('Position in space (0 = bottom)')
for space_order, acq_order in zip(slice_idcs, acq_by_pos):
plt.text(n_x / 2, space_order, str(acq_order), va='center')
plt.title('''\
Slice acquisition order (center) by position (left)
Acquisition order''');
To correct a slice, we need to determine the acquisition times of each voxel, and then interpolate at some other acquisition time. In this case, let's use the onset of volume acquisition (volume time + 0):
k = 17
offset = slice_times[k]
slicek = bold[{'k': k}]
filter = make_interp_spline(slicek.time + offset, slicek.T, k=3)
filtered = filter(bold.time).T
Little change is evident from visually inspecting a time point within the slice:
arrk = np.array(slicek)
mn, mx = arrk.min(), arrk.max()
plt.imshow(slicek[..., 20], vmin=mn, vmax=mx)
plt.show()
plt.imshow(filtered[..., 20], vmin=mn, vmax=mx)
<matplotlib.image.AxesImage at 0x13d9f20a0>
However, differences can be detected:
plt.imshow((slicek - filtered)[..., 20])
<matplotlib.image.AxesImage at 0x13d99a6d0>
Note that filtered
is a standard numpy array, so a wrapping function will need to explicitly create a new xarray image:
type(filtered)
numpy.ndarray
Now, let's write a function to perform slice-timing correction, and allow users to select an offset other than 0s:
def slice_timing_correct(bold: xr.DataArray, target_offset: float=0, order: int=3) -> xr.DataArray:
# New output array
output = xr.zeros_like(bold)
output.attrs = bold.attrs.copy()
# New time metadata
output['time'] = bold.time + target_offset
slice_timing = output.attrs.pop('SliceTiming')
# If absent, BIDS specifies that SliceEncodingDirection is k
slice_dir = bold.attrs.get('SliceEncodingDirection', 'k')
slice_axis = slice_dir[0]
# Flipped direction simply reverses the array
if slice_dir[1:] == '-':
slice_timing = slice_timing[::-1]
# Generate one spline model per slice. Note transpose, since the interpolated axis (time) must be first.
for k, offset in enumerate(slice_timing):
filter = make_interp_spline(bold.time + offset, bold[{slice_axis: k}].T, k=order)
output[{slice_axis: k}] = filter(output.time).T
return output
stc = slice_timing_correct(bold)
stc
<xarray.DataArray 'sub-07_task-balloonanalogrisktask_bold' (i: 64, j: 64, k: 34, time: 246)> Size: 274MB dask.array<setitem, shape=(64, 64, 34, 246), dtype=float64, chunksize=(64, 64, 34, 246), chunktype=numpy.ndarray> Coordinates: * i (i) int64 512B 0 1 2 3 4 5 6 7 8 9 ... 55 56 57 58 59 60 61 62 63 * j (j) int64 512B 0 1 2 3 4 5 6 7 8 9 ... 55 56 57 58 59 60 61 62 63 * k (k) int64 272B 0 1 2 3 4 5 6 7 8 9 ... 25 26 27 28 29 30 31 32 33 * time (time) float64 2kB 0.0 2.0 4.0 6.0 8.0 ... 484.0 486.0 488.0 490.0 Attributes: xib-FrequencyEncodingDirection: i PhaseEncodingDirection: j SliceEncodingDirection: k RepetitionTime: 2.0 xib-affines: {'scanner': [[-3.0, 0.0, -0.0, 90.091979... EffectiveEchoSpacing: 0.000395 ImagingFrequency: 123.249608 SeriesDate: 18681110 SeriesNumber: 9 SeriesTime: 171509.015000 StudyID: 1 StudyTime: 160942.703000
plt.imshow(stc.isel(k=17, time=20))
<matplotlib.image.AxesImage at 0x13da71a90>