tch_plus/vision/
image.rs

1//! Utility functions to manipulate images.
2use crate::wrappers::image::{load_hwc, load_hwc_from_mem, resize_hwc, save_hwc};
3use crate::{Device, TchError, Tensor};
4use std::io;
5use std::path::Path;
6
7pub(crate) fn hwc_to_chw(tensor: &Tensor) -> Tensor {
8    tensor.permute([2, 0, 1])
9}
10
11pub(crate) fn chw_to_hwc(tensor: &Tensor) -> Tensor {
12    tensor.permute([1, 2, 0])
13}
14
15/// Loads an image from a file.
16///
17/// On success returns a tensor of shape [channel, height, width].
18pub fn load<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
19    let tensor = load_hwc(path)?;
20    Ok(hwc_to_chw(&tensor))
21}
22
23/// Loads an image from memory.
24///
25/// On success returns a tensor of shape [channel, height, width].
26pub fn load_from_memory(img_data: &[u8]) -> Result<Tensor, TchError> {
27    let tensor = load_hwc_from_mem(img_data)?;
28    Ok(hwc_to_chw(&tensor))
29}
30
31/// Saves an image to a file.
32///
33/// This expects as input a tensor of shape [channel, height, width].
34/// The image format is based on the filename suffix, supported suffixes
35/// are jpg, png, tga, and bmp.
36/// The tensor input should be of kind UInt8 with values ranging from
37/// 0 to 255.
38pub fn save<T: AsRef<Path>>(t: &Tensor, path: T) -> Result<(), TchError> {
39    let t = t.to_kind(crate::Kind::Uint8);
40    match t.size().as_slice() {
41        [1, _, _, _] => save_hwc(&chw_to_hwc(&t.squeeze_dim(0)).to_device(Device::Cpu), path),
42        [_, _, _] => save_hwc(&chw_to_hwc(&t).to_device(Device::Cpu), path),
43        sz => Err(TchError::FileFormat(format!("unexpected size for image tensor {sz:?}"))),
44    }
45}
46
47/// Resizes an image.
48///
49/// This expects as input a tensor of shape [channel, height, width] and returns
50/// a tensor of shape [channel, out_h, out_w].
51pub fn resize(t: &Tensor, out_w: i64, out_h: i64) -> Result<Tensor, TchError> {
52    Ok(hwc_to_chw(&resize_hwc(&chw_to_hwc(t), out_w, out_h)?))
53}
54
55pub fn resize_preserve_aspect_ratio_hwc(
56    t: &Tensor,
57    out_w: i64,
58    out_h: i64,
59) -> Result<Tensor, TchError> {
60    let tensor_size = t.size();
61    let (w, h) = (tensor_size[0], tensor_size[1]);
62    if w * out_h == h * out_w {
63        Ok(hwc_to_chw(&resize_hwc(t, out_w, out_h)?))
64    } else {
65        let (resize_w, resize_h) = {
66            let ratio_w = out_w as f64 / w as f64;
67            let ratio_h = out_h as f64 / h as f64;
68            let ratio = ratio_w.max(ratio_h);
69            ((ratio * h as f64) as i64, (ratio * w as f64) as i64)
70        };
71        let resize_w = i64::max(resize_w, out_w);
72        let resize_h = i64::max(resize_h, out_h);
73        let t = hwc_to_chw(&resize_hwc(t, resize_w, resize_h)?);
74        let t = if resize_w == out_w { t } else { t.f_narrow(2, (resize_w - out_w) / 2, out_w)? };
75        let t = if resize_h == out_h { t } else { t.f_narrow(1, (resize_h - out_h) / 2, out_h)? };
76        Ok(t)
77    }
78}
79
80/// Resize an image, preserve the aspect ratio by taking a center crop.
81///
82/// This expects as input a tensor of shape [channel, height, width] and returns
83pub fn resize_preserve_aspect_ratio(
84    t: &Tensor,
85    out_w: i64,
86    out_h: i64,
87) -> Result<Tensor, TchError> {
88    resize_preserve_aspect_ratio_hwc(&chw_to_hwc(t), out_w, out_h)
89}
90
91/// Loads and resize an image, preserve the aspect ratio by taking a center crop.
92pub fn load_and_resize<T: AsRef<Path>>(
93    path: T,
94    out_w: i64,
95    out_h: i64,
96) -> Result<Tensor, TchError> {
97    let tensor = load_hwc(path)?;
98    resize_preserve_aspect_ratio_hwc(&tensor, out_w, out_h)
99}
100
101/// Loads and resize an image from memory, preserve the aspect ratio by taking a center crop.
102pub fn load_and_resize_from_memory(
103    img_data: &[u8],
104    out_w: i64,
105    out_h: i64,
106) -> Result<Tensor, TchError> {
107    let tensor = load_hwc_from_mem(img_data)?;
108    resize_preserve_aspect_ratio_hwc(&tensor, out_w, out_h)
109}
110
111fn visit_dirs(dir: &Path, files: &mut Vec<std::fs::DirEntry>) -> Result<(), TchError> {
112    if dir.is_dir() {
113        for entry in std::fs::read_dir(dir)? {
114            let entry = entry?;
115            let path = entry.path();
116            if path.is_dir() {
117                visit_dirs(&path, files)?;
118            } else if entry
119                .file_name()
120                .to_str()
121                .map_or(false, |s| s.ends_with(".png") || s.ends_with(".jpg"))
122            {
123                files.push(entry);
124            }
125        }
126    }
127    Ok(())
128}
129
130/// Loads all the images in a directory.
131pub fn load_dir<T: AsRef<Path>>(path: T, out_w: i64, out_h: i64) -> Result<Tensor, TchError> {
132    let mut files: Vec<std::fs::DirEntry> = vec![];
133    visit_dirs(path.as_ref(), &mut files)?;
134    if files.is_empty() {
135        return Err(TchError::Io(io::Error::new(
136            io::ErrorKind::NotFound,
137            format!("no image found in {:?}", path.as_ref(),),
138        )));
139    }
140    let v: Vec<_> = files
141        .iter()
142        // Silently discard image errors.
143        .filter_map(|x| load_and_resize(x.path(), out_w, out_h).ok())
144        .collect();
145    Ok(Tensor::stack(&v, 0))
146}