sklears_utils/array_utils/
core.rs

1//! Core array utilities and validation functions
2
3use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::{s, Array1, Array2};
5use scirs2_core::numeric::Zero;
6use std::collections::HashMap;
7
8/// Check that a 2D array is not empty
9pub fn check_array_2d<T>(array: &Array2<T>) -> UtilsResult<()> {
10    if array.is_empty() {
11        return Err(UtilsError::EmptyInput);
12    }
13    Ok(())
14}
15
16/// Check that a 1D array is not empty
17pub fn check_array_1d<T>(array: &Array1<T>) -> UtilsResult<()> {
18    if array.is_empty() {
19        return Err(UtilsError::EmptyInput);
20    }
21    Ok(())
22}
23
24/// Safe indexing for 1D arrays
25pub fn safe_indexing<T: Clone>(array: &Array1<T>, indices: &[usize]) -> UtilsResult<Array1<T>> {
26    let mut result = Vec::with_capacity(indices.len());
27
28    for &idx in indices {
29        if idx >= array.len() {
30            return Err(UtilsError::InvalidParameter(format!(
31                "Index {idx} out of bounds for array of length {}",
32                array.len()
33            )));
34        }
35        result.push(array[idx].clone());
36    }
37
38    Ok(Array1::from_vec(result))
39}
40
41/// Safe indexing for 2D arrays
42pub fn safe_indexing_2d<T: Clone>(array: &Array2<T>, indices: &[usize]) -> UtilsResult<Array2<T>> {
43    if array.is_empty() {
44        return Err(UtilsError::EmptyInput);
45    }
46
47    let ncols = array.ncols();
48    let mut result = Vec::with_capacity(indices.len() * ncols);
49
50    for &idx in indices {
51        if idx >= array.nrows() {
52            return Err(UtilsError::InvalidParameter(format!(
53                "Row index {idx} out of bounds for array with {} rows",
54                array.nrows()
55            )));
56        }
57        for col in 0..ncols {
58            result.push(array[[idx, col]].clone());
59        }
60    }
61
62    let result_len = result.len();
63    let result_array = Array1::from_vec(result)
64        .into_shape_with_order((indices.len(), ncols))
65        .map_err(|_| UtilsError::ShapeMismatch {
66            expected: vec![indices.len(), ncols],
67            actual: vec![result_len],
68        })?;
69
70    Ok(result_array)
71}
72
73/// Convert 2D array to 1D if possible, otherwise return error
74pub fn column_or_1d<T: Clone>(array: &Array2<T>) -> UtilsResult<Array1<T>> {
75    if array.ncols() == 1 {
76        Ok(array.column(0).to_owned())
77    } else if array.nrows() == 1 {
78        Ok(array.row(0).to_owned())
79    } else {
80        Err(UtilsError::ShapeMismatch {
81            expected: vec![1],
82            actual: vec![array.nrows(), array.ncols()],
83        })
84    }
85}
86
87/// Normalize array to unit norm (L2)
88pub fn normalize_array(array: &mut Array1<f64>) -> UtilsResult<()> {
89    if array.is_empty() {
90        return Err(UtilsError::EmptyInput);
91    }
92
93    let norm = array.iter().map(|&x| x * x).sum::<f64>().sqrt();
94    if norm > 1e-10 {
95        array.par_mapv_inplace(|x| x / norm);
96    }
97    Ok(())
98}
99
100/// Get unique labels in sorted order
101pub fn unique_labels<T: Clone + Ord>(labels: &Array1<T>) -> Vec<T> {
102    let mut unique: Vec<T> = labels.iter().cloned().collect();
103    unique.sort();
104    unique.dedup();
105    unique
106}
107
108/// Count occurrences of each label
109pub fn label_counts<T: Clone + Eq + std::hash::Hash>(labels: &Array1<T>) -> HashMap<T, usize> {
110    let mut counts = HashMap::new();
111    for label in labels.iter() {
112        *counts.entry(label.clone()).or_insert(0) += 1;
113    }
114    counts
115}
116
117/// Split array into chunks
118pub fn array_split<T: Clone>(array: &Array1<T>, n_splits: usize) -> UtilsResult<Vec<Array1<T>>> {
119    if n_splits == 0 {
120        return Err(UtilsError::InvalidParameter(
121            "Number of splits must be positive".to_string(),
122        ));
123    }
124
125    if array.is_empty() {
126        return Ok(vec![Array1::from_vec(vec![]); n_splits]);
127    }
128
129    let chunk_size = array.len() / n_splits;
130    let remainder = array.len() % n_splits;
131
132    let mut splits = Vec::with_capacity(n_splits);
133    let mut start = 0;
134
135    for i in 0..n_splits {
136        let current_chunk_size = if i < remainder {
137            chunk_size + 1
138        } else {
139            chunk_size
140        };
141
142        let end = start + current_chunk_size;
143        let chunk = array.slice(s![start..end]).to_owned();
144        splits.push(chunk);
145        start = end;
146    }
147
148    Ok(splits)
149}
150
151/// Concatenate arrays
152pub fn array_concatenate<T: Clone>(arrays: &[Array1<T>]) -> UtilsResult<Array1<T>> {
153    if arrays.is_empty() {
154        return Err(UtilsError::EmptyInput);
155    }
156
157    let mut result = Vec::new();
158    for array in arrays {
159        result.extend_from_slice(array.as_slice().unwrap());
160    }
161
162    Ok(Array1::from_vec(result))
163}
164
165/// Resize array to new size, padding with zeros if needed
166pub fn array_resize<T: Clone + Zero>(array: &Array1<T>, new_size: usize) -> Array1<T> {
167    let mut result = vec![T::zero(); new_size];
168    let copy_size = array.len().min(new_size);
169    result[..copy_size].clone_from_slice(&array.as_slice().unwrap()[..copy_size]);
170    Array1::from_vec(result)
171}
172
173/// Count unique elements
174pub fn array_unique_counts<T: Clone + Ord + std::hash::Hash>(
175    array: &Array1<T>,
176) -> HashMap<T, usize> {
177    let mut counts = HashMap::new();
178    for item in array.iter() {
179        *counts.entry(item.clone()).or_insert(0) += 1;
180    }
181    counts
182}
183
184/// Reverse array
185pub fn array_reverse<T: Clone>(array: &Array1<T>) -> Array1<T> {
186    let mut reversed: Vec<T> = array.iter().cloned().collect();
187    reversed.reverse();
188    Array1::from_vec(reversed)
189}