Pooling

Pooling#

Pooling is a standard operation in convolutional neural networks (CNNs) used to downsample feature maps. It reduces the spatial dimensions (height and width) while keeping the number of channels unchanged.

Pooling is not a learnable operation — it applies a fixed function (like max or average) over small regions of the input.

A pooling layer slides a small window (like a kernel) across the input and applies a function. Like convolution, pooling has:

  • Kernel size: window size (\(k\times k\))

  • Stride: step size (often equals the kernel size for downsampling)

  • Padding: rarely used in pooling, but available

Max Pooling#

The example below shows a Max-Pooling with a \(2\times 2\) window.

/tmp/ipykernel_6939/3563995373.py:94: UserWarning: frames=<zip object at 0x7efffa9eccc0> which we can infer the length of, did not pass an explicit *save_count* and passed cache_frame_data=True.  To avoid a possibly unbounded cache, frame data caching has been disabled. To suppress this warning either pass `cache_frame_data=False` or `save_count=MAX_FRAMES`.
  anim = matplotlib.animation.FuncAnimation(fig, animate, init_func=init,

In Python, we can apply a max pooling layer as follows.

import torch
import torch.nn as nn

x = torch.tensor([[[[1., 2., 3., 4.],
                    [5., 6., 7., 8.],
                    [9.,10.,11.,12.],
                    [13.,14.,15.,16.]]]])

pool = nn.MaxPool2d(kernel_size=2, stride=2)
y = pool(x)

print(y)  # shape: (1, 1, 2, 2)
tensor([[[[ 6.,  8.],
          [14., 16.]]]])

The effect of max-pooling is mainly the amplification of features. It compresses the input to a summary that contains the most prevalent features of the previous feature map.

Average Pooling#

Average pooling computes the average of the window. The example below shows the input on the right and the average pooling output on the right.

/tmp/ipykernel_6939/2782939310.py:94: UserWarning: frames=<zip object at 0x7efef93a4cc0> which we can infer the length of, did not pass an explicit *save_count* and passed cache_frame_data=True.  To avoid a possibly unbounded cache, frame data caching has been disabled. To suppress this warning either pass `cache_frame_data=False` or `save_count=MAX_FRAMES`.
  anim = matplotlib.animation.FuncAnimation(fig, animate, init_func=init,

Average pooling smoothes the features in the input feature map.