1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
use crate::datum::Blob;
use crate::dim::TDim;
use crate::prelude::*;
use crate::TractResult;
use ndarray::*;

pub trait ArrayDatum: Sized {
    unsafe fn stack_tensors(
        axis: usize,
        tensors: &[impl std::borrow::Borrow<Tensor>],
    ) -> TractResult<Tensor>;
    unsafe fn stack_views(axis: usize, views: &[ArrayViewD<Self>]) -> TractResult<ArrayD<Self>>;
    unsafe fn uninitialized_array<S, D, Sh>(shape: Sh) -> ArrayBase<S, D>
    where
        Sh: ShapeBuilder<Dim = D>,
        S: DataOwned<Elem = Self>,
        D: Dimension;
}

macro_rules! impl_stack_views_by_copy(
    ($t: ty) => {
        impl ArrayDatum for $t {
            unsafe fn stack_tensors(axis: usize, tensors:&[impl std::borrow::Borrow<Tensor>]) -> TractResult<Tensor> {
                let arrays = tensors.iter().map(|t| t.borrow().to_array_view_unchecked::<$t>()).collect::<TVec<_>>();
                Self::stack_views(axis, &arrays).map(|a| a.into_tensor())
            }

            unsafe fn stack_views(axis: usize, views:&[ArrayViewD<$t>]) -> TractResult<ArrayD<$t>> {
                Ok(ndarray::stack(ndarray::Axis(axis), views)?)
            }
            unsafe fn uninitialized_array<S, D, Sh>(shape: Sh) -> ArrayBase<S, D> where
                Sh: ShapeBuilder<Dim = D>,
                S: DataOwned<Elem=Self>,
                D: Dimension {
                    ArrayBase::<S,D>::uninitialized(shape)
                }
        }
    };
);

macro_rules! impl_stack_views_by_clone(
    ($t: ty) => {
        impl ArrayDatum for $t {
            unsafe fn stack_tensors(axis: usize, tensors:&[impl std::borrow::Borrow<Tensor>]) -> TractResult<Tensor> {
                let arrays = tensors.iter().map(|t| t.borrow().to_array_view::<$t>()).collect::<TractResult<TVec<_>>>()?;
                let views = arrays.iter().map(|a| a.view()).collect::<TVec<_>>();
                Self::stack_views(axis, &views).map(|a| a.into_tensor())
            }

            unsafe fn stack_views(axis: usize, views:&[ArrayViewD<$t>]) -> TractResult<ArrayD<$t>> {
                let mut shape = views[0].shape().to_vec();
                shape[axis] = views.iter().map(|v| v.shape()[axis]).sum();
                let mut array = ndarray::Array::default(&*shape);
                let mut offset = 0;
                for v in views {
                    let len = v.shape()[axis];
                    array.slice_axis_mut(Axis(axis), (offset..(offset + len)).into()).assign(&v);
                    offset += len;
                }
                Ok(array)
            }

            unsafe fn uninitialized_array<S, D, Sh>(shape: Sh) -> ArrayBase<S, D> where
                Sh: ShapeBuilder<Dim = D>,
                S: DataOwned<Elem=Self>,
                D: Dimension {
                    ArrayBase::<S,D>::default(shape)
                }
        }
    };
);

impl_stack_views_by_copy!(i8);
impl_stack_views_by_copy!(i16);
impl_stack_views_by_copy!(i32);
impl_stack_views_by_copy!(i64);

impl_stack_views_by_clone!(Blob);
impl_stack_views_by_clone!(String);
impl_stack_views_by_clone!(TDim);