1#[cfg(feature = "serialize")]
2use serde::{Deserialize, Serialize};
3use std::ops::{Index, IndexMut};
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
6#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
7pub struct Shape {
8 dims: Vec<usize>,
9}
10
11impl Shape {
12 pub fn new(dims: Vec<usize>) -> Self {
13 Self { dims }
14 }
15
16 pub fn from_slice(dims: &[usize]) -> Self {
17 Self {
18 dims: dims.to_vec(),
19 }
20 }
21
22 pub fn rank(&self) -> usize {
23 self.dims.len()
24 }
25
26 pub fn len(&self) -> usize {
27 self.dims.len()
28 }
29
30 pub fn is_empty(&self) -> bool {
31 self.dims.is_empty()
32 }
33
34 pub fn size(&self) -> usize {
35 self.dims.iter().product()
36 }
37
38 pub fn elements(&self) -> usize {
39 self.size()
40 }
41
42 pub fn dims(&self) -> &[usize] {
43 &self.dims
44 }
45
46 pub fn is_scalar(&self) -> bool {
47 self.dims.is_empty()
48 }
49
50 pub fn is_compatible_with(&self, other: &Self) -> bool {
51 if self.rank() != other.rank() {
52 return false;
53 }
54 self.dims
55 .iter()
56 .zip(&other.dims)
57 .all(|(a, b)| *a == *b || *a == 1 || *b == 1)
58 }
59
60 pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
61 let rank = self.rank().max(other.rank());
62 let mut result = vec![1; rank];
63
64 for i in 0..self.rank() {
65 result[rank - self.rank() + i] = self.dims[i];
66 }
67
68 for i in 0..other.rank() {
69 let idx = rank - other.rank() + i;
70 if result[idx] == 1 {
71 result[idx] = other.dims[i];
72 } else if other.dims[i] != 1 && result[idx] != other.dims[i] {
73 return None;
74 }
75 }
76
77 Some(Self::new(result))
78 }
79
80 pub fn iter(&self) -> std::slice::Iter<'_, usize> {
82 self.dims.iter()
83 }
84
85 pub fn to_vec(&self) -> Vec<usize> {
87 self.dims.clone()
88 }
89}
90
91impl Index<usize> for Shape {
92 type Output = usize;
93
94 fn index(&self, index: usize) -> &Self::Output {
95 &self.dims[index]
96 }
97}
98
99impl IndexMut<usize> for Shape {
100 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
101 &mut self.dims[index]
102 }
103}
104
105impl std::fmt::Display for Shape {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 write!(f, "[")?;
108 for (i, dim) in self.dims.iter().enumerate() {
109 if i > 0 {
110 write!(f, ", ")?;
111 }
112 write!(f, "{dim}")?;
113 }
114 write!(f, "]")
115 }
116}