Skip to main content

yscv_model/
fusion.rs

1use yscv_autograd::Graph;
2use yscv_tensor::Tensor;
3
4use crate::{BatchNorm2dLayer, Conv2dLayer, ModelLayer, SequentialModel};
5
6/// Fuse Conv2d + BatchNorm2d into a single Conv2d with adjusted weights and bias.
7///
8/// BatchNorm during inference computes:
9///   `y = gamma * (x - mean) / sqrt(var + eps) + beta`
10///
11/// When preceded by Conv (`conv_out = W * x + b`), we can fold BN into conv:
12///   `W_fused = scale * W`  (per output channel)
13///   `b_fused = scale * (b - mean) + beta`
14/// where `scale = gamma / sqrt(var + eps)`.
15///
16/// The fused Conv2d produces the same output as running Conv2d followed by BatchNorm2d,
17/// eliminating the BatchNorm layer entirely and saving computation.
18///
19/// Conv2d weight layout is NHWC: `[KH, KW, C_in, C_out]`.
20pub fn fuse_conv_bn(conv: &Conv2dLayer, bn: &BatchNorm2dLayer) -> Conv2dLayer {
21    let out_channels = conv.out_channels();
22    assert_eq!(
23        out_channels,
24        bn.num_features(),
25        "Conv2d out_channels ({}) must match BatchNorm2d num_features ({})",
26        out_channels,
27        bn.num_features()
28    );
29
30    let gamma = bn.gamma().data();
31    let beta = bn.beta().data();
32    let running_mean = bn.running_mean().data();
33    let running_var = bn.running_var().data();
34    let eps = bn.epsilon();
35
36    // Compute per-channel scale: gamma / sqrt(var + eps)
37    let scale: Vec<f32> = (0..out_channels)
38        .map(|c| gamma[c] / (running_var[c] + eps).sqrt())
39        .collect();
40
41    // Fuse weights: multiply each output channel slice by its scale.
42    // Weight shape: [KH, KW, C_in, C_out]
43    let weight = conv.weight();
44    let w_data = weight.data();
45    let kh = conv.kernel_h();
46    let kw = conv.kernel_w();
47    let c_in = conv.in_channels();
48
49    let mut fused_w = vec![0.0f32; w_data.len()];
50    for i in 0..kh {
51        for j in 0..kw {
52            for ci in 0..c_in {
53                for co in 0..out_channels {
54                    let idx = ((i * kw + j) * c_in + ci) * out_channels + co;
55                    fused_w[idx] = w_data[idx] * scale[co];
56                }
57            }
58        }
59    }
60
61    // Fuse bias: scale * (old_bias - mean) + beta
62    let old_bias: Vec<f32> = match conv.bias() {
63        Some(b) => b.data().to_vec(),
64        None => vec![0.0; out_channels],
65    };
66
67    let fused_b: Vec<f32> = (0..out_channels)
68        .map(|c| scale[c] * (old_bias[c] - running_mean[c]) + beta[c])
69        .collect();
70
71    let fused_weight =
72        Tensor::from_vec(vec![kh, kw, c_in, out_channels], fused_w).expect("valid fused weight");
73    let fused_bias = Tensor::from_vec(vec![out_channels], fused_b).expect("valid fused bias");
74
75    Conv2dLayer::new(
76        c_in,
77        out_channels,
78        kh,
79        kw,
80        conv.stride_h(),
81        conv.stride_w(),
82        fused_weight,
83        Some(fused_bias),
84    )
85    .expect("fused Conv2dLayer construction should not fail")
86}
87
88/// Scan a `SequentialModel` and fuse Conv2d + BatchNorm2d patterns.
89///
90/// Returns a new optimized `SequentialModel` with fewer layers.
91/// Conv2d immediately followed by BatchNorm2d is replaced by a single fused Conv2d.
92/// All other layers (including ReLU after the fused Conv2d) are preserved as-is.
93pub fn optimize_sequential(model: &SequentialModel, graph: &mut Graph) -> SequentialModel {
94    let layers = model.layers();
95    let mut optimized = SequentialModel::new(graph);
96    let mut i = 0;
97
98    while i < layers.len() {
99        if i + 1 < layers.len()
100            && let (ModelLayer::Conv2d(conv), ModelLayer::BatchNorm2d(bn)) =
101                (&layers[i], &layers[i + 1])
102        {
103            let fused = fuse_conv_bn(conv, bn);
104            optimized
105                .add_conv2d(
106                    fused.in_channels(),
107                    fused.out_channels(),
108                    fused.kernel_h(),
109                    fused.kernel_w(),
110                    fused.stride_h(),
111                    fused.stride_w(),
112                    fused.weight().clone(),
113                    fused.bias().cloned(),
114                )
115                .expect("adding fused conv layer should not fail");
116            i += 2; // skip both Conv2d and BatchNorm2d
117            continue;
118        }
119
120        // Copy layer as-is using the appropriate add method.
121        push_layer(&mut optimized, graph, &layers[i]);
122        i += 1;
123    }
124
125    optimized
126}
127
128/// Helper to push a single `ModelLayer` into a `SequentialModel` via the public API.
129fn push_layer(model: &mut SequentialModel, graph: &mut Graph, layer: &ModelLayer) {
130    match layer {
131        ModelLayer::Conv2d(l) => {
132            model
133                .add_conv2d(
134                    l.in_channels(),
135                    l.out_channels(),
136                    l.kernel_h(),
137                    l.kernel_w(),
138                    l.stride_h(),
139                    l.stride_w(),
140                    l.weight().clone(),
141                    l.bias().cloned(),
142                )
143                .expect("add_conv2d");
144        }
145        ModelLayer::BatchNorm2d(l) => {
146            model
147                .add_batch_norm2d(
148                    l.num_features(),
149                    l.epsilon(),
150                    l.gamma().clone(),
151                    l.beta().clone(),
152                    l.running_mean().clone(),
153                    l.running_var().clone(),
154                )
155                .expect("add_batch_norm2d");
156        }
157        ModelLayer::ReLU(_) => model.add_relu(),
158        ModelLayer::LeakyReLU(l) => {
159            model
160                .add_leaky_relu(l.negative_slope())
161                .expect("add_leaky_relu");
162        }
163        ModelLayer::Sigmoid(_) => model.add_sigmoid(),
164        ModelLayer::Tanh(_) => model.add_tanh(),
165        ModelLayer::Dropout(l) => {
166            model.add_dropout(l.rate()).expect("add_dropout");
167        }
168        ModelLayer::Flatten(_) => model.add_flatten(),
169        ModelLayer::Softmax(_) => model.add_softmax(),
170        ModelLayer::GlobalAvgPool2d(_) => model.add_global_avg_pool2d(),
171        ModelLayer::MaxPool2d(l) => {
172            model
173                .add_max_pool2d(l.kernel_h(), l.kernel_w(), l.stride_h(), l.stride_w())
174                .expect("add_max_pool2d");
175        }
176        ModelLayer::AvgPool2d(l) => {
177            model
178                .add_avg_pool2d(l.kernel_h(), l.kernel_w(), l.stride_h(), l.stride_w())
179                .expect("add_avg_pool2d");
180        }
181        ModelLayer::Linear(l) => {
182            // Linear requires graph registration; use zero_init as a fallback
183            // since we cannot retrieve the tensors without a graph reference.
184            model
185                .add_linear_zero(graph, l.in_features(), l.out_features())
186                .expect("add_linear_zero");
187        }
188        ModelLayer::Embedding(l) => {
189            let weight = Tensor::zeros(vec![l.num_embeddings(), l.embedding_dim()])
190                .expect("embedding weight");
191            model
192                .add_embedding(graph, l.num_embeddings(), l.embedding_dim(), weight)
193                .expect("add_embedding");
194        }
195        ModelLayer::LayerNorm(l) => {
196            model
197                .add_layer_norm(graph, l.normalized_shape(), 1e-5)
198                .expect("add_layer_norm");
199        }
200        ModelLayer::GroupNorm(l) => {
201            model
202                .add_group_norm(graph, l.num_groups(), l.num_channels(), 1e-5)
203                .expect("add_group_norm");
204        }
205        ModelLayer::DepthwiseConv2d(l) => {
206            model
207                .add_depthwise_conv2d(
208                    l.channels(),
209                    l.kernel_h(),
210                    l.kernel_w(),
211                    l.stride_h(),
212                    l.stride_w(),
213                    l.weight().clone(),
214                    l.bias().cloned(),
215                )
216                .expect("add_depthwise_conv2d");
217        }
218        ModelLayer::SeparableConv2d(l) => {
219            model
220                .add_separable_conv2d(
221                    l.in_channels(),
222                    l.out_channels(),
223                    l.kernel_h(),
224                    l.kernel_w(),
225                    l.stride_h(),
226                    l.stride_w(),
227                    l.depthwise().weight().clone(),
228                    l.pointwise().weight().clone(),
229                    l.pointwise().bias().cloned(),
230                )
231                .expect("add_separable_conv2d");
232        }
233        ModelLayer::LoraLinear(l) => {
234            // LoRA layers pass through as-is (cannot reconstruct without graph context).
235            model
236                .add_linear_zero(graph, l.in_features, l.out_features)
237                .expect("add_linear_zero for lora");
238        }
239        // Inference-only layers: push as-is via raw layer insertion.
240        ModelLayer::Conv1d(_)
241        | ModelLayer::Conv3d(_)
242        | ModelLayer::ConvTranspose2d(_)
243        | ModelLayer::AdaptiveAvgPool2d(_)
244        | ModelLayer::AdaptiveMaxPool2d(_)
245        | ModelLayer::InstanceNorm(_)
246        | ModelLayer::PixelShuffle(_)
247        | ModelLayer::Upsample(_)
248        | ModelLayer::GELU(_)
249        | ModelLayer::SiLU(_)
250        | ModelLayer::Mish(_)
251        | ModelLayer::PReLU(_)
252        | ModelLayer::ResidualBlock(_)
253        | ModelLayer::Rnn(_)
254        | ModelLayer::Lstm(_)
255        | ModelLayer::Gru(_)
256        | ModelLayer::MultiHeadAttention(_)
257        | ModelLayer::TransformerEncoder(_)
258        | ModelLayer::FeedForward(_)
259        | ModelLayer::DeformableConv2d(_) => {
260            model.push_raw_layer(layer.clone());
261        }
262    }
263}