Skip to main content

tenflowers_core/tensor/
indexing.rs

1//! Tensor Indexing Operations
2//!
3//! This module provides direct indexing capabilities for tensors,
4//! allowing for convenient access to tensor elements using
5//! standard Rust indexing syntax.
6
7use super::core::{Tensor, TensorStorage};
8use std::ops::Index;
9
10// Index trait implementation for tensor[indices] syntax
11impl<T: Clone> Index<&[usize]> for Tensor<T> {
12    type Output = T;
13
14    fn index(&self, index: &[usize]) -> &Self::Output {
15        match &self.storage {
16            TensorStorage::Cpu(arr) => {
17                if index.len() != arr.ndim() {
18                    panic!(
19                        "Index dimension mismatch: expected {} dimensions, got {}",
20                        arr.ndim(),
21                        index.len()
22                    );
23                }
24                arr.get(index).expect("Index out of bounds")
25            }
26            #[cfg(feature = "gpu")]
27            TensorStorage::Gpu(_) => {
28                panic!("Direct indexing not supported for GPU tensors. Use .get() or convert to CPU first.")
29            }
30        }
31    }
32}
33
34// Index trait implementation for single-dimension indexing
35impl<T: Clone> Index<usize> for Tensor<T> {
36    type Output = T;
37
38    fn index(&self, index: usize) -> &Self::Output {
39        match &self.storage {
40            TensorStorage::Cpu(arr) => {
41                if arr.ndim() != 1 {
42                    panic!(
43                        "Single index only supported for 1D tensors, but tensor has {} dimensions",
44                        arr.ndim()
45                    );
46                }
47                arr.get([index]).expect("Index out of bounds")
48            }
49            #[cfg(feature = "gpu")]
50            TensorStorage::Gpu(_) => {
51                panic!("Direct indexing not supported for GPU tensors. Use .get() or convert to CPU first.")
52            }
53        }
54    }
55}