Skip to main content

tenflowers_core/
shape.rs

1#[cfg(feature = "serialize")]
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
7pub struct Shape {
8    dims: Vec<usize>,
9}
10
11impl Shape {
12    pub fn new(dims: Vec<usize>) -> Self {
13        Self { dims }
14    }
15
16    pub fn from_slice(dims: &[usize]) -> Self {
17        Self {
18            dims: dims.to_vec(),
19        }
20    }
21
22    pub fn rank(&self) -> usize {
23        self.dims.len()
24    }
25
26    pub fn len(&self) -> usize {
27        self.dims.len()
28    }
29
30    pub fn is_empty(&self) -> bool {
31        self.dims.is_empty()
32    }
33
34    pub fn size(&self) -> usize {
35        self.dims.iter().product()
36    }
37
38    pub fn elements(&self) -> usize {
39        self.size()
40    }
41
42    pub fn dims(&self) -> &[usize] {
43        &self.dims
44    }
45
46    pub fn is_scalar(&self) -> bool {
47        self.dims.is_empty()
48    }
49
50    pub fn is_compatible_with(&self, other: &Self) -> bool {
51        if self.rank() != other.rank() {
52            return false;
53        }
54        self.dims
55            .iter()
56            .zip(&other.dims)
57            .all(|(a, b)| *a == *b || *a == 1 || *b == 1)
58    }
59
60    pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
61        let rank = self.rank().max(other.rank());
62        let mut result = vec![1; rank];
63
64        for i in 0..self.rank() {
65            result[rank - self.rank() + i] = self.dims[i];
66        }
67
68        for i in 0..other.rank() {
69            let idx = rank - other.rank() + i;
70            if result[idx] == 1 {
71                result[idx] = other.dims[i];
72            } else if other.dims[i] != 1 && result[idx] != other.dims[i] {
73                return None;
74            }
75        }
76
77        Some(Self::new(result))
78    }
79
80    /// Get an iterator over the dimensions
81    pub fn iter(&self) -> std::slice::Iter<'_, usize> {
82        self.dims.iter()
83    }
84
85    /// Convert dimensions to a vector
86    pub fn to_vec(&self) -> Vec<usize> {
87        self.dims.clone()
88    }
89}
90
91impl Index<usize> for Shape {
92    type Output = usize;
93
94    fn index(&self, index: usize) -> &Self::Output {
95        &self.dims[index]
96    }
97}
98
99impl IndexMut<usize> for Shape {
100    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
101        &mut self.dims[index]
102    }
103}
104
105impl std::fmt::Display for Shape {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        write!(f, "[")?;
108        for (i, dim) in self.dims.iter().enumerate() {
109            if i > 0 {
110                write!(f, ", ")?;
111            }
112            write!(f, "{dim}")?;
113        }
114        write!(f, "]")
115    }
116}