Source code for mr_utils.cs.thresholding.normalized_iht

'''Normalized iterative hard thresholding algorithm.'''

import logging

import numpy as np
from skimage.measure import compare_mse

from mr_utils.utils.printtable import Table

logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.DEBUG)

[docs]def nIHT(A, y, k, c=0.1, kappa=None, x=None, maxiter=200, tol=1e-8, disp=False): '''Normalized iterative hard thresholding. Parameters ========== A : array_like Measurement matrix y : array_like Measurements (i.e., y = Ax) k : int Number of nonzero coefficients preserved after thresholding. c : float, optional Small, fixed constant. Tunable. kappa : float, optional Constant, > 1/(1 - c). x : array_like, optional True signal we want to estimate. maxiter : int, optional Maximum number of iterations (of the outer loop). tol : float, optional Stopping criteria. disp : bool, optional Whether or not to display iteration info. Returns ======= x_hat : array_like Estimate of x. Notes ===== Implements Algorithm 8.6 from [1]_. References ========== .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # Basic checks assert 0 < c < 1, 'c must be in (0,1)' # length of measurement vector and original signal _n, N = A.shape[:] # Make sure we have everything we need for disp if disp and x is None: logging.warning('No true x provided, using x=0 for MSE calc.') x = np.zeros(N) if disp: table = Table( ['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Initializations x_hat = np.zeros(N) # Inital calculation of support val = A.T.dot(y) thresh = -np.sort(-np.abs(val))[k-1] val[np.abs(val) < thresh] = 0 T = np.nonzero(val) # Find suitable kappa if the user didn't give us one if kappa is None: # kappa must be > 1/(1 - c), so try 2 times the lower bound kappa = 2/(1 - c) else: assert kappa > 1/(1 - c), 'kappa must be > 1/(1 - c)' # Do the iterative part of the thresholding... ii = 0 for ii in range(int(maxiter)): # Compute residual r = y - np.dot(A, x_hat) # Check stopping criteria stop_criteria = np.linalg.norm(r)/np.linalg.norm(y) if stop_criteria < tol: break # Let's check out what's going on if disp: logging.info(table.row( [ii, stop_criteria, compare_mse(x, x_hat)])) # Compute step size g = np.dot(A.T, r) mu = np.linalg.norm(g)**2/np.linalg.norm(np.dot(A, g))**2 # Hard thresholding xn = x_hat + mu*g xn[np.argsort(np.abs(xn))[:-k]] = 0 # Compute support of xn Tn = np.nonzero(xn) # Decide what to do if np.array_equal(Tn, T): x_hat = xn else: cond = (1 - c)*np.linalg.norm(xn - x_hat)**2/np.linalg.norm( np.dot(A, xn - x_hat))**2 if mu <= cond: x_hat = xn else: while mu > cond: mu /= kappa*(1 - c) xn = x_hat + mu*g thresh = -np.sort(-np.abs(xn))[k-1] xn[np.abs(xn) < thresh] = 0 cond = (1 - c)*np.linalg.norm( xn - x_hat)**2/np.linalg.norm(np.dot(A, xn - x_hat))**2 Tn = np.nonzero(xn) x_hat = xn # Regroup and debrief... if ii == (maxiter-1): logging.warning( 'Hit maximum iteration count, estimate may not be accurate!') else: if disp: logging.info('Final || r || . || y ||^-1 : %g', (np.linalg.norm(r)/np.linalg.norm(y))) return x_hat
if __name__ == '__main__': pass