Source code for mr_utils.cs.convex.proximal_gd

'''Proximal Gradient Descent.

Flexible encoding model, flexible sparsity model, and flexible
reordering model.  This is the one I would use out of all the ones
I've coded up.  Might be slower than the others as there's a little
more checking to do each iteration.
'''

import logging
import importlib

import numpy as np
from pywt import threshold

# from mr_utils.utils.orderings import inverse_permutation

logging.basicConfig(format='%(levelname)s: %(message)s',
                    level=logging.DEBUG)

[docs]def proximal_GD( y, forward_fun, inverse_fun, sparsify, unsparsify, reorder_fun=None, mode='soft', alpha=.5, alpha_start=.5, thresh_sep=True, selective=None, x=None, ignore_residual=False, ignore_mse=True, ignore_ssim=True, disp=False, maxiter=200, strikes=0): r'''Proximal gradient descent for generic encoding/sparsity model. Parameters ---------- y : array_like Measured data (i.e., y = Ax). forward_fun : callable A, the forward transformation function. inverse_fun : callable A^H, the inverse transformation function. sparsify : callable Sparsifying transform. unsparsify : callable Inverse sparsifying transform. reorder_fun : callable, optional Reordering function. unreorder_fun : callable, optional Inverse reordering function. mode : {'soft', 'hard', 'garotte', 'greater', 'less'}, optional Thresholding mode. alpha : float or callable, optional Step size, used for thresholding. alpha_start : float, optional Initial alpha to start with if alpha is callable. thresh_sep : bool, optional Whether or not to threshold real/imag individually. selective : bool, optional Function returning indicies of update to keep at each iter. x : array_like, optional The true image we are trying to reconstruct. ignore_residual : bool, optional Whether or not to break out of loop if resid increases. ignore_mse : bool, optional Whether or not to break out of loop if MSE increases. ignore_ssim : bool, optional Whether or not to break out of loop if SSIM increases. disp : bool, optional Whether or not to display iteration info. maxiter : int, optional Maximum number of iterations. strikes : int, optional Number of ending conditions tolerated before giving up. Returns ------- x_hat : array_like Estimate of x. Notes ----- Solves the problem: .. math:: \min_x || y - Ax ||^2_2 + \lambda \text{Sparsify}(x) If `x=None`, then MSE will not be calculated. You probably want `mode='soft'`. For the other options, see docs for pywt.threshold. `selective=None` will not throw away any updates. ''' # Make sure compare_mse, compare_ssim is defined if x is None: compare_mse = lambda xx, yy: 0 compare_ssim = lambda xx, yy: 0 xabs = 0 logging.info( 'No true x provided, MSE/SSIM will not be calculated.') else: from skimage.measure import compare_mse, compare_ssim # Precompute absolute value of true image xabs = np.abs(x.astype(y.dtype)) # Get some display stuff happening if disp: # Don't use tqdm range_fun = range from mr_utils.utils.printtable import Table table = Table( ['iter', 'norm', 'MSE', 'SSIM'], [len(repr(maxiter)), 8, 8, 8], ['d', 'e', 'e', 'e']) hdr = table.header() for line in hdr.split('\n'): logging.info(line) else: # Use tqdm to give us an idea of how fast we're going from tqdm import trange, tqdm range_fun = lambda x: trange( x, leave=False, desc='Proximal GD') # Initialize x_hat = np.zeros(y.shape, dtype=y.dtype) r = -y.copy() prev_stop_criteria = np.inf cur_ssim = 0 prev_ssim = compare_ssim(xabs, np.abs(inverse_fun(y))) cur_mse = 0 prev_mse = compare_mse(xabs, np.abs(inverse_fun(y))) norm_y = np.linalg.norm(y) if isinstance(alpha, float): alpha0 = alpha else: alpha0 = alpha_start # Do the thing strike_count = 0 for ii in range_fun(int(maxiter)): # Compute stop criteria stop_criteria = np.linalg.norm(r)/norm_y if not ignore_residual and stop_criteria > prev_stop_criteria: if strike_count > strikes: msg = ('Breaking out of loop after %d iterations. ' 'Norm of residual increased!' % ii) if importlib.util.find_spec("tqdm") is None: tqdm.write(msg) else: logging.warning(msg) break else: strike_count += 1 prev_stop_criteria = stop_criteria # Compute gradient descent step in prep for reordering grad_step = x_hat - inverse_fun(r) # Do reordering if we asked for it if reorder_fun is not None: reorder_idx = reorder_fun(grad_step) reorder_idx_r = reorder_idx.real.astype(int) reorder_idx_i = reorder_idx.imag.astype(int) # unreorder_idx_r = inverse_permutation(reorder_idx_r) # unreorder_idx_i = inverse_permutation(reorder_idx_i) # unreorder_idx_r = np.arange( # reorder_idx_r.size).astype(int) # unreorder_idx_r[reorder_idx_r] = reorder_idx_r # unreorder_idx_i = np.arange( # reorder_idx_i.size).astype(int) # unreorder_idx_i[reorder_idx_i] = reorder_idx_i grad_step = ( grad_step.real[np.unravel_index( reorder_idx_r, y.shape)] \ + 1j*grad_step.imag[np.unravel_index( reorder_idx_i, y.shape)]).reshape(y.shape) # Take the step, we would normally assign x_hat directly, but # because we might be reordering and selectively updating, # we'll store it in a temporary variable... if thresh_sep: tmp = sparsify(grad_step) # Take a half step in each real/imag after talk with Ed tmp_r = threshold(tmp.real, value=alpha0/2, mode=mode) tmp_i = threshold(tmp.imag, value=alpha0/2, mode=mode) update = unsparsify(tmp_r + 1j*tmp_i) else: update = unsparsify( threshold( sparsify(grad_step), value=alpha0, mode=mode)) # Undo the reordering if we did it if reorder_fun is not None: # update = ( # update.real[np.unravel_index( # unreorder_idx_r, y.shape)] \ # + 1j*update.imag[np.unravel_index( # unreorder_idx_i, y.shape)]).reshape(y.shape) update_r = np.zeros(y.shape) update_r[np.unravel_index( reorder_idx_r, y.shape)] = update.real.flatten() update_i = np.zeros(y.shape) update_i[np.unravel_index( reorder_idx_i, y.shape)] = update.imag.flatten() update = update_r + 1j*update_i # Look at where we want to take the step - tread carefully... if selective is not None: selective_idx = selective(x_hat, update, ii) # Update image estimae if selective is not None: x_hat[selective_idx] = update[selective_idx] else: x_hat = update # Tell the user what happened if disp: curxabs = np.abs(x_hat) cur_mse = compare_mse(curxabs, xabs) cur_ssim = compare_ssim(curxabs, xabs) logging.info( table.row( [ii, stop_criteria, cur_mse, cur_ssim])) if not ignore_mse and cur_mse > prev_mse: if strike_count > strikes: msg = ('Breaking out of loop after %d iterations. ' 'MSE increased!' % ii) if importlib.util.find_spec("tqdm") is None: tqdm.write(msg) else: logging.warning(msg) break else: strike_count += 1 prev_mse = cur_mse if not ignore_ssim and cur_ssim > prev_ssim: if strike_count > strikes: msg = ('Breaking out of loop after %d iterations. ' 'SSIM increased!' % ii) if importlib.util.find_spec("tqdm") is None: tqdm.write(msg) else: logging.warning(msg) break else: strike_count += 1 prev_ssim = cur_ssim # Compute residual r = forward_fun(x_hat) - y # Get next step size if callable(alpha): alpha0 = alpha(alpha0, ii) return x_hat