Source code for mr_utils.utils.wavelet

'''Wrappers for PyWavelets.'''

import warnings

import pywt
import numpy as np

[docs]def combine_chunks(wvlt, shape, dtype=float): '''Stitch together the output of PyWavelets wavedec2. Parameters ---------- wvlt : list of array_like Output of pywt.wavedec2(). shape : tuple Desired shape. dtype : np.dtype, optional Type of numpy array. Returns ------- wavelet_transform : array_like The stitched together elements wvlt (see notes). locations : list Indices telling us how we stitched it together so we can take it back apart. Notes ----- .. code-block:: none We have tuples that look like this: ------------------- | | | | cA(LL) | cH(LH) | | | | (cA, (cH, cV, cD)) <---> ------------------- | | | | cV(HL) | cD(HH) | | | | ------------------- ''' # Initialize wavelet_transform = np.zeros(shape, dtype=dtype) cVy = 0 cHx = 0 locations = [] # Start top left cA5 = wvlt[0] xx, yy = cA5.shape[:] wavelet_transform[:xx, :yy] = cA5 cHx += xx cVy += yy # locations.append(cartesian((np.arange(xx),np.arange(yy)))) locations.append(((0, xx), (0, yy))) # Iterate over tuples (cHi,cVi,cDi) for ii in range(1, len(wvlt)): locations.append([]) # cA is already in place, move on to cH xx, yy = wvlt[ii][0].shape[:] wavelet_transform[cHx:cHx+xx, :yy] = wvlt[ii][0] #locations[-1].append(cartesian( # (np.arange(cHx,cHx+xx),np.arange(yy)))) locations[-1].append(((cHx, cHx+xx), (0, yy))) # Now get cV xx, yy = wvlt[ii][1].shape[:] wavelet_transform[:xx, cVy:cVy+yy] = wvlt[ii][1] #locations[-1].append( # cartesian((np.arange(xx),np.arange(cVy,cVy+yy)))) locations[-1].append(((0, xx), (cVy, cVy+yy))) # Fill in cD xx, yy = wvlt[ii][2].shape[:] wavelet_transform[cHx:cHx+xx, cVy:cVy+yy] = wvlt[ii][2] locations[-1].append(((cHx, cHx+xx), (cVy, cVy+yy))) # Update indices cHx += xx cVy += yy return(wavelet_transform, locations)
[docs]def split_chunks(coeffs, locations): '''Separate the stitched together transform into blocks again. Parameters ---------- x : array_like Stitched together wavelet transform. locations : list Indices where the coefficients for each block are located. Returns ------- coeff_list : list of array_like List of coefficients of wavelet decomposition (like the output of pywt.wavedec2()). Notes ----- x, locations are the output of combine_chunks(). ''' # Split coefficients out into coefficient list coeff_list = [] xx, yy = locations[0] coeff_list.append(coeffs[xx[0]:xx[1], yy[0]:yy[1]]) for ii in range(1, len(locations)): xx, yy = locations[ii][0] cHi = coeffs[xx[0]:xx[1], yy[0]:yy[1]] xx, yy = locations[ii][1] cVi = coeffs[xx[0]:xx[1], yy[0]:yy[1]] xx, yy = locations[ii][2] cDi = coeffs[xx[0]:xx[1], yy[0]:yy[1]] coeff_list.append((cHi, cVi, cDi)) return coeff_list
[docs]def wavelet_forward( x, wavelet, mode='symmetric', level=None, axes=(-2, -1)): '''Wrapper for the multilevel 2D discrete wavelet transform. Parameters ---------- x : array_like Input data. wavelet : str Wavelet to use. mode : str, optional Signal extension mode. level : int, optional Decomposition level (must be >= 0). axes : tuple, optional Axes over which to compute the DWT. Returns ------- wavelet_transform : array_like The stitched together elements wvlt (see combine_chunks). locations : list Indices telling us how we stitched it together so we can take it back apart. Notes ----- See PyWavelets documentation on pywt.wavedec2() for more information. If level=None (default) then it will be calculated using the dwt_max_level function. ''' # Make sure we don't go too deep max_level = pywt.dwtn_max_level(x.shape, wavelet) if level is not None and level > max_level: msg = ('Level %d cannot be achieved, using max level=%d!' '' % (level, max_level)) warnings.warn(msg) level = max_level # Do the pywavelets thing wvlt = pywt.wavedec2(x, wavelet, mode, level, axes) # But wvlt is a bunch of tuples, and we want them all stitched # together: return combine_chunks(wvlt, x.shape, x.dtype)
[docs]def wavelet_inverse( coeffs, locations, wavelet, mode='symmetric', axes=(-2, -1)): '''Wrapper for the multilevel 2D inverse DWT. Parameters ---------- coeffs : array_like Combined coefficients. locations : list Indices where the coefficients for each block are located. wavelet : str Wavelet to use. mode : str, optional Signal extension mode. axes : tuple, optional Axes over which to compute the IDWT. Returns ------- array_like Inverse transform of wavelet transform, the original image. Notes ----- coeffs, locations are the output of forward(). ''' # Split coefficients out into coefficient list coeff_list = split_chunks(coeffs, locations) return pywt.waverec2(coeff_list, wavelet, mode, axes)
[docs]def cdf97_2d_forward(x, level, axes=(-2, -1)): '''Forward 2D Cohen–Daubechies–Feauveau 9/7 wavelet. Parameters ---------- x : array_like 2D signal. level : int Decomposition level. axes : tuple, optional Axes to perform wavelet decomposition across. Returns ------- wavelet_transform : array_like The stitched together elements wvlt (see combine_chunks). locations : list Indices telling us how we stitched it together so we can take it back apart. Notes ----- Returns transform, same shape as input, with locations. Locations is a list of indices instructing cdf97_2d_inverse where the coefficients for each block are located. Biorthogonal 4/4 is the same as CDF 9/7 according to wikipedia [1]_. References ---------- .. [1] https://en.wikipedia.org/wiki/ Cohen%E2%80%93Daubechies%E2%80%93Feauveau_wavelet#Numbering ''' # Make sure we don't go too deep max_level = pywt.dwtn_max_level(x.shape, 'bior4.4', axes=axes) if level > max_level: msg = ('Level %d cannot be achieved, using max level=%d!' '' % (level, max_level)) warnings.warn(msg) level = max_level # periodization seems to be the only way to get shapes to line up. cdf97 = pywt.wavedec2( x, wavelet='bior4.4', mode='periodization', level=level, axes=axes) # Now throw all the chunks together return combine_chunks(cdf97, x.shape, x.dtype)
[docs]def cdf97_2d_inverse(coeffs, locations, axes=(-2, -1)): '''Inverse 2D Cohen–Daubechies–Feauveau 9/7 wavelet. Parameters ---------- coeffs : array_like Stitched together wavelet transform. locations : list Output of cdf97_2d_forward(). axes : tuple, optional Axes to perform wavelet transform across. Returns ------- array_like Inverse CDF97 transform. ''' # Split coefficients out into coefficient list coeff_list = split_chunks(coeffs, locations) return pywt.waverec2( coeff_list, wavelet='bior4.4', mode='periodization', axes=axes)
if __name__ == '__main__': pass