Source code for mr_utils.cs.thresholding.iterative_hard_thresholding

'''Simple iterative hard thresholding algorithm.'''

import logging
import numpy as np

# import matplotlib.pyplot as plt

from mr_utils.utils.printtable import Table

logging.basicConfig(
    format='%(levelname)s: %(message)s', level=logging.DEBUG)

[docs]def IHT(A, y, k, mu=1, maxiter=500, tol=1e-8, x=None, disp=False): r'''Iterative hard thresholding algorithm (IHT). Parameters ---------- A : array_like Measurement matrix. y : array_like Measurements (i.e., y = Ax). k : int Number of expected nonzero coefficients. mu : float, optional Step size. maxiter : int, optional Maximum number of iterations. tol : float, optional Stopping criteria. x : array_like, optional True signal we are trying to estimate. disp : bool, optional Whether or not to display iterations. Returns ------- x_hat : array_like Estimate of x. Notes ----- Solves the problem: .. math:: \min_x || y - Ax ||^2_2 \text{ s.t. } ||x||_0 \leq k If `disp=True`, then MSE will be calculated using provided x. `mu=1` seems to satisfy Theorem 8.4 often, but might need to be adjusted (usually < 1). See normalized IHT for adaptive step size. Implements Algorithm 8.5 from [1]_. References ---------- .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # length of measurement vector and original signal _n, N = A.shape[:] # Make sure we have everything we need for disp if disp and x is None: logging.warning('No true x provided, using x=0 for MSE calc.') x = np.zeros(N) # Some fancy, asthetic touches... if disp: table = Table( ['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) range_fun = range else: from tqdm import trange range_fun = lambda x: trange(x, leave=False, desc='IHT') # Initial estimate of x, x_hat x_hat = np.zeros(N, dtype=y.dtype) # Get initial residue r = y.copy() # Set up header for logger if disp: hdr = table.header() for line in hdr.split('\n'): logging.info(line) # Run until tol reached or maxiter reached tt = 0 for tt in range_fun(int(maxiter)): # Update estimate using residual scaled by step size x_hat += mu*np.dot(A.conj().T, r) # Leave only k coefficients nonzero (hard threshold) x_hat[np.argsort(np.abs(x_hat))[:-k]] = 0 stop_criteria = np.linalg.norm(r)/np.linalg.norm(y) # Show MSE at current iteration if we wanted it if disp: logging.info(table.row( [tt, stop_criteria, np.mean((np.abs(x - x_hat)**2))])) # update the residual r = y - np.dot(A, x_hat) # Check stopping criteria if stop_criteria < tol: break # Regroup and debrief... if tt == (maxiter-1): logging.warning(('Hit maximum iteration count, estimate ' 'may not be accurate!')) else: if disp: logging.info( 'Final || r || . || y ||^-1 : %g', (np.linalg.norm(r)/np.linalg.norm(y))) return x_hat
if __name__ == '__main__': pass