Skip to main content

yscv_imgproc/ops/
io.rs

1use std::path::Path;
2
3use image::{DynamicImage, ImageFormat, ImageReader};
4use yscv_tensor::Tensor;
5
6use super::super::ImgProcError;
7
8/// Read an image from disk and return it as a `[H, W, C]` `f32` tensor with values in `[0, 1]`.
9///
10/// Supported formats: PNG, JPEG, BMP, WebP.
11/// RGB images produce shape `[H, W, 3]`, grayscale images are converted to RGB.
12///
13/// Accepts any path-like type (`&str`, `String`, `&Path`, `PathBuf`, etc.).
14pub fn imread(path: impl AsRef<Path>) -> Result<Tensor, ImgProcError> {
15    let path = path.as_ref();
16    let reader = ImageReader::open(path).map_err(|e| ImgProcError::Io {
17        message: format!("{}: {e}", path.display()),
18    })?;
19    let img = reader
20        .decode()
21        .map_err(|e| ImgProcError::ImageDecode {
22            message: format!("{}: {e}", path.display()),
23        })?
24        .into_rgb8();
25    let (w, h) = (img.width() as usize, img.height() as usize);
26    let raw = img.into_raw();
27    let data: Vec<f32> = raw.iter().map(|&b| b as f32 / 255.0).collect();
28    Tensor::from_vec(vec![h, w, 3], data).map_err(Into::into)
29}
30
31/// Read an image as grayscale and return it as a `[H, W, 1]` `f32` tensor with values in `[0, 1]`.
32///
33/// Accepts any path-like type (`&str`, `String`, `&Path`, `PathBuf`, etc.).
34pub fn imread_gray(path: impl AsRef<Path>) -> Result<Tensor, ImgProcError> {
35    let path = path.as_ref();
36    let reader = ImageReader::open(path).map_err(|e| ImgProcError::Io {
37        message: format!("{}: {e}", path.display()),
38    })?;
39    let img = reader
40        .decode()
41        .map_err(|e| ImgProcError::ImageDecode {
42            message: format!("{}: {e}", path.display()),
43        })?
44        .into_luma8();
45    let (w, h) = (img.width() as usize, img.height() as usize);
46    let raw = img.into_raw();
47    let data: Vec<f32> = raw.iter().map(|&b| b as f32 / 255.0).collect();
48    Tensor::from_vec(vec![h, w, 1], data).map_err(Into::into)
49}
50
51/// Write a `[H, W, C]` `f32` tensor (values in `[0, 1]`) to disk.
52///
53/// Format is inferred from the file extension. Channels must be 1 (grayscale) or 3 (RGB).
54///
55/// Accepts any path-like type (`&str`, `String`, `&Path`, `PathBuf`, etc.).
56pub fn imwrite(path: impl AsRef<Path>, image: &Tensor) -> Result<(), ImgProcError> {
57    let path = path.as_ref();
58    let shape = image.shape();
59    if shape.len() != 3 {
60        return Err(ImgProcError::InvalidImageShape {
61            expected_rank: 3,
62            got: shape.to_vec(),
63        });
64    }
65    let (h, w, c) = (shape[0], shape[1], shape[2]);
66    if c != 1 && c != 3 {
67        return Err(ImgProcError::InvalidChannelCount {
68            expected: 3,
69            got: c,
70        });
71    }
72
73    let format = ImageFormat::from_path(path).map_err(|_| ImgProcError::UnsupportedFormat {
74        path: path.display().to_string(),
75    })?;
76
77    let bytes: Vec<u8> = image
78        .data()
79        .iter()
80        .map(|&v| (v.clamp(0.0, 1.0) * 255.0 + 0.5) as u8)
81        .collect();
82
83    let dyn_img = if c == 1 {
84        DynamicImage::ImageLuma8(
85            image::GrayImage::from_raw(w as u32, h as u32, bytes).ok_or_else(|| {
86                ImgProcError::ImageEncode {
87                    message: "failed to construct grayscale image buffer".into(),
88                }
89            })?,
90        )
91    } else {
92        DynamicImage::ImageRgb8(
93            image::RgbImage::from_raw(w as u32, h as u32, bytes).ok_or_else(|| {
94                ImgProcError::ImageEncode {
95                    message: "failed to construct RGB image buffer".into(),
96                }
97            })?,
98        )
99    };
100
101    dyn_img
102        .save_with_format(path, format)
103        .map_err(|e| ImgProcError::ImageEncode {
104            message: format!("{}: {e}", path.display()),
105        })
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::fs;
112
113    #[test]
114    #[cfg_attr(miri, ignore)]
115    fn test_imread_imwrite_roundtrip_rgb() {
116        let dir = std::env::temp_dir().join("yscv_io_test");
117        let _ = fs::create_dir_all(&dir);
118        let path = dir.join("test_rgb.png");
119
120        // Create a small 2x3 RGB image tensor.
121        let data: Vec<f32> = vec![
122            1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, // row 0: R, G, B
123            0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, // row 1
124        ];
125        let img = Tensor::from_vec(vec![2, 3, 3], data).unwrap();
126        imwrite(&path, &img).unwrap();
127
128        let loaded = imread(&path).unwrap();
129        assert_eq!(loaded.shape(), &[2, 3, 3]);
130
131        // Values should round-trip within 1/255 tolerance.
132        for (a, b) in img.data().iter().zip(loaded.data().iter()) {
133            assert!((a - b).abs() < 2.0 / 255.0, "a={a} b={b}");
134        }
135
136        let _ = fs::remove_file(&path);
137    }
138
139    #[test]
140    #[cfg_attr(miri, ignore)]
141    fn test_imread_gray() {
142        let dir = std::env::temp_dir().join("yscv_io_test");
143        let _ = fs::create_dir_all(&dir);
144        let path = dir.join("test_gray.png");
145
146        let data: Vec<f32> = vec![0.0, 0.5, 1.0, 0.25];
147        let img = Tensor::from_vec(vec![2, 2, 1], data).unwrap();
148        imwrite(&path, &img).unwrap();
149
150        let loaded = imread_gray(&path).unwrap();
151        assert_eq!(loaded.shape(), &[2, 2, 1]);
152
153        let _ = fs::remove_file(&path);
154    }
155
156    #[test]
157    fn test_imwrite_invalid_rank() {
158        let path = Path::new("/tmp/invalid.png");
159        let img = Tensor::from_vec(vec![4], vec![0.0; 4]).unwrap();
160        assert!(imwrite(path, &img).is_err());
161    }
162
163    #[test]
164    #[cfg_attr(miri, ignore)]
165    fn test_imread_nonexistent_file() {
166        let result = imread(Path::new("/tmp/nonexistent_yscv_test_image.png"));
167        assert!(result.is_err());
168    }
169}