Skip to main content

slop_tensor/
inner.rs

1use std::{
2    marker::PhantomData,
3    mem::ManuallyDrop,
4    ops::{Index, IndexMut},
5};
6
7use derive_where::derive_where;
8use rand::{distributions::Standard, prelude::Distribution, Rng};
9use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
10use slop_algebra::{ExtensionField, Field};
11use slop_alloc::{
12    Backend, Buffer, CpuBackend, HasBackend, Init, TryReserveError, GLOBAL_CPU_BACKEND,
13};
14use slop_matrix::Matrix;
15
16use crate::{Dimensions, DimensionsError};
17
18#[derive(Debug, Clone)]
19#[derive_where(PartialEq, Eq; Buffer<T, A>)]
20pub struct Tensor<T, A: Backend = CpuBackend> {
21    pub storage: Buffer<T, A>,
22    pub dimensions: Dimensions,
23}
24
25impl<T, A: Backend> Tensor<T, A> {
26    #[inline]
27    pub fn with_sizes_in(sizes: impl AsRef<[usize]>, allocator: A) -> Self {
28        Self::try_with_sizes_in(sizes, allocator).unwrap()
29    }
30
31    #[inline]
32    pub fn zeros_in(sizes: impl AsRef<[usize]>, allocator: A) -> Self {
33        let mut tensor = Self::with_sizes_in(sizes, allocator);
34        tensor.storage.write_bytes(0, tensor.total_len() * std::mem::size_of::<T>()).unwrap();
35        tensor
36    }
37
38    #[inline]
39    pub fn zeros_in_with_total_capacity(sizes: impl AsRef<[usize]>, allocator: A) -> Self {
40        let mut tensor = Self::with_sizes_in(sizes, allocator);
41        tensor.storage.write_bytes(0, tensor.total_len() * std::mem::size_of::<T>()).unwrap();
42        tensor
43    }
44
45    #[inline]
46    pub fn try_with_sizes_in(
47        sizes: impl AsRef<[usize]>,
48        allocator: A,
49    ) -> Result<Self, TryReserveError> {
50        let dimensions = Dimensions::try_from(sizes.as_ref()).unwrap();
51        Ok(Self {
52            storage: Buffer::try_with_capacity_in(dimensions.total_len(), allocator)?,
53            dimensions,
54        })
55    }
56
57    #[track_caller]
58    pub fn reshape_in_place(&mut self, sizes: impl AsRef<[usize]>) {
59        #[cold]
60        #[track_caller]
61        #[inline(never)]
62        fn dimension_fail(new_dimensions: &Dimensions, old_dimensions: &Dimensions) -> ! {
63            panic!(
64                "TensorView::reshape: dimension mismatch: {new_dimensions:?} vs {old_dimensions:?}"
65            );
66        }
67
68        let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
69        if self.dimensions.compatible(&dimensions).is_err() {
70            dimension_fail(&dimensions, &self.dimensions);
71        }
72        self.dimensions = dimensions;
73    }
74
75    #[inline]
76    #[track_caller]
77    pub fn reshape(mut self, sizes: impl AsRef<[usize]>) -> Self {
78        #[cold]
79        #[track_caller]
80        #[inline(never)]
81        fn dimension_fail(new_dimensions: &Dimensions, old_dimensions: &Dimensions) -> ! {
82            panic!(
83                "TensorView::reshape: dimension mismatch: {new_dimensions:?} vs {old_dimensions:?}"
84            );
85        }
86
87        let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
88        if self.dimensions.compatible(&dimensions).is_err() {
89            dimension_fail(&dimensions, &self.dimensions);
90        }
91        self.dimensions = dimensions;
92        self
93    }
94
95    /// # Safety
96    ///
97    /// The caller must ensure that the new dimensions are compatible with the existing dimensions.
98    #[inline]
99    pub unsafe fn reshape_unchecked(mut self, dimensions: Dimensions) {
100        self.dimensions = dimensions;
101    }
102
103    #[inline]
104    pub fn flatten_in_place(&mut self) {
105        self.reshape_in_place([self.dimensions.total_len()]);
106    }
107
108    #[inline]
109    pub fn flatten(mut self) -> Self {
110        self.flatten_in_place();
111        self
112    }
113
114    #[inline]
115    pub fn into_buffer(self) -> Buffer<T, A> {
116        self.storage
117    }
118
119    #[inline]
120    pub fn as_buffer(&self) -> &Buffer<T, A> {
121        &self.storage
122    }
123
124    #[inline]
125    pub fn as_mut_buffer(&mut self) -> &mut Buffer<T, A> {
126        &mut self.storage
127    }
128
129    #[inline]
130    pub fn backend(&self) -> &A {
131        self.storage.allocator()
132    }
133
134    #[inline]
135    pub fn shape(&self) -> &Dimensions {
136        &self.dimensions
137    }
138
139    /// Returns the dimensions of the tensor.
140    #[inline]
141    pub fn sizes(&self) -> &[usize] {
142        self.dimensions.sizes()
143    }
144
145    #[inline]
146    pub fn strides(&self) -> &[usize] {
147        self.dimensions.strides()
148    }
149
150    #[inline]
151    pub fn as_ptr(&self) -> *const T {
152        self.storage.as_ptr()
153    }
154
155    /// # Safety
156    ///
157    /// This function is unsafe because it enables bypassing the lifetime of the tensor.
158    #[inline]
159    pub unsafe fn owned_unchecked(&self) -> ManuallyDrop<Self> {
160        self.owned_unchecked_in(self.storage.allocator().clone())
161    }
162
163    /// # Safety
164    ///
165    /// This function is unsafe because it enables bypassing the lifetime of the tensor.
166    #[inline]
167    pub unsafe fn owned_unchecked_in(&self, storage_allocator: A) -> ManuallyDrop<Self> {
168        let dimensions = self.dimensions.clone();
169        let storage = self.storage.owned_unchecked_in(storage_allocator);
170        let storage = ManuallyDrop::into_inner(storage);
171        ManuallyDrop::new(Self { storage, dimensions })
172    }
173
174    #[inline]
175    pub fn total_len(&self) -> usize {
176        self.dimensions.total_len()
177    }
178
179    pub fn as_mut_ptr(&mut self) -> *mut T {
180        self.storage.as_mut_ptr()
181    }
182
183    #[inline]
184    pub fn as_view(&'_ self) -> TensorView<'_, T, A> {
185        TensorView {
186            ptr: self.as_ptr(),
187            dimensions: self.dimensions.clone(),
188            backend: self.backend().clone(),
189            _marker: PhantomData,
190        }
191    }
192
193    #[inline]
194    pub fn as_view_mut(&'_ mut self) -> TensorViewMut<'_, T, A> {
195        TensorViewMut {
196            ptr: self.as_mut_ptr(),
197            dimensions: self.dimensions.clone(),
198            _marker: PhantomData,
199        }
200    }
201
202    #[inline]
203    pub fn get(&'_ self, index: usize) -> Option<TensorView<'_, T, A>> {
204        self.as_view().get(index)
205    }
206
207    #[inline]
208    pub fn get_mut(&'_ mut self, index: usize) -> Option<TensorViewMut<'_, T, A>> {
209        self.as_view_mut().get(index)
210    }
211
212    #[inline]
213    pub fn split(&'_ self) -> impl Iterator<Item = TensorView<'_, T, A>> {
214        self.as_view().split()
215    }
216
217    #[inline]
218    pub fn split_mut(&'_ mut self) -> impl Iterator<Item = TensorViewMut<'_, T, A>> {
219        self.as_view_mut().split_mut()
220    }
221
222    /// # Safety
223    ///
224    /// See [std::mem::MaybeUninit::assume_init].
225    #[inline]
226    pub unsafe fn assume_init(&mut self) {
227        self.storage.set_len(self.storage.capacity());
228    }
229
230    pub fn flatten_to_base<F: Field>(self) -> Tensor<F, A>
231    where
232        T: ExtensionField<F>,
233    {
234        let [height, width]: [usize; 2] = self.sizes().try_into().unwrap();
235        let dimensions = Dimensions::try_from([height, T::D * width]).unwrap();
236        let data_storage = self.into_buffer().flatten_to_base();
237        Tensor { storage: data_storage, dimensions }
238    }
239
240    pub fn into_extension<ET: ExtensionField<T>>(self) -> Tensor<ET, A>
241    where
242        T: Field,
243    {
244        let [height, width]: [usize; 2] = self.sizes().try_into().unwrap();
245        let dimensions = Dimensions::try_from([height, width / ET::D]).unwrap();
246        let extension_storage = self.into_buffer().into_extension();
247        Tensor { storage: extension_storage, dimensions }
248    }
249}
250
251impl<T, A: Backend, I: AsRef<[usize]>> Index<I> for Tensor<T, A> {
252    type Output = Init<T, A>;
253
254    #[track_caller]
255    fn index(&self, index: I) -> &Self::Output {
256        #[cold]
257        #[track_caller]
258        #[inline(never)]
259        fn dimension_fail(index_len: usize, sizes_len: usize) -> ! {
260            panic!(
261                "Index length ({index_len}) does not match tensor dimensions length ({sizes_len})"
262            );
263        }
264
265        if index.as_ref().len() != self.dimensions.sizes().len() {
266            dimension_fail(index.as_ref().len(), self.dimensions.sizes().len());
267        }
268        let index = self.dimensions.index_map(index);
269        &self.storage[index]
270    }
271}
272
273impl<T, A: Backend, I: AsRef<[usize]>> IndexMut<I> for Tensor<T, A> {
274    fn index_mut(&mut self, index: I) -> &mut Self::Output {
275        let index = self.dimensions.index_map(index);
276        &mut self.storage[index]
277    }
278}
279
280impl<T, A: Backend> From<Buffer<T, A>> for Tensor<T, A> {
281    #[inline]
282    fn from(buffer: Buffer<T, A>) -> Self {
283        let dims = [buffer.len()].into_iter().collect();
284        Self { storage: buffer, dimensions: dims }
285    }
286}
287
288impl<T, A: Backend> HasBackend for Tensor<T, A> {
289    type Backend = A;
290
291    fn backend(&self) -> &Self::Backend {
292        self.backend()
293    }
294}
295
296impl<T> From<Vec<T>> for Tensor<T, CpuBackend> {
297    #[inline]
298    fn from(vec: Vec<T>) -> Self {
299        Self::from(Buffer::from(vec))
300    }
301}
302
303impl<T> FromIterator<T> for Tensor<T, CpuBackend> {
304    #[inline]
305    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
306        Self::from(iter.into_iter().collect::<Vec<_>>())
307    }
308}
309
310impl<T: Clone + Send + Sync> From<slop_matrix::dense::RowMajorMatrix<T>> for Tensor<T, CpuBackend> {
311    fn from(value: slop_matrix::dense::RowMajorMatrix<T>) -> Self {
312        let dimensions: Dimensions = [value.height(), value.width()].try_into().unwrap();
313        let storage = Buffer::from(value.values);
314        Self { storage, dimensions }
315    }
316}
317
318impl<T: Clone + Send + Sync> TryFrom<Tensor<T, CpuBackend>>
319    for slop_matrix::dense::RowMajorMatrix<T>
320{
321    type Error = DimensionsError;
322    fn try_from(value: Tensor<T, CpuBackend>) -> Result<Self, Self::Error> {
323        if value.sizes().len() != 2 {
324            return Err(DimensionsError::TooManyDimensions(value.sizes().len()));
325        }
326        let width = value.sizes()[1];
327        let values = value.storage.into_vec();
328        Ok(Self::new(values, width))
329    }
330}
331
332impl<T> Tensor<T, CpuBackend> {
333    pub fn rand<R: Rng>(rng: &mut R, sizes: impl AsRef<[usize]>) -> Self
334    where
335        Standard: Distribution<T>,
336    {
337        let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
338        let values = rng.sample_iter(Standard).take(dimensions.total_len()).collect::<Vec<_>>();
339        Self { storage: Buffer::from(values), dimensions }
340    }
341
342    #[inline]
343    pub fn with_sizes(sizes: impl AsRef<[usize]>) -> Self {
344        Tensor::with_sizes_in(sizes, GLOBAL_CPU_BACKEND)
345    }
346
347    #[inline]
348    pub fn as_slice(&self) -> &[T] {
349        &self.storage[..]
350    }
351
352    #[inline]
353    pub fn as_mut_slice(&mut self) -> &mut [T] {
354        &mut self.storage[..]
355    }
356}
357
358#[derive(Debug)]
359pub struct TensorView<'a, T, A: Backend = CpuBackend> {
360    ptr: *const T,
361    dimensions: Dimensions,
362    backend: A,
363    /// Marker to ensure that the view is not used after the original tensor is freed.
364    _marker: PhantomData<&'a Tensor<T, A>>,
365}
366
367impl<'a, T, A: Backend> TensorView<'a, T, A> {
368    #[inline]
369    pub fn as_ptr(&self) -> *const T {
370        self.ptr
371    }
372
373    #[inline]
374    pub fn sizes(&self) -> &[usize] {
375        self.dimensions.sizes()
376    }
377
378    #[inline]
379    pub fn backend(&self) -> &A {
380        &self.backend
381    }
382
383    #[inline]
384    /// # Safety
385    ///
386    /// The caller must ensure that the pointer is valid for the given dimensions and backend.
387    pub unsafe fn from_raw_parts(ptr: *const T, dimensions: Dimensions, backend: A) -> Self {
388        Self { ptr, dimensions, backend, _marker: PhantomData }
389    }
390
391    #[inline]
392    pub fn strides(&self) -> &[usize] {
393        self.dimensions.strides()
394    }
395
396    #[inline]
397    pub fn total_len(&self) -> usize {
398        self.dimensions.total_len()
399    }
400
401    #[inline]
402    pub fn shape(&self) -> &Dimensions {
403        &self.dimensions
404    }
405
406    #[inline]
407    pub fn flatten(self) -> TensorView<'a, T, A> {
408        let total_len = self.total_len();
409        self.reshape([total_len])
410    }
411
412    #[inline]
413    #[track_caller]
414    pub fn reshape(self, sizes: impl AsRef<[usize]>) -> TensorView<'a, T, A> {
415        #[cold]
416        #[track_caller]
417        #[inline(never)]
418        fn dimension_fail(new_dimensions: &Dimensions, old_dimensions: &Dimensions) -> ! {
419            panic!(
420                "TensorView::reshape: dimension mismatch: {new_dimensions:?} vs {old_dimensions:?}"
421            );
422        }
423
424        let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
425        if self.dimensions.compatible(&dimensions).is_err() {
426            dimension_fail(&dimensions, &self.dimensions);
427        }
428        TensorView {
429            ptr: self.ptr,
430            dimensions,
431            backend: self.backend.clone(),
432            _marker: PhantomData,
433        }
434    }
435
436    #[inline]
437    pub fn get(mut self, index: usize) -> Option<Self> {
438        let size = self.dimensions.sizes_mut().remove(0);
439        if index >= size {
440            return None;
441        }
442        let stride = self.dimensions.strides_mut().remove(0);
443        let offset = index * stride;
444
445        let ptr = unsafe { self.ptr.add(offset) };
446        Some(Self {
447            ptr,
448            dimensions: self.dimensions,
449            backend: self.backend.clone(),
450            _marker: PhantomData,
451        })
452    }
453
454    pub fn split(self) -> impl Iterator<Item = Self> {
455        (0..self.dimensions.sizes()[0]).map(move |i| self.clone().get(i).unwrap())
456    }
457}
458
459impl<'a, T, A: Backend> Clone for TensorView<'a, T, A> {
460    fn clone(&self) -> Self {
461        Self {
462            ptr: self.ptr,
463            dimensions: self.dimensions.clone(),
464            backend: self.backend.clone(),
465            _marker: PhantomData,
466        }
467    }
468}
469
470impl<'a, T, A: Backend> From<&'a Tensor<T, A>> for TensorView<'a, T, A> {
471    fn from(tensor: &'a Tensor<T, A>) -> Self {
472        tensor.as_view()
473    }
474}
475
476impl<'a, T, A: Backend, I: AsRef<[usize]>> Index<I> for TensorView<'a, T, A> {
477    type Output = Init<T, A>;
478
479    #[inline]
480    fn index(&self, index: I) -> &Self::Output {
481        let index = self.dimensions.index_map(index);
482        unsafe {
483            let ptr = self.ptr.add(index) as *const Init<T, A>;
484            ptr.as_ref().unwrap()
485        }
486    }
487}
488
489impl<T> Default for Tensor<T, CpuBackend> {
490    fn default() -> Self {
491        Self::from(Buffer::default())
492    }
493}
494
495#[derive(Debug)]
496pub struct TensorViewMut<'a, T, A: Backend = CpuBackend> {
497    ptr: *mut T,
498    dimensions: Dimensions,
499    /// Marker to ensure that we get an exlusive reference, and that the view is not used after the
500    /// original tensor is freed.
501    _marker: PhantomData<&'a mut Tensor<T, A>>,
502}
503
504impl<'a, T, A: Backend> TensorViewMut<'a, T, A> {
505    #[inline]
506    pub fn as_mut_ptr(&mut self) -> *mut T {
507        self.ptr
508    }
509
510    #[inline]
511    pub fn sizes(&self) -> &[usize] {
512        self.dimensions.sizes()
513    }
514
515    #[inline]
516    pub fn shape(&self) -> &Dimensions {
517        &self.dimensions
518    }
519
520    #[inline]
521    pub fn strides(&self) -> &[usize] {
522        self.dimensions.strides()
523    }
524
525    #[inline]
526    pub fn flatten(self) -> TensorViewMut<'a, T, A> {
527        let total_len = self.total_len();
528        self.reshape([total_len])
529    }
530
531    #[inline]
532    pub fn reshape(self, sizes: impl AsRef<[usize]>) -> TensorViewMut<'a, T, A> {
533        let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
534        self.dimensions.compatible(&dimensions).unwrap();
535        TensorViewMut { ptr: self.ptr, dimensions, _marker: PhantomData }
536    }
537
538    #[inline]
539    pub fn get(mut self, index: usize) -> Option<Self> {
540        let size = self.dimensions.sizes_mut().remove(0);
541        if index >= size {
542            return None;
543        }
544        let stride = self.dimensions.strides_mut().remove(0);
545        let offset = index * stride;
546
547        let ptr = unsafe { self.ptr.add(offset) };
548        Some(Self { ptr, dimensions: self.dimensions, _marker: PhantomData })
549    }
550
551    #[inline]
552    pub fn split_mut(self) -> impl Iterator<Item = Self> {
553        (0..self.dimensions.sizes()[0]).map(move |i| {
554            let self_copy =
555                Self { ptr: self.ptr, dimensions: self.dimensions.clone(), _marker: PhantomData };
556            self_copy.get(i).unwrap()
557        })
558    }
559
560    #[inline]
561    pub fn total_len(&self) -> usize {
562        self.dimensions.total_len()
563    }
564}
565
566impl<'a, T> TensorView<'a, T, CpuBackend> {
567    #[inline]
568    pub fn as_slice(self) -> &'a [T] {
569        unsafe { std::slice::from_raw_parts(self.ptr, self.dimensions.total_len()) }
570    }
571}
572
573impl<'a, T> TensorViewMut<'a, T, CpuBackend> {
574    #[inline]
575    pub fn as_slice(self) -> &'a [T] {
576        unsafe { std::slice::from_raw_parts(self.ptr, self.dimensions.total_len()) }
577    }
578
579    #[inline]
580    pub fn as_mut_slice(self) -> &'a mut [T] {
581        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.dimensions.total_len()) }
582    }
583}
584
585impl<'a, T, A: Backend> From<&'a mut Tensor<T, A>> for TensorViewMut<'a, T, A> {
586    fn from(tensor: &'a mut Tensor<T, A>) -> Self {
587        tensor.as_view_mut()
588    }
589}
590
591impl<'a, T, A: Backend, I: AsRef<[usize]>> Index<I> for TensorViewMut<'a, T, A> {
592    type Output = Init<T, A>;
593
594    #[inline]
595    fn index(&self, index: I) -> &Self::Output {
596        let index = self.dimensions.index_map(index);
597        unsafe {
598            let ptr = self.ptr.add(index) as *const T as *const Init<T, A>;
599            ptr.as_ref().unwrap()
600        }
601    }
602}
603
604impl<'a, T, A: Backend, I: AsRef<[usize]>> IndexMut<I> for TensorViewMut<'a, T, A> {
605    #[inline]
606    fn index_mut(&mut self, index: I) -> &mut Self::Output {
607        let index = self.dimensions.index_map(index);
608        unsafe {
609            let ptr = self.ptr.add(index) as *mut Init<T, A>;
610            ptr.as_mut().unwrap()
611        }
612    }
613}
614
615// A macro to create a 1D or 2D tensor from a list of elements.
616#[macro_export]
617macro_rules! tensor {
618    // ----- 2D pattern: e.g. tensor![[1,2,3], [4,5,6]] -----
619    //
620    // Matches a top-level array of sub-arrays: [ [a,b,c], [d,e,f], ... ].
621    // Each sub-array is 1D. We gather them all in a Vec<Vec<_>>,
622    // check that all rows have the same length, flatten them,
623    // and reshape into a 2D Tensor.
624
625    ($([$($elem:expr),* $(,)?]),+ $(,)?) => {{
626        // Gather each sub-array into a temporary Vec<Vec<T>>.
627        let rows = vec![
628            $(
629                vec![$($elem,)*]
630            ),*
631        ];
632
633        // Check that all rows have the same length.
634        let row_len = rows[0].len();
635        let rows_count = rows.len();
636        if !rows.iter().all(|r| r.len() == row_len) {
637            panic!("All sub-lists must have the same length to form a 2D tensor.");
638        }
639
640        // Flatten everything into a single Vec<T>.
641        let flattened = rows.into_iter().flatten().collect::<Vec<_>>();
642
643        // Build the Tensor and reshape it to [rows_count, row_len].
644        // (We assume .reshape([..]) returns Self in your code.)
645        $crate::Tensor::from(flattened).reshape([rows_count, row_len])
646    }};
647
648    // ----- 1D pattern with outer brackets: e.g. tensor!([1, 2, 3]) -----
649    //
650    // If you do want “bare” bracket usage to produce a 1D Tensor (shape = [3]).
651
652    ([$($elem:expr),* $(,)?]) => {{
653        let v = vec![$($elem,)*];
654        $crate::Tensor::from(v)
655    }};
656
657    // ----- 1D “bare” comma‐separated: e.g. tensor![1, 2, 3] -----
658    //
659    // Matches a simple comma list at top-level.
660
661    ($($elem:expr),+ $(,)?) => {{
662        let v = vec![$($elem,)*];
663        $crate::Tensor::from(v)
664    }};
665}
666
667// Make a serialize and deserialize for Tensor<T> using the fact that we can serialize the buffer
668// and the dimensions.
669
670impl<T: Serialize> Serialize for Tensor<T> {
671    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
672        let mut state = serializer.serialize_struct("Tensor", 2)?;
673        state.serialize_field("storage", &self.storage)?;
674        state.serialize_field("dimensions", &self.dimensions)?;
675        state.end()
676    }
677}
678
679impl<'de, T: Deserialize<'de>> Deserialize<'de> for Tensor<T> {
680    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
681        #[derive(Deserialize)]
682        #[serde(field_identifier, rename_all = "lowercase")]
683        enum Field {
684            Storage,
685            Dimensions,
686        }
687
688        struct TensorVisitor<T>(PhantomData<T>);
689
690        impl<'de, T: Deserialize<'de>> serde::de::Visitor<'de> for TensorVisitor<T> {
691            type Value = Tensor<T>;
692
693            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
694                formatter.write_str("struct Tensor")
695            }
696
697            fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
698            where
699                V: serde::de::SeqAccess<'de>,
700            {
701                let storage = seq
702                    .next_element()?
703                    .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
704                let dimensions = seq
705                    .next_element()?
706                    .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
707                Ok(Tensor { storage, dimensions })
708            }
709
710            fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
711            where
712                V: serde::de::MapAccess<'de>,
713            {
714                let mut storage = None;
715                let mut dimensions = None;
716
717                while let Some(key) = map.next_key()? {
718                    match key {
719                        Field::Storage => {
720                            if storage.is_some() {
721                                return Err(serde::de::Error::duplicate_field("storage"));
722                            }
723                            storage = Some(map.next_value()?);
724                        }
725                        Field::Dimensions => {
726                            if dimensions.is_some() {
727                                return Err(serde::de::Error::duplicate_field("dimensions"));
728                            }
729                            dimensions = Some(map.next_value()?);
730                        }
731                    }
732                }
733
734                let storage = storage.ok_or_else(|| serde::de::Error::missing_field("storage"))?;
735                let dimensions =
736                    dimensions.ok_or_else(|| serde::de::Error::missing_field("dimensions"))?;
737                Ok(Tensor { storage, dimensions })
738            }
739        }
740
741        deserializer.deserialize_struct(
742            "Tensor",
743            &["storage", "dimensions"],
744            TensorVisitor(PhantomData),
745        )
746    }
747}
748
749#[cfg(test)]
750mod tests {
751
752    use slop_alloc::buffer;
753
754    use super::*;
755
756    #[test]
757    fn test_tensor_element_index() {
758        let tensor = Tensor::<u32>::from(buffer![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape([2, 5]);
759        assert_eq!(*tensor[[0, 0]], 1);
760        assert_eq!(*tensor[[0, 1]], 2);
761        assert_eq!(*tensor[[0, 2]], 3);
762        assert_eq!(*tensor[[0, 3]], 4);
763        assert_eq!(*tensor[[0, 4]], 5);
764        assert_eq!(*tensor[[1, 0]], 6);
765        assert_eq!(*tensor[[1, 1]], 7);
766        assert_eq!(*tensor[[1, 2]], 8);
767        assert_eq!(*tensor[[1, 3]], 9);
768        assert_eq!(*tensor[[1, 4]], 10);
769    }
770
771    #[test]
772    fn test_tensor_slice_index() {
773        let tensor = Tensor::<u32>::from(buffer![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape([2, 5]);
774
775        let first_row = tensor.get(0).unwrap();
776        assert_eq!(first_row.sizes(), [5]);
777        assert_eq!(first_row.strides(), [1]);
778        assert_eq!(*first_row[[0]], 1);
779        assert_eq!(*first_row[[1]], 2);
780        assert_eq!(*first_row[[2]], 3);
781        assert_eq!(*first_row[[3]], 4);
782        assert_eq!(*first_row[[4]], 5);
783
784        let second_row = tensor.get(1).unwrap();
785        assert_eq!(*second_row[[0]], 6);
786        assert_eq!(*second_row[[1]], 7);
787        assert_eq!(*second_row[[2]], 8);
788        assert_eq!(*second_row[[3]], 9);
789        assert_eq!(*second_row[[4]], 10);
790
791        let tensor = Tensor::<u32>::from((0..24).collect::<Vec<_>>()).reshape([2, 3, 4]);
792        assert_eq!(*tensor[[0, 0, 0]], 0);
793        assert_eq!(*tensor[[0, 0, 1]], 1);
794        assert_eq!(*tensor[[0, 0, 2]], 2);
795        assert_eq!(*tensor[[0, 0, 3]], 3);
796        assert_eq!(*tensor[[0, 1, 0]], 4);
797        assert_eq!(*tensor[[0, 1, 1]], 5);
798        assert_eq!(*tensor[[0, 1, 2]], 6);
799        assert_eq!(*tensor[[0, 1, 3]], 7);
800        assert_eq!(*tensor[[0, 2, 0]], 8);
801        assert_eq!(*tensor[[0, 2, 1]], 9);
802        assert_eq!(*tensor[[0, 2, 2]], 10);
803        assert_eq!(*tensor[[0, 2, 3]], 11);
804        assert_eq!(*tensor[[1, 0, 0]], 12);
805        assert_eq!(*tensor[[1, 0, 1]], 13);
806        assert_eq!(*tensor[[1, 0, 2]], 14);
807        assert_eq!(*tensor[[1, 0, 3]], 15);
808        assert_eq!(*tensor[[1, 1, 0]], 16);
809        assert_eq!(*tensor[[1, 1, 1]], 17);
810        assert_eq!(*tensor[[1, 1, 2]], 18);
811        assert_eq!(*tensor[[1, 1, 3]], 19);
812        assert_eq!(*tensor[[1, 2, 0]], 20);
813        assert_eq!(*tensor[[1, 2, 1]], 21);
814        assert_eq!(*tensor[[1, 2, 2]], 22);
815        assert_eq!(*tensor[[1, 2, 3]], 23);
816    }
817
818    #[test]
819    fn test_p3_matrix_to_tensor() {
820        let mut rng = rand::thread_rng();
821        let matrix = slop_matrix::dense::RowMajorMatrix::<u32>::rand(&mut rng, 100, 400);
822        let tensor = Tensor::from(matrix.clone());
823
824        assert_eq!(tensor.sizes(), [100, 400]);
825
826        let matrix_back = slop_matrix::dense::RowMajorMatrix::<u32>::try_from(tensor).unwrap();
827        assert_eq!(matrix_back.values, matrix.values);
828    }
829
830    #[test]
831    fn test_tensor_macro() {
832        let tensor = tensor![1, 2, 3, 4, 5, 6];
833        assert_eq!(tensor.sizes(), [6]);
834        assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5, 6]);
835
836        let tensor = tensor![[1, 2, 3], [4, 5, 6]];
837        assert_eq!(tensor.sizes(), [2, 3]);
838        assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5, 6]);
839
840        let tensor = tensor![[1, 2, 3, 4, 5]];
841        assert_eq!(tensor.sizes(), [1, 5]);
842        assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5]);
843
844        let tensor = tensor![[1], [2], [3], [4], [5]];
845        assert_eq!(tensor.sizes(), [5, 1]);
846        assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5]);
847    }
848
849    #[test]
850    fn test_tensor_serialize_deserialize() {
851        let tensor = Tensor::<u32>::from(buffer![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape([2, 5]);
852        let serialized = serde_json::to_string(&tensor).unwrap();
853        let deserialized: Tensor<u32> = serde_json::from_str(&serialized).unwrap();
854        assert_eq!(deserialized, tensor);
855    }
856}