1use std::path::Path;
2
3use image::{DynamicImage, ImageFormat, ImageReader};
4use yscv_tensor::Tensor;
5
6use super::super::ImgProcError;
7
8pub fn imread(path: impl AsRef<Path>) -> Result<Tensor, ImgProcError> {
15 let path = path.as_ref();
16 let reader = ImageReader::open(path).map_err(|e| ImgProcError::Io {
17 message: format!("{}: {e}", path.display()),
18 })?;
19 let img = reader
20 .decode()
21 .map_err(|e| ImgProcError::ImageDecode {
22 message: format!("{}: {e}", path.display()),
23 })?
24 .into_rgb8();
25 let (w, h) = (img.width() as usize, img.height() as usize);
26 let raw = img.into_raw();
27 let data: Vec<f32> = raw.iter().map(|&b| b as f32 / 255.0).collect();
28 Tensor::from_vec(vec![h, w, 3], data).map_err(Into::into)
29}
30
31pub fn imread_gray(path: impl AsRef<Path>) -> Result<Tensor, ImgProcError> {
35 let path = path.as_ref();
36 let reader = ImageReader::open(path).map_err(|e| ImgProcError::Io {
37 message: format!("{}: {e}", path.display()),
38 })?;
39 let img = reader
40 .decode()
41 .map_err(|e| ImgProcError::ImageDecode {
42 message: format!("{}: {e}", path.display()),
43 })?
44 .into_luma8();
45 let (w, h) = (img.width() as usize, img.height() as usize);
46 let raw = img.into_raw();
47 let data: Vec<f32> = raw.iter().map(|&b| b as f32 / 255.0).collect();
48 Tensor::from_vec(vec![h, w, 1], data).map_err(Into::into)
49}
50
51pub fn imwrite(path: impl AsRef<Path>, image: &Tensor) -> Result<(), ImgProcError> {
57 let path = path.as_ref();
58 let shape = image.shape();
59 if shape.len() != 3 {
60 return Err(ImgProcError::InvalidImageShape {
61 expected_rank: 3,
62 got: shape.to_vec(),
63 });
64 }
65 let (h, w, c) = (shape[0], shape[1], shape[2]);
66 if c != 1 && c != 3 {
67 return Err(ImgProcError::InvalidChannelCount {
68 expected: 3,
69 got: c,
70 });
71 }
72
73 let format = ImageFormat::from_path(path).map_err(|_| ImgProcError::UnsupportedFormat {
74 path: path.display().to_string(),
75 })?;
76
77 let bytes: Vec<u8> = image
78 .data()
79 .iter()
80 .map(|&v| (v.clamp(0.0, 1.0) * 255.0 + 0.5) as u8)
81 .collect();
82
83 let dyn_img = if c == 1 {
84 DynamicImage::ImageLuma8(
85 image::GrayImage::from_raw(w as u32, h as u32, bytes).ok_or_else(|| {
86 ImgProcError::ImageEncode {
87 message: "failed to construct grayscale image buffer".into(),
88 }
89 })?,
90 )
91 } else {
92 DynamicImage::ImageRgb8(
93 image::RgbImage::from_raw(w as u32, h as u32, bytes).ok_or_else(|| {
94 ImgProcError::ImageEncode {
95 message: "failed to construct RGB image buffer".into(),
96 }
97 })?,
98 )
99 };
100
101 dyn_img
102 .save_with_format(path, format)
103 .map_err(|e| ImgProcError::ImageEncode {
104 message: format!("{}: {e}", path.display()),
105 })
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::fs;
112
113 #[test]
114 #[cfg_attr(miri, ignore)]
115 fn test_imread_imwrite_roundtrip_rgb() {
116 let dir = std::env::temp_dir().join("yscv_io_test");
117 let _ = fs::create_dir_all(&dir);
118 let path = dir.join("test_rgb.png");
119
120 let data: Vec<f32> = vec![
122 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.5, 0.0, 0.5, ];
125 let img = Tensor::from_vec(vec![2, 3, 3], data).unwrap();
126 imwrite(&path, &img).unwrap();
127
128 let loaded = imread(&path).unwrap();
129 assert_eq!(loaded.shape(), &[2, 3, 3]);
130
131 for (a, b) in img.data().iter().zip(loaded.data().iter()) {
133 assert!((a - b).abs() < 2.0 / 255.0, "a={a} b={b}");
134 }
135
136 let _ = fs::remove_file(&path);
137 }
138
139 #[test]
140 #[cfg_attr(miri, ignore)]
141 fn test_imread_gray() {
142 let dir = std::env::temp_dir().join("yscv_io_test");
143 let _ = fs::create_dir_all(&dir);
144 let path = dir.join("test_gray.png");
145
146 let data: Vec<f32> = vec![0.0, 0.5, 1.0, 0.25];
147 let img = Tensor::from_vec(vec![2, 2, 1], data).unwrap();
148 imwrite(&path, &img).unwrap();
149
150 let loaded = imread_gray(&path).unwrap();
151 assert_eq!(loaded.shape(), &[2, 2, 1]);
152
153 let _ = fs::remove_file(&path);
154 }
155
156 #[test]
157 fn test_imwrite_invalid_rank() {
158 let path = Path::new("/tmp/invalid.png");
159 let img = Tensor::from_vec(vec![4], vec![0.0; 4]).unwrap();
160 assert!(imwrite(path, &img).is_err());
161 }
162
163 #[test]
164 #[cfg_attr(miri, ignore)]
165 fn test_imread_nonexistent_file() {
166 let result = imread(Path::new("/tmp/nonexistent_yscv_test_image.png"));
167 assert!(result.is_err());
168 }
169}