show_image/features/
tch.rs

1//! Support for the [`tch`][::tch] crate.
2//!
3//! This module adds support for displaying [`tch::Tensor`] as images.
4//! The main interface is provided by an extension trait [`TensorAsImage`],
5//! which allows you to wrap a tensor in a [`TensorImage`].
6//! The wrapper struct adds some required meta-data for interpreting the tensor data as an image.
7//!
8//! The meta-data has to be supplied by the user, or it can be guessed automatically based on the tensor shape.
9//! When guessing, you do need to specify if you want to interpret multi-channel tensors as RGB or BGR.
10//! An extension trait [`TensorAsImage`] is provided to construct the wrapper with the proper meta-data.
11//!
12//! It is not always possible to interpret a tensor as the requested image format,
13//! so all function in the extension trait return a [`Result`].
14//! The [`Into<Image>`] trait is implemented for [`TensorImage`] and for [`Result`]`<`[`TensorImage`]`, `[`ImageDataError`]`>`,
15//! so you can directly pass use the result to so set the image of a window directly.
16//!
17//! Both planar and interlaced tensors are supported.
18//! If you specify the format manually, you must also specify if the tensor contains interlaced or planar data.
19//! If you let the library guess, it will try to deduce it automatically based on the tensor shape.
20//!
21//! # Example
22//! ```no_run
23//! use show_image::{create_window, WindowOptions};
24//! use show_image::tch::TensorAsImage;
25//!
26//! let tensor = tch::vision::imagenet::load_image("/path/to/image.png").unwrap();
27//! let window = create_window("image", WindowOptions::default())?;
28//! window.set_image("image-001", tensor.as_image_guess_rgb())?;
29//! # Result::<(), Box<dyn std::error::Error>>::Ok(())
30//! ```
31
32use crate::error::ImageDataError;
33use crate::Alpha;
34use crate::BoxImage;
35use crate::Image;
36use crate::ImageInfo;
37use crate::PixelFormat;
38
39/// Wrapper for [`tch::Tensor`] that implements `Into<Image>`.
40pub struct TensorImage<'a> {
41	tensor: &'a tch::Tensor,
42	info: ImageInfo,
43	planar: bool,
44}
45
46/// The pixel format of a tensor, or a color format to guess the pixel format.
47#[derive(Copy, Clone, Debug, Eq, PartialEq)]
48pub enum TensorPixelFormat {
49	/// The tensor has planar pixel data.
50	Planar(PixelFormat),
51
52	/// The tensor has interlaced pixel data.
53	Interlaced(PixelFormat),
54
55	/// The library should guess if the pixel data is planar or interlaced.
56	Guess(ColorFormat),
57}
58
59/// A preferred color format for guessing the pixel format of a tensor.
60#[derive(Copy, Clone, Debug, Eq, PartialEq)]
61pub enum ColorFormat {
62	/// Interpret 3 or 4 channel tensors as RGB or RGBA.
63	Rgb,
64
65	/// Interpret 3 or 4 channel tensors as BGR or BGRA.
66	Bgr,
67}
68
69/// Extension trait to allow displaying tensors as image.
70///
71/// The tensor data will always be copied.
72/// Additionally, the data will be converted to 8 bit integers,
73/// and planar data will be converted to interlaced data.
74///
75/// The original tensor is unaffected, but the conversion can be expensive.
76/// If you also need to convert the tensor, consider doing so before displaying it.
77#[allow(clippy::needless_lifetimes)]
78pub trait TensorAsImage {
79	/// Wrap the tensor in a [`TensorImage`] that implements `Into<Image>`.
80	///
81	/// This function requires you to specify the pixel format of the tensor,
82	/// or a preferred color format to have the library guess based on the tensor shape.
83	///
84	/// See the other functions in the trait for easier shorthands.
85	fn as_image<'a>(&'a self, pixel_format: TensorPixelFormat) -> Result<TensorImage<'a>, ImageDataError>;
86
87	/// Wrap the tensor with a known pixel format in a [`TensorImage`], assuming it holds interlaced pixel data.
88	fn as_interlaced<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, ImageDataError> {
89		self.as_image(TensorPixelFormat::Interlaced(pixel_format))
90	}
91
92	/// Wrap the tensor with a known pixel format in a [`TensorImage`], assuming it holds planar pixel data.
93	fn as_planar<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, ImageDataError> {
94		self.as_image(TensorPixelFormat::Planar(pixel_format))
95	}
96
97	/// Wrap the tensor in a [`TensorImage`].
98	///
99	/// The pixel format of the tensor will be guessed based on the shape.
100	/// The `color_format` argument determines if tensors with 3 or 4 channels are interpreted as RGB or BGR.
101	fn as_image_guess<'a>(&'a self, color_format: ColorFormat) -> Result<TensorImage<'a>, ImageDataError> {
102		self.as_image(TensorPixelFormat::Guess(color_format))
103	}
104
105	/// Wrap the tensor in a [`TensorImage`].
106	///
107	/// The pixel format of the tensor will be guessed based on the shape.
108	/// Tensors with 3 or 4 channels will be interpreted as RGB.
109	fn as_image_guess_rgb<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
110		self.as_image_guess(ColorFormat::Rgb)
111	}
112
113	/// Wrap the tensor in a [`TensorImage`].
114	///
115	/// The pixel format of the tensor will be guessed based on the shape.
116	/// Tensors with 3 or 4 channels will be interpreted as BGR.
117	fn as_image_guess_bgr<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
118		self.as_image_guess(ColorFormat::Bgr)
119	}
120
121	/// Wrap the tensor in a [`TensorImage`], assuming it holds monochrome data.
122	fn as_mono8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
123		self.as_interlaced(PixelFormat::Mono8)
124	}
125
126	/// Wrap the tensor in a [`TensorImage`], assuming it holds interlaced RGB data.
127	fn as_interlaced_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
128		self.as_interlaced(PixelFormat::Rgb8)
129	}
130
131	/// Wrap the tensor in a [`TensorImage`], assuming it holds interlaced RGBA data.
132	fn as_interlaced_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
133		self.as_interlaced(PixelFormat::Rgba8(Alpha::Unpremultiplied))
134	}
135
136	/// Wrap the tensor in a [`TensorImage`], assuming it holds interlaced BGR data.
137	fn as_interlaced_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
138		self.as_interlaced(PixelFormat::Bgr8)
139	}
140
141	/// Wrap the tensor in a [`TensorImage`], assuming it holds interlaced BGRA data.
142	fn as_interlaced_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
143		self.as_interlaced(PixelFormat::Bgra8(Alpha::Unpremultiplied))
144	}
145
146	/// Wrap the tensor in a [`TensorImage`], assuming it holds planar RGB data.
147	fn as_planar_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
148		self.as_planar(PixelFormat::Rgb8)
149	}
150
151	/// Wrap the tensor in a [`TensorImage`], assuming it holds planar RGBA data.
152	fn as_planar_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
153		self.as_planar(PixelFormat::Rgba8(Alpha::Unpremultiplied))
154	}
155
156	/// Wrap the tensor in a [`TensorImage`], assuming it holds planar BGR data.
157	fn as_planar_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
158		self.as_planar(PixelFormat::Bgr8)
159	}
160
161	/// Wrap the tensor in a [`TensorImage`], assuming it holds planar BGRA data.
162	fn as_planar_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
163		self.as_planar(PixelFormat::Bgra8(Alpha::Unpremultiplied))
164	}
165}
166
167impl TensorAsImage for tch::Tensor {
168	fn as_image(&self, pixel_format: TensorPixelFormat) -> Result<TensorImage, ImageDataError> {
169		let (planar, info) = match pixel_format {
170			TensorPixelFormat::Planar(pixel_format) => (true, tensor_info(self, pixel_format, true)?),
171			TensorPixelFormat::Interlaced(pixel_format) => (false, tensor_info(self, pixel_format, false)?),
172			TensorPixelFormat::Guess(color_format) => guess_tensor_info(self, color_format)?,
173		};
174		Ok(TensorImage {
175			tensor: self,
176			info,
177			planar,
178		})
179	}
180}
181
182fn tensor_to_byte_vec(tensor: &tch::Tensor) -> Vec<u8> {
183	let size = tensor.numel() * tensor.kind().elt_size_in_bytes();
184	let mut data = vec![0u8; size];
185	tensor.copy_data_u8(&mut data, tensor.numel());
186	data
187}
188
189impl<'a> From<TensorImage<'a>> for Image {
190	fn from(other: TensorImage<'a>) -> Self {
191		let data = if other.planar {
192			tensor_to_byte_vec(&other.tensor.permute([1, 2, 0]))
193		} else {
194			tensor_to_byte_vec(other.tensor)
195		};
196
197		BoxImage::new(other.info, data.into_boxed_slice()).into()
198	}
199}
200
201impl<'a> From<Result<TensorImage<'a>, ImageDataError>> for Image {
202	fn from(other: Result<TensorImage<'a>, ImageDataError>) -> Self {
203		match other {
204			Ok(x) => x.into(),
205			Err(e) => Image::Invalid(e),
206		}
207	}
208}
209
210/// Compute the image info of a tensor, given a known pixel format.
211#[allow(clippy::branches_sharing_code)] // Stop lying, clippy.
212fn tensor_info(tensor: &tch::Tensor, pixel_format: PixelFormat, planar: bool) -> Result<ImageInfo, String> {
213	let expected_channels = pixel_format.channels();
214	let dimensions = tensor.dim();
215
216	if dimensions == 3 {
217		let shape = tensor.size3().unwrap();
218		if planar {
219			let (channels, height, width) = shape;
220			if channels != i64::from(expected_channels) {
221				Err(format!("expected shape ({}, height, width), found {:?}", expected_channels, shape))
222			} else {
223				Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
224			}
225		} else {
226			let (height, width, channels) = shape;
227			if channels != i64::from(expected_channels) {
228				Err(format!("expected shape (height, width, {}), found {:?}", expected_channels, shape))
229			} else {
230				Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
231			}
232		}
233	} else if dimensions == 2 && expected_channels == 1 {
234		let (height, width) = tensor.size2().unwrap();
235		Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
236	} else {
237		Err(format!(
238			"wrong number of dimensions ({}) for format ({:?})",
239			dimensions, pixel_format
240		))
241	}
242}
243
244/// Guess the image info of a tensor.
245fn guess_tensor_info(tensor: &tch::Tensor, color_format: ColorFormat) -> Result<(bool, ImageInfo), String> {
246	let dimensions = tensor.dim();
247
248	if dimensions == 2 {
249		let (height, width) = tensor.size2().unwrap();
250		Ok((false, ImageInfo::mono8(width as u32, height as u32)))
251	} else if dimensions == 3 {
252		let shape = tensor.size3().unwrap();
253		match (shape.0 as u32, shape.1 as u32, shape.2 as u32, color_format) {
254			(h, w, 1, _) => Ok((false, ImageInfo::mono8(w, h))),
255			(1, h, w, _) => Ok((false, ImageInfo::mono8(w, h))), // "planar" doesn't do anything here, so call it interlaced
256			(h, w, 3, ColorFormat::Rgb) => Ok((false, ImageInfo::rgb8(w, h))),
257			(h, w, 3, ColorFormat::Bgr) => Ok((false, ImageInfo::bgr8(w, h))),
258			(3, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgb8(w, h))),
259			(3, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgr8(w, h))),
260			(h, w, 4, ColorFormat::Rgb) => Ok((false, ImageInfo::rgba8(w, h))),
261			(h, w, 4, ColorFormat::Bgr) => Ok((false, ImageInfo::bgra8(w, h))),
262			(4, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgba8(w, h))),
263			(4, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgra8(w, h))),
264			_ => Err(format!("unable to guess pixel format for tensor with shape {:?}, expected (height, width) or (height, width, channels) or (channels, height, width) where channels is either 1, 3 or 4", shape))
265		}
266	} else {
267		Err(format!(
268			"unable to guess pixel format for tensor with {} dimensions, expected 2 or 3 dimensions",
269			dimensions
270		))
271	}
272}
273
274#[cfg(test)]
275mod test {
276	use super::*;
277	use assert2::assert;
278
279	#[test]
280	fn guess_tensor_info() {
281		let data = tch::Tensor::from_slice(&(0..120).collect::<Vec<u8>>());
282
283		// Guess monochrome from compatible data.
284		assert!(data.reshape([12, 10, 1]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
285		assert!(data.reshape([1, 12, 10]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
286		assert!(data.reshape([12, 10]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
287
288		// Guess RGB[A]/BGR[A] from interlaced data.
289		assert!(data.reshape([8, 5, 3]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 8)));
290		assert!(data.reshape([8, 5, 3]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 8)));
291		assert!(data.reshape([5, 6, 4]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgba8(6, 5)));
292		assert!(data.reshape([5, 6, 4]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgra8(6, 5)));
293
294		// Guess RGB[A]/BGR[A] from planar data.
295		assert!(data.reshape([3, 8, 5]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 8)));
296		assert!(data.reshape([3, 8, 5]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 8)));
297		assert!(data.reshape([4, 5, 6]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgba8(6, 5)));
298		assert!(data.reshape([4, 5, 6]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgra8(6, 5)));
299
300		// Fail to guess on other dimensions
301		assert!(let Err(_) = data.reshape([120]).as_image_guess_rgb().map(|x| x.info));
302		assert!(let Err(_) = data.reshape([2, 10, 6]).as_image_guess_rgb().map(|x| x.info));
303		assert!(let Err(_) = data.reshape([6, 10, 2]).as_image_guess_rgb().map(|x| x.info));
304		assert!(let Err(_) = data.reshape([8, 5, 3, 1]).as_image_guess_rgb().map(|x| x.info));
305		assert!(let Err(_) = data.reshape([4, 5, 6, 1]).as_image_guess_rgb().map(|x| x.info));
306	}
307
308	#[test]
309	fn tensor_info_interlaced_with_known_format() {
310		let data = tch::Tensor::from_slice(&(0..60).collect::<Vec<u8>>());
311
312		// Monochrome
313		assert!(data.reshape([12, 5, 1]).as_mono8().map(|x| x.info) == Ok(ImageInfo::mono8(5, 12)));
314		assert!(data.reshape([12, 5]).as_mono8().map(|x| x.info) == Ok(ImageInfo::mono8(5, 12)));
315		assert!(let Err(_) = data.reshape([12, 5, 1, 1]).as_mono8().map(|x| x.info));
316		assert!(let Err(_) = data.reshape([6, 5, 2]).as_mono8().map(|x| x.info));
317		assert!(let Err(_) = data.reshape([3, 5, 4]).as_mono8().map(|x| x.info));
318		assert!(let Err(_) = data.reshape([4, 5, 3]).as_mono8().map(|x| x.info));
319		assert!(let Err(_) = data.reshape([60]).as_mono8().map(|x| x.info));
320
321		// RGB/BGR
322		assert!(data.reshape([4, 5, 3]).as_interlaced_rgb8().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 4)));
323		assert!(data.reshape([4, 5, 3]).as_interlaced_bgr8().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 4)));
324		assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_interlaced_bgr8().map(|x| x.info));
325		assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_interlaced_bgr8().map(|x| x.info));
326		assert!(let Err(_) = data.reshape([3, 5, 4]).as_interlaced_bgr8().map(|x| x.info));
327		assert!(let Err(_) = data.reshape([3, 5, 4]).as_interlaced_bgr8().map(|x| x.info));
328		assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgb8().map(|x| x.info));
329		assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgb8().map(|x| x.info));
330
331		// RGBA/BGRA
332		assert!(data.reshape([3, 5, 4]).as_interlaced_rgba8().map(|x| x.info) == Ok(ImageInfo::rgba8(5, 3)));
333		assert!(data.reshape([3, 5, 4]).as_interlaced_bgra8().map(|x| x.info) == Ok(ImageInfo::bgra8(5, 3)));
334		assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_interlaced_rgba8().map(|x| x.info));
335		assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_interlaced_bgra8().map(|x| x.info));
336		assert!(let Err(_) = data.reshape([4, 5, 3]).as_interlaced_rgba8().map(|x| x.info));
337		assert!(let Err(_) = data.reshape([4, 5, 3]).as_interlaced_bgra8().map(|x| x.info));
338		assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgba8().map(|x| x.info));
339		assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_bgra8().map(|x| x.info));
340	}
341
342	#[test]
343	fn tensor_info_planar_with_known_format() {
344		let data = tch::Tensor::from_slice(&(0..60).collect::<Vec<u8>>());
345
346		// RGB/BGR
347		assert!(data.reshape([3, 4, 5]).as_planar_rgb8().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 4)));
348		assert!(data.reshape([3, 4, 5]).as_planar_bgr8().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 4)));
349		assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_planar_bgr8().map(|x| x.info));
350		assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_planar_bgr8().map(|x| x.info));
351		assert!(let Err(_) = data.reshape([4, 5, 3]).as_planar_bgr8().map(|x| x.info));
352		assert!(let Err(_) = data.reshape([4, 5, 3]).as_planar_bgr8().map(|x| x.info));
353		assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgb8().map(|x| x.info));
354		assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgb8().map(|x| x.info));
355
356		// RGBA/BGRA
357		assert!(data.reshape([4, 3, 5]).as_planar_rgba8().map(|x| x.info) == Ok(ImageInfo::rgba8(5, 3)));
358		assert!(data.reshape([4, 3, 5]).as_planar_bgra8().map(|x| x.info) == Ok(ImageInfo::bgra8(5, 3)));
359		assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_planar_rgba8().map(|x| x.info));
360		assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_planar_bgra8().map(|x| x.info));
361		assert!(let Err(_) = data.reshape([3, 5, 4]).as_planar_rgba8().map(|x| x.info));
362		assert!(let Err(_) = data.reshape([3, 5, 4]).as_planar_bgra8().map(|x| x.info));
363		assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgba8().map(|x| x.info));
364		assert!(let Err(_) = data.reshape([15, 4]).as_planar_bgra8().map(|x| x.info));
365	}
366}