sklears_utils/array_utils/
shape_ops.rs

1//! Shape manipulation operations for arrays
2
3use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::{s, Array1, Array2};
5use scirs2_core::numeric::Zero;
6
7/// Reshape 1D array to 2D with specified dimensions
8pub fn reshape_1d_to_2d<T: Clone>(
9    array: &Array1<T>,
10    rows: usize,
11    cols: usize,
12) -> UtilsResult<Array2<T>> {
13    if rows * cols != array.len() {
14        return Err(UtilsError::ShapeMismatch {
15            expected: vec![rows * cols],
16            actual: vec![array.len()],
17        });
18    }
19
20    let reshaped = array
21        .clone()
22        .into_shape_with_order((rows, cols))
23        .map_err(|_| UtilsError::ShapeMismatch {
24            expected: vec![rows, cols],
25            actual: vec![array.len()],
26        })?;
27
28    Ok(reshaped)
29}
30
31/// Flatten 2D array to 1D
32pub fn flatten_2d<T: Clone>(array: &Array2<T>) -> Array1<T> {
33    let mut flattened = Vec::with_capacity(array.len());
34    for row in array.rows() {
35        for item in row {
36            flattened.push(item.clone());
37        }
38    }
39    Array1::from_vec(flattened)
40}
41
42/// Check if two shapes are broadcastable
43pub fn is_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
44    let max_len = shape1.len().max(shape2.len());
45
46    for i in 0..max_len {
47        let dim1 = if i < shape1.len() {
48            shape1[shape1.len() - 1 - i]
49        } else {
50            1
51        };
52        let dim2 = if i < shape2.len() {
53            shape2[shape2.len() - 1 - i]
54        } else {
55            1
56        };
57
58        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
59            return false;
60        }
61    }
62
63    true
64}
65
66/// Compute broadcasted shape
67pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> UtilsResult<Vec<usize>> {
68    if !is_broadcastable(shape1, shape2) {
69        return Err(UtilsError::ShapeMismatch {
70            expected: shape1.to_vec(),
71            actual: shape2.to_vec(),
72        });
73    }
74
75    let max_len = shape1.len().max(shape2.len());
76    let mut result = Vec::with_capacity(max_len);
77
78    for i in 0..max_len {
79        let dim1 = if i < shape1.len() {
80            shape1[shape1.len() - 1 - i]
81        } else {
82            1
83        };
84        let dim2 = if i < shape2.len() {
85            shape2[shape2.len() - 1 - i]
86        } else {
87            1
88        };
89
90        result.push(dim1.max(dim2));
91    }
92
93    result.reverse();
94    Ok(result)
95}
96
97/// Transpose 2D array
98pub fn transpose<T: Clone>(array: &Array2<T>) -> Array2<T> {
99    array.t().to_owned()
100}
101
102/// Stack 1D arrays along specified axis to form 2D array
103pub fn stack_1d<T: Clone + Zero>(arrays: &[&Array1<T>], axis: usize) -> UtilsResult<Array2<T>> {
104    if arrays.is_empty() {
105        return Err(UtilsError::EmptyInput);
106    }
107
108    if axis > 1 {
109        return Err(UtilsError::InvalidParameter(
110            "Axis must be 0 or 1 for stacking 1D arrays into 2D".to_string(),
111        ));
112    }
113
114    let first_len = arrays[0].len();
115    for array in arrays.iter() {
116        if array.len() != first_len {
117            return Err(UtilsError::ShapeMismatch {
118                expected: vec![first_len],
119                actual: vec![array.len()],
120            });
121        }
122    }
123
124    let mut result = if axis == 0 {
125        // Stack along rows (each array becomes a row)
126        Array2::zeros((arrays.len(), first_len))
127    } else {
128        // Stack along columns (each array becomes a column)
129        Array2::zeros((first_len, arrays.len()))
130    };
131
132    for (i, array) in arrays.iter().enumerate() {
133        if axis == 0 {
134            for (j, value) in array.iter().enumerate() {
135                result[[i, j]] = value.clone();
136            }
137        } else {
138            for (j, value) in array.iter().enumerate() {
139                result[[j, i]] = value.clone();
140            }
141        }
142    }
143
144    Ok(result)
145}
146
147/// Concatenate 2D arrays along specified axis
148pub fn concatenate_2d<T: Clone + Zero>(
149    arrays: &[&Array2<T>],
150    axis: usize,
151) -> UtilsResult<Array2<T>> {
152    if arrays.is_empty() {
153        return Err(UtilsError::EmptyInput);
154    }
155
156    if axis > 1 {
157        return Err(UtilsError::InvalidParameter(
158            "Axis must be 0 or 1 for 2D arrays".to_string(),
159        ));
160    }
161
162    let first_shape = arrays[0].raw_dim();
163
164    if axis == 0 {
165        // Concatenate along rows - all arrays must have same number of columns
166        let ncols = first_shape[1];
167        for array in arrays.iter() {
168            if array.ncols() != ncols {
169                return Err(UtilsError::ShapeMismatch {
170                    expected: vec![array.nrows(), ncols],
171                    actual: vec![array.nrows(), array.ncols()],
172                });
173            }
174        }
175
176        let total_rows: usize = arrays.iter().map(|arr| arr.nrows()).sum();
177        let mut result = Array2::zeros((total_rows, ncols));
178
179        let mut row_offset = 0;
180        for array in arrays {
181            let nrows = array.nrows();
182            for i in 0..nrows {
183                for j in 0..ncols {
184                    result[[row_offset + i, j]] = array[[i, j]].clone();
185                }
186            }
187            row_offset += nrows;
188        }
189
190        Ok(result)
191    } else {
192        // Concatenate along columns - all arrays must have same number of rows
193        let nrows = first_shape[0];
194        for array in arrays.iter() {
195            if array.nrows() != nrows {
196                return Err(UtilsError::ShapeMismatch {
197                    expected: vec![nrows, array.ncols()],
198                    actual: vec![array.nrows(), array.ncols()],
199                });
200            }
201        }
202
203        let total_cols: usize = arrays.iter().map(|arr| arr.ncols()).sum();
204        let mut result = Array2::zeros((nrows, total_cols));
205
206        let mut col_offset = 0;
207        for array in arrays {
208            let ncols = array.ncols();
209            for i in 0..nrows {
210                for j in 0..ncols {
211                    result[[i, col_offset + j]] = array[[i, j]].clone();
212                }
213            }
214            col_offset += ncols;
215        }
216
217        Ok(result)
218    }
219}
220
221/// Split 2D array along specified axis
222pub fn split_2d<T: Clone>(
223    array: &Array2<T>,
224    indices_or_sections: &[usize],
225    axis: usize,
226) -> UtilsResult<Vec<Array2<T>>> {
227    if axis > 1 {
228        return Err(UtilsError::InvalidParameter(
229            "Axis must be 0 or 1 for 2D arrays".to_string(),
230        ));
231    }
232
233    let mut splits = Vec::new();
234    let mut start = 0;
235
236    if axis == 0 {
237        // Split along rows
238        for &split_point in indices_or_sections {
239            if split_point > array.nrows() {
240                return Err(UtilsError::InvalidParameter(
241                    "Split index exceeds array dimensions".to_string(),
242                ));
243            }
244
245            let section = array.slice(s![start..split_point, ..]).to_owned();
246            splits.push(section);
247            start = split_point;
248        }
249
250        // Add remaining section
251        if start < array.nrows() {
252            let section = array.slice(s![start.., ..]).to_owned();
253            splits.push(section);
254        }
255    } else {
256        // Split along columns
257        for &split_point in indices_or_sections {
258            if split_point > array.ncols() {
259                return Err(UtilsError::InvalidParameter(
260                    "Split index exceeds array dimensions".to_string(),
261                ));
262            }
263
264            let section = array.slice(s![.., start..split_point]).to_owned();
265            splits.push(section);
266            start = split_point;
267        }
268
269        // Add remaining section
270        if start < array.ncols() {
271            let section = array.slice(s![.., start..]).to_owned();
272            splits.push(section);
273        }
274    }
275
276    Ok(splits)
277}
278
279/// Tile 2D array with specified repetitions
280pub fn tile_2d<T: Clone + Zero>(array: &Array2<T>, reps: (usize, usize)) -> UtilsResult<Array2<T>> {
281    if reps.0 == 0 || reps.1 == 0 {
282        return Err(UtilsError::InvalidParameter(
283            "Repetitions must be positive".to_string(),
284        ));
285    }
286
287    let (nrows, ncols) = array.dim();
288    let new_shape = (nrows * reps.0, ncols * reps.1);
289    let mut result = Array2::zeros(new_shape);
290
291    for rep_row in 0..reps.0 {
292        for rep_col in 0..reps.1 {
293            let row_offset = rep_row * nrows;
294            let col_offset = rep_col * ncols;
295
296            for i in 0..nrows {
297                for j in 0..ncols {
298                    result[[row_offset + i, col_offset + j]] = array[[i, j]].clone();
299                }
300            }
301        }
302    }
303
304    Ok(result)
305}
306
307/// Pad 2D array with specified padding
308pub fn pad_2d<T: Clone + Zero>(
309    array: &Array2<T>,
310    padding: ((usize, usize), (usize, usize)),
311    constant_value: Option<T>,
312) -> UtilsResult<Array2<T>> {
313    let (nrows, ncols) = array.dim();
314    let (row_padding, col_padding) = padding;
315
316    let new_nrows = nrows + row_padding.0 + row_padding.1;
317    let new_ncols = ncols + col_padding.0 + col_padding.1;
318
319    let fill_value = constant_value.unwrap_or_else(T::zero);
320    let mut result = Array2::from_elem((new_nrows, new_ncols), fill_value);
321
322    // Copy original array to the center
323    for i in 0..nrows {
324        for j in 0..ncols {
325            result[[i + row_padding.0, j + col_padding.0]] = array[[i, j]].clone();
326        }
327    }
328
329    Ok(result)
330}