xdl_core/
dimension.rs

1//! XDL array dimensions and indexing
2
3use crate::{XdlError, MAXRANK};
4use serde::{Deserialize, Serialize};
5
6/// XDL array dimension descriptor
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub struct Dimension {
9    dimensions: Vec<usize>,
10}
11
12impl Dimension {
13    /// Create a scalar (0-dimensional)
14    pub fn scalar() -> Self {
15        Self { dimensions: vec![] }
16    }
17
18    /// Create from dimension vector
19    pub fn from_vec(dims: Vec<usize>) -> Result<Self, XdlError> {
20        if dims.len() > MAXRANK {
21            return Err(XdlError::DimensionError(format!(
22                "Too many dimensions: {} > {}",
23                dims.len(),
24                MAXRANK
25            )));
26        }
27
28        if dims.contains(&0) {
29            return Err(XdlError::DimensionError(
30                "Zero dimensions not allowed".to_string(),
31            ));
32        }
33
34        Ok(Self { dimensions: dims })
35    }
36
37    /// Create 1-dimensional array
38    pub fn from_size(size: usize) -> Result<Self, XdlError> {
39        if size == 0 {
40            return Err(XdlError::DimensionError(
41                "Zero size not allowed".to_string(),
42            ));
43        }
44        Ok(Self {
45            dimensions: vec![size],
46        })
47    }
48
49    /// Get number of dimensions (rank)
50    pub fn rank(&self) -> usize {
51        self.dimensions.len()
52    }
53
54    /// Get dimension sizes
55    pub fn dims(&self) -> &[usize] {
56        &self.dimensions
57    }
58
59    /// Get specific dimension size
60    pub fn dim(&self, index: usize) -> Option<usize> {
61        self.dimensions.get(index).copied()
62    }
63
64    /// Calculate total number of elements
65    pub fn n_elements(&self) -> usize {
66        if self.dimensions.is_empty() {
67            1 // scalar
68        } else {
69            self.dimensions.iter().product()
70        }
71    }
72
73    /// Check if this is a scalar
74    pub fn is_scalar(&self) -> bool {
75        self.dimensions.is_empty()
76    }
77
78    /// Check if this is a vector (1D array)
79    pub fn is_vector(&self) -> bool {
80        self.dimensions.len() == 1
81    }
82
83    /// Convert multidimensional index to linear index
84    pub fn linear_index(&self, indices: &[usize]) -> Result<usize, XdlError> {
85        if indices.len() != self.dimensions.len() {
86            return Err(XdlError::DimensionError(format!(
87                "Index rank {} doesn't match array rank {}",
88                indices.len(),
89                self.dimensions.len()
90            )));
91        }
92
93        let mut linear_idx = 0;
94        let mut stride = 1;
95
96        for (i, (&idx, &dim)) in indices.iter().zip(&self.dimensions).enumerate().rev() {
97            if idx >= dim {
98                return Err(XdlError::IndexError(format!(
99                    "Index {} out of range for dimension {} (size {})",
100                    idx, i, dim
101                )));
102            }
103            linear_idx += idx * stride;
104            stride *= dim;
105        }
106
107        Ok(linear_idx)
108    }
109
110    /// Convert linear index to multidimensional indices
111    pub fn multi_index(&self, linear_idx: usize) -> Result<Vec<usize>, XdlError> {
112        if linear_idx >= self.n_elements() {
113            return Err(XdlError::IndexError(format!(
114                "Linear index {} out of range for array with {} elements",
115                linear_idx,
116                self.n_elements()
117            )));
118        }
119
120        if self.is_scalar() {
121            return Ok(vec![]);
122        }
123
124        let mut indices = vec![0; self.dimensions.len()];
125        let mut remaining = linear_idx;
126
127        for i in (0..self.dimensions.len()).rev() {
128            let dim_size = self.dimensions[i];
129            indices[i] = remaining % dim_size;
130            remaining /= dim_size;
131        }
132
133        Ok(indices)
134    }
135
136    /// Reform to new dimensions (like IDL REFORM)
137    pub fn reform(&self, new_dims: Vec<usize>) -> Result<Self, XdlError> {
138        let new_n_elements: usize = new_dims.iter().product();
139        if new_n_elements != self.n_elements() {
140            return Err(XdlError::DimensionError(format!(
141                "Cannot reform array of {} elements to {} elements",
142                self.n_elements(),
143                new_n_elements
144            )));
145        }
146
147        Self::from_vec(new_dims)
148    }
149
150    /// Transpose dimensions
151    pub fn transpose(&self, perm: Option<&[usize]>) -> Result<Self, XdlError> {
152        if self.is_scalar() {
153            return Ok(self.clone());
154        }
155
156        let perm = if let Some(p) = perm {
157            if p.len() != self.dimensions.len() {
158                return Err(XdlError::DimensionError(
159                    "Permutation length doesn't match array rank".to_string(),
160                ));
161            }
162            p.to_vec()
163        } else {
164            // Default: reverse order
165            (0..self.dimensions.len()).rev().collect()
166        };
167
168        // Check permutation validity
169        let mut check = vec![false; self.dimensions.len()];
170        for &p in &perm {
171            if p >= self.dimensions.len() {
172                return Err(XdlError::DimensionError(
173                    "Invalid permutation index".to_string(),
174                ));
175            }
176            if check[p] {
177                return Err(XdlError::DimensionError(
178                    "Duplicate in permutation".to_string(),
179                ));
180            }
181            check[p] = true;
182        }
183
184        let new_dims = perm.iter().map(|&i| self.dimensions[i]).collect();
185        Ok(Self {
186            dimensions: new_dims,
187        })
188    }
189}
190
191impl Default for Dimension {
192    fn default() -> Self {
193        Self::scalar()
194    }
195}
196
197impl std::fmt::Display for Dimension {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        if self.is_scalar() {
200            write!(f, "scalar")
201        } else {
202            write!(
203                f,
204                "[{}]",
205                self.dimensions
206                    .iter()
207                    .map(|d| d.to_string())
208                    .collect::<Vec<_>>()
209                    .join(", ")
210            )
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_scalar_dimension() {
221        let dim = Dimension::scalar();
222        assert!(dim.is_scalar());
223        assert_eq!(dim.rank(), 0);
224        assert_eq!(dim.n_elements(), 1);
225    }
226
227    #[test]
228    fn test_vector_dimension() {
229        let dim = Dimension::from_size(10).unwrap();
230        assert!(dim.is_vector());
231        assert_eq!(dim.rank(), 1);
232        assert_eq!(dim.n_elements(), 10);
233        assert_eq!(dim.dim(0), Some(10));
234    }
235
236    #[test]
237    fn test_multi_dimension() {
238        let dim = Dimension::from_vec(vec![3, 4, 5]).unwrap();
239        assert_eq!(dim.rank(), 3);
240        assert_eq!(dim.n_elements(), 60);
241        assert_eq!(dim.dims(), &[3, 4, 5]);
242    }
243
244    #[test]
245    fn test_indexing() {
246        let dim = Dimension::from_vec(vec![3, 4]).unwrap();
247
248        // Test linear index conversion
249        assert_eq!(dim.linear_index(&[0, 0]).unwrap(), 0);
250        assert_eq!(dim.linear_index(&[2, 3]).unwrap(), 11);
251
252        // Test multi index conversion
253        assert_eq!(dim.multi_index(0).unwrap(), vec![0, 0]);
254        assert_eq!(dim.multi_index(11).unwrap(), vec![2, 3]);
255    }
256
257    #[test]
258    fn test_reform() {
259        let dim = Dimension::from_vec(vec![3, 4]).unwrap();
260        let reformed = dim.reform(vec![2, 6]).unwrap();
261        assert_eq!(reformed.dims(), &[2, 6]);
262        assert_eq!(reformed.n_elements(), 12);
263    }
264
265    #[test]
266    fn test_transpose() {
267        let dim = Dimension::from_vec(vec![3, 4, 5]).unwrap();
268        let transposed = dim.transpose(None).unwrap();
269        assert_eq!(transposed.dims(), &[5, 4, 3]);
270
271        let custom_transpose = dim.transpose(Some(&[1, 0, 2])).unwrap();
272        assert_eq!(custom_transpose.dims(), &[4, 3, 5]);
273    }
274
275    #[test]
276    fn test_error_cases() {
277        // Zero dimension
278        assert!(Dimension::from_vec(vec![3, 0, 5]).is_err());
279
280        // Too many dimensions
281        assert!(Dimension::from_vec(vec![1; MAXRANK + 1]).is_err());
282
283        // Invalid reform
284        let dim = Dimension::from_size(10).unwrap();
285        assert!(dim.reform(vec![3, 4]).is_err()); // 12 != 10
286    }
287}