Source code for mr_utils.cs.convex.gd_fourier_encoded_tv

'''Gradient descent algorithm for Fourier encoding model and TV constraint.'''

import logging

import numpy as np

from mr_utils.utils import dTV

logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.DEBUG)

[docs]def GD_FE_TV( kspace, samp, alpha=.5, lam=.01, do_reordering=False, im_true=None, ignore_residual=False, disp=False, maxiter=200): r'''Gradient descent for Fourier encoding model and TV constraint. Parameters ========== kspace : array_like Measured image. samp : array_like Sampling mask. alpha : float, optional Step size. lam : float, optional TV constraint weight. do_reordering : bool, optional Whether or not to reorder for sparsity constraint. im_true : array_like, optional The true image we are trying to reconstruct. ignore_residual : bool, optional Whether or not to break out of loop if resid increases. disp : bool, optional Whether or not to display iteration info. maxiter : int, optional Maximum number of iterations. Returns ======= m_hat : array_like Estimate of im_true. Notes ===== Solves the problem: .. math:: \min_x || d - \text{FT}(I \odot S) ||^2_2 + \lambda \text{TV}(I) where d is measured k-space, I is the image estimate, S is the undersampling mask, and TV is the total variation operator. If `im_true=None`, then MSE will not be calculated. ''' # Make sure compare_mse is defined if im_true is None: compare_mse = lambda x, y: 0 logging.info('No true x provided, MSE will not be calculated.') else: from skimage.measure import compare_mse im_true = np.abs(im_true) # Get the reordering indicies ready if do_reordering: from mr_utils.utils.sort2d import sort2d from mr_utils.utils.orderings import inverse_permutation _, reordering = sort2d(im_true) inverse_reordering = inverse_permutation(reordering) # Get some display stuff happening if disp: from mr_utils.utils.printtable import Table table = Table( ['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Initialize m_hat = np.zeros(kspace.shape, dtype=kspace.dtype) r = -kspace.copy() prev_stop_criteria = np.inf norm_kspace = np.linalg.norm(kspace) # Do the thing for ii in range(int(maxiter)): # Fidelity term fidelity = np.fft.ifft2(r) # Let's reorder if we said that was going to be a thing if do_reordering: m_hat = m_hat.flatten()[reordering].reshape(im_true.shape) # Sparsity term second_term = dTV(m_hat) if do_reordering: m_hat = m_hat.flatten()[inverse_reordering].reshape(im_true.shape) second_term = second_term \ .flatten()[inverse_reordering].reshape(im_true.shape) # Compute stop criteria stop_criteria = np.linalg.norm(r)/norm_kspace if not ignore_residual and stop_criteria > prev_stop_criteria: logging.warning('Breaking out of loop after %d iterations. \ Norm of residual increased!', ii) break prev_stop_criteria = stop_criteria # Take the step m_hat -= alpha*(fidelity + lam*second_term) # Tell the user what happened if disp: logging.info( table.row( [ii, stop_criteria, compare_mse(np.abs(m_hat), im_true)])) # Compute residual r = np.fft.fft2(m_hat)*samp - kspace return m_hat