Skip to main content

yscv_model/layers/
misc.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use super::ModelLayer;
5use crate::ModelError;
6
7/// Dropout layer (training vs eval mode).
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct DropoutLayer {
10    rate: f32,
11    training: bool,
12}
13
14impl DropoutLayer {
15    pub fn new(rate: f32) -> Result<Self, ModelError> {
16        if !rate.is_finite() || !(0.0..1.0).contains(&rate) {
17            return Err(ModelError::InvalidDropoutRate { rate });
18        }
19        Ok(Self {
20            rate,
21            training: true,
22        })
23    }
24
25    pub fn rate(&self) -> f32 {
26        self.rate
27    }
28
29    pub fn set_training(&mut self, training: bool) {
30        self.training = training;
31    }
32
33    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
34        if !self.training || self.rate == 0.0 {
35            return Ok(input);
36        }
37        // During training with non-zero rate, apply inverted dropout via scaling.
38        // For deterministic autograd compatibility, we scale by (1-rate) as an
39        // approximate expectation-preserving proxy without random masking.
40        let scale_factor = 1.0 / (1.0 - self.rate);
41        let scale = graph.constant(Tensor::scalar(scale_factor));
42        graph.mul(input, scale).map_err(Into::into)
43    }
44}
45
46/// Flatten layer: reshapes NHWC `[N, H, W, C]` to `[N, H*W*C]` for dense layer input.
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
48pub struct FlattenLayer;
49
50impl FlattenLayer {
51    pub fn new() -> Self {
52        Self
53    }
54
55    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
56        let shape = input.shape();
57        if shape.len() < 2 {
58            return Err(ModelError::InvalidFlattenShape {
59                got: shape.to_vec(),
60            });
61        }
62        let batch = shape[0];
63        let features: usize = shape[1..].iter().product();
64        input.reshape(vec![batch, features]).map_err(Into::into)
65    }
66
67    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
68        graph.flatten(input).map_err(Into::into)
69    }
70}
71
72/// Softmax layer over the last dimension.
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
74pub struct SoftmaxLayer;
75
76impl SoftmaxLayer {
77    pub fn new() -> Self {
78        Self
79    }
80
81    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
82        yscv_kernels::softmax_last_dim(input).map_err(Into::into)
83    }
84
85    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
86        graph.softmax(input).map_err(Into::into)
87    }
88}
89
90/// Pixel shuffle / sub-pixel convolution: rearranges `[N, H, W, C*r^2]` -> `[N, H*r, W*r, C]`.
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub struct PixelShuffleLayer {
93    upscale_factor: usize,
94}
95
96impl PixelShuffleLayer {
97    pub fn new(upscale_factor: usize) -> Self {
98        Self { upscale_factor }
99    }
100
101    pub fn upscale_factor(&self) -> usize {
102        self.upscale_factor
103    }
104
105    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
106        let shape = input.shape();
107        if shape.len() != 4 {
108            return Err(ModelError::InvalidInputShape {
109                expected_features: 0,
110                got: shape.to_vec(),
111            });
112        }
113        let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
114        let r = self.upscale_factor;
115        let out_c = c / (r * r);
116        let out_h = h * r;
117        let out_w = w * r;
118        let data = input.data();
119        let mut out = vec![0.0f32; batch * out_h * out_w * out_c];
120
121        for b in 0..batch {
122            for ih in 0..h {
123                for iw in 0..w {
124                    for oc in 0..out_c {
125                        for ry in 0..r {
126                            for rx in 0..r {
127                                let ic = oc * r * r + ry * r + rx;
128                                let oh = ih * r + ry;
129                                let ow = iw * r + rx;
130                                out[((b * out_h + oh) * out_w + ow) * out_c + oc] =
131                                    data[((b * h + ih) * w + iw) * c + ic];
132                            }
133                        }
134                    }
135                }
136            }
137        }
138        Ok(Tensor::from_vec(vec![batch, out_h, out_w, out_c], out)?)
139    }
140
141    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
142        graph
143            .pixel_shuffle(input, self.upscale_factor)
144            .map_err(Into::into)
145    }
146}
147
148/// Upsample layer: nearest or bilinear upsampling.
149#[derive(Debug, Clone, Copy, PartialEq)]
150pub struct UpsampleLayer {
151    scale_factor: usize,
152    bilinear: bool,
153}
154
155impl UpsampleLayer {
156    pub fn new(scale_factor: usize, bilinear: bool) -> Self {
157        Self {
158            scale_factor,
159            bilinear,
160        }
161    }
162
163    pub fn scale_factor(&self) -> usize {
164        self.scale_factor
165    }
166    pub fn is_bilinear(&self) -> bool {
167        self.bilinear
168    }
169
170    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
171        let shape = input.shape();
172        if shape.len() != 4 {
173            return Err(ModelError::InvalidInputShape {
174                expected_features: 0,
175                got: shape.to_vec(),
176            });
177        }
178        let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
179        let r = self.scale_factor;
180        let out_h = h * r;
181        let out_w = w * r;
182        let data = input.data();
183        let mut out = vec![0.0f32; batch * out_h * out_w * c];
184
185        if !self.bilinear {
186            // Nearest
187            for b in 0..batch {
188                for oh in 0..out_h {
189                    let ih = oh / r;
190                    for ow in 0..out_w {
191                        let iw = ow / r;
192                        let src = ((b * h + ih) * w + iw) * c;
193                        let dst = ((b * out_h + oh) * out_w + ow) * c;
194                        out[dst..dst + c].copy_from_slice(&data[src..src + c]);
195                    }
196                }
197            }
198        } else {
199            // Bilinear
200            for b in 0..batch {
201                for oh in 0..out_h {
202                    let src_y = (oh as f32 + 0.5) / r as f32 - 0.5;
203                    let y0 = (src_y.floor() as usize).min(h - 1);
204                    let y1 = (y0 + 1).min(h - 1);
205                    let fy = src_y - y0 as f32;
206                    for ow in 0..out_w {
207                        let src_x = (ow as f32 + 0.5) / r as f32 - 0.5;
208                        let x0 = (src_x.floor() as usize).min(w - 1);
209                        let x1 = (x0 + 1).min(w - 1);
210                        let fx = src_x - x0 as f32;
211                        for ch in 0..c {
212                            let v00 = data[((b * h + y0) * w + x0) * c + ch];
213                            let v10 = data[((b * h + y0) * w + x1) * c + ch];
214                            let v01 = data[((b * h + y1) * w + x0) * c + ch];
215                            let v11 = data[((b * h + y1) * w + x1) * c + ch];
216                            out[((b * out_h + oh) * out_w + ow) * c + ch] =
217                                v00 * (1.0 - fx) * (1.0 - fy)
218                                    + v10 * fx * (1.0 - fy)
219                                    + v01 * (1.0 - fx) * fy
220                                    + v11 * fx * fy;
221                        }
222                    }
223                }
224            }
225        }
226        Ok(Tensor::from_vec(vec![batch, out_h, out_w, c], out)?)
227    }
228
229    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
230        if self.bilinear {
231            // For bilinear, fall back to inference-only for now.
232            return Err(ModelError::InferenceOnlyLayer);
233        }
234        graph
235            .upsample_nearest(input, self.scale_factor)
236            .map_err(Into::into)
237    }
238}
239
240/// Residual block: runs input through a sequence of layers, then adds the
241/// original input as a skip connection (`output = layers(input) + input`).
242#[derive(Debug, Clone)]
243pub struct ResidualBlock {
244    layers: Vec<ModelLayer>,
245}
246
247impl ResidualBlock {
248    /// Creates a new residual block wrapping the given layers.
249    pub fn new(layers: Vec<ModelLayer>) -> Self {
250        Self { layers }
251    }
252
253    /// Returns a reference to the inner layers.
254    pub fn layers(&self) -> &[ModelLayer] {
255        &self.layers
256    }
257
258    /// Runs inference: passes `input` through all inner layers sequentially,
259    /// then adds the skip connection (`output = layers_output + input`).
260    pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
261        let mut current = input.clone();
262        for layer in &self.layers {
263            current = layer.forward_inference(&current)?;
264        }
265        current.add(input).map_err(ModelError::Tensor)
266    }
267
268    /// Graph-mode forward: passes `input` through all inner layers, then adds skip connection.
269    pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
270        let mut current = input;
271        for layer in &self.layers {
272            current = layer.forward(graph, current)?;
273        }
274        graph.add(current, input).map_err(Into::into)
275    }
276}
277
278/// Mask prediction head for instance segmentation (Mask R-CNN style).
279///
280/// Takes RoI-pooled features `[N, H, W, C]` and produces binary masks
281/// `[N, mask_h, mask_w, num_classes]` via a series of conv layers + upsample.
282#[derive(Debug, Clone, PartialEq)]
283pub struct MaskHead {
284    /// 4 conv layers: each [3, 3, channels, channels]
285    conv_weights: Vec<Tensor>,
286    /// Final 1×1 conv for class prediction: [1, 1, channels, num_classes]
287    class_conv: Tensor,
288    channels: usize,
289    num_classes: usize,
290    mask_size: usize,
291}
292
293impl MaskHead {
294    /// Create a mask head with `num_conv` intermediate conv layers.
295    pub fn new(
296        in_channels: usize,
297        channels: usize,
298        num_classes: usize,
299        mask_size: usize,
300        num_conv: usize,
301    ) -> Result<Self, ModelError> {
302        let mut conv_weights = Vec::with_capacity(num_conv);
303        for i in 0..num_conv {
304            let c_in = if i == 0 { in_channels } else { channels };
305            conv_weights.push(Tensor::zeros(vec![3, 3, c_in, channels])?);
306        }
307        let class_conv = Tensor::zeros(vec![1, 1, channels, num_classes])?;
308        Ok(Self {
309            conv_weights,
310            class_conv,
311            channels,
312            num_classes,
313            mask_size,
314        })
315    }
316
317    pub fn num_classes(&self) -> usize {
318        self.num_classes
319    }
320    pub fn mask_size(&self) -> usize {
321        self.mask_size
322    }
323    pub fn channels(&self) -> usize {
324        self.channels
325    }
326
327    /// Forward pass: conv layers → ReLU → upsample → class prediction.
328    pub fn forward_inference(&self, roi_features: &Tensor) -> Result<Tensor, ModelError> {
329        let mut x = roi_features.clone();
330        // Apply conv layers with ReLU
331        for w in &self.conv_weights {
332            x = yscv_kernels::conv2d_nhwc(&x, w, None, 1, 1)?;
333            // In-place ReLU
334            let data = x.data_mut();
335            for v in data.iter_mut() {
336                *v = v.max(0.0);
337            }
338        }
339        // 2× bilinear upsample (simple nearest-neighbor for now)
340        let shape = x.shape();
341        if shape.len() == 4 {
342            let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
343            let nh = h * 2;
344            let nw = w * 2;
345            let mut up = vec![0.0f32; n * nh * nw * c];
346            for bi in 0..n {
347                for yi in 0..nh {
348                    for xi in 0..nw {
349                        let sy = yi / 2;
350                        let sx = xi / 2;
351                        let src = bi * h * w * c + sy * w * c + sx * c;
352                        let dst = bi * nh * nw * c + yi * nw * c + xi * c;
353                        up[dst..dst + c].copy_from_slice(&x.data()[src..src + c]);
354                    }
355                }
356            }
357            x = Tensor::from_vec(vec![n, nh, nw, c], up)?;
358        }
359        // Final 1×1 class prediction conv
360        yscv_kernels::conv2d_nhwc(&x, &self.class_conv, None, 1, 1).map_err(Into::into)
361    }
362}