1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
use crate::dim::DimLike; use std::fmt; use std::marker::PhantomData; #[derive(Copy, Clone, Debug, PartialEq)] pub enum DataFormat { NCHW, NHWC, } impl Default for DataFormat { fn default() -> DataFormat { DataFormat::NCHW } } impl DataFormat { pub fn shape<D, S>(&self, shape: S) -> DataShape<D, S> where D: DimLike, S: AsRef<[D]> + fmt::Debug, { DataShape { fmt: *self, shape: shape, _phantom: PhantomData, } } } #[derive(Copy, Clone, Debug, PartialEq)] pub struct DataShape<D, S> where D: DimLike, S: AsRef<[D]> + fmt::Debug, { pub fmt: DataFormat, pub shape: S, _phantom: PhantomData<D>, } impl<D, S> DataShape<D, S> where D: DimLike, S: AsRef<[D]> + fmt::Debug, { pub fn rank(&self) -> usize { self.shape.as_ref().len() } pub fn hw_rank(&self) -> usize { self.shape.as_ref().len() - 2 } pub fn n_axis(&self) -> usize { 0 } pub fn c_axis(&self) -> usize { match self.fmt { DataFormat::NHWC => self.shape.as_ref().len() - 1, DataFormat::NCHW => 1, } } pub fn h_axis(&self) -> usize { match self.fmt { DataFormat::NHWC => 1, DataFormat::NCHW => 2, } } pub fn hw_axes(&self) -> ::std::ops::Range<usize> { self.h_axis()..self.h_axis() + self.hw_rank() } pub fn n_dim(&self) -> D { self.shape.as_ref()[self.n_axis()] } pub fn c_dim(&self) -> D { self.shape.as_ref()[self.c_axis()] } pub fn hw_dims(&self) -> &[D] { &self.shape.as_ref()[self.h_axis()..][..self.hw_rank()] } }