ug/
layout.rs

1//! The shape of a tensor is a tuple with the size of each of its dimensions.
2use crate::{bail, Error, Result};
3
4#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
5pub struct Shape(Vec<usize>);
6
7pub const SCALAR: Shape = Shape(vec![]);
8
9impl std::fmt::Debug for Shape {
10    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11        write!(f, "{:?}", &self.dims())
12    }
13}
14
15impl<const C: usize> From<&[usize; C]> for Shape {
16    fn from(dims: &[usize; C]) -> Self {
17        Self(dims.to_vec())
18    }
19}
20
21impl From<&[usize]> for Shape {
22    fn from(dims: &[usize]) -> Self {
23        Self(dims.to_vec())
24    }
25}
26
27impl From<&Shape> for Shape {
28    fn from(shape: &Shape) -> Self {
29        Self(shape.0.to_vec())
30    }
31}
32
33impl From<()> for Shape {
34    fn from(_: ()) -> Self {
35        Self(vec![])
36    }
37}
38
39impl From<usize> for Shape {
40    fn from(d1: usize) -> Self {
41        Self(vec![d1])
42    }
43}
44
45impl From<(usize,)> for Shape {
46    fn from(d1: (usize,)) -> Self {
47        Self(vec![d1.0])
48    }
49}
50
51impl From<(usize, usize)> for Shape {
52    fn from(d12: (usize, usize)) -> Self {
53        Self(vec![d12.0, d12.1])
54    }
55}
56
57impl From<(usize, usize, usize)> for Shape {
58    fn from(d123: (usize, usize, usize)) -> Self {
59        Self(vec![d123.0, d123.1, d123.2])
60    }
61}
62
63impl From<(usize, usize, usize, usize)> for Shape {
64    fn from(d1234: (usize, usize, usize, usize)) -> Self {
65        Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
66    }
67}
68
69impl From<(usize, usize, usize, usize, usize)> for Shape {
70    fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
71        Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
72    }
73}
74
75impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
76    fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
77        Self(vec![d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5])
78    }
79}
80
81impl From<Vec<usize>> for Shape {
82    fn from(dims: Vec<usize>) -> Self {
83        Self(dims)
84    }
85}
86
87macro_rules! extract_dims {
88    ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
89        pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
90            if dims.len() != $cnt {
91                bail!(
92                    "unexpected number of dims, expected {} got {} shape {:?}",
93                    $cnt,
94                    dims.len(),
95                    dims
96                )
97            }
98            Ok($dims(dims))
99        }
100
101        impl Shape {
102            pub fn $fn_name(&self) -> Result<$out_type> {
103                $fn_name(self.0.as_slice())
104            }
105        }
106
107        impl std::convert::TryInto<$out_type> for Shape {
108            type Error = crate::Error;
109            fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
110                self.$fn_name()
111            }
112        }
113    };
114}
115
116impl Shape {
117    pub fn num_elements(&self) -> usize {
118        self.dims().iter().product()
119    }
120
121    pub fn from_dims(dims: &[usize]) -> Self {
122        Self(dims.to_vec())
123    }
124
125    /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
126    pub fn rank(&self) -> usize {
127        self.0.len()
128    }
129
130    pub fn into_dims(self) -> Vec<usize> {
131        self.0
132    }
133
134    /// The dimensions as a slice of `usize`.
135    pub fn dims(&self) -> &[usize] {
136        &self.0
137    }
138
139    /// The total number of elements, this is the product of all dimension sizes.
140    pub fn elem_count(&self) -> usize {
141        self.0.iter().product()
142    }
143
144    /// The strides given in number of elements for a contiguous n-dimensional
145    /// arrays using this shape.
146    pub fn stride_contiguous(&self) -> Vec<usize> {
147        let mut stride: Vec<_> = self
148            .0
149            .iter()
150            .rev()
151            .scan(1, |prod, u| {
152                let prod_pre_mult = *prod;
153                *prod *= u;
154                Some(prod_pre_mult)
155            })
156            .collect();
157        stride.reverse();
158        stride
159    }
160
161    /// Returns true if the strides are C contiguous (aka row major).
162    pub fn is_contiguous(&self, stride: &[usize]) -> bool {
163        if self.0.len() != stride.len() {
164            return false;
165        }
166        let mut acc = 1;
167        for (&stride, &dim) in stride.iter().zip(self.0.iter()).rev() {
168            if dim > 1 && stride != acc {
169                return false;
170            }
171            acc *= dim;
172        }
173        true
174    }
175
176    /// Returns true if the strides are Fortran contiguous (aka column major).
177    pub fn is_fortran_contiguous(&self, stride: &[usize]) -> bool {
178        if self.0.len() != stride.len() {
179            return false;
180        }
181        let mut acc = 1;
182        for (&stride, &dim) in stride.iter().zip(self.0.iter()) {
183            if dim > 1 && stride != acc {
184                return false;
185            }
186            acc *= dim;
187        }
188        true
189    }
190
191    /// Modifies the shape by adding a list of additional dimensions at the end of the existing
192    /// dimensions.
193    pub fn extend(mut self, additional_dims: &[usize]) -> Self {
194        self.0.extend(additional_dims);
195        self
196    }
197
198    /// Check whether the two shapes are compatible for broadcast, and if it is the case return the
199    /// broadcasted shape. This is to be used for binary pointwise ops.
200    pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
201        let lhs = self;
202        let lhs_dims = lhs.dims();
203        let rhs_dims = rhs.dims();
204        let lhs_ndims = lhs_dims.len();
205        let rhs_ndims = rhs_dims.len();
206        let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
207        let mut bcast_dims = vec![0; bcast_ndims];
208        for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
209            let rev_idx = bcast_ndims - idx;
210            let l_value = if lhs_ndims < rev_idx { 1 } else { lhs_dims[lhs_ndims - rev_idx] };
211            let r_value = if rhs_ndims < rev_idx { 1 } else { rhs_dims[rhs_ndims - rev_idx] };
212            *bcast_value = if l_value == r_value {
213                l_value
214            } else if l_value == 1 {
215                r_value
216            } else if r_value == 1 {
217                l_value
218            } else {
219                bail!("shape mismatch in binary op '{op}', lhs: {lhs:?} rhs: {rhs:?}")
220            }
221        }
222        Ok(Shape::from(bcast_dims))
223    }
224}
225
226pub trait Dim {
227    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
228    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
229}
230
231impl Dim for usize {
232    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
233        let dim = *self;
234        if dim >= shape.dims().len() {
235            bail!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")
236        }
237        Ok(dim)
238    }
239
240    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
241        let dim = *self;
242        if dim > shape.dims().len() {
243            bail!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")
244        }
245        Ok(dim)
246    }
247}
248
249impl Dim for i32 {
250    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
251        let dim = *self;
252        if dim >= 0 {
253            (dim as usize).to_index(shape, op)
254        } else {
255            D::Minus((-dim) as usize).to_index(shape, op)
256        }
257    }
258
259    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
260        let dim = *self;
261        if dim >= 0 {
262            (dim as usize).to_index_plus_one(shape, op)
263        } else {
264            D::Minus((-dim) as usize).to_index_plus_one(shape, op)
265        }
266    }
267}
268
269#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
270pub enum D {
271    Minus1,
272    Minus2,
273    Minus(usize),
274}
275
276impl D {
277    fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
278        let dim = match self {
279            Self::Minus1 => -1,
280            Self::Minus2 => -2,
281            Self::Minus(u) => -(*u as i32),
282        };
283        Error::Msg(format!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")).bt()
284    }
285}
286
287impl Dim for D {
288    fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
289        let rank = shape.rank();
290        match self {
291            Self::Minus1 if rank >= 1 => Ok(rank - 1),
292            Self::Minus2 if rank >= 2 => Ok(rank - 2),
293            Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
294            _ => Err(self.out_of_range(shape, op)),
295        }
296    }
297
298    fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
299        let rank = shape.rank();
300        match self {
301            Self::Minus1 => Ok(rank),
302            Self::Minus2 if rank >= 1 => Ok(rank - 1),
303            Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
304            _ => Err(self.out_of_range(shape, op)),
305        }
306    }
307}
308
309pub trait Dims: Sized {
310    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
311
312    fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
313        let dims = self.to_indexes_internal(shape, op)?;
314        for (i, &dim) in dims.iter().enumerate() {
315            if dims[..i].contains(&dim) {
316                bail!("duplicate dim indexes in '{op}', dims: {dims:?}, shape: {shape:?}")
317            }
318            if dim >= shape.rank() {
319                bail!("dim out of range in '{op}', dim: {dim}, shape: {shape:?}")
320            }
321        }
322        Ok(dims)
323    }
324}
325
326impl Dims for Vec<usize> {
327    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
328        Ok(self)
329    }
330}
331
332impl<const N: usize> Dims for [usize; N] {
333    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
334        Ok(self.to_vec())
335    }
336}
337
338impl Dims for &[usize] {
339    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
340        Ok(self.to_vec())
341    }
342}
343
344impl Dims for () {
345    fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
346        Ok(vec![])
347    }
348}
349
350impl<D: Dim + Sized> Dims for D {
351    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
352        let dim = self.to_index(shape, op)?;
353        Ok(vec![dim])
354    }
355}
356
357impl<D: Dim> Dims for (D,) {
358    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
359        let dim = self.0.to_index(shape, op)?;
360        Ok(vec![dim])
361    }
362}
363
364impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
365    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
366        let d0 = self.0.to_index(shape, op)?;
367        let d1 = self.1.to_index(shape, op)?;
368        Ok(vec![d0, d1])
369    }
370}
371
372impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
373    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
374        let d0 = self.0.to_index(shape, op)?;
375        let d1 = self.1.to_index(shape, op)?;
376        let d2 = self.2.to_index(shape, op)?;
377        Ok(vec![d0, d1, d2])
378    }
379}
380
381impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
382    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
383        let d0 = self.0.to_index(shape, op)?;
384        let d1 = self.1.to_index(shape, op)?;
385        let d2 = self.2.to_index(shape, op)?;
386        let d3 = self.3.to_index(shape, op)?;
387        Ok(vec![d0, d1, d2, d3])
388    }
389}
390
391impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
392    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
393        let d0 = self.0.to_index(shape, op)?;
394        let d1 = self.1.to_index(shape, op)?;
395        let d2 = self.2.to_index(shape, op)?;
396        let d3 = self.3.to_index(shape, op)?;
397        let d4 = self.4.to_index(shape, op)?;
398        Ok(vec![d0, d1, d2, d3, d4])
399    }
400}
401
402impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
403    fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
404        let d0 = self.0.to_index(shape, op)?;
405        let d1 = self.1.to_index(shape, op)?;
406        let d2 = self.2.to_index(shape, op)?;
407        let d3 = self.3.to_index(shape, op)?;
408        let d4 = self.4.to_index(shape, op)?;
409        let d5 = self.5.to_index(shape, op)?;
410        Ok(vec![d0, d1, d2, d3, d4, d5])
411    }
412}
413
414extract_dims!(dims0, 0, |_: &[usize]| (), ());
415extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
416extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
417extract_dims!(dims3, 3, |d: &[usize]| (d[0], d[1], d[2]), (usize, usize, usize));
418extract_dims!(dims4, 4, |d: &[usize]| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize));
419extract_dims!(
420    dims5,
421    5,
422    |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
423    (usize, usize, usize, usize, usize)
424);
425
426pub trait ShapeWithOneHole {
427    fn into_shape(self, el_count: usize) -> Result<Shape>;
428}
429
430impl<S: Into<Shape>> ShapeWithOneHole for S {
431    fn into_shape(self, _el_count: usize) -> Result<Shape> {
432        Ok(self.into())
433    }
434}
435
436impl ShapeWithOneHole for ((),) {
437    fn into_shape(self, el_count: usize) -> Result<Shape> {
438        Ok(el_count.into())
439    }
440}
441
442fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
443    if prod_d == 0 {
444        bail!("cannot reshape tensor of {el_count} elements to {s:?}")
445    }
446    if el_count % prod_d != 0 {
447        bail!("cannot reshape tensor with {el_count} elements to {s:?}")
448    }
449    Ok(el_count / prod_d)
450}
451
452impl ShapeWithOneHole for ((), usize) {
453    fn into_shape(self, el_count: usize) -> Result<Shape> {
454        let ((), d1) = self;
455        Ok((hole_size(el_count, d1, &self)?, d1).into())
456    }
457}
458
459impl ShapeWithOneHole for (usize, ()) {
460    fn into_shape(self, el_count: usize) -> Result<Shape> {
461        let (d1, ()) = self;
462        Ok((d1, hole_size(el_count, d1, &self)?).into())
463    }
464}
465
466impl ShapeWithOneHole for ((), usize, usize) {
467    fn into_shape(self, el_count: usize) -> Result<Shape> {
468        let ((), d1, d2) = self;
469        Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
470    }
471}
472
473impl ShapeWithOneHole for (usize, (), usize) {
474    fn into_shape(self, el_count: usize) -> Result<Shape> {
475        let (d1, (), d2) = self;
476        Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
477    }
478}
479
480impl ShapeWithOneHole for (usize, usize, ()) {
481    fn into_shape(self, el_count: usize) -> Result<Shape> {
482        let (d1, d2, ()) = self;
483        Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
484    }
485}
486
487impl ShapeWithOneHole for ((), usize, usize, usize) {
488    fn into_shape(self, el_count: usize) -> Result<Shape> {
489        let ((), d1, d2, d3) = self;
490        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
491        Ok((d, d1, d2, d3).into())
492    }
493}
494
495impl ShapeWithOneHole for (usize, (), usize, usize) {
496    fn into_shape(self, el_count: usize) -> Result<Shape> {
497        let (d1, (), d2, d3) = self;
498        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
499        Ok((d1, d, d2, d3).into())
500    }
501}
502
503impl ShapeWithOneHole for (usize, usize, (), usize) {
504    fn into_shape(self, el_count: usize) -> Result<Shape> {
505        let (d1, d2, (), d3) = self;
506        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
507        Ok((d1, d2, d, d3).into())
508    }
509}
510
511impl ShapeWithOneHole for (usize, usize, usize, ()) {
512    fn into_shape(self, el_count: usize) -> Result<Shape> {
513        let (d1, d2, d3, ()) = self;
514        let d = hole_size(el_count, d1 * d2 * d3, &self)?;
515        Ok((d1, d2, d3, d).into())
516    }
517}
518
519impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
520    fn into_shape(self, el_count: usize) -> Result<Shape> {
521        let ((), d1, d2, d3, d4) = self;
522        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
523        Ok((d, d1, d2, d3, d4).into())
524    }
525}
526
527impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
528    fn into_shape(self, el_count: usize) -> Result<Shape> {
529        let (d1, (), d2, d3, d4) = self;
530        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
531        Ok((d1, d, d2, d3, d4).into())
532    }
533}
534
535impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
536    fn into_shape(self, el_count: usize) -> Result<Shape> {
537        let (d1, d2, (), d3, d4) = self;
538        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
539        Ok((d1, d2, d, d3, d4).into())
540    }
541}
542
543impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
544    fn into_shape(self, el_count: usize) -> Result<Shape> {
545        let (d1, d2, d3, (), d4) = self;
546        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
547        Ok((d1, d2, d3, d, d4).into())
548    }
549}
550
551impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
552    fn into_shape(self, el_count: usize) -> Result<Shape> {
553        let (d1, d2, d3, d4, ()) = self;
554        let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
555        Ok((d1, d2, d3, d4, d).into())
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn stride() {
565        let shape = Shape::from(());
566        assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
567        let shape = Shape::from(42);
568        assert_eq!(shape.stride_contiguous(), [1]);
569        let shape = Shape::from((42, 1337));
570        assert_eq!(shape.stride_contiguous(), [1337, 1]);
571        let shape = Shape::from((299, 792, 458));
572        assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
573    }
574}
575
576#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
577pub struct Layout {
578    shape: Shape,
579    strides: Vec<usize>,
580    // TODO: We only have a single offset here, maybe we should have one offset per dimension?
581    offset: usize,
582}
583
584impl Layout {
585    pub fn can_be_compressed(&self) -> bool {
586        let strides = self.strides();
587        let dims = self.dims();
588        if dims.len() <= 1 {
589            return false;
590        }
591        for i in 0..dims.len() - 1 {
592            if strides[i] != strides[i + 1] * dims[i + 1] {
593                return false;
594            }
595        }
596        true
597    }
598
599    pub fn compress_all(&self) -> Result<Self> {
600        let strides = self.strides();
601        let dims = self.dims();
602        for i in 0..dims.len() - 1 {
603            if strides[i] != strides[i + 1] * dims[i + 1] {
604                bail!("cannot collapse dims, {self:?}")
605            }
606        }
607        let stride = strides.last().copied().unwrap_or(1);
608        let dim = self.num_elements();
609        Ok(Self { shape: Shape::from(dim), strides: vec![stride], offset: self.offset() })
610    }
611
612    pub fn from_shape<S: Into<Shape>>(shape: S) -> Self {
613        let shape = shape.into();
614        let mut strides = vec![];
615        let mut stride = 1;
616        for l in shape.dims().iter().rev() {
617            strides.push(stride);
618            stride *= l
619        }
620        strides.reverse();
621        Self { shape, strides, offset: 0 }
622    }
623
624    pub fn transpose(&self) -> Self {
625        let r = self.rank();
626        if r < 2 {
627            return self.clone();
628        }
629        let mut dims = self.dims().to_vec();
630        let mut strides = self.strides.to_vec();
631        dims.swap(r - 2, r - 1);
632        strides.swap(r - 2, r - 1);
633        Self { shape: dims.into(), offset: self.offset, strides }
634    }
635
636    pub fn num_elements(&self) -> usize {
637        self.shape.num_elements()
638    }
639
640    pub fn shape(&self) -> &Shape {
641        &self.shape
642    }
643
644    pub fn dims(&self) -> &[usize] {
645        self.shape.dims()
646    }
647
648    pub fn rank(&self) -> usize {
649        self.shape.rank()
650    }
651
652    pub fn strides(&self) -> &[usize] {
653        self.strides.as_slice()
654    }
655
656    pub fn offset(&self) -> usize {
657        self.offset
658    }
659
660    pub fn set_offset(&mut self, offset: usize) {
661        self.offset = offset
662    }
663
664    pub fn c_contiguous(&self) -> bool {
665        let mut prod_l = 1;
666        for (&s, &l) in self.strides.iter().zip(self.shape.dims().iter()).rev() {
667            if s != prod_l {
668                return false;
669            }
670            prod_l *= l
671        }
672        true
673    }
674}