1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
use super::Tensor; use crate::datum::Datum; use ndarray::*; use std::sync::Arc; pub fn arr4<A, V, U, T>(xs: &[V]) -> Array4<A> where V: FixedInitializer<Elem = U> + Clone, U: FixedInitializer<Elem = T> + Clone, T: FixedInitializer<Elem = A> + Clone, A: Clone, { use ndarray::*; let mut xs = xs.to_vec(); let dim = Ix4(xs.len(), V::len(), U::len(), T::len()); let ptr = xs.as_mut_ptr(); let len = xs.len(); let cap = xs.capacity(); let expand_len = len * V::len() * U::len() * T::len(); ::std::mem::forget(xs); unsafe { let v = if ::std::mem::size_of::<A>() == 0 { Vec::from_raw_parts(ptr as *mut A, expand_len, expand_len) } else if V::len() == 0 || U::len() == 0 || T::len() == 0 { Vec::new() } else { let expand_cap = cap * V::len() * U::len() * T::len(); Vec::from_raw_parts(ptr as *mut A, expand_len, expand_cap) }; ArrayBase::from_shape_vec_unchecked(dim, v) } } pub fn tensor0<A: Datum>(x: A) -> Tensor { Tensor::from(arr0(x)) } pub fn tensor1<A: Datum>(xs: &[A]) -> Tensor { Tensor::from(arr1(xs)) } pub fn tensor2<A: Datum, T>(xs: &[T]) -> Tensor where T: FixedInitializer<Elem = A> + Clone, { Tensor::from(arr2(xs)) } pub fn tensor3<A: Datum, T, U>(xs: &[U]) -> Tensor where U: FixedInitializer<Elem = T> + Clone, T: FixedInitializer<Elem = A> + Clone, { Tensor::from(arr3(xs)) } pub fn tensor4<A: Datum, T, U, V>(xs: &[V]) -> Tensor where V: FixedInitializer<Elem = U> + Clone, U: FixedInitializer<Elem = T> + Clone, T: FixedInitializer<Elem = A> + Clone, { Tensor::from(arr4(xs)) } pub fn rctensor0<A: Datum>(x: A) -> Arc<Tensor> { Arc::new(Tensor::from(arr0(x))) } pub fn rctensor1<A: Datum>(xs: &[A]) -> Arc<Tensor> { Arc::new(Tensor::from(arr1(xs))) } pub fn rctensor2<A: Datum, T>(xs: &[T]) -> Arc<Tensor> where T: FixedInitializer<Elem = A> + Clone, { Arc::new(Tensor::from(arr2(xs))) } pub fn rctensor3<A: Datum, T, U>(xs: &[U]) -> Arc<Tensor> where U: FixedInitializer<Elem = T> + Clone, T: FixedInitializer<Elem = A> + Clone, { Arc::new(Tensor::from(arr3(xs))) } pub fn rctensor4<A: Datum, T, U, V>(xs: &[V]) -> Arc<Tensor> where V: FixedInitializer<Elem = U> + Clone, U: FixedInitializer<Elem = T> + Clone, T: FixedInitializer<Elem = A> + Clone, { Arc::new(Tensor::from(arr4(xs))) }