1use crate::{Tensor, TensorElement};
4use torsh_core::error::{Result, TorshError};
5
6#[derive(Debug, Clone)]
8pub enum TensorIndex {
9 Index(i64),
11 Range(Option<i64>, Option<i64>, Option<i64>), All,
15 List(Vec<i64>),
17 Mask(Tensor<bool>),
19 Ellipsis,
21 NewAxis,
23}
24
25impl TensorIndex {
26 pub fn range(start: Option<i64>, stop: Option<i64>) -> Self {
28 TensorIndex::Range(start, stop, None)
29 }
30
31 pub fn range_step(start: Option<i64>, stop: Option<i64>, step: i64) -> Self {
33 TensorIndex::Range(start, stop, Some(step))
34 }
35}
36
37impl<T: TensorElement> Tensor<T> {
39 pub fn index(&self, indices: &[TensorIndex]) -> Result<Self> {
41 let consuming_indices = indices
43 .iter()
44 .filter(|idx| !matches!(idx, TensorIndex::NewAxis | TensorIndex::Ellipsis))
45 .count();
46
47 if consuming_indices > self.ndim() {
48 return Err(TorshError::InvalidArgument(format!(
49 "Too many indices for tensor: tensor has {} dimensions but {} consuming indices were provided",
50 self.ndim(),
51 consuming_indices
52 )));
53 }
54
55 let expanded_indices = self.expand_ellipsis(indices)?;
57
58 let mut output_shape = Vec::new();
60 let mut slices = Vec::new();
61 let mut input_dim_idx = 0; for index in expanded_indices.iter() {
64 if let TensorIndex::NewAxis = index {
65 output_shape.push(1);
67 slices.push((0, 1, 1));
68 continue;
70 }
71
72 let dim_size = if input_dim_idx < self.ndim() {
74 self.shape().dims()[input_dim_idx]
75 } else {
76 return Err(TorshError::InvalidArgument(format!(
77 "Index {} beyond tensor dimensions (tensor has {} dimensions)",
78 input_dim_idx,
79 self.ndim()
80 )));
81 };
82
83 match index {
84 TensorIndex::Index(idx) => {
85 let idx = if *idx < 0 {
87 (dim_size as i64 + idx) as usize
88 } else {
89 *idx as usize
90 };
91
92 if idx >= dim_size {
93 return Err(TorshError::IndexOutOfBounds {
94 index: idx,
95 size: dim_size,
96 });
97 }
98
99 slices.push((idx, idx + 1, 1));
100 input_dim_idx += 1;
102 }
103 TensorIndex::Range(start, stop, step) => {
104 let step = step.unwrap_or(1);
105 if step == 0 {
106 return Err(TorshError::InvalidArgument(
107 "Step cannot be zero".to_string(),
108 ));
109 }
110
111 let start = start
112 .map(|s| {
113 if s < 0 {
114 (dim_size as i64 + s).max(0) as usize
115 } else {
116 s.min(dim_size as i64) as usize
117 }
118 })
119 .unwrap_or(0);
120
121 let stop = stop
122 .map(|s| {
123 if s < 0 {
124 (dim_size as i64 + s).max(0) as usize
125 } else {
126 s.min(dim_size as i64) as usize
127 }
128 })
129 .unwrap_or(dim_size);
130
131 let size = if step > 0 {
132 ((stop as i64 - start as i64 + step - 1) / step).max(0) as usize
133 } else {
134 ((stop as i64 - start as i64 + step + 1) / step).max(0) as usize
135 };
136
137 output_shape.push(size);
138 slices.push((start, stop, step as usize));
139 input_dim_idx += 1;
140 }
141 TensorIndex::All => {
142 output_shape.push(dim_size);
143 slices.push((0, dim_size, 1));
144 input_dim_idx += 1;
145 }
146 TensorIndex::List(indices_list) => {
147 for &idx in indices_list {
149 let normalized_idx = if idx < 0 {
150 (dim_size as i64 + idx) as usize
151 } else {
152 idx as usize
153 };
154
155 if normalized_idx >= dim_size {
156 return Err(TorshError::IndexOutOfBounds {
157 index: normalized_idx,
158 size: dim_size,
159 });
160 }
161 }
162
163 output_shape.push(indices_list.len());
164 slices.push((0, indices_list.len(), 0)); input_dim_idx += 1;
167 }
168 TensorIndex::Mask(mask) => {
169 if mask.ndim() != 1 {
171 return Err(TorshError::InvalidArgument(
172 "Boolean mask must be 1D for single dimension indexing".to_string(),
173 ));
174 }
175
176 if mask.numel() != dim_size {
177 return Err(TorshError::ShapeMismatch {
178 expected: vec![dim_size],
179 got: mask.shape().dims().to_vec(),
180 });
181 }
182
183 let mask_data = mask.to_vec()?;
185 let true_count = mask_data.iter().filter(|&&x| x).count();
186
187 output_shape.push(true_count);
188 slices.push((0, true_count, 0)); input_dim_idx += 1;
191 }
192 TensorIndex::NewAxis => {
193 return Err(TorshError::InvalidArgument(
195 "NewAxis should be handled before this point".to_string(),
196 ));
197 }
198 TensorIndex::Ellipsis => {
199 return Err(TorshError::InvalidArgument(
201 "Ellipsis should be expanded before processing".to_string(),
202 ));
203 }
204 }
205 }
206
207 if output_shape.is_empty() {
209 output_shape.push(1);
210 }
211
212 if expanded_indices
214 .iter()
215 .any(|idx| matches!(idx, TensorIndex::List(_) | TensorIndex::Mask(_)))
216 {
217 self.extract_advanced_indexing(&expanded_indices, &output_shape)
218 } else {
219 self.extract_basic_indexing(&expanded_indices, &output_shape, &slices)
220 }
221 }
222
223 fn extract_basic_indexing(
225 &self,
226 indices: &[TensorIndex],
227 output_shape: &[usize],
228 slices: &[(usize, usize, usize)],
229 ) -> Result<Self> {
230 let input_data = self.to_vec()?;
231
232 let output_size = output_shape.iter().product();
233 let mut output_data = Vec::with_capacity(output_size);
234
235 let input_strides = self.compute_strides();
236 let output_strides = compute_strides_from_shape(output_shape);
237
238 for out_idx in 0..output_size {
239 let mut out_indices = vec![0; output_shape.len()];
241 let mut remaining = out_idx;
242 for (i, &stride) in output_strides.iter().enumerate() {
243 out_indices[i] = remaining / stride;
244 remaining %= stride;
245 }
246
247 let mut input_flat_idx = 0;
249 let mut out_dim = 0;
250 let mut input_dim = 0;
251
252 for (slice_idx, &(start, _, step)) in slices.iter().enumerate() {
253 if slice_idx < indices.len() && matches!(indices[slice_idx], TensorIndex::NewAxis) {
255 out_dim += 1;
256 continue;
257 }
258
259 if input_dim >= input_strides.len() {
261 break;
262 }
263
264 let idx = if slice_idx < indices.len()
265 && matches!(indices[slice_idx], TensorIndex::Index(_))
266 {
267 start
268 } else {
269 start + out_indices[out_dim] * step
270 };
271 input_flat_idx += idx * input_strides[input_dim];
272
273 if !(slice_idx < indices.len()
274 && matches!(indices[slice_idx], TensorIndex::Index(_)))
275 {
276 out_dim += 1;
277 }
278 input_dim += 1;
279 }
280
281 output_data.push(input_data[input_flat_idx]);
282 }
283
284 Self::from_data(output_data, output_shape.to_vec(), self.device)
285 }
286
287 fn extract_advanced_indexing(
289 &self,
290 indices: &[TensorIndex],
291 output_shape: &[usize],
292 ) -> Result<Self> {
293 let input_data = self.to_vec()?;
294
295 let output_size = output_shape.iter().product();
296 let mut output_data = Vec::with_capacity(output_size);
297
298 let input_strides = self.compute_strides();
299 let output_strides = compute_strides_from_shape(output_shape);
300
301 for out_idx in 0..output_size {
302 let mut out_indices = vec![0; output_shape.len()];
304 let mut remaining = out_idx;
305 for (i, &stride) in output_strides.iter().enumerate() {
306 out_indices[i] = remaining / stride;
307 remaining %= stride;
308 }
309
310 let mut input_flat_idx = 0;
312 let mut out_dim = 0;
313
314 for (dim_idx, index) in indices.iter().enumerate() {
315 if dim_idx >= self.ndim() {
316 break;
317 }
318
319 let input_idx = match index {
320 TensorIndex::Index(idx) => {
321 let dim_size = self.shape().dims()[dim_idx];
322
323 if *idx < 0 {
324 (dim_size as i64 + idx) as usize
325 } else {
326 *idx as usize
327 }
328 }
329 TensorIndex::Range(start, _stop, step) => {
330 let dim_size = self.shape().dims()[dim_idx];
331 let step = step.unwrap_or(1);
332 let start = start
333 .map(|s| {
334 if s < 0 {
335 (dim_size as i64 + s).max(0) as usize
336 } else {
337 s.min(dim_size as i64) as usize
338 }
339 })
340 .unwrap_or(0);
341
342 start + out_indices[out_dim] * (step as usize)
343 }
344 TensorIndex::All => out_indices[out_dim],
345 TensorIndex::List(indices_list) => {
346 let list_idx = out_indices[out_dim];
348 if list_idx >= indices_list.len() {
349 return Err(TorshError::IndexOutOfBounds {
350 index: list_idx,
351 size: indices_list.len(),
352 });
353 }
354
355 let actual_idx = indices_list[list_idx];
356 let dim_size = self.shape().dims()[dim_idx];
357
358 if actual_idx < 0 {
359 (dim_size as i64 + actual_idx) as usize
360 } else {
361 actual_idx as usize
362 }
363 }
364 TensorIndex::Mask(mask) => {
365 let mask_data = mask.to_vec()?;
367
368 let target_true_idx = out_indices[out_dim];
370 let mut true_count = 0;
371 let mut found_idx = None;
372 for (i, &mask_val) in mask_data.iter().enumerate() {
373 if mask_val {
374 if true_count == target_true_idx {
375 found_idx = Some(i);
376 break;
377 }
378 true_count += 1;
379 }
380 }
381
382 match found_idx {
383 Some(idx) => idx,
384 None => {
385 return Err(TorshError::IndexOutOfBounds {
386 index: target_true_idx,
387 size: true_count,
388 });
389 }
390 }
391 }
392 TensorIndex::NewAxis => {
393 continue;
395 }
396 TensorIndex::Ellipsis => {
397 out_indices[out_dim]
399 }
400 };
401
402 input_flat_idx += input_idx * input_strides[dim_idx];
403
404 if !matches!(index, TensorIndex::Index(_) | TensorIndex::NewAxis) {
406 out_dim += 1;
407 }
408 }
409
410 for stride in input_strides
412 .iter()
413 .skip(indices.len())
414 .take(self.ndim() - indices.len())
415 {
416 if out_dim < out_indices.len() {
417 input_flat_idx += out_indices[out_dim] * stride;
418 out_dim += 1;
419 }
420 }
421
422 if input_flat_idx >= input_data.len() {
423 return Err(TorshError::IndexOutOfBounds {
424 index: input_flat_idx,
425 size: input_data.len(),
426 });
427 }
428
429 output_data.push(input_data[input_flat_idx]);
430 }
431
432 Self::from_data(output_data, output_shape.to_vec(), self.device)
433 }
434
435 fn expand_ellipsis(&self, indices: &[TensorIndex]) -> Result<Vec<TensorIndex>> {
437 let mut expanded = Vec::new();
438 let mut found_ellipsis = false;
439
440 let non_expanding_indices = indices
442 .iter()
443 .filter(|idx| !matches!(idx, TensorIndex::Ellipsis | TensorIndex::NewAxis))
444 .count();
445
446 for index in indices {
447 match index {
448 TensorIndex::Ellipsis => {
449 if found_ellipsis {
450 return Err(TorshError::InvalidArgument(
451 "Only one ellipsis (...) is allowed per indexing operation".to_string(),
452 ));
453 }
454 found_ellipsis = true;
455
456 let ellipsis_dims = if self.ndim() >= non_expanding_indices {
458 self.ndim() - non_expanding_indices
459 } else {
460 0
461 };
462
463 for _ in 0..ellipsis_dims {
465 expanded.push(TensorIndex::All);
466 }
467 }
468 _ => {
469 expanded.push(index.clone());
470 }
471 }
472 }
473
474 if !found_ellipsis {
476 let current_dims = expanded
477 .iter()
478 .filter(|idx| !matches!(idx, TensorIndex::NewAxis))
479 .count();
480
481 for _ in current_dims..self.ndim() {
482 expanded.push(TensorIndex::All);
483 }
484 }
485
486 Ok(expanded)
487 }
488
489 pub fn get_1d(&self, index: usize) -> Result<T> {
491 if self.ndim() != 1 {
492 return Err(TorshError::InvalidShape(
493 "get_1d() can only be used on 1D tensors".to_string(),
494 ));
495 }
496
497 if index >= self.shape().dims()[0] {
498 return Err(TorshError::IndexOutOfBounds {
499 index,
500 size: self.shape().dims()[0],
501 });
502 }
503
504 let data = self.data()?;
505 Ok(data[index])
506 }
507
508 pub fn get_2d(&self, row: usize, col: usize) -> Result<T> {
510 if self.ndim() != 2 {
511 return Err(TorshError::InvalidShape(
512 "get_2d() can only be used on 2D tensors".to_string(),
513 ));
514 }
515
516 let shape = self.shape();
517 if row >= shape.dims()[0] || col >= shape.dims()[1] {
518 return Err(TorshError::IndexOutOfBounds {
519 index: row * shape.dims()[1] + col,
520 size: shape.numel(),
521 });
522 }
523
524 let data = self.to_vec()?;
525
526 let index = row * shape.dims()[1] + col;
527 Ok(data[index])
528 }
529
530 pub fn get_3d(&self, x: usize, y: usize, z: usize) -> Result<T> {
532 if self.ndim() != 3 {
533 return Err(TorshError::InvalidShape(
534 "get_3d() can only be used on 3D tensors".to_string(),
535 ));
536 }
537
538 let shape = self.shape();
539 if x >= shape.dims()[0] || y >= shape.dims()[1] || z >= shape.dims()[2] {
540 return Err(TorshError::IndexOutOfBounds {
541 index: x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z,
542 size: shape.numel(),
543 });
544 }
545
546 let data = self.to_vec()?;
547
548 let index = x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z;
549 Ok(data[index])
550 }
551
552 pub fn set_1d(&mut self, index: usize, value: T) -> Result<()> {
554 if self.ndim() != 1 {
555 return Err(TorshError::InvalidShape(
556 "set_1d() can only be used on 1D tensors".to_string(),
557 ));
558 }
559
560 if index >= self.shape().dims()[0] {
561 return Err(TorshError::IndexOutOfBounds {
562 index,
563 size: self.shape().dims()[0],
564 });
565 }
566
567 let mut data = self.to_vec()?;
568 data[index] = value;
569 *self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
570 Ok(())
571 }
572
573 pub fn set_2d(&mut self, row: usize, col: usize, value: T) -> Result<()> {
575 if self.ndim() != 2 {
576 return Err(TorshError::InvalidShape(
577 "set_2d() can only be used on 2D tensors".to_string(),
578 ));
579 }
580
581 let shape = self.shape();
582 if row >= shape.dims()[0] || col >= shape.dims()[1] {
583 return Err(TorshError::IndexOutOfBounds {
584 index: row * shape.dims()[1] + col,
585 size: shape.numel(),
586 });
587 }
588
589 let mut data = self.to_vec()?;
590 let index = row * shape.dims()[1] + col;
591 data[index] = value;
592 *self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
593 Ok(())
594 }
595
596 pub fn set_3d(&mut self, x: usize, y: usize, z: usize, value: T) -> Result<()> {
598 if self.ndim() != 3 {
599 return Err(TorshError::InvalidShape(
600 "set_3d() can only be used on 3D tensors".to_string(),
601 ));
602 }
603
604 let shape = self.shape();
605 if x >= shape.dims()[0] || y >= shape.dims()[1] || z >= shape.dims()[2] {
606 return Err(TorshError::IndexOutOfBounds {
607 index: x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z,
608 size: shape.numel(),
609 });
610 }
611
612 let mut data = self.to_vec()?;
613 let index = x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z;
614 data[index] = value;
615 *self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
616 Ok(())
617 }
618
619 pub fn select(&self, dim: i32, index: i64) -> Result<Self> {
621 let ndim = self.ndim() as i32;
622 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
623
624 if dim >= self.ndim() {
625 return Err(TorshError::InvalidArgument(format!(
626 "Dimension {} out of range for tensor with {} dimensions",
627 dim,
628 self.ndim()
629 )));
630 }
631
632 let dim_size = self.shape().dims()[dim] as i64;
633 let index = if index < 0 { dim_size + index } else { index };
634
635 if index < 0 || index >= dim_size {
636 return Err(TorshError::IndexOutOfBounds {
637 index: index as usize,
638 size: dim_size as usize,
639 });
640 }
641
642 let mut indices = Vec::new();
644 for d in 0..self.ndim() {
645 if d == dim {
646 indices.push(TensorIndex::Index(index));
647 } else {
648 indices.push(TensorIndex::All);
649 }
650 }
651
652 self.index(&indices)
654 }
655
656 pub fn slice_with_step(
658 &self,
659 dim: i32,
660 start: Option<i64>,
661 end: Option<i64>,
662 step: Option<i64>,
663 ) -> Result<Self> {
664 let ndim = self.ndim() as i32;
665 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
666
667 if dim >= self.ndim() {
668 return Err(TorshError::InvalidArgument(format!(
669 "Dimension {} out of range for tensor with {} dimensions",
670 dim,
671 self.ndim()
672 )));
673 }
674
675 let mut indices = Vec::new();
677 for d in 0..self.ndim() {
678 if d == dim {
679 indices.push(TensorIndex::Range(start, end, step));
680 } else {
681 indices.push(TensorIndex::All);
682 }
683 }
684
685 self.index(&indices)
687 }
688
689 pub fn narrow(&self, dim: i32, start: i64, length: usize) -> Result<Self> {
691 let ndim = self.ndim() as i32;
692 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
693
694 if dim >= self.ndim() {
695 return Err(TorshError::InvalidArgument(format!(
696 "Dimension {} out of range for tensor with {} dimensions",
697 dim,
698 self.ndim()
699 )));
700 }
701
702 let dim_size = self.shape().dims()[dim] as i64;
703 let start = if start < 0 { dim_size + start } else { start };
704
705 if start < 0 || start >= dim_size {
706 return Err(TorshError::InvalidArgument(format!(
707 "Start index {start} out of range for dimension {dim} with size {dim_size}"
708 )));
709 }
710
711 let end = start + length as i64;
712 if end > dim_size {
713 return Err(TorshError::InvalidArgument(format!(
714 "End index {end} out of range for dimension {dim} with size {dim_size}"
715 )));
716 }
717
718 let mut indices = Vec::new();
720 for d in 0..self.ndim() {
721 if d == dim {
722 indices.push(TensorIndex::Range(Some(start), Some(end), None));
723 } else {
724 indices.push(TensorIndex::All);
725 }
726 }
727
728 self.index(&indices)
730 }
731
732 pub fn masked_select(&self, mask: &Tensor<bool>) -> Result<Self> {
734 if self.shape() != mask.shape() {
735 return Err(TorshError::ShapeMismatch {
736 expected: self.shape().dims().to_vec(),
737 got: mask.shape().dims().to_vec(),
738 });
739 }
740
741 let self_data = self.data()?;
742 let mask_data = mask.data()?;
743
744 let mut selected_data = Vec::new();
746 for (i, &mask_val) in mask_data.iter().enumerate() {
747 if mask_val {
748 selected_data.push(self_data[i]);
749 }
750 }
751
752 Self::from_data(
754 selected_data.clone(),
755 vec![selected_data.len()],
756 self.device,
757 )
758 }
759
760 pub fn take(&self, indices: &Tensor<i64>) -> Result<Self> {
761 let self_data = self.data()?;
762
763 let indices_data = indices.data()?;
764
765 let self_size = self.shape().numel();
766 let output_shape = indices.shape().dims().to_vec();
767 let output_size = indices.shape().numel();
768 let mut output_data = Vec::with_capacity(output_size);
769
770 for &idx in indices_data.iter() {
772 let idx = if idx < 0 {
773 (self_size as i64 + idx) as usize
774 } else {
775 idx as usize
776 };
777
778 if idx >= self_size {
779 return Err(TorshError::IndexOutOfBounds {
780 index: idx,
781 size: self_size,
782 });
783 }
784
785 output_data.push(self_data[idx]);
786 }
787
788 Self::from_data(output_data, output_shape, self.device)
789 }
790
791 pub fn put(&self, indices: &Tensor<i64>, values: &Self) -> Result<Self> {
793 let self_data = self.data()?;
794
795 let indices_data = indices.data()?;
796 let values_data = values.data()?;
797
798 if indices.shape() != values.shape() {
800 return Err(TorshError::ShapeMismatch {
801 expected: indices.shape().dims().to_vec(),
802 got: values.shape().dims().to_vec(),
803 });
804 }
805
806 let self_size = self.shape().numel();
807 let mut output_data = self_data.clone();
808
809 for (i, &idx) in indices_data.iter().enumerate() {
811 let idx = if idx < 0 {
812 (self_size as i64 + idx) as usize
813 } else {
814 idx as usize
815 };
816
817 if idx >= self_size {
818 return Err(TorshError::IndexOutOfBounds {
819 index: idx,
820 size: self_size,
821 });
822 }
823
824 output_data[idx] = values_data[i];
825 }
826
827 Self::from_data(output_data, self.shape().dims().to_vec(), self.device)
828 }
829
830 pub fn index_select(&self, dim: i32, index: &Tensor<i64>) -> Result<Self> {
832 let ndim = self.ndim() as i32;
833 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
834
835 if dim >= self.ndim() {
836 return Err(TorshError::InvalidArgument(format!(
837 "Dimension {} out of range for tensor with {} dimensions",
838 dim,
839 self.ndim()
840 )));
841 }
842
843 if index.ndim() != 1 {
845 return Err(TorshError::InvalidShape(
846 "index_select expects a 1D index tensor".to_string(),
847 ));
848 }
849
850 let mut output_shape = self.shape().dims().to_vec();
852 output_shape[dim] = index.shape().dims()[0];
853
854 let output_size: usize = output_shape.iter().product();
855 let mut output_data = Vec::with_capacity(output_size);
856
857 let self_data = self.data()?;
858
859 let index_data = index.data()?;
860
861 let self_strides = self.compute_strides();
863 let _output_strides = Self::compute_strides_for_shape(&output_shape);
864
865 for out_idx in 0..output_size {
867 let mut indices = vec![0; self.ndim()];
869 let mut remaining = out_idx;
870 for i in (0..self.ndim()).rev() {
871 indices[i] = remaining % output_shape[i];
872 remaining /= output_shape[i];
873 }
874
875 let select_idx = indices[dim];
877 let selected_value = index_data[select_idx] as usize;
878
879 if selected_value >= self.shape().dims()[dim] {
880 return Err(TorshError::IndexOutOfBounds {
881 index: selected_value,
882 size: self.shape().dims()[dim],
883 });
884 }
885
886 indices[dim] = selected_value;
887
888 let src_flat_idx = indices
890 .iter()
891 .zip(&self_strides)
892 .map(|(idx, stride)| idx * stride)
893 .sum::<usize>();
894
895 output_data.push(self_data[src_flat_idx]);
896 }
897
898 Self::from_data(output_data, output_shape, self.device)
899 }
900
901 pub(crate) fn compute_strides(&self) -> Vec<usize> {
903 Self::compute_strides_for_shape(self.shape().dims())
904 }
905
906 pub(crate) fn compute_strides_for_shape(shape: &[usize]) -> Vec<usize> {
908 let mut strides = vec![1; shape.len()];
909 for i in (0..shape.len() - 1).rev() {
910 strides[i] = strides[i + 1] * shape[i + 1];
911 }
912 strides
913 }
914}
915
916fn compute_strides_from_shape(shape: &[usize]) -> Vec<usize> {
918 let mut strides = vec![1; shape.len()];
919 for i in (0..shape.len() - 1).rev() {
920 strides[i] = strides[i + 1] * shape[i + 1];
921 }
922 strides
923}
924
925#[macro_export]
927macro_rules! idx {
928 ($idx:expr) => {
930 vec![TensorIndex::Index($idx)]
931 };
932
933 ($($idx:expr),+ $(,)?) => {
935 vec![$(TensorIndex::Index($idx)),+]
936 };
937}
938
939#[macro_export]
940macro_rules! s {
941 (..) => {
943 TensorIndex::All
944 };
945
946 (.. $stop:expr) => {
948 TensorIndex::range(None, Some($stop))
949 };
950
951 ($start:expr, $stop:expr) => {
953 TensorIndex::range(Some($start), Some($stop))
954 };
955
956 ($start:expr, $stop:expr, $step:expr) => {
958 TensorIndex::range_step(Some($start), Some($stop), $step)
959 };
960
961 (ellipsis) => {
963 TensorIndex::Ellipsis
964 };
965
966 (None) => {
968 TensorIndex::NewAxis
969 };
970}
971
972#[macro_export]
974macro_rules! fancy_idx {
975 [$($idx:expr),+ $(,)?] => {
977 TensorIndex::List(vec![$($idx),+])
978 };
979}
980
981#[macro_export]
982macro_rules! mask_idx {
983 [$mask:expr] => {
985 TensorIndex::Mask($mask)
986 };
987}
988
989impl<T: TensorElement> Tensor<T> {
991 pub fn index_with_list(&self, dim: i32, indices: &[i64]) -> Result<Self> {
993 let ndim = self.ndim() as i32;
994 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
995
996 if dim >= self.ndim() {
997 return Err(TorshError::InvalidArgument(format!(
998 "Dimension {} out of range for tensor with {} dimensions",
999 dim,
1000 self.ndim()
1001 )));
1002 }
1003
1004 let mut index_spec = vec![TensorIndex::All; self.ndim()];
1005 index_spec[dim] = TensorIndex::List(indices.to_vec());
1006
1007 self.index(&index_spec)
1008 }
1009
1010 pub fn index_with_mask(&self, dim: i32, mask: &Tensor<bool>) -> Result<Self> {
1012 let ndim = self.ndim() as i32;
1013 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
1014
1015 if dim >= self.ndim() {
1016 return Err(TorshError::InvalidArgument(format!(
1017 "Dimension {} out of range for tensor with {} dimensions",
1018 dim,
1019 self.ndim()
1020 )));
1021 }
1022
1023 let mut index_spec = vec![TensorIndex::All; self.ndim()];
1024 index_spec[dim] = TensorIndex::Mask(mask.clone());
1025
1026 self.index(&index_spec)
1027 }
1028
1029 pub fn mask_select(&self, mask: &Tensor<bool>) -> Result<Self> {
1031 if self.shape() != mask.shape() {
1032 return Err(TorshError::ShapeMismatch {
1033 expected: self.shape().dims().to_vec(),
1034 got: mask.shape().dims().to_vec(),
1035 });
1036 }
1037
1038 let self_data = self.data()?;
1039
1040 let mask_data = mask.data()?;
1041
1042 let mut selected_data = Vec::new();
1044 for (i, &mask_val) in mask_data.iter().enumerate() {
1045 if mask_val {
1046 selected_data.push(self_data[i]);
1047 }
1048 }
1049
1050 Self::from_data(
1052 selected_data.clone(),
1053 vec![selected_data.len()],
1054 self.device,
1055 )
1056 }
1057
1058 pub fn where_condition<F>(&self, condition: F) -> Result<Tensor<bool>>
1060 where
1061 F: Fn(&T) -> bool,
1062 T: Clone,
1063 {
1064 let data = self.data()?;
1065
1066 let mask_data: Vec<bool> = data.iter().map(condition).collect();
1067
1068 Tensor::from_data(mask_data, self.shape().dims().to_vec(), self.device)
1069 }
1070
1071 pub fn scatter_indexed(&self, dim: i32, index: &Tensor<i64>, src: &Self) -> Result<Self> {
1073 let ndim = self.ndim() as i32;
1074 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
1075
1076 if dim >= self.ndim() {
1077 return Err(TorshError::InvalidArgument(format!(
1078 "Dimension {} out of range for tensor with {} dimensions",
1079 dim,
1080 self.ndim()
1081 )));
1082 }
1083
1084 let self_shape_binding = self.shape();
1085 let self_shape = self_shape_binding.dims();
1086 let index_shape_binding = index.shape();
1087 let index_shape = index_shape_binding.dims();
1088 let src_shape_binding = src.shape();
1089 let src_shape = src_shape_binding.dims();
1090
1091 if index_shape != src_shape {
1093 return Err(TorshError::ShapeMismatch {
1094 expected: index_shape.to_vec(),
1095 got: src_shape.to_vec(),
1096 });
1097 }
1098
1099 if index_shape.len() != self_shape.len() {
1100 return Err(TorshError::InvalidArgument(
1101 "Index tensor must have same number of dimensions as input tensor".to_string(),
1102 ));
1103 }
1104
1105 let mut result_data = self.data()?.clone();
1107 let index_data = index.data()?;
1108 let src_data = src.data()?;
1109 let self_strides = self.compute_strides();
1110
1111 let index_size = index_shape.iter().product();
1112
1113 for flat_idx in 0..index_size {
1115 let mut coords = Vec::new();
1117 let mut temp_idx = flat_idx;
1118
1119 for &dim_size in index_shape.iter().rev() {
1120 coords.push(temp_idx % dim_size);
1121 temp_idx /= dim_size;
1122 }
1123 coords.reverse();
1124
1125 let scatter_idx = index_data[flat_idx];
1127 let dim_size = self_shape[dim] as i64;
1128 let scatter_idx = if scatter_idx < 0 {
1129 dim_size + scatter_idx
1130 } else {
1131 scatter_idx
1132 };
1133
1134 if scatter_idx < 0 || scatter_idx >= dim_size {
1135 return Err(TorshError::IndexOutOfBounds {
1136 index: scatter_idx as usize,
1137 size: dim_size as usize,
1138 });
1139 }
1140
1141 coords[dim] = scatter_idx as usize;
1143 let mut dest_idx = 0;
1144 for (coord, &stride) in coords.iter().zip(self_strides.iter()) {
1145 dest_idx += coord * stride;
1146 }
1147
1148 result_data[dest_idx] = src_data[flat_idx];
1149 }
1150
1151 Self::from_data(result_data, self_shape.to_vec(), self.device)
1152 }
1153}
1154
1155#[cfg(test)]
1156mod tests {
1157 use super::*;
1158 use crate::creation::{tensor_2d, zeros};
1159
1160 #[test]
1161 fn test_index_macros() {
1162 let indices = idx![5];
1164 assert_eq!(indices.len(), 1);
1165
1166 let indices = idx![1, 2, 3];
1168 assert_eq!(indices.len(), 3);
1169
1170 let _all = s![..];
1172 let _range = s![1, 5];
1173 let _range_step = s![1, 10, 2];
1174 let _to = s![..7];
1175
1176 let _fancy = fancy_idx![0, 2, 1];
1178 let _ellipsis = s![ellipsis];
1179 let _newaxis = s![None];
1180 }
1181
1182 #[test]
1183 fn test_get_set() {
1184 let tensor = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
1185
1186 assert_eq!(tensor.get(&[0, 0]).unwrap(), 1.0);
1188 assert_eq!(tensor.get(&[0, 1]).unwrap(), 2.0);
1189 assert_eq!(tensor.get(&[1, 2]).unwrap(), 6.0);
1190
1191 tensor.set(&[1, 1], 10.0).unwrap();
1193 assert_eq!(tensor.get(&[1, 1]).unwrap(), 10.0);
1194
1195 assert!(tensor.get(&[2, 0]).is_err());
1197 assert!(tensor.set(&[0, 3], 0.0).is_err());
1198 }
1199
1200 #[test]
1201 fn test_gather() {
1202 let tensor = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]).unwrap();
1204
1205 let indices = tensor_2d(&[&[0i64, 2, 1], &[1, 0, 2], &[2, 1, 0]]).unwrap();
1207
1208 let result = tensor.gather(1, &indices).unwrap();
1209
1210 assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1212 assert_eq!(result.get(&[0, 1]).unwrap(), 3.0);
1213 assert_eq!(result.get(&[0, 2]).unwrap(), 2.0);
1214 assert_eq!(result.get(&[1, 0]).unwrap(), 5.0);
1215 assert_eq!(result.get(&[1, 1]).unwrap(), 4.0);
1216 assert_eq!(result.get(&[1, 2]).unwrap(), 6.0);
1217 assert_eq!(result.get(&[2, 0]).unwrap(), 9.0);
1218 assert_eq!(result.get(&[2, 1]).unwrap(), 8.0);
1219 assert_eq!(result.get(&[2, 2]).unwrap(), 7.0);
1220 }
1221
1222 #[test]
1223 fn test_scatter() {
1224 let tensor = zeros::<f32>(&[3, 3]).unwrap();
1226
1227 let indices = tensor_2d(&[&[0i64, 2, 1], &[1, 0, 2], &[2, 1, 0]]).unwrap();
1229
1230 let src = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]).unwrap();
1232
1233 let result = tensor.scatter(1, &indices, &src).unwrap();
1234
1235 assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1237 assert_eq!(result.get(&[0, 1]).unwrap(), 3.0);
1238 assert_eq!(result.get(&[0, 2]).unwrap(), 2.0);
1239 assert_eq!(result.get(&[1, 0]).unwrap(), 5.0);
1240 assert_eq!(result.get(&[1, 1]).unwrap(), 4.0);
1241 assert_eq!(result.get(&[1, 2]).unwrap(), 6.0);
1242 assert_eq!(result.get(&[2, 0]).unwrap(), 9.0);
1243 assert_eq!(result.get(&[2, 1]).unwrap(), 8.0);
1244 assert_eq!(result.get(&[2, 2]).unwrap(), 7.0);
1245 }
1246
1247 #[test]
1248 fn test_index_select() {
1249 let tensor = tensor_2d(&[
1251 &[1.0, 2.0, 3.0, 4.0],
1252 &[5.0, 6.0, 7.0, 8.0],
1253 &[9.0, 10.0, 11.0, 12.0],
1254 ])
1255 .unwrap();
1256
1257 let row_indices = crate::creation::tensor_1d(&[0i64, 2]).unwrap();
1259 let result = tensor.index_select(0, &row_indices).unwrap();
1260
1261 assert_eq!(result.shape().dims(), &[2, 4]);
1262 assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1263 assert_eq!(result.get(&[0, 3]).unwrap(), 4.0);
1264 assert_eq!(result.get(&[1, 0]).unwrap(), 9.0);
1265 assert_eq!(result.get(&[1, 3]).unwrap(), 12.0);
1266
1267 let col_indices = crate::creation::tensor_1d(&[1i64, 3]).unwrap();
1269 let result = tensor.index_select(1, &col_indices).unwrap();
1270
1271 assert_eq!(result.shape().dims(), &[3, 2]);
1272 assert_eq!(result.get(&[0, 0]).unwrap(), 2.0);
1273 assert_eq!(result.get(&[0, 1]).unwrap(), 4.0);
1274 assert_eq!(result.get(&[2, 0]).unwrap(), 10.0);
1275 assert_eq!(result.get(&[2, 1]).unwrap(), 12.0);
1276 }
1277
1278 #[test]
1279 fn test_list_indexing() {
1280 let tensor = tensor_2d(&[
1282 &[1.0, 2.0, 3.0, 4.0],
1283 &[5.0, 6.0, 7.0, 8.0],
1284 &[9.0, 10.0, 11.0, 12.0],
1285 ])
1286 .unwrap();
1287
1288 let indices = vec![TensorIndex::List(vec![0, 2]), TensorIndex::All];
1290 let result = tensor.index(&indices).unwrap();
1291
1292 assert_eq!(result.shape().dims(), &[2, 4]);
1293 assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1294 assert_eq!(result.get(&[0, 3]).unwrap(), 4.0);
1295 assert_eq!(result.get(&[1, 0]).unwrap(), 9.0);
1296 assert_eq!(result.get(&[1, 3]).unwrap(), 12.0);
1297
1298 let result2 = tensor.index_with_list(0, &[0, 2]).unwrap();
1300 assert_eq!(result.shape(), result2.shape());
1301 assert_eq!(result.get(&[0, 0]).unwrap(), result2.get(&[0, 0]).unwrap());
1302 }
1303
1304 #[test]
1305 fn test_boolean_mask_indexing() {
1306 use crate::creation::tensor_1d;
1307
1308 let tensor = tensor_1d(&[10.0, 20.0, 30.0, 40.0, 50.0]).unwrap();
1310
1311 let mask = Tensor::from_data(
1313 vec![true, false, true, false, true],
1314 vec![5],
1315 crate::DeviceType::Cpu,
1316 )
1317 .unwrap();
1318
1319 let result = tensor.mask_select(&mask).unwrap();
1321 assert_eq!(result.shape().dims(), &[3]);
1322 assert_eq!(result.get(&[0]).unwrap(), 10.0);
1323 assert_eq!(result.get(&[1]).unwrap(), 30.0);
1324 assert_eq!(result.get(&[2]).unwrap(), 50.0);
1325
1326 let result2 = tensor.index_with_mask(0, &mask).unwrap();
1328 assert_eq!(result2.shape().dims(), &[3]);
1329 assert_eq!(result2.get(&[0]).unwrap(), 10.0);
1330 assert_eq!(result2.get(&[1]).unwrap(), 30.0);
1331 assert_eq!(result2.get(&[2]).unwrap(), 50.0);
1332 }
1333
1334 #[test]
1335 fn test_where_condition() {
1336 use crate::creation::tensor_1d;
1337
1338 let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1339
1340 let mask = tensor.where_condition(|&x| x > 3.0).unwrap();
1342
1343 {
1344 let mask_data = mask.data().unwrap();
1345 assert!(!mask_data[0]); assert!(!mask_data[1]); assert!(!mask_data[2]); assert!(mask_data[3]); assert!(mask_data[4]); } let selected = tensor.mask_select(&mask).unwrap();
1354 assert_eq!(selected.shape().dims(), &[2]);
1355 assert_eq!(selected.get(&[0]).unwrap(), 4.0);
1356 assert_eq!(selected.get(&[1]).unwrap(), 5.0);
1357 }
1358
1359 #[test]
1360 fn test_newaxis_indexing() {
1361 use crate::creation::tensor_1d;
1362
1363 let tensor = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
1364
1365 let indices = vec![TensorIndex::NewAxis, TensorIndex::All];
1367 let result = tensor.index(&indices).unwrap();
1368 assert_eq!(result.shape().dims(), &[1, 3]);
1369
1370 let indices = vec![TensorIndex::All, TensorIndex::NewAxis];
1372 let result = tensor.index(&indices).unwrap();
1373 assert_eq!(result.shape().dims(), &[3, 1]);
1374
1375 let indices = vec![
1377 TensorIndex::NewAxis,
1378 TensorIndex::All,
1379 TensorIndex::NewAxis,
1380 TensorIndex::NewAxis,
1381 ];
1382 let result = tensor.index(&indices).unwrap();
1383 assert_eq!(result.shape().dims(), &[1, 3, 1, 1]);
1384 }
1385
1386 #[test]
1387 fn test_ellipsis_indexing() {
1388 let tensor = crate::creation::zeros::<f32>(&[2, 3, 4]).unwrap();
1390
1391 let indices = vec![TensorIndex::Index(0), TensorIndex::Ellipsis];
1393 let result = tensor.index(&indices).unwrap();
1394 assert_eq!(result.shape().dims(), &[3, 4]);
1395
1396 let indices = vec![TensorIndex::Index(1), TensorIndex::Ellipsis];
1398 let result = tensor.index(&indices).unwrap();
1399 assert_eq!(result.shape().dims(), &[3, 4]);
1400 }
1401
1402 #[test]
1403 fn test_complex_indexing() {
1404 let tensor = tensor_2d(&[
1406 &[1.0, 2.0, 3.0, 4.0],
1407 &[5.0, 6.0, 7.0, 8.0],
1408 &[9.0, 10.0, 11.0, 12.0],
1409 &[13.0, 14.0, 15.0, 16.0],
1410 ])
1411 .unwrap();
1412
1413 let indices = vec![
1415 TensorIndex::List(vec![0, 2, 3]),
1416 TensorIndex::Range(Some(1), Some(4), None),
1417 ];
1418 let result = tensor.index(&indices).unwrap();
1419
1420 assert_eq!(result.shape().dims(), &[3, 3]);
1421 assert_eq!(result.get(&[0, 0]).unwrap(), 2.0); assert_eq!(result.get(&[1, 0]).unwrap(), 10.0); assert_eq!(result.get(&[2, 2]).unwrap(), 16.0); }
1425
1426 #[test]
1427 fn test_negative_indexing() {
1428 use crate::creation::tensor_1d;
1429
1430 let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1431
1432 let indices = vec![TensorIndex::Index(-1)];
1434 let result = tensor.index(&indices).unwrap();
1435 assert_eq!(result.numel(), 1);
1436 assert_eq!(result.item().unwrap(), 5.0);
1437
1438 let indices = vec![TensorIndex::Range(Some(-3), Some(-1), None)];
1440 let result = tensor.index(&indices).unwrap();
1441 assert_eq!(result.shape().dims(), &[2]);
1442 assert_eq!(result.get(&[0]).unwrap(), 3.0);
1443 assert_eq!(result.get(&[1]).unwrap(), 4.0);
1444
1445 let indices = vec![TensorIndex::List(vec![-1, -2, 0])];
1447 let result = tensor.index(&indices).unwrap();
1448 assert_eq!(result.shape().dims(), &[3]);
1449 assert_eq!(result.get(&[0]).unwrap(), 5.0); assert_eq!(result.get(&[1]).unwrap(), 4.0); assert_eq!(result.get(&[2]).unwrap(), 1.0); }
1453}