Skip to main content

fuse_conv_bn

Function fuse_conv_bn 

Source
pub fn fuse_conv_bn(conv: &Conv2dLayer, bn: &BatchNorm2dLayer) -> Conv2dLayer
Expand description

Fuse Conv2d + BatchNorm2d into a single Conv2d with adjusted weights and bias.

BatchNorm during inference computes: y = gamma * (x - mean) / sqrt(var + eps) + beta

When preceded by Conv (conv_out = W * x + b), we can fold BN into conv: W_fused = scale * W (per output channel) b_fused = scale * (b - mean) + beta where scale = gamma / sqrt(var + eps).

The fused Conv2d produces the same output as running Conv2d followed by BatchNorm2d, eliminating the BatchNorm layer entirely and saving computation.

Conv2d weight layout is NHWC: [KH, KW, C_in, C_out].