Pooling#
Pooling is a standard operation in convolutional neural networks (CNNs) used to downsample feature maps. It reduces the spatial dimensions (height and width) while keeping the number of channels unchanged.
Pooling is not a learnable operation — it applies a fixed function (like max or average) over small regions of the input.
A pooling layer slides a small window (like a kernel) across the input and applies a function. Like convolution, pooling has:
Kernel size: window size (\(k\times k\))
Stride: step size (often equals the kernel size for downsampling)
Padding: rarely used in pooling, but available
Max Pooling#
The example below shows a Max-Pooling with a \(2\times 2\) window.
/tmp/ipykernel_6939/3563995373.py:94: UserWarning: frames=<zip object at 0x7efffa9eccc0> which we can infer the length of, did not pass an explicit *save_count* and passed cache_frame_data=True. To avoid a possibly unbounded cache, frame data caching has been disabled. To suppress this warning either pass `cache_frame_data=False` or `save_count=MAX_FRAMES`.
anim = matplotlib.animation.FuncAnimation(fig, animate, init_func=init,
In Python, we can apply a max pooling layer as follows.
import torch
import torch.nn as nn
x = torch.tensor([[[[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9.,10.,11.,12.],
[13.,14.,15.,16.]]]])
pool = nn.MaxPool2d(kernel_size=2, stride=2)
y = pool(x)
print(y) # shape: (1, 1, 2, 2)
tensor([[[[ 6., 8.],
[14., 16.]]]])
The effect of max-pooling is mainly the amplification of features. It compresses the input to a summary that contains the most prevalent features of the previous feature map.
Average Pooling#
Average pooling computes the average of the window. The example below shows the input on the right and the average pooling output on the right.
/tmp/ipykernel_6939/2782939310.py:94: UserWarning: frames=<zip object at 0x7efef93a4cc0> which we can infer the length of, did not pass an explicit *save_count* and passed cache_frame_data=True. To avoid a possibly unbounded cache, frame data caching has been disabled. To suppress this warning either pass `cache_frame_data=False` or `save_count=MAX_FRAMES`.
anim = matplotlib.animation.FuncAnimation(fig, animate, init_func=init,
Average pooling smoothes the features in the input feature map.