use crate::error::ImageDataError;
use crate::Alpha;
use crate::BoxImage;
use crate::Image;
use crate::ImageInfo;
use crate::PixelFormat;
pub struct TensorImage<'a> {
tensor: &'a tch::Tensor,
info: ImageInfo,
planar: bool,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum TensorPixelFormat {
Planar(PixelFormat),
Interlaced(PixelFormat),
Guess(ColorFormat),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ColorFormat {
Rgb,
Bgr,
}
#[allow(clippy::needless_lifetimes)]
pub trait TensorAsImage {
fn as_image<'a>(&'a self, pixel_format: TensorPixelFormat) -> Result<TensorImage<'a>, ImageDataError>;
fn as_interlaced<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, ImageDataError> {
self.as_image(TensorPixelFormat::Interlaced(pixel_format))
}
fn as_planar<'a>(&'a self, pixel_format: PixelFormat) -> Result<TensorImage<'a>, ImageDataError> {
self.as_image(TensorPixelFormat::Planar(pixel_format))
}
fn as_image_guess<'a>(&'a self, color_format: ColorFormat) -> Result<TensorImage<'a>, ImageDataError> {
self.as_image(TensorPixelFormat::Guess(color_format))
}
fn as_image_guess_rgb<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_image_guess(ColorFormat::Rgb)
}
fn as_image_guess_bgr<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_image_guess(ColorFormat::Bgr)
}
fn as_mono8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_interlaced(PixelFormat::Mono8)
}
fn as_interlaced_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_interlaced(PixelFormat::Rgb8)
}
fn as_interlaced_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_interlaced(PixelFormat::Rgba8(Alpha::Unpremultiplied))
}
fn as_interlaced_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_interlaced(PixelFormat::Bgr8)
}
fn as_interlaced_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_interlaced(PixelFormat::Bgra8(Alpha::Unpremultiplied))
}
fn as_planar_rgb8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_planar(PixelFormat::Rgb8)
}
fn as_planar_rgba8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_planar(PixelFormat::Rgba8(Alpha::Unpremultiplied))
}
fn as_planar_bgr8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_planar(PixelFormat::Bgr8)
}
fn as_planar_bgra8<'a>(&'a self) -> Result<TensorImage<'a>, ImageDataError> {
self.as_planar(PixelFormat::Bgra8(Alpha::Unpremultiplied))
}
}
impl TensorAsImage for tch::Tensor {
fn as_image(&self, pixel_format: TensorPixelFormat) -> Result<TensorImage, ImageDataError> {
let (planar, info) = match pixel_format {
TensorPixelFormat::Planar(pixel_format) => (true, tensor_info(self, pixel_format, true)?),
TensorPixelFormat::Interlaced(pixel_format) => (false, tensor_info(self, pixel_format, false)?),
TensorPixelFormat::Guess(color_format) => guess_tensor_info(self, color_format)?,
};
Ok(TensorImage {
tensor: self,
info,
planar,
})
}
}
fn tensor_to_byte_vec(tensor: &tch::Tensor) -> Vec<u8> {
let size = tensor.numel() * tensor.kind().elt_size_in_bytes();
let mut data = vec![0u8; size];
tensor.copy_data_u8(&mut data, tensor.numel());
data
}
impl<'a> From<TensorImage<'a>> for Image {
fn from(other: TensorImage<'a>) -> Self {
let data = if other.planar {
tensor_to_byte_vec(&other.tensor.permute([1, 2, 0]))
} else {
tensor_to_byte_vec(other.tensor)
};
BoxImage::new(other.info, data.into_boxed_slice()).into()
}
}
impl<'a> From<Result<TensorImage<'a>, ImageDataError>> for Image {
fn from(other: Result<TensorImage<'a>, ImageDataError>) -> Self {
match other {
Ok(x) => x.into(),
Err(e) => Image::Invalid(e),
}
}
}
#[allow(clippy::branches_sharing_code)] fn tensor_info(tensor: &tch::Tensor, pixel_format: PixelFormat, planar: bool) -> Result<ImageInfo, String> {
let expected_channels = pixel_format.channels();
let dimensions = tensor.dim();
if dimensions == 3 {
let shape = tensor.size3().unwrap();
if planar {
let (channels, height, width) = shape;
if channels != i64::from(expected_channels) {
Err(format!("expected shape ({}, height, width), found {:?}", expected_channels, shape))
} else {
Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
}
} else {
let (height, width, channels) = shape;
if channels != i64::from(expected_channels) {
Err(format!("expected shape (height, width, {}), found {:?}", expected_channels, shape))
} else {
Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
}
}
} else if dimensions == 2 && expected_channels == 1 {
let (height, width) = tensor.size2().unwrap();
Ok(ImageInfo::new(pixel_format, width as u32, height as u32))
} else {
Err(format!(
"wrong number of dimensions ({}) for format ({:?})",
dimensions, pixel_format
))
}
}
fn guess_tensor_info(tensor: &tch::Tensor, color_format: ColorFormat) -> Result<(bool, ImageInfo), String> {
let dimensions = tensor.dim();
if dimensions == 2 {
let (height, width) = tensor.size2().unwrap();
Ok((false, ImageInfo::mono8(width as u32, height as u32)))
} else if dimensions == 3 {
let shape = tensor.size3().unwrap();
match (shape.0 as u32, shape.1 as u32, shape.2 as u32, color_format) {
(h, w, 1, _) => Ok((false, ImageInfo::mono8(w, h))),
(1, h, w, _) => Ok((false, ImageInfo::mono8(w, h))), (h, w, 3, ColorFormat::Rgb) => Ok((false, ImageInfo::rgb8(w, h))),
(h, w, 3, ColorFormat::Bgr) => Ok((false, ImageInfo::bgr8(w, h))),
(3, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgb8(w, h))),
(3, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgr8(w, h))),
(h, w, 4, ColorFormat::Rgb) => Ok((false, ImageInfo::rgba8(w, h))),
(h, w, 4, ColorFormat::Bgr) => Ok((false, ImageInfo::bgra8(w, h))),
(4, h, w, ColorFormat::Rgb) => Ok((true, ImageInfo::rgba8(w, h))),
(4, h, w, ColorFormat::Bgr) => Ok((true, ImageInfo::bgra8(w, h))),
_ => Err(format!("unable to guess pixel format for tensor with shape {:?}, expected (height, width) or (height, width, channels) or (channels, height, width) where channels is either 1, 3 or 4", shape))
}
} else {
Err(format!(
"unable to guess pixel format for tensor with {} dimensions, expected 2 or 3 dimensions",
dimensions
))
}
}
#[cfg(test)]
mod test {
use super::*;
use assert2::assert;
#[test]
fn guess_tensor_info() {
let data = tch::Tensor::from_slice(&(0..120).collect::<Vec<u8>>());
assert!(data.reshape([12, 10, 1]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
assert!(data.reshape([1, 12, 10]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
assert!(data.reshape([12, 10]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::mono8(10, 12)));
assert!(data.reshape([8, 5, 3]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 8)));
assert!(data.reshape([8, 5, 3]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 8)));
assert!(data.reshape([5, 6, 4]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgba8(6, 5)));
assert!(data.reshape([5, 6, 4]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgra8(6, 5)));
assert!(data.reshape([3, 8, 5]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 8)));
assert!(data.reshape([3, 8, 5]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 8)));
assert!(data.reshape([4, 5, 6]).as_image_guess_rgb().map(|x| x.info) == Ok(ImageInfo::rgba8(6, 5)));
assert!(data.reshape([4, 5, 6]).as_image_guess_bgr().map(|x| x.info) == Ok(ImageInfo::bgra8(6, 5)));
assert!(let Err(_) = data.reshape([120]).as_image_guess_rgb().map(|x| x.info));
assert!(let Err(_) = data.reshape([2, 10, 6]).as_image_guess_rgb().map(|x| x.info));
assert!(let Err(_) = data.reshape([6, 10, 2]).as_image_guess_rgb().map(|x| x.info));
assert!(let Err(_) = data.reshape([8, 5, 3, 1]).as_image_guess_rgb().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 6, 1]).as_image_guess_rgb().map(|x| x.info));
}
#[test]
fn tensor_info_interlaced_with_known_format() {
let data = tch::Tensor::from_slice(&(0..60).collect::<Vec<u8>>());
assert!(data.reshape([12, 5, 1]).as_mono8().map(|x| x.info) == Ok(ImageInfo::mono8(5, 12)));
assert!(data.reshape([12, 5]).as_mono8().map(|x| x.info) == Ok(ImageInfo::mono8(5, 12)));
assert!(let Err(_) = data.reshape([12, 5, 1, 1]).as_mono8().map(|x| x.info));
assert!(let Err(_) = data.reshape([6, 5, 2]).as_mono8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4]).as_mono8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3]).as_mono8().map(|x| x.info));
assert!(let Err(_) = data.reshape([60]).as_mono8().map(|x| x.info));
assert!(data.reshape([4, 5, 3]).as_interlaced_rgb8().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 4)));
assert!(data.reshape([4, 5, 3]).as_interlaced_bgr8().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 4)));
assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_interlaced_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_interlaced_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4]).as_interlaced_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4]).as_interlaced_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgb8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgb8().map(|x| x.info));
assert!(data.reshape([3, 5, 4]).as_interlaced_rgba8().map(|x| x.info) == Ok(ImageInfo::rgba8(5, 3)));
assert!(data.reshape([3, 5, 4]).as_interlaced_bgra8().map(|x| x.info) == Ok(ImageInfo::bgra8(5, 3)));
assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_interlaced_rgba8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_interlaced_bgra8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3]).as_interlaced_rgba8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3]).as_interlaced_bgra8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_rgba8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_interlaced_bgra8().map(|x| x.info));
}
#[test]
fn tensor_info_planar_with_known_format() {
let data = tch::Tensor::from_slice(&(0..60).collect::<Vec<u8>>());
assert!(data.reshape([3, 4, 5]).as_planar_rgb8().map(|x| x.info) == Ok(ImageInfo::rgb8(5, 4)));
assert!(data.reshape([3, 4, 5]).as_planar_bgr8().map(|x| x.info) == Ok(ImageInfo::bgr8(5, 4)));
assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_planar_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3, 1]).as_planar_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3]).as_planar_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([4, 5, 3]).as_planar_bgr8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgb8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgb8().map(|x| x.info));
assert!(data.reshape([4, 3, 5]).as_planar_rgba8().map(|x| x.info) == Ok(ImageInfo::rgba8(5, 3)));
assert!(data.reshape([4, 3, 5]).as_planar_bgra8().map(|x| x.info) == Ok(ImageInfo::bgra8(5, 3)));
assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_planar_rgba8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4, 1]).as_planar_bgra8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4]).as_planar_rgba8().map(|x| x.info));
assert!(let Err(_) = data.reshape([3, 5, 4]).as_planar_bgra8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_planar_rgba8().map(|x| x.info));
assert!(let Err(_) = data.reshape([15, 4]).as_planar_bgra8().map(|x| x.info));
}
}