tract_data/tensor/
litteral.rs1use super::Tensor;
2use crate::datum::Datum;
3use ndarray::*;
4use std::sync::Arc;
5
6pub fn arr4<A, const N: usize, const M: usize, const T: usize>(xs: &[[[[A; T]; M]; N]]) -> Array4<A>
7where
8 A: Clone,
9{
10 use ndarray::*;
11 let xs = xs.to_vec();
12 let dim = Ix4(xs.len(), N, M, T);
13 let len = xs.len();
14 let cap = xs.capacity();
15 let expand_len = len * N * M * T;
16 let ptr = Box::into_raw(xs.into_boxed_slice());
17 unsafe {
18 let v = if ::std::mem::size_of::<A>() == 0 {
19 Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
20 } else if N == 0 || M == 0 || T == 0 {
21 Vec::new()
22 } else {
23 let expand_cap = cap * N * M * T;
24 Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
25 };
26 ArrayBase::from_shape_vec_unchecked(dim, v)
27 }
28}
29
30pub fn tensor0<A: Datum>(x: A) -> Tensor {
31 unsafe {
32 let mut tensor = Tensor::uninitialized::<A>(&[]).unwrap();
33 tensor.as_slice_mut_unchecked::<A>()[0] = x;
34 tensor
35 }
36}
37
38pub fn tensor1<A: Datum>(xs: &[A]) -> Tensor {
39 Tensor::from(arr1(xs))
40}
41
42pub fn tensor2<A: Datum, const N: usize>(xs: &[[A; N]]) -> Tensor {
43 Tensor::from(arr2(xs))
44}
45
46pub fn tensor3<A: Datum, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Tensor {
47 Tensor::from(arr3(xs))
48}
49
50pub fn tensor4<A: Datum, const N: usize, const M: usize, const T: usize>(
51 xs: &[[[[A; T]; M]; N]],
52) -> Tensor {
53 Tensor::from(arr4(xs))
54}
55
56pub fn rctensor0<A: Datum>(x: A) -> Arc<Tensor> {
57 Arc::new(Tensor::from(arr0(x)))
58}
59
60pub fn rctensor1<A: Datum>(xs: &[A]) -> Arc<Tensor> {
61 Arc::new(Tensor::from(arr1(xs)))
62}
63
64pub fn rctensor2<A: Datum, const N: usize>(xs: &[[A; N]]) -> Arc<Tensor> {
65 Arc::new(Tensor::from(arr2(xs)))
66}
67
68pub fn rctensor3<A: Datum, const N: usize, const M: usize>(xs: &[[[A; M]; N]]) -> Arc<Tensor> {
69 Arc::new(Tensor::from(arr3(xs)))
70}
71
72pub fn rctensor4<A: Datum, const N: usize, const M: usize, const T: usize>(
73 xs: &[[[[A; T]; M]; N]],
74) -> Arc<Tensor> {
75 Arc::new(Tensor::from(arr4(xs)))
76}