1use std::fmt;
5
6use itertools::Either;
7use vortex::error::VortexExpect;
8use vortex::error::VortexResult;
9use vortex::error::vortex_ensure;
10use vortex::error::vortex_ensure_eq;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct FixedShapeTensorMetadata {
15 logical_shape: Vec<usize>,
23
24 dim_names: Option<Vec<String>>,
29
30 permutation: Option<Vec<usize>>,
36}
37
38impl FixedShapeTensorMetadata {
39 pub fn new(shape: Vec<usize>) -> Self {
44 Self {
45 logical_shape: shape,
46 dim_names: None,
47 permutation: None,
48 }
49 }
50
51 pub fn with_dim_names(mut self, names: Vec<String>) -> VortexResult<Self> {
56 if !names.is_empty() {
57 vortex_ensure_eq!(
58 names.len(),
59 self.logical_shape.len(),
60 "dim_names length ({}) must match logical_shape length ({})",
61 names.len(),
62 self.logical_shape.len()
63 );
64 self.dim_names = Some(names);
65 }
66
67 Ok(self)
68 }
69
70 pub fn with_permutation(mut self, permutation: Vec<usize>) -> VortexResult<Self> {
76 if !permutation.is_empty() {
77 vortex_ensure_eq!(
78 permutation.len(),
79 self.logical_shape.len(),
80 "permutation length ({}) must match logical_shape length ({})",
81 permutation.len(),
82 self.logical_shape.len()
83 );
84
85 let mut seen = vec![false; permutation.len()];
87 for &p in &permutation {
88 vortex_ensure!(
89 p < permutation.len(),
90 "permutation index {p} is out of range for {} dimensions",
91 permutation.len()
92 );
93 vortex_ensure!(!seen[p], "permutation contains duplicate index {p}");
94 seen[p] = true;
95 }
96
97 self.permutation = Some(permutation);
98 }
99
100 Ok(self)
101 }
102
103 pub fn ndim(&self) -> usize {
105 self.logical_shape.len()
106 }
107
108 pub fn logical_shape(&self) -> &[usize] {
110 &self.logical_shape
111 }
112
113 pub fn dim_names(&self) -> Option<&[String]> {
115 self.dim_names.as_deref()
116 }
117
118 pub fn permutation(&self) -> Option<&[usize]> {
120 self.permutation.as_deref()
121 }
122
123 pub fn physical_shape(&self) -> impl Iterator<Item = usize> + '_ {
131 let ndim = self.logical_shape.len();
132 let permutation = self.permutation.as_deref();
133
134 match permutation {
135 None => Either::Left(self.logical_shape.iter().copied()),
136 Some(perm) => Either::Right(
137 (0..ndim).map(move |p| self.logical_shape[Self::inverse_perm(perm, p)]),
138 ),
139 }
140 }
141
142 pub fn strides(&self) -> impl Iterator<Item = usize> + '_ {
151 let ndim = self.logical_shape.len();
152 let permutation = self.permutation.as_deref();
153
154 match permutation {
155 None => Either::Left(
156 (0..ndim).map(|i| self.logical_shape[i + 1..].iter().product::<usize>()),
157 ),
158 Some(permutation) => {
159 Either::Right((0..ndim).map(|i| self.permuted_stride(i, permutation)))
160 }
161 }
162 }
163
164 fn permuted_stride(&self, i: usize, perm: &[usize]) -> usize {
170 let phys = perm[i];
171
172 perm.iter()
175 .enumerate()
176 .filter(|&(_, &p)| p > phys)
177 .map(|(l, _)| self.logical_shape[l])
178 .product::<usize>()
179 }
180
181 fn inverse_perm(perm: &[usize], p: usize) -> usize {
188 perm.iter()
189 .position(|&pi| pi == p)
190 .vortex_expect("permutation must contain every physical position exactly once")
191 }
192}
193
194impl fmt::Display for FixedShapeTensorMetadata {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(f, "Tensor(")?;
197
198 match &self.dim_names {
199 Some(names) => {
200 for (i, (dim, name)) in self.logical_shape.iter().zip(names.iter()).enumerate() {
201 if i > 0 {
202 write!(f, ", ")?;
203 }
204 write!(f, "{name}: {dim}")?;
205 }
206 }
207 None => {
208 for (i, dim) in self.logical_shape.iter().enumerate() {
209 if i > 0 {
210 write!(f, ", ")?;
211 }
212 write!(f, "{dim}")?;
213 }
214 }
215 }
216
217 if let Some(perm) = &self.permutation {
218 for (i, p) in perm.iter().enumerate() {
219 if i > 0 {
220 write!(f, ", ")?;
221 }
222 write!(f, "{p}")?;
223 }
224 write!(f, "]")?;
225 }
226
227 write!(f, ")")
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use rstest::rstest;
234
235 use super::*;
236
237 fn slow_strides(shape: &[usize], perm: &[usize]) -> Vec<usize> {
243 let ndim = shape.len();
244
245 let mut physical_shape = vec![0usize; ndim];
247 for l in 0..ndim {
248 physical_shape[perm[l]] = shape[l];
249 }
250
251 let mut physical_strides = vec![1usize; ndim];
253 for i in (0..ndim.saturating_sub(1)).rev() {
254 physical_strides[i] = physical_strides[i + 1] * physical_shape[i + 1];
255 }
256
257 (0..ndim).map(|l| physical_strides[perm[l]]).collect()
259 }
260
261 #[rstest]
264 #[case::scalar_0d(vec![], vec![])]
265 #[case::vector_1d(vec![5], vec![1])]
266 #[case::matrix_2d(vec![3, 4], vec![4, 1])]
267 #[case::tensor_3d(vec![2, 3, 4], vec![12, 4, 1])]
268 #[case::zero_dim( vec![3, 0, 4], vec![0, 4, 1])]
269 fn strides_row_major(#[case] shape: Vec<usize>, #[case] expected: Vec<usize>) {
270 let m = FixedShapeTensorMetadata::new(shape);
271 assert_eq!(m.strides().collect::<Vec<_>>(), expected);
272 }
273
274 #[rstest]
280 #[case::transpose_2d(vec![3, 4], vec![1, 0], vec![1, 3])]
282 #[case::perm_3d_201( vec![2, 3, 4], vec![2, 0, 1], vec![1, 8, 2])]
284 #[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![0, 1, 0])]
286 fn strides_permuted(
287 #[case] shape: Vec<usize>,
288 #[case] perm: Vec<usize>,
289 #[case] expected: Vec<usize>,
290 ) -> VortexResult<()> {
291 let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?;
292 let actual: Vec<usize> = m.strides().collect();
293 assert_eq!(actual, expected);
294 assert_eq!(actual, slow_strides(&shape, &perm));
295 Ok(())
296 }
297
298 #[test]
299 fn strides_identity_permutation_matches_row_major() -> VortexResult<()> {
300 let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]);
301 let identity =
302 FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?;
303 assert_eq!(
304 row_major.strides().collect::<Vec<_>>(),
305 identity.strides().collect::<Vec<_>>(),
306 );
307 Ok(())
308 }
309
310 #[rstest]
313 #[case::perm_3d_120(vec![2, 3, 4], vec![1, 2, 0])]
314 #[case::perm_3d_021(vec![2, 3, 4], vec![0, 2, 1])]
315 #[case::identity_3d(vec![2, 3, 4], vec![0, 1, 2])]
316 #[case::zero_lead( vec![0, 3, 4], vec![2, 0, 1])]
317 #[case::rev_4d( vec![2, 3, 4, 5], vec![3, 2, 1, 0])]
318 #[case::swap_4d( vec![2, 3, 4, 5], vec![1, 0, 3, 2])]
319 #[case::half_4d( vec![2, 3, 4, 5], vec![2, 3, 0, 1])]
320 fn strides_match_slow_reference(
321 #[case] shape: Vec<usize>,
322 #[case] perm: Vec<usize>,
323 ) -> VortexResult<()> {
324 let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?;
325 assert_eq!(m.strides().collect::<Vec<_>>(), slow_strides(&shape, &perm));
326 Ok(())
327 }
328
329 #[test]
332 fn physical_shape_no_permutation() {
333 let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]);
334 assert_eq!(m.physical_shape().collect::<Vec<_>>(), vec![2, 3, 4]);
335 }
336
337 #[rstest]
338 #[case::transpose_2d(vec![3, 4], vec![1, 0], vec![4, 3])]
340 #[case::perm_3d( vec![2, 3, 4], vec![2, 0, 1], vec![3, 4, 2])]
342 #[case::identity( vec![2, 3, 4], vec![0, 1, 2], vec![2, 3, 4])]
344 #[case::zero_dim( vec![3, 0, 4], vec![1, 2, 0], vec![4, 3, 0])]
346 fn physical_shape_permuted(
347 #[case] shape: Vec<usize>,
348 #[case] perm: Vec<usize>,
349 #[case] expected: Vec<usize>,
350 ) -> VortexResult<()> {
351 let m = FixedShapeTensorMetadata::new(shape).with_permutation(perm)?;
352 assert_eq!(m.physical_shape().collect::<Vec<_>>(), expected);
353 Ok(())
354 }
355
356 #[test]
357 fn dim_names_wrong_length() {
358 let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]);
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn permutation_wrong_length() {
364 let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0]);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn permutation_out_of_range() {
370 let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 5]);
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn permutation_duplicate_index() {
376 let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 0]);
377 assert!(result.is_err());
378 }
379}