1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
use crate::datum::Blob; use crate::dim::TDim; use crate::prelude::*; use crate::TractResult; use ndarray::*; pub trait ArrayDatum: Sized { unsafe fn stack_tensors( axis: usize, tensors: &[impl std::borrow::Borrow<Tensor>], ) -> TractResult<Tensor>; unsafe fn stack_views(axis: usize, views: &[ArrayViewD<Self>]) -> TractResult<ArrayD<Self>>; unsafe fn uninitialized_array<S, D, Sh>(shape: Sh) -> ArrayBase<S, D> where Sh: ShapeBuilder<Dim = D>, S: DataOwned<Elem = Self>, D: Dimension; } macro_rules! impl_stack_views_by_copy( ($t: ty) => { impl ArrayDatum for $t { unsafe fn stack_tensors(axis: usize, tensors:&[impl std::borrow::Borrow<Tensor>]) -> TractResult<Tensor> { let arrays = tensors.iter().map(|t| t.borrow().to_array_view_unchecked::<$t>()).collect::<TVec<_>>(); Self::stack_views(axis, &arrays).map(|a| a.into_tensor()) } unsafe fn stack_views(axis: usize, views:&[ArrayViewD<$t>]) -> TractResult<ArrayD<$t>> { Ok(ndarray::stack(ndarray::Axis(axis), views)?) } unsafe fn uninitialized_array<S, D, Sh>(shape: Sh) -> ArrayBase<S, D> where Sh: ShapeBuilder<Dim = D>, S: DataOwned<Elem=Self>, D: Dimension { ArrayBase::<S,D>::uninitialized(shape) } } }; ); macro_rules! impl_stack_views_by_clone( ($t: ty) => { impl ArrayDatum for $t { unsafe fn stack_tensors(axis: usize, tensors:&[impl std::borrow::Borrow<Tensor>]) -> TractResult<Tensor> { let arrays = tensors.iter().map(|t| t.borrow().to_array_view::<$t>()).collect::<TractResult<TVec<_>>>()?; let views = arrays.iter().map(|a| a.view()).collect::<TVec<_>>(); Self::stack_views(axis, &views).map(|a| a.into_tensor()) } unsafe fn stack_views(axis: usize, views:&[ArrayViewD<$t>]) -> TractResult<ArrayD<$t>> { let mut shape = views[0].shape().to_vec(); shape[axis] = views.iter().map(|v| v.shape()[axis]).sum(); let mut array = ndarray::Array::default(&*shape); let mut offset = 0; for v in views { let len = v.shape()[axis]; array.slice_axis_mut(Axis(axis), (offset..(offset + len)).into()).assign(&v); offset += len; } Ok(array) } unsafe fn uninitialized_array<S, D, Sh>(shape: Sh) -> ArrayBase<S, D> where Sh: ShapeBuilder<Dim = D>, S: DataOwned<Elem=Self>, D: Dimension { ArrayBase::<S,D>::default(shape) } } }; ); impl_stack_views_by_copy!(i8); impl_stack_views_by_copy!(i16); impl_stack_views_by_copy!(i32); impl_stack_views_by_copy!(i64); impl_stack_views_by_clone!(Blob); impl_stack_views_by_clone!(String); impl_stack_views_by_clone!(TDim);