1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use super::Tensor;
use crate::datum::Datum;
use ndarray::*;
use std::sync::Arc;

pub fn arr4<A, V, U, T>(xs: &[V]) -> Array4<A>
where
    V: FixedInitializer<Elem = U> + Clone,
    U: FixedInitializer<Elem = T> + Clone,
    T: FixedInitializer<Elem = A> + Clone,
    A: Clone,
{
    use ndarray::*;
    let mut xs = xs.to_vec();
    let dim = Ix4(xs.len(), V::len(), U::len(), T::len());
    let ptr = xs.as_mut_ptr();
    let len = xs.len();
    let cap = xs.capacity();
    let expand_len = len * V::len() * U::len() * T::len();
    ::std::mem::forget(xs);
    unsafe {
        let v = if ::std::mem::size_of::<A>() == 0 {
            Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len)
        } else if V::len() == 0 || U::len() == 0 || T::len() == 0 {
            Vec::new()
        } else {
            let expand_cap = cap * V::len() * U::len() * T::len();
            Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap)
        };
        ArrayBase::from_shape_vec_unchecked(dim, v)
    }
}

pub fn tensor0<A: Datum>(x: A) -> Tensor {
    Tensor::from(arr0(x))
}

pub fn tensor1<A: Datum>(xs: &[A]) -> Tensor {
    Tensor::from(arr1(xs))
}

pub fn tensor2<A: Datum, T>(xs: &[T]) -> Tensor
where
    T: FixedInitializer<Elem = A> + Clone,
{
    Tensor::from(arr2(xs))
}

pub fn tensor3<A: Datum, T, U>(xs: &[U]) -> Tensor
where
    U: FixedInitializer<Elem = T> + Clone,
    T: FixedInitializer<Elem = A> + Clone,
{
    Tensor::from(arr3(xs))
}

pub fn tensor4<A: Datum, T, U, V>(xs: &[V]) -> Tensor
where
    V: FixedInitializer<Elem = U> + Clone,
    U: FixedInitializer<Elem = T> + Clone,
    T: FixedInitializer<Elem = A> + Clone,
{
    Tensor::from(arr4(xs))
}

pub fn rctensor0<A: Datum>(x: A) -> Arc<Tensor> {
    Arc::new(Tensor::from(arr0(x)))
}

pub fn rctensor1<A: Datum>(xs: &[A]) -> Arc<Tensor> {
    Arc::new(Tensor::from(arr1(xs)))
}

pub fn rctensor2<A: Datum, T>(xs: &[T]) -> Arc<Tensor>
where
    T: FixedInitializer<Elem = A> + Clone,
{
    Arc::new(Tensor::from(arr2(xs)))
}

pub fn rctensor3<A: Datum, T, U>(xs: &[U]) -> Arc<Tensor>
where
    U: FixedInitializer<Elem = T> + Clone,
    T: FixedInitializer<Elem = A> + Clone,
{
    Arc::new(Tensor::from(arr3(xs)))
}

pub fn rctensor4<A: Datum, T, U, V>(xs: &[V]) -> Arc<Tensor>
where
    V: FixedInitializer<Elem = U> + Clone,
    U: FixedInitializer<Elem = T> + Clone,
    T: FixedInitializer<Elem = A> + Clone,
{
    Arc::new(Tensor::from(arr4(xs)))
}