wonnx_preprocessing/
image.rs

1use std::path::Path;
2
3use image::{imageops::FilterType, ImageBuffer, Pixel, Rgb};
4use ndarray::s;
5
6// Loads an image as (1,1,w,h) with pixels ranging 0...1 for 0..255 pixel values
7pub fn load_bw_image(
8    image_path: &Path,
9    width: usize,
10    height: usize,
11) -> ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<[usize; 4]>> {
12    let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(image_path)
13        .unwrap()
14        .resize_exact(width as u32, height as u32, FilterType::Nearest)
15        .to_rgb8();
16
17    // Python:
18    // # image[y, x, RGB]
19    // # x==0 --> left
20    // # y==0 --> top
21
22    // See https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb
23    // for pre-processing image.
24    // WARNING: Note order of declaration of arguments: (_,c,j,i)
25    ndarray::Array::from_shape_fn((1, 1, width, height), |(_, c, j, i)| {
26        let pixel = image_buffer.get_pixel(i as u32, j as u32);
27        let channels = pixel.channels();
28
29        // range [0, 255] -> range [0, 1]
30        (channels[c] as f32) / 255.0
31    })
32}
33
34// Loads an image as (1, 3, h, w)
35pub fn load_rgb_image(
36    image_path: &Path,
37    width: usize,
38    height: usize,
39) -> ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<[usize; 4]>> {
40    log::info!("load_rgb_image {:?} {}x{}", image_path, width, height);
41    let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(image_path)
42        .unwrap()
43        .resize_to_fill(width as u32, height as u32, FilterType::Nearest)
44        .to_rgb8();
45
46    // Python:
47    // # image[y, x, RGB]
48    // # x==0 --> left
49    // # y==0 --> top
50
51    // See https://github.com/onnx/models/blob/master/vision/classification/imagenet_inference.ipynb
52    // for pre-processing image.
53    // WARNING: Note order of declaration of arguments: (_,c,j,i)
54    let mut array = ndarray::Array::from_shape_fn((1, 3, height, width), |(_, c, j, i)| {
55        let pixel = image_buffer.get_pixel(i as u32, j as u32);
56        let channels = pixel.channels();
57
58        // range [0, 255] -> range [0, 1]
59        (channels[c] as f32) / 255.0
60    });
61
62    // Normalize channels to mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]
63    let mean = [0.485, 0.456, 0.406];
64    let std = [0.229, 0.224, 0.225];
65    for c in 0..3 {
66        let mut channel_array = array.slice_mut(s![0, c, .., ..]);
67        channel_array -= mean[c];
68        channel_array /= std[c];
69    }
70
71    // Batch of 1
72    array
73}