Skip to main content

yscv_model/layers/
norm.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_kernels::BatchNorm2dParams;
3use yscv_tensor::Tensor;
4
5use crate::ModelError;
6
7/// 2D batch normalization layer (NHWC layout).
8///
9/// Supports both inference-mode (raw tensor) and graph-mode (autograd training).
10/// Stores learned gamma/beta and running mean/variance.
11#[derive(Debug, Clone, PartialEq)]
12pub struct BatchNorm2dLayer {
13    num_features: usize,
14    epsilon: f32,
15    gamma: Tensor,
16    beta: Tensor,
17    running_mean: Tensor,
18    running_var: Tensor,
19    gamma_node: Option<NodeId>,
20    beta_node: Option<NodeId>,
21    mean_node: Option<NodeId>,
22    var_node: Option<NodeId>,
23}
24
25impl BatchNorm2dLayer {
26    pub fn new(
27        num_features: usize,
28        epsilon: f32,
29        gamma: Tensor,
30        beta: Tensor,
31        running_mean: Tensor,
32        running_var: Tensor,
33    ) -> Result<Self, ModelError> {
34        let expected = vec![num_features];
35        for (name, t) in [
36            ("gamma", &gamma),
37            ("beta", &beta),
38            ("running_mean", &running_mean),
39            ("running_var", &running_var),
40        ] {
41            if t.shape() != expected {
42                return Err(ModelError::InvalidParameterShape {
43                    parameter: name,
44                    expected: expected.clone(),
45                    got: t.shape().to_vec(),
46                });
47            }
48        }
49        if !epsilon.is_finite() || epsilon <= 0.0 {
50            return Err(ModelError::InvalidBatchNormEpsilon { epsilon });
51        }
52        Ok(Self {
53            num_features,
54            epsilon,
55            gamma,
56            beta,
57            running_mean,
58            running_var,
59            gamma_node: None,
60            beta_node: None,
61            mean_node: None,
62            var_node: None,
63        })
64    }
65
66    /// Unit-scale/zero-shift initialization with zero running statistics.
67    pub fn identity_init(num_features: usize, epsilon: f32) -> Result<Self, ModelError> {
68        let gamma = Tensor::filled(vec![num_features], 1.0)?;
69        let beta = Tensor::zeros(vec![num_features])?;
70        let running_mean = Tensor::zeros(vec![num_features])?;
71        let running_var = Tensor::filled(vec![num_features], 1.0)?;
72        Self::new(
73            num_features,
74            epsilon,
75            gamma,
76            beta,
77            running_mean,
78            running_var,
79        )
80    }
81
82    /// Registers gamma/beta as graph variables (trainable), running stats as constants.
83    pub fn register_params(&mut self, graph: &mut Graph) {
84        self.gamma_node = Some(graph.variable(self.gamma.clone()));
85        self.beta_node = Some(graph.variable(self.beta.clone()));
86        self.mean_node = Some(graph.constant(self.running_mean.clone()));
87        self.var_node = Some(graph.constant(self.running_var.clone()));
88    }
89
90    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
91        if let Some(g_id) = self.gamma_node {
92            self.gamma = graph.value(g_id)?.clone();
93        }
94        if let Some(b_id) = self.beta_node {
95            self.beta = graph.value(b_id)?.clone();
96        }
97        Ok(())
98    }
99
100    pub fn num_features(&self) -> usize {
101        self.num_features
102    }
103    pub fn epsilon(&self) -> f32 {
104        self.epsilon
105    }
106    pub fn gamma(&self) -> &Tensor {
107        &self.gamma
108    }
109    pub fn beta(&self) -> &Tensor {
110        &self.beta
111    }
112    pub fn running_mean(&self) -> &Tensor {
113        &self.running_mean
114    }
115    pub fn running_var(&self) -> &Tensor {
116        &self.running_var
117    }
118    pub fn gamma_mut(&mut self) -> &mut Tensor {
119        &mut self.gamma
120    }
121    pub fn beta_mut(&mut self) -> &mut Tensor {
122        &mut self.beta
123    }
124    pub fn running_mean_mut(&mut self) -> &mut Tensor {
125        &mut self.running_mean
126    }
127    pub fn running_var_mut(&mut self) -> &mut Tensor {
128        &mut self.running_var
129    }
130    pub fn gamma_node(&self) -> Option<NodeId> {
131        self.gamma_node
132    }
133    pub fn beta_node(&self) -> Option<NodeId> {
134        self.beta_node
135    }
136
137    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
138        let g_id = self.gamma_node.ok_or(ModelError::ParamsNotRegistered {
139            layer: "BatchNorm2d",
140        })?;
141        let b_id = self.beta_node.ok_or(ModelError::ParamsNotRegistered {
142            layer: "BatchNorm2d",
143        })?;
144        let m_id = self.mean_node.ok_or(ModelError::ParamsNotRegistered {
145            layer: "BatchNorm2d",
146        })?;
147        let v_id = self.var_node.ok_or(ModelError::ParamsNotRegistered {
148            layer: "BatchNorm2d",
149        })?;
150        graph
151            .batch_norm2d_nhwc(input, g_id, b_id, m_id, v_id, self.epsilon)
152            .map_err(Into::into)
153    }
154
155    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
156        yscv_kernels::batch_norm2d_nhwc(
157            input,
158            BatchNorm2dParams {
159                gamma: &self.gamma,
160                beta: &self.beta,
161                mean: &self.running_mean,
162                variance: &self.running_var,
163                epsilon: self.epsilon,
164            },
165        )
166        .map_err(Into::into)
167    }
168}
169
170/// Layer normalization over the last dimension.
171#[derive(Debug, Clone)]
172pub struct LayerNormLayer {
173    normalized_shape: usize,
174    eps: f32,
175    gamma: NodeId,
176    beta: NodeId,
177}
178
179impl LayerNormLayer {
180    pub fn new(graph: &mut Graph, normalized_shape: usize, eps: f32) -> Result<Self, ModelError> {
181        let gamma = graph.variable(Tensor::ones(vec![normalized_shape])?);
182        let beta = graph.variable(Tensor::zeros(vec![normalized_shape])?);
183        Ok(Self {
184            normalized_shape,
185            eps,
186            gamma,
187            beta,
188        })
189    }
190
191    pub fn normalized_shape(&self) -> usize {
192        self.normalized_shape
193    }
194    pub fn gamma_node(&self) -> NodeId {
195        self.gamma
196    }
197    pub fn beta_node(&self) -> NodeId {
198        self.beta
199    }
200
201    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
202        graph
203            .layer_norm(input, self.gamma, self.beta, self.eps)
204            .map_err(Into::into)
205    }
206
207    pub fn forward_inference(&self, graph: &Graph, input: &Tensor) -> Result<Tensor, ModelError> {
208        let shape = input.shape();
209        let last_dim = *shape.last().ok_or(ModelError::InvalidInputShape {
210            expected_features: self.normalized_shape,
211            got: shape.to_vec(),
212        })?;
213        if last_dim != self.normalized_shape {
214            return Err(ModelError::InvalidInputShape {
215                expected_features: self.normalized_shape,
216                got: shape.to_vec(),
217            });
218        }
219        let data = input.data();
220        let gamma = graph.value(self.gamma)?.data().to_vec();
221        let beta = graph.value(self.beta)?.data().to_vec();
222        let num_groups = data.len() / last_dim;
223        let mut out = vec![0.0f32; data.len()];
224        for g in 0..num_groups {
225            let base = g * last_dim;
226            let slice = &data[base..base + last_dim];
227            let mean = slice.iter().sum::<f32>() / last_dim as f32;
228            let var = slice.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / last_dim as f32;
229            let inv_std = 1.0 / (var + self.eps).sqrt();
230            for i in 0..last_dim {
231                out[base + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
232            }
233        }
234        Ok(Tensor::from_vec(shape.to_vec(), out)?)
235    }
236}
237
238/// Group normalization: divides channels into groups and normalizes within each group.
239#[derive(Debug, Clone)]
240pub struct GroupNormLayer {
241    num_groups: usize,
242    num_channels: usize,
243    eps: f32,
244    gamma: NodeId,
245    beta: NodeId,
246}
247
248impl GroupNormLayer {
249    pub fn new(
250        graph: &mut Graph,
251        num_groups: usize,
252        num_channels: usize,
253        eps: f32,
254    ) -> Result<Self, ModelError> {
255        if !num_channels.is_multiple_of(num_groups) {
256            return Err(ModelError::InvalidInputShape {
257                expected_features: num_groups,
258                got: vec![num_channels],
259            });
260        }
261        let gamma = graph.variable(Tensor::ones(vec![num_channels])?);
262        let beta = graph.variable(Tensor::zeros(vec![num_channels])?);
263        Ok(Self {
264            num_groups,
265            num_channels,
266            eps,
267            gamma,
268            beta,
269        })
270    }
271
272    pub fn num_groups(&self) -> usize {
273        self.num_groups
274    }
275    pub fn num_channels(&self) -> usize {
276        self.num_channels
277    }
278    pub fn gamma_node(&self) -> NodeId {
279        self.gamma
280    }
281    pub fn beta_node(&self) -> NodeId {
282        self.beta
283    }
284
285    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
286        graph
287            .group_norm(input, self.gamma, self.beta, self.num_groups, self.eps)
288            .map_err(Into::into)
289    }
290
291    /// Forward inference on NHWC input `[N, H, W, C]`.
292    pub fn forward_inference(&self, graph: &Graph, input: &Tensor) -> Result<Tensor, ModelError> {
293        let shape = input.shape();
294        if shape.len() != 4 || shape[3] != self.num_channels {
295            return Err(ModelError::InvalidInputShape {
296                expected_features: self.num_channels,
297                got: shape.to_vec(),
298            });
299        }
300        let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
301        let channels_per_group = c / self.num_groups;
302        let spatial = h * w;
303        let data = input.data();
304        let gamma = graph.value(self.gamma)?.data().to_vec();
305        let beta = graph.value(self.beta)?.data().to_vec();
306        let mut out = vec![0.0f32; data.len()];
307
308        for ni in 0..n {
309            for gi in 0..self.num_groups {
310                let c_start = gi * channels_per_group;
311                let c_end = c_start + channels_per_group;
312                let group_size = spatial * channels_per_group;
313                let mut sum = 0.0f32;
314                for hi in 0..h {
315                    for wi in 0..w {
316                        let base = ((ni * h + hi) * w + wi) * c;
317                        for ci in c_start..c_end {
318                            sum += data[base + ci];
319                        }
320                    }
321                }
322                let mean = sum / group_size as f32;
323                let mut var_sum = 0.0f32;
324                for hi in 0..h {
325                    for wi in 0..w {
326                        let base = ((ni * h + hi) * w + wi) * c;
327                        for ci in c_start..c_end {
328                            let d = data[base + ci] - mean;
329                            var_sum += d * d;
330                        }
331                    }
332                }
333                let inv_std = 1.0 / (var_sum / group_size as f32 + self.eps).sqrt();
334                for hi in 0..h {
335                    for wi in 0..w {
336                        let base = ((ni * h + hi) * w + wi) * c;
337                        for ci in c_start..c_end {
338                            out[base + ci] =
339                                (data[base + ci] - mean) * inv_std * gamma[ci] + beta[ci];
340                        }
341                    }
342                }
343            }
344        }
345        Ok(Tensor::from_vec(shape.to_vec(), out)?)
346    }
347}
348
349/// Instance normalization (normalizes per-sample per-channel).
350///
351/// NHWC layout: `[batch, H, W, C]`.
352#[derive(Debug, Clone, PartialEq)]
353pub struct InstanceNormLayer {
354    num_features: usize,
355    eps: f32,
356    gamma: Tensor,
357    beta: Tensor,
358    gamma_node: Option<NodeId>,
359    beta_node: Option<NodeId>,
360}
361
362impl InstanceNormLayer {
363    pub fn new(num_features: usize, eps: f32) -> Result<Self, ModelError> {
364        Ok(Self {
365            num_features,
366            eps,
367            gamma: Tensor::from_vec(vec![num_features], vec![1.0; num_features])?,
368            beta: Tensor::zeros(vec![num_features])?,
369            gamma_node: None,
370            beta_node: None,
371        })
372    }
373
374    pub fn register_params(&mut self, graph: &mut Graph) {
375        self.gamma_node = Some(graph.variable(self.gamma.clone()));
376        self.beta_node = Some(graph.variable(self.beta.clone()));
377    }
378
379    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
380        if let Some(g_id) = self.gamma_node {
381            self.gamma = graph.value(g_id)?.clone();
382        }
383        if let Some(b_id) = self.beta_node {
384            self.beta = graph.value(b_id)?.clone();
385        }
386        Ok(())
387    }
388
389    pub fn gamma_node(&self) -> Option<NodeId> {
390        self.gamma_node
391    }
392    pub fn beta_node(&self) -> Option<NodeId> {
393        self.beta_node
394    }
395
396    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
397        let g_id = self.gamma_node.ok_or(ModelError::ParamsNotRegistered {
398            layer: "InstanceNorm",
399        })?;
400        let b_id = self.beta_node.ok_or(ModelError::ParamsNotRegistered {
401            layer: "InstanceNorm",
402        })?;
403        graph
404            .instance_norm_nhwc(input, g_id, b_id, self.eps)
405            .map_err(Into::into)
406    }
407
408    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
409        let shape = input.shape();
410        if shape.len() != 4 || shape[3] != self.num_features {
411            return Err(ModelError::InvalidInputShape {
412                expected_features: self.num_features,
413                got: shape.to_vec(),
414            });
415        }
416        let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
417        let data = input.data();
418        let g = self.gamma.data();
419        let b = self.beta.data();
420        let spatial = h * w;
421        let mut out = vec![0.0f32; data.len()];
422
423        for n in 0..batch {
424            for ch in 0..c {
425                let mut sum = 0.0f32;
426                for s in 0..spatial {
427                    let idx = (n * h * w + s) * c + ch;
428                    sum += data[idx];
429                }
430                let mean = sum / spatial as f32;
431                let mut var_sum = 0.0f32;
432                for s in 0..spatial {
433                    let idx = (n * h * w + s) * c + ch;
434                    let d = data[idx] - mean;
435                    var_sum += d * d;
436                }
437                let inv_std = 1.0 / (var_sum / spatial as f32 + self.eps).sqrt();
438                for s in 0..spatial {
439                    let idx = (n * h * w + s) * c + ch;
440                    out[idx] = (data[idx] - mean) * inv_std * g[ch] + b[ch];
441                }
442            }
443        }
444        Ok(Tensor::from_vec(shape.to_vec(), out)?)
445    }
446}