How to reconstruct a 2D image using only sine functions

Contents

How to reconstruct a 2D image using only sine functions#

In the previous notebook, we decomposed a one dimensional signal into its frequency components using the Discrete Time Fourier Transform (DTFT) and the Discrete Time Fourier Transform (DFT). Here in this notebook, we will do a similar analysis on 2D signals. We will decompose a 2D signal into its frequency components and reconstruct it back.

2D DTFT#

Remember the analysis equation of DTFT:

\[ \quad X(e^{j\omega}) = \sum_{n=-\infty}^{\infty}x[n]e^{-j\omega n}. \]

For 2D, this equation becomes the following:

\[ \quad X(e^{ju}, e^{jv}) = \sum_{m=-\infty}^{\infty} \sum_{n=-\infty}^{\infty}x[m,n]e^{-ju m - jvn}. \]

The implementation is straightforward (see the 1D DTFT implementation in the previous notebook) and left to the reader as exercise.

We are interested in applying the analysis equation to an image like this

mnist3

This is an image from the MNIST dataset, which you might remember from the notebook An application of convolution in machine learning. It is a 28-by-28 pixel, grayscale image. It is a spatially limited signal, that is, it does not have any values outside the 28-by-28 area. However, the DTFT is for analyzing signals that extend from minus infinity to plus infinity. So, it can be represented as a 2D discrete time signal \(x[m,n]\) where the image pixels are embedded at \(0 \le m \le 27\) and \(0 \le n \le 27\); for all other locations (i.e. \(m<0\;\text{or}\; m>27\;\text{or}\;n<0\;\text{or}\;n>27\)), \(x[m,n]\) is zero.

The other alternative for applying Fourier transform on this image is to use the Discrete Fourier Transform (DFT), which is suitable for time limited or spatially limited data such as images.

2D DFT#

The analysis equation for 2D DFT is as follows:

\[ X[u,v] = \sum_{m=0}^{M-1}\sum_{n=0}^{N-1} x[m,n] e^{-j 2 \pi (\frac{m}{M}u + \frac{n}{N} v)}. \]

2D DFT is implemented in numpy’s fft2. FFT stands for Fast Fourier Transform, which is a fast algorithm to compute DFT. The complex exponential in the right hand side represents waves of different frequencies and different orientations. To see this, let us plot its real part for some parameters. Remember that the Euler equation expands a complex exponential into sine and cosine waves: \(e^{j\theta} = \cos(\theta) + j\sin(\theta)\). So, the real part of the complex exponential above is \(\cos(2\pi (\frac{m}{M}u + \frac{n}{N} v))\). Below we create and plot cosine waves for different \(u\) and \(v\).

import numpy as np
from matplotlib import pyplot as plt 

def generate_wave(u,v):
    nrows = 50
    ncols = 50
    wave = np.zeros((nrows,ncols), dtype=np.float64)
    for m in range(nrows):
        for n in range(ncols):
            f = np.cos(2*np.pi*(u*m/nrows +v*n/ncols))
            wave[m,n] = np.abs(f)
    return wave

plt.imshow(generate_wave(5,5), cmap='gray');
plt.xticks([]), plt.yticks([]);
_images/0aaffd7cfe4b6305c05c962c81959d6ceb70812b7fb1195d900355f2ff5d02e0.png
plt.imshow(generate_wave(0,3), cmap='gray');
plt.xticks([]), plt.yticks([]);
_images/abc89ad44f648c0f1283b5d8c218521b89565e2b8dbe4fad765fc15d5a904be4.png

You are encouraged to play with the values of u and v above to see how the orientation and frequency of the wave changes.

Now let us apply 2D DFT to a simple image, an image of a white square on a black background as seen below.

image = np.zeros((13,13))
image[4:9,4:9] = 1
plt.imshow(image, cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/e55be59bc1172b51bd67417c65c450ad04283253258572dea7d4fb2e3fbf76f5.png

Below two lines compute the DFT of this image. Adn, we plot the magnitude spectrum.

ft = np.fft.fft2(image)
ft = np.fft.fftshift(ft)
plt.imshow(np.abs(ft), cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/a5f4ecc58d1b73aa9888e1e997715ea53e281edcb2104135b1ce34dc1d1b3c89.png

The second line above (the fftshift) shifts the zero-frequency component to the center of the spectrum. Thw low frequency components are around the center and the frequency gets higher as you move away from the center.

We can reconstruct back the original image from its Fourier transsform as follows.

recons_image = np.fft.ifft2(ft)
plt.imshow(np.abs(recons_image), cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/e55be59bc1172b51bd67417c65c450ad04283253258572dea7d4fb2e3fbf76f5.png

Now we will do another reconstruction but this time we will use only the low frequency components. We can achieve this by multiplying the Fourier transform with a mask as follows.

nrows, ncols = image.shape

# Build a mask 
x = np.linspace(0, nrows, nrows)
y = np.linspace(0, ncols, ncols)
X, Y = np.meshgrid(x, y)
cy, cx = nrows/2, ncols/2
mask = np.sqrt((X-cx)**2+(Y-cy)**2)
mask[mask<=4] = 1
mask[mask>4] = 0

plt.imshow(mask, cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/839e45725a605802b17cfe866099c91095d2754c99fe79ee309a62b6de5d1a37.png

To form the mask we computed distances of each pixel from the center and then thresholded these distances with 4. Those loxations (pixels) that have a distance equal to or less than 4 is made 1, others 0. Now we can use this mask to reconstruct our image back by using only the low frequencies that correspond to the locations with mask value 1.

recons_image = np.fft.ifft2(ft*mask)
plt.imshow(np.abs(recons_image), cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/0f43e213d544d00e018343f8b12c5044ef1ad2319a08d4fb20c572af96ce3d9b.png

This result shows that when we use only the low frequency components, the reconstruction is blurry. We do not see the sharp edges of the square, as expected. Low frequency waves cannot reconstruct a sharp edge.

Now let us reverse the mask:

mask2 = mask.copy()
mask2[mask==1] = 0
mask2[mask==0] = 1
plt.imshow(mask2, cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/9b07ca4d31daadf24d3823b02913c3aafc123c8940af8c05227c255fb0223fb6.png

When we apply this mask to the Fourier transform and reconstruct the image, we only get the edges of the white square. This is expected because how we are using only the high frequency components which are good for reconstructing edges, i.e. sharp transitions.

recons_image = np.fft.ifft2(ft*mask2)
plt.imshow(np.abs(recons_image), cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/fd2133fab15be83f70c9833f944b86f49ead4697f44f4af0594bd4f2cb9ed531.png

You are encouraged to play with the threshold value (4) of the mask above and see how it changes the reconstructions with the mask.

Now let us apply DFT to a real image. We will use an MNIST image:

from PIL import Image
import requests
from io import BytesIO

response = requests.get('https://384book.net/_images/3.png')
image = Image.open(BytesIO(response.content))
image = np.asarray(image)
image = image[:,:,0]

plt.imshow(image, cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/44bc2e28841da418130e9ef62d9d68f442e70f447c28b9fe02e8a1f124e149ce.png

Here is how its DFT magnitude spectrum looks like:

ft = np.fft.fft2(image)
ft = np.fft.fftshift(ft)
plt.imshow(np.abs(ft), cmap='gray');
plt.xticks([]), plt.yticks([]);
_images/e11c7dd6fdc2d00cdc9758e2f72be1c2a9aec7f6813ad1c85f0f4f41655eb451.png

If we apply a mask to this spectrum to keep the low frequency components only, we get back a blurry “3”:

nrows, ncols = image.shape

# Build a mask 
x = np.linspace(0, nrows, nrows)
y = np.linspace(0, ncols, ncols)
X, Y = np.meshgrid(x, y)
cy, cx = nrows/2, ncols/2
mask = np.sqrt((X-cx)**2+(Y-cy)**2)
mask[mask<=4] = 1
mask[mask>4] = 0

plt.imshow(mask, cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/a8522432fbb098802cc3785f766b64a74277e9c13e446a8e15d42c49e4be6af9.png
recons_image = np.fft.ifft2(ft*mask)
plt.imshow(np.abs(recons_image), cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/cdbba94f38df8266c1b50ab28da57a403c4c5dd8f4c35dde27f4b21dcc35c485.png

In contrast, if we adjust the mask so that it zeros out the low frequency components and keep only the high frequency components, we get the edges in the image:

mask2 = mask.copy()
mask2[mask==1] = 0
mask2[mask==0] = 1
plt.imshow(mask2, cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/63b1c808d1460ea548b5e9296418d929ba98c7febb655d30dd1b9aef9586bd7c.png
recons_image = np.fft.ifft2(ft*mask2)
plt.imshow(np.abs(recons_image), cmap='gray')
plt.xticks([]), plt.yticks([]);
_images/55dfe2dda73e1179bc8db3dd7bfabf96721dc8336cd52350a9b91a5b58876243.png