Skip to main content

yscv_model/layers/
conv.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// 2D convolution layer (NHWC layout).
7///
8/// Supports both inference-mode (raw tensor) and graph-mode (autograd training).
9/// Kernel shape: `[KH, KW, C_in, C_out]`.
10#[derive(Debug, Clone, PartialEq)]
11pub struct Conv2dLayer {
12    in_channels: usize,
13    out_channels: usize,
14    kernel_h: usize,
15    kernel_w: usize,
16    stride_h: usize,
17    stride_w: usize,
18    weight: Tensor,
19    bias: Option<Tensor>,
20    weight_node: Option<NodeId>,
21    bias_node: Option<NodeId>,
22}
23
24impl Conv2dLayer {
25    #[allow(clippy::too_many_arguments)]
26    pub fn new(
27        in_channels: usize,
28        out_channels: usize,
29        kernel_h: usize,
30        kernel_w: usize,
31        stride_h: usize,
32        stride_w: usize,
33        weight: Tensor,
34        bias: Option<Tensor>,
35    ) -> Result<Self, ModelError> {
36        let expected_weight = vec![kernel_h, kernel_w, in_channels, out_channels];
37        if weight.shape() != expected_weight {
38            return Err(ModelError::InvalidParameterShape {
39                parameter: "conv2d weight",
40                expected: expected_weight,
41                got: weight.shape().to_vec(),
42            });
43        }
44        if let Some(ref b) = bias
45            && b.shape() != [out_channels]
46        {
47            return Err(ModelError::InvalidParameterShape {
48                parameter: "conv2d bias",
49                expected: vec![out_channels],
50                got: b.shape().to_vec(),
51            });
52        }
53        if stride_h == 0 || stride_w == 0 {
54            return Err(ModelError::InvalidConv2dStride { stride_h, stride_w });
55        }
56        Ok(Self {
57            in_channels,
58            out_channels,
59            kernel_h,
60            kernel_w,
61            stride_h,
62            stride_w,
63            weight,
64            bias,
65            weight_node: None,
66            bias_node: None,
67        })
68    }
69
70    /// Creates a conv2d layer and registers its parameters as graph variables.
71    #[allow(clippy::too_many_arguments)]
72    pub fn new_in_graph(
73        graph: &mut Graph,
74        in_channels: usize,
75        out_channels: usize,
76        kernel_h: usize,
77        kernel_w: usize,
78        stride_h: usize,
79        stride_w: usize,
80        weight: Tensor,
81        bias: Option<Tensor>,
82    ) -> Result<Self, ModelError> {
83        let mut layer = Self::new(
84            in_channels,
85            out_channels,
86            kernel_h,
87            kernel_w,
88            stride_h,
89            stride_w,
90            weight,
91            bias,
92        )?;
93        layer.register_params(graph);
94        Ok(layer)
95    }
96
97    pub fn zero_init(
98        in_channels: usize,
99        out_channels: usize,
100        kernel_h: usize,
101        kernel_w: usize,
102        stride_h: usize,
103        stride_w: usize,
104        use_bias: bool,
105    ) -> Result<Self, ModelError> {
106        let weight = Tensor::zeros(vec![kernel_h, kernel_w, in_channels, out_channels])?;
107        let bias = if use_bias {
108            Some(Tensor::zeros(vec![out_channels])?)
109        } else {
110            None
111        };
112        Self::new(
113            in_channels,
114            out_channels,
115            kernel_h,
116            kernel_w,
117            stride_h,
118            stride_w,
119            weight,
120            bias,
121        )
122    }
123
124    /// Registers weight/bias tensors as graph variables for autograd training.
125    pub fn register_params(&mut self, graph: &mut Graph) {
126        self.weight_node = Some(graph.variable(self.weight.clone()));
127        self.bias_node = self.bias.as_ref().map(|b| graph.variable(b.clone()));
128    }
129
130    /// Synchronizes owned tensors from the graph (e.g. after optimizer step).
131    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
132        if let Some(w_id) = self.weight_node {
133            self.weight = graph.value(w_id)?.clone();
134        }
135        if let Some(b_id) = self.bias_node {
136            self.bias = Some(graph.value(b_id)?.clone());
137        }
138        Ok(())
139    }
140
141    pub fn in_channels(&self) -> usize {
142        self.in_channels
143    }
144    pub fn out_channels(&self) -> usize {
145        self.out_channels
146    }
147    pub fn kernel_h(&self) -> usize {
148        self.kernel_h
149    }
150    pub fn kernel_w(&self) -> usize {
151        self.kernel_w
152    }
153    pub fn stride_h(&self) -> usize {
154        self.stride_h
155    }
156    pub fn stride_w(&self) -> usize {
157        self.stride_w
158    }
159    pub fn weight(&self) -> &Tensor {
160        &self.weight
161    }
162    pub fn bias(&self) -> Option<&Tensor> {
163        self.bias.as_ref()
164    }
165    pub fn weight_mut(&mut self) -> &mut Tensor {
166        &mut self.weight
167    }
168    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
169        self.bias.as_mut()
170    }
171    pub fn weight_node(&self) -> Option<NodeId> {
172        self.weight_node
173    }
174    pub fn bias_node(&self) -> Option<NodeId> {
175        self.bias_node
176    }
177
178    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
179        let w_id = self
180            .weight_node
181            .ok_or(ModelError::ParamsNotRegistered { layer: "Conv2d" })?;
182        graph
183            .conv2d_nhwc(input, w_id, self.bias_node, self.stride_h, self.stride_w)
184            .map_err(Into::into)
185    }
186
187    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
188        yscv_kernels::conv2d_nhwc(
189            input,
190            &self.weight,
191            self.bias.as_ref(),
192            self.stride_h,
193            self.stride_w,
194        )
195        .map_err(Into::into)
196    }
197}
198
199/// Depthwise 2D convolution layer (NHWC layout).
200///
201/// Each input channel is convolved with its own filter.
202/// Kernel shape: `[KH, KW, C, 1]`.
203#[derive(Debug, Clone, PartialEq)]
204pub struct DepthwiseConv2dLayer {
205    channels: usize,
206    kernel_h: usize,
207    kernel_w: usize,
208    stride_h: usize,
209    stride_w: usize,
210    weight: Tensor,
211    bias: Option<Tensor>,
212    weight_node: Option<NodeId>,
213    bias_node: Option<NodeId>,
214}
215
216impl DepthwiseConv2dLayer {
217    pub fn new(
218        channels: usize,
219        kernel_h: usize,
220        kernel_w: usize,
221        stride_h: usize,
222        stride_w: usize,
223        weight: Tensor,
224        bias: Option<Tensor>,
225    ) -> Result<Self, ModelError> {
226        let expected_weight = vec![kernel_h, kernel_w, channels, 1];
227        if weight.shape() != expected_weight {
228            return Err(ModelError::InvalidParameterShape {
229                parameter: "depthwise_conv2d weight",
230                expected: expected_weight,
231                got: weight.shape().to_vec(),
232            });
233        }
234        if let Some(ref b) = bias
235            && b.shape() != [channels]
236        {
237            return Err(ModelError::InvalidParameterShape {
238                parameter: "depthwise_conv2d bias",
239                expected: vec![channels],
240                got: b.shape().to_vec(),
241            });
242        }
243        if stride_h == 0 || stride_w == 0 {
244            return Err(ModelError::InvalidConv2dStride { stride_h, stride_w });
245        }
246        Ok(Self {
247            channels,
248            kernel_h,
249            kernel_w,
250            stride_h,
251            stride_w,
252            weight,
253            bias,
254            weight_node: None,
255            bias_node: None,
256        })
257    }
258
259    pub fn zero_init(
260        channels: usize,
261        kernel_h: usize,
262        kernel_w: usize,
263        stride_h: usize,
264        stride_w: usize,
265        use_bias: bool,
266    ) -> Result<Self, ModelError> {
267        let weight = Tensor::zeros(vec![kernel_h, kernel_w, channels, 1])?;
268        let bias = if use_bias {
269            Some(Tensor::zeros(vec![channels])?)
270        } else {
271            None
272        };
273        Self::new(
274            channels, kernel_h, kernel_w, stride_h, stride_w, weight, bias,
275        )
276    }
277
278    pub fn register_params(&mut self, graph: &mut Graph) {
279        self.weight_node = Some(graph.variable(self.weight.clone()));
280        self.bias_node = self.bias.as_ref().map(|b| graph.variable(b.clone()));
281    }
282
283    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
284        if let Some(w_id) = self.weight_node {
285            self.weight = graph.value(w_id)?.clone();
286        }
287        if let Some(b_id) = self.bias_node {
288            self.bias = Some(graph.value(b_id)?.clone());
289        }
290        Ok(())
291    }
292
293    pub fn channels(&self) -> usize {
294        self.channels
295    }
296    pub fn kernel_h(&self) -> usize {
297        self.kernel_h
298    }
299    pub fn kernel_w(&self) -> usize {
300        self.kernel_w
301    }
302    pub fn stride_h(&self) -> usize {
303        self.stride_h
304    }
305    pub fn stride_w(&self) -> usize {
306        self.stride_w
307    }
308    pub fn weight(&self) -> &Tensor {
309        &self.weight
310    }
311    pub fn bias(&self) -> Option<&Tensor> {
312        self.bias.as_ref()
313    }
314    pub fn weight_mut(&mut self) -> &mut Tensor {
315        &mut self.weight
316    }
317    pub fn bias_mut(&mut self) -> Option<&mut Tensor> {
318        self.bias.as_mut()
319    }
320    pub fn weight_node(&self) -> Option<NodeId> {
321        self.weight_node
322    }
323    pub fn bias_node(&self) -> Option<NodeId> {
324        self.bias_node
325    }
326
327    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
328        let w_id = self.weight_node.ok_or(ModelError::ParamsNotRegistered {
329            layer: "DepthwiseConv2d",
330        })?;
331        graph
332            .depthwise_conv2d_nhwc(input, w_id, self.bias_node, self.stride_h, self.stride_w)
333            .map_err(Into::into)
334    }
335
336    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
337        yscv_kernels::depthwise_conv2d_nhwc(
338            input,
339            &self.weight,
340            self.bias.as_ref(),
341            self.stride_h,
342            self.stride_w,
343        )
344        .map_err(Into::into)
345    }
346}
347
348/// Separable 2D convolution layer (NHWC layout).
349///
350/// Composed of a depthwise convolution followed by a pointwise (1x1) convolution.
351/// Depthwise kernel shape: `[KH, KW, C_in, 1]`, pointwise kernel: `[1, 1, C_in, C_out]`.
352#[derive(Debug, Clone, PartialEq)]
353pub struct SeparableConv2dLayer {
354    depthwise: DepthwiseConv2dLayer,
355    pointwise: Conv2dLayer,
356}
357
358impl SeparableConv2dLayer {
359    #[allow(clippy::too_many_arguments)]
360    pub fn new(
361        in_channels: usize,
362        out_channels: usize,
363        kernel_h: usize,
364        kernel_w: usize,
365        stride_h: usize,
366        stride_w: usize,
367        depthwise_weight: Tensor,
368        pointwise_weight: Tensor,
369        bias: Option<Tensor>,
370    ) -> Result<Self, ModelError> {
371        let depthwise = DepthwiseConv2dLayer::new(
372            in_channels,
373            kernel_h,
374            kernel_w,
375            stride_h,
376            stride_w,
377            depthwise_weight,
378            None,
379        )?;
380        let pointwise = Conv2dLayer::new(
381            in_channels,
382            out_channels,
383            1,
384            1,
385            1,
386            1,
387            pointwise_weight,
388            bias,
389        )?;
390        Ok(Self {
391            depthwise,
392            pointwise,
393        })
394    }
395
396    pub fn zero_init(
397        in_channels: usize,
398        out_channels: usize,
399        kernel_h: usize,
400        kernel_w: usize,
401        stride_h: usize,
402        stride_w: usize,
403        use_bias: bool,
404    ) -> Result<Self, ModelError> {
405        let depthwise = DepthwiseConv2dLayer::zero_init(
406            in_channels,
407            kernel_h,
408            kernel_w,
409            stride_h,
410            stride_w,
411            false,
412        )?;
413        let pointwise = Conv2dLayer::zero_init(in_channels, out_channels, 1, 1, 1, 1, use_bias)?;
414        Ok(Self {
415            depthwise,
416            pointwise,
417        })
418    }
419
420    pub fn register_params(&mut self, graph: &mut Graph) {
421        self.depthwise.register_params(graph);
422        self.pointwise.register_params(graph);
423    }
424
425    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
426        self.depthwise.sync_from_graph(graph)?;
427        self.pointwise.sync_from_graph(graph)?;
428        Ok(())
429    }
430
431    pub fn in_channels(&self) -> usize {
432        self.depthwise.channels()
433    }
434    pub fn out_channels(&self) -> usize {
435        self.pointwise.out_channels()
436    }
437    pub fn kernel_h(&self) -> usize {
438        self.depthwise.kernel_h()
439    }
440    pub fn kernel_w(&self) -> usize {
441        self.depthwise.kernel_w()
442    }
443    pub fn stride_h(&self) -> usize {
444        self.depthwise.stride_h()
445    }
446    pub fn stride_w(&self) -> usize {
447        self.depthwise.stride_w()
448    }
449    pub fn depthwise(&self) -> &DepthwiseConv2dLayer {
450        &self.depthwise
451    }
452    pub fn pointwise(&self) -> &Conv2dLayer {
453        &self.pointwise
454    }
455    pub fn depthwise_mut(&mut self) -> &mut DepthwiseConv2dLayer {
456        &mut self.depthwise
457    }
458    pub fn pointwise_mut(&mut self) -> &mut Conv2dLayer {
459        &mut self.pointwise
460    }
461
462    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
463        let dw_out = self.depthwise.forward(graph, input)?;
464        self.pointwise.forward(graph, dw_out)
465    }
466
467    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
468        let dw_out = self.depthwise.forward_inference(input)?;
469        self.pointwise.forward_inference(&dw_out)
470    }
471}
472
473/// Deformable 2D convolution layer (NHWC layout).
474///
475/// Like standard conv2d but sampling positions are offset by learned offsets.
476/// An internal offset convolution produces offsets from the input, then those
477/// offsets are used to sample input with bilinear interpolation.
478///
479/// Weight shape: `[kH, kW, C_in, C_out]`.
480/// Offset weight shape: `[kH, kW, C_in, kH*kW*2]` -- conv producing offsets.
481#[derive(Debug, Clone, PartialEq)]
482pub struct DeformableConv2dLayer {
483    pub weight: Tensor,
484    pub offset_weight: Tensor,
485    pub bias: Option<Tensor>,
486    pub stride: usize,
487    pub padding: usize,
488    pub kernel_h: usize,
489    pub kernel_w: usize,
490    in_channels: usize,
491    out_channels: usize,
492    weight_node: Option<NodeId>,
493    offset_weight_node: Option<NodeId>,
494    bias_node: Option<NodeId>,
495}
496
497impl DeformableConv2dLayer {
498    #[allow(clippy::too_many_arguments)]
499    pub fn new(
500        in_channels: usize,
501        out_channels: usize,
502        kernel_h: usize,
503        kernel_w: usize,
504        stride: usize,
505        padding: usize,
506        weight: Tensor,
507        offset_weight: Tensor,
508        bias: Option<Tensor>,
509    ) -> Result<Self, ModelError> {
510        let expected_weight = vec![kernel_h, kernel_w, in_channels, out_channels];
511        if weight.shape() != expected_weight {
512            return Err(ModelError::InvalidParameterShape {
513                parameter: "deformable_conv2d weight",
514                expected: expected_weight,
515                got: weight.shape().to_vec(),
516            });
517        }
518        let offset_out = kernel_h * kernel_w * 2;
519        let expected_offset_weight = vec![kernel_h, kernel_w, in_channels, offset_out];
520        if offset_weight.shape() != expected_offset_weight {
521            return Err(ModelError::InvalidParameterShape {
522                parameter: "deformable_conv2d offset_weight",
523                expected: expected_offset_weight,
524                got: offset_weight.shape().to_vec(),
525            });
526        }
527        if let Some(ref b) = bias
528            && b.shape() != [out_channels]
529        {
530            return Err(ModelError::InvalidParameterShape {
531                parameter: "deformable_conv2d bias",
532                expected: vec![out_channels],
533                got: b.shape().to_vec(),
534            });
535        }
536        if stride == 0 {
537            return Err(ModelError::InvalidConv2dStride {
538                stride_h: stride,
539                stride_w: stride,
540            });
541        }
542        Ok(Self {
543            weight,
544            offset_weight,
545            bias,
546            stride,
547            padding,
548            kernel_h,
549            kernel_w,
550            in_channels,
551            out_channels,
552            weight_node: None,
553            offset_weight_node: None,
554            bias_node: None,
555        })
556    }
557
558    pub fn zero_init(
559        in_channels: usize,
560        out_channels: usize,
561        kernel_h: usize,
562        kernel_w: usize,
563        stride: usize,
564        padding: usize,
565        use_bias: bool,
566    ) -> Result<Self, ModelError> {
567        let offset_out = kernel_h * kernel_w * 2;
568        let weight = Tensor::zeros(vec![kernel_h, kernel_w, in_channels, out_channels])?;
569        let offset_weight = Tensor::zeros(vec![kernel_h, kernel_w, in_channels, offset_out])?;
570        let bias = if use_bias {
571            Some(Tensor::zeros(vec![out_channels])?)
572        } else {
573            None
574        };
575        Self::new(
576            in_channels,
577            out_channels,
578            kernel_h,
579            kernel_w,
580            stride,
581            padding,
582            weight,
583            offset_weight,
584            bias,
585        )
586    }
587
588    pub fn in_channels(&self) -> usize {
589        self.in_channels
590    }
591    pub fn out_channels(&self) -> usize {
592        self.out_channels
593    }
594    pub fn kernel_h(&self) -> usize {
595        self.kernel_h
596    }
597    pub fn kernel_w(&self) -> usize {
598        self.kernel_w
599    }
600    pub fn stride(&self) -> usize {
601        self.stride
602    }
603    pub fn padding(&self) -> usize {
604        self.padding
605    }
606    pub fn weight(&self) -> &Tensor {
607        &self.weight
608    }
609    pub fn offset_weight(&self) -> &Tensor {
610        &self.offset_weight
611    }
612    pub fn bias(&self) -> Option<&Tensor> {
613        self.bias.as_ref()
614    }
615
616    pub fn weight_node(&self) -> Option<NodeId> {
617        self.weight_node
618    }
619
620    pub fn register_params(&mut self, graph: &mut Graph) {
621        self.weight_node = Some(graph.variable(self.weight.clone()));
622        self.offset_weight_node = Some(graph.variable(self.offset_weight.clone()));
623        self.bias_node = self.bias.as_ref().map(|b| graph.variable(b.clone()));
624    }
625
626    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
627        let w = self.weight_node.ok_or(ModelError::ParamsNotRegistered {
628            layer: "DeformableConv2d",
629        })?;
630        let ow = self
631            .offset_weight_node
632            .ok_or(ModelError::ParamsNotRegistered {
633                layer: "DeformableConv2d",
634            })?;
635
636        // Step 1: Compute offsets via standard conv2d of input with offset_weight.
637        // If padding > 0, we need to pad the input first.
638        let padded = if self.padding > 0 {
639            let pad_per_dim = &[
640                0,
641                0,
642                self.padding,
643                self.padding,
644                self.padding,
645                self.padding,
646                0,
647                0,
648            ];
649            graph.pad(input, pad_per_dim, 0.0)?
650        } else {
651            input
652        };
653        let offsets = graph.conv2d_nhwc(padded, ow, None, self.stride, self.stride)?;
654
655        // Step 2: Deformable conv using the computed offsets.
656        graph
657            .deformable_conv2d_nhwc(input, w, offsets, self.bias_node, self.stride, self.padding)
658            .map_err(Into::into)
659    }
660
661    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
662        // Step 1: Compute offsets by convolving input with offset_weight (standard conv2d).
663        // The offset conv uses the same kernel size, stride, and padding as the main conv
664        // so that the offset map has the same spatial dimensions as the output.
665        // Since conv2d_nhwc does not support padding, we pad the input manually.
666        let padded = if self.padding > 0 {
667            Self::pad_nhwc(input, self.padding)?
668        } else {
669            input.clone()
670        };
671        let offsets = yscv_kernels::conv2d_nhwc(
672            &padded,
673            &self.offset_weight,
674            None,
675            self.stride,
676            self.stride,
677        )?;
678
679        // Step 2: Apply deformable conv with those offsets.
680        Ok(yscv_kernels::deformable_conv2d_nhwc(
681            input,
682            &self.weight,
683            &offsets,
684            self.bias.as_ref(),
685            self.stride,
686            self.padding,
687        )?)
688    }
689
690    /// Zero-pads an NHWC tensor by `pad` on each spatial side.
691    fn pad_nhwc(input: &Tensor, pad: usize) -> Result<Tensor, ModelError> {
692        let batch = input.shape()[0];
693        let h = input.shape()[1];
694        let w = input.shape()[2];
695        let c = input.shape()[3];
696        let new_h = h + 2 * pad;
697        let new_w = w + 2 * pad;
698        let mut data = vec![0.0f32; batch * new_h * new_w * c];
699        let src = input.data();
700        for n in 0..batch {
701            for y in 0..h {
702                let src_row = n * h * w * c + y * w * c;
703                let dst_row = n * new_h * new_w * c + (y + pad) * new_w * c + pad * c;
704                data[dst_row..dst_row + w * c].copy_from_slice(&src[src_row..src_row + w * c]);
705            }
706        }
707        Tensor::from_vec(vec![batch, new_h, new_w, c], data).map_err(Into::into)
708    }
709}
710
711/// 1D convolution layer (NLC layout: `[batch, length, channels]`).
712///
713/// Kernel shape: `[K, C_in, C_out]`.
714#[derive(Debug, Clone, PartialEq)]
715pub struct Conv1dLayer {
716    in_channels: usize,
717    out_channels: usize,
718    kernel_size: usize,
719    stride: usize,
720    weight: Tensor,
721    bias: Option<Tensor>,
722    weight_node: Option<NodeId>,
723    bias_node: Option<NodeId>,
724}
725
726impl Conv1dLayer {
727    pub fn new(
728        in_channels: usize,
729        out_channels: usize,
730        kernel_size: usize,
731        stride: usize,
732        weight: Tensor,
733        bias: Option<Tensor>,
734    ) -> Result<Self, ModelError> {
735        let expected = vec![kernel_size, in_channels, out_channels];
736        if weight.shape() != expected {
737            return Err(ModelError::InvalidParameterShape {
738                parameter: "conv1d weight",
739                expected,
740                got: weight.shape().to_vec(),
741            });
742        }
743        Ok(Self {
744            in_channels,
745            out_channels,
746            kernel_size,
747            stride,
748            weight,
749            bias,
750            weight_node: None,
751            bias_node: None,
752        })
753    }
754
755    pub fn kernel_size(&self) -> usize {
756        self.kernel_size
757    }
758    pub fn kernel(&self) -> &Tensor {
759        &self.weight
760    }
761    pub fn in_channels(&self) -> usize {
762        self.in_channels
763    }
764    pub fn out_channels(&self) -> usize {
765        self.out_channels
766    }
767    pub fn stride(&self) -> usize {
768        self.stride
769    }
770    pub fn weight_node(&self) -> Option<NodeId> {
771        self.weight_node
772    }
773    pub fn bias_node(&self) -> Option<NodeId> {
774        self.bias_node
775    }
776
777    pub fn register_params(&mut self, graph: &mut Graph) {
778        self.weight_node = Some(graph.variable(self.weight.clone()));
779        self.bias_node = self.bias.as_ref().map(|b| graph.variable(b.clone()));
780    }
781
782    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
783        let w_id = self
784            .weight_node
785            .ok_or(ModelError::ParamsNotRegistered { layer: "Conv1d" })?;
786        graph
787            .conv1d_nlc(input, w_id, self.bias_node, self.stride)
788            .map_err(Into::into)
789    }
790
791    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
792        let shape = input.shape();
793        if shape.len() != 3 || shape[2] != self.in_channels {
794            return Err(ModelError::InvalidInputShape {
795                expected_features: self.in_channels,
796                got: shape.to_vec(),
797            });
798        }
799        let (batch, length, _) = (shape[0], shape[1], shape[2]);
800        let out_len = (length - self.kernel_size) / self.stride + 1;
801        let data = input.data();
802        let w = self.weight.data();
803
804        let mut out = vec![0.0f32; batch * out_len * self.out_channels];
805        for b in 0..batch {
806            for ol in 0..out_len {
807                let start = ol * self.stride;
808                for oc in 0..self.out_channels {
809                    let mut sum = 0.0f32;
810                    for k in 0..self.kernel_size {
811                        for ic in 0..self.in_channels {
812                            sum += data[(b * length + start + k) * self.in_channels + ic]
813                                * w[(k * self.in_channels + ic) * self.out_channels + oc];
814                        }
815                    }
816                    if let Some(ref bias) = self.bias {
817                        sum += bias.data()[oc];
818                    }
819                    out[(b * out_len + ol) * self.out_channels + oc] = sum;
820                }
821            }
822        }
823        Ok(Tensor::from_vec(
824            vec![batch, out_len, self.out_channels],
825            out,
826        )?)
827    }
828}
829
830/// Transposed 2D convolution layer (NHWC layout).
831///
832/// Kernel shape: `[KH, KW, C_out, C_in]` (note: reversed from Conv2d).
833#[derive(Debug, Clone, PartialEq)]
834pub struct ConvTranspose2dLayer {
835    in_channels: usize,
836    out_channels: usize,
837    kernel_h: usize,
838    kernel_w: usize,
839    stride_h: usize,
840    stride_w: usize,
841    weight: Tensor,
842    bias: Option<Tensor>,
843    weight_node: Option<NodeId>,
844    bias_node: Option<NodeId>,
845}
846
847impl ConvTranspose2dLayer {
848    #[allow(clippy::too_many_arguments)]
849    pub fn new(
850        in_channels: usize,
851        out_channels: usize,
852        kernel_h: usize,
853        kernel_w: usize,
854        stride_h: usize,
855        stride_w: usize,
856        weight: Tensor,
857        bias: Option<Tensor>,
858    ) -> Result<Self, ModelError> {
859        let expected = vec![kernel_h, kernel_w, out_channels, in_channels];
860        if weight.shape() != expected {
861            return Err(ModelError::InvalidParameterShape {
862                parameter: "conv_transpose2d weight",
863                expected,
864                got: weight.shape().to_vec(),
865            });
866        }
867        Ok(Self {
868            in_channels,
869            out_channels,
870            kernel_h,
871            kernel_w,
872            stride_h,
873            stride_w,
874            weight,
875            bias,
876            weight_node: None,
877            bias_node: None,
878        })
879    }
880
881    pub fn kernel(&self) -> &Tensor {
882        &self.weight
883    }
884    pub fn stride(&self) -> usize {
885        self.stride_h
886    }
887    pub fn weight_node(&self) -> Option<NodeId> {
888        self.weight_node
889    }
890    pub fn bias_node(&self) -> Option<NodeId> {
891        self.bias_node
892    }
893
894    pub fn register_params(&mut self, graph: &mut Graph) {
895        self.weight_node = Some(graph.variable(self.weight.clone()));
896        self.bias_node = self.bias.as_ref().map(|b| graph.variable(b.clone()));
897    }
898
899    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
900        if let Some(w_id) = self.weight_node {
901            self.weight = graph.value(w_id)?.clone();
902        }
903        if let Some(b_id) = self.bias_node {
904            self.bias = Some(graph.value(b_id)?.clone());
905        }
906        Ok(())
907    }
908
909    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
910        let w_id = self.weight_node.ok_or(ModelError::ParamsNotRegistered {
911            layer: "ConvTranspose2d",
912        })?;
913        graph
914            .conv_transpose2d_nhwc(input, w_id, self.bias_node, self.stride_h, self.stride_w)
915            .map_err(Into::into)
916    }
917
918    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
919        let shape = input.shape();
920        if shape.len() != 4 || shape[3] != self.in_channels {
921            return Err(ModelError::InvalidInputShape {
922                expected_features: self.in_channels,
923                got: shape.to_vec(),
924            });
925        }
926        let (batch, h, w, _) = (shape[0], shape[1], shape[2], shape[3]);
927        let out_h = (h - 1) * self.stride_h + self.kernel_h;
928        let out_w = (w - 1) * self.stride_w + self.kernel_w;
929        let data = input.data();
930        let wt = self.weight.data();
931
932        let mut out = vec![0.0f32; batch * out_h * out_w * self.out_channels];
933        for b in 0..batch {
934            for ih in 0..h {
935                for iw in 0..w {
936                    for ic in 0..self.in_channels {
937                        let val = data[((b * h + ih) * w + iw) * self.in_channels + ic];
938                        for kh in 0..self.kernel_h {
939                            for kw in 0..self.kernel_w {
940                                let oh = ih * self.stride_h + kh;
941                                let ow = iw * self.stride_w + kw;
942                                for oc in 0..self.out_channels {
943                                    let w_idx = ((kh * self.kernel_w + kw) * self.out_channels
944                                        + oc)
945                                        * self.in_channels
946                                        + ic;
947                                    out[((b * out_h + oh) * out_w + ow) * self.out_channels
948                                        + oc] += val * wt[w_idx];
949                                }
950                            }
951                        }
952                    }
953                }
954            }
955        }
956
957        if let Some(ref bias) = self.bias {
958            let bd = bias.data();
959            for i in 0..(batch * out_h * out_w) {
960                for oc in 0..self.out_channels {
961                    out[i * self.out_channels + oc] += bd[oc];
962                }
963            }
964        }
965
966        Ok(Tensor::from_vec(
967            vec![batch, out_h, out_w, self.out_channels],
968            out,
969        )?)
970    }
971}
972
973/// 3D convolution layer (BDHWC layout).
974///
975/// Wraps the `conv3d` kernel for volumetric data (video, medical imaging).
976/// Kernel shape: `[KD, KH, KW, C_in, C_out]`.
977#[derive(Debug, Clone, PartialEq)]
978pub struct Conv3dLayer {
979    in_channels: usize,
980    out_channels: usize,
981    kernel_d: usize,
982    kernel_h: usize,
983    kernel_w: usize,
984    stride: (usize, usize, usize),
985    padding: (usize, usize, usize),
986    weight: Tensor,
987    bias: Option<Tensor>,
988    weight_node: Option<NodeId>,
989    bias_node: Option<NodeId>,
990}
991
992impl Conv3dLayer {
993    #[allow(clippy::too_many_arguments)]
994    pub fn new(
995        in_channels: usize,
996        out_channels: usize,
997        kernel_d: usize,
998        kernel_h: usize,
999        kernel_w: usize,
1000        stride: (usize, usize, usize),
1001        padding: (usize, usize, usize),
1002        weight: Tensor,
1003        bias: Option<Tensor>,
1004    ) -> Result<Self, ModelError> {
1005        let expected_weight = vec![kernel_d, kernel_h, kernel_w, in_channels, out_channels];
1006        if weight.shape() != expected_weight {
1007            return Err(ModelError::InvalidParameterShape {
1008                parameter: "conv3d weight",
1009                expected: expected_weight,
1010                got: weight.shape().to_vec(),
1011            });
1012        }
1013        if let Some(ref b) = bias
1014            && b.shape() != [out_channels]
1015        {
1016            return Err(ModelError::InvalidParameterShape {
1017                parameter: "conv3d bias",
1018                expected: vec![out_channels],
1019                got: b.shape().to_vec(),
1020            });
1021        }
1022        if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
1023            return Err(ModelError::InvalidConv2dStride {
1024                stride_h: stride.1,
1025                stride_w: stride.2,
1026            });
1027        }
1028        Ok(Self {
1029            in_channels,
1030            out_channels,
1031            kernel_d,
1032            kernel_h,
1033            kernel_w,
1034            stride,
1035            padding,
1036            weight,
1037            bias,
1038            weight_node: None,
1039            bias_node: None,
1040        })
1041    }
1042
1043    pub fn zero_init(
1044        in_channels: usize,
1045        out_channels: usize,
1046        kernel_d: usize,
1047        kernel_h: usize,
1048        kernel_w: usize,
1049        stride: (usize, usize, usize),
1050        padding: (usize, usize, usize),
1051        use_bias: bool,
1052    ) -> Result<Self, ModelError> {
1053        let weight = Tensor::zeros(vec![
1054            kernel_d,
1055            kernel_h,
1056            kernel_w,
1057            in_channels,
1058            out_channels,
1059        ])?;
1060        let bias = if use_bias {
1061            Some(Tensor::zeros(vec![out_channels])?)
1062        } else {
1063            None
1064        };
1065        Self::new(
1066            in_channels,
1067            out_channels,
1068            kernel_d,
1069            kernel_h,
1070            kernel_w,
1071            stride,
1072            padding,
1073            weight,
1074            bias,
1075        )
1076    }
1077
1078    pub fn register_params(&mut self, graph: &mut Graph) {
1079        self.weight_node = Some(graph.variable(self.weight.clone()));
1080        self.bias_node = self.bias.as_ref().map(|b| graph.variable(b.clone()));
1081    }
1082
1083    pub fn sync_from_graph(&mut self, graph: &Graph) -> Result<(), ModelError> {
1084        if let Some(w_id) = self.weight_node {
1085            self.weight = graph.value(w_id)?.clone();
1086        }
1087        if let Some(b_id) = self.bias_node {
1088            self.bias = Some(graph.value(b_id)?.clone());
1089        }
1090        Ok(())
1091    }
1092
1093    pub fn in_channels(&self) -> usize {
1094        self.in_channels
1095    }
1096    pub fn out_channels(&self) -> usize {
1097        self.out_channels
1098    }
1099    pub fn weight(&self) -> &Tensor {
1100        &self.weight
1101    }
1102    pub fn bias(&self) -> Option<&Tensor> {
1103        self.bias.as_ref()
1104    }
1105    pub fn weight_node(&self) -> Option<NodeId> {
1106        self.weight_node
1107    }
1108    pub fn bias_node(&self) -> Option<NodeId> {
1109        self.bias_node
1110    }
1111
1112    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
1113        let w_id = self
1114            .weight_node
1115            .ok_or(ModelError::ParamsNotRegistered { layer: "Conv3d" })?;
1116        graph
1117            .conv3d_ndhwc(
1118                input,
1119                w_id,
1120                self.bias_node,
1121                self.stride.0,
1122                self.stride.1,
1123                self.stride.2,
1124            )
1125            .map_err(Into::into)
1126    }
1127
1128    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
1129        let input_shape = input.shape();
1130        if input_shape.len() != 5 {
1131            return Err(ModelError::InvalidParameterShape {
1132                parameter: "conv3d input",
1133                expected: vec![0, 0, 0, 0, self.in_channels], // [B, D, H, W, C_in]
1134                got: input_shape.to_vec(),
1135            });
1136        }
1137        let kernel_shape = self.weight.shape();
1138        let (out_data, out_shape) = yscv_kernels::conv3d(
1139            input.data(),
1140            input_shape,
1141            self.weight.data(),
1142            kernel_shape,
1143            self.stride,
1144            self.padding,
1145        );
1146        let mut result = Tensor::from_vec(out_shape, out_data)?;
1147        if let Some(ref b) = self.bias {
1148            let c_out = self.out_channels;
1149            let data = result.data_mut();
1150            let bias_data = b.data();
1151            for pixel in data.chunks_mut(c_out) {
1152                for (v, &bv) in pixel.iter_mut().zip(bias_data.iter()) {
1153                    *v += bv;
1154                }
1155            }
1156        }
1157        Ok(result)
1158    }
1159}