1use core::marker::PhantomData;
2
3use concat_arrays::concat_arrays;
4
5use crate::{
6 prelude::*,
7 MutTensor,
8 TensorView,
9};
10
11#[derive(Debug, PartialEq, Eq)]
15pub struct MutTensorView<
16 'a,
17 const TOTAL_RANK: usize,
18 const DRANK: usize,
19 const SRANK: usize,
20 Scalar: IsCoreScalar + 'static,
21 STensor: IsStaticTensor<Scalar, SRANK, ROWS, COLS> + 'static,
22 const ROWS: usize,
23 const COLS: usize,
24> where
25 ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
26 ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>: ndarray::Dimension,
27{
28 pub elem_view_mut: ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; DRANK]>>,
30 pub scalar_view_mut: ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>>,
32}
33
34pub trait IsMutTensorLike<
36 'a,
37 const TOTAL_RANK: usize,
38 const DRANK: usize,
39 const SRANK: usize,
40 Scalar: IsCoreScalar + 'static,
41 STensor: IsStaticTensor<Scalar, SRANK, ROWS, COLS> + 'static,
42 const ROWS: usize,
43 const COLS: usize,
44>: IsTensorLike<'a, TOTAL_RANK, DRANK, SRANK, Scalar, STensor, ROWS, COLS> where
45 ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
46 ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>: ndarray::Dimension,
47{
48 fn elem_view_mut<'b: 'a>(
50 &'b mut self,
51 ) -> ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; DRANK]>>;
52
53 fn scalar_view_mut<'b: 'a>(
55 &'b mut self,
56 ) -> ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; TOTAL_RANK]>>;
57
58 fn get_mut(&'a mut self, idx: [usize; DRANK]) -> &'a mut STensor;
60}
61
62macro_rules! mut_view_is_view {
63 ($scalar_rank:literal, $srank:literal, $drank:literal) => {
64
65 impl<
66 'a,
67 Scalar: IsCoreScalar + 'static,
68 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
69 const ROWS: usize,
70 const COLS: usize,
71 > MutTensorView<'a, $scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
72 {
73
74 pub fn view(
76 & self,
77 ) -> TensorView<'_, $scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
78 {
79 let v = TensorView {
80 elem_view: self.elem_view_mut.view(),
81 scalar_view: self.scalar_view_mut.view(),
82 };
83 v
84 }
85
86 pub fn new(
88 elem_view_mut: ndarray::ArrayViewMut<
89 'a,
90 STensor,
91 ndarray::Dim<[ndarray::Ix; $drank]>,
92 >,
93 ) -> Self {
94 let dims: [usize; $drank] = elem_view_mut.shape().try_into().unwrap();
95 #[allow(clippy::drop_non_drop)]
96 let shape: [usize; $scalar_rank] = concat_arrays!(dims, STensor::sdims());
97
98 let dstrides: [isize; $drank] = elem_view_mut.strides().try_into().unwrap();
99 let mut dstrides: [usize; $drank] = dstrides.map(|x| x as usize);
100 let num_scalars = STensor::num_scalars();
101 for d in dstrides.iter_mut() {
102 *d *= num_scalars;
103 }
104 #[allow(clippy::drop_non_drop)]
105 let strides = concat_arrays!(dstrides, STensor::get_strides());
106
107 let ptr = elem_view_mut.as_ptr() as *mut Scalar;
108 use ndarray::ShapeBuilder;
109 assert_eq!(core::mem::size_of::<STensor>(),
110 core::mem::size_of::<Scalar>() * ROWS * COLS
111 );
112
113 let scalar_view_mut =
114 unsafe { ndarray::ArrayViewMut::from_shape_ptr(shape.strides(strides), ptr) };
115
116 Self {
117 elem_view_mut,
118 scalar_view_mut,
119 }
120 }
121
122 pub fn mut_scalar(&'a mut self, idx: [usize; $scalar_rank]) -> &'a mut Scalar{
124 &mut self.scalar_view_mut[idx]
125 }
126
127 pub fn map<
129 'b,
130 const OTHER_HRANK: usize,
131 const OTHER_SRANK: usize,
132 OtherScalar: IsCoreScalar + 'static,
133 OtherSTensor: IsStaticTensor<
134 OtherScalar,
135 OTHER_SRANK,
136 OTHER_ROWS,
137 OTHER_COLS,
138 > + 'static,
139 const OTHER_ROWS: usize,
140 const OTHER_COLS: usize,
141 V : IsTensorView::<
142 'b,
143 OTHER_HRANK,
144 $drank,
145 OTHER_SRANK,
146 OtherScalar,
147 OtherSTensor,
148 OTHER_ROWS,
149 OTHER_COLS,
150 >,
151 F: FnMut(&mut STensor, &OtherSTensor)
152 >(
153 &'a mut self,
154 view: &'b V,
155 op: F,
156 ) where
157 ndarray::Dim<[ndarray::Ix; OTHER_HRANK]>: ndarray::Dimension
158 {
159 self.elem_view_mut.zip_mut_with(&view.elem_view(),op);
160 }
161 }
162
163
164
165
166 impl<
167 'a,
168 Scalar: IsCoreScalar + 'static,
169 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
170 const ROWS: usize,
171 const COLS: usize,
172 > IsTensorLike<'a, $scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
173 for MutTensorView<
174 'a,
175 $scalar_rank,
176 $drank,
177 $srank,
178 Scalar,
179 STensor,
180 ROWS,
181 COLS
182> {
183 fn elem_view<'b:'a>(
184 &'b self,
185 ) -> ndarray::ArrayView<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>> {
186 self.view().elem_view
187 }
188
189 fn get(& self, idx: [usize; $drank]) -> STensor {
190 self.view().get(idx)
191 }
192
193 fn dims(&self) -> [usize; $drank] {
194 self.view().dims()
195 }
196
197 fn scalar_view<'b:'a>(
198 &'b self,
199 ) -> ndarray::ArrayView<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>> {
200 self.view().scalar_view
201 }
202
203 fn scalar_get(&'a self, idx: [usize; $scalar_rank]) -> Scalar {
204 self.view().scalar_get(idx)
205 }
206
207 fn scalar_dims(&self) -> [usize; $scalar_rank] {
208 self.view().scalar_dims()
209 }
210
211 fn to_mut_tensor(
212 &self,
213 ) -> MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS> {
214 MutTensor {
215 mut_array: self.view().elem_view.to_owned(),
216 phantom: PhantomData::default(),
217 }
218 }
219 }
220
221 impl<
222 'a,
223 Scalar: IsCoreScalar + 'static,
224 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
225 const ROWS: usize,
226 const COLS: usize,
227 >
228 IsMutTensorLike<'a,
229 $scalar_rank,
230 $drank, $srank,
231 Scalar, STensor,
232 ROWS,
233 COLS >
234 for MutTensorView<'a,
235 $scalar_rank,
236 $drank,
237 $srank,
238 Scalar,
239 STensor,
240 ROWS,
241 COLS,
242 >
243 {
244 fn elem_view_mut<'b:'a>(
245 &'b mut self,
246 ) -> ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>>{
247 self.elem_view_mut.view_mut()
248 }
249
250 fn scalar_view_mut<'b:'a>(
251 &'b mut self,
252 ) -> ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>>{
253 self.scalar_view_mut.view_mut()
254 }
255
256 fn get_mut(&'a mut self, idx: [usize; $drank]) -> &'a mut STensor{
257 &mut self.elem_view_mut[idx]
258 }
259
260
261 }
262 };
263}
264
265mut_view_is_view!(1, 0, 1);
266mut_view_is_view!(2, 0, 2);
267mut_view_is_view!(2, 1, 1);
268mut_view_is_view!(3, 0, 3);
269mut_view_is_view!(3, 1, 2);
270mut_view_is_view!(3, 2, 1);
271mut_view_is_view!(4, 0, 4);
272mut_view_is_view!(4, 1, 3);
273mut_view_is_view!(4, 2, 2);
274mut_view_is_view!(5, 0, 5);
275mut_view_is_view!(5, 1, 4);
276mut_view_is_view!(5, 2, 3);