1use crate::Scalar;
4use crate::error::{CoreError, Result};
5
6use super::{Tensor, compute_strides};
7
8#[cfg_attr(
21 feature = "serde-support",
22 derive(serde::Serialize, serde::Deserialize)
23)]
24#[derive(Debug, Clone, Copy)]
25pub struct SliceRange {
26 pub start: usize,
27 pub stop: usize,
28 pub step: usize,
29}
30
31impl SliceRange {
32 #[allow(clippy::similar_names)]
47 pub fn new(start: usize, stop: usize, step: usize) -> Self {
48 debug_assert!(step > 0, "slice step must be > 0");
49 Self { start, stop, step }
50 }
51
52 pub fn range(start: usize, stop: usize) -> Self {
62 Self::new(start, stop, 1)
63 }
64
65 pub fn full(len: usize) -> Self {
77 Self::new(0, len, 1)
78 }
79
80 fn len(&self) -> usize {
82 if self.stop <= self.start {
83 0
84 } else {
85 (self.stop - self.start).div_ceil(self.step)
86 }
87 }
88}
89
90impl<T: Scalar> Tensor<T> {
91 pub fn slice(&self, ranges: &[SliceRange]) -> Result<Self> {
108 if ranges.len() != self.ndim() {
109 return Err(CoreError::InvalidArgument {
110 reason: "number of slice ranges must match tensor rank",
111 });
112 }
113
114 for (d, r) in ranges.iter().enumerate() {
116 if r.stop > self.shape[d] {
117 return Err(CoreError::IndexOutOfBounds {
118 index: vec![r.stop],
119 shape: self.shape.clone(),
120 });
121 }
122 if r.step == 0 {
123 return Err(CoreError::InvalidArgument {
124 reason: "slice step must be > 0",
125 });
126 }
127 }
128
129 let new_shape: Vec<usize> = ranges.iter().map(SliceRange::len).collect();
130 let new_numel: usize = new_shape.iter().product();
131
132 if new_numel == 0 {
133 return Tensor::from_vec(vec![], new_shape);
134 }
135
136 let mut data = Vec::with_capacity(new_numel);
137 let mut index = vec![0usize; self.ndim()];
138
139 for (d, r) in ranges.iter().enumerate() {
141 index[d] = r.start;
142 }
143
144 for _ in 0..new_numel {
146 let flat = index
147 .iter()
148 .zip(self.strides.iter())
149 .map(|(&i, &s)| i * s)
150 .sum::<usize>();
151 data.push(self.data[flat]);
152
153 for d in (0..self.ndim()).rev() {
155 index[d] += ranges[d].step;
156 if index[d] < ranges[d].stop {
157 break;
158 }
159 index[d] = ranges[d].start;
160 }
161 }
162
163 Tensor::from_vec(data, new_shape)
164 }
165
166 pub fn select(&self, axis: usize, index: usize) -> Result<Self> {
180 if axis >= self.ndim() {
181 return Err(CoreError::AxisOutOfBounds {
182 axis,
183 ndim: self.ndim(),
184 });
185 }
186 if index >= self.shape[axis] {
187 return Err(CoreError::IndexOutOfBounds {
188 index: vec![index],
189 shape: self.shape.clone(),
190 });
191 }
192
193 let mut ranges: Vec<SliceRange> = self
194 .shape
195 .iter()
196 .map(|&len| SliceRange::full(len))
197 .collect();
198 ranges[axis] = SliceRange::new(index, index + 1, 1);
199
200 let sliced = self.slice(&ranges)?;
201 let mut new_shape: Vec<usize> = sliced.shape().to_vec();
203 new_shape.remove(axis);
204 if new_shape.is_empty() {
205 Ok(Tensor::scalar(sliced.data[0]))
206 } else {
207 let strides = compute_strides(&new_shape);
208 Ok(Tensor {
209 data: sliced.data,
210 shape: new_shape,
211 strides,
212 })
213 }
214 }
215}
216
217impl<T: Scalar> Tensor<T> {
222 pub fn index_select(&self, axis: usize, indices: &[usize]) -> Result<Self> {
237 if axis >= self.ndim() {
238 return Err(CoreError::AxisOutOfBounds {
239 axis,
240 ndim: self.ndim(),
241 });
242 }
243 for &idx in indices {
244 if idx >= self.shape[axis] {
245 return Err(CoreError::IndexOutOfBounds {
246 index: vec![idx],
247 shape: self.shape.clone(),
248 });
249 }
250 }
251
252 let mut new_shape = self.shape.clone();
253 new_shape[axis] = indices.len();
254 let new_numel: usize = new_shape.iter().product();
255
256 if new_numel == 0 {
257 return Tensor::from_vec(vec![], new_shape);
258 }
259
260 let mut data = Vec::with_capacity(new_numel);
261
262 let ndim = self.ndim();
264 let mut out_idx = vec![0usize; ndim];
265
266 for _ in 0..new_numel {
267 let mut src_flat = 0;
269 for d in 0..ndim {
270 let src_coord = if d == axis {
271 indices[out_idx[d]]
272 } else {
273 out_idx[d]
274 };
275 src_flat += src_coord * self.strides[d];
276 }
277 data.push(self.data[src_flat]);
278
279 for d in (0..ndim).rev() {
281 out_idx[d] += 1;
282 if out_idx[d] < new_shape[d] {
283 break;
284 }
285 out_idx[d] = 0;
286 }
287 }
288
289 Tensor::from_vec(data, new_shape)
290 }
291
292 pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
306 if mask.len() != self.numel() {
307 return Err(CoreError::InvalidArgument {
308 reason: "mask length must equal tensor element count",
309 });
310 }
311
312 let data: Vec<T> = self
313 .data
314 .iter()
315 .zip(mask.iter())
316 .filter(|&(_, &m)| m)
317 .map(|(&v, _)| v)
318 .collect();
319
320 let len = data.len();
321 Tensor::from_vec(data, vec![len])
322 }
323
324 pub fn masked_select_along(&self, mask: &[bool]) -> Result<Self> {
340 if self.ndim() == 0 {
341 return Err(CoreError::InvalidArgument {
342 reason: "cannot mask-select along axis 0 of a scalar tensor",
343 });
344 }
345 if mask.len() != self.shape[0] {
346 return Err(CoreError::InvalidArgument {
347 reason: "mask length must equal shape[0]",
348 });
349 }
350
351 let row_size: usize = self.strides[0]; let selected: usize = mask.iter().filter(|&&m| m).count();
353
354 let mut new_shape = self.shape.clone();
355 new_shape[0] = selected;
356 let new_numel: usize = new_shape.iter().product();
357
358 let mut data = Vec::with_capacity(new_numel);
359 for (i, &m) in mask.iter().enumerate() {
360 if m {
361 let start = i * row_size;
362 let end = start + row_size;
363 data.extend_from_slice(&self.data[start..end]);
364 }
365 }
366
367 Tensor::from_vec(data, new_shape)
368 }
369
370 pub fn index_put(&mut self, axis: usize, indices: &[usize], values: &Tensor<T>) -> Result<()> {
385 if axis >= self.ndim() {
386 return Err(CoreError::AxisOutOfBounds {
387 axis,
388 ndim: self.ndim(),
389 });
390 }
391 let mut expected_shape = self.shape.clone();
393 expected_shape[axis] = indices.len();
394 if values.shape() != expected_shape.as_slice() {
395 return Err(CoreError::DimensionMismatch {
396 expected: expected_shape,
397 got: values.shape().to_vec(),
398 });
399 }
400 for &idx in indices {
401 if idx >= self.shape[axis] {
402 return Err(CoreError::IndexOutOfBounds {
403 index: vec![idx],
404 shape: self.shape.clone(),
405 });
406 }
407 }
408
409 let ndim = self.ndim();
410 let val_numel = values.numel();
411
412 if val_numel == 0 {
413 return Ok(());
414 }
415
416 let mut out_idx = vec![0usize; ndim];
417
418 for vi in 0..val_numel {
419 let mut dst_flat = 0;
421 for d in 0..ndim {
422 let dst_coord = if d == axis {
423 indices[out_idx[d]]
424 } else {
425 out_idx[d]
426 };
427 dst_flat += dst_coord * self.strides[d];
428 }
429 self.data[dst_flat] = values.data[vi];
430
431 for d in (0..ndim).rev() {
433 out_idx[d] += 1;
434 if out_idx[d] < expected_shape[d] {
435 break;
436 }
437 out_idx[d] = 0;
438 }
439 }
440
441 Ok(())
442 }
443
444 pub fn masked_put(&mut self, mask: &[bool], values: &[T]) -> Result<()> {
458 if mask.len() != self.numel() {
459 return Err(CoreError::InvalidArgument {
460 reason: "mask length must equal tensor element count",
461 });
462 }
463 let true_count = mask.iter().filter(|&&m| m).count();
464 if values.len() != true_count {
465 return Err(CoreError::InvalidArgument {
466 reason: "values length must equal number of true entries in mask",
467 });
468 }
469
470 let mut vi = 0;
471 for (i, &m) in mask.iter().enumerate() {
472 if m {
473 self.data[i] = values[vi];
474 vi += 1;
475 }
476 }
477
478 Ok(())
479 }
480
481 pub fn gather(&self, axis: usize, indices: &Tensor<usize>) -> Result<Tensor<T>> {
499 if axis >= self.ndim() {
500 return Err(CoreError::AxisOutOfBounds {
501 axis,
502 ndim: self.ndim(),
503 });
504 }
505 if indices.ndim() != self.ndim() {
506 return Err(CoreError::DimensionMismatch {
507 expected: self.shape.clone(),
508 got: indices.shape().to_vec(),
509 });
510 }
511 for d in 0..self.ndim() {
513 if d != axis && indices.shape()[d] != self.shape[d] {
514 return Err(CoreError::DimensionMismatch {
515 expected: self.shape.clone(),
516 got: indices.shape().to_vec(),
517 });
518 }
519 }
520
521 let out_shape = indices.shape().to_vec();
522 let out_numel: usize = out_shape.iter().product();
523
524 if out_numel == 0 {
525 return Tensor::from_vec(vec![], out_shape);
526 }
527
528 let ndim = self.ndim();
529 let mut data = Vec::with_capacity(out_numel);
530 let mut out_idx = vec![0usize; ndim];
531 let idx_strides = compute_strides(&out_shape);
532
533 for _ in 0..out_numel {
534 let idx_flat: usize = out_idx
536 .iter()
537 .zip(idx_strides.iter())
538 .map(|(&i, &s)| i * s)
539 .sum();
540 let gather_idx = indices.data[idx_flat];
541
542 if gather_idx >= self.shape[axis] {
543 return Err(CoreError::IndexOutOfBounds {
544 index: vec![gather_idx],
545 shape: self.shape.clone(),
546 });
547 }
548
549 let src_flat: usize = out_idx
551 .iter()
552 .enumerate()
553 .zip(self.strides.iter())
554 .map(|((d, &oi), &s)| {
555 let coord = if d == axis { gather_idx } else { oi };
556 coord * s
557 })
558 .sum();
559 data.push(self.data[src_flat]);
560
561 for d in (0..ndim).rev() {
563 out_idx[d] += 1;
564 if out_idx[d] < out_shape[d] {
565 break;
566 }
567 out_idx[d] = 0;
568 }
569 }
570
571 Tensor::from_vec(data, out_shape)
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 #[test]
580 fn test_slice_basic() {
581 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
585
586 let s = t
588 .slice(&[SliceRange::range(0, 2), SliceRange::range(1, 3)])
589 .unwrap();
590 assert_eq!(s.shape(), &[2, 2]);
591 assert_eq!(s.as_slice(), &[2, 3, 5, 6]);
592 }
593
594 #[test]
595 fn test_slice_with_step() {
596 let t = Tensor::<i32>::arange(10);
597 let s = t.slice(&[SliceRange::new(0, 10, 3)]).unwrap();
599 assert_eq!(s.shape(), &[4]);
600 assert_eq!(s.as_slice(), &[0, 3, 6, 9]);
601 }
602
603 #[test]
604 fn test_slice_full() {
605 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
606 let s = t
607 .slice(&[SliceRange::full(2), SliceRange::full(2)])
608 .unwrap();
609 assert_eq!(s, t);
610 }
611
612 #[test]
613 fn test_select_row() {
614 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
615 let row = t.select(0, 1).unwrap();
616 assert_eq!(row.shape(), &[3]);
617 assert_eq!(row.as_slice(), &[4, 5, 6]);
618 }
619
620 #[test]
621 fn test_select_col() {
622 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
623 let col = t.select(1, 0).unwrap();
624 assert_eq!(col.shape(), &[2]);
625 assert_eq!(col.as_slice(), &[1, 4]);
626 }
627
628 #[test]
629 fn test_select_to_scalar() {
630 let t = Tensor::from_vec(vec![42], vec![1]).unwrap();
631 let s = t.select(0, 0).unwrap();
632 assert_eq!(s.ndim(), 0);
633 assert_eq!(s.as_slice(), &[42]);
634 }
635
636 #[test]
637 fn test_slice_out_of_bounds() {
638 let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
639 assert!(
640 t.slice(&[SliceRange::range(0, 3), SliceRange::full(2)])
641 .is_err()
642 );
643 }
644
645 #[test]
646 fn test_select_axis_oob() {
647 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
648 assert!(t.select(1, 0).is_err());
649 }
650
651 #[test]
652 fn test_select_index_oob() {
653 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
654 assert!(t.select(0, 5).is_err());
655 }
656
657 #[test]
662 fn test_index_select_1d() {
663 let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
664 let s = t.index_select(0, &[4, 0, 2]).unwrap();
665 assert_eq!(s.shape(), &[3]);
666 assert_eq!(s.as_slice(), &[50, 10, 30]);
667 }
668
669 #[test]
670 fn test_index_select_2d_axis0() {
671 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
675 let s = t.index_select(0, &[2, 0]).unwrap();
677 assert_eq!(s.shape(), &[2, 3]);
678 assert_eq!(s.as_slice(), &[7, 8, 9, 1, 2, 3]);
679 }
680
681 #[test]
682 fn test_index_select_2d_axis1() {
683 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
686 let s = t.index_select(1, &[2, 0]).unwrap();
688 assert_eq!(s.shape(), &[2, 2]);
689 assert_eq!(s.as_slice(), &[3, 1, 6, 4]);
690 }
691
692 #[test]
693 fn test_masked_select_flat() {
694 let t = Tensor::from_vec(vec![10, 20, 30, 40, 50], vec![5]).unwrap();
695 let mask = vec![true, false, true, false, true];
696 let s = t.masked_select(&mask).unwrap();
697 assert_eq!(s.shape(), &[3]);
698 assert_eq!(s.as_slice(), &[10, 30, 50]);
699 }
700
701 #[test]
702 fn test_masked_select_along_rows() {
703 let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![3, 2]).unwrap();
707 let mask = vec![false, true, true];
708 let s = t.masked_select_along(&mask).unwrap();
709 assert_eq!(s.shape(), &[2, 2]);
710 assert_eq!(s.as_slice(), &[3, 4, 5, 6]);
711 }
712
713 #[test]
714 fn test_index_put() {
715 let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], vec![3, 3]).unwrap();
719 let vals = Tensor::from_vec(vec![10, 20, 30, 70, 80, 90], vec![2, 3]).unwrap();
721 t.index_put(0, &[0, 2], &vals).unwrap();
722 assert_eq!(t.as_slice(), &[10, 20, 30, 4, 5, 6, 70, 80, 90]);
723 }
724
725 #[test]
726 fn test_masked_put() {
727 let mut t = Tensor::from_vec(vec![1, 2, 3, 4, 5], vec![5]).unwrap();
728 let mask = vec![false, true, false, true, false];
729 t.masked_put(&mask, &[99, 88]).unwrap();
730 assert_eq!(t.as_slice(), &[1, 99, 3, 88, 5]);
731 }
732
733 #[test]
734 fn test_index_out_of_bounds() {
735 let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
736 assert!(t.index_select(0, &[5]).is_err());
737 assert!(t.index_select(1, &[0]).is_err()); }
739}