sklears_utils/array_utils/
indexing.rs1use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::{Array1, Array2};
5
6pub fn fancy_indexing_1d<T: Clone>(array: &Array1<T>, indices: &[usize]) -> UtilsResult<Array1<T>> {
8 let mut result = Vec::with_capacity(indices.len());
9
10 for &idx in indices {
11 if idx >= array.len() {
12 return Err(UtilsError::InvalidParameter(format!(
13 "Index {} out of bounds for array of length {}",
14 idx,
15 array.len()
16 )));
17 }
18 result.push(array[idx].clone());
19 }
20
21 Ok(Array1::from_vec(result))
22}
23
24pub fn fancy_indexing_2d<T: Clone>(
26 array: &Array2<T>,
27 row_indices: &[usize],
28 col_indices: &[usize],
29) -> UtilsResult<Array2<T>> {
30 if row_indices.len() != col_indices.len() {
31 return Err(UtilsError::ShapeMismatch {
32 expected: vec![row_indices.len()],
33 actual: vec![col_indices.len()],
34 });
35 }
36
37 let mut result = Vec::new();
38
39 for (&row_idx, &col_idx) in row_indices.iter().zip(col_indices.iter()) {
40 if row_idx >= array.nrows() {
41 return Err(UtilsError::InvalidParameter(format!(
42 "Row index {} out of bounds for array with {} rows",
43 row_idx,
44 array.nrows()
45 )));
46 }
47 if col_idx >= array.ncols() {
48 return Err(UtilsError::InvalidParameter(format!(
49 "Column index {} out of bounds for array with {} columns",
50 col_idx,
51 array.ncols()
52 )));
53 }
54 result.push(array[[row_idx, col_idx]].clone());
55 }
56
57 let result_len = result.len();
59 let result_array = Array1::from_vec(result)
60 .into_shape_with_order((result_len, 1))
61 .map_err(|_| UtilsError::ShapeMismatch {
62 expected: vec![result_len, 1],
63 actual: vec![result_len],
64 })?;
65
66 Ok(result_array)
67}
68
69pub fn boolean_indexing_1d<T: Clone>(
71 array: &Array1<T>,
72 mask: &Array1<bool>,
73) -> UtilsResult<Array1<T>> {
74 if array.len() != mask.len() {
75 return Err(UtilsError::ShapeMismatch {
76 expected: vec![array.len()],
77 actual: vec![mask.len()],
78 });
79 }
80
81 let mut result = Vec::new();
82 for (value, &include) in array.iter().zip(mask.iter()) {
83 if include {
84 result.push(value.clone());
85 }
86 }
87
88 Ok(Array1::from_vec(result))
89}
90
91pub fn boolean_indexing_2d<T: Clone>(
93 array: &Array2<T>,
94 mask: &Array1<bool>,
95) -> UtilsResult<Array2<T>> {
96 if array.nrows() != mask.len() {
97 return Err(UtilsError::ShapeMismatch {
98 expected: vec![array.nrows()],
99 actual: vec![mask.len()],
100 });
101 }
102
103 let mut result_rows = Vec::new();
104 let ncols = array.ncols();
105
106 for (row_idx, &include) in mask.iter().enumerate() {
107 if include {
108 let mut row = Vec::with_capacity(ncols);
109 for col_idx in 0..ncols {
110 row.push(array[[row_idx, col_idx]].clone());
111 }
112 result_rows.extend(row);
113 }
114 }
115
116 let n_selected_rows = mask.iter().filter(|&&x| x).count();
117
118 if n_selected_rows == 0 {
119 return Array2::from_shape_vec((0, ncols), vec![]).map_err(|_| UtilsError::ShapeMismatch {
120 expected: vec![0, ncols],
121 actual: vec![0],
122 });
123 }
124
125 let result_len = result_rows.len();
126 let result_array = Array1::from_vec(result_rows)
127 .into_shape_with_order((n_selected_rows, ncols))
128 .map_err(|_| UtilsError::ShapeMismatch {
129 expected: vec![n_selected_rows, ncols],
130 actual: vec![result_len],
131 })?;
132
133 Ok(result_array)
134}
135
136pub fn create_mask<T, F>(array: &Array1<T>, condition: F) -> Array1<bool>
138where
139 T: Clone,
140 F: Fn(&T) -> bool,
141{
142 array.mapv(|ref x| condition(x))
143}
144
145pub fn where_condition<T, F>(
147 condition: &Array1<bool>,
148 true_values: &Array1<T>,
149 false_values: &Array1<T>,
150) -> UtilsResult<Array1<T>>
151where
152 T: Clone,
153 F: Clone,
154{
155 if condition.len() != true_values.len() || condition.len() != false_values.len() {
156 return Err(UtilsError::ShapeMismatch {
157 expected: vec![condition.len()],
158 actual: vec![true_values.len(), false_values.len()],
159 });
160 }
161
162 let mut result = Vec::with_capacity(condition.len());
163 for ((cond, true_val), false_val) in condition
164 .iter()
165 .zip(true_values.iter())
166 .zip(false_values.iter())
167 {
168 if *cond {
169 result.push(true_val.clone());
170 } else {
171 result.push(false_val.clone());
172 }
173 }
174
175 Ok(Array1::from_vec(result))
176}
177
178pub fn slice_with_step<T: Clone>(
180 array: &Array1<T>,
181 start: Option<usize>,
182 end: Option<usize>,
183 step: usize,
184) -> UtilsResult<Array1<T>> {
185 if step == 0 {
186 return Err(UtilsError::InvalidParameter(
187 "Step cannot be zero".to_string(),
188 ));
189 }
190
191 let len = array.len();
192 let start_idx = start.unwrap_or(0).min(len);
193 let end_idx = end.unwrap_or(len).min(len);
194
195 if start_idx >= end_idx {
196 return Ok(Array1::from_vec(vec![]));
197 }
198
199 let mut result = Vec::new();
200 let mut idx = start_idx;
201
202 while idx < end_idx {
203 result.push(array[idx].clone());
204 idx += step;
205 }
206
207 Ok(Array1::from_vec(result))
208}
209
210pub fn argmax<T>(array: &Array1<T>) -> UtilsResult<usize>
212where
213 T: PartialOrd + Clone,
214{
215 if array.is_empty() {
216 return Err(UtilsError::EmptyInput);
217 }
218
219 let mut max_idx = 0;
220 let mut max_val = &array[0];
221
222 for (idx, val) in array.iter().enumerate().skip(1) {
223 if val > max_val {
224 max_val = val;
225 max_idx = idx;
226 }
227 }
228
229 Ok(max_idx)
230}
231
232pub fn argmin<T>(array: &Array1<T>) -> UtilsResult<usize>
234where
235 T: PartialOrd + Clone,
236{
237 if array.is_empty() {
238 return Err(UtilsError::EmptyInput);
239 }
240
241 let mut min_idx = 0;
242 let mut min_val = &array[0];
243
244 for (idx, val) in array.iter().enumerate().skip(1) {
245 if val < min_val {
246 min_val = val;
247 min_idx = idx;
248 }
249 }
250
251 Ok(min_idx)
252}
253
254pub fn argsort<T>(array: &Array1<T>) -> Vec<usize>
256where
257 T: PartialOrd + Clone,
258{
259 let mut indices: Vec<usize> = (0..array.len()).collect();
260 indices.sort_by(|&a, &b| array[a].partial_cmp(&array[b]).unwrap());
261 indices
262}
263
264pub fn take_1d<T: Clone>(array: &Array1<T>, indices: &[usize]) -> UtilsResult<Array1<T>> {
266 fancy_indexing_1d(array, indices)
267}
268
269pub fn put_1d<T: Clone>(array: &mut Array1<T>, indices: &[usize], values: &[T]) -> UtilsResult<()> {
271 if indices.len() != values.len() {
272 return Err(UtilsError::ShapeMismatch {
273 expected: vec![indices.len()],
274 actual: vec![values.len()],
275 });
276 }
277
278 for (&idx, value) in indices.iter().zip(values.iter()) {
279 if idx >= array.len() {
280 return Err(UtilsError::InvalidParameter(format!(
281 "Index {} out of bounds for array of length {}",
282 idx,
283 array.len()
284 )));
285 }
286 array[idx] = value.clone();
287 }
288
289 Ok(())
290}
291
292pub fn filter_array<T: Clone>(array: &Array1<T>, predicate: impl Fn(&T) -> bool) -> Array1<T> {
294 let filtered: Vec<T> = array.iter().filter(|&x| predicate(x)).cloned().collect();
295 Array1::from_vec(filtered)
296}
297
298pub fn compress_1d<T: Clone>(array: &Array1<T>, mask: &Array1<bool>) -> UtilsResult<Array1<T>> {
300 boolean_indexing_1d(array, mask)
301}