1use ndarray::{s, Array2, Array3, Array4, ArrayD};
13use scirs2_neural::layers::{BatchNorm, Conv2D, MaxPool2D, PaddingMode, Sequential};
14use scirs2_neural::losses::CrossEntropyLoss;
15use scirs2_neural::prelude::*;
16
17type StdResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
19use rand::prelude::*;
20use rand::rngs::SmallRng;
21
22#[derive(Debug, Clone)]
24pub struct SegmentationConfig {
25 pub num_classes: usize,
26 pub input_size: (usize, usize),
27 pub encoder_channels: Vec<usize>,
28 pub decoder_channels: Vec<usize>,
29 pub skip_connections: bool,
30}
31
32impl Default for SegmentationConfig {
33 fn default() -> Self {
34 Self {
35 num_classes: 4, input_size: (128, 128),
37 encoder_channels: vec![64, 128, 256, 512],
38 decoder_channels: vec![256, 128, 64, 32],
39 skip_connections: true,
40 }
41 }
42}
43
44pub struct SegmentationDataset {
46 config: SegmentationConfig,
47 rng: SmallRng,
48}
49
50impl SegmentationDataset {
51 pub fn new(config: SegmentationConfig, seed: u64) -> Self {
52 Self {
53 config,
54 rng: SmallRng::seed_from_u64(seed),
55 }
56 }
57
58 pub fn generate_sample(&mut self) -> (Array3<f32>, Array2<usize>) {
60 let (height, width) = self.config.input_size;
61 let mut image = Array3::<f32>::zeros((3, height, width)); let mut mask = Array2::<usize>::zeros((height, width));
63
64 for c in 0..3 {
66 for i in 0..height {
67 for j in 0..width {
68 image[[c, i, j]] = self.rng.random_range(0.1..0.3);
69 }
70 }
71 }
72
73 let num_shapes = self.rng.random_range(3..8);
75
76 for _ in 0..num_shapes {
77 let shape_type = self.rng.random_range(0..3);
78 let class_id = self.rng.random_range(1..self.config.num_classes);
79
80 let color = match class_id {
81 1 => [0.8, 0.2, 0.2], 2 => [0.2, 0.8, 0.2], 3 => [0.2, 0.2, 0.8], _ => [0.8, 0.8, 0.2], };
86
87 match shape_type {
88 0 => {
89 let rect_width = self.rng.random_range(15..40);
91 let rect_height = self.rng.random_range(15..40);
92 let start_x = self.rng.random_range(0..(width.saturating_sub(rect_width)));
93 let start_y = self
94 .rng
95 .random_range(0..(height.saturating_sub(rect_height)));
96
97 for i in start_y..(start_y + rect_height).min(height) {
98 for j in start_x..(start_x + rect_width).min(width) {
99 mask[[i, j]] = class_id;
100 for c in 0..3 {
101 image[[c, i, j]] = color[c] + self.rng.random_range(-0.1..0.1);
102 }
103 }
104 }
105 }
106 1 => {
107 let radius = self.rng.random_range(8..25) as f32;
109 let center_x = self
110 .rng
111 .random_range(radius as usize..(width - radius as usize))
112 as f32;
113 let center_y = self
114 .rng
115 .random_range(radius as usize..(height - radius as usize))
116 as f32;
117
118 for i in 0..height {
119 for j in 0..width {
120 let dx = j as f32 - center_x;
121 let dy = i as f32 - center_y;
122 if dx * dx + dy * dy <= radius * radius {
123 mask[[i, j]] = class_id;
124 for c in 0..3 {
125 image[[c, i, j]] = color[c] + self.rng.random_range(-0.1..0.1);
126 }
127 }
128 }
129 }
130 }
131 _ => {
132 let size = self.rng.random_range(15..35);
134 let center_x = self.rng.random_range(size / 2..(width - size / 2));
135 let center_y = self.rng.random_range(size / 2..(height - size / 2));
136
137 for i in (center_y.saturating_sub(size / 2))..(center_y + size / 2).min(height)
138 {
139 let row_width = (size as f32
140 * (1.0 - (i as f32 - center_y as f32).abs() / (size as f32 / 2.0)))
141 as usize;
142 for j in (center_x.saturating_sub(row_width / 2))
143 ..(center_x + row_width / 2).min(width)
144 {
145 mask[[i, j]] = class_id;
146 for c in 0..3 {
147 image[[c, i, j]] = color[c] + self.rng.random_range(-0.1..0.1);
148 }
149 }
150 }
151 }
152 }
153 }
154
155 (image, mask)
156 }
157
158 pub fn generate_batch(&mut self, batch_size: usize) -> (Array4<f32>, Array3<usize>) {
160 let (height, width) = self.config.input_size;
161 let mut images = Array4::<f32>::zeros((batch_size, 3, height, width));
162 let mut masks = Array3::<usize>::zeros((batch_size, height, width));
163
164 for i in 0..batch_size {
165 let (image, mask) = self.generate_sample();
166 images.slice_mut(s![i, .., .., ..]).assign(&image);
167 masks.slice_mut(s![i, .., ..]).assign(&mask);
168 }
169
170 (images, masks)
171 }
172}
173
174pub struct EncoderBlock {
176 conv1: Conv2D<f32>,
177 bn1: BatchNorm<f32>,
178 conv2: Conv2D<f32>,
179 bn2: BatchNorm<f32>,
180 pool: MaxPool2D<f32>,
181}
182
183impl EncoderBlock {
184 pub fn new(in_channels: usize, out_channels: usize, rng: &mut SmallRng) -> StdResult<Self> {
185 Ok(Self {
186 conv1: Conv2D::new(
187 in_channels,
188 out_channels,
189 (3, 3),
190 (1, 1),
191 PaddingMode::Same,
192 rng,
193 )?,
194 bn1: BatchNorm::new(out_channels, 0.1, 1e-5, rng)?,
195 conv2: Conv2D::new(
196 out_channels,
197 out_channels,
198 (3, 3),
199 (1, 1),
200 PaddingMode::Same,
201 rng,
202 )?,
203 bn2: BatchNorm::new(out_channels, 0.1, 1e-5, rng)?,
204 pool: MaxPool2D::new((2, 2), (2, 2), None)?,
205 })
206 }
207
208 pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<(ArrayD<f32>, ArrayD<f32>)> {
209 let x = self.conv1.forward(input)?;
211 let x = self.bn1.forward(&x)?;
212 let x = self.conv2.forward(&x)?;
216 let skip = self.bn2.forward(&x)?;
217 let pooled = self.pool.forward(&skip)?;
221
222 Ok((pooled, skip))
223 }
224}
225
226pub struct DecoderBlock {
228 conv1: Conv2D<f32>,
229 bn1: BatchNorm<f32>,
230 conv2: Conv2D<f32>,
231 bn2: BatchNorm<f32>,
232}
233
234impl DecoderBlock {
235 pub fn new(in_channels: usize, out_channels: usize, rng: &mut SmallRng) -> StdResult<Self> {
236 Ok(Self {
237 conv1: Conv2D::new(
238 in_channels,
239 out_channels,
240 (3, 3),
241 (1, 1),
242 PaddingMode::Same,
243 rng,
244 )?,
245 bn1: BatchNorm::new(out_channels, 0.1, 1e-5, rng)?,
246 conv2: Conv2D::new(
247 out_channels,
248 out_channels,
249 (3, 3),
250 (1, 1),
251 PaddingMode::Same,
252 rng,
253 )?,
254 bn2: BatchNorm::new(out_channels, 0.1, 1e-5, rng)?,
255 })
256 }
257
258 pub fn forward(
259 &self,
260 input: &ArrayD<f32>,
261 skip: Option<&ArrayD<f32>>,
262 ) -> StdResult<ArrayD<f32>> {
263 let upsampled = self.upsample(input)?;
265
266 let x = if let Some(skip_tensor) = skip {
268 self.concatenate(&upsampled, skip_tensor)?
269 } else {
270 upsampled
271 };
272
273 let x = self.conv1.forward(&x)?;
275 let x = self.bn1.forward(&x)?;
276 let x = self.conv2.forward(&x)?;
280 let x = self.bn2.forward(&x)?;
281 Ok(x)
284 }
285
286 fn upsample(&self, input: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
287 let shape = input.shape();
289 let batch_size = shape[0];
290 let channels = shape[1];
291 let height = shape[2];
292 let width = shape[3];
293
294 let mut upsampled = Array4::<f32>::zeros((batch_size, channels, height * 2, width * 2));
295
296 for b in 0..batch_size {
297 for c in 0..channels {
298 for i in 0..height {
299 for j in 0..width {
300 let value = input[[b, c, i, j]];
301 upsampled[[b, c, i * 2, j * 2]] = value;
302 upsampled[[b, c, i * 2, j * 2 + 1]] = value;
303 upsampled[[b, c, i * 2 + 1, j * 2]] = value;
304 upsampled[[b, c, i * 2 + 1, j * 2 + 1]] = value;
305 }
306 }
307 }
308 }
309
310 Ok(upsampled.into_dyn())
311 }
312
313 fn concatenate(&self, input1: &ArrayD<f32>, input2: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
314 let shape1 = input1.shape();
316 let shape2 = input2.shape();
317
318 if shape1[0] != shape2[0] || shape1[2] != shape2[2] || shape1[3] != shape2[3] {
319 return Err("Shapes incompatible for concatenation".into());
320 }
321
322 let batch_size = shape1[0];
323 let channels1 = shape1[1];
324 let channels2 = shape2[1];
325 let height = shape1[2];
326 let width = shape1[3];
327
328 let mut concatenated =
329 Array4::<f32>::zeros((batch_size, channels1 + channels2, height, width));
330
331 for b in 0..batch_size {
333 for c in 0..channels1 {
334 for i in 0..height {
335 for j in 0..width {
336 concatenated[[b, c, i, j]] = input1[[b, c, i, j]];
337 }
338 }
339 }
340 }
341
342 for b in 0..batch_size {
344 for c in 0..channels2 {
345 for i in 0..height {
346 for j in 0..width {
347 concatenated[[b, channels1 + c, i, j]] = input2[[b, c, i, j]];
348 }
349 }
350 }
351 }
352
353 Ok(concatenated.into_dyn())
354 }
355}
356
357pub struct UNetModel {
359 encoders: Vec<EncoderBlock>,
360 decoders: Vec<DecoderBlock>,
361 bottleneck: Sequential<f32>,
362 final_conv: Conv2D<f32>,
363 config: SegmentationConfig,
364}
365
366impl UNetModel {
367 pub fn new(config: SegmentationConfig, rng: &mut SmallRng) -> StdResult<Self> {
368 let mut encoders = Vec::new();
369 let mut decoders = Vec::new();
370
371 let mut in_channels = 3; for &out_channels in &config.encoder_channels {
374 encoders.push(EncoderBlock::new(in_channels, out_channels, rng)?);
375 in_channels = out_channels;
376 }
377
378 let bottleneck_channels = config.encoder_channels.last().copied().unwrap_or(512);
380 let mut bottleneck = Sequential::new();
381 bottleneck.add(Conv2D::new(
382 bottleneck_channels,
383 bottleneck_channels * 2,
384 (3, 3),
385 (1, 1),
386 PaddingMode::Same,
387 rng,
388 )?);
389 bottleneck.add(BatchNorm::new(bottleneck_channels * 2, 0.1, 1e-5, rng)?);
390 bottleneck.add(Conv2D::new(
391 bottleneck_channels * 2,
392 bottleneck_channels,
393 (3, 3),
394 (1, 1),
395 PaddingMode::Same,
396 rng,
397 )?);
398 bottleneck.add(BatchNorm::new(bottleneck_channels, 0.1, 1e-5, rng)?);
399
400 in_channels = bottleneck_channels;
402 for (i, &out_channels) in config.decoder_channels.iter().enumerate() {
403 let decoder_in_channels =
404 if config.skip_connections && i < config.encoder_channels.len() {
405 let encoder_idx = config.encoder_channels.len() - 1 - i;
407 in_channels + config.encoder_channels[encoder_idx]
408 } else {
409 in_channels
410 };
411 decoders.push(DecoderBlock::new(decoder_in_channels, out_channels, rng)?);
412 in_channels = out_channels;
413 }
414
415 let final_channels = config.decoder_channels.last().copied().unwrap_or(32);
417 let final_conv = Conv2D::new(
418 final_channels,
419 config.num_classes,
420 (1, 1),
421 (1, 1),
422 PaddingMode::Same,
423 rng,
424 )?;
425
426 Ok(Self {
427 encoders,
428 decoders,
429 bottleneck,
430 final_conv,
431 config,
432 })
433 }
434
435 pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
436 let mut x = input.clone();
437 let mut skip_connections = Vec::new();
438
439 for encoder in &self.encoders {
441 let (encoded, skip) = encoder.forward(&x)?;
442 skip_connections.push(skip);
443 x = encoded;
444 }
445
446 x = self.bottleneck.forward(&x)?;
448
449 skip_connections.reverse(); for (i, decoder) in self.decoders.iter().enumerate() {
452 let skip = if self.config.skip_connections && i < skip_connections.len() {
453 Some(&skip_connections[i])
454 } else {
455 None
456 };
457 x = decoder.forward(&x, skip)?;
458 }
459
460 let output = self.final_conv.forward(&x)?;
462
463 Ok(output)
464 }
465}
466
467pub struct SegmentationMetrics {
469 num_classes: usize,
470}
471
472impl SegmentationMetrics {
473 pub fn new(num_classes: usize) -> Self {
474 Self { num_classes }
475 }
476
477 pub fn pixel_accuracy(&self, predictions: &Array3<usize>, ground_truth: &Array3<usize>) -> f32 {
479 let mut correct = 0;
480 let mut total = 0;
481
482 for (pred, gt) in predictions.iter().zip(ground_truth.iter()) {
483 if pred == gt {
484 correct += 1;
485 }
486 total += 1;
487 }
488
489 if total > 0 {
490 correct as f32 / total as f32
491 } else {
492 0.0
493 }
494 }
495
496 pub fn mean_iou(&self, predictions: &Array3<usize>, ground_truth: &Array3<usize>) -> f32 {
498 let mut class_ious = Vec::new();
499
500 for class_id in 0..self.num_classes {
501 let iou = self.class_iou(predictions, ground_truth, class_id);
502 if !iou.is_nan() {
503 class_ious.push(iou);
504 }
505 }
506
507 if class_ious.is_empty() {
508 0.0
509 } else {
510 class_ious.iter().sum::<f32>() / class_ious.len() as f32
511 }
512 }
513
514 fn class_iou(
516 &self,
517 predictions: &Array3<usize>,
518 ground_truth: &Array3<usize>,
519 class_id: usize,
520 ) -> f32 {
521 let mut intersection = 0;
522 let mut union = 0;
523
524 for (pred, gt) in predictions.iter().zip(ground_truth.iter()) {
525 let pred_match = *pred == class_id;
526 let gt_match = *gt == class_id;
527
528 if pred_match && gt_match {
529 intersection += 1;
530 }
531 if pred_match || gt_match {
532 union += 1;
533 }
534 }
535
536 if union > 0 {
537 intersection as f32 / union as f32
538 } else {
539 f32::NAN
540 }
541 }
542
543 pub fn confusion_matrix(
545 &self,
546 predictions: &Array3<usize>,
547 ground_truth: &Array3<usize>,
548 ) -> Array2<usize> {
549 let mut matrix = Array2::<usize>::zeros((self.num_classes, self.num_classes));
550
551 for (pred, gt) in predictions.iter().zip(ground_truth.iter()) {
552 if *pred < self.num_classes && *gt < self.num_classes {
553 matrix[[*gt, *pred]] += 1;
554 }
555 }
556
557 matrix
558 }
559}
560
561fn logits_to_predictions(logits: &ArrayD<f32>) -> Array3<usize> {
563 let shape = logits.shape();
564 let batch_size = shape[0];
565 let num_classes = shape[1];
566 let height = shape[2];
567 let width = shape[3];
568
569 let mut predictions = Array3::<usize>::zeros((batch_size, height, width));
570
571 for b in 0..batch_size {
572 for i in 0..height {
573 for j in 0..width {
574 let mut best_class = 0;
575 let mut best_score = logits[[b, 0, i, j]];
576
577 for c in 1..num_classes {
578 let score = logits[[b, c, i, j]];
579 if score > best_score {
580 best_score = score;
581 best_class = c;
582 }
583 }
584
585 predictions[[b, i, j]] = best_class;
586 }
587 }
588 }
589
590 predictions
591}
592
593fn masks_to_targets(masks: &Array3<usize>, num_classes: usize) -> ArrayD<f32> {
595 let shape = masks.shape();
596 let batch_size = shape[0];
597 let height = shape[1];
598 let width = shape[2];
599
600 let mut targets = Array4::<f32>::zeros((batch_size, num_classes, height, width));
601
602 for b in 0..batch_size {
603 for i in 0..height {
604 for j in 0..width {
605 let class_id = masks[[b, i, j]];
606 if class_id < num_classes {
607 targets[[b, class_id, i, j]] = 1.0;
608 }
609 }
610 }
611 }
612
613 targets.into_dyn()
614}
615
616fn train_segmentation_model() -> StdResult<()> {
618 println!("šØ Starting Semantic Segmentation Training");
619
620 let mut rng = SmallRng::seed_from_u64(42);
621 let config = SegmentationConfig::default();
622
623 println!("š Starting model training...");
624
625 println!("šļø Building U-Net segmentation model...");
627 let model = UNetModel::new(config.clone(), &mut rng)?;
628 println!("ā
Model created with {} classes", config.num_classes);
629
630 let mut dataset = SegmentationDataset::new(config.clone(), 123);
632
633 let loss_fn = CrossEntropyLoss::new(1e-7);
635
636 let metrics = SegmentationMetrics::new(config.num_classes);
638
639 println!("š Training configuration:");
640 println!(" - Input size: {:?}", config.input_size);
641 println!(" - Number of classes: {}", config.num_classes);
642 println!(" - Encoder channels: {:?}", config.encoder_channels);
643 println!(" - Decoder channels: {:?}", config.decoder_channels);
644 println!(" - Skip connections: {}", config.skip_connections);
645
646 let num_epochs = 15;
648 let batch_size = 2; let _learning_rate = 0.001;
650
651 for epoch in 0..num_epochs {
652 println!("\nš Epoch {}/{}", epoch + 1, num_epochs);
653
654 let mut epoch_loss = 0.0;
655 let num_batches = 10; for batch_idx in 0..num_batches {
658 let (images, masks) = dataset.generate_batch(batch_size);
660 let images_dyn = images.into_dyn();
661
662 let logits = model.forward(&images_dyn)?;
664
665 let targets = masks_to_targets(&masks, config.num_classes);
667
668 let batch_loss = loss_fn.forward(&logits, &targets)?;
670 epoch_loss += batch_loss;
671
672 if batch_idx % 5 == 0 {
673 print!(
674 "š Batch {}/{} - Loss: {:.4} \r",
675 batch_idx + 1,
676 num_batches,
677 batch_loss
678 );
679 }
680 }
681
682 let avg_loss = epoch_loss / num_batches as f32;
683 println!(
684 "ā
Epoch {} completed - Average Loss: {:.4}",
685 epoch + 1,
686 avg_loss
687 );
688
689 if (epoch + 1) % 5 == 0 {
691 println!("š Running evaluation...");
692
693 let (val_images, val_masks) = dataset.generate_batch(batch_size);
695 let val_images_dyn = val_images.into_dyn();
696
697 let val_logits = model.forward(&val_images_dyn)?;
699 let predictions = logits_to_predictions(&val_logits);
700
701 let pixel_acc = metrics.pixel_accuracy(&predictions, &val_masks);
703 let miou = metrics.mean_iou(&predictions, &val_masks);
704
705 println!("š Validation metrics:");
706 println!(" - Pixel Accuracy: {:.4}", pixel_acc);
707 println!(" - Mean IoU: {:.4}", miou);
708
709 println!(" - Class-wise IoU:");
711 for class_id in 0..config.num_classes {
712 let class_iou = metrics.class_iou(&predictions, &val_masks, class_id);
713 if !class_iou.is_nan() {
714 println!(" Class {}: {:.4}", class_id, class_iou);
715 }
716 }
717 }
718 }
719
720 println!("\nš Semantic segmentation training completed!");
721
722 println!("š¬ Final evaluation...");
724 let (test_images, test_masks) = dataset.generate_batch(4);
725 let test_images_dyn = test_images.into_dyn();
726
727 let test_logits = model.forward(&test_images_dyn)?;
728 let final_predictions = logits_to_predictions(&test_logits);
729
730 let final_pixel_acc = metrics.pixel_accuracy(&final_predictions, &test_masks);
731 let final_miou = metrics.mean_iou(&final_predictions, &test_masks);
732
733 println!("š Final metrics:");
734 println!(" - Pixel Accuracy: {:.4}", final_pixel_acc);
735 println!(" - Mean IoU: {:.4}", final_miou);
736
737 let confusion = metrics.confusion_matrix(&final_predictions, &test_masks);
739 println!(" - Confusion Matrix:");
740 for i in 0..config.num_classes {
741 print!(" [");
742 for j in 0..config.num_classes {
743 print!("{:4}", confusion[[i, j]]);
744 }
745 println!("]");
746 }
747
748 println!("\nš Model Analysis:");
750 println!(" - Architecture: U-Net with skip connections");
751 println!(
752 " - Parameters: ~{:.1}K (estimated)",
753 (config.encoder_channels.iter().sum::<usize>()
754 + config.decoder_channels.iter().sum::<usize>())
755 / 1000
756 );
757 println!(" - Memory efficient: ā
(skip connections preserve spatial info)");
758 println!(" - JIT optimized: ā
");
759
760 Ok(())
761}
762
763fn main() -> StdResult<()> {
764 println!("šØ Semantic Segmentation Complete Example");
765 println!("==========================================");
766 println!();
767 println!("This example demonstrates:");
768 println!("⢠Building a U-Net style segmentation model");
769 println!("⢠Encoder-decoder architecture with skip connections");
770 println!("⢠Synthetic dataset with geometric shapes");
771 println!("⢠Pixel-wise classification and evaluation");
772 println!("⢠Segmentation metrics (IoU, mIoU, pixel accuracy)");
773 println!("⢠JIT compilation for performance optimization");
774 println!();
775
776 train_segmentation_model()?;
777
778 println!("\nš” Key Concepts Demonstrated:");
779 println!(" š¹ U-Net encoder-decoder architecture");
780 println!(" š¹ Skip connections for spatial information preservation");
781 println!(" š¹ Pixel-wise classification loss");
782 println!(" š¹ Intersection over Union (IoU) metrics");
783 println!(" š¹ Confusion matrix analysis");
784 println!(" š¹ Upsampling and feature concatenation");
785 println!();
786 println!("š For production use:");
787 println!(" ⢠Implement proper upsampling (transpose convolution)");
788 println!(" ⢠Add data augmentation (rotation, flipping, scaling)");
789 println!(" ⢠Use pre-trained encoders (ResNet, EfficientNet)");
790 println!(" ⢠Implement focal loss for class imbalance");
791 println!(" ⢠Add multi-scale training and testing");
792 println!(" ⢠Use real datasets (Cityscapes, ADE20K, Pascal VOC)");
793 println!(" ⢠Implement DeepLabV3+, PSPNet, or other SOTA architectures");
794
795 Ok(())
796}
797
798#[cfg(test)]
799mod tests {
800 use super::*;
801
802 #[test]
803 fn test_segmentation_config() {
804 let config = SegmentationConfig::default();
805 assert_eq!(config.num_classes, 4);
806 assert_eq!(config.input_size, (128, 128));
807 assert!(!config.encoder_channels.is_empty());
808 assert!(!config.decoder_channels.is_empty());
809 }
810
811 #[test]
812 fn test_dataset_generation() {
813 let config = SegmentationConfig::default();
814 let mut dataset = SegmentationDataset::new(config.clone(), 42);
815
816 let (image, mask) = dataset.generate_sample();
817 assert_eq!(
818 image.shape(),
819 &[3, config.input_size.0, config.input_size.1]
820 );
821 assert_eq!(mask.shape(), &[config.input_size.0, config.input_size.1]);
822
823 for &class_id in mask.iter() {
825 assert!(class_id < config.num_classes);
826 }
827 }
828
829 #[test]
830 fn test_segmentation_metrics() {
831 let metrics = SegmentationMetrics::new(3);
832
833 let predictions = Array3::<usize>::from_shape_fn((1, 4, 4), |(_, i, j)| (i + j) % 3);
835 let ground_truth = predictions.clone();
836
837 let pixel_acc = metrics.pixel_accuracy(&predictions, &ground_truth);
838 assert_eq!(pixel_acc, 1.0);
839
840 let miou = metrics.mean_iou(&predictions, &ground_truth);
841 assert_eq!(miou, 1.0);
842 }
843
844 #[test]
845 fn test_model_creation() -> StdResult<()> {
846 let mut rng = SmallRng::seed_from_u64(42);
847 let config = SegmentationConfig {
849 num_classes: 4,
850 input_size: (16, 16), encoder_channels: vec![16, 32, 64, 128], decoder_channels: vec![64, 32, 16, 8], skip_connections: true,
854 };
855
856 let model = UNetModel::new(config.clone(), &mut rng)?;
857
858 let batch_size = 1;
860 let input = Array4::<f32>::ones((batch_size, 3, config.input_size.0, config.input_size.1))
861 .into_dyn();
862 let output = model.forward(&input)?;
863
864 assert_eq!(output.shape()[0], batch_size);
865 assert_eq!(output.shape()[1], config.num_classes);
866 assert_eq!(output.shape()[2], config.input_size.0);
867 assert_eq!(output.shape()[3], config.input_size.1);
868
869 Ok(())
870 }
871
872 #[test]
873 fn test_logits_to_predictions() {
874 let logits = Array4::<f32>::from_shape_fn((1, 3, 2, 2), |(_, c, _, _)| c as f32);
875 let logits_dyn = logits.into_dyn();
876
877 let predictions = logits_to_predictions(&logits_dyn);
878
879 for &pred in predictions.iter() {
881 assert_eq!(pred, 2);
882 }
883 }
884
885 #[test]
886 fn test_masks_to_targets() {
887 let masks = Array3::<usize>::from_shape_fn((1, 2, 2), |(_, i, j)| (i + j) % 3);
888 let targets = masks_to_targets(&masks, 3);
889
890 assert_eq!(targets.shape(), &[1, 3, 2, 2]);
891
892 for b in 0..1 {
894 for i in 0..2 {
895 for j in 0..2 {
896 let class_id = masks[[b, i, j]];
897 for c in 0..3 {
898 let expected = if c == class_id { 1.0 } else { 0.0 };
899 assert_eq!(targets[[b, c, i, j]], expected);
900 }
901 }
902 }
903 }
904 }
905}