sklears_utils/array_utils/
indexing.rs

1//! Advanced indexing operations
2
3use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::{Array1, Array2};
5
6/// Fancy indexing for 1D arrays with multiple indices
7pub fn fancy_indexing_1d<T: Clone>(array: &Array1<T>, indices: &[usize]) -> UtilsResult<Array1<T>> {
8    let mut result = Vec::with_capacity(indices.len());
9
10    for &idx in indices {
11        if idx >= array.len() {
12            return Err(UtilsError::InvalidParameter(format!(
13                "Index {} out of bounds for array of length {}",
14                idx,
15                array.len()
16            )));
17        }
18        result.push(array[idx].clone());
19    }
20
21    Ok(Array1::from_vec(result))
22}
23
24/// Fancy indexing for 2D arrays
25pub fn fancy_indexing_2d<T: Clone>(
26    array: &Array2<T>,
27    row_indices: &[usize],
28    col_indices: &[usize],
29) -> UtilsResult<Array2<T>> {
30    if row_indices.len() != col_indices.len() {
31        return Err(UtilsError::ShapeMismatch {
32            expected: vec![row_indices.len()],
33            actual: vec![col_indices.len()],
34        });
35    }
36
37    let mut result = Vec::new();
38
39    for (&row_idx, &col_idx) in row_indices.iter().zip(col_indices.iter()) {
40        if row_idx >= array.nrows() {
41            return Err(UtilsError::InvalidParameter(format!(
42                "Row index {} out of bounds for array with {} rows",
43                row_idx,
44                array.nrows()
45            )));
46        }
47        if col_idx >= array.ncols() {
48            return Err(UtilsError::InvalidParameter(format!(
49                "Column index {} out of bounds for array with {} columns",
50                col_idx,
51                array.ncols()
52            )));
53        }
54        result.push(array[[row_idx, col_idx]].clone());
55    }
56
57    // Return as single column matrix
58    let result_len = result.len();
59    let result_array = Array1::from_vec(result)
60        .into_shape_with_order((result_len, 1))
61        .map_err(|_| UtilsError::ShapeMismatch {
62            expected: vec![result_len, 1],
63            actual: vec![result_len],
64        })?;
65
66    Ok(result_array)
67}
68
69/// Boolean indexing for 1D arrays
70pub fn boolean_indexing_1d<T: Clone>(
71    array: &Array1<T>,
72    mask: &Array1<bool>,
73) -> UtilsResult<Array1<T>> {
74    if array.len() != mask.len() {
75        return Err(UtilsError::ShapeMismatch {
76            expected: vec![array.len()],
77            actual: vec![mask.len()],
78        });
79    }
80
81    let mut result = Vec::new();
82    for (value, &include) in array.iter().zip(mask.iter()) {
83        if include {
84            result.push(value.clone());
85        }
86    }
87
88    Ok(Array1::from_vec(result))
89}
90
91/// Boolean indexing for 2D arrays (row-wise)
92pub fn boolean_indexing_2d<T: Clone>(
93    array: &Array2<T>,
94    mask: &Array1<bool>,
95) -> UtilsResult<Array2<T>> {
96    if array.nrows() != mask.len() {
97        return Err(UtilsError::ShapeMismatch {
98            expected: vec![array.nrows()],
99            actual: vec![mask.len()],
100        });
101    }
102
103    let mut result_rows = Vec::new();
104    let ncols = array.ncols();
105
106    for (row_idx, &include) in mask.iter().enumerate() {
107        if include {
108            let mut row = Vec::with_capacity(ncols);
109            for col_idx in 0..ncols {
110                row.push(array[[row_idx, col_idx]].clone());
111            }
112            result_rows.extend(row);
113        }
114    }
115
116    let n_selected_rows = mask.iter().filter(|&&x| x).count();
117
118    if n_selected_rows == 0 {
119        return Array2::from_shape_vec((0, ncols), vec![]).map_err(|_| UtilsError::ShapeMismatch {
120            expected: vec![0, ncols],
121            actual: vec![0],
122        });
123    }
124
125    let result_len = result_rows.len();
126    let result_array = Array1::from_vec(result_rows)
127        .into_shape_with_order((n_selected_rows, ncols))
128        .map_err(|_| UtilsError::ShapeMismatch {
129            expected: vec![n_selected_rows, ncols],
130            actual: vec![result_len],
131        })?;
132
133    Ok(result_array)
134}
135
136/// Create boolean mask from condition function
137pub fn create_mask<T, F>(array: &Array1<T>, condition: F) -> Array1<bool>
138where
139    T: Clone,
140    F: Fn(&T) -> bool,
141{
142    array.mapv(|ref x| condition(x))
143}
144
145/// Apply where condition (like numpy.where)
146pub fn where_condition<T, F>(
147    condition: &Array1<bool>,
148    true_values: &Array1<T>,
149    false_values: &Array1<T>,
150) -> UtilsResult<Array1<T>>
151where
152    T: Clone,
153    F: Clone,
154{
155    if condition.len() != true_values.len() || condition.len() != false_values.len() {
156        return Err(UtilsError::ShapeMismatch {
157            expected: vec![condition.len()],
158            actual: vec![true_values.len(), false_values.len()],
159        });
160    }
161
162    let mut result = Vec::with_capacity(condition.len());
163    for ((cond, true_val), false_val) in condition
164        .iter()
165        .zip(true_values.iter())
166        .zip(false_values.iter())
167    {
168        if *cond {
169            result.push(true_val.clone());
170        } else {
171            result.push(false_val.clone());
172        }
173    }
174
175    Ok(Array1::from_vec(result))
176}
177
178/// Slice with step (like Python's array\[start:end:step\])
179pub fn slice_with_step<T: Clone>(
180    array: &Array1<T>,
181    start: Option<usize>,
182    end: Option<usize>,
183    step: usize,
184) -> UtilsResult<Array1<T>> {
185    if step == 0 {
186        return Err(UtilsError::InvalidParameter(
187            "Step cannot be zero".to_string(),
188        ));
189    }
190
191    let len = array.len();
192    let start_idx = start.unwrap_or(0).min(len);
193    let end_idx = end.unwrap_or(len).min(len);
194
195    if start_idx >= end_idx {
196        return Ok(Array1::from_vec(vec![]));
197    }
198
199    let mut result = Vec::new();
200    let mut idx = start_idx;
201
202    while idx < end_idx {
203        result.push(array[idx].clone());
204        idx += step;
205    }
206
207    Ok(Array1::from_vec(result))
208}
209
210/// Find indices of maximum values
211pub fn argmax<T>(array: &Array1<T>) -> UtilsResult<usize>
212where
213    T: PartialOrd + Clone,
214{
215    if array.is_empty() {
216        return Err(UtilsError::EmptyInput);
217    }
218
219    let mut max_idx = 0;
220    let mut max_val = &array[0];
221
222    for (idx, val) in array.iter().enumerate().skip(1) {
223        if val > max_val {
224            max_val = val;
225            max_idx = idx;
226        }
227    }
228
229    Ok(max_idx)
230}
231
232/// Find indices of minimum values
233pub fn argmin<T>(array: &Array1<T>) -> UtilsResult<usize>
234where
235    T: PartialOrd + Clone,
236{
237    if array.is_empty() {
238        return Err(UtilsError::EmptyInput);
239    }
240
241    let mut min_idx = 0;
242    let mut min_val = &array[0];
243
244    for (idx, val) in array.iter().enumerate().skip(1) {
245        if val < min_val {
246            min_val = val;
247            min_idx = idx;
248        }
249    }
250
251    Ok(min_idx)
252}
253
254/// Sort indices (argsort)
255pub fn argsort<T>(array: &Array1<T>) -> Vec<usize>
256where
257    T: PartialOrd + Clone,
258{
259    let mut indices: Vec<usize> = (0..array.len()).collect();
260    indices.sort_by(|&a, &b| array[a].partial_cmp(&array[b]).unwrap());
261    indices
262}
263
264/// Take elements at indices
265pub fn take_1d<T: Clone>(array: &Array1<T>, indices: &[usize]) -> UtilsResult<Array1<T>> {
266    fancy_indexing_1d(array, indices)
267}
268
269/// Put values at indices
270pub fn put_1d<T: Clone>(array: &mut Array1<T>, indices: &[usize], values: &[T]) -> UtilsResult<()> {
271    if indices.len() != values.len() {
272        return Err(UtilsError::ShapeMismatch {
273            expected: vec![indices.len()],
274            actual: vec![values.len()],
275        });
276    }
277
278    for (&idx, value) in indices.iter().zip(values.iter()) {
279        if idx >= array.len() {
280            return Err(UtilsError::InvalidParameter(format!(
281                "Index {} out of bounds for array of length {}",
282                idx,
283                array.len()
284            )));
285        }
286        array[idx] = value.clone();
287    }
288
289    Ok(())
290}
291
292/// Filter array with condition function
293pub fn filter_array<T: Clone>(array: &Array1<T>, predicate: impl Fn(&T) -> bool) -> Array1<T> {
294    let filtered: Vec<T> = array.iter().filter(|&x| predicate(x)).cloned().collect();
295    Array1::from_vec(filtered)
296}
297
298/// Compress array with boolean mask
299pub fn compress_1d<T: Clone>(array: &Array1<T>, mask: &Array1<bool>) -> UtilsResult<Array1<T>> {
300    boolean_indexing_1d(array, mask)
301}