use crate::internal::*;
use ndarray::prelude::*;
use std::alloc;
use std::fmt;
use std::mem::size_of;
use tract_linalg::align;
use tract_linalg::f16::f16;
#[cfg(feature = "serialize")]
use serde::ser::{Serialize, Serializer};
use std::sync::Arc;
use crate::datum::TryInto;
#[derive(Clone)]
pub struct SharedTensor(Arc<Tensor>);
impl SharedTensor {
pub fn as_tensor(&self) -> &Tensor {
self.0.as_ref()
}
pub fn to_tensor(self) -> Tensor {
Arc::try_unwrap(self.0).unwrap_or_else(|arc| arc.as_ref().clone())
}
pub fn to_array<'a, D: crate::datum::Datum>(self) -> TractResult<::ndarray::ArrayD<D>> {
self.to_tensor().into_array()
}
}
impl<M> From<M> for SharedTensor
where
Tensor: From<M>,
{
fn from(m: M) -> SharedTensor {
SharedTensor::from(Arc::new(m.into()))
}
}
impl From<Arc<Tensor>> for SharedTensor {
fn from(m: Arc<Tensor>) -> SharedTensor {
SharedTensor(m)
}
}
impl ::std::ops::Deref for SharedTensor {
type Target = Tensor;
fn deref(&self) -> &Tensor {
self.0.as_ref()
}
}
impl PartialEq for SharedTensor {
fn eq(&self, other: &SharedTensor) -> bool {
self.as_tensor() == other.as_tensor()
}
}
impl fmt::Debug for SharedTensor {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(fmt)
}
}
pub struct Tensor {
null: bool,
dt: DatumType,
shape: TVec<usize>,
alignment: usize,
data: Vec<u8>,
}
impl Clone for Tensor {
fn clone(&self) -> Tensor {
assert!(
(self.alignment as u32).count_ones() == 1,
"Invalid alignment in tensor ({})",
self.alignment
);
Tensor {
shape: self.shape.clone(),
data: align::realign_slice(&self.data, self.alignment),
..*self
}
}
}
impl Default for Tensor {
fn default() -> Tensor {
Tensor::from(arr0(0f32))
}
}
impl Tensor {
pub unsafe fn uninitialized_aligned<T: Datum>(
shape: &[usize],
alignment: usize,
) -> TractResult<Tensor> {
assert!((alignment as u32).count_ones() == 1, "Invalid alignment required ({})", alignment);
let len = shape.iter().cloned().product::<usize>() * size_of::<T>();
let data = if len == 0 {
vec![]
} else {
let aligned_buffer =
alloc::alloc(alloc::Layout::from_size_align(len, alignment).map_err(|e| {
format!("Memory layout error: {:?} ({}, {})", e, len, alignment)
})?);
Vec::from_raw_parts(aligned_buffer as _, len, len)
};
Ok(Tensor { null: false, dt: T::datum_type(), shape: shape.into(), alignment, data })
}
pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
let data = align::realign_slice(content, size_of::<T>());
Ok(Tensor {
null: false,
dt: T::datum_type(),
shape: shape.into(),
data,
alignment: T::datum_type().alignment(),
})
}
pub fn into_aligned(self, alignment: usize) -> TractResult<Tensor> {
Ok(Tensor {
null: self.null,
dt: self.datum_type(),
shape: self.shape,
data: align::realign_vec(self.data, alignment),
alignment,
})
}
pub unsafe fn null<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
Self::null_dt(T::datum_type(), shape)
}
pub unsafe fn null_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
Ok(Tensor { null: true, dt, shape: shape.into(), data: vec![], alignment: dt.alignment() })
}
pub fn is_null(&self) -> bool {
self.null
}
pub fn into_tensor(self) -> SharedTensor {
SharedTensor::from(self)
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn into_shape(self, shape: &[usize]) -> TractResult<Tensor> {
Ok(Tensor { shape: shape.into(), ..self })
}
pub fn datum_type(&self) -> DatumType {
self.dt
}
pub fn dump_t<D: Datum>(&self, force_full: bool) -> TractResult<String> {
use itertools::Itertools;
let spec = TensorFact::dt_shape(D::datum_type(), &*self.shape);
let data = self.to_array_view::<D>()?;
let s = if force_full || data.len() <= 12 {
format!("{} {}", spec.format_dt_shape(), data.iter().join(", "))
} else {
format!("{} {}...", spec.format_dt_shape(), data.iter().take(8).join(", "))
};
Ok(s)
}
pub fn dump(&self, force_full: bool) -> TractResult<String> {
dispatch_datum!(Self::dump_t(self.dt)(self, force_full))
}
pub fn close_enough(&self, other: &Self, approx: bool) -> bool {
if self.is_null() != other.is_null() {
return false;
}
let ma = self.cast_to::<f32>().unwrap();
let ma = ma.to_array_view::<f32>().unwrap();
let mb = other.cast_to::<f32>().unwrap();
let mb = mb.to_array_view::<f32>().unwrap();
let avg =
ma.iter().filter(|a| a.is_finite()).map(|&a| a.abs()).sum::<f32>() / ma.len() as f32;
let dev = (ma.iter().filter(|a| a.is_finite()).map(|&a| (a - avg).powi(2)).sum::<f32>()
/ ma.len() as f32)
.sqrt();
let margin = if approx { (dev / 5.0).max(avg.abs() / 10_000.0).max(1e-5) } else { 0.0 };
ma.shape() == mb.shape()
&& mb
.iter()
.zip(ma.iter())
.map(|(&a, &b)| {
(
a,
b,
a.is_nan() && b.is_nan() || a == b || (b - a).abs() <= margin,
)
})
.all(|t| t.2)
}
pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
let casted = unsafe { vec_to_datum::<D>(self.data) };
unsafe { Ok(ArrayD::from_shape_vec_unchecked(&*self.shape, casted)) }
}
pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
Ok(self.data.as_ptr() as *const D)
}
pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
let datum_size = ::std::mem::size_of::<D>();
unsafe {
Ok(std::slice::from_raw_parts::<D>(
self.data.as_ptr() as *const D,
self.data.len() / datum_size,
))
}
}
pub fn to_array_view<'a, D: Datum>(&'a self) -> TractResult<ArrayViewD<'a, D>> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
if self.data.len() != 0 {
unsafe {
return Ok(ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as _));
}
} else {
return Ok(ArrayViewD::from_shape(&*self.shape, &[])?);
}
}
pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
let datum_size = ::std::mem::size_of::<D>();
unsafe {
Ok(std::slice::from_raw_parts_mut::<D>(
self.data.as_mut_ptr() as *mut D,
self.data.len() / datum_size,
))
}
}
pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
Ok(self.data.as_mut_ptr() as *mut D)
}
pub fn to_array_view_mut<'a, D: Datum>(&'a mut self) -> TractResult<ArrayViewMutD<'a, D>> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
let shape = self.shape.clone();
if self.data.len() != 0 {
unsafe { Ok(ArrayViewMutD::from_shape_ptr(&*shape, self.data.as_mut_ptr() as _)) }
} else {
return Ok(ArrayViewMutD::from_shape(&*self.shape, &mut [])?);
}
}
pub fn to_scalar<'a, D: Datum>(&'a self) -> TractResult<&D> {
if self.datum_type() != D::datum_type() {
bail!(
"Incompatible datum type. Required {:?}, got {:?}",
D::datum_type(),
self.datum_type()
);
}
if self.is_null() {
bail!("Null tensor")
}
unsafe { Ok(&*(self.data.as_ptr() as *const D)) }
}
fn cast_data<Source: Datum + TryInto<Target>, Target: Datum>(
&self,
) -> TractResult<Vec<Target>> {
self.as_slice::<Source>()?.iter().map(|s| s.try_into()).collect()
}
fn cast<Source: Datum + TryInto<Target>, Target: Datum>(&self) -> TractResult<Tensor> {
let data = self.cast_data::<Source, Target>()?;
Ok(Tensor {
null: self.null,
dt: Target::datum_type(),
shape: self.shape.clone(),
data: vec_to_u8(data),
alignment: Target::datum_type().alignment(),
})
}
pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<Tensor>> {
self.cast_to_dt(D::datum_type())
}
pub fn cast_to_dt(&self, dt: DatumType) -> TractResult<Cow<Tensor>> {
use DatumType::*;
if self.dt == dt {
return Ok(Cow::Borrowed(self));
}
let target = match (self.dt, dt) {
(TDim, I32) => self.cast::<crate::dim::TDim, i32>()?,
(TDim, I64) => self.cast::<crate::dim::TDim, i64>()?,
(I32, TDim) => self.cast::<i32, crate::dim::TDim>()?,
(I64, TDim) => self.cast::<i64, crate::dim::TDim>()?,
(F16, F32) => self.cast::<f16, f32>()?,
(F32, F16) => self.cast::<f32, f16>()?,
(F16, F64) => self.cast::<f16, f64>()?,
(F64, F16) => self.cast::<f64, f16>()?,
(F32, F64) => self.cast::<f32, f64>()?,
(F64, F32) => self.cast::<f64, f32>()?,
(I8, I16) => self.cast::<i8, i16>()?,
(I16, I8) => self.cast::<i16, i8>()?,
(I8, I32) => self.cast::<i8, i32>()?,
(I32, I8) => self.cast::<i32, i8>()?,
(I8, I64) => self.cast::<i8, i64>()?,
(I64, I8) => self.cast::<i64, i8>()?,
(I16, I32) => self.cast::<i16, i32>()?,
(I32, I16) => self.cast::<i32, i16>()?,
(I16, I64) => self.cast::<i16, i64>()?,
(I64, I16) => self.cast::<i64, i16>()?,
(I32, I64) => self.cast::<i32, i64>()?,
(I64, I32) => self.cast::<i64, i32>()?,
(Bool, F32) => self.cast::<bool, f32>()?,
(I8, F32) => self.cast::<i8, f32>()?,
(I16, F32) => self.cast::<i16, f32>()?,
(I32, F32) => self.cast::<i32, f32>()?,
(I64, F32) => self.cast::<i64, f32>()?,
(F32, String) => self.cast::<f32, std::string::String>()?,
(String, F32) => self.cast::<std::string::String, f32>()?,
_ => bail!("Unsupported cast from {:?} to {:?}", self.dt, dt),
};
Ok(Cow::Owned(target))
}
fn eq_t<D: Datum>(&self, other: &Tensor) -> TractResult<bool> {
Ok(self.to_array_view::<D>()? == other.to_array_view::<D>()?)
}
fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
dispatch_datum!(Self::eq_t(self.dt)(self, other))
}
}
impl PartialEq for Tensor {
fn eq(&self, other: &Tensor) -> bool {
if self.dt != other.dt || self.shape != other.shape {
return false;
}
self.eq_dt(other).unwrap_or(false)
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
let content = self.dump(false).unwrap_or_else(|e| format!("Error : {:?}", e));
write!(formatter, "{}", content)
}
}
#[cfg(feature = "serialize")]
impl Serialize for Tensor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
macro_rules! serialize_inner {
($type:ident, $m:ident) => {{
let data =
(stringify!($type), self.shape(), $m.iter().cloned().collect::<Vec<_>>());
data.serialize(serializer)
}};
};
use Tensor::*;
match self {
Bool(m) => serialize_inner!(bool, m),
U8(m) => serialize_inner!(u8, m),
U16(m) => serialize_inner!(u16, m),
I8(m) => serialize_inner!(i8, m),
I16(m) => serialize_inner!(i16, m),
I32(m) => serialize_inner!(i32, m),
I64(m) => serialize_inner!(i64, m),
F16(m) => serialize_inner!(f16, m),
F32(m) => serialize_inner!(f32, m),
F64(m) => serialize_inner!(f64, m),
TDim(m) => serialize_inner!(TDim, m),
String(m) => serialize_inner!(str, m),
}
}
}
fn vec_to_u8<T: Datum>(mut data: Vec<T>) -> Vec<u8> {
let v = unsafe {
Vec::from_raw_parts(
data.as_mut_ptr() as *mut u8,
data.len() * ::std::mem::size_of::<T>(),
data.capacity() * ::std::mem::size_of::<T>(),
)
};
::std::mem::forget(data);
v
}
unsafe fn vec_to_datum<T: Datum>(mut data: Vec<u8>) -> Vec<T> {
assert!(data.as_ptr() as usize % T::datum_type().alignment() == 0);
let v = Vec::from_raw_parts(
data.as_mut_ptr() as *mut T,
data.len() / ::std::mem::size_of::<T>(),
data.capacity() / ::std::mem::size_of::<T>(),
);
::std::mem::forget(data);
v
}
impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
fn from(it: Array<T, D>) -> Tensor {
let shape = it.shape().into();
let data = if it.is_standard_layout() {
it.into_raw_vec()
} else {
it.view().into_iter().cloned().collect()
};
let raw_data = vec_to_u8(data);
Tensor {
null: false,
dt: T::datum_type(),
shape,
data: raw_data,
alignment: T::datum_type().alignment(),
}
}
}
pub fn arr4<A, V, U, T>(xs: &[V]) -> ::ndarray::Array4<A>
where
V: ::ndarray::FixedInitializer<Elem = U> + Clone,
U: ::ndarray::FixedInitializer<Elem = T> + Clone,
T: ::ndarray::FixedInitializer<Elem = A> + Clone,
A: Clone,
{
use ndarray::*;
let mut xs = xs.to_vec();
let dim = Ix4(xs.len(), V::len(), U::len(), T::len());
let ptr = xs.as_mut_ptr();
let len = xs.len();
let cap = xs.capacity();
let expand_len = len * V::len() * U::len() * T::len();
::std::mem::forget(xs);
unsafe {
let v = if ::std::mem::size_of::<A>() == 0 {
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
} else if V::len() == 0 || U::len() == 0 || T::len() == 0 {
Vec::new()
} else {
let expand_cap = cap * V::len() * U::len() * T::len();
Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
};
ArrayBase::from_shape_vec_unchecked(dim, v)
}
}