tenflowers_core/tensor/
indexing.rs1use super::core::{Tensor, TensorStorage};
8use std::ops::Index;
9
10impl<T: Clone> Index<&[usize]> for Tensor<T> {
12 type Output = T;
13
14 fn index(&self, index: &[usize]) -> &Self::Output {
15 match &self.storage {
16 TensorStorage::Cpu(arr) => {
17 if index.len() != arr.ndim() {
18 panic!(
19 "Index dimension mismatch: expected {} dimensions, got {}",
20 arr.ndim(),
21 index.len()
22 );
23 }
24 arr.get(index).expect("Index out of bounds")
25 }
26 #[cfg(feature = "gpu")]
27 TensorStorage::Gpu(_) => {
28 panic!("Direct indexing not supported for GPU tensors. Use .get() or convert to CPU first.")
29 }
30 }
31 }
32}
33
34impl<T: Clone> Index<usize> for Tensor<T> {
36 type Output = T;
37
38 fn index(&self, index: usize) -> &Self::Output {
39 match &self.storage {
40 TensorStorage::Cpu(arr) => {
41 if arr.ndim() != 1 {
42 panic!(
43 "Single index only supported for 1D tensors, but tensor has {} dimensions",
44 arr.ndim()
45 );
46 }
47 arr.get([index]).expect("Index out of bounds")
48 }
49 #[cfg(feature = "gpu")]
50 TensorStorage::Gpu(_) => {
51 panic!("Direct indexing not supported for GPU tensors. Use .get() or convert to CPU first.")
52 }
53 }
54 }
55}