Source code for mr_utils.coils.gs_comparison.gs_coil_combine_comparison

'''Functions to compare coil combination methods.

This actually might belong in the examples, and it needs to be checked to
make sure it still works.
'''

import numpy as np
from ismrmrdtools.simulation import generate_birdcage_sensitivities
from ismrmrdtools.coils import calculate_csm_walsh, calculate_csm_inati_iter
from skimage.measure import compare_nrmse

from mr_utils import view
from mr_utils.recon.ssfp import gs_recon
from mr_utils.utils import sos
from mr_utils.test_data.phantom import bssfp_2d_cylinder

[docs]def get_numerical_phantom_params(SNR=None): '''Preset parameters for a numerical cylindrical phantom. Parameters ---------- SNR : float Signal to noise ratio, calculated: std = avg_signal/SNR. Returns ------- params : dictionary Parameter dictionary including `noise_std`, `dim`, `pc_vals`, and `coil_nums` fields. ''' if SNR is None: noise_std = 0 else: im = get_true_im_numerical_phantom() m_avg = np.abs(np.nanmean(im[np.nonzero(im)])) noise_std = m_avg/SNR params = { 'noise_std': noise_std, 'dim': 64, 'pc_vals': [0, np.pi/2, np.pi, 3*np.pi/2], # 'coil_nums': [ 2,4,8,16,32 ] 'coil_nums': [16] } return params
[docs]def get_true_im_numerical_phantom(): '''Get reference bSSFP simulated phantom. As the geometric solution to the elliptical signal model still has some residual banding, do it a few times at a bunch of different phase cycles to remove virually all banding. This ensures that the contrast will be comparable to the banded phantoms. Returns ------- true_im : array_like Banding free reference image with true bSSFP contrast. ''' # Load in params for simulation params = get_numerical_phantom_params(SNR=None) dim = params['dim'] pc_vals = params['pc_vals'] # Find true im by using no noise gs_recon averaged over several # different phase-cycles to remove residual banding # true_im = bssfp_2d_cylinder(dims=(dim,dim),phase_cyc=0) true_im = np.zeros((dim, dim), dtype='complex') avgs = [0, np.pi/6, np.pi/3, np.pi/4] # avgs = [ 0 ] for extra in avgs: pc0 = bssfp_2d_cylinder(dims=(dim, dim), phase_cyc=(pc_vals[0]+extra)) pc1 = bssfp_2d_cylinder(dims=(dim, dim), phase_cyc=(pc_vals[1]+extra)) pc2 = bssfp_2d_cylinder(dims=(dim, dim), phase_cyc=(pc_vals[2]+extra)) pc3 = bssfp_2d_cylinder(dims=(dim, dim), phase_cyc=(pc_vals[3]+extra)) true_im += gs_recon(pc0, pc1, pc2, pc3) true_im /= len(avgs) true_im += 1j*true_im # view(np.concatenate((true_im,pc0))) return true_im
[docs]def get_coil_sensitivity_maps(): '''Simulate coil sensitivity maps. Returns ------- csms : list List of coil sensitivity maps (arrays), one for each coil. ''' # get simple coil sensitivity maps (1,4,8,16,32 coil combinations) params = get_numerical_phantom_params() dim = params['dim'] coil_nums = params['coil_nums'] csms = [] for coil_num in coil_nums: csms.append( generate_birdcage_sensitivities(dim, number_of_coils=coil_num)) return csms
# Metric will be percent ripple
[docs]def ripple(im0): '''Calculate % ripple metric using local patches of line. Parameters ---------- im0 : array_like Image to calculate ripple of. Returns ------- float Percent ripple calculated by using local patches along a line through the center of `im0` ''' im = im0.copy() # We only want one line through image line = im[:, int(im.shape[1]/2)] line = line[np.abs(line) > np.max(np.abs(line))/10] # Choose a "patch" of the line over distance you assume to be linear # and get the ripple for each patch pad = 6 val = [] for ii in range(np.mod(line.size, pad)): line = np.abs( line[ii*6 + int(line.size/2) - pad:ii*6 + int(line.size/2) + pad]) val.append((np.max(line) - np.min(line))/np.mean(line)) return 100*np.mean(val)
[docs]def ripple_normal(im): '''Calculate % ripple metric. Parameters ---------- im : array_like Image to calculate ripple of. Returns ------- float Percent ripple. Notes ----- A horizontal line is drawn through the center of the image. The percent ripple is calculated along this line. ''' line = np.abs(im[:, int(im.shape[1]/2)]) line = line[np.abs(line) > np.max(np.abs(line))/5] # view(line) val = (np.max(line) - np.min(line))/np.mean(line) return 100*val
[docs]def comparison_knee(): # pylint: disable=R0914 '''Coil by coil, Walsh method, and Inati iterative method for knee data.''' # # Load the knee data #dir = '/home/nicholas/Documents/rawdata/SSFP_SPECTRA_dphiOffset_08022018/' # files = [ # 'meas_MID362_TRUFI_STW_TE3_FID29379', # 'meas_MID363_TRUFI_STW_TE3_dphi_45_FID29380', # 'meas_MID364_TRUFI_STW_TE3_dphi_90_FID29381', # 'meas_MID365_TRUFI_STW_TE3_dphi_135_FID29382', # 'meas_MID366_TRUFI_STW_TE3_dphi_180_FID29383', # 'meas_MID367_TRUFI_STW_TE3_dphi_225_FID29384', # 'meas_MID368_TRUFI_STW_TE3_dphi_270_FID29385', # 'meas_MID369_TRUFI_STW_TE3_dphi_315_FID29386' # ] # pc_vals = [0, 45, 90, 135, 180, 225, 270, 315] # dims = (512, 256) # num_coils = 4 # num_avgs = 16 # # Load in raw once, then save as npy with collapsed avg dimension # pcs = np.zeros((len(files),dims[0],dims[1],num_coils),dtype='complex') # for ii,file in enumerate(files): # pcs[ii,...] = np.mean( # load_raw('%s/%s.dat' % (dir,file),use='s2i'),axis=-1) # np.save('%s/te3.npy' % dir,pcs) # pcs looks like (pc,x,y,coil) pcs = np.load('%s/te3.npy' % dir) pcs = np.fft.fftshift(np.fft.fft2(pcs, axes=(1, 2)), axes=(1, 2)) # print(pcs.shape) # view(pcs,fft=True,montage_axis=0,movie_axis=3) # Do recon then coil combine coils0 = np.zeros( (pcs.shape[-1], pcs.shape[1], pcs.shape[2]), dtype='complex') coils1 = coils0.copy() for ii in range(pcs.shape[-1]): # We have two sets: 0,90,180,27 and 45,135,225,315 idx0 = [0, 2, 4, 6] idx1 = [1, 3, 5, 7] coils0[ii, ...] = gs_recon(*[x.squeeze() for x in pcs[idx0, :, :, ii]]) coils1[ii, ...] = gs_recon(*[x.squeeze() for x in pcs[idx1, :, :, ii]]) # Then do the coil combine csm_walsh, _ = calculate_csm_walsh(coils0) im_est_recon_then_walsh0 = np.sum(csm_walsh*np.conj(coils0), axis=0) # view(im_est_recon_then_walsh0) csm_walsh, _ = calculate_csm_walsh(coils1) im_est_recon_then_walsh1 = np.sum(csm_walsh*np.conj(coils1), axis=0) # view(im_est_recon_then_walsh1) rip0 = ripple(im_est_recon_then_walsh0) rip1 = ripple(im_est_recon_then_walsh1) print('recon then walsh: ', np.mean([rip0, rip1])) # Now try inati _csm_inati, im_est_recon_then_inati0 = calculate_csm_inati_iter( coils0, smoothing=5) _csm_inati, im_est_recon_then_inati1 = calculate_csm_inati_iter( coils1, smoothing=5) rip0 = ripple(im_est_recon_then_inati0) rip1 = ripple(im_est_recon_then_inati1) print('recon then inati: ', np.mean([rip0, rip1])) # Now try sos im_est_recon_then_sos0 = sos(coils0, axes=0) im_est_recon_then_sos1 = sos(coils1, axes=0) rip0 = ripple(im_est_recon_then_sos0) rip1 = ripple(im_est_recon_then_sos1) print('recon then sos: ', np.mean([rip0, rip1])) # view(im_est_recon_then_sos) ## Now the other way, combine then recon pcs0 = np.zeros( (2, pcs.shape[0], pcs.shape[1], pcs.shape[2]), dtype='complex') # pcs1 = pcs0.copy() for ii in range(pcs.shape[0]): # Walsh it up csm_walsh, _ = calculate_csm_walsh(pcs[ii, ...].transpose((2, 0, 1))) pcs0[0, ii, ...] = np.sum( csm_walsh*np.conj(pcs[ii, ...].transpose((2, 0, 1))), axis=0) # view(pcs0[ii,...]) # Inati it up _csm_inati, pcs0[1, ii, ...] = calculate_csm_inati_iter( pcs[ii, ...].transpose((2, 0, 1)), smoothing=5) ## Now perform gs_recon on each coil combined set # Walsh im_est_walsh_then_recon0 = gs_recon( *[x.squeeze() for x in pcs0[0, idx0, ...]]) im_est_walsh_then_recon1 = gs_recon( *[x.squeeze() for x in pcs0[0, idx1, ...]]) # Inati im_est_inati_then_recon0 = gs_recon( *[x.squeeze() for x in pcs0[1, idx0, ...]]) im_est_inati_then_recon1 = gs_recon( *[x.squeeze() for x in pcs0[1, idx1, ...]]) # view(im_est_walsh_then_recon0) # view(im_est_walsh_then_recon1) view(im_est_inati_then_recon0) view(im_est_inati_then_recon1) rip0 = ripple(im_est_walsh_then_recon0) rip1 = ripple(im_est_walsh_then_recon1) print('walsh then recon: ', np.mean([rip0, rip1])) rip0 = ripple(im_est_inati_then_recon0) rip1 = ripple(im_est_inati_then_recon1) print('inati then recon: ', np.mean([rip0, rip1]))
# pcs1[ii,...] = gs_recon(*[ x.squeeze() for x in pcs[idx1,...] ])
[docs]def comparison_numerical_phantom(SNR=None): # pylint: disable=R0914,R0915 '''Compare coil by coil, Walsh method, and Inati iterative method. Parameters ========== SNR : float Signal to noise ratio. ''' true_im = get_true_im_numerical_phantom() csms = get_coil_sensitivity_maps() params = get_numerical_phantom_params(SNR=SNR) pc_vals = params['pc_vals'] dim = params['dim'] noise_std = params['noise_std'] # coil_nums = params['coil_nums'] # We want to solve gs_recon for each coil we have in the pc set err = np.zeros((5, len(csms))) rip = err.copy() for ii, csm in enumerate(csms): # I have coil sensitivities, now I need images to apply them to. # coil_ims: (pc,coil,x,y) coil_ims = np.zeros( (len(pc_vals), csm.shape[0], dim, dim), dtype='complex') for jj, pc in enumerate(pc_vals): im = bssfp_2d_cylinder(dims=(dim, dim), phase_cyc=pc) im += 1j*im coil_ims[jj, ...] = im*csm coil_ims[jj, ...] += np.random.normal( 0, noise_std, coil_ims[jj, ...].shape)/2 + 1j*np.random.normal( 0, noise_std, coil_ims[jj, ...].shape)/2 # Solve the gs_recon coil by coil coil_ims_gs = np.zeros((csm.shape[0], dim, dim), dtype='complex') lpcs = len(pc_vals) for kk in range(csm.shape[0]): coil_ims_gs[kk, ...] = gs_recon( *[x.squeeze() for x in np.split(coil_ims[:, kk, ...], lpcs)]) coil_ims_gs[np.isnan(coil_ims_gs)] = 0 # Easy way out: combine all the coils using sos im_est_sos = sos(coil_ims_gs) # view(im_est_sos) # Take coil by coil solution and do Walsh on it to collapse coil dim # walsh csm_walsh, _ = calculate_csm_walsh(coil_ims_gs) im_est_recon_then_walsh = np.sum( csm_walsh*np.conj(coil_ims_gs), axis=0) im_est_recon_then_walsh[np.isnan(im_est_recon_then_walsh)] = 0 # view(im_est_recon_then_walsh) # inati _csm_inati, im_est_recon_then_inati = calculate_csm_inati_iter( coil_ims_gs) # Collapse the coil dimension of each phase-cycle using Walsh,Inati pc_est_walsh = np.zeros((len(pc_vals), dim, dim), dtype='complex') pc_est_inati = np.zeros((len(pc_vals), dim, dim), dtype='complex') for jj in range(len(pc_vals)): ## Walsh csm_walsh, _ = calculate_csm_walsh(coil_ims[jj, ...]) pc_est_walsh[jj, ...] = np.sum( csm_walsh*np.conj(coil_ims[jj, ...]), axis=0) # view(csm_walsh) # view(pc_est_walsh) ## Inati _csm_inati, pc_est_inati[jj, ...] = calculate_csm_inati_iter( coil_ims[jj, ...], smoothing=1) # pc_est_inati[jj,...] = np.sum( # csm_inati*np.conj(coil_ims[jj,...]),axis=0) # view(csm_inati) # Now solve the gs_recon using collapsed coils im_est_walsh = gs_recon( *[x.squeeze() for x in np.split(pc_est_walsh, len(pc_vals))]) im_est_inati = gs_recon( *[x.squeeze() for x in np.split(pc_est_inati, len(pc_vals))]) # view(im_est_walsh) # view(im_est_recon_then_walsh) # Compute error metrics err[0, ii] = compare_nrmse(im_est_sos, true_im) err[1, ii] = compare_nrmse(im_est_recon_then_walsh, true_im) err[2, ii] = compare_nrmse(im_est_recon_then_inati, true_im) err[3, ii] = compare_nrmse(im_est_walsh, true_im) err[4, ii] = compare_nrmse(im_est_inati, true_im) im_est_sos[np.isnan(im_est_sos)] = 0 im_est_recon_then_walsh[np.isnan(im_est_recon_then_walsh)] = 0 im_est_recon_then_inati[np.isnan(im_est_recon_then_inati)] = 0 im_est_walsh[np.isnan(im_est_walsh)] = 0 im_est_inati[np.isnan(im_est_inati)] = 0 rip[0, ii] = ripple_normal(im_est_sos) rip[1, ii] = ripple_normal(im_est_recon_then_walsh) rip[2, ii] = ripple_normal(im_est_recon_then_inati) rip[3, ii] = ripple_normal(im_est_walsh) rip[4, ii] = ripple_normal(im_est_inati) # view(im_est_inati) # # Let's show some stuff # plt.plot(coil_nums,err[0,:],'*-',label='SOS') # plt.plot(coil_nums,err[1,:],label='Recon then Walsh') # plt.plot(coil_nums,err[2,:],label='Walsh then Recon') # # plt.plot(coil_nums,err[3,:],label='Inati') # plt.legend() # plt.show() print('SOS RMSE:', np.mean(err[0, :])) print('recon then walsh RMSE:', np.mean(err[1, :])) print('recon then inati RMSE:', np.mean(err[2, :])) print('walsh then recon RMSE:', np.mean(err[3, :])) print('inati then recon RMSE:', np.mean(err[4, :])) print('SOS ripple:', np.mean(err[0, :])) print('recon then walsh ripple:', np.mean(rip[1, :])) print('recon then inati ripple:', np.mean(rip[2, :])) print('walsh then recon ripple:', np.mean(rip[3, :])) print('inati then recon ripple:', np.mean(rip[4, :])) view(im_est_recon_then_walsh[int(dim/2), :]) view(im_est_recon_then_inati[int(dim/2), :]) view(im_est_walsh[int(dim/2), :]) view(im_est_inati[int(dim/2), :]) # view(im_est_inati) # view(np.stack(( # im_est_recon_then_walsh, # im_est_recon_then_inati, # im_est_walsh, # im_est_inati))) return err
if __name__ == '__main__': pass