1use yscv_autograd::{Graph, NodeId};
2use yscv_tensor::Tensor;
3
4use super::ModelLayer;
5use crate::ModelError;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct DropoutLayer {
10 rate: f32,
11 training: bool,
12}
13
14impl DropoutLayer {
15 pub fn new(rate: f32) -> Result<Self, ModelError> {
16 if !rate.is_finite() || !(0.0..1.0).contains(&rate) {
17 return Err(ModelError::InvalidDropoutRate { rate });
18 }
19 Ok(Self {
20 rate,
21 training: true,
22 })
23 }
24
25 pub fn rate(&self) -> f32 {
26 self.rate
27 }
28
29 pub fn set_training(&mut self, training: bool) {
30 self.training = training;
31 }
32
33 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
34 if !self.training || self.rate == 0.0 {
35 return Ok(input);
36 }
37 let scale_factor = 1.0 / (1.0 - self.rate);
41 let scale = graph.constant(Tensor::scalar(scale_factor));
42 graph.mul(input, scale).map_err(Into::into)
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
48pub struct FlattenLayer;
49
50impl FlattenLayer {
51 pub fn new() -> Self {
52 Self
53 }
54
55 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
56 let shape = input.shape();
57 if shape.len() < 2 {
58 return Err(ModelError::InvalidFlattenShape {
59 got: shape.to_vec(),
60 });
61 }
62 let batch = shape[0];
63 let features: usize = shape[1..].iter().product();
64 input.reshape(vec![batch, features]).map_err(Into::into)
65 }
66
67 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
68 graph.flatten(input).map_err(Into::into)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
74pub struct SoftmaxLayer;
75
76impl SoftmaxLayer {
77 pub fn new() -> Self {
78 Self
79 }
80
81 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
82 yscv_kernels::softmax_last_dim(input).map_err(Into::into)
83 }
84
85 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
86 graph.softmax(input).map_err(Into::into)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub struct PixelShuffleLayer {
93 upscale_factor: usize,
94}
95
96impl PixelShuffleLayer {
97 pub fn new(upscale_factor: usize) -> Self {
98 Self { upscale_factor }
99 }
100
101 pub fn upscale_factor(&self) -> usize {
102 self.upscale_factor
103 }
104
105 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
106 let shape = input.shape();
107 if shape.len() != 4 {
108 return Err(ModelError::InvalidInputShape {
109 expected_features: 0,
110 got: shape.to_vec(),
111 });
112 }
113 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
114 let r = self.upscale_factor;
115 let out_c = c / (r * r);
116 let out_h = h * r;
117 let out_w = w * r;
118 let data = input.data();
119 let mut out = vec![0.0f32; batch * out_h * out_w * out_c];
120
121 for b in 0..batch {
122 for ih in 0..h {
123 for iw in 0..w {
124 for oc in 0..out_c {
125 for ry in 0..r {
126 for rx in 0..r {
127 let ic = oc * r * r + ry * r + rx;
128 let oh = ih * r + ry;
129 let ow = iw * r + rx;
130 out[((b * out_h + oh) * out_w + ow) * out_c + oc] =
131 data[((b * h + ih) * w + iw) * c + ic];
132 }
133 }
134 }
135 }
136 }
137 }
138 Ok(Tensor::from_vec(vec![batch, out_h, out_w, out_c], out)?)
139 }
140
141 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
142 graph
143 .pixel_shuffle(input, self.upscale_factor)
144 .map_err(Into::into)
145 }
146}
147
148#[derive(Debug, Clone, Copy, PartialEq)]
150pub struct UpsampleLayer {
151 scale_factor: usize,
152 bilinear: bool,
153}
154
155impl UpsampleLayer {
156 pub fn new(scale_factor: usize, bilinear: bool) -> Self {
157 Self {
158 scale_factor,
159 bilinear,
160 }
161 }
162
163 pub fn scale_factor(&self) -> usize {
164 self.scale_factor
165 }
166 pub fn is_bilinear(&self) -> bool {
167 self.bilinear
168 }
169
170 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
171 let shape = input.shape();
172 if shape.len() != 4 {
173 return Err(ModelError::InvalidInputShape {
174 expected_features: 0,
175 got: shape.to_vec(),
176 });
177 }
178 let (batch, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
179 let r = self.scale_factor;
180 let out_h = h * r;
181 let out_w = w * r;
182 let data = input.data();
183 let mut out = vec![0.0f32; batch * out_h * out_w * c];
184
185 if !self.bilinear {
186 for b in 0..batch {
188 for oh in 0..out_h {
189 let ih = oh / r;
190 for ow in 0..out_w {
191 let iw = ow / r;
192 let src = ((b * h + ih) * w + iw) * c;
193 let dst = ((b * out_h + oh) * out_w + ow) * c;
194 out[dst..dst + c].copy_from_slice(&data[src..src + c]);
195 }
196 }
197 }
198 } else {
199 for b in 0..batch {
201 for oh in 0..out_h {
202 let src_y = (oh as f32 + 0.5) / r as f32 - 0.5;
203 let y0 = (src_y.floor() as usize).min(h - 1);
204 let y1 = (y0 + 1).min(h - 1);
205 let fy = src_y - y0 as f32;
206 for ow in 0..out_w {
207 let src_x = (ow as f32 + 0.5) / r as f32 - 0.5;
208 let x0 = (src_x.floor() as usize).min(w - 1);
209 let x1 = (x0 + 1).min(w - 1);
210 let fx = src_x - x0 as f32;
211 for ch in 0..c {
212 let v00 = data[((b * h + y0) * w + x0) * c + ch];
213 let v10 = data[((b * h + y0) * w + x1) * c + ch];
214 let v01 = data[((b * h + y1) * w + x0) * c + ch];
215 let v11 = data[((b * h + y1) * w + x1) * c + ch];
216 out[((b * out_h + oh) * out_w + ow) * c + ch] =
217 v00 * (1.0 - fx) * (1.0 - fy)
218 + v10 * fx * (1.0 - fy)
219 + v01 * (1.0 - fx) * fy
220 + v11 * fx * fy;
221 }
222 }
223 }
224 }
225 }
226 Ok(Tensor::from_vec(vec![batch, out_h, out_w, c], out)?)
227 }
228
229 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
230 if self.bilinear {
231 return Err(ModelError::InferenceOnlyLayer);
233 }
234 graph
235 .upsample_nearest(input, self.scale_factor)
236 .map_err(Into::into)
237 }
238}
239
240#[derive(Debug, Clone)]
243pub struct ResidualBlock {
244 layers: Vec<ModelLayer>,
245}
246
247impl ResidualBlock {
248 pub fn new(layers: Vec<ModelLayer>) -> Self {
250 Self { layers }
251 }
252
253 pub fn layers(&self) -> &[ModelLayer] {
255 &self.layers
256 }
257
258 pub fn forward_inference(&self, input: &Tensor) -> Result<Tensor, ModelError> {
261 let mut current = input.clone();
262 for layer in &self.layers {
263 current = layer.forward_inference(¤t)?;
264 }
265 current.add(input).map_err(ModelError::Tensor)
266 }
267
268 pub fn forward(&self, graph: &mut Graph, input: NodeId) -> Result<NodeId, ModelError> {
270 let mut current = input;
271 for layer in &self.layers {
272 current = layer.forward(graph, current)?;
273 }
274 graph.add(current, input).map_err(Into::into)
275 }
276}
277
278#[derive(Debug, Clone, PartialEq)]
283pub struct MaskHead {
284 conv_weights: Vec<Tensor>,
286 class_conv: Tensor,
288 channels: usize,
289 num_classes: usize,
290 mask_size: usize,
291}
292
293impl MaskHead {
294 pub fn new(
296 in_channels: usize,
297 channels: usize,
298 num_classes: usize,
299 mask_size: usize,
300 num_conv: usize,
301 ) -> Result<Self, ModelError> {
302 let mut conv_weights = Vec::with_capacity(num_conv);
303 for i in 0..num_conv {
304 let c_in = if i == 0 { in_channels } else { channels };
305 conv_weights.push(Tensor::zeros(vec![3, 3, c_in, channels])?);
306 }
307 let class_conv = Tensor::zeros(vec![1, 1, channels, num_classes])?;
308 Ok(Self {
309 conv_weights,
310 class_conv,
311 channels,
312 num_classes,
313 mask_size,
314 })
315 }
316
317 pub fn num_classes(&self) -> usize {
318 self.num_classes
319 }
320 pub fn mask_size(&self) -> usize {
321 self.mask_size
322 }
323 pub fn channels(&self) -> usize {
324 self.channels
325 }
326
327 pub fn forward_inference(&self, roi_features: &Tensor) -> Result<Tensor, ModelError> {
329 let mut x = roi_features.clone();
330 for w in &self.conv_weights {
332 x = yscv_kernels::conv2d_nhwc(&x, w, None, 1, 1)?;
333 let data = x.data_mut();
335 for v in data.iter_mut() {
336 *v = v.max(0.0);
337 }
338 }
339 let shape = x.shape();
341 if shape.len() == 4 {
342 let (n, h, w, c) = (shape[0], shape[1], shape[2], shape[3]);
343 let nh = h * 2;
344 let nw = w * 2;
345 let mut up = vec![0.0f32; n * nh * nw * c];
346 for bi in 0..n {
347 for yi in 0..nh {
348 for xi in 0..nw {
349 let sy = yi / 2;
350 let sx = xi / 2;
351 let src = bi * h * w * c + sy * w * c + sx * c;
352 let dst = bi * nh * nw * c + yi * nw * c + xi * c;
353 up[dst..dst + c].copy_from_slice(&x.data()[src..src + c]);
354 }
355 }
356 }
357 x = Tensor::from_vec(vec![n, nh, nw, c], up)?;
358 }
359 yscv_kernels::conv2d_nhwc(&x, &self.class_conv, None, 1, 1).map_err(Into::into)
361 }
362}