1use std::{
2 marker::PhantomData,
3 mem::ManuallyDrop,
4 ops::{Index, IndexMut},
5};
6
7use derive_where::derive_where;
8use rand::{distributions::Standard, prelude::Distribution, Rng};
9use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
10use slop_algebra::{ExtensionField, Field};
11use slop_alloc::{
12 Backend, Buffer, CpuBackend, HasBackend, Init, TryReserveError, GLOBAL_CPU_BACKEND,
13};
14use slop_matrix::Matrix;
15
16use crate::{Dimensions, DimensionsError};
17
18#[derive(Debug, Clone)]
19#[derive_where(PartialEq, Eq; Buffer<T, A>)]
20pub struct Tensor<T, A: Backend = CpuBackend> {
21 pub storage: Buffer<T, A>,
22 pub dimensions: Dimensions,
23}
24
25impl<T, A: Backend> Tensor<T, A> {
26 #[inline]
27 pub fn with_sizes_in(sizes: impl AsRef<[usize]>, allocator: A) -> Self {
28 Self::try_with_sizes_in(sizes, allocator).unwrap()
29 }
30
31 #[inline]
32 pub fn zeros_in(sizes: impl AsRef<[usize]>, allocator: A) -> Self {
33 let mut tensor = Self::with_sizes_in(sizes, allocator);
34 tensor.storage.write_bytes(0, tensor.total_len() * std::mem::size_of::<T>()).unwrap();
35 tensor
36 }
37
38 #[inline]
39 pub fn zeros_in_with_total_capacity(sizes: impl AsRef<[usize]>, allocator: A) -> Self {
40 let mut tensor = Self::with_sizes_in(sizes, allocator);
41 tensor.storage.write_bytes(0, tensor.total_len() * std::mem::size_of::<T>()).unwrap();
42 tensor
43 }
44
45 #[inline]
46 pub fn try_with_sizes_in(
47 sizes: impl AsRef<[usize]>,
48 allocator: A,
49 ) -> Result<Self, TryReserveError> {
50 let dimensions = Dimensions::try_from(sizes.as_ref()).unwrap();
51 Ok(Self {
52 storage: Buffer::try_with_capacity_in(dimensions.total_len(), allocator)?,
53 dimensions,
54 })
55 }
56
57 #[track_caller]
58 pub fn reshape_in_place(&mut self, sizes: impl AsRef<[usize]>) {
59 #[cold]
60 #[track_caller]
61 #[inline(never)]
62 fn dimension_fail(new_dimensions: &Dimensions, old_dimensions: &Dimensions) -> ! {
63 panic!(
64 "TensorView::reshape: dimension mismatch: {new_dimensions:?} vs {old_dimensions:?}"
65 );
66 }
67
68 let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
69 if self.dimensions.compatible(&dimensions).is_err() {
70 dimension_fail(&dimensions, &self.dimensions);
71 }
72 self.dimensions = dimensions;
73 }
74
75 #[inline]
76 #[track_caller]
77 pub fn reshape(mut self, sizes: impl AsRef<[usize]>) -> Self {
78 #[cold]
79 #[track_caller]
80 #[inline(never)]
81 fn dimension_fail(new_dimensions: &Dimensions, old_dimensions: &Dimensions) -> ! {
82 panic!(
83 "TensorView::reshape: dimension mismatch: {new_dimensions:?} vs {old_dimensions:?}"
84 );
85 }
86
87 let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
88 if self.dimensions.compatible(&dimensions).is_err() {
89 dimension_fail(&dimensions, &self.dimensions);
90 }
91 self.dimensions = dimensions;
92 self
93 }
94
95 #[inline]
99 pub unsafe fn reshape_unchecked(mut self, dimensions: Dimensions) {
100 self.dimensions = dimensions;
101 }
102
103 #[inline]
104 pub fn flatten_in_place(&mut self) {
105 self.reshape_in_place([self.dimensions.total_len()]);
106 }
107
108 #[inline]
109 pub fn flatten(mut self) -> Self {
110 self.flatten_in_place();
111 self
112 }
113
114 #[inline]
115 pub fn into_buffer(self) -> Buffer<T, A> {
116 self.storage
117 }
118
119 #[inline]
120 pub fn as_buffer(&self) -> &Buffer<T, A> {
121 &self.storage
122 }
123
124 #[inline]
125 pub fn as_mut_buffer(&mut self) -> &mut Buffer<T, A> {
126 &mut self.storage
127 }
128
129 #[inline]
130 pub fn backend(&self) -> &A {
131 self.storage.allocator()
132 }
133
134 #[inline]
135 pub fn shape(&self) -> &Dimensions {
136 &self.dimensions
137 }
138
139 #[inline]
141 pub fn sizes(&self) -> &[usize] {
142 self.dimensions.sizes()
143 }
144
145 #[inline]
146 pub fn strides(&self) -> &[usize] {
147 self.dimensions.strides()
148 }
149
150 #[inline]
151 pub fn as_ptr(&self) -> *const T {
152 self.storage.as_ptr()
153 }
154
155 #[inline]
159 pub unsafe fn owned_unchecked(&self) -> ManuallyDrop<Self> {
160 self.owned_unchecked_in(self.storage.allocator().clone())
161 }
162
163 #[inline]
167 pub unsafe fn owned_unchecked_in(&self, storage_allocator: A) -> ManuallyDrop<Self> {
168 let dimensions = self.dimensions.clone();
169 let storage = self.storage.owned_unchecked_in(storage_allocator);
170 let storage = ManuallyDrop::into_inner(storage);
171 ManuallyDrop::new(Self { storage, dimensions })
172 }
173
174 #[inline]
175 pub fn total_len(&self) -> usize {
176 self.dimensions.total_len()
177 }
178
179 pub fn as_mut_ptr(&mut self) -> *mut T {
180 self.storage.as_mut_ptr()
181 }
182
183 #[inline]
184 pub fn as_view(&'_ self) -> TensorView<'_, T, A> {
185 TensorView {
186 ptr: self.as_ptr(),
187 dimensions: self.dimensions.clone(),
188 backend: self.backend().clone(),
189 _marker: PhantomData,
190 }
191 }
192
193 #[inline]
194 pub fn as_view_mut(&'_ mut self) -> TensorViewMut<'_, T, A> {
195 TensorViewMut {
196 ptr: self.as_mut_ptr(),
197 dimensions: self.dimensions.clone(),
198 _marker: PhantomData,
199 }
200 }
201
202 #[inline]
203 pub fn get(&'_ self, index: usize) -> Option<TensorView<'_, T, A>> {
204 self.as_view().get(index)
205 }
206
207 #[inline]
208 pub fn get_mut(&'_ mut self, index: usize) -> Option<TensorViewMut<'_, T, A>> {
209 self.as_view_mut().get(index)
210 }
211
212 #[inline]
213 pub fn split(&'_ self) -> impl Iterator<Item = TensorView<'_, T, A>> {
214 self.as_view().split()
215 }
216
217 #[inline]
218 pub fn split_mut(&'_ mut self) -> impl Iterator<Item = TensorViewMut<'_, T, A>> {
219 self.as_view_mut().split_mut()
220 }
221
222 #[inline]
226 pub unsafe fn assume_init(&mut self) {
227 self.storage.set_len(self.storage.capacity());
228 }
229
230 pub fn flatten_to_base<F: Field>(self) -> Tensor<F, A>
231 where
232 T: ExtensionField<F>,
233 {
234 let [height, width]: [usize; 2] = self.sizes().try_into().unwrap();
235 let dimensions = Dimensions::try_from([height, T::D * width]).unwrap();
236 let data_storage = self.into_buffer().flatten_to_base();
237 Tensor { storage: data_storage, dimensions }
238 }
239
240 pub fn into_extension<ET: ExtensionField<T>>(self) -> Tensor<ET, A>
241 where
242 T: Field,
243 {
244 let [height, width]: [usize; 2] = self.sizes().try_into().unwrap();
245 let dimensions = Dimensions::try_from([height, width / ET::D]).unwrap();
246 let extension_storage = self.into_buffer().into_extension();
247 Tensor { storage: extension_storage, dimensions }
248 }
249}
250
251impl<T, A: Backend, I: AsRef<[usize]>> Index<I> for Tensor<T, A> {
252 type Output = Init<T, A>;
253
254 #[track_caller]
255 fn index(&self, index: I) -> &Self::Output {
256 #[cold]
257 #[track_caller]
258 #[inline(never)]
259 fn dimension_fail(index_len: usize, sizes_len: usize) -> ! {
260 panic!(
261 "Index length ({index_len}) does not match tensor dimensions length ({sizes_len})"
262 );
263 }
264
265 if index.as_ref().len() != self.dimensions.sizes().len() {
266 dimension_fail(index.as_ref().len(), self.dimensions.sizes().len());
267 }
268 let index = self.dimensions.index_map(index);
269 &self.storage[index]
270 }
271}
272
273impl<T, A: Backend, I: AsRef<[usize]>> IndexMut<I> for Tensor<T, A> {
274 fn index_mut(&mut self, index: I) -> &mut Self::Output {
275 let index = self.dimensions.index_map(index);
276 &mut self.storage[index]
277 }
278}
279
280impl<T, A: Backend> From<Buffer<T, A>> for Tensor<T, A> {
281 #[inline]
282 fn from(buffer: Buffer<T, A>) -> Self {
283 let dims = [buffer.len()].into_iter().collect();
284 Self { storage: buffer, dimensions: dims }
285 }
286}
287
288impl<T, A: Backend> HasBackend for Tensor<T, A> {
289 type Backend = A;
290
291 fn backend(&self) -> &Self::Backend {
292 self.backend()
293 }
294}
295
296impl<T> From<Vec<T>> for Tensor<T, CpuBackend> {
297 #[inline]
298 fn from(vec: Vec<T>) -> Self {
299 Self::from(Buffer::from(vec))
300 }
301}
302
303impl<T> FromIterator<T> for Tensor<T, CpuBackend> {
304 #[inline]
305 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
306 Self::from(iter.into_iter().collect::<Vec<_>>())
307 }
308}
309
310impl<T: Clone + Send + Sync> From<slop_matrix::dense::RowMajorMatrix<T>> for Tensor<T, CpuBackend> {
311 fn from(value: slop_matrix::dense::RowMajorMatrix<T>) -> Self {
312 let dimensions: Dimensions = [value.height(), value.width()].try_into().unwrap();
313 let storage = Buffer::from(value.values);
314 Self { storage, dimensions }
315 }
316}
317
318impl<T: Clone + Send + Sync> TryFrom<Tensor<T, CpuBackend>>
319 for slop_matrix::dense::RowMajorMatrix<T>
320{
321 type Error = DimensionsError;
322 fn try_from(value: Tensor<T, CpuBackend>) -> Result<Self, Self::Error> {
323 if value.sizes().len() != 2 {
324 return Err(DimensionsError::TooManyDimensions(value.sizes().len()));
325 }
326 let width = value.sizes()[1];
327 let values = value.storage.into_vec();
328 Ok(Self::new(values, width))
329 }
330}
331
332impl<T> Tensor<T, CpuBackend> {
333 pub fn rand<R: Rng>(rng: &mut R, sizes: impl AsRef<[usize]>) -> Self
334 where
335 Standard: Distribution<T>,
336 {
337 let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
338 let values = rng.sample_iter(Standard).take(dimensions.total_len()).collect::<Vec<_>>();
339 Self { storage: Buffer::from(values), dimensions }
340 }
341
342 #[inline]
343 pub fn with_sizes(sizes: impl AsRef<[usize]>) -> Self {
344 Tensor::with_sizes_in(sizes, GLOBAL_CPU_BACKEND)
345 }
346
347 #[inline]
348 pub fn as_slice(&self) -> &[T] {
349 &self.storage[..]
350 }
351
352 #[inline]
353 pub fn as_mut_slice(&mut self) -> &mut [T] {
354 &mut self.storage[..]
355 }
356}
357
358#[derive(Debug)]
359pub struct TensorView<'a, T, A: Backend = CpuBackend> {
360 ptr: *const T,
361 dimensions: Dimensions,
362 backend: A,
363 _marker: PhantomData<&'a Tensor<T, A>>,
365}
366
367impl<'a, T, A: Backend> TensorView<'a, T, A> {
368 #[inline]
369 pub fn as_ptr(&self) -> *const T {
370 self.ptr
371 }
372
373 #[inline]
374 pub fn sizes(&self) -> &[usize] {
375 self.dimensions.sizes()
376 }
377
378 #[inline]
379 pub fn backend(&self) -> &A {
380 &self.backend
381 }
382
383 #[inline]
384 pub unsafe fn from_raw_parts(ptr: *const T, dimensions: Dimensions, backend: A) -> Self {
388 Self { ptr, dimensions, backend, _marker: PhantomData }
389 }
390
391 #[inline]
392 pub fn strides(&self) -> &[usize] {
393 self.dimensions.strides()
394 }
395
396 #[inline]
397 pub fn total_len(&self) -> usize {
398 self.dimensions.total_len()
399 }
400
401 #[inline]
402 pub fn shape(&self) -> &Dimensions {
403 &self.dimensions
404 }
405
406 #[inline]
407 pub fn flatten(self) -> TensorView<'a, T, A> {
408 let total_len = self.total_len();
409 self.reshape([total_len])
410 }
411
412 #[inline]
413 #[track_caller]
414 pub fn reshape(self, sizes: impl AsRef<[usize]>) -> TensorView<'a, T, A> {
415 #[cold]
416 #[track_caller]
417 #[inline(never)]
418 fn dimension_fail(new_dimensions: &Dimensions, old_dimensions: &Dimensions) -> ! {
419 panic!(
420 "TensorView::reshape: dimension mismatch: {new_dimensions:?} vs {old_dimensions:?}"
421 );
422 }
423
424 let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
425 if self.dimensions.compatible(&dimensions).is_err() {
426 dimension_fail(&dimensions, &self.dimensions);
427 }
428 TensorView {
429 ptr: self.ptr,
430 dimensions,
431 backend: self.backend.clone(),
432 _marker: PhantomData,
433 }
434 }
435
436 #[inline]
437 pub fn get(mut self, index: usize) -> Option<Self> {
438 let size = self.dimensions.sizes_mut().remove(0);
439 if index >= size {
440 return None;
441 }
442 let stride = self.dimensions.strides_mut().remove(0);
443 let offset = index * stride;
444
445 let ptr = unsafe { self.ptr.add(offset) };
446 Some(Self {
447 ptr,
448 dimensions: self.dimensions,
449 backend: self.backend.clone(),
450 _marker: PhantomData,
451 })
452 }
453
454 pub fn split(self) -> impl Iterator<Item = Self> {
455 (0..self.dimensions.sizes()[0]).map(move |i| self.clone().get(i).unwrap())
456 }
457}
458
459impl<'a, T, A: Backend> Clone for TensorView<'a, T, A> {
460 fn clone(&self) -> Self {
461 Self {
462 ptr: self.ptr,
463 dimensions: self.dimensions.clone(),
464 backend: self.backend.clone(),
465 _marker: PhantomData,
466 }
467 }
468}
469
470impl<'a, T, A: Backend> From<&'a Tensor<T, A>> for TensorView<'a, T, A> {
471 fn from(tensor: &'a Tensor<T, A>) -> Self {
472 tensor.as_view()
473 }
474}
475
476impl<'a, T, A: Backend, I: AsRef<[usize]>> Index<I> for TensorView<'a, T, A> {
477 type Output = Init<T, A>;
478
479 #[inline]
480 fn index(&self, index: I) -> &Self::Output {
481 let index = self.dimensions.index_map(index);
482 unsafe {
483 let ptr = self.ptr.add(index) as *const Init<T, A>;
484 ptr.as_ref().unwrap()
485 }
486 }
487}
488
489impl<T> Default for Tensor<T, CpuBackend> {
490 fn default() -> Self {
491 Self::from(Buffer::default())
492 }
493}
494
495#[derive(Debug)]
496pub struct TensorViewMut<'a, T, A: Backend = CpuBackend> {
497 ptr: *mut T,
498 dimensions: Dimensions,
499 _marker: PhantomData<&'a mut Tensor<T, A>>,
502}
503
504impl<'a, T, A: Backend> TensorViewMut<'a, T, A> {
505 #[inline]
506 pub fn as_mut_ptr(&mut self) -> *mut T {
507 self.ptr
508 }
509
510 #[inline]
511 pub fn sizes(&self) -> &[usize] {
512 self.dimensions.sizes()
513 }
514
515 #[inline]
516 pub fn shape(&self) -> &Dimensions {
517 &self.dimensions
518 }
519
520 #[inline]
521 pub fn strides(&self) -> &[usize] {
522 self.dimensions.strides()
523 }
524
525 #[inline]
526 pub fn flatten(self) -> TensorViewMut<'a, T, A> {
527 let total_len = self.total_len();
528 self.reshape([total_len])
529 }
530
531 #[inline]
532 pub fn reshape(self, sizes: impl AsRef<[usize]>) -> TensorViewMut<'a, T, A> {
533 let dimensions: Dimensions = sizes.as_ref().try_into().unwrap();
534 self.dimensions.compatible(&dimensions).unwrap();
535 TensorViewMut { ptr: self.ptr, dimensions, _marker: PhantomData }
536 }
537
538 #[inline]
539 pub fn get(mut self, index: usize) -> Option<Self> {
540 let size = self.dimensions.sizes_mut().remove(0);
541 if index >= size {
542 return None;
543 }
544 let stride = self.dimensions.strides_mut().remove(0);
545 let offset = index * stride;
546
547 let ptr = unsafe { self.ptr.add(offset) };
548 Some(Self { ptr, dimensions: self.dimensions, _marker: PhantomData })
549 }
550
551 #[inline]
552 pub fn split_mut(self) -> impl Iterator<Item = Self> {
553 (0..self.dimensions.sizes()[0]).map(move |i| {
554 let self_copy =
555 Self { ptr: self.ptr, dimensions: self.dimensions.clone(), _marker: PhantomData };
556 self_copy.get(i).unwrap()
557 })
558 }
559
560 #[inline]
561 pub fn total_len(&self) -> usize {
562 self.dimensions.total_len()
563 }
564}
565
566impl<'a, T> TensorView<'a, T, CpuBackend> {
567 #[inline]
568 pub fn as_slice(self) -> &'a [T] {
569 unsafe { std::slice::from_raw_parts(self.ptr, self.dimensions.total_len()) }
570 }
571}
572
573impl<'a, T> TensorViewMut<'a, T, CpuBackend> {
574 #[inline]
575 pub fn as_slice(self) -> &'a [T] {
576 unsafe { std::slice::from_raw_parts(self.ptr, self.dimensions.total_len()) }
577 }
578
579 #[inline]
580 pub fn as_mut_slice(self) -> &'a mut [T] {
581 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.dimensions.total_len()) }
582 }
583}
584
585impl<'a, T, A: Backend> From<&'a mut Tensor<T, A>> for TensorViewMut<'a, T, A> {
586 fn from(tensor: &'a mut Tensor<T, A>) -> Self {
587 tensor.as_view_mut()
588 }
589}
590
591impl<'a, T, A: Backend, I: AsRef<[usize]>> Index<I> for TensorViewMut<'a, T, A> {
592 type Output = Init<T, A>;
593
594 #[inline]
595 fn index(&self, index: I) -> &Self::Output {
596 let index = self.dimensions.index_map(index);
597 unsafe {
598 let ptr = self.ptr.add(index) as *const T as *const Init<T, A>;
599 ptr.as_ref().unwrap()
600 }
601 }
602}
603
604impl<'a, T, A: Backend, I: AsRef<[usize]>> IndexMut<I> for TensorViewMut<'a, T, A> {
605 #[inline]
606 fn index_mut(&mut self, index: I) -> &mut Self::Output {
607 let index = self.dimensions.index_map(index);
608 unsafe {
609 let ptr = self.ptr.add(index) as *mut Init<T, A>;
610 ptr.as_mut().unwrap()
611 }
612 }
613}
614
615#[macro_export]
617macro_rules! tensor {
618 ($([$($elem:expr),* $(,)?]),+ $(,)?) => {{
626 let rows = vec![
628 $(
629 vec![$($elem,)*]
630 ),*
631 ];
632
633 let row_len = rows[0].len();
635 let rows_count = rows.len();
636 if !rows.iter().all(|r| r.len() == row_len) {
637 panic!("All sub-lists must have the same length to form a 2D tensor.");
638 }
639
640 let flattened = rows.into_iter().flatten().collect::<Vec<_>>();
642
643 $crate::Tensor::from(flattened).reshape([rows_count, row_len])
646 }};
647
648 ([$($elem:expr),* $(,)?]) => {{
653 let v = vec![$($elem,)*];
654 $crate::Tensor::from(v)
655 }};
656
657 ($($elem:expr),+ $(,)?) => {{
662 let v = vec![$($elem,)*];
663 $crate::Tensor::from(v)
664 }};
665}
666
667impl<T: Serialize> Serialize for Tensor<T> {
671 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
672 let mut state = serializer.serialize_struct("Tensor", 2)?;
673 state.serialize_field("storage", &self.storage)?;
674 state.serialize_field("dimensions", &self.dimensions)?;
675 state.end()
676 }
677}
678
679impl<'de, T: Deserialize<'de>> Deserialize<'de> for Tensor<T> {
680 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
681 #[derive(Deserialize)]
682 #[serde(field_identifier, rename_all = "lowercase")]
683 enum Field {
684 Storage,
685 Dimensions,
686 }
687
688 struct TensorVisitor<T>(PhantomData<T>);
689
690 impl<'de, T: Deserialize<'de>> serde::de::Visitor<'de> for TensorVisitor<T> {
691 type Value = Tensor<T>;
692
693 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
694 formatter.write_str("struct Tensor")
695 }
696
697 fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
698 where
699 V: serde::de::SeqAccess<'de>,
700 {
701 let storage = seq
702 .next_element()?
703 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
704 let dimensions = seq
705 .next_element()?
706 .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
707 Ok(Tensor { storage, dimensions })
708 }
709
710 fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
711 where
712 V: serde::de::MapAccess<'de>,
713 {
714 let mut storage = None;
715 let mut dimensions = None;
716
717 while let Some(key) = map.next_key()? {
718 match key {
719 Field::Storage => {
720 if storage.is_some() {
721 return Err(serde::de::Error::duplicate_field("storage"));
722 }
723 storage = Some(map.next_value()?);
724 }
725 Field::Dimensions => {
726 if dimensions.is_some() {
727 return Err(serde::de::Error::duplicate_field("dimensions"));
728 }
729 dimensions = Some(map.next_value()?);
730 }
731 }
732 }
733
734 let storage = storage.ok_or_else(|| serde::de::Error::missing_field("storage"))?;
735 let dimensions =
736 dimensions.ok_or_else(|| serde::de::Error::missing_field("dimensions"))?;
737 Ok(Tensor { storage, dimensions })
738 }
739 }
740
741 deserializer.deserialize_struct(
742 "Tensor",
743 &["storage", "dimensions"],
744 TensorVisitor(PhantomData),
745 )
746 }
747}
748
749#[cfg(test)]
750mod tests {
751
752 use slop_alloc::buffer;
753
754 use super::*;
755
756 #[test]
757 fn test_tensor_element_index() {
758 let tensor = Tensor::<u32>::from(buffer![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape([2, 5]);
759 assert_eq!(*tensor[[0, 0]], 1);
760 assert_eq!(*tensor[[0, 1]], 2);
761 assert_eq!(*tensor[[0, 2]], 3);
762 assert_eq!(*tensor[[0, 3]], 4);
763 assert_eq!(*tensor[[0, 4]], 5);
764 assert_eq!(*tensor[[1, 0]], 6);
765 assert_eq!(*tensor[[1, 1]], 7);
766 assert_eq!(*tensor[[1, 2]], 8);
767 assert_eq!(*tensor[[1, 3]], 9);
768 assert_eq!(*tensor[[1, 4]], 10);
769 }
770
771 #[test]
772 fn test_tensor_slice_index() {
773 let tensor = Tensor::<u32>::from(buffer![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape([2, 5]);
774
775 let first_row = tensor.get(0).unwrap();
776 assert_eq!(first_row.sizes(), [5]);
777 assert_eq!(first_row.strides(), [1]);
778 assert_eq!(*first_row[[0]], 1);
779 assert_eq!(*first_row[[1]], 2);
780 assert_eq!(*first_row[[2]], 3);
781 assert_eq!(*first_row[[3]], 4);
782 assert_eq!(*first_row[[4]], 5);
783
784 let second_row = tensor.get(1).unwrap();
785 assert_eq!(*second_row[[0]], 6);
786 assert_eq!(*second_row[[1]], 7);
787 assert_eq!(*second_row[[2]], 8);
788 assert_eq!(*second_row[[3]], 9);
789 assert_eq!(*second_row[[4]], 10);
790
791 let tensor = Tensor::<u32>::from((0..24).collect::<Vec<_>>()).reshape([2, 3, 4]);
792 assert_eq!(*tensor[[0, 0, 0]], 0);
793 assert_eq!(*tensor[[0, 0, 1]], 1);
794 assert_eq!(*tensor[[0, 0, 2]], 2);
795 assert_eq!(*tensor[[0, 0, 3]], 3);
796 assert_eq!(*tensor[[0, 1, 0]], 4);
797 assert_eq!(*tensor[[0, 1, 1]], 5);
798 assert_eq!(*tensor[[0, 1, 2]], 6);
799 assert_eq!(*tensor[[0, 1, 3]], 7);
800 assert_eq!(*tensor[[0, 2, 0]], 8);
801 assert_eq!(*tensor[[0, 2, 1]], 9);
802 assert_eq!(*tensor[[0, 2, 2]], 10);
803 assert_eq!(*tensor[[0, 2, 3]], 11);
804 assert_eq!(*tensor[[1, 0, 0]], 12);
805 assert_eq!(*tensor[[1, 0, 1]], 13);
806 assert_eq!(*tensor[[1, 0, 2]], 14);
807 assert_eq!(*tensor[[1, 0, 3]], 15);
808 assert_eq!(*tensor[[1, 1, 0]], 16);
809 assert_eq!(*tensor[[1, 1, 1]], 17);
810 assert_eq!(*tensor[[1, 1, 2]], 18);
811 assert_eq!(*tensor[[1, 1, 3]], 19);
812 assert_eq!(*tensor[[1, 2, 0]], 20);
813 assert_eq!(*tensor[[1, 2, 1]], 21);
814 assert_eq!(*tensor[[1, 2, 2]], 22);
815 assert_eq!(*tensor[[1, 2, 3]], 23);
816 }
817
818 #[test]
819 fn test_p3_matrix_to_tensor() {
820 let mut rng = rand::thread_rng();
821 let matrix = slop_matrix::dense::RowMajorMatrix::<u32>::rand(&mut rng, 100, 400);
822 let tensor = Tensor::from(matrix.clone());
823
824 assert_eq!(tensor.sizes(), [100, 400]);
825
826 let matrix_back = slop_matrix::dense::RowMajorMatrix::<u32>::try_from(tensor).unwrap();
827 assert_eq!(matrix_back.values, matrix.values);
828 }
829
830 #[test]
831 fn test_tensor_macro() {
832 let tensor = tensor![1, 2, 3, 4, 5, 6];
833 assert_eq!(tensor.sizes(), [6]);
834 assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5, 6]);
835
836 let tensor = tensor![[1, 2, 3], [4, 5, 6]];
837 assert_eq!(tensor.sizes(), [2, 3]);
838 assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5, 6]);
839
840 let tensor = tensor![[1, 2, 3, 4, 5]];
841 assert_eq!(tensor.sizes(), [1, 5]);
842 assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5]);
843
844 let tensor = tensor![[1], [2], [3], [4], [5]];
845 assert_eq!(tensor.sizes(), [5, 1]);
846 assert_eq!(tensor.as_slice(), [1, 2, 3, 4, 5]);
847 }
848
849 #[test]
850 fn test_tensor_serialize_deserialize() {
851 let tensor = Tensor::<u32>::from(buffer![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).reshape([2, 5]);
852 let serialized = serde_json::to_string(&tensor).unwrap();
853 let deserialized: Tensor<u32> = serde_json::from_str(&serialized).unwrap();
854 assert_eq!(deserialized, tensor);
855 }
856}