tract_core/ops/nn/
data_formats.rs1use crate::internal::*;
2use std::fmt;
3use tract_itertools::Itertools;
4
5#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)]
6pub enum DataFormat {
7 #[default]
8 NCHW,
9 NHWC,
10 CHW,
11 HWC,
12}
13
14impl DataFormat {
15 pub fn dispose_n_axis(&self) -> DataFormat {
16 match self {
17 DataFormat::NCHW => DataFormat::CHW,
18 DataFormat::NHWC => DataFormat::HWC,
19 _ => panic!("Attempt at removing N axis on {self:?}"),
20 }
21 }
22
23 pub fn shape<D, S>(&self, shape: S) -> TractResult<BaseDataShape<D, S>>
24 where
25 D: DimLike,
26 S: AsRef<[D]> + fmt::Debug,
27 {
28 let mut strides: TVec<D> = tvec![D::one()];
29 for dim in shape.as_ref().iter().skip(1).rev() {
30 let previous = strides.last().unwrap().clone();
31 strides.push(previous * dim)
32 }
33 strides.reverse();
34 Ok(BaseDataShape { fmt: *self, shape, strides })
35 }
36
37 pub fn from_n_c_hw<D, S>(&self, n: D, c: D, shape: S) -> TractResult<BaseDataShape<D, TVec<D>>>
38 where
39 D: DimLike,
40 S: AsRef<[D]> + fmt::Debug,
41 {
42 let mut me = tvec!();
43 if *self == DataFormat::NCHW || *self == DataFormat::NHWC {
44 me.push(n);
45 }
46 if *self == DataFormat::NCHW || *self == DataFormat::CHW {
47 me.push(c.clone());
48 }
49 me.extend(shape.as_ref().iter().cloned());
50 if *self == DataFormat::NHWC || *self == DataFormat::HWC {
51 me.push(c);
52 }
53 self.shape(me)
54 }
55
56 pub fn has_n(&self) -> bool {
57 *self == DataFormat::NHWC || *self == DataFormat::NCHW
58 }
59
60 pub fn c_is_last(&self) -> bool {
61 *self == DataFormat::NHWC || *self == DataFormat::HWC
62 }
63
64 pub fn h_axis(&self) -> usize {
65 self.has_n() as usize + !self.c_is_last() as usize
66 }
67
68 pub fn with_n(&self) -> DataFormat {
69 match self {
70 DataFormat::CHW => DataFormat::NCHW,
71 DataFormat::HWC => DataFormat::NHWC,
72 _ => *self,
73 }
74 }
75}
76
77pub type SymDataShape = BaseDataShape<TDim, TVec<TDim>>;
78pub type DataShape = BaseDataShape<usize, TVec<usize>>;
79
80#[derive(Clone, PartialEq, Eq, Hash)]
81pub struct BaseDataShape<D, S>
82where
83 D: DimLike,
84 S: AsRef<[D]> + fmt::Debug,
85{
86 pub fmt: DataFormat,
87 pub shape: S,
88 pub strides: TVec<D>,
89}
90
91impl<D, S> fmt::Debug for BaseDataShape<D, S>
92where
93 D: DimLike,
94 S: AsRef<[D]> + fmt::Debug,
95{
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 write!(
98 f,
99 "{:?} {} (strides: {})",
100 self.fmt,
101 self.shape.as_ref().iter().join(","),
102 self.strides.iter().join(",")
103 )
104 }
105}
106
107impl<D, S> BaseDataShape<D, S>
108where
109 D: DimLike,
110 S: AsRef<[D]> + fmt::Debug,
111{
112 #[inline]
113 pub fn rank(&self) -> usize {
114 self.shape.as_ref().len()
115 }
116
117 #[inline]
118 pub fn hw_rank(&self) -> usize {
119 self.rank() - 1 - self.n_axis().is_some() as usize
120 }
121
122 #[inline]
123 pub fn n_axis(&self) -> Option<usize> {
124 match self.fmt {
125 DataFormat::NHWC | DataFormat::NCHW => Some(0),
126 DataFormat::HWC | DataFormat::CHW => None,
127 }
128 }
129
130 #[inline]
131 pub fn c_axis(&self) -> usize {
132 match self.fmt {
133 DataFormat::NHWC | DataFormat::HWC => self.shape.as_ref().len() - 1,
134 DataFormat::NCHW => 1,
135 DataFormat::CHW => 0,
136 }
137 }
138
139 #[inline]
140 pub fn h_axis(&self) -> usize {
141 match self.fmt {
142 DataFormat::HWC => 0,
143 DataFormat::NHWC | DataFormat::CHW => 1,
144 DataFormat::NCHW => 2,
145 }
146 }
147
148 #[inline]
149 pub fn hw_axes(&self) -> ::std::ops::Range<usize> {
150 self.h_axis()..self.h_axis() + self.hw_rank()
151 }
152
153 #[inline]
154 pub fn n_dim(&self) -> Option<&D> {
155 self.n()
156 }
157
158 #[inline]
159 pub fn c_dim(&self) -> &D {
160 self.c()
161 }
162
163 #[inline]
164 pub fn hw_dims(&self) -> &[D] {
165 unsafe { self.shape.as_ref().get_unchecked(self.hw_axes()) }
166 }
167
168 #[inline]
169 pub fn n(&self) -> Option<&D> {
170 unsafe { self.n_axis().map(|axis| self.shape.as_ref().get_unchecked(axis)) }
171 }
172
173 #[inline]
174 pub fn c(&self) -> &D {
175 unsafe { self.shape.as_ref().get_unchecked(self.c_axis()) }
176 }
177
178 #[inline]
179 pub fn n_stride(&self) -> Option<&D> {
180 unsafe { self.n_axis().map(|axis| self.strides.get_unchecked(axis)) }
181 }
182
183 #[inline]
184 pub fn h_stride(&self) -> &D {
185 unsafe { self.hw_strides().get_unchecked(0) }
186 }
187
188 #[inline]
189 pub fn hw_strides(&self) -> &[D] {
190 unsafe { self.strides.get_unchecked(self.hw_axes()) }
191 }
192
193 #[inline]
194 pub fn w_stride(&self) -> &D {
195 unsafe { self.hw_strides().get_unchecked(self.hw_rank() - 1) }
196 }
197
198 #[inline]
199 pub fn c_stride(&self) -> &D {
200 unsafe { self.strides.get_unchecked(self.c_axis()) }
201 }
202}