1use crate::{XdlError, MAXRANK};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Dimension {
9 dimensions: Vec<usize>,
10}
11
12impl Dimension {
13 pub fn scalar() -> Self {
15 Self { dimensions: vec![] }
16 }
17
18 pub fn from_vec(dims: Vec<usize>) -> Result<Self, XdlError> {
20 if dims.len() > MAXRANK {
21 return Err(XdlError::DimensionError(format!(
22 "Too many dimensions: {} > {}",
23 dims.len(),
24 MAXRANK
25 )));
26 }
27
28 if dims.contains(&0) {
29 return Err(XdlError::DimensionError(
30 "Zero dimensions not allowed".to_string(),
31 ));
32 }
33
34 Ok(Self { dimensions: dims })
35 }
36
37 pub fn from_size(size: usize) -> Result<Self, XdlError> {
39 if size == 0 {
40 return Err(XdlError::DimensionError(
41 "Zero size not allowed".to_string(),
42 ));
43 }
44 Ok(Self {
45 dimensions: vec![size],
46 })
47 }
48
49 pub fn rank(&self) -> usize {
51 self.dimensions.len()
52 }
53
54 pub fn dims(&self) -> &[usize] {
56 &self.dimensions
57 }
58
59 pub fn dim(&self, index: usize) -> Option<usize> {
61 self.dimensions.get(index).copied()
62 }
63
64 pub fn n_elements(&self) -> usize {
66 if self.dimensions.is_empty() {
67 1 } else {
69 self.dimensions.iter().product()
70 }
71 }
72
73 pub fn is_scalar(&self) -> bool {
75 self.dimensions.is_empty()
76 }
77
78 pub fn is_vector(&self) -> bool {
80 self.dimensions.len() == 1
81 }
82
83 pub fn linear_index(&self, indices: &[usize]) -> Result<usize, XdlError> {
85 if indices.len() != self.dimensions.len() {
86 return Err(XdlError::DimensionError(format!(
87 "Index rank {} doesn't match array rank {}",
88 indices.len(),
89 self.dimensions.len()
90 )));
91 }
92
93 let mut linear_idx = 0;
94 let mut stride = 1;
95
96 for (i, (&idx, &dim)) in indices.iter().zip(&self.dimensions).enumerate().rev() {
97 if idx >= dim {
98 return Err(XdlError::IndexError(format!(
99 "Index {} out of range for dimension {} (size {})",
100 idx, i, dim
101 )));
102 }
103 linear_idx += idx * stride;
104 stride *= dim;
105 }
106
107 Ok(linear_idx)
108 }
109
110 pub fn multi_index(&self, linear_idx: usize) -> Result<Vec<usize>, XdlError> {
112 if linear_idx >= self.n_elements() {
113 return Err(XdlError::IndexError(format!(
114 "Linear index {} out of range for array with {} elements",
115 linear_idx,
116 self.n_elements()
117 )));
118 }
119
120 if self.is_scalar() {
121 return Ok(vec![]);
122 }
123
124 let mut indices = vec![0; self.dimensions.len()];
125 let mut remaining = linear_idx;
126
127 for i in (0..self.dimensions.len()).rev() {
128 let dim_size = self.dimensions[i];
129 indices[i] = remaining % dim_size;
130 remaining /= dim_size;
131 }
132
133 Ok(indices)
134 }
135
136 pub fn reform(&self, new_dims: Vec<usize>) -> Result<Self, XdlError> {
138 let new_n_elements: usize = new_dims.iter().product();
139 if new_n_elements != self.n_elements() {
140 return Err(XdlError::DimensionError(format!(
141 "Cannot reform array of {} elements to {} elements",
142 self.n_elements(),
143 new_n_elements
144 )));
145 }
146
147 Self::from_vec(new_dims)
148 }
149
150 pub fn transpose(&self, perm: Option<&[usize]>) -> Result<Self, XdlError> {
152 if self.is_scalar() {
153 return Ok(self.clone());
154 }
155
156 let perm = if let Some(p) = perm {
157 if p.len() != self.dimensions.len() {
158 return Err(XdlError::DimensionError(
159 "Permutation length doesn't match array rank".to_string(),
160 ));
161 }
162 p.to_vec()
163 } else {
164 (0..self.dimensions.len()).rev().collect()
166 };
167
168 let mut check = vec![false; self.dimensions.len()];
170 for &p in &perm {
171 if p >= self.dimensions.len() {
172 return Err(XdlError::DimensionError(
173 "Invalid permutation index".to_string(),
174 ));
175 }
176 if check[p] {
177 return Err(XdlError::DimensionError(
178 "Duplicate in permutation".to_string(),
179 ));
180 }
181 check[p] = true;
182 }
183
184 let new_dims = perm.iter().map(|&i| self.dimensions[i]).collect();
185 Ok(Self {
186 dimensions: new_dims,
187 })
188 }
189}
190
191impl Default for Dimension {
192 fn default() -> Self {
193 Self::scalar()
194 }
195}
196
197impl std::fmt::Display for Dimension {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 if self.is_scalar() {
200 write!(f, "scalar")
201 } else {
202 write!(
203 f,
204 "[{}]",
205 self.dimensions
206 .iter()
207 .map(|d| d.to_string())
208 .collect::<Vec<_>>()
209 .join(", ")
210 )
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_scalar_dimension() {
221 let dim = Dimension::scalar();
222 assert!(dim.is_scalar());
223 assert_eq!(dim.rank(), 0);
224 assert_eq!(dim.n_elements(), 1);
225 }
226
227 #[test]
228 fn test_vector_dimension() {
229 let dim = Dimension::from_size(10).unwrap();
230 assert!(dim.is_vector());
231 assert_eq!(dim.rank(), 1);
232 assert_eq!(dim.n_elements(), 10);
233 assert_eq!(dim.dim(0), Some(10));
234 }
235
236 #[test]
237 fn test_multi_dimension() {
238 let dim = Dimension::from_vec(vec![3, 4, 5]).unwrap();
239 assert_eq!(dim.rank(), 3);
240 assert_eq!(dim.n_elements(), 60);
241 assert_eq!(dim.dims(), &[3, 4, 5]);
242 }
243
244 #[test]
245 fn test_indexing() {
246 let dim = Dimension::from_vec(vec![3, 4]).unwrap();
247
248 assert_eq!(dim.linear_index(&[0, 0]).unwrap(), 0);
250 assert_eq!(dim.linear_index(&[2, 3]).unwrap(), 11);
251
252 assert_eq!(dim.multi_index(0).unwrap(), vec![0, 0]);
254 assert_eq!(dim.multi_index(11).unwrap(), vec![2, 3]);
255 }
256
257 #[test]
258 fn test_reform() {
259 let dim = Dimension::from_vec(vec![3, 4]).unwrap();
260 let reformed = dim.reform(vec![2, 6]).unwrap();
261 assert_eq!(reformed.dims(), &[2, 6]);
262 assert_eq!(reformed.n_elements(), 12);
263 }
264
265 #[test]
266 fn test_transpose() {
267 let dim = Dimension::from_vec(vec![3, 4, 5]).unwrap();
268 let transposed = dim.transpose(None).unwrap();
269 assert_eq!(transposed.dims(), &[5, 4, 3]);
270
271 let custom_transpose = dim.transpose(Some(&[1, 0, 2])).unwrap();
272 assert_eq!(custom_transpose.dims(), &[4, 3, 5]);
273 }
274
275 #[test]
276 fn test_error_cases() {
277 assert!(Dimension::from_vec(vec![3, 0, 5]).is_err());
279
280 assert!(Dimension::from_vec(vec![1; MAXRANK + 1]).is_err());
282
283 let dim = Dimension::from_size(10).unwrap();
285 assert!(dim.reform(vec![3, 4]).is_err()); }
287}