1use core::{
2 fmt::Debug,
3 marker::PhantomData,
4};
5
6use ndarray::{
7 Dim,
8 Ix,
9};
10use sophus_autodiff::linalg::{
11 SMat,
12 SVec,
13};
14
15use crate::{
16 prelude::*,
17 ArcTensor,
18 MutTensorView,
19 TensorView,
20};
21
22#[derive(Default, Debug, Clone)]
26pub struct MutTensor<
27 const TOTAL_RANK: usize,
28 const DRANK: usize,
29 const SRANK: usize,
30 Scalar: IsCoreScalar + 'static,
31 STensor: IsStaticTensor<Scalar, SRANK, ROWS, COLS> + 'static,
32 const ROWS: usize,
33 const COLS: usize,
34> where
35 ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
36{
37 pub mut_array: ndarray::Array<STensor, Dim<[Ix; DRANK]>>,
39 pub phantom: PhantomData<(Scalar, STensor)>,
41}
42
43pub trait InnerVecToMat<
45 const TOTAL_RANK: usize,
46 const DRANK: usize,
47 const SRANK: usize,
48 const HYBER_RANK_PLUS1: usize,
49 const SRANK_PLUS1: usize,
50 Scalar: IsCoreScalar + 'static,
51 const ROWS: usize,
52> where
53 SVec<Scalar, ROWS>: IsStaticTensor<Scalar, SRANK_PLUS1, ROWS, 1>,
54 ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
55{
56 type Output;
58
59 fn inner_vec_to_mat(self) -> Self::Output;
61}
62
63pub trait InnerScalarToVec<
65 const TOTAL_RANK: usize,
66 const DRANK: usize,
67 const SRANK: usize,
68 const HYBER_RANK_PLUS1: usize,
69 const SRANK_PLUS1: usize,
70 Scalar: IsCoreScalar + 'static,
71> where
72 SVec<Scalar, 1>: IsStaticTensor<Scalar, SRANK_PLUS1, 1, 1>,
73 ndarray::Dim<[ndarray::Ix; DRANK]>: ndarray::Dimension,
74{
75 type Output;
77
78 fn inner_scalar_to_vec(self) -> Self::Output;
80}
81
82impl<Scalar: IsCoreScalar + 'static, const ROWS: usize> InnerVecToMat<3, 1, 2, 4, 2, Scalar, ROWS>
83 for MutTensorXR<3, 2, 1, Scalar, ROWS>
84{
85 type Output = MutTensorXRC<4, 2, 2, Scalar, ROWS, 1>;
86
87 fn inner_vec_to_mat(self) -> MutTensorXRC<4, 2, 2, Scalar, ROWS, 1> {
88 MutTensorXRC::<4, 2, 2, Scalar, ROWS, 1> {
89 mut_array: self.mut_array,
90 phantom: PhantomData,
91 }
92 }
93}
94
95impl<Scalar: IsCoreScalar + 'static> InnerScalarToVec<2, 0, 2, 3, 1, Scalar>
96 for MutTensorX<2, Scalar>
97{
98 type Output = MutTensorXR<3, 2, 1, Scalar, 1>;
99
100 fn inner_scalar_to_vec(self) -> MutTensorXR<3, 2, 1, Scalar, 1> {
101 MutTensorXR::<3, 2, 1, Scalar, 1> {
102 mut_array: self.mut_array.map(|x| SVec::<Scalar, 1>::new(x.clone())),
103 phantom: PhantomData,
104 }
105 }
106}
107
108pub type MutTensorX<const DRANK: usize, Scalar> = MutTensor<DRANK, DRANK, 0, Scalar, Scalar, 1, 1>;
110
111pub type MutTensorXR<
113 const TOTAL_RANK: usize,
114 const DRANK: usize,
115 const SRANK: usize,
116 Scalar,
117 const R: usize,
118> = MutTensor<TOTAL_RANK, DRANK, SRANK, Scalar, SVec<Scalar, R>, R, 1>;
119
120pub type MutTensorXRC<
122 const TOTAL_RANK: usize,
123 const DRANK: usize,
124 const SRANK: usize,
125 Scalar,
126 const R: usize,
127 const C: usize,
128> = MutTensor<TOTAL_RANK, DRANK, SRANK, Scalar, SMat<Scalar, R, C>, R, C>;
129
130pub type MutTensorD<Scalar> = MutTensorX<1, Scalar>;
132
133pub type MutTensorDD<Scalar> = MutTensorX<2, Scalar>;
135
136pub type MutTensorDR<Scalar, const R: usize> = MutTensorXR<2, 1, 1, Scalar, R>;
138
139pub type MutTensorDDD<Scalar> = MutTensorX<3, Scalar>;
141
142pub type MutTensorDDR<Scalar, const R: usize> = MutTensorXR<3, 2, 1, Scalar, R>;
144
145pub type MutTensorDRC<Scalar, const R: usize, const C: usize> = MutTensorXRC<3, 1, 2, Scalar, R, C>;
147
148pub type MutTensorDDDD<Scalar> = MutTensorX<4, Scalar>;
150
151pub type MutTensorDDDR<Scalar, const R: usize> = MutTensorXR<4, 3, 1, Scalar, R>;
153
154pub type MutTensorDDRC<Scalar, const R: usize, const C: usize> =
156 MutTensorXRC<4, 2, 2, Scalar, R, C>;
157
158pub type MutTensorDDDDD<Scalar> = MutTensorX<5, Scalar>;
160
161pub type MutTensorDDDDR<Scalar, const R: usize> = MutTensorXR<5, 4, 1, Scalar, R>;
163
164pub type MutTensorDDDRC<Scalar, const R: usize, const C: usize> =
166 MutTensorXRC<5, 3, 2, Scalar, R, C>;
167
168macro_rules! mut_tensor_is_view {
169 ($scalar_rank:literal, $srank:literal, $drank:literal) => {
170
171
172 impl<
173 'a,
174 Scalar: IsCoreScalar + 'static,
175 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
176 const ROWS: usize,
177 const COLS: usize,
178 > IsTensorLike<'a, $scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
179 for MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
180 {
181 fn elem_view<'b:'a>(
182 &'b self,
183 ) -> ndarray::ArrayView<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>> {
184 self.view().elem_view
185 }
186
187 fn get(& self, idx: [usize; $drank]) -> STensor {
188 self.view().get(idx)
189 }
190
191 fn dims(&self) -> [usize; $drank] {
192 self.view().dims()
193 }
194
195 fn scalar_view<'b:'a>(
196 &'b self,
197 ) -> ndarray::ArrayView<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>> {
198 self.view().scalar_view
199 }
200
201 fn scalar_get(&'a self, idx: [usize; $scalar_rank]) -> Scalar {
202 self.view().scalar_get(idx)
203 }
204
205 fn scalar_dims(&self) -> [usize; $scalar_rank] {
206 self.view().scalar_dims()
207 }
208
209 fn to_mut_tensor(
210 &self,
211 ) -> MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS> {
212 MutTensor {
213 mut_array: self.elem_view().to_owned(),
214 phantom: PhantomData::default(),
215 }
216 }
217 }
218
219 impl<
220 'a,
221 Scalar: IsCoreScalar + 'static,
222 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
223 const ROWS: usize,
224 const COLS: usize,
225
226 >
227 IsMutTensorLike<'a,
228 $scalar_rank, $drank, $srank,
229 Scalar, STensor,
230 ROWS, COLS
231 >
232 for MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
233 {
234 fn elem_view_mut<'b:'a>(
235 &'b mut self,
236 ) -> ndarray::ArrayViewMut<'a, STensor, ndarray::Dim<[ndarray::Ix; $drank]>>{
237 self.mut_view().elem_view_mut
238 }
239 fn get_mut(& mut self, idx: [usize; $drank]) -> &mut STensor{
240 &mut self.mut_array[idx]
241 }
242
243 fn scalar_view_mut<'b:'a>(
244 &'b mut self,
245 ) -> ndarray::ArrayViewMut<'a, Scalar, ndarray::Dim<[ndarray::Ix; $scalar_rank]>>{
246 self.mut_view().scalar_view_mut
247 }
248 }
249
250 impl<'a, Scalar: IsCoreScalar+ 'static,
251 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
252 const ROWS: usize,
253 const COLS: usize,
254
255 > PartialEq for
256 MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
257 {
258 fn eq(&self, other: &Self) -> bool {
259 self.view().scalar_view == other.view().scalar_view
260 }
261 }
262
263 impl<'a, Scalar: IsCoreScalar+ 'static,
264 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
265 const ROWS: usize,
266 const COLS: usize,
267
268 >
269 MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
270 {
271
272 pub fn from_shape(size: [usize; $drank]) -> Self {
274 MutTensor::<$scalar_rank, $drank, $srank, Scalar, STensor,
275 ROWS, COLS>::from_shape_and_val(
276 size, num_traits::Zero::zero()
277 )
278 }
279
280 pub fn from_map2<
283 'b,
284 const OTHER_HRANK: usize, const OTHER_SRANK: usize,
285 OtherScalar: IsCoreScalar + 'static,
286 OtherSTensor: IsStaticTensor<
287 OtherScalar, OTHER_SRANK, OTHER_ROWS, OTHER_COLS
288 > + 'static,
289 const OTHER_ROWS: usize, const OTHER_COLS: usize,
290 V : IsTensorView::<'b,
291 OTHER_HRANK, $drank, OTHER_SRANK,
292 OtherScalar, OtherSTensor,
293 OTHER_ROWS, OTHER_COLS
294 >,
295 const OTHER_HRANK2: usize, const OTHER_SRANK2: usize,
296 OtherScalar2: IsCoreScalar + 'static,
297 OtherSTensor2: IsStaticTensor<
298 OtherScalar2, OTHER_SRANK2, OTHER_ROWS2, OTHER_COLS2,
299 > + 'static,
300 const OTHER_ROWS2: usize, const OTHER_COLS2: usize,
301 V2 : IsTensorView::<'b,
302 OTHER_HRANK2, $drank, OTHER_SRANK2,
303 OtherScalar2, OtherSTensor2,
304 OTHER_ROWS2, OTHER_COLS2
305 >,
306 F: FnMut(&OtherSTensor, &OtherSTensor2)->STensor
307 >(
308 view: &'b V,
309 view2: &'b V2,
310 mut op: F,
311 )
312 -> Self
313 where
314 ndarray::Dim<[ndarray::Ix; OTHER_HRANK]>: ndarray::Dimension,
315 ndarray::Dim<[ndarray::Ix; OTHER_HRANK2]>: ndarray::Dimension
316
317 {
318 let mut out = Self::from_shape(view.dims());
319 ndarray::Zip::from(&mut out.elem_view_mut())
320 .and(&view.elem_view())
321 .and(&view2.elem_view())
322 .for_each(
323 |out, v, v2|{
324 *out = op(v, v2);
325 });
326 out
327 }
328 }
329
330 impl<'a, Scalar: IsCoreScalar+ 'static,
331 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
332 const ROWS: usize,
333 const COLS: usize,
334
335 >
336 MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
337 {
338
339
340 pub fn mut_view<'b: 'a>(
342 &'b mut self,
343 ) -> MutTensorView<'a,
344 $scalar_rank, $drank, $srank,
345 Scalar, STensor,
346 ROWS, COLS>
347 {
348 MutTensorView::<
349 'a,
350 $scalar_rank, $drank, $srank,
351 Scalar, STensor, ROWS, COLS>::new
352 (
353 self.mut_array.view_mut()
354 )
355 }
356
357 pub fn view<'b: 'a>(&'b self
359 ) -> TensorView<'a, $scalar_rank, $drank, $srank, Scalar, STensor,
360 ROWS, COLS> {
361 TensorView::<'a, $scalar_rank, $drank, $srank, Scalar, STensor,
362 ROWS, COLS>::new(
363 self.mut_array.view())
364 }
365
366
367 pub fn from_shape_and_val
369 (
370 shape: [usize; $drank],
371 val: STensor,
372 ) -> Self
373 {
374 Self{
375 mut_array: ndarray::Array::<STensor, Dim<[Ix; $drank]>>::from_elem(shape, val),
376 phantom: PhantomData::default()
377 }
378 }
379
380 pub fn make_copy_from(
382 v: &TensorView<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
383 ) -> Self
384 {
385 IsTensorLike::to_mut_tensor(v)
386 }
387
388 pub fn to_shared(self)
390 -> ArcTensor::<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
391 {
392 ArcTensor::<
393 $scalar_rank,
394 $drank, $srank,
395 Scalar, STensor,
396 ROWS, COLS>::from_mut_tensor(self)
397 }
398
399 pub fn from_map<
402 'b,
403 const OTHER_HRANK: usize, const OTHER_SRANK: usize,
404 OtherScalar: IsCoreScalar+ 'static,
405 OtherSTensor: IsStaticTensor<
406 OtherScalar, OTHER_SRANK,
407 OTHER_ROWS, OTHER_COLS
408 > + 'static,
409 const OTHER_ROWS: usize, const OTHER_COLS: usize,
410 V : IsTensorView::<
411 'b,
412 OTHER_HRANK, $drank, OTHER_SRANK,
413 OtherScalar, OtherSTensor,
414 OTHER_ROWS, OTHER_COLS
415 >,
416 F: FnMut(&OtherSTensor)-> STensor
417 > (
418 view: &'b V,
419 op: F,
420 )
421 -> Self where
422 ndarray::Dim<[ndarray::Ix; OTHER_HRANK]>: ndarray::Dimension,
423 ndarray::Dim<[ndarray::Ix; $drank]>: ndarray::Dimension,
424 {
425 Self {
426 mut_array: view.elem_view().map(op),
427 phantom: PhantomData::default()
428 }
429 }
430
431
432
433
434
435 }
436 };
437}
438
439mut_tensor_is_view!(1, 0, 1);
440mut_tensor_is_view!(2, 0, 2);
441mut_tensor_is_view!(2, 1, 1);
442mut_tensor_is_view!(3, 0, 3);
443mut_tensor_is_view!(3, 1, 2);
444mut_tensor_is_view!(3, 2, 1);
445mut_tensor_is_view!(4, 0, 4);
446mut_tensor_is_view!(4, 1, 3);
447mut_tensor_is_view!(4, 2, 2);
448mut_tensor_is_view!(5, 0, 5);
449mut_tensor_is_view!(5, 1, 4);
450mut_tensor_is_view!(5, 2, 3);
451
452macro_rules! mut_tensor_is_view_drank_1 {
453 ($scalar_rank:literal, $srank:literal) => {
454 impl<
455 'a,
456 Scalar: IsCoreScalar + 'static,
457 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
458 const ROWS: usize,
459 const COLS: usize,
460 > MutTensor<$scalar_rank, 1, $srank, Scalar, STensor, ROWS, COLS>
461 {
462 pub fn from_fn<F: FnMut([usize; 1]) -> STensor>(shape: [usize; 1], mut op: F) -> Self {
464 Self {
465 mut_array: ndarray::Array::<STensor, Dim<[Ix; 1]>>::from_shape_fn(
466 shape,
467 |idx| op([idx]),
468 ),
469 phantom: PhantomData::default(),
470 }
471 }
472 }
473 };
474}
475
476mut_tensor_is_view_drank_1!(1, 0);
477mut_tensor_is_view_drank_1!(2, 1);
478mut_tensor_is_view_drank_1!(3, 2);
479
480macro_rules! mut_tensor_is_view_drank_2_plus {
481 ($scalar_rank:literal, $srank:literal, $drank:literal) => {
482 impl<
483 'a,
484 Scalar: IsCoreScalar + 'static,
485 STensor: IsStaticTensor<Scalar, $srank, ROWS, COLS> + 'static,
486 const ROWS: usize,
487 const COLS: usize,
488 > MutTensor<$scalar_rank, $drank, $srank, Scalar, STensor, ROWS, COLS>
489 {
490 pub fn from_fn<F: FnMut([usize; $drank]) -> STensor>(
492 shape: [usize; $drank],
493 mut op: F,
494 ) -> Self {
495 Self {
496 mut_array: ndarray::Array::<STensor, Dim<[Ix; $drank]>>::from_shape_fn(
497 shape,
498 |idx| op(idx.try_into().unwrap()),
499 ),
500 phantom: PhantomData::default(),
501 }
502 }
503 }
504 };
505}
506
507mut_tensor_is_view_drank_2_plus!(2, 0, 2);
508mut_tensor_is_view_drank_2_plus!(3, 0, 3);
509mut_tensor_is_view_drank_2_plus!(3, 1, 2);
510mut_tensor_is_view_drank_2_plus!(4, 0, 4);
511mut_tensor_is_view_drank_2_plus!(4, 1, 3);
512mut_tensor_is_view_drank_2_plus!(4, 2, 2);
513mut_tensor_is_view_drank_2_plus!(5, 0, 5);
514mut_tensor_is_view_drank_2_plus!(5, 1, 4);
515mut_tensor_is_view_drank_2_plus!(5, 2, 3);
516
517#[test]
518fn mut_tensor_tests() {
519 use log::info;
520 #[cfg(feature = "simd")]
521 use sophus_autodiff::linalg::BatchMatF64;
522 {
523 let _rank1_tensor = MutTensorD::<u8>::default();
524 let shape = [2];
526 let tensor_f32 = MutTensorD::from_shape_and_val(shape, 0.0);
527 assert_eq!(tensor_f32.view().dims(), shape);
529 }
530 {
531 let _rank2_tensor = MutTensorDD::<u8>::default();
532 let shape = [3, 2];
534 let tensor_f32 = MutTensorDD::<f32>::from_shape(shape);
535 assert_eq!(tensor_f32.view().dims(), shape);
537 }
538 {
539 let _rank3_tensor = MutTensorDDD::<u8>::default();
540 let shape = [3, 2, 4];
542 let tensor_f32 = MutTensorDDD::<f32>::from_shape(shape);
543 assert_eq!(tensor_f32.view().dims(), shape);
545 }
546 {
548 let shape = [3];
549 {
550 let tensor_f32 = MutTensorD::from_shape_and_val(shape, 1.0);
551 let op = |v: &f32| {
552 let mut value = SVec::<f32, 3>::default();
553 value[0] = *v;
554 value[1] = 0.2 * *v;
555 value[2] = 0.3 * *v;
556 value
557 };
558 let pattern = MutTensorDR::<f32, 3>::from_map(&tensor_f32.view(), op);
559
560 info!("p :{}", pattern.mut_array);
561 }
566 let shape = [3, 2];
567 {
568 let tensor_f32 = MutTensorDD::from_shape_and_val(shape, 1.0);
569 let op = |v: &f32| {
570 let mut value = SVec::<f32, 3>::default();
571 value[0] = *v;
572 value[1] = 0.2 * *v;
573 value[2] = 0.3 * *v;
574 value
575 };
576 let pattern = MutTensorDDR::from_map(&tensor_f32.view(), op);
577 info!("p :{}", pattern.mut_array);
578 info!("p :{}", pattern.view().scalar_view());
579 }
580 let shape = [3, 2, 4];
581 {
582 let tensor_f32 = MutTensorDDD::from_shape_and_val(shape, 1.0);
583 let op = |v: &f32| {
584 let mut value = SVec::<f32, 3>::default();
585 value[0] = *v;
586 value[1] = 0.2 * *v;
587 value[2] = 0.3 * *v;
588 value
589 };
590 let pattern = MutTensorDDDR::from_map(&tensor_f32.view(), op);
591 info!("p :{}", pattern.mut_array);
592 info!("p :{}", pattern.view().scalar_view());
593 }
594 }
595
596 #[cfg(feature = "simd")]
598 {
599 let shape = [3];
600
601 let _tensor_u8 = MutTensorD::from_shape_and_val(shape, 0);
602 let _tensor_f64 = MutTensorDRC::from_shape_and_val(shape, SMat::<f64, 4, 4>::zeros());
603 let _tensor_batched_f32 =
604 MutTensorDRC::from_shape_and_val(shape, BatchMatF64::<2, 3, 4>::zeros());
605 }
606
607 {
609 let shape = [1];
610 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
611 let data_mat = SMat::<f32, 3, 2>::from_vec(data.to_vec());
612 let tensor_f32 = MutTensorDRC::from_shape_and_val(shape, data_mat);
613 assert_eq!(tensor_f32.dims(), shape);
614 assert_eq!(tensor_f32.view().scalar_get([0, 0, 0]), data[0]);
615 assert_eq!(tensor_f32.view().scalar_get([0, 1, 0]), data[1]);
616 assert_eq!(tensor_f32.view().scalar_get([0, 2, 0]), data[2]);
617 assert_eq!(tensor_f32.view().scalar_get([0, 0, 1]), data[3]);
618 assert_eq!(tensor_f32.view().scalar_get([0, 1, 1]), data[4]);
619 assert_eq!(tensor_f32.view().scalar_get([0, 2, 1]), data[5]);
620 }
621}