sophus_tensor/
mut_tensor_view.rs

1use core::marker::PhantomData;
2
3use concat_arrays::concat_arrays;
4
5use crate::{
6    prelude::*,
7    MutTensor,
8    TensorView,
9};
10
11/// Mutable tensor view
12///
13/// See TensorView for more details of the tensor structure
14#[derive(Debug, PartialEq, Eq)]
15pub struct MutTensorView<
16    'a,
17    const TOTAL_RANK: usize,
18    const DRANK: usize,
19    const SRANK: usize,
20    Scalar: IsCoreScalar + 'static,
21    STensor: IsStaticTensor<Scalar, SRANK, ROWS, COLS> + 'static,
22    const ROWS: usize,
23    const COLS: usize,
24> where
25    ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
26    ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>: ndarray::Dimension,
27{
28    /// mutable ndarray view of the static tensors with shape [D1, D2, ...]
29    pub elem_view_mut: ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; DRANK]>>,
30    /// mutable ndarray view of the scalars with shape [D1, D2, ..., S0, S1, ...]
31    pub scalar_view_mut: ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>>,
32}
33
34/// A mutable tensor like object
35pub trait IsMutTensorLike<
36    'a,
37    const TOTAL_RANK: usize,
38    const DRANK: usize,
39    const SRANK: usize,
40    Scalar: IsCoreScalar + 'static,
41    STensor: IsStaticTensor<Scalar, SRANK, ROWS, COLS> + 'static,
42    const ROWS: usize,
43    const COLS: usize,
44>: IsTensorLike<'a, TOTAL_RANK, DRANK, SRANK, Scalar, STensor, ROWS, COLS> where
45    ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
46    ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>: ndarray::Dimension,
47{
48    /// mutable ndarray view of the static tensors with shape [D1, D2, ...]
49    fn elem_view_mut<'b: 'a>(
50        &'b mut self,
51    ) -> ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; DRANK]>>;
52
53    /// mutable ndarray view of the scalars with shape [D1, D2, ..., S0, S1, ...]
54    fn scalar_view_mut<'b: 'a>(
55        &'b mut self,
56    ) -> ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>>;
57
58    /// mutable reference to the static tensor at index idx
59    fn get_mut(&'a mut self, idx: [usize; DRANK]) -> &'a mut STensor;
60}
61
62macro_rules! mut_view_is_view {
63    ($scalar_rank:literal, $srank:literal, $drank:literal) => {
64
65        impl<
66                'a,
67                Scalar: IsCoreScalar + 'static,
68                STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
69                const ROWS: usize,
70                const COLS: usize,
71            > MutTensorView<'a, $scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
72        {
73
74            /// Returns a tensor view
75            pub fn view(
76                & self,
77            ) -> TensorView<'_, $scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
78            {
79               let v =  TensorView {
80                    elem_view: self.elem_view_mut.view(),
81                    scalar_view: self.scalar_view_mut.view(),
82                };
83                v
84            }
85
86            /// new mutable tensor view from a mutable ndarray of static tensors
87            pub fn new(
88                elem_view_mut: ndarray::ArrayViewMut<
89                    'a,
90                    STensor,
91                    ndarray::Dim<[ndarray::Ix; $drank]>,
92                >,
93            ) -> Self {
94                let dims: [usize; $drank] = elem_view_mut.shape().try_into().unwrap();
95                #[allow(clippy::drop_non_drop)]
96                let shape: [usize; $scalar_rank] = concat_arrays!(dims, STensor::sdims());
97
98                let dstrides: [isize; $drank] = elem_view_mut.strides().try_into().unwrap();
99                let mut dstrides: [usize; $drank] = dstrides.map(|x| x as usize);
100                let num_scalars = STensor::num_scalars();
101                for d in dstrides.iter_mut() {
102                    *d *= num_scalars;
103                }
104                #[allow(clippy::drop_non_drop)]
105                let strides = concat_arrays!(dstrides, STensor::get_strides());
106
107                let ptr = elem_view_mut.as_ptr() as *mut Scalar;
108                use ndarray::ShapeBuilder;
109                assert_eq!(core::mem::size_of::<STensor>(),
110                   core::mem::size_of::<Scalar>() * ROWS * COLS
111                );
112
113                let scalar_view_mut =
114                    unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape.strides(strides), ptr) };
115
116                Self {
117                    elem_view_mut,
118                    scalar_view_mut,
119                }
120            }
121
122            /// get mutable reference to scalar at index idx
123            pub fn mut_scalar(&'a mut self, idx: [usize; $scalar_rank]) -> &'a mut Scalar{
124                &mut self.scalar_view_mut[idx]
125            }
126
127            /// fills self using a unary operator applied to the elements of another tensor
128            pub fn map<
129                'b,
130                const OTHER_HRANK: usize,
131                const OTHER_SRANK: usize,
132                OtherScalar: IsCoreScalar + 'static,
133                OtherSTensor: IsStaticTensor<
134                    OtherScalar,
135                    OTHER_SRANK,
136                    OTHER_ROWS,
137                    OTHER_COLS,
138                > + 'static,
139                const OTHER_ROWS: usize,
140                const OTHER_COLS: usize,
141                V : IsTensorView::<
142                    'b,
143                    OTHER_HRANK,
144                    $drank,
145                    OTHER_SRANK,
146                    OtherScalar,
147                    OtherSTensor,
148                    OTHER_ROWS,
149                    OTHER_COLS,
150                >,
151                F: FnMut(&mut STensor, &OtherSTensor)
152            >(
153                &'a mut self,
154                view: &'b V,
155                op: F,
156            ) where
157                ndarray::Dim<[ndarray::Ix; OTHER_HRANK]>: ndarray::Dimension
158            {
159                self.elem_view_mut.zip_mut_with(&view.elem_view(),op);
160            }
161        }
162
163
164
165
166        impl<
167        'a,
168                Scalar: IsCoreScalar + 'static,
169                STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
170                const ROWS: usize,
171                const COLS: usize,
172            > IsTensorLike<'a, $scalar_rank, $drank, $srank, Scalar, STensor,  ROWS, COLS>
173            for MutTensorView<
174                'a,
175                $scalar_rank,
176                $drank,
177                $srank,
178                Scalar,
179                STensor,
180                ROWS,
181                COLS
182>        {
183            fn elem_view<'b:'a>(
184                &'b self,
185            ) -> ndarray::ArrayView<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>> {
186                self.view().elem_view
187            }
188
189            fn get(& self, idx: [usize; $drank]) -> STensor {
190                self.view().get(idx)
191            }
192
193            fn dims(&self) -> [usize; $drank] {
194                self.view().dims()
195            }
196
197            fn scalar_view<'b:'a>(
198                &'b self,
199            ) -> ndarray::ArrayView<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>> {
200                self.view().scalar_view
201            }
202
203            fn scalar_get(&'a self, idx: [usize; $scalar_rank]) -> Scalar {
204                self.view().scalar_get(idx)
205            }
206
207            fn scalar_dims(&self) -> [usize; $scalar_rank] {
208                self.view().scalar_dims()
209            }
210
211            fn to_mut_tensor(
212                &self,
213            ) -> MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor,  ROWS, COLS> {
214                MutTensor {
215                    mut_array: self.view().elem_view.to_owned(),
216                    phantom: PhantomData::default(),
217                }
218            }
219        }
220
221        impl<
222        'a,
223                Scalar: IsCoreScalar + 'static,
224                STensor: IsStaticTensor<Scalar, $srank,  ROWS, COLS> + 'static,
225                const ROWS: usize,
226                const COLS: usize,
227            >
228            IsMutTensorLike<'a,
229                $scalar_rank,
230                $drank, $srank,
231                Scalar, STensor,
232                ROWS,
233                COLS            >
234            for MutTensorView<'a,
235                $scalar_rank,
236                $drank,
237                $srank,
238                Scalar,
239                STensor,
240                ROWS,
241                COLS,
242            >
243        {
244            fn elem_view_mut<'b:'a>(
245                &'b mut self,
246            ) -> ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>>{
247                self.elem_view_mut.view_mut()
248            }
249
250            fn scalar_view_mut<'b:'a>(
251                &'b  mut self,
252            ) -> ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>>{
253                self.scalar_view_mut.view_mut()
254            }
255
256            fn get_mut(&'a mut self, idx: [usize; $drank]) -> &'a mut STensor{
257                &mut self.elem_view_mut[idx]
258            }
259
260
261        }
262    };
263}
264
265mut_view_is_view!(1, 0, 1);
266mut_view_is_view!(2, 0, 2);
267mut_view_is_view!(2, 1, 1);
268mut_view_is_view!(3, 0, 3);
269mut_view_is_view!(3, 1, 2);
270mut_view_is_view!(3, 2, 1);
271mut_view_is_view!(4, 0, 4);
272mut_view_is_view!(4, 1, 3);
273mut_view_is_view!(4, 2, 2);
274mut_view_is_view!(5, 0, 5);
275mut_view_is_view!(5, 1, 4);
276mut_view_is_view!(5, 2, 3);