tract_proxy/
ndarray_interop.rs1#[macro_export]
32macro_rules! impl_ndarray_interop {
33 () => {
34 $crate::impl_ndarray_interop!(ndarray);
35 };
36 ($($nd:ident)::+) => {
37 trait Tract {
38 fn tract(self) -> $crate::__ndarray_interop::anyhow::Result<$crate::Tensor>;
39 }
40
41 impl<T, S, D> Tract for $($nd)::+::ArrayBase<S, D>
42 where
43 T: $crate::__ndarray_interop::Datum + Clone + 'static,
44 S: $($nd)::+::RawData<Elem = T> + $($nd)::+::Data,
45 D: $($nd)::+::Dimension,
46 {
47 fn tract(self) -> $crate::__ndarray_interop::anyhow::Result<$crate::Tensor> {
48 use $crate::__ndarray_interop::TensorInterface as _;
49 if let Some(slice) = self.as_slice_memory_order()
50 && (0..self.ndim()).all(|ix| {
51 self.strides().get(ix + 1).is_none_or(|next| *next <= self.strides()[ix])
52 })
53 {
54 $crate::Tensor::from_slice(self.shape(), slice)
55 } else {
56 let slice: ::std::vec::Vec<_> = self.iter().cloned().collect();
57 $crate::Tensor::from_slice(self.shape(), &slice)
58 }
59 }
60 }
61
62 trait Ndarray {
63 fn ndarray<T: $crate::__ndarray_interop::Datum>(
64 &self,
65 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayViewD<'_, T>>;
66 fn ndarray0<T: $crate::__ndarray_interop::Datum>(
67 &self,
68 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView0<'_, T>>;
69 fn ndarray1<T: $crate::__ndarray_interop::Datum>(
70 &self,
71 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView1<'_, T>>;
72 fn ndarray2<T: $crate::__ndarray_interop::Datum>(
73 &self,
74 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView2<'_, T>>;
75 fn ndarray3<T: $crate::__ndarray_interop::Datum>(
76 &self,
77 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView3<'_, T>>;
78 fn ndarray4<T: $crate::__ndarray_interop::Datum>(
79 &self,
80 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView4<'_, T>>;
81 fn ndarray5<T: $crate::__ndarray_interop::Datum>(
82 &self,
83 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView5<'_, T>>;
84 fn ndarray6<T: $crate::__ndarray_interop::Datum>(
85 &self,
86 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView6<'_, T>>;
87 }
88
89 impl Ndarray for $crate::Tensor {
90 fn ndarray<T: $crate::__ndarray_interop::Datum>(
91 &self,
92 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayViewD<'_, T>> {
93 use $crate::__ndarray_interop::TensorInterface as _;
94 let (shape, data) = self.as_shape_and_slice::<T>()?;
95 Ok(unsafe { $($nd)::+::ArrayViewD::from_shape_ptr(shape, data.as_ptr()) })
96 }
97 fn ndarray0<T: $crate::__ndarray_interop::Datum>(
98 &self,
99 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView0<'_, T>> {
100 Ok(self.ndarray::<T>()?.into_dimensionality()?)
101 }
102 fn ndarray1<T: $crate::__ndarray_interop::Datum>(
103 &self,
104 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView1<'_, T>> {
105 Ok(self.ndarray::<T>()?.into_dimensionality()?)
106 }
107 fn ndarray2<T: $crate::__ndarray_interop::Datum>(
108 &self,
109 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView2<'_, T>> {
110 Ok(self.ndarray::<T>()?.into_dimensionality()?)
111 }
112 fn ndarray3<T: $crate::__ndarray_interop::Datum>(
113 &self,
114 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView3<'_, T>> {
115 Ok(self.ndarray::<T>()?.into_dimensionality()?)
116 }
117 fn ndarray4<T: $crate::__ndarray_interop::Datum>(
118 &self,
119 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView4<'_, T>> {
120 Ok(self.ndarray::<T>()?.into_dimensionality()?)
121 }
122 fn ndarray5<T: $crate::__ndarray_interop::Datum>(
123 &self,
124 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView5<'_, T>> {
125 Ok(self.ndarray::<T>()?.into_dimensionality()?)
126 }
127 fn ndarray6<T: $crate::__ndarray_interop::Datum>(
128 &self,
129 ) -> $crate::__ndarray_interop::anyhow::Result<$($nd)::+::ArrayView6<'_, T>> {
130 Ok(self.ndarray::<T>()?.into_dimensionality()?)
131 }
132 }
133 };
134}
135
136#[doc(hidden)]
137pub mod __ndarray_interop {
138 pub use ::anyhow;
139 pub use ::tract_api::{Datum, TensorInterface};
140}