1use anyhow::{anyhow, Result};
18use scirs2_core::ndarray_ext::{ArrayView, IxDyn, Zip};
19use scirs2_core::numeric::{Float, FromPrimitive, Num, ToPrimitive};
20use tenrso_core::{Axis, DenseND};
21
22pub fn advanced_gather<T>(
45 input: &DenseND<T>,
46 axis: Axis,
47 indices: &DenseND<T>,
48 allow_negative: bool,
49) -> Result<DenseND<T>>
50where
51 T: Clone + Num + Float + ToPrimitive + FromPrimitive,
52{
53 let input_shape = input.shape();
54
55 if axis >= input_shape.len() {
57 return Err(anyhow!(
58 "Axis {} out of bounds for tensor with {} dimensions",
59 axis,
60 input_shape.len()
61 ));
62 }
63
64 let axis_size = input_shape[axis];
65 let indices_view = indices.view();
66
67 let converted_indices: Vec<usize> = indices_view
69 .iter()
70 .map(|&idx| {
71 let idx_i64 = idx
72 .to_i64()
73 .ok_or_else(|| anyhow!("Index value cannot be converted to integer"))?;
74
75 let final_idx = if idx_i64 < 0 {
76 if !allow_negative {
77 return Err(anyhow!("Negative index {} not allowed", idx_i64));
78 }
79 let positive_idx = (axis_size as i64 + idx_i64) as usize;
81 if positive_idx >= axis_size {
82 return Err(anyhow!(
83 "Negative index {} out of bounds for axis size {}",
84 idx_i64,
85 axis_size
86 ));
87 }
88 positive_idx
89 } else {
90 let idx_usize = idx_i64 as usize;
91 if idx_usize >= axis_size {
92 return Err(anyhow!(
93 "Index {} out of bounds for axis size {}",
94 idx_usize,
95 axis_size
96 ));
97 }
98 idx_usize
99 };
100
101 Ok(final_idx)
102 })
103 .collect::<Result<Vec<_>>>()?;
104
105 let mut output_shape = input_shape.to_vec();
107 output_shape[axis] = converted_indices.len();
108
109 let total_elements: usize = output_shape.iter().product();
111 let mut output_data = vec![T::zero(); total_elements];
112
113 let input_view = input.view();
115 gather_recursive(
116 &input_view,
117 &converted_indices,
118 axis,
119 &output_shape,
120 &mut output_data,
121 0,
122 &mut 0,
123 )?;
124
125 DenseND::from_vec(output_data, &output_shape)
126}
127
128fn gather_recursive<T>(
130 input: &ArrayView<T, IxDyn>,
131 indices: &[usize],
132 axis: Axis,
133 output_shape: &[usize],
134 output_data: &mut [T],
135 current_depth: usize,
136 output_idx: &mut usize,
137) -> Result<()>
138where
139 T: Clone + Num,
140{
141 if current_depth == output_shape.len() {
142 return Ok(());
143 }
144
145 let dim_size = output_shape[current_depth];
146
147 if current_depth == axis {
148 for &idx in indices {
150 let slice = input.index_axis(scirs2_core::ndarray_ext::Axis(current_depth), idx);
151 if current_depth == output_shape.len() - 1 {
152 output_data[*output_idx] = slice.iter().next().unwrap().clone();
154 *output_idx += 1;
155 } else {
156 gather_recursive(
157 &slice,
158 indices,
159 axis,
160 output_shape,
161 output_data,
162 current_depth + 1,
163 output_idx,
164 )?;
165 }
166 }
167 } else {
168 for i in 0..dim_size {
170 let slice = input.index_axis(scirs2_core::ndarray_ext::Axis(current_depth), i);
171 if current_depth == output_shape.len() - 1 {
172 output_data[*output_idx] = slice.iter().next().unwrap().clone();
174 *output_idx += 1;
175 } else {
176 gather_recursive(
177 &slice,
178 indices,
179 axis,
180 output_shape,
181 output_data,
182 current_depth + 1,
183 output_idx,
184 )?;
185 }
186 }
187 }
188
189 Ok(())
190}
191
192pub fn advanced_scatter<T>(
217 shape: &[usize],
218 axis: Axis,
219 indices: &DenseND<T>,
220 values: &DenseND<T>,
221 mode: ScatterMode,
222) -> Result<DenseND<T>>
223where
224 T: Clone + Num + Float + ToPrimitive + FromPrimitive + PartialOrd,
225{
226 if axis >= shape.len() {
228 return Err(anyhow!(
229 "Axis {} out of bounds for tensor with {} dimensions",
230 axis,
231 shape.len()
232 ));
233 }
234
235 let axis_size = shape[axis];
236 let indices_view = indices.view();
237
238 let converted_indices: Vec<usize> = indices_view
240 .iter()
241 .map(|&idx| {
242 let idx_i64 = idx
243 .to_i64()
244 .ok_or_else(|| anyhow!("Index value cannot be converted to integer"))?;
245
246 if idx_i64 < 0 {
247 return Err(anyhow!("Negative indices not supported in scatter"));
248 }
249
250 let idx_usize = idx_i64 as usize;
251 if idx_usize >= axis_size {
252 return Err(anyhow!(
253 "Index {} out of bounds for axis size {}",
254 idx_usize,
255 axis_size
256 ));
257 }
258
259 Ok(idx_usize)
260 })
261 .collect::<Result<Vec<_>>>()?;
262
263 let total_elements: usize = shape.iter().product();
265 let mut output_data = match mode {
266 ScatterMode::Replace => vec![T::zero(); total_elements],
267 ScatterMode::Add => vec![T::zero(); total_elements],
268 ScatterMode::Max => vec![T::from_f64(f64::NEG_INFINITY).unwrap(); total_elements],
269 ScatterMode::Min => vec![T::from_f64(f64::INFINITY).unwrap(); total_elements],
270 };
271
272 let values_view = values.view();
274 scatter_recursive(
275 &values_view,
276 &converted_indices,
277 axis,
278 shape,
279 &mut output_data,
280 0,
281 &mut 0,
282 mode,
283 )?;
284
285 DenseND::from_vec(output_data, shape)
286}
287
288#[derive(Clone, Copy, Debug)]
290pub enum ScatterMode {
291 Replace,
293 Add,
295 Max,
297 Min,
299}
300
301#[allow(clippy::too_many_arguments)]
303fn scatter_recursive<T>(
304 values: &ArrayView<T, IxDyn>,
305 indices: &[usize],
306 axis: Axis,
307 output_shape: &[usize],
308 output_data: &mut [T],
309 current_depth: usize,
310 values_idx: &mut usize,
311 mode: ScatterMode,
312) -> Result<()>
313where
314 T: Clone + Num + PartialOrd,
315{
316 if current_depth == output_shape.len() {
317 return Ok(());
318 }
319
320 let _dim_size = output_shape[current_depth];
321
322 if current_depth == axis {
323 for &out_idx in indices {
325 if current_depth == output_shape.len() - 1 {
326 let value = values.iter().nth(*values_idx).unwrap().clone();
328 let flat_idx = compute_flat_index(output_shape, &[out_idx], current_depth);
329
330 match mode {
331 ScatterMode::Replace => output_data[flat_idx] = value,
332 ScatterMode::Add => {
333 output_data[flat_idx] = output_data[flat_idx].clone() + value
334 }
335 ScatterMode::Max => {
336 if value > output_data[flat_idx] {
337 output_data[flat_idx] = value;
338 }
339 }
340 ScatterMode::Min => {
341 if value < output_data[flat_idx] {
342 output_data[flat_idx] = value;
343 }
344 }
345 }
346 *values_idx += 1;
347 }
348 }
349 }
350
351 Ok(())
352}
353
354fn compute_flat_index(shape: &[usize], indices: &[usize], depth: usize) -> usize {
356 let mut flat_idx = 0;
357 let mut stride = 1;
358
359 for i in (0..=depth).rev() {
360 flat_idx += indices[i] * stride;
361 if i > 0 {
362 stride *= shape[i];
363 }
364 }
365
366 flat_idx
367}
368
369pub fn fancy_index_mask<T>(input: &DenseND<T>, mask: &DenseND<T>) -> Result<DenseND<T>>
379where
380 T: Clone + Num + PartialOrd,
381{
382 if input.shape() != mask.shape() {
383 return Err(anyhow!(
384 "Input and mask must have the same shape: {:?} vs {:?}",
385 input.shape(),
386 mask.shape()
387 ));
388 }
389
390 let input_view = input.view();
391 let mask_view = mask.view();
392 let zero = T::zero();
393
394 let mut selected = Vec::new();
396 Zip::from(&input_view)
397 .and(&mask_view)
398 .for_each(|val, mask_val| {
399 if *mask_val > zero {
400 selected.push(val.clone());
401 }
402 });
403
404 let output_shape = vec![selected.len()];
405 DenseND::from_vec(selected, &output_shape)
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_advanced_gather_1d() {
414 let input = DenseND::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], &[5]).unwrap();
415 let indices = DenseND::from_vec(vec![0.0, 2.0, 4.0, 1.0], &[4]).unwrap();
416
417 let result = advanced_gather(&input, 0, &indices, false).unwrap();
418
419 assert_eq!(result.shape(), &[4]);
420 let result_view = result.view();
421 assert_eq!(result_view[[0]], 10.0);
422 assert_eq!(result_view[[1]], 30.0);
423 assert_eq!(result_view[[2]], 50.0);
424 assert_eq!(result_view[[3]], 20.0);
425 }
426
427 #[test]
428 fn test_advanced_gather_negative_indices() {
429 let input = DenseND::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], &[5]).unwrap();
430 let indices = DenseND::from_vec(vec![-1.0, -2.0], &[2]).unwrap();
431
432 let result = advanced_gather(&input, 0, &indices, true).unwrap();
433
434 assert_eq!(result.shape(), &[2]);
435 let result_view = result.view();
436 assert_eq!(result_view[[0]], 50.0); assert_eq!(result_view[[1]], 40.0); }
439
440 #[test]
441 fn test_advanced_gather_out_of_bounds() {
442 let input = DenseND::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap();
443 let indices = DenseND::from_vec(vec![0.0, 5.0], &[2]).unwrap();
444
445 let result = advanced_gather(&input, 0, &indices, false);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn test_advanced_scatter_replace() {
451 let shape = vec![5];
452 let indices = DenseND::from_vec(vec![0.0, 2.0, 4.0], &[3]).unwrap();
453 let values = DenseND::from_vec(vec![10.0, 30.0, 50.0], &[3]).unwrap();
454
455 let result = advanced_scatter(&shape, 0, &indices, &values, ScatterMode::Replace).unwrap();
456
457 assert_eq!(result.shape(), &[5]);
458 let result_view = result.view();
459 assert_eq!(result_view[[0]], 10.0);
460 assert_eq!(result_view[[1]], 0.0);
461 assert_eq!(result_view[[2]], 30.0);
462 assert_eq!(result_view[[3]], 0.0);
463 assert_eq!(result_view[[4]], 50.0);
464 }
465
466 #[test]
467 fn test_fancy_index_mask() {
468 let input = DenseND::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0], &[5]).unwrap();
469 let mask = DenseND::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], &[5]).unwrap();
470
471 let result = fancy_index_mask(&input, &mask).unwrap();
472
473 assert_eq!(result.shape(), &[3]);
474 let result_view = result.view();
475 assert_eq!(result_view[[0]], 10.0);
476 assert_eq!(result_view[[1]], 30.0);
477 assert_eq!(result_view[[2]], 50.0);
478 }
479
480 #[test]
481 fn test_fancy_index_mask_all_false() {
482 let input = DenseND::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap();
483 let mask = DenseND::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap();
484
485 let result = fancy_index_mask(&input, &mask).unwrap();
486
487 assert_eq!(result.shape(), &[0]);
488 }
489
490 #[test]
491 fn test_fancy_index_mask_shape_mismatch() {
492 let input = DenseND::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap();
493 let mask = DenseND::from_vec(vec![1.0, 0.0], &[2]).unwrap();
494
495 let result = fancy_index_mask(&input, &mask);
496 assert!(result.is_err());
497 }
498}