1use crate::error::ImageDataError;
33use crate::Alpha;
34use crate::BoxImage;
35use crate::Image;
36use crate::ImageInfo;
37use crate::PixelFormat;
38
39pub struct TensorImage<'a> {
41 tensor: &'a tch::Tensor,
42 info: ImageInfo,
43 planar: bool,
44}
45
46#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48pub enum TensorPixelFormat {
49 Planar(PixelFormat),
51
52 Interlaced(PixelFormat),
54
55 Guess(ColorFormat),
57}
58
59#[derive(Copy, Clone, Debug, Eq, PartialEq)]
61pub enum ColorFormat {
62 Rgb,
64
65 Bgr,
67}
68
69#[allow(clippy::needless_lifetimes)]
78pub trait TensorAsImage {
79 fn as_image<'a>(&'a self, pixel_format: TensorPixelFormat) -> Result<TensorImage<'a>, ImageDataError>;
86
87 fn as_interlaced<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, ImageDataError> {
89 self.as_image(TensorPixelFormat::Interlaced(pixel_format))
90 }
91
92 fn as_planar<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, ImageDataError> {
94 self.as_image(TensorPixelFormat::Planar(pixel_format))
95 }
96
97 fn as_image_guess<'a>(&'a self, color_format: ColorFormat) -> Result<TensorImage<'a>, ImageDataError> {
102 self.as_image(TensorPixelFormat::Guess(color_format))
103 }
104
105 fn as_image_guess_rgb<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
110 self.as_image_guess(ColorFormat::Rgb)
111 }
112
113 fn as_image_guess_bgr<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
118 self.as_image_guess(ColorFormat::Bgr)
119 }
120
121 fn as_mono8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
123 self.as_interlaced(PixelFormat::Mono8)
124 }
125
126 fn as_interlaced_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
128 self.as_interlaced(PixelFormat::Rgb8)
129 }
130
131 fn as_interlaced_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
133 self.as_interlaced(PixelFormat::Rgba8(Alpha::Unpremultiplied))
134 }
135
136 fn as_interlaced_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
138 self.as_interlaced(PixelFormat::Bgr8)
139 }
140
141 fn as_interlaced_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
143 self.as_interlaced(PixelFormat::Bgra8(Alpha::Unpremultiplied))
144 }
145
146 fn as_planar_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
148 self.as_planar(PixelFormat::Rgb8)
149 }
150
151 fn as_planar_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
153 self.as_planar(PixelFormat::Rgba8(Alpha::Unpremultiplied))
154 }
155
156 fn as_planar_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
158 self.as_planar(PixelFormat::Bgr8)
159 }
160
161 fn as_planar_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
163 self.as_planar(PixelFormat::Bgra8(Alpha::Unpremultiplied))
164 }
165}
166
167impl TensorAsImage for tch::Tensor {
168 fn as_image(&self, pixel_format: TensorPixelFormat) -> Result<TensorImage, ImageDataError> {
169 let (planar, info) = match pixel_format {
170 TensorPixelFormat::Planar(pixel_format) => (true, tensor_info(self, pixel_format, true)?),
171 TensorPixelFormat::Interlaced(pixel_format) => (false, tensor_info(self, pixel_format, false)?),
172 TensorPixelFormat::Guess(color_format) => guess_tensor_info(self, color_format)?,
173 };
174 Ok(TensorImage {
175 tensor: self,
176 info,
177 planar,
178 })
179 }
180}
181
182fn tensor_to_byte_vec(tensor: &tch::Tensor) -> Vec<u8> {
183 let size = tensor.numel() * tensor.kind().elt_size_in_bytes();
184 let mut data = vec![0u8; size];
185 tensor.copy_data_u8(&mut data, tensor.numel());
186 data
187}
188
189impl<'a> From<TensorImage<'a>> for Image {
190 fn from(other: TensorImage<'a>) -> Self {
191 let data = if other.planar {
192 tensor_to_byte_vec(&other.tensor.permute([1, 2, 0]))
193 } else {
194 tensor_to_byte_vec(other.tensor)
195 };
196
197 BoxImage::new(other.info, data.into_boxed_slice()).into()
198 }
199}
200
201impl<'a> From<Result<TensorImage<'a>, ImageDataError>> for Image {
202 fn from(other: Result<TensorImage<'a>, ImageDataError>) -> Self {
203 match other {
204 Ok(x) => x.into(),
205 Err(e) => Image::Invalid(e),
206 }
207 }
208}
209
210#[allow(clippy::branches_sharing_code)] fn tensor_info(tensor: &tch::Tensor, pixel_format: PixelFormat, planar: bool) -> Result<ImageInfo, String> {
213 let expected_channels = pixel_format.channels();
214 let dimensions = tensor.dim();
215
216 if dimensions == 3 {
217 let shape = tensor.size3().unwrap();
218 if planar {
219 let (channels, height, width) = shape;
220 if channels != i64::from(expected_channels) {
221 Err(format!("expected shape ({}, height, width), found {:?}", expected_channels, shape))
222 } else {
223 Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
224 }
225 } else {
226 let (height, width, channels) = shape;
227 if channels != i64::from(expected_channels) {
228 Err(format!("expected shape (height, width, {}), found {:?}", expected_channels, shape))
229 } else {
230 Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
231 }
232 }
233 } else if dimensions == 2 && expected_channels == 1 {
234 let (height, width) = tensor.size2().unwrap();
235 Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
236 } else {
237 Err(format!(
238 "wrong number of dimensions ({}) for format ({:?})",
239 dimensions, pixel_format
240 ))
241 }
242}
243
244fn guess_tensor_info(tensor: &tch::Tensor, color_format: ColorFormat) -> Result<(bool, ImageInfo), String> {
246 let dimensions = tensor.dim();
247
248 if dimensions == 2 {
249 let (height, width) = tensor.size2().unwrap();
250 Ok((false, ImageInfo::mono8(width as u32, height as u32)))
251 } else if dimensions == 3 {
252 let shape = tensor.size3().unwrap();
253 match (shape.0 as u32, shape.1 as u32, shape.2 as u32, color_format) {
254 (h, w, 1, _) => Ok((false, ImageInfo::mono8(w, h))),
255 (1, h, w, _) => Ok((false, ImageInfo::mono8(w, h))), (h, w, 3, ColorFormat::Rgb) => Ok((false, ImageInfo::rgb8(w, h))),
257 (h, w, 3, ColorFormat::Bgr) => Ok((false, ImageInfo::bgr8(w, h))),
258 (3, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgb8(w, h))),
259 (3, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgr8(w, h))),
260 (h, w, 4, ColorFormat::Rgb) => Ok((false, ImageInfo::rgba8(w, h))),
261 (h, w, 4, ColorFormat::Bgr) => Ok((false, ImageInfo::bgra8(w, h))),
262 (4, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgba8(w, h))),
263 (4, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgra8(w, h))),
264 _ => Err(format!("unable to guess pixel format for tensor with shape {:?}, expected (height, width) or (height, width, channels) or (channels, height, width) where channels is either 1, 3 or 4", shape))
265 }
266 } else {
267 Err(format!(
268 "unable to guess pixel format for tensor with {} dimensions, expected 2 or 3 dimensions",
269 dimensions
270 ))
271 }
272}
273
274#[cfg(test)]
275mod test {
276 use super::*;
277 use assert2::assert;
278
279 #[test]
280 fn guess_tensor_info() {
281 let data = tch::Tensor::from_slice(&(0..120).collect::<Vec<u8>>());
282
283 assert!(data.reshape([12, 10, 1]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
285 assert!(data.reshape([1, 12, 10]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
286 assert!(data.reshape([12, 10]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
287
288 assert!(data.reshape([8, 5, 3]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 8)));
290 assert!(data.reshape([8, 5, 3]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 8)));
291 assert!(data.reshape([5, 6, 4]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgba8(6, 5)));
292 assert!(data.reshape([5, 6, 4]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgra8(6, 5)));
293
294 assert!(data.reshape([3, 8, 5]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 8)));
296 assert!(data.reshape([3, 8, 5]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 8)));
297 assert!(data.reshape([4, 5, 6]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgba8(6, 5)));
298 assert!(data.reshape([4, 5, 6]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgra8(6, 5)));
299
300 assert!(let Err(_) = data.reshape([120]).as_image_guess_rgb().map(|x| x.info));
302 assert!(let Err(_) = data.reshape([2, 10, 6]).as_image_guess_rgb().map(|x| x.info));
303 assert!(let Err(_) = data.reshape([6, 10, 2]).as_image_guess_rgb().map(|x| x.info));
304 assert!(let Err(_) = data.reshape([8, 5, 3, 1]).as_image_guess_rgb().map(|x| x.info));
305 assert!(let Err(_) = data.reshape([4, 5, 6, 1]).as_image_guess_rgb().map(|x| x.info));
306 }
307
308 #[test]
309 fn tensor_info_interlaced_with_known_format() {
310 let data = tch::Tensor::from_slice(&(0..60).collect::<Vec<u8>>());
311
312 assert!(data.reshape([12, 5, 1]).as_mono8().map(|x| x.info) == Ok(ImageInfo::mono8(5, 12)));
314 assert!(data.reshape([12, 5]).as_mono8().map(|x| x.info) == Ok(ImageInfo::mono8(5, 12)));
315 assert!(let Err(_) = data.reshape([12, 5, 1, 1]).as_mono8().map(|x| x.info));
316 assert!(let Err(_) = data.reshape([6, 5, 2]).as_mono8().map(|x| x.info));
317 assert!(let Err(_) = data.reshape([3, 5, 4]).as_mono8().map(|x| x.info));
318 assert!(let Err(_) = data.reshape([4, 5, 3]).as_mono8().map(|x| x.info));
319 assert!(let Err(_) = data.reshape([60]).as_mono8().map(|x| x.info));
320
321 assert!(data.reshape([4, 5, 3]).as_interlaced_rgb8().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 4)));
323 assert!(data.reshape([4, 5, 3]).as_interlaced_bgr8().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 4)));
324 assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_interlaced_bgr8().map(|x| x.info));
325 assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_interlaced_bgr8().map(|x| x.info));
326 assert!(let Err(_) = data.reshape([3, 5, 4]).as_interlaced_bgr8().map(|x| x.info));
327 assert!(let Err(_) = data.reshape([3, 5, 4]).as_interlaced_bgr8().map(|x| x.info));
328 assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgb8().map(|x| x.info));
329 assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgb8().map(|x| x.info));
330
331 assert!(data.reshape([3, 5, 4]).as_interlaced_rgba8().map(|x| x.info) == Ok(ImageInfo::rgba8(5, 3)));
333 assert!(data.reshape([3, 5, 4]).as_interlaced_bgra8().map(|x| x.info) == Ok(ImageInfo::bgra8(5, 3)));
334 assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_interlaced_rgba8().map(|x| x.info));
335 assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_interlaced_bgra8().map(|x| x.info));
336 assert!(let Err(_) = data.reshape([4, 5, 3]).as_interlaced_rgba8().map(|x| x.info));
337 assert!(let Err(_) = data.reshape([4, 5, 3]).as_interlaced_bgra8().map(|x| x.info));
338 assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgba8().map(|x| x.info));
339 assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_bgra8().map(|x| x.info));
340 }
341
342 #[test]
343 fn tensor_info_planar_with_known_format() {
344 let data = tch::Tensor::from_slice(&(0..60).collect::<Vec<u8>>());
345
346 assert!(data.reshape([3, 4, 5]).as_planar_rgb8().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 4)));
348 assert!(data.reshape([3, 4, 5]).as_planar_bgr8().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 4)));
349 assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_planar_bgr8().map(|x| x.info));
350 assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_planar_bgr8().map(|x| x.info));
351 assert!(let Err(_) = data.reshape([4, 5, 3]).as_planar_bgr8().map(|x| x.info));
352 assert!(let Err(_) = data.reshape([4, 5, 3]).as_planar_bgr8().map(|x| x.info));
353 assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgb8().map(|x| x.info));
354 assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgb8().map(|x| x.info));
355
356 assert!(data.reshape([4, 3, 5]).as_planar_rgba8().map(|x| x.info) == Ok(ImageInfo::rgba8(5, 3)));
358 assert!(data.reshape([4, 3, 5]).as_planar_bgra8().map(|x| x.info) == Ok(ImageInfo::bgra8(5, 3)));
359 assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_planar_rgba8().map(|x| x.info));
360 assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_planar_bgra8().map(|x| x.info));
361 assert!(let Err(_) = data.reshape([3, 5, 4]).as_planar_rgba8().map(|x| x.info));
362 assert!(let Err(_) = data.reshape([3, 5, 4]).as_planar_bgra8().map(|x| x.info));
363 assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgba8().map(|x| x.info));
364 assert!(let Err(_) = data.reshape([15, 4]).as_planar_bgra8().map(|x| x.info));
365 }
366}