Source code for mr_utils.cs.greedy.cosamp

'''Compressive sampling matching pursuit (CoSaMP) algorithm.

This implementation currently does not handle complex signals.
'''

import logging

import numpy as np

from mr_utils.utils.printtable import Table

logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.DEBUG)

[docs]def cosamp( A, y, k, lstsq='exact', tol=1e-8, maxiter=500, x=None, disp=False): '''Compressive sampling matching pursuit (CoSaMP) algorithm. Parameters ========== A : array_like Measurement matrix. y : array_like Measurements (i.e., y = Ax). k : int Number of expected nonzero coefficients. lstsq : {'exact', 'lm', 'gd'}, optional How to solve intermediate least squares problem. tol : float, optional Stopping criteria. maxiter : int, optional Maximum number of iterations. x : array_like, optional True signal we are trying to estimate. disp : bool, optional Whether or not to display iterations. Returns ======= x_hat : array_like Estimate of x. Notes ===== lstsq function - 'exact' solves it using numpy's linalg.lstsq method. - 'lm' uses solves with the Levenberg-Marquardt algorithm. - 'gd' uses 3 iterations of a gradient descent solver. Implements Algorithm 8.7 from [1]_. References ========== .. [1] Eldar, Yonina C., and Gitta Kutyniok, eds. Compressed sensing: theory and applications. Cambridge University Press, 2012. ''' # length of measurement vector and original signal _n, N = A.shape[:] # Initializations x_hat = np.zeros(N, dtype=y.dtype) r = y.copy() ynorm = np.linalg.norm(y) if x is None: x = np.zeros(x_hat.shape, dtype=y.dtype) elif x.size < x_hat.size: x = np.hstack(([0], x)) # Decide how we want to solve the intermediate least squares problem if lstsq == 'exact': lstsq_fun = lambda A0, y: np.linalg.lstsq(A0, y, rcond=None)[0] elif lstsq == 'lm': # # This also doesn't work very well currently.... # from scipy.optimize import least_squares # lstsq_fun = lambda A0, y: least_squares( # lambda x: np.linalg.norm(y - np.dot(A0, x)), # np.zeros(A0.shape[1], dtype=y.dtype))['x'] raise NotImplementedError() elif lstsq == 'gd': # # This doesn't work very well... # from mr_utils.optimization import gd, fd_complex_step # lstsq_fun = lambda A0, y: gd( # lambda x: np.linalg.norm(y - np.dot(A0, x)), # fd_complex_step, # np.zeros(A0.shape[1], dtype=y.dtype), maxiter=3)[0] raise NotImplementedError() else: raise NotImplementedError() # Start up a table if disp: table = Table( ['iter', 'norm', 'MSE'], [len(repr(maxiter)), 8, 8], ['d', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) for ii in range(maxiter): # Get step direction g = np.dot(A.conj().T, r) # Add 2*k largest elements of g to support set Tn = np.union1d(x_hat.nonzero()[0], np.argsort(np.abs(g))[-(2*k):]) # Solve the least squares problem xn = np.zeros(N, dtype=y.dtype) xn[Tn] = lstsq_fun(A[:, Tn], y) xn[np.argsort(np.abs(xn))[:-k]] = 0 x_hat = xn.copy() # Compute new residual r = y - np.dot(A, x_hat) # Compute stopping criteria stop_criteria = np.linalg.norm(r)/ynorm # Show MSE at current iteration if we wanted it if disp: logging.info( table.row( [ii, stop_criteria, np.mean((np.abs(x - x_hat)**2))])) # Check stopping criteria if stop_criteria < tol: break return x_hat