Skip to main content

scivex_core/tensor/
indexing.rs

1//! Slicing and advanced indexing for [`Tensor`].
2
3use crate::Scalar;
4use crate::error::{CoreError, Result};
5
6use super::{Tensor, compute_strides};
7
8/// A range specification for one axis when slicing a tensor.
9///
10/// Mirrors Python's `start:stop:step` slice notation.
11///
12/// # Examples
13///
14/// ```
15/// # use scivex_core::tensor::indexing::SliceRange;
16/// let r = SliceRange::new(0, 10, 2); // 0, 2, 4, 6, 8
17/// assert_eq!(r.start, 0);
18/// assert_eq!(r.step, 2);
19/// ```
20#[cfg_attr(
21    feature = "serde-support",
22    derive(serde::Serialize, serde::Deserialize)
23)]
24#[derive(Debug, Clone, Copy)]
25pub struct SliceRange {
26    pub start: usize,
27    pub stop: usize,
28    pub step: usize,
29}
30
31impl SliceRange {
32    /// Create a new slice range.
33    ///
34    /// # Panics
35    ///
36    /// Panics in debug mode if `step == 0`.
37    ///
38    /// # Examples
39    ///
40    /// ```
41    /// # use scivex_core::tensor::indexing::SliceRange;
42    /// let r = SliceRange::new(1, 5, 2);
43    /// assert_eq!(r.start, 1);
44    /// assert_eq!(r.stop, 5);
45    /// ```
46    #[allow(clippy::similar_names)]
47    pub fn new(start: usize, stop: usize, step: usize) -> Self {
48        debug_assert!(step > 0, "slice step must be > 0");
49        Self { start, stop, step }
50    }
51
52    /// Shorthand for `start..stop` with step 1.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// # use scivex_core::tensor::indexing::SliceRange;
58    /// let r = SliceRange::range(2, 5);
59    /// assert_eq!(r.step, 1);
60    /// ```
61    pub fn range(start: usize, stop: usize) -> Self {
62        Self::new(start, stop, 1)
63    }
64
65    /// Select the full extent of an axis. Requires knowing the axis length.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// # use scivex_core::tensor::indexing::SliceRange;
71    /// let r = SliceRange::full(5);
72    /// assert_eq!(r.start, 0);
73    /// assert_eq!(r.stop, 5);
74    /// assert_eq!(r.step, 1);
75    /// ```
76    pub fn full(len: usize) -> Self {
77        Self::new(0, len, 1)
78    }
79
80    /// The number of elements this range selects.
81    fn len(&self) -> usize {
82        if self.stop <= self.start {
83            0
84        } else {
85            (self.stop - self.start).div_ceil(self.step)
86        }
87    }
88}
89
90impl<T: Scalar> Tensor<T> {
91    /// Extract a sub-tensor by slicing along each axis.
92    ///
93    /// `ranges` must have exactly `ndim` elements. Each [`SliceRange`]
94    /// specifies which indices to take along that axis.
95    ///
96    /// Returns a new tensor with copied data.
97    ///
98    /// # Examples
99    ///
100    /// ```
101    /// # use scivex_core::{Tensor, tensor::indexing::SliceRange};
102    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
103    /// let s = t.slice(&[SliceRange::range(0, 2), SliceRange::range(1, 3)]).unwrap();
104    /// assert_eq!(s.shape(), &[2, 2]);
105    /// assert_eq!(s.as_slice(), &[2, 3, 5, 6]);
106    /// ```
107    pub fn slice(&self, ranges: &[SliceRange]) -> Result<Self> {
108        if ranges.len() != self.ndim() {
109            return Err(CoreError::InvalidArgument {
110                reason: "number of slice ranges must match tensor rank",
111            });
112        }
113
114        // Validate ranges
115        for (d, r) in ranges.iter().enumerate() {
116            if r.stop > self.shape[d] {
117                return Err(CoreError::IndexOutOfBounds {
118                    index: vec![r.stop],
119                    shape: self.shape.clone(),
120                });
121            }
122            if r.step == 0 {
123                return Err(CoreError::InvalidArgument {
124                    reason: "slice step must be > 0",
125                });
126            }
127        }
128
129        let new_shape: Vec<usize> = ranges.iter().map(SliceRange::len).collect();
130        let new_numel: usize = new_shape.iter().product();
131
132        if new_numel == 0 {
133            return Tensor::from_vec(vec![], new_shape);
134        }
135
136        let mut data = Vec::with_capacity(new_numel);
137        let mut index = vec![0usize; self.ndim()];
138
139        // Initialize index to start positions
140        for (d, r) in ranges.iter().enumerate() {
141            index[d] = r.start;
142        }
143
144        // Iterate over all output elements (odometer on the sliced indices)
145        for _ in 0..new_numel {
146            let flat = index
147                .iter()
148                .zip(self.strides.iter())
149                .map(|(&i, &s)| i * s)
150                .sum::<usize>();
151            data.push(self.data[flat]);
152
153            // Advance the odometer
154            for d in (0..self.ndim()).rev() {
155                index[d] += ranges[d].step;
156                if index[d] < ranges[d].stop {
157                    break;
158                }
159                index[d] = ranges[d].start;
160            }
161        }
162
163        Tensor::from_vec(data, new_shape)
164    }
165
166    /// Select a single index along the given axis, reducing dimensionality by 1.
167    ///
168    /// For a 2-D tensor, `select(0, i)` returns row `i` as a 1-D tensor.
169    ///
170    /// # Examples
171    ///
172    /// ```
173    /// # use scivex_core::Tensor;
174    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
175    /// let row = t.select(0, 1).unwrap();
176    /// assert_eq!(row.shape(), &[3]);
177    /// assert_eq!(row.as_slice(), &[4, 5, 6]);
178    /// ```
179    pub fn select(&self, axis: usize, index: usize) -> Result<Self> {
180        if axis >= self.ndim() {
181            return Err(CoreError::AxisOutOfBounds {
182                axis,
183                ndim: self.ndim(),
184            });
185        }
186        if index >= self.shape[axis] {
187            return Err(CoreError::IndexOutOfBounds {
188                index: vec![index],
189                shape: self.shape.clone(),
190            });
191        }
192
193        let mut ranges: Vec<SliceRange> = self
194            .shape
195            .iter()
196            .map(|&len| SliceRange::full(len))
197            .collect();
198        ranges[axis] = SliceRange::new(index, index + 1, 1);
199
200        let sliced = self.slice(&ranges)?;
201        // Remove the axis that was selected (it has size 1)
202        let mut new_shape: Vec<usize> = sliced.shape().to_vec();
203        new_shape.remove(axis);
204        if new_shape.is_empty() {
205            Ok(Tensor::scalar(sliced.data[0]))
206        } else {
207            let strides = compute_strides(&new_shape);
208            Ok(Tensor {
209                data: sliced.data,
210                shape: new_shape,
211                strides,
212            })
213        }
214    }
215}
216
217// ======================================================================
218// Fancy indexing: integer array indexing, boolean mask indexing, gather/scatter
219// ======================================================================
220
221impl<T: Scalar> Tensor<T> {
222    /// Select elements along an axis using an array of indices.
223    ///
224    /// Like numpy `np.take(arr, indices, axis)`. For a tensor of shape
225    /// `[d0, d1, ..., dk, ...]`, selecting along axis `k` with `indices` of
226    /// length `m` produces a tensor of shape `[d0, ..., m, ..., dn]`.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// # use scivex_core::Tensor;
232    /// let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
233    /// let s = t.index_select(0, &[4, 0, 2]).unwrap();
234    /// assert_eq!(s.as_slice(), &[50, 10, 30]);
235    /// ```
236    pub fn index_select(&self, axis: usize, indices: &[usize]) -> Result<Self> {
237        if axis >= self.ndim() {
238            return Err(CoreError::AxisOutOfBounds {
239                axis,
240                ndim: self.ndim(),
241            });
242        }
243        for &idx in indices {
244            if idx >= self.shape[axis] {
245                return Err(CoreError::IndexOutOfBounds {
246                    index: vec![idx],
247                    shape: self.shape.clone(),
248                });
249            }
250        }
251
252        let mut new_shape = self.shape.clone();
253        new_shape[axis] = indices.len();
254        let new_numel: usize = new_shape.iter().product();
255
256        if new_numel == 0 {
257            return Tensor::from_vec(vec![], new_shape);
258        }
259
260        let mut data = Vec::with_capacity(new_numel);
261
262        // Iterate over every output position
263        let ndim = self.ndim();
264        let mut out_idx = vec![0usize; ndim];
265
266        for _ in 0..new_numel {
267            // Map output index to source index
268            let mut src_flat = 0;
269            for d in 0..ndim {
270                let src_coord = if d == axis {
271                    indices[out_idx[d]]
272                } else {
273                    out_idx[d]
274                };
275                src_flat += src_coord * self.strides[d];
276            }
277            data.push(self.data[src_flat]);
278
279            // Advance odometer
280            for d in (0..ndim).rev() {
281                out_idx[d] += 1;
282                if out_idx[d] < new_shape[d] {
283                    break;
284                }
285                out_idx[d] = 0;
286            }
287        }
288
289        Tensor::from_vec(data, new_shape)
290    }
291
292    /// Select elements where `mask` is true, returning a flat 1-D tensor.
293    ///
294    /// Like numpy `arr[mask]`. The mask length must equal the total number of
295    /// elements in the tensor.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// # use scivex_core::Tensor;
301    /// let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
302    /// let s = t.masked_select(&[true, false, true, false, true]).unwrap();
303    /// assert_eq!(s.as_slice(), &[10, 30, 50]);
304    /// ```
305    pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
306        if mask.len() != self.numel() {
307            return Err(CoreError::InvalidArgument {
308                reason: "mask length must equal tensor element count",
309            });
310        }
311
312        let data: Vec<T> = self
313            .data
314            .iter()
315            .zip(mask.iter())
316            .filter(|&(_, &m)| m)
317            .map(|(&v, _)| v)
318            .collect();
319
320        let len = data.len();
321        Tensor::from_vec(data, vec![len])
322    }
323
324    /// Select rows (slices along axis 0) where `mask` is true.
325    ///
326    /// `mask.len()` must equal `self.shape()[0]`. The result has the same
327    /// number of dimensions, with dimension 0 reduced to the count of `true`
328    /// entries.
329    ///
330    /// # Examples
331    ///
332    /// ```
333    /// # use scivex_core::Tensor;
334    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![3, 2]).unwrap();
335    /// let s = t.masked_select_along(&[false, true, true]).unwrap();
336    /// assert_eq!(s.shape(), &[2, 2]);
337    /// assert_eq!(s.as_slice(), &[3, 4, 5, 6]);
338    /// ```
339    pub fn masked_select_along(&self, mask: &[bool]) -> Result<Self> {
340        if self.ndim() == 0 {
341            return Err(CoreError::InvalidArgument {
342                reason: "cannot mask-select along axis 0 of a scalar tensor",
343            });
344        }
345        if mask.len() != self.shape[0] {
346            return Err(CoreError::InvalidArgument {
347                reason: "mask length must equal shape[0]",
348            });
349        }
350
351        let row_size: usize = self.strides[0]; // number of elements per row
352        let selected: usize = mask.iter().filter(|&&m| m).count();
353
354        let mut new_shape = self.shape.clone();
355        new_shape[0] = selected;
356        let new_numel: usize = new_shape.iter().product();
357
358        let mut data = Vec::with_capacity(new_numel);
359        for (i, &m) in mask.iter().enumerate() {
360            if m {
361                let start = i * row_size;
362                let end = start + row_size;
363                data.extend_from_slice(&self.data[start..end]);
364            }
365        }
366
367        Tensor::from_vec(data, new_shape)
368    }
369
370    /// Set elements along an axis at the given indices.
371    ///
372    /// `values` must have the same shape as the result of
373    /// `self.index_select(axis, indices)`.
374    ///
375    /// # Examples
376    ///
377    /// ```
378    /// # use scivex_core::Tensor;
379    /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
380    /// let vals = Tensor::from_vec(vec![10, 20, 30], vec![1, 3]).unwrap();
381    /// t.index_put(0, &[0], &vals).unwrap();
382    /// assert_eq!(t.as_slice(), &[10, 20, 30, 4, 5, 6]);
383    /// ```
384    pub fn index_put(&mut self, axis: usize, indices: &[usize], values: &Tensor<T>) -> Result<()> {
385        if axis >= self.ndim() {
386            return Err(CoreError::AxisOutOfBounds {
387                axis,
388                ndim: self.ndim(),
389            });
390        }
391        // Validate expected shape
392        let mut expected_shape = self.shape.clone();
393        expected_shape[axis] = indices.len();
394        if values.shape() != expected_shape.as_slice() {
395            return Err(CoreError::DimensionMismatch {
396                expected: expected_shape,
397                got: values.shape().to_vec(),
398            });
399        }
400        for &idx in indices {
401            if idx >= self.shape[axis] {
402                return Err(CoreError::IndexOutOfBounds {
403                    index: vec![idx],
404                    shape: self.shape.clone(),
405                });
406            }
407        }
408
409        let ndim = self.ndim();
410        let val_numel = values.numel();
411
412        if val_numel == 0 {
413            return Ok(());
414        }
415
416        let mut out_idx = vec![0usize; ndim];
417
418        for vi in 0..val_numel {
419            // Map output index to source index in self
420            let mut dst_flat = 0;
421            for d in 0..ndim {
422                let dst_coord = if d == axis {
423                    indices[out_idx[d]]
424                } else {
425                    out_idx[d]
426                };
427                dst_flat += dst_coord * self.strides[d];
428            }
429            self.data[dst_flat] = values.data[vi];
430
431            // Advance odometer using expected_shape
432            for d in (0..ndim).rev() {
433                out_idx[d] += 1;
434                if out_idx[d] < expected_shape[d] {
435                    break;
436                }
437                out_idx[d] = 0;
438            }
439        }
440
441        Ok(())
442    }
443
444    /// Set elements where `mask` is true to corresponding values from `values`.
445    ///
446    /// `mask.len()` must equal `self.numel()`, and `values.len()` must equal
447    /// the number of `true` entries in the mask.
448    ///
449    /// # Examples
450    ///
451    /// ```
452    /// # use scivex_core::Tensor;
453    /// let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5], vec![5]).unwrap();
454    /// t.masked_put(&[false, true, false, true, false], &[99, 88]).unwrap();
455    /// assert_eq!(t.as_slice(), &[1, 99, 3, 88, 5]);
456    /// ```
457    pub fn masked_put(&mut self, mask: &[bool], values: &[T]) -> Result<()> {
458        if mask.len() != self.numel() {
459            return Err(CoreError::InvalidArgument {
460                reason: "mask length must equal tensor element count",
461            });
462        }
463        let true_count = mask.iter().filter(|&&m| m).count();
464        if values.len() != true_count {
465            return Err(CoreError::InvalidArgument {
466                reason: "values length must equal number of true entries in mask",
467            });
468        }
469
470        let mut vi = 0;
471        for (i, &m) in mask.iter().enumerate() {
472            if m {
473                self.data[i] = values[vi];
474                vi += 1;
475            }
476        }
477
478        Ok(())
479    }
480
481    /// Gather values along `axis` using an index tensor.
482    ///
483    /// The `indices` tensor must have the same number of dimensions as `self`,
484    /// and all dimensions except `axis` must match `self`'s shape. The output
485    /// has the same shape as `indices`.
486    ///
487    /// This is equivalent to PyTorch's `torch.gather`.
488    ///
489    /// # Examples
490    ///
491    /// ```
492    /// # use scivex_core::Tensor;
493    /// let t = Tensor::from_vec(vec![10, 20, 30, 40, 50, 60], vec![2, 3]).unwrap();
494    /// let idx = Tensor::from_vec(vec![2, 0, 1, 0], vec![2, 2]).unwrap();
495    /// let g = t.gather(1, &idx).unwrap();
496    /// assert_eq!(g.as_slice(), &[30, 10, 50, 40]);
497    /// ```
498    pub fn gather(&self, axis: usize, indices: &Tensor<usize>) -> Result<Tensor<T>> {
499        if axis >= self.ndim() {
500            return Err(CoreError::AxisOutOfBounds {
501                axis,
502                ndim: self.ndim(),
503            });
504        }
505        if indices.ndim() != self.ndim() {
506            return Err(CoreError::DimensionMismatch {
507                expected: self.shape.clone(),
508                got: indices.shape().to_vec(),
509            });
510        }
511        // All dims except axis must match
512        for d in 0..self.ndim() {
513            if d != axis && indices.shape()[d] != self.shape[d] {
514                return Err(CoreError::DimensionMismatch {
515                    expected: self.shape.clone(),
516                    got: indices.shape().to_vec(),
517                });
518            }
519        }
520
521        let out_shape = indices.shape().to_vec();
522        let out_numel: usize = out_shape.iter().product();
523
524        if out_numel == 0 {
525            return Tensor::from_vec(vec![], out_shape);
526        }
527
528        let ndim = self.ndim();
529        let mut data = Vec::with_capacity(out_numel);
530        let mut out_idx = vec![0usize; ndim];
531        let idx_strides = compute_strides(&out_shape);
532
533        for _ in 0..out_numel {
534            // Read the index value from the indices tensor
535            let idx_flat: usize = out_idx
536                .iter()
537                .zip(idx_strides.iter())
538                .map(|(&i, &s)| i * s)
539                .sum();
540            let gather_idx = indices.data[idx_flat];
541
542            if gather_idx >= self.shape[axis] {
543                return Err(CoreError::IndexOutOfBounds {
544                    index: vec![gather_idx],
545                    shape: self.shape.clone(),
546                });
547            }
548
549            // Compute flat index in self
550            let src_flat: usize = out_idx
551                .iter()
552                .enumerate()
553                .zip(self.strides.iter())
554                .map(|((d, &oi), &s)| {
555                    let coord = if d == axis { gather_idx } else { oi };
556                    coord * s
557                })
558                .sum();
559            data.push(self.data[src_flat]);
560
561            // Advance odometer
562            for d in (0..ndim).rev() {
563                out_idx[d] += 1;
564                if out_idx[d] < out_shape[d] {
565                    break;
566                }
567                out_idx[d] = 0;
568            }
569        }
570
571        Tensor::from_vec(data, out_shape)
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_slice_basic() {
581        // [[1, 2, 3],
582        //  [4, 5, 6],
583        //  [7, 8, 9]]
584        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
585
586        // Slice rows 0..2, cols 1..3 -> [[2, 3], [5, 6]]
587        let s = t
588            .slice(&[SliceRange::range(0, 2), SliceRange::range(1, 3)])
589            .unwrap();
590        assert_eq!(s.shape(), &[2, 2]);
591        assert_eq!(s.as_slice(), &[2, 3, 5, 6]);
592    }
593
594    #[test]
595    fn test_slice_with_step() {
596        let t = Tensor::<i32>::arange(10);
597        // [0, 1, 2, ..., 9] with step 3 -> [0, 3, 6, 9]
598        let s = t.slice(&[SliceRange::new(0, 10, 3)]).unwrap();
599        assert_eq!(s.shape(), &[4]);
600        assert_eq!(s.as_slice(), &[0, 3, 6, 9]);
601    }
602
603    #[test]
604    fn test_slice_full() {
605        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
606        let s = t
607            .slice(&[SliceRange::full(2), SliceRange::full(2)])
608            .unwrap();
609        assert_eq!(s, t);
610    }
611
612    #[test]
613    fn test_select_row() {
614        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
615        let row = t.select(0, 1).unwrap();
616        assert_eq!(row.shape(), &[3]);
617        assert_eq!(row.as_slice(), &[4, 5, 6]);
618    }
619
620    #[test]
621    fn test_select_col() {
622        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
623        let col = t.select(1, 0).unwrap();
624        assert_eq!(col.shape(), &[2]);
625        assert_eq!(col.as_slice(), &[1, 4]);
626    }
627
628    #[test]
629    fn test_select_to_scalar() {
630        let t = Tensor::from_vec(vec![42], vec![1]).unwrap();
631        let s = t.select(0, 0).unwrap();
632        assert_eq!(s.ndim(), 0);
633        assert_eq!(s.as_slice(), &[42]);
634    }
635
636    #[test]
637    fn test_slice_out_of_bounds() {
638        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
639        assert!(
640            t.slice(&[SliceRange::range(0, 3), SliceRange::full(2)])
641                .is_err()
642        );
643    }
644
645    #[test]
646    fn test_select_axis_oob() {
647        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
648        assert!(t.select(1, 0).is_err());
649    }
650
651    #[test]
652    fn test_select_index_oob() {
653        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
654        assert!(t.select(0, 5).is_err());
655    }
656
657    // ------------------------------------------------------------------
658    // Fancy indexing tests
659    // ------------------------------------------------------------------
660
661    #[test]
662    fn test_index_select_1d() {
663        let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
664        let s = t.index_select(0, &[4, 0, 2]).unwrap();
665        assert_eq!(s.shape(), &[3]);
666        assert_eq!(s.as_slice(), &[50, 10, 30]);
667    }
668
669    #[test]
670    fn test_index_select_2d_axis0() {
671        // [[1, 2, 3],
672        //  [4, 5, 6],
673        //  [7, 8, 9]]
674        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
675        // Select rows 2, 0 -> [[7, 8, 9], [1, 2, 3]]
676        let s = t.index_select(0, &[2, 0]).unwrap();
677        assert_eq!(s.shape(), &[2, 3]);
678        assert_eq!(s.as_slice(), &[7, 8, 9, 1, 2, 3]);
679    }
680
681    #[test]
682    fn test_index_select_2d_axis1() {
683        // [[1, 2, 3],
684        //  [4, 5, 6]]
685        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
686        // Select cols 2, 0 -> [[3, 1], [6, 4]]
687        let s = t.index_select(1, &[2, 0]).unwrap();
688        assert_eq!(s.shape(), &[2, 2]);
689        assert_eq!(s.as_slice(), &[3, 1, 6, 4]);
690    }
691
692    #[test]
693    fn test_masked_select_flat() {
694        let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
695        let mask = vec![true, false, true, false, true];
696        let s = t.masked_select(&mask).unwrap();
697        assert_eq!(s.shape(), &[3]);
698        assert_eq!(s.as_slice(), &[10, 30, 50]);
699    }
700
701    #[test]
702    fn test_masked_select_along_rows() {
703        // [[1, 2],
704        //  [3, 4],
705        //  [5, 6]]
706        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![3, 2]).unwrap();
707        let mask = vec![false, true, true];
708        let s = t.masked_select_along(&mask).unwrap();
709        assert_eq!(s.shape(), &[2, 2]);
710        assert_eq!(s.as_slice(), &[3, 4, 5, 6]);
711    }
712
713    #[test]
714    fn test_index_put() {
715        // [[1, 2, 3],
716        //  [4, 5, 6],
717        //  [7, 8, 9]]
718        let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
719        // Replace rows 0 and 2
720        let vals = Tensor::from_vec(vec![10, 20, 30, 70, 80, 90], vec![2, 3]).unwrap();
721        t.index_put(0, &[0, 2], &vals).unwrap();
722        assert_eq!(t.as_slice(), &[10, 20, 30, 4, 5, 6, 70, 80, 90]);
723    }
724
725    #[test]
726    fn test_masked_put() {
727        let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5], vec![5]).unwrap();
728        let mask = vec![false, true, false, true, false];
729        t.masked_put(&mask, &[99, 88]).unwrap();
730        assert_eq!(t.as_slice(), &[1, 99, 3, 88, 5]);
731    }
732
733    #[test]
734    fn test_index_out_of_bounds() {
735        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
736        assert!(t.index_select(0, &[5]).is_err());
737        assert!(t.index_select(1, &[0]).is_err()); // axis OOB
738    }
739}