Linear position-invariant filtering. default tip
authorStefan van der Walt <stefan@sun.ac.za>
Wed Aug 06 12:23:18 2008 +0200 (3 years ago)
changeset 0e97c0a6dd0ea
Linear position-invariant filtering.
lpi_filter.py
tests/data/camera.png
tests/test_lpi_filter.py
       1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
       2 +++ b/lpi_filter.py	Wed Aug 06 12:23:18 2008 +0200
       3 @@ -0,0 +1,158 @@
       4 +"""
       5 +:author: Stefan van der Walt, 2008
       6 +:license: modified BSD
       7 +"""
       8 +
       9 +__all__ = ['LPIFilter2D']
      10 +__docformat__ = 'restructuredtext en'
      11 +
      12 +import numpy as np
      13 +from scipy.fftpack import fftshift, ifftshift
      14 +
      15 +eps = np.finfo(float).eps
      16 +
      17 +class LPIFilter2D(object):
      18 +    """Linear Position-Invariant Filter (2-dimensional)
      19 +
      20 +    """
      21 +    def __init__(self,impulse_response,**filter_params):
      22 +        """
      23 +        *Parameters*:
      24 +            impulse_response : callable f(r,c,**filter_params)
      25 +                Function that yields the impulse response.  `r` and
      26 +                `c` are 1-dimensional vectors that represent row and
      27 +                column positions, in other words coordinates are
      28 +                (r[0],c[0]),(r[0],c[1]) etc.  `**filter_params` are
      29 +                passed through.
      30 +
      31 +                In other words, example would be called like this:
      32 +
      33 +                r = [0,0,0,1,1,1,2,2,2]
      34 +                c = [0,1,2,0,1,2,0,1,2]
      35 +                impulse_response(r,c,**filter_params)
      36 +
      37 +        *Example*:
      38 +
      39 +           Gaussian filter:
      40 +
      41 +           >>> def filt_func(r,c):
      42 +                   return np.exp(-np.hypot(r,c)/1)
      43 +
      44 +           >>> filter = LPIFilter2D(filt_func)
      45 +
      46 +
      47 +        """
      48 +        self.impulse_response = impulse_response
      49 +        self.filter_params = filter_params
      50 +        self._cache = None
      51 +
      52 +    def _pad(self,data,shape):
      53 +        """Pad the data to the given shape with zeros.
      54 +
      55 +        *Parameters*:
      56 +            data : 2-d ndarray
      57 +                Input data
      58 +            shape : (2,) tuple
      59 +
      60 +        """
      61 +        out = np.zeros(shape)
      62 +        out[[slice(0,n) for n in data.shape]] = data
      63 +        return out
      64 +
      65 +    def _prepare(self,data):
      66 +        """Calculate filter and data FFT in preparation for filtering.
      67 +
      68 +        """
      69 +        dshape = np.array(data.shape)
      70 +        dshape += (dshape %2 == 0) # all filter dimensions must be uneven
      71 +        oshape = np.array(data.shape)*2-1
      72 +
      73 +        if self._cache is None or np.any(self._cache.shape != oshape):
      74 +            coords = np.mgrid[[slice(0,float(n)) for n in dshape]]
      75 +            # this steps over two sets of coordinates,
      76 +            # not over the coordinates individually
      77 +            for k,coord in enumerate(coords):
      78 +                coord -= (dshape[k]-1)/2.
      79 +            coords = coords.reshape(2,-1).T # coordinate pairs (r,c)
      80 +
      81 +            f = self.impulse_response(coords[:,0],coords[:,1],
      82 +                                      **self.filter_params).reshape(dshape)
      83 +
      84 +            f = self._pad(f,oshape)
      85 +            F = np.dual.fftn(f)
      86 +            self._cache = F
      87 +        else:
      88 +            F = self._cache
      89 +
      90 +        data = self._pad(data,oshape)
      91 +        G = np.dual.fftn(data)
      92 +
      93 +        return F,G
      94 +
      95 +    def _min_limit(self,x,val=eps):
      96 +        mask = np.abs(x) < eps
      97 +        x[mask] = np.sign(x[mask])*eps
      98 +
      99 +    def _centre(self,x,oshape):
     100 +        """Return an array of oshape from the centre of x.
     101 +
     102 +        """
     103 +        start = (np.array(x.shape) - np.array(oshape))/2.+1
     104 +        out = x[[slice(s,s+n) for s,n in zip(start,oshape)]]
     105 +        return out
     106 +
     107 +    def __call__(self,data):
     108 +        """Apply the filter to the given data.
     109 +
     110 +        *Parameters*:
     111 +            data : (M,N) ndarray
     112 +
     113 +        """
     114 +        F,G = self._prepare(data)
     115 +        out = np.dual.ifftn(F*G)
     116 +        out = np.abs(self._centre(out,data.shape))
     117 +        return out
     118 +
     119 +    def inverse(self,data,max_gain=2):
     120 +        """Apply the filter in reverse to the given data.
     121 +
     122 +        *Parameters*:
     123 +            data : (M,N) ndarray
     124 +                Input data.
     125 +            max_gain : float
     126 +                Limit the filter gain.  Often, the filter contains
     127 +                zeros, which would cause the inverse filter to have
     128 +                infinite gain.  High gain causes amplification of
     129 +                artefacts, so a conservative limit is recommended.
     130 +
     131 +        """
     132 +        F,G = self._prepare(data)
     133 +        self._min_limit(F)
     134 +
     135 +        F = 1/F
     136 +        mask = np.abs(F) > max_gain
     137 +        F[mask] = np.sign(F[mask])*max_gain
     138 +
     139 +        return self._centre(np.abs(ifftshift(np.dual.ifftn(G*F))),data.shape)
     140 +
     141 +    def wiener(self,data,K=0.25):
     142 +        """Minimum Mean Square Error (Wiener) inverse filter.
     143 +
     144 +        *Parameters*:
     145 +            data : (M,N) ndarray
     146 +                Input data.
     147 +            K : float or (M,N) ndarray
     148 +                Ratio between power spectrum of noise and undegraded
     149 +                image.
     150 +
     151 +        """
     152 +        F,G = self._prepare(data)
     153 +        self._min_limit(F)
     154 +
     155 +        H_mag_sqr = np.abs(F)**2
     156 +        F = 1/F * H_mag_sqr / (H_mag_sqr + K)
     157 +
     158 +        return self._centre(np.abs(ifftshift(np.dual.ifftn(G*F))),data.shape)
     159 +
     160 +    def constrained_least_squares(self,data,lam):
     161 +        pass
     1.1 Binary file tests/data/camera.png has changed
     2.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     2.2 +++ b/tests/test_lpi_filter.py	Wed Aug 06 12:23:18 2008 +0200
     2.3 @@ -0,0 +1,81 @@
     2.4 +import os.path
     2.5 +
     2.6 +from unittest import TestCase
     2.7 +
     2.8 +import numpy as np
     2.9 +from numpy import testing
    2.10 +from numpy.testing import assert_equal
    2.11 +
    2.12 +from PIL import Image
    2.13 +
    2.14 +from lpi_filter import *
    2.15 +
    2.16 +data_dir = os.path.join(os.path.dirname(__file__), './data/')
    2.17 +
    2.18 +def imread(fname,flatten=False):
    2.19 +    """Return a copy of a PIL image as a numpy array.
    2.20 +
    2.21 +    *Parameters*:
    2.22 +        im : PIL image
    2.23 +            Input image.
    2.24 +        flatten : bool
    2.25 +            If true, convert the output to grey-scale.
    2.26 +
    2.27 +    *Returns*:
    2.28 +        img_array : ndarray
    2.29 +            The different colour bands/channels are stored in the
    2.30 +            third dimension, such that a grey-image is MxN, an
    2.31 +            RGB-image MxNx3 and an RGBA-image MxNx4.
    2.32 +
    2.33 +    """
    2.34 +    im = Image.open(fname)
    2.35 +    if flatten:
    2.36 +        im = im.convert('F')
    2.37 +    return np.array(im)
    2.38 +
    2.39 +
    2.40 +class TestLPIFilter2D():
    2.41 +    img = imread(os.path.join(data_dir + 'camera.png'),
    2.42 +                 flatten=True)[:-101,:-100]
    2.43 +
    2.44 +    def filt_func(self,r,c):
    2.45 +        return np.exp(-np.hypot(r,c)/1)
    2.46 +
    2.47 +    def setUp(self):
    2.48 +        self.f = LPIFilter2D(self.filt_func)
    2.49 +
    2.50 +    def tst_shape(self, x):
    2.51 +        X = self.f(x)
    2.52 +        assert_equal(X.shape,x.shape)
    2.53 +
    2.54 +    def test_ip_shape(self):
    2.55 +        rows,columns = self.img.shape[:2]
    2.56 +
    2.57 +        for c_slice in [slice(0,columns),slice(0,columns-5),
    2.58 +                        slice(0,columns-100)]:
    2.59 +            yield (self.tst_shape,self.img[:,c_slice])
    2.60 +
    2.61 +    def test_inverse(self):
    2.62 +        F = self.f(self.img)
    2.63 +        g = self.f.inverse(F)
    2.64 +        assert_equal(g.shape,self.img.shape)
    2.65 +
    2.66 +        g1 = self.f.inverse(F[::-1,::-1])
    2.67 +        assert ((g-g1[::-1,::-1]).sum() < 55)
    2.68 +
    2.69 +        # test cache
    2.70 +        g1 = self.f.inverse(F[::-1,::-1])
    2.71 +        assert ((g-g1[::-1,::-1]).sum() < 55)
    2.72 +
    2.73 +
    2.74 +    def test_wiener(self):
    2.75 +        F = self.f(self.img)
    2.76 +        g = self.f.wiener(F)
    2.77 +        assert_equal(g.shape,self.img.shape)
    2.78 +
    2.79 +        g1 = self.f.wiener(F[::-1,::-1])
    2.80 +        assert ((g-g1[::-1,::-1]).sum() < 1)
    2.81 +
    2.82 +
    2.83 +if __name__ == "__main__":
    2.84 +    NumpyTest().run()