1use yscv_autograd::{Graph, NodeId};
2use yscv_kernels::BatchNorm2dParams;
3use yscv_tensor::Tensor;
4
5use crate::ModelError;
6
7#[derive(Debug, Clone, PartialEq)]
12pub struct BatchNorm2dLayer {
13 num_features: usize,
14 epsilon: f32,
15 gamma: Tensor,
16 beta: Tensor,
17 running_mean: Tensor,
18 running_var: Tensor,
19 gamma_node: Option<NodeId>,
20 beta_node: Option<NodeId>,
21 mean_node: Option<NodeId>,
22 var_node: Option<NodeId>,
23}
24
25impl BatchNorm2dLayer {
26 pub fn new(
27 num_features: usize,
28 epsilon: f32,
29 gamma: Tensor,
30 beta: Tensor,
31 running_mean: Tensor,
32 running_var: Tensor,
33 ) -> Result<Self, ModelError> {
34 let expected = vec![num_features];
35 for (name, t) in [
36 ("gamma", &gamma),
37 ("beta", &beta),
38 ("running_mean", &running_mean),
39 ("running_var", &running_var),
40 ] {
41 if t.shape() != expected {
42 return Err(ModelError::InvalidParameterShape {
43 parameter: name,
44 expected: expected.clone(),
45 got: t.shape().to_vec(),
46 });
47 }
48 }
49 if !epsilon.is_finite() || epsilon <= 0.0 {
50 return Err(ModelError::InvalidBatchNormEpsilon { epsilon });
51 }
52 Ok(Self {
53 num_features,
54 epsilon,
55 gamma,
56 beta,
57 running_mean,
58 running_var,
59 gamma_node: None,
60 beta_node: None,
61 mean_node: None,
62 var_node: None,
63 })
64 }
65
66 pub fn identity_init(num_features: usize, epsilon: f32) -> Result<Self, ModelError> {
68 let gamma = Tensor::filled(vec![num_features], 1.0)?;
69 let beta = Tensor::zeros(vec![num_features])?;
70 let running_mean = Tensor::zeros(vec![num_features])?;
71 let running_var = Tensor::filled(vec![num_features], 1.0)?;
72 Self::new(
73 num_features,
74 epsilon,
75 gamma,
76 beta,
77 running_mean,
78 running_var,
79 )
80 }
81
82 pub fn register_params(&mut self, graph: &mut Graph) {
84 self.gamma_node = Some(graph.variable(self.gamma.clone()));
85 self.beta_node = Some(graph.variable(self.beta.clone()));
86 self.mean_node = Some(graph.constant(self.running_mean.clone()));
87 self.var_node = Some(graph.constant(self.running_var.clone()));
88 }
89
90 pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
91 if let Some(g_id) = self.gamma_node {
92 self.gamma = graph.value(g_id)?.clone();
93 }
94 if let Some(b_id) = self.beta_node {
95 self.beta = graph.value(b_id)?.clone();
96 }
97 Ok(())
98 }
99
100 pub fn num_features(&self) -> usize {
101 self.num_features
102 }
103 pub fn epsilon(&self) -> f32 {
104 self.epsilon
105 }
106 pub fn gamma(&self) -> &Tensor {
107 &self.gamma
108 }
109 pub fn beta(&self) -> &Tensor {
110 &self.beta
111 }
112 pub fn running_mean(&self) -> &Tensor {
113 &self.running_mean
114 }
115 pub fn running_var(&self) -> &Tensor {
116 &self.running_var
117 }
118 pub fn gamma_mut(&mut self) -> &mut Tensor {
119 &mut self.gamma
120 }
121 pub fn beta_mut(&mut self) -> &mut Tensor {
122 &mut self.beta
123 }
124 pub fn running_mean_mut(&mut self) -> &mut Tensor {
125 &mut self.running_mean
126 }
127 pub fn running_var_mut(&mut self) -> &mut Tensor {
128 &mut self.running_var
129 }
130 pub fn gamma_node(&self) -> Option<NodeId> {
131 self.gamma_node
132 }
133 pub fn beta_node(&self) -> Option<NodeId> {
134 self.beta_node
135 }
136
137 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
138 let g_id = self.gamma_node.ok_or(ModelError::ParamsNotRegistered {
139 layer: "BatchNorm2d",
140 })?;
141 let b_id = self.beta_node.ok_or(ModelError::ParamsNotRegistered {
142 layer: "BatchNorm2d",
143 })?;
144 let m_id = self.mean_node.ok_or(ModelError::ParamsNotRegistered {
145 layer: "BatchNorm2d",
146 })?;
147 let v_id = self.var_node.ok_or(ModelError::ParamsNotRegistered {
148 layer: "BatchNorm2d",
149 })?;
150 graph
151 .batch_norm2d_nhwc(input, g_id, b_id, m_id, v_id, self.epsilon)
152 .map_err(Into::into)
153 }
154
155 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
156 yscv_kernels::batch_norm2d_nhwc(
157 input,
158 BatchNorm2dParams {
159 gamma: &self.gamma,
160 beta: &self.beta,
161 mean: &self.running_mean,
162 variance: &self.running_var,
163 epsilon: self.epsilon,
164 },
165 )
166 .map_err(Into::into)
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct LayerNormLayer {
173 normalized_shape: usize,
174 eps: f32,
175 gamma: NodeId,
176 beta: NodeId,
177}
178
179impl LayerNormLayer {
180 pub fn new(graph: &mut Graph, normalized_shape: usize, eps: f32) -> Result<Self, ModelError> {
181 let gamma = graph.variable(Tensor::ones(vec![normalized_shape])?);
182 let beta = graph.variable(Tensor::zeros(vec![normalized_shape])?);
183 Ok(Self {
184 normalized_shape,
185 eps,
186 gamma,
187 beta,
188 })
189 }
190
191 pub fn normalized_shape(&self) -> usize {
192 self.normalized_shape
193 }
194 pub fn gamma_node(&self) -> NodeId {
195 self.gamma
196 }
197 pub fn beta_node(&self) -> NodeId {
198 self.beta
199 }
200
201 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
202 graph
203 .layer_norm(input, self.gamma, self.beta, self.eps)
204 .map_err(Into::into)
205 }
206
207 pub fn forward_inference(&self, graph: &Graph, input: &Tensor) -> Result<Tensor, ModelError> {
208 let shape = input.shape();
209 let last_dim = *shape.last().ok_or(ModelError::InvalidInputShape {
210 expected_features: self.normalized_shape,
211 got: shape.to_vec(),
212 })?;
213 if last_dim != self.normalized_shape {
214 return Err(ModelError::InvalidInputShape {
215 expected_features: self.normalized_shape,
216 got: shape.to_vec(),
217 });
218 }
219 let data = input.data();
220 let gamma = graph.value(self.gamma)?.data().to_vec();
221 let beta = graph.value(self.beta)?.data().to_vec();
222 let num_groups = data.len() / last_dim;
223 let mut out = vec![0.0f32; data.len()];
224 for g in 0..num_groups {
225 let base = g * last_dim;
226 let slice = &data[base..base + last_dim];
227 let mean = slice.iter().sum::<f32>() / last_dim as f32;
228 let var = slice.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / last_dim as f32;
229 let inv_std = 1.0 / (var + self.eps).sqrt();
230 for i in 0..last_dim {
231 out[base + i] = (slice[i] - mean) * inv_std * gamma[i] + beta[i];
232 }
233 }
234 Ok(Tensor::from_vec(shape.to_vec(), out)?)
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct GroupNormLayer {
241 num_groups: usize,
242 num_channels: usize,
243 eps: f32,
244 gamma: NodeId,
245 beta: NodeId,
246}
247
248impl GroupNormLayer {
249 pub fn new(
250 graph: &mut Graph,
251 num_groups: usize,
252 num_channels: usize,
253 eps: f32,
254 ) -> Result<Self, ModelError> {
255 if !num_channels.is_multiple_of(num_groups) {
256 return Err(ModelError::InvalidInputShape {
257 expected_features: num_groups,
258 got: vec![num_channels],
259 });
260 }
261 let gamma = graph.variable(Tensor::ones(vec![num_channels])?);
262 let beta = graph.variable(Tensor::zeros(vec![num_channels])?);
263 Ok(Self {
264 num_groups,
265 num_channels,
266 eps,
267 gamma,
268 beta,
269 })
270 }
271
272 pub fn num_groups(&self) -> usize {
273 self.num_groups
274 }
275 pub fn num_channels(&self) -> usize {
276 self.num_channels
277 }
278 pub fn gamma_node(&self) -> NodeId {
279 self.gamma
280 }
281 pub fn beta_node(&self) -> NodeId {
282 self.beta
283 }
284
285 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
286 graph
287 .group_norm(input, self.gamma, self.beta, self.num_groups, self.eps)
288 .map_err(Into::into)
289 }
290
291 pub fn forward_inference(&self, graph: &Graph, input: &Tensor) -> Result<Tensor, ModelError> {
293 let shape = input.shape();
294 if shape.len() != 4 || shape[3] != self.num_channels {
295 return Err(ModelError::InvalidInputShape {
296 expected_features: self.num_channels,
297 got: shape.to_vec(),
298 });
299 }
300 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
301 let channels_per_group = c / self.num_groups;
302 let spatial = h * w;
303 let data = input.data();
304 let gamma = graph.value(self.gamma)?.data().to_vec();
305 let beta = graph.value(self.beta)?.data().to_vec();
306 let mut out = vec![0.0f32; data.len()];
307
308 for ni in 0..n {
309 for gi in 0..self.num_groups {
310 let c_start = gi * channels_per_group;
311 let c_end = c_start + channels_per_group;
312 let group_size = spatial * channels_per_group;
313 let mut sum = 0.0f32;
314 for hi in 0..h {
315 for wi in 0..w {
316 let base = ((ni * h + hi) * w + wi) * c;
317 for ci in c_start..c_end {
318 sum += data[base + ci];
319 }
320 }
321 }
322 let mean = sum / group_size as f32;
323 let mut var_sum = 0.0f32;
324 for hi in 0..h {
325 for wi in 0..w {
326 let base = ((ni * h + hi) * w + wi) * c;
327 for ci in c_start..c_end {
328 let d = data[base + ci] - mean;
329 var_sum += d * d;
330 }
331 }
332 }
333 let inv_std = 1.0 / (var_sum / group_size as f32 + self.eps).sqrt();
334 for hi in 0..h {
335 for wi in 0..w {
336 let base = ((ni * h + hi) * w + wi) * c;
337 for ci in c_start..c_end {
338 out[base + ci] =
339 (data[base + ci] - mean) * inv_std * gamma[ci] + beta[ci];
340 }
341 }
342 }
343 }
344 }
345 Ok(Tensor::from_vec(shape.to_vec(), out)?)
346 }
347}
348
349#[derive(Debug, Clone, PartialEq)]
353pub struct InstanceNormLayer {
354 num_features: usize,
355 eps: f32,
356 gamma: Tensor,
357 beta: Tensor,
358 gamma_node: Option<NodeId>,
359 beta_node: Option<NodeId>,
360}
361
362impl InstanceNormLayer {
363 pub fn new(num_features: usize, eps: f32) -> Result<Self, ModelError> {
364 Ok(Self {
365 num_features,
366 eps,
367 gamma: Tensor::from_vec(vec![num_features], vec![1.0; num_features])?,
368 beta: Tensor::zeros(vec![num_features])?,
369 gamma_node: None,
370 beta_node: None,
371 })
372 }
373
374 pub fn register_params(&mut self, graph: &mut Graph) {
375 self.gamma_node = Some(graph.variable(self.gamma.clone()));
376 self.beta_node = Some(graph.variable(self.beta.clone()));
377 }
378
379 pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
380 if let Some(g_id) = self.gamma_node {
381 self.gamma = graph.value(g_id)?.clone();
382 }
383 if let Some(b_id) = self.beta_node {
384 self.beta = graph.value(b_id)?.clone();
385 }
386 Ok(())
387 }
388
389 pub fn gamma_node(&self) -> Option<NodeId> {
390 self.gamma_node
391 }
392 pub fn beta_node(&self) -> Option<NodeId> {
393 self.beta_node
394 }
395
396 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
397 let g_id = self.gamma_node.ok_or(ModelError::ParamsNotRegistered {
398 layer: "InstanceNorm",
399 })?;
400 let b_id = self.beta_node.ok_or(ModelError::ParamsNotRegistered {
401 layer: "InstanceNorm",
402 })?;
403 graph
404 .instance_norm_nhwc(input, g_id, b_id, self.eps)
405 .map_err(Into::into)
406 }
407
408 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
409 let shape = input.shape();
410 if shape.len() != 4 || shape[3] != self.num_features {
411 return Err(ModelError::InvalidInputShape {
412 expected_features: self.num_features,
413 got: shape.to_vec(),
414 });
415 }
416 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
417 let data = input.data();
418 let g = self.gamma.data();
419 let b = self.beta.data();
420 let spatial = h * w;
421 let mut out = vec![0.0f32; data.len()];
422
423 for n in 0..batch {
424 for ch in 0..c {
425 let mut sum = 0.0f32;
426 for s in 0..spatial {
427 let idx = (n * h * w + s) * c + ch;
428 sum += data[idx];
429 }
430 let mean = sum / spatial as f32;
431 let mut var_sum = 0.0f32;
432 for s in 0..spatial {
433 let idx = (n * h * w + s) * c + ch;
434 let d = data[idx] - mean;
435 var_sum += d * d;
436 }
437 let inv_std = 1.0 / (var_sum / spatial as f32 + self.eps).sqrt();
438 for s in 0..spatial {
439 let idx = (n * h * w + s) * c + ch;
440 out[idx] = (data[idx] - mean) * inv_std * g[ch] + b[ch];
441 }
442 }
443 }
444 Ok(Tensor::from_vec(shape.to_vec(), out)?)
445 }
446}