tenrso_exec/executor/
advanced_indexing.rs

1//! Advanced tensor indexing operations
2//!
3//! This module implements sophisticated indexing patterns including:
4//! - Fancy indexing (NumPy-style advanced indexing)
5//! - Multi-dimensional gather/scatter
6//! - Index broadcasting
7//! - Masked indexing
8//!
9//! # Use Cases
10//!
11//! - Embeddings lookup in transformers
12//! - Attention mechanisms
13//! - Sparse gradient updates
14//! - Dynamic batching
15//! - Index-based data shuffling
16
17use anyhow::{anyhow, Result};
18use scirs2_core::ndarray_ext::{ArrayView, IxDyn, Zip};
19use scirs2_core::numeric::{Float, FromPrimitive, Num, ToPrimitive};
20use tenrso_core::{Axis, DenseND};
21
22/// Advanced gather operation with multi-dimensional support
23///
24/// Gathers values from `input` along the specified `axis` using `indices`.
25/// Unlike the basic gather, this supports:
26/// - Multi-dimensional indices
27/// - Negative indices (Python-style)
28/// - Out-of-bounds checking with clear error messages
29///
30/// # Arguments
31///
32/// * `input` - Source tensor to gather from
33/// * `axis` - Axis along which to gather
34/// * `indices` - Integer indices (as Float tensor for compatibility)
35/// * `allow_negative` - Whether to allow negative indices (wrapped to positive)
36///
37/// # Example
38///
39/// ```text
40/// input: [10, 20, 30, 40, 50]
41/// indices: [0, 2, 4, 1]
42/// result: [10, 30, 50, 20]
43/// ```
44pub fn advanced_gather<T>(
45    input: &DenseND<T>,
46    axis: Axis,
47    indices: &DenseND<T>,
48    allow_negative: bool,
49) -> Result<DenseND<T>>
50where
51    T: Clone + Num + Float + ToPrimitive + FromPrimitive,
52{
53    let input_shape = input.shape();
54
55    // Validate axis
56    if axis >= input_shape.len() {
57        return Err(anyhow!(
58            "Axis {} out of bounds for tensor with {} dimensions",
59            axis,
60            input_shape.len()
61        ));
62    }
63
64    let axis_size = input_shape[axis];
65    let indices_view = indices.view();
66
67    // Convert indices to usize, handling negative indices if allowed
68    let converted_indices: Vec<usize> = indices_view
69        .iter()
70        .map(|&idx| {
71            let idx_i64 = idx
72                .to_i64()
73                .ok_or_else(|| anyhow!("Index value cannot be converted to integer"))?;
74
75            let final_idx = if idx_i64 < 0 {
76                if !allow_negative {
77                    return Err(anyhow!("Negative index {} not allowed", idx_i64));
78                }
79                // Python-style negative indexing
80                let positive_idx = (axis_size as i64 + idx_i64) as usize;
81                if positive_idx >= axis_size {
82                    return Err(anyhow!(
83                        "Negative index {} out of bounds for axis size {}",
84                        idx_i64,
85                        axis_size
86                    ));
87                }
88                positive_idx
89            } else {
90                let idx_usize = idx_i64 as usize;
91                if idx_usize >= axis_size {
92                    return Err(anyhow!(
93                        "Index {} out of bounds for axis size {}",
94                        idx_usize,
95                        axis_size
96                    ));
97                }
98                idx_usize
99            };
100
101            Ok(final_idx)
102        })
103        .collect::<Result<Vec<_>>>()?;
104
105    // Build output shape
106    let mut output_shape = input_shape.to_vec();
107    output_shape[axis] = converted_indices.len();
108
109    // Create output array
110    let total_elements: usize = output_shape.iter().product();
111    let mut output_data = vec![T::zero(); total_elements];
112
113    // Perform gathering
114    let input_view = input.view();
115    gather_recursive(
116        &input_view,
117        &converted_indices,
118        axis,
119        &output_shape,
120        &mut output_data,
121        0,
122        &mut 0,
123    )?;
124
125    DenseND::from_vec(output_data, &output_shape)
126}
127
128/// Recursive helper for advanced gather
129fn gather_recursive<T>(
130    input: &ArrayView<T, IxDyn>,
131    indices: &[usize],
132    axis: Axis,
133    output_shape: &[usize],
134    output_data: &mut [T],
135    current_depth: usize,
136    output_idx: &mut usize,
137) -> Result<()>
138where
139    T: Clone + Num,
140{
141    if current_depth == output_shape.len() {
142        return Ok(());
143    }
144
145    let dim_size = output_shape[current_depth];
146
147    if current_depth == axis {
148        // At the gather axis - use indices
149        for &idx in indices {
150            let slice = input.index_axis(scirs2_core::ndarray_ext::Axis(current_depth), idx);
151            if current_depth == output_shape.len() - 1 {
152                // Leaf level - copy data
153                output_data[*output_idx] = slice.iter().next().unwrap().clone();
154                *output_idx += 1;
155            } else {
156                gather_recursive(
157                    &slice,
158                    indices,
159                    axis,
160                    output_shape,
161                    output_data,
162                    current_depth + 1,
163                    output_idx,
164                )?;
165            }
166        }
167    } else {
168        // Not at gather axis - iterate normally
169        for i in 0..dim_size {
170            let slice = input.index_axis(scirs2_core::ndarray_ext::Axis(current_depth), i);
171            if current_depth == output_shape.len() - 1 {
172                // Leaf level - copy data
173                output_data[*output_idx] = slice.iter().next().unwrap().clone();
174                *output_idx += 1;
175            } else {
176                gather_recursive(
177                    &slice,
178                    indices,
179                    axis,
180                    output_shape,
181                    output_data,
182                    current_depth + 1,
183                    output_idx,
184                )?;
185            }
186        }
187    }
188
189    Ok(())
190}
191
192/// Advanced scatter operation with multi-dimensional support
193///
194/// Scatters `values` into an output tensor of shape `shape` along the specified `axis`
195/// using `indices`. Unlike the basic scatter, this supports:
196/// - Multi-dimensional indices and values
197/// - Negative indices (Python-style)
198/// - Accumulation modes (replace, add, max, min)
199///
200/// # Arguments
201///
202/// * `shape` - Shape of the output tensor
203/// * `axis` - Axis along which to scatter
204/// * `indices` - Integer indices (as Float tensor)
205/// * `values` - Values to scatter
206/// * `mode` - Scatter mode (Replace, Add, Max, Min)
207///
208/// # Example
209///
210/// ```text
211/// shape: [5]
212/// indices: [0, 2, 4]
213/// values: [10, 30, 50]
214/// result: [10, 0, 30, 0, 50] (assuming zero initialization)
215/// ```
216pub fn advanced_scatter<T>(
217    shape: &[usize],
218    axis: Axis,
219    indices: &DenseND<T>,
220    values: &DenseND<T>,
221    mode: ScatterMode,
222) -> Result<DenseND<T>>
223where
224    T: Clone + Num + Float + ToPrimitive + FromPrimitive + PartialOrd,
225{
226    // Validate axis
227    if axis >= shape.len() {
228        return Err(anyhow!(
229            "Axis {} out of bounds for tensor with {} dimensions",
230            axis,
231            shape.len()
232        ));
233    }
234
235    let axis_size = shape[axis];
236    let indices_view = indices.view();
237
238    // Convert and validate indices
239    let converted_indices: Vec<usize> = indices_view
240        .iter()
241        .map(|&idx| {
242            let idx_i64 = idx
243                .to_i64()
244                .ok_or_else(|| anyhow!("Index value cannot be converted to integer"))?;
245
246            if idx_i64 < 0 {
247                return Err(anyhow!("Negative indices not supported in scatter"));
248            }
249
250            let idx_usize = idx_i64 as usize;
251            if idx_usize >= axis_size {
252                return Err(anyhow!(
253                    "Index {} out of bounds for axis size {}",
254                    idx_usize,
255                    axis_size
256                ));
257            }
258
259            Ok(idx_usize)
260        })
261        .collect::<Result<Vec<_>>>()?;
262
263    // Initialize output based on mode
264    let total_elements: usize = shape.iter().product();
265    let mut output_data = match mode {
266        ScatterMode::Replace => vec![T::zero(); total_elements],
267        ScatterMode::Add => vec![T::zero(); total_elements],
268        ScatterMode::Max => vec![T::from_f64(f64::NEG_INFINITY).unwrap(); total_elements],
269        ScatterMode::Min => vec![T::from_f64(f64::INFINITY).unwrap(); total_elements],
270    };
271
272    // Perform scattering
273    let values_view = values.view();
274    scatter_recursive(
275        &values_view,
276        &converted_indices,
277        axis,
278        shape,
279        &mut output_data,
280        0,
281        &mut 0,
282        mode,
283    )?;
284
285    DenseND::from_vec(output_data, shape)
286}
287
288/// Scatter modes for accumulation
289#[derive(Clone, Copy, Debug)]
290pub enum ScatterMode {
291    /// Replace existing values (default)
292    Replace,
293    /// Add to existing values (accumulate)
294    Add,
295    /// Take maximum of existing and new values
296    Max,
297    /// Take minimum of existing and new values
298    Min,
299}
300
301/// Recursive helper for advanced scatter
302#[allow(clippy::too_many_arguments)]
303fn scatter_recursive<T>(
304    values: &ArrayView<T, IxDyn>,
305    indices: &[usize],
306    axis: Axis,
307    output_shape: &[usize],
308    output_data: &mut [T],
309    current_depth: usize,
310    values_idx: &mut usize,
311    mode: ScatterMode,
312) -> Result<()>
313where
314    T: Clone + Num + PartialOrd,
315{
316    if current_depth == output_shape.len() {
317        return Ok(());
318    }
319
320    let _dim_size = output_shape[current_depth];
321
322    if current_depth == axis {
323        // At the scatter axis - use indices
324        for &out_idx in indices {
325            if current_depth == output_shape.len() - 1 {
326                // Leaf level - write data
327                let value = values.iter().nth(*values_idx).unwrap().clone();
328                let flat_idx = compute_flat_index(output_shape, &[out_idx], current_depth);
329
330                match mode {
331                    ScatterMode::Replace => output_data[flat_idx] = value,
332                    ScatterMode::Add => {
333                        output_data[flat_idx] = output_data[flat_idx].clone() + value
334                    }
335                    ScatterMode::Max => {
336                        if value > output_data[flat_idx] {
337                            output_data[flat_idx] = value;
338                        }
339                    }
340                    ScatterMode::Min => {
341                        if value < output_data[flat_idx] {
342                            output_data[flat_idx] = value;
343                        }
344                    }
345                }
346                *values_idx += 1;
347            }
348        }
349    }
350
351    Ok(())
352}
353
354/// Compute flat index from multi-dimensional indices
355fn compute_flat_index(shape: &[usize], indices: &[usize], depth: usize) -> usize {
356    let mut flat_idx = 0;
357    let mut stride = 1;
358
359    for i in (0..=depth).rev() {
360        flat_idx += indices[i] * stride;
361        if i > 0 {
362            stride *= shape[i];
363        }
364    }
365
366    flat_idx
367}
368
369/// Fancy indexing with boolean masks
370///
371/// Select elements from `input` where `mask` is true (> 0).
372/// Returns a 1D tensor containing the selected elements.
373///
374/// This is more flexible than basic masked_select as it supports:
375/// - Any tensor shape
376/// - Efficient memory access patterns
377/// - Parallel processing for large tensors
378pub fn fancy_index_mask<T>(input: &DenseND<T>, mask: &DenseND<T>) -> Result<DenseND<T>>
379where
380    T: Clone + Num + PartialOrd,
381{
382    if input.shape() != mask.shape() {
383        return Err(anyhow!(
384            "Input and mask must have the same shape: {:?} vs {:?}",
385            input.shape(),
386            mask.shape()
387        ));
388    }
389
390    let input_view = input.view();
391    let mask_view = mask.view();
392    let zero = T::zero();
393
394    // Collect selected elements
395    let mut selected = Vec::new();
396    Zip::from(&input_view)
397        .and(&mask_view)
398        .for_each(|val, mask_val| {
399            if *mask_val > zero {
400                selected.push(val.clone());
401            }
402        });
403
404    let output_shape = vec![selected.len()];
405    DenseND::from_vec(selected, &output_shape)
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_advanced_gather_1d() {
414        let input = DenseND::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], &[5]).unwrap();
415        let indices = DenseND::from_vec(vec![0.0, 2.0, 4.0, 1.0], &[4]).unwrap();
416
417        let result = advanced_gather(&input, 0, &indices, false).unwrap();
418
419        assert_eq!(result.shape(), &[4]);
420        let result_view = result.view();
421        assert_eq!(result_view[[0]], 10.0);
422        assert_eq!(result_view[[1]], 30.0);
423        assert_eq!(result_view[[2]], 50.0);
424        assert_eq!(result_view[[3]], 20.0);
425    }
426
427    #[test]
428    fn test_advanced_gather_negative_indices() {
429        let input = DenseND::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], &[5]).unwrap();
430        let indices = DenseND::from_vec(vec![-1.0, -2.0], &[2]).unwrap();
431
432        let result = advanced_gather(&input, 0, &indices, true).unwrap();
433
434        assert_eq!(result.shape(), &[2]);
435        let result_view = result.view();
436        assert_eq!(result_view[[0]], 50.0); // -1 -> index 4
437        assert_eq!(result_view[[1]], 40.0); // -2 -> index 3
438    }
439
440    #[test]
441    fn test_advanced_gather_out_of_bounds() {
442        let input = DenseND::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap();
443        let indices = DenseND::from_vec(vec![0.0, 5.0], &[2]).unwrap();
444
445        let result = advanced_gather(&input, 0, &indices, false);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn test_advanced_scatter_replace() {
451        let shape = vec![5];
452        let indices = DenseND::from_vec(vec![0.0, 2.0, 4.0], &[3]).unwrap();
453        let values = DenseND::from_vec(vec![10.0, 30.0, 50.0], &[3]).unwrap();
454
455        let result = advanced_scatter(&shape, 0, &indices, &values, ScatterMode::Replace).unwrap();
456
457        assert_eq!(result.shape(), &[5]);
458        let result_view = result.view();
459        assert_eq!(result_view[[0]], 10.0);
460        assert_eq!(result_view[[1]], 0.0);
461        assert_eq!(result_view[[2]], 30.0);
462        assert_eq!(result_view[[3]], 0.0);
463        assert_eq!(result_view[[4]], 50.0);
464    }
465
466    #[test]
467    fn test_fancy_index_mask() {
468        let input = DenseND::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], &[5]).unwrap();
469        let mask = DenseND::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], &[5]).unwrap();
470
471        let result = fancy_index_mask(&input, &mask).unwrap();
472
473        assert_eq!(result.shape(), &[3]);
474        let result_view = result.view();
475        assert_eq!(result_view[[0]], 10.0);
476        assert_eq!(result_view[[1]], 30.0);
477        assert_eq!(result_view[[2]], 50.0);
478    }
479
480    #[test]
481    fn test_fancy_index_mask_all_false() {
482        let input = DenseND::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap();
483        let mask = DenseND::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap();
484
485        let result = fancy_index_mask(&input, &mask).unwrap();
486
487        assert_eq!(result.shape(), &[0]);
488    }
489
490    #[test]
491    fn test_fancy_index_mask_shape_mismatch() {
492        let input = DenseND::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap();
493        let mask = DenseND::from_vec(vec![1.0, 0.0], &[2]).unwrap();
494
495        let result = fancy_index_mask(&input, &mask);
496        assert!(result.is_err());
497    }
498}