part2
PART II: Age regression from ventricle volume¶
This part of the coursework is about age regression from ventricle volume extracted from segmentations of brain MRI.
First, a brain segmentation method needs to be implemented that separates the brain into gray matter, white matter and CSF. Lateral ventricles are filled with CSF, so lateral ventricles are a subset of the CSF segmentation.
In order to isolate the ventricles from the rest of the CSF, an atlas image with a lateral ventricle mask is provided. However, the original brain images are not aligned with the atlas, so a registration method needs to be implemented to align each brain image with the atlas.
Once registered, the lateral ventrical mask can be used to mask out the part of the brain segmentation containing the ventricles which allow to calculate the ventricle volume in millimetres. A model for age regression from ventricle volume can then be trained and used to make predictions for testing data.
You will be provided with registered and segmented images, so you can work on the regression tasks without having the registration and segmentation implemented. Both functions will check if their are already registered images and segmentations available. To test your implementations, rename the provided ‘reg’ and ‘seg’ data folders.
Read the descriptions and code carefully and look out for the cells marked with ‘TASK’.¶
The following cell contains helper code to obtain filenames and for reading age information for each subject from a spreadsheet.
In [ ]:
import os
import re
import numpy
import xlrd
import SimpleITK as sitk
# Retrieve the list of patients
data_dir = ‘./data/t1-images’
imageNames = sorted(os.walk(data_dir).next()[2]) # Retrieve all the imagenames
# Read the spreadsheet to retrieve the age information for each subject
ages = []
csvfilename = ‘./data/meta/IXI.xls’
workbook = xlrd.open_workbook(csvfilename)
sheet = workbook.sheet_by_index(0)
idCells = sheet.col_slice(colx=0, start_rowx=1,end_rowx=None)
ageCells = sheet.col_slice(colx=11,start_rowx=1,end_rowx=None)
idAgeDic = dict( (ii.value, ageCells[loopId].value) for loopId,ii in enumerate(idCells))
The next cell contains helper functions for image registration such as image visualisation. These can be used in your implementation, but are optional.
In [ ]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.html.widgets import interact, fixed
from IPython.display import clear_output
#callback invoked by the interact ipython method for scrolling through the image stacks of
#the two images (moving and fixed)
def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):
#create a figure with two subplots and the specified size
plt.subplots(1,2,figsize=(10,8))
#draw the fixed image in the first subplot
plt.subplot(1,2,1)
plt.imshow(fixed_npa[fixed_image_z,:,:],cmap=plt.cm.Greys_r);
plt.title(‘fixed image’)
plt.axis(‘off’)
#draw the moving image in the second subplot
plt.subplot(1,2,2)
plt.imshow(moving_npa[moving_image_z,:,:],cmap=plt.cm.Greys_r);
plt.title(‘moving image’)
plt.axis(‘off’)
#callback invoked by the ipython interact method for scrolling and modifying the alpha blending
#of an image stack of two images that occupy the same physical space.
def display_images_with_alpha(image_z, alpha, fixed, moving):
img = (1.0 – alpha)*fixed[:,:,image_z] + alpha*moving[:,:,image_z]
plt.imshow(sitk.GetArrayFromImage(img),cmap=plt.cm.Greys_r);
plt.axis(‘off’)
#callback invoked when the StartEvent happens, sets up our new data
def start_plot():
global metric_values, multires_iterations
metric_values = []
multires_iterations = []
#callback invoked when the EndEvent happens, do cleanup of data and figure
def end_plot():
global metric_values, multires_iterations
del metric_values
del multires_iterations
#close figure, we don’t want to get a duplicate of the plot latter on
plt.close()
#callback invoked when the IterationEvent happens, update our data and display new figure
def plot_values(registration_method):
global metric_values, multires_iterations
metric_values.append(registration_method.GetMetricValue())
#clear the output area (wait=True, to reduce flickering), and plot current data
clear_output(wait=True)
#plot the similarity metric values
plt.plot(metric_values, ‘r’)
plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], ‘b*’)
plt.xlabel(‘Iteration Number’,fontsize=12)
plt.ylabel(‘Metric Value’,fontsize=12)
plt.show()
#callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the
#metric_values list.
def update_multires_iterations():
global metric_values, multires_iterations
multires_iterations.append(len(metric_values))
TASK 2.1: Image to atlas registration¶
In the next cell you are asked to implement an intensity-based image registration method that can later be used to align brain images with an atlas image.
The function below takes filenames as arguments of the moving and reference image, saves the registered image in a specified folder and returns its filename.
In [ ]:
def rigid_registration(movingImageName,referenceImageName,regSaveDir=’./data/reg/’):
# Check if the registration directory exists
if not os.path.exists(regSaveDir):
os.makedirs(regSaveDir)
# Registration name
registered_image_name = regSaveDir + movingImageName.split(‘/’)[-1]
# Check if registered image is already saved to disk
if not os.path.isfile(registered_image_name):
# Load the images
fixed_image = sitk.ReadImage(referenceImageName, sitk.sitkFloat32)
moving_image = sitk.ReadImage(movingImageName, sitk.sitkFloat32)
registration_method = sitk.ImageRegistrationMethod()
# ADD CODE HERE
#connect all of the observers so that we can perform plotting during registration
registration_method.AddCommand(sitk.sitkStartEvent, start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, update_multires_iterations)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: plot_values(registration_method))
final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32),
sitk.Cast(moving_image, sitk.sitkFloat32))
# Resample the image moving image after the registration
moving_resampled = sitk.Resample(moving_image, fixed_image, final_transform, sitk.sitkLinear, 0.0, moving_image.GetPixelIDValue())
# Write the updated image
sitk.WriteImage(moving_resampled,registered_image_name)
return registered_image_name
Task 2.2: Ventricle segmentation and volume calculation¶
In the next cell you are asked to implement a function that segments a brain image into gray matter, white matter, and CSF and then uses a lateral ventricle mask to extract ventricle volume in millilitres.
For brain segmentation, check out http://goo.gl/W6EO5u.
For inspiration on how to extract ventricle volume, check out https://sites.google.com/site/mrilateralventricle/. Note, you are not asked to replicate this approach, but you can extract ventricle volume in a similar way.
In [ ]:
from nipy import load_image, save_image
from nipy.core.image.image_spaces import (make_xyz_image,xyz_affine)
from nipy.algorithms.segmentation import BrainT1Segmentation
def ventricle_volume(inputImageName,LVmaskImagename,segSaveDir=’./data/seg/’):
# Check if the segmentation directory exists
if not os.path.exists(segSaveDir):
os.makedirs(segSaveDir)
# Segmentation filename
fullSegmentationName = segSaveDir + inputImageName.split(‘/’)[-1]
# Check if segmentation is already saved to disk
if not os.path.isfile(fullSegmentationName):
# Read the image
img_sitk = sitk.ReadImage(inputImageName)
# Create image mask
img = load_image(inputImageName)
mask = img.get_data() > 0 # ignoring background pixels
# Perform brain segmentation
# ADD CODE HERE
# HINT: ensure the resulting segmentation has the correct orientation and flipping
# HINT: you can make use of make_xyz_image and xyz_affine to re-orient and numpy.swapaxes to flip arrays
# Save the segmentation to disk to avoid recomputation
sitk.WriteImage(labelimage,fullSegmentationName)
# Compute ventricle volume in millilitres
# Load the segmentation
labelimage = sitk.ReadImage(fullSegmentationName)
# Load the LV mask
LVmaskimage = sitk.ReadImage(LVmaskImagename)
# ADD CODE HERE
# HINT: Extract the CSF mask from the labelimage
# HINT: The LV mask has a different resolution, check out sitk.ResampleImageFilter
return vol
The next cell performs registration and segmentation on all images.
In [ ]:
LVmaskFilename = ‘./data/atlas/LV_mask.nii.gz’
atlasFilename = ‘./data/atlas/atlas.nii.gz’
volumes = []
labels = []
for idx,filename in enumerate(imageNames):
if (idx+1) % 50 == 0:
print ‘Processed subjects {0} of {1}’.format(idx+1,len(imageNames))
# Retrieve the subject id and its age
regexp_result = re.search(r’mIXI\d+’, filename)
subjectId = (int(regexp_result.group().split(‘mIXI’)[1]))
labels.append(idAgeDic[subjectId])
# Registration to atlas space
fullFilename = data_dir + ‘/’ + filename
fullFilename_registered = rigid_registration(fullFilename,atlasFilename)
# Segmentation and ventricle volume computation
vol = ventricle_volume(fullFilename_registered,LVmaskFilename)
volumes.append(vol)
volumes = numpy.array(volumes,dtype=numpy.float32).reshape(-1,1)
labels = numpy.array(labels,dtype=numpy.float32)
The cell below plots the ventricle volume vs age.
In [ ]:
# Plot data
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.scatter(volumes,labels)
ax.grid()
plt.ylabel(“Age”)
plt.xlabel(“Ventricle volume (ml)”)
plt.show()
TASK 2.3: Age regression¶
In the next cell you are asked to implement (or copy & paste) two functions for training and applying a model for age regression, and another function for evaluating prediction quality.
In [ ]:
def trainRegressor (data, labels):
# ADD CODE HERE
return model
def applyRegressor (data, model):
# ADD CODE HERE
return labels
def evaluate(labels_true, labels_predicted, plot=False):
if plot:
%pylab inline
plt.figure(figsize=(6,6))
plt.scatter(labels_true, labels_predicted)
plt.plot([0, 100], [0, 100], ‘–k’, linewidth=3)
plt.axis(‘tight’); plt.xlabel(‘True age’,fontsize=15); plt.ylabel(‘Predicted age’, fontsize=15)
plt.tick_params(axis=’both’, which=’major’, labelsize=15); plt.grid(‘on’); plt.show()
# Age Prediction Errors
prediction_errors = labels_true – labels_predicted
# Mean error
mean_error = numpy.mean(numpy.abs(prediction_errors))
print ‘Mean error is {0}’.format(mean_error)
# Root mean squared error
root_mean_squared_error = numpy.sqrt(numpy.mean(numpy.power(prediction_errors,2)))
print ‘Root mean squared error is {0}’.format(root_mean_squared_error)
return prediction_errors
The next cell prepares the data for a very simple experiment where the images are split half/half into two sets, one for training and one for testing.
In [ ]:
# Split data half/half into training and testing
trainingVolumes = volumes[0::2]
trainingLabels = labels[0::2]
testingVolumes = volumes[1::2]
testingLabels = labels[1::2]
print ‘Number of training images is {0}’.format(len(trainingVolumes))
print ‘Number of testing images is {0}’.format(len(testingVolumes))
TASK 2.4: Simple experiment¶
In the next cell you are asked to set up and execute a simple experiment using the above training and testing images. You need four steps: 1) train a regressor, 2) apply the regressor on training data and visualise model fit, 3) apply the regressor on testing data, 4) evaluate the prediction quality
In [ ]:
# 1) Train a model
# ADD CODE HERE
# 2) Apply model on training data (and an artificial data range) and visualise
# ADD CODE HERE
# 3) Apply model on testing data
# ADD CODE HERE
# 4) Evaluate predictions
# ADD CODE HERE
TASK 2.5: Cross validation using k-folds¶
In the next cell you are asked to implement a k-fold cross validation such that every subject is used once for testing and prediction errors can be computed for all subjects.
In [ ]:
from sklearn.model_selection import KFold
def kfold_cross_validation(n_folds, vols, lbls):
kf = KFold(n_splits=n_folds)
predictions = numpy.array([])
for foldId, (trainIds,testIds) in enumerate(kf.split(range(0,len(vols)))):
print ‘Fold: {0}/{1}’.format(foldId+1,n_folds)
# ADD CODE HERE
predictions = numpy.concatenate((predictions,testingLabels_predicted))
return predictions
The following cell run a 2-fold cross validation and compute errors for all subjects.
In [ ]:
predictions = kfold_cross_validation(2, volumes, labels)
errors = evaluate(labels, predictions, True)
In [ ]: