tract_core/ops/nn/
data_formats.rs

1use crate::internal::*;
2use std::fmt;
3use tract_itertools::Itertools;
4
5#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)]
6pub enum DataFormat {
7    #[default]
8    NCHW,
9    NHWC,
10    CHW,
11    HWC,
12}
13
14impl DataFormat {
15    pub fn dispose_n_axis(&self) -> DataFormat {
16        match self {
17            DataFormat::NCHW => DataFormat::CHW,
18            DataFormat::NHWC => DataFormat::HWC,
19            _ => panic!("Attempt at removing N axis on {self:?}"),
20        }
21    }
22
23    pub fn shape<D, S>(&self, shape: S) -> TractResult<BaseDataShape<D, S>>
24    where
25        D: DimLike,
26        S: AsRef<[D]> + fmt::Debug,
27    {
28        let mut strides: TVec<D> = tvec![D::one()];
29        for dim in shape.as_ref().iter().skip(1).rev() {
30            let previous = strides.last().unwrap().clone();
31            strides.push(previous * dim)
32        }
33        strides.reverse();
34        Ok(BaseDataShape { fmt: *self, shape, strides })
35    }
36
37    pub fn from_n_c_hw<D, S>(&self, n: D, c: D, shape: S) -> TractResult<BaseDataShape<D, TVec<D>>>
38    where
39        D: DimLike,
40        S: AsRef<[D]> + fmt::Debug,
41    {
42        let mut me = tvec!();
43        if *self == DataFormat::NCHW || *self == DataFormat::NHWC {
44            me.push(n);
45        }
46        if *self == DataFormat::NCHW || *self == DataFormat::CHW {
47            me.push(c.clone());
48        }
49        me.extend(shape.as_ref().iter().cloned());
50        if *self == DataFormat::NHWC || *self == DataFormat::HWC {
51            me.push(c);
52        }
53        self.shape(me)
54    }
55
56    pub fn has_n(&self) -> bool {
57        *self == DataFormat::NHWC || *self == DataFormat::NCHW
58    }
59
60    pub fn c_is_last(&self) -> bool {
61        *self == DataFormat::NHWC || *self == DataFormat::HWC
62    }
63
64    pub fn h_axis(&self) -> usize {
65        self.has_n() as usize + !self.c_is_last() as usize
66    }
67
68    pub fn with_n(&self) -> DataFormat {
69        match self {
70            DataFormat::CHW => DataFormat::NCHW,
71            DataFormat::HWC => DataFormat::NHWC,
72            _ => *self,
73        }
74    }
75}
76
77pub type SymDataShape = BaseDataShape<TDim, TVec<TDim>>;
78pub type DataShape = BaseDataShape<usize, TVec<usize>>;
79
80#[derive(Clone, PartialEq, Eq, Hash)]
81pub struct BaseDataShape<D, S>
82where
83    D: DimLike,
84    S: AsRef<[D]> + fmt::Debug,
85{
86    pub fmt: DataFormat,
87    pub shape: S,
88    pub strides: TVec<D>,
89}
90
91impl<D, S> fmt::Debug for BaseDataShape<D, S>
92where
93    D: DimLike,
94    S: AsRef<[D]> + fmt::Debug,
95{
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        write!(
98            f,
99            "{:?} {} (strides: {})",
100            self.fmt,
101            self.shape.as_ref().iter().join(","),
102            self.strides.iter().join(",")
103        )
104    }
105}
106
107impl<D, S> BaseDataShape<D, S>
108where
109    D: DimLike,
110    S: AsRef<[D]> + fmt::Debug,
111{
112    #[inline]
113    pub fn rank(&self) -> usize {
114        self.shape.as_ref().len()
115    }
116
117    #[inline]
118    pub fn hw_rank(&self) -> usize {
119        self.rank() - 1 - self.n_axis().is_some() as usize
120    }
121
122    #[inline]
123    pub fn n_axis(&self) -> Option<usize> {
124        match self.fmt {
125            DataFormat::NHWC | DataFormat::NCHW => Some(0),
126            DataFormat::HWC | DataFormat::CHW => None,
127        }
128    }
129
130    #[inline]
131    pub fn c_axis(&self) -> usize {
132        match self.fmt {
133            DataFormat::NHWC | DataFormat::HWC => self.shape.as_ref().len() - 1,
134            DataFormat::NCHW => 1,
135            DataFormat::CHW => 0,
136        }
137    }
138
139    #[inline]
140    pub fn h_axis(&self) -> usize {
141        match self.fmt {
142            DataFormat::HWC => 0,
143            DataFormat::NHWC | DataFormat::CHW => 1,
144            DataFormat::NCHW => 2,
145        }
146    }
147
148    #[inline]
149    pub fn hw_axes(&self) -> ::std::ops::Range<usize> {
150        self.h_axis()..self.h_axis() + self.hw_rank()
151    }
152
153    #[inline]
154    pub fn n_dim(&self) -> Option<&D> {
155        self.n()
156    }
157
158    #[inline]
159    pub fn c_dim(&self) -> &D {
160        self.c()
161    }
162
163    #[inline]
164    pub fn hw_dims(&self) -> &[D] {
165        unsafe { self.shape.as_ref().get_unchecked(self.hw_axes()) }
166    }
167
168    #[inline]
169    pub fn n(&self) -> Option<&D> {
170        unsafe { self.n_axis().map(|axis| self.shape.as_ref().get_unchecked(axis)) }
171    }
172
173    #[inline]
174    pub fn c(&self) -> &D {
175        unsafe { self.shape.as_ref().get_unchecked(self.c_axis()) }
176    }
177
178    #[inline]
179    pub fn n_stride(&self) -> Option<&D> {
180        unsafe { self.n_axis().map(|axis| self.strides.get_unchecked(axis)) }
181    }
182
183    #[inline]
184    pub fn h_stride(&self) -> &D {
185        unsafe { self.hw_strides().get_unchecked(0) }
186    }
187
188    #[inline]
189    pub fn hw_strides(&self) -> &[D] {
190        unsafe { self.strides.get_unchecked(self.hw_axes()) }
191    }
192
193    #[inline]
194    pub fn w_stride(&self) -> &D {
195        unsafe { self.hw_strides().get_unchecked(self.hw_rank() - 1) }
196    }
197
198    #[inline]
199    pub fn c_stride(&self) -> &D {
200        unsafe { self.strides.get_unchecked(self.c_axis()) }
201    }
202}