Skip to main content

tract_proxy/
ndarray_interop.rs

1/// Generate ndarray interop for [`tract_proxy::Tensor`][crate::Tensor] using the
2/// caller crate's own `ndarray` version.
3///
4/// `tract-proxy` itself has no public `ndarray` dependency: the tensor interface
5/// deals only in shapes, slices, bytes and primitive datums. If your
6/// application wants the ergonomics of `ndarray`, invoke this macro once at the
7/// root of your crate. The macro expands in your crate's scope, so the
8/// `ndarray::*` types referenced in the generated code resolve against *your*
9/// `ndarray` dependency.
10///
11/// The generated surface mirrors `tract::impl_ndarray_interop!` exactly: a
12/// `Tract` trait with `fn tract(self) -> anyhow::Result<tract_proxy::Tensor>`
13/// for `ndarray::ArrayBase`, and an `Ndarray` trait with `ndarray::<T>()` /
14/// `ndarray0..ndarray6::<T>()` on `tract_proxy::Tensor`.
15///
16/// # Invocation
17///
18/// Zero-argument form uses the `ndarray` crate from your crate's
19/// dependencies:
20///
21/// ```ignore
22/// tract_proxy::impl_ndarray_interop!();
23/// ```
24///
25/// Explicit form takes the ndarray root as a path — useful if your
26/// `Cargo.toml` renames the crate or pins multiple versions side by side:
27///
28/// ```ignore
29/// tract_proxy::impl_ndarray_interop!(ndarray_017);
30/// ```
31#[macro_export]
32macro_rules! impl_ndarray_interop {
33    () => {
34        $crate::impl_ndarray_interop!(ndarray);
35    };
36    ($($nd:ident)::+) => {
37        trait Tract {
38            fn tract(self) -> $crate::__ndarray_interop::anyhow::Result<$crate::Tensor>;
39        }
40
41        impl<T, S, D> Tract for $($nd)::+::ArrayBase<S, D>
42        where
43            T: $crate::__ndarray_interop::Datum + Clone + 'static,
44            S: $($nd)::+::RawData<Elem = T> + $($nd)::+::Data,
45            D: $($nd)::+::Dimension,
46        {
47            fn tract(self) -> $crate::__ndarray_interop::anyhow::Result<$crate::Tensor> {
48                use $crate::__ndarray_interop::TensorInterface as _;
49                if let Some(slice) = self.as_slice_memory_order()
50                    && (0..self.ndim()).all(|ix| {
51                        self.strides().get(ix + 1).is_none_or(|next| *next <= self.strides()[ix])
52                    })
53                {
54                    $crate::Tensor::from_slice(self.shape(), slice)
55                } else {
56                    let slice: ::std::vec::Vec<_> = self.iter().cloned().collect();
57                    $crate::Tensor::from_slice(self.shape(), &slice)
58                }
59            }
60        }
61
62        trait Ndarray {
63            fn ndarray<T: $crate::__ndarray_interop::Datum>(
64                &self,
65            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayViewD<'_, T>>;
66            fn ndarray0<T: $crate::__ndarray_interop::Datum>(
67                &self,
68            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView0<'_, T>>;
69            fn ndarray1<T: $crate::__ndarray_interop::Datum>(
70                &self,
71            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView1<'_, T>>;
72            fn ndarray2<T: $crate::__ndarray_interop::Datum>(
73                &self,
74            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView2<'_, T>>;
75            fn ndarray3<T: $crate::__ndarray_interop::Datum>(
76                &self,
77            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView3<'_, T>>;
78            fn ndarray4<T: $crate::__ndarray_interop::Datum>(
79                &self,
80            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView4<'_, T>>;
81            fn ndarray5<T: $crate::__ndarray_interop::Datum>(
82                &self,
83            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView5<'_, T>>;
84            fn ndarray6<T: $crate::__ndarray_interop::Datum>(
85                &self,
86            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView6<'_, T>>;
87        }
88
89        impl Ndarray for $crate::Tensor {
90            fn ndarray<T: $crate::__ndarray_interop::Datum>(
91                &self,
92            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayViewD<'_, T>> {
93                use $crate::__ndarray_interop::TensorInterface as _;
94                let (shape, data) = self.as_shape_and_slice::<T>()?;
95                Ok(unsafe { $($nd)::+::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
96            }
97            fn ndarray0<T: $crate::__ndarray_interop::Datum>(
98                &self,
99            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView0<'_, T>> {
100                Ok(self.ndarray::<T>()?.into_dimensionality()?)
101            }
102            fn ndarray1<T: $crate::__ndarray_interop::Datum>(
103                &self,
104            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView1<'_, T>> {
105                Ok(self.ndarray::<T>()?.into_dimensionality()?)
106            }
107            fn ndarray2<T: $crate::__ndarray_interop::Datum>(
108                &self,
109            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView2<'_, T>> {
110                Ok(self.ndarray::<T>()?.into_dimensionality()?)
111            }
112            fn ndarray3<T: $crate::__ndarray_interop::Datum>(
113                &self,
114            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView3<'_, T>> {
115                Ok(self.ndarray::<T>()?.into_dimensionality()?)
116            }
117            fn ndarray4<T: $crate::__ndarray_interop::Datum>(
118                &self,
119            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView4<'_, T>> {
120                Ok(self.ndarray::<T>()?.into_dimensionality()?)
121            }
122            fn ndarray5<T: $crate::__ndarray_interop::Datum>(
123                &self,
124            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView5<'_, T>> {
125                Ok(self.ndarray::<T>()?.into_dimensionality()?)
126            }
127            fn ndarray6<T: $crate::__ndarray_interop::Datum>(
128                &self,
129            ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView6<'_, T>> {
130                Ok(self.ndarray::<T>()?.into_dimensionality()?)
131            }
132        }
133    };
134}
135
136#[doc(hidden)]
137pub mod __ndarray_interop {
138    pub use ::anyhow;
139    pub use ::tract_api::{Datum, TensorInterface};
140}