torsh_tensor/
indexing.rs

1//! Tensor indexing and slicing operations
2
3use crate::{Tensor, TensorElement};
4use torsh_core::error::{Result, TorshError};
5
6/// Index type for tensor indexing
7#[derive(Debug, Clone)]
8pub enum TensorIndex {
9    /// Single index
10    Index(i64),
11    /// Range of indices
12    Range(Option<i64>, Option<i64>, Option<i64>), // start, stop, step
13    /// All indices (:)
14    All,
15    /// List of indices (fancy indexing)
16    List(Vec<i64>),
17    /// Boolean mask
18    Mask(Tensor<bool>),
19    /// Ellipsis (...) - represents multiple ':' to fill remaining dimensions
20    Ellipsis,
21    /// Newaxis (None) - adds a dimension of size 1
22    NewAxis,
23}
24
25impl TensorIndex {
26    /// Create a range index
27    pub fn range(start: Option<i64>, stop: Option<i64>) -> Self {
28        TensorIndex::Range(start, stop, None)
29    }
30
31    /// Create a range index with step
32    pub fn range_step(start: Option<i64>, stop: Option<i64>, step: i64) -> Self {
33        TensorIndex::Range(start, stop, Some(step))
34    }
35}
36
37/// Indexing implementation
38impl<T: TensorElement> Tensor<T> {
39    /// Index into the tensor
40    pub fn index(&self, indices: &[TensorIndex]) -> Result<Self> {
41        // Validate number of indices (NewAxis and Ellipsis don't consume tensor dimensions)
42        let consuming_indices = indices
43            .iter()
44            .filter(|idx| !matches!(idx, TensorIndex::NewAxis | TensorIndex::Ellipsis))
45            .count();
46
47        if consuming_indices > self.ndim() {
48            return Err(TorshError::InvalidArgument(format!(
49                "Too many indices for tensor: tensor has {} dimensions but {} consuming indices were provided",
50                self.ndim(),
51                consuming_indices
52            )));
53        }
54
55        // Handle ellipsis by expanding indices first
56        let expanded_indices = self.expand_ellipsis(indices)?;
57
58        // Process each expanded index to determine the output shape and extraction logic
59        let mut output_shape = Vec::new();
60        let mut slices = Vec::new();
61        let mut input_dim_idx = 0; // Track which input tensor dimension we're accessing
62
63        for index in expanded_indices.iter() {
64            if let TensorIndex::NewAxis = index {
65                // NewAxis doesn't consume input dimensions, just adds a new dimension of size 1
66                output_shape.push(1);
67                slices.push((0, 1, 1));
68                // Don't increment input_dim_idx for NewAxis
69                continue;
70            }
71
72            // For all other indices, we need to get the dimension size from the input tensor
73            let dim_size = if input_dim_idx < self.ndim() {
74                self.shape().dims()[input_dim_idx]
75            } else {
76                return Err(TorshError::InvalidArgument(format!(
77                    "Index {} beyond tensor dimensions (tensor has {} dimensions)",
78                    input_dim_idx,
79                    self.ndim()
80                )));
81            };
82
83            match index {
84                TensorIndex::Index(idx) => {
85                    // Single index - this dimension is removed
86                    let idx = if *idx < 0 {
87                        (dim_size as i64 + idx) as usize
88                    } else {
89                        *idx as usize
90                    };
91
92                    if idx >= dim_size {
93                        return Err(TorshError::IndexOutOfBounds {
94                            index: idx,
95                            size: dim_size,
96                        });
97                    }
98
99                    slices.push((idx, idx + 1, 1));
100                    // Single index doesn't add an output dimension, but consumes input dimension
101                    input_dim_idx += 1;
102                }
103                TensorIndex::Range(start, stop, step) => {
104                    let step = step.unwrap_or(1);
105                    if step == 0 {
106                        return Err(TorshError::InvalidArgument(
107                            "Step cannot be zero".to_string(),
108                        ));
109                    }
110
111                    let start = start
112                        .map(|s| {
113                            if s < 0 {
114                                (dim_size as i64 + s).max(0) as usize
115                            } else {
116                                s.min(dim_size as i64) as usize
117                            }
118                        })
119                        .unwrap_or(0);
120
121                    let stop = stop
122                        .map(|s| {
123                            if s < 0 {
124                                (dim_size as i64 + s).max(0) as usize
125                            } else {
126                                s.min(dim_size as i64) as usize
127                            }
128                        })
129                        .unwrap_or(dim_size);
130
131                    let size = if step > 0 {
132                        ((stop as i64 - start as i64 + step - 1) / step).max(0) as usize
133                    } else {
134                        ((stop as i64 - start as i64 + step + 1) / step).max(0) as usize
135                    };
136
137                    output_shape.push(size);
138                    slices.push((start, stop, step as usize));
139                    input_dim_idx += 1;
140                }
141                TensorIndex::All => {
142                    output_shape.push(dim_size);
143                    slices.push((0, dim_size, 1));
144                    input_dim_idx += 1;
145                }
146                TensorIndex::List(indices_list) => {
147                    // Fancy indexing with list of indices
148                    for &idx in indices_list {
149                        let normalized_idx = if idx < 0 {
150                            (dim_size as i64 + idx) as usize
151                        } else {
152                            idx as usize
153                        };
154
155                        if normalized_idx >= dim_size {
156                            return Err(TorshError::IndexOutOfBounds {
157                                index: normalized_idx,
158                                size: dim_size,
159                            });
160                        }
161                    }
162
163                    output_shape.push(indices_list.len());
164                    // Store list indices as a special slice marker
165                    slices.push((0, indices_list.len(), 0)); // step=0 indicates list indexing
166                    input_dim_idx += 1;
167                }
168                TensorIndex::Mask(mask) => {
169                    // Boolean mask indexing - dimension is flattened
170                    if mask.ndim() != 1 {
171                        return Err(TorshError::InvalidArgument(
172                            "Boolean mask must be 1D for single dimension indexing".to_string(),
173                        ));
174                    }
175
176                    if mask.numel() != dim_size {
177                        return Err(TorshError::ShapeMismatch {
178                            expected: vec![dim_size],
179                            got: mask.shape().dims().to_vec(),
180                        });
181                    }
182
183                    // Count True values to determine output size
184                    let mask_data = mask.to_vec()?;
185                    let true_count = mask_data.iter().filter(|&&x| x).count();
186
187                    output_shape.push(true_count);
188                    // Store mask as special slice marker
189                    slices.push((0, true_count, 0)); // step=0 indicates mask indexing
190                    input_dim_idx += 1;
191                }
192                TensorIndex::NewAxis => {
193                    // This should not happen since NewAxis is handled earlier
194                    return Err(TorshError::InvalidArgument(
195                        "NewAxis should be handled before this point".to_string(),
196                    ));
197                }
198                TensorIndex::Ellipsis => {
199                    // This should not happen since ellipsis is expanded earlier
200                    return Err(TorshError::InvalidArgument(
201                        "Ellipsis should be expanded before processing".to_string(),
202                    ));
203                }
204            }
205        }
206
207        // If all indices were single indices, we need at least one dimension
208        if output_shape.is_empty() {
209            output_shape.push(1);
210        }
211
212        // Use specialized extraction logic for advanced indexing
213        if expanded_indices
214            .iter()
215            .any(|idx| matches!(idx, TensorIndex::List(_) | TensorIndex::Mask(_)))
216        {
217            self.extract_advanced_indexing(&expanded_indices, &output_shape)
218        } else {
219            self.extract_basic_indexing(&expanded_indices, &output_shape, &slices)
220        }
221    }
222
223    /// Extract data using basic indexing (ranges, single indices, all)
224    fn extract_basic_indexing(
225        &self,
226        indices: &[TensorIndex],
227        output_shape: &[usize],
228        slices: &[(usize, usize, usize)],
229    ) -> Result<Self> {
230        let input_data = self.to_vec()?;
231
232        let output_size = output_shape.iter().product();
233        let mut output_data = Vec::with_capacity(output_size);
234
235        let input_strides = self.compute_strides();
236        let output_strides = compute_strides_from_shape(output_shape);
237
238        for out_idx in 0..output_size {
239            // Convert flat index to multi-dimensional indices
240            let mut out_indices = vec![0; output_shape.len()];
241            let mut remaining = out_idx;
242            for (i, &stride) in output_strides.iter().enumerate() {
243                out_indices[i] = remaining / stride;
244                remaining %= stride;
245            }
246
247            // Map output indices to input indices using slices
248            let mut input_flat_idx = 0;
249            let mut out_dim = 0;
250            let mut input_dim = 0;
251
252            for (slice_idx, &(start, _, step)) in slices.iter().enumerate() {
253                // Skip NewAxis dimensions in input tensor
254                if slice_idx < indices.len() && matches!(indices[slice_idx], TensorIndex::NewAxis) {
255                    out_dim += 1;
256                    continue;
257                }
258
259                // Ensure we don't exceed input dimensions
260                if input_dim >= input_strides.len() {
261                    break;
262                }
263
264                let idx = if slice_idx < indices.len()
265                    && matches!(indices[slice_idx], TensorIndex::Index(_))
266                {
267                    start
268                } else {
269                    start + out_indices[out_dim] * step
270                };
271                input_flat_idx += idx * input_strides[input_dim];
272
273                if !(slice_idx < indices.len()
274                    && matches!(indices[slice_idx], TensorIndex::Index(_)))
275                {
276                    out_dim += 1;
277                }
278                input_dim += 1;
279            }
280
281            output_data.push(input_data[input_flat_idx]);
282        }
283
284        Self::from_data(output_data, output_shape.to_vec(), self.device)
285    }
286
287    /// Extract data using advanced indexing (lists, masks)
288    fn extract_advanced_indexing(
289        &self,
290        indices: &[TensorIndex],
291        output_shape: &[usize],
292    ) -> Result<Self> {
293        let input_data = self.to_vec()?;
294
295        let output_size = output_shape.iter().product();
296        let mut output_data = Vec::with_capacity(output_size);
297
298        let input_strides = self.compute_strides();
299        let output_strides = compute_strides_from_shape(output_shape);
300
301        for out_idx in 0..output_size {
302            // Convert flat index to multi-dimensional indices
303            let mut out_indices = vec![0; output_shape.len()];
304            let mut remaining = out_idx;
305            for (i, &stride) in output_strides.iter().enumerate() {
306                out_indices[i] = remaining / stride;
307                remaining %= stride;
308            }
309
310            // Map output indices to input indices using advanced indexing
311            let mut input_flat_idx = 0;
312            let mut out_dim = 0;
313
314            for (dim_idx, index) in indices.iter().enumerate() {
315                if dim_idx >= self.ndim() {
316                    break;
317                }
318
319                let input_idx = match index {
320                    TensorIndex::Index(idx) => {
321                        let dim_size = self.shape().dims()[dim_idx];
322
323                        if *idx < 0 {
324                            (dim_size as i64 + idx) as usize
325                        } else {
326                            *idx as usize
327                        }
328                    }
329                    TensorIndex::Range(start, _stop, step) => {
330                        let dim_size = self.shape().dims()[dim_idx];
331                        let step = step.unwrap_or(1);
332                        let start = start
333                            .map(|s| {
334                                if s < 0 {
335                                    (dim_size as i64 + s).max(0) as usize
336                                } else {
337                                    s.min(dim_size as i64) as usize
338                                }
339                            })
340                            .unwrap_or(0);
341
342                        start + out_indices[out_dim] * (step as usize)
343                    }
344                    TensorIndex::All => out_indices[out_dim],
345                    TensorIndex::List(indices_list) => {
346                        // Fancy indexing: use the list index
347                        let list_idx = out_indices[out_dim];
348                        if list_idx >= indices_list.len() {
349                            return Err(TorshError::IndexOutOfBounds {
350                                index: list_idx,
351                                size: indices_list.len(),
352                            });
353                        }
354
355                        let actual_idx = indices_list[list_idx];
356                        let dim_size = self.shape().dims()[dim_idx];
357
358                        if actual_idx < 0 {
359                            (dim_size as i64 + actual_idx) as usize
360                        } else {
361                            actual_idx as usize
362                        }
363                    }
364                    TensorIndex::Mask(mask) => {
365                        // Boolean mask indexing
366                        let mask_data = mask.to_vec()?;
367
368                        // Find the nth True value in the mask
369                        let target_true_idx = out_indices[out_dim];
370                        let mut true_count = 0;
371                        let mut found_idx = None;
372                        for (i, &mask_val) in mask_data.iter().enumerate() {
373                            if mask_val {
374                                if true_count == target_true_idx {
375                                    found_idx = Some(i);
376                                    break;
377                                }
378                                true_count += 1;
379                            }
380                        }
381
382                        match found_idx {
383                            Some(idx) => idx,
384                            None => {
385                                return Err(TorshError::IndexOutOfBounds {
386                                    index: target_true_idx,
387                                    size: true_count,
388                                });
389                            }
390                        }
391                    }
392                    TensorIndex::NewAxis => {
393                        // NewAxis doesn't consume input dimensions
394                        continue;
395                    }
396                    TensorIndex::Ellipsis => {
397                        // Ellipsis should be handled in shape computation
398                        out_indices[out_dim]
399                    }
400                };
401
402                input_flat_idx += input_idx * input_strides[dim_idx];
403
404                // Only advance output dimension for non-index operations
405                if !matches!(index, TensorIndex::Index(_) | TensorIndex::NewAxis) {
406                    out_dim += 1;
407                }
408            }
409
410            // Handle remaining dimensions
411            for stride in input_strides
412                .iter()
413                .skip(indices.len())
414                .take(self.ndim() - indices.len())
415            {
416                if out_dim < out_indices.len() {
417                    input_flat_idx += out_indices[out_dim] * stride;
418                    out_dim += 1;
419                }
420            }
421
422            if input_flat_idx >= input_data.len() {
423                return Err(TorshError::IndexOutOfBounds {
424                    index: input_flat_idx,
425                    size: input_data.len(),
426                });
427            }
428
429            output_data.push(input_data[input_flat_idx]);
430        }
431
432        Self::from_data(output_data, output_shape.to_vec(), self.device)
433    }
434
435    /// Expand ellipsis into explicit All indices
436    fn expand_ellipsis(&self, indices: &[TensorIndex]) -> Result<Vec<TensorIndex>> {
437        let mut expanded = Vec::new();
438        let mut found_ellipsis = false;
439
440        // Count non-ellipsis, non-newaxis indices to determine how many dimensions ellipsis should expand to
441        let non_expanding_indices = indices
442            .iter()
443            .filter(|idx| !matches!(idx, TensorIndex::Ellipsis | TensorIndex::NewAxis))
444            .count();
445
446        for index in indices {
447            match index {
448                TensorIndex::Ellipsis => {
449                    if found_ellipsis {
450                        return Err(TorshError::InvalidArgument(
451                            "Only one ellipsis (...) is allowed per indexing operation".to_string(),
452                        ));
453                    }
454                    found_ellipsis = true;
455
456                    // Calculate how many dimensions the ellipsis should expand to
457                    let ellipsis_dims = if self.ndim() >= non_expanding_indices {
458                        self.ndim() - non_expanding_indices
459                    } else {
460                        0
461                    };
462
463                    // Expand ellipsis to All indices
464                    for _ in 0..ellipsis_dims {
465                        expanded.push(TensorIndex::All);
466                    }
467                }
468                _ => {
469                    expanded.push(index.clone());
470                }
471            }
472        }
473
474        // If no ellipsis was found, add implicit trailing All indices for remaining dimensions
475        if !found_ellipsis {
476            let current_dims = expanded
477                .iter()
478                .filter(|idx| !matches!(idx, TensorIndex::NewAxis))
479                .count();
480
481            for _ in current_dims..self.ndim() {
482                expanded.push(TensorIndex::All);
483            }
484        }
485
486        Ok(expanded)
487    }
488
489    /// Get a single element (1D indexing)
490    pub fn get_1d(&self, index: usize) -> Result<T> {
491        if self.ndim() != 1 {
492            return Err(TorshError::InvalidShape(
493                "get_1d() can only be used on 1D tensors".to_string(),
494            ));
495        }
496
497        if index >= self.shape().dims()[0] {
498            return Err(TorshError::IndexOutOfBounds {
499                index,
500                size: self.shape().dims()[0],
501            });
502        }
503
504        let data = self.data()?;
505        Ok(data[index])
506    }
507
508    /// Get a single element (2D indexing)
509    pub fn get_2d(&self, row: usize, col: usize) -> Result<T> {
510        if self.ndim() != 2 {
511            return Err(TorshError::InvalidShape(
512                "get_2d() can only be used on 2D tensors".to_string(),
513            ));
514        }
515
516        let shape = self.shape();
517        if row >= shape.dims()[0] || col >= shape.dims()[1] {
518            return Err(TorshError::IndexOutOfBounds {
519                index: row * shape.dims()[1] + col,
520                size: shape.numel(),
521            });
522        }
523
524        let data = self.to_vec()?;
525
526        let index = row * shape.dims()[1] + col;
527        Ok(data[index])
528    }
529
530    /// Get a single element (3D indexing)
531    pub fn get_3d(&self, x: usize, y: usize, z: usize) -> Result<T> {
532        if self.ndim() != 3 {
533            return Err(TorshError::InvalidShape(
534                "get_3d() can only be used on 3D tensors".to_string(),
535            ));
536        }
537
538        let shape = self.shape();
539        if x >= shape.dims()[0] || y >= shape.dims()[1] || z >= shape.dims()[2] {
540            return Err(TorshError::IndexOutOfBounds {
541                index: x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z,
542                size: shape.numel(),
543            });
544        }
545
546        let data = self.to_vec()?;
547
548        let index = x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z;
549        Ok(data[index])
550    }
551
552    /// Set a single element (1D indexing)
553    pub fn set_1d(&mut self, index: usize, value: T) -> Result<()> {
554        if self.ndim() != 1 {
555            return Err(TorshError::InvalidShape(
556                "set_1d() can only be used on 1D tensors".to_string(),
557            ));
558        }
559
560        if index >= self.shape().dims()[0] {
561            return Err(TorshError::IndexOutOfBounds {
562                index,
563                size: self.shape().dims()[0],
564            });
565        }
566
567        let mut data = self.to_vec()?;
568        data[index] = value;
569        *self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
570        Ok(())
571    }
572
573    /// Set a single element (2D indexing)
574    pub fn set_2d(&mut self, row: usize, col: usize, value: T) -> Result<()> {
575        if self.ndim() != 2 {
576            return Err(TorshError::InvalidShape(
577                "set_2d() can only be used on 2D tensors".to_string(),
578            ));
579        }
580
581        let shape = self.shape();
582        if row >= shape.dims()[0] || col >= shape.dims()[1] {
583            return Err(TorshError::IndexOutOfBounds {
584                index: row * shape.dims()[1] + col,
585                size: shape.numel(),
586            });
587        }
588
589        let mut data = self.to_vec()?;
590        let index = row * shape.dims()[1] + col;
591        data[index] = value;
592        *self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
593        Ok(())
594    }
595
596    /// Set a single element (3D indexing)
597    pub fn set_3d(&mut self, x: usize, y: usize, z: usize, value: T) -> Result<()> {
598        if self.ndim() != 3 {
599            return Err(TorshError::InvalidShape(
600                "set_3d() can only be used on 3D tensors".to_string(),
601            ));
602        }
603
604        let shape = self.shape();
605        if x >= shape.dims()[0] || y >= shape.dims()[1] || z >= shape.dims()[2] {
606            return Err(TorshError::IndexOutOfBounds {
607                index: x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z,
608                size: shape.numel(),
609            });
610        }
611
612        let mut data = self.to_vec()?;
613        let index = x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z;
614        data[index] = value;
615        *self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
616        Ok(())
617    }
618
619    /// Select along a dimension
620    pub fn select(&self, dim: i32, index: i64) -> Result<Self> {
621        let ndim = self.ndim() as i32;
622        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
623
624        if dim >= self.ndim() {
625            return Err(TorshError::InvalidArgument(format!(
626                "Dimension {} out of range for tensor with {} dimensions",
627                dim,
628                self.ndim()
629            )));
630        }
631
632        let dim_size = self.shape().dims()[dim] as i64;
633        let index = if index < 0 { dim_size + index } else { index };
634
635        if index < 0 || index >= dim_size {
636            return Err(TorshError::IndexOutOfBounds {
637                index: index as usize,
638                size: dim_size as usize,
639            });
640        }
641
642        // Create index array for slicing
643        let mut indices = Vec::new();
644        for d in 0..self.ndim() {
645            if d == dim {
646                indices.push(TensorIndex::Index(index));
647            } else {
648                indices.push(TensorIndex::All);
649            }
650        }
651
652        // Use the existing index function
653        self.index(&indices)
654    }
655
656    /// Slice along a dimension with PyTorch-style parameters
657    pub fn slice_with_step(
658        &self,
659        dim: i32,
660        start: Option<i64>,
661        end: Option<i64>,
662        step: Option<i64>,
663    ) -> Result<Self> {
664        let ndim = self.ndim() as i32;
665        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
666
667        if dim >= self.ndim() {
668            return Err(TorshError::InvalidArgument(format!(
669                "Dimension {} out of range for tensor with {} dimensions",
670                dim,
671                self.ndim()
672            )));
673        }
674
675        // Create index array for slicing
676        let mut indices = Vec::new();
677        for d in 0..self.ndim() {
678            if d == dim {
679                indices.push(TensorIndex::Range(start, end, step));
680            } else {
681                indices.push(TensorIndex::All);
682            }
683        }
684
685        // Use the existing index function
686        self.index(&indices)
687    }
688
689    /// Narrow along a dimension
690    pub fn narrow(&self, dim: i32, start: i64, length: usize) -> Result<Self> {
691        let ndim = self.ndim() as i32;
692        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
693
694        if dim >= self.ndim() {
695            return Err(TorshError::InvalidArgument(format!(
696                "Dimension {} out of range for tensor with {} dimensions",
697                dim,
698                self.ndim()
699            )));
700        }
701
702        let dim_size = self.shape().dims()[dim] as i64;
703        let start = if start < 0 { dim_size + start } else { start };
704
705        if start < 0 || start >= dim_size {
706            return Err(TorshError::InvalidArgument(format!(
707                "Start index {start} out of range for dimension {dim} with size {dim_size}"
708            )));
709        }
710
711        let end = start + length as i64;
712        if end > dim_size {
713            return Err(TorshError::InvalidArgument(format!(
714                "End index {end} out of range for dimension {dim} with size {dim_size}"
715            )));
716        }
717
718        // Create index array for slicing
719        let mut indices = Vec::new();
720        for d in 0..self.ndim() {
721            if d == dim {
722                indices.push(TensorIndex::Range(Some(start), Some(end), None));
723            } else {
724                indices.push(TensorIndex::All);
725            }
726        }
727
728        // Use the existing index function
729        self.index(&indices)
730    }
731
732    /// Boolean indexing (masking)
733    pub fn masked_select(&self, mask: &Tensor<bool>) -> Result<Self> {
734        if self.shape() != mask.shape() {
735            return Err(TorshError::ShapeMismatch {
736                expected: self.shape().dims().to_vec(),
737                got: mask.shape().dims().to_vec(),
738            });
739        }
740
741        let self_data = self.data()?;
742        let mask_data = mask.data()?;
743
744        // Collect all elements where mask is true
745        let mut selected_data = Vec::new();
746        for (i, &mask_val) in mask_data.iter().enumerate() {
747            if mask_val {
748                selected_data.push(self_data[i]);
749            }
750        }
751
752        // Return 1D tensor with selected elements
753        Self::from_data(
754            selected_data.clone(),
755            vec![selected_data.len()],
756            self.device,
757        )
758    }
759
760    pub fn take(&self, indices: &Tensor<i64>) -> Result<Self> {
761        let self_data = self.data()?;
762
763        let indices_data = indices.data()?;
764
765        let self_size = self.shape().numel();
766        let output_shape = indices.shape().dims().to_vec();
767        let output_size = indices.shape().numel();
768        let mut output_data = Vec::with_capacity(output_size);
769
770        // Take elements at the given flat indices
771        for &idx in indices_data.iter() {
772            let idx = if idx < 0 {
773                (self_size as i64 + idx) as usize
774            } else {
775                idx as usize
776            };
777
778            if idx >= self_size {
779                return Err(TorshError::IndexOutOfBounds {
780                    index: idx,
781                    size: self_size,
782                });
783            }
784
785            output_data.push(self_data[idx]);
786        }
787
788        Self::from_data(output_data, output_shape, self.device)
789    }
790
791    /// Put values at indices
792    pub fn put(&self, indices: &Tensor<i64>, values: &Self) -> Result<Self> {
793        let self_data = self.data()?;
794
795        let indices_data = indices.data()?;
796        let values_data = values.data()?;
797
798        // Check that indices and values have the same shape
799        if indices.shape() != values.shape() {
800            return Err(TorshError::ShapeMismatch {
801                expected: indices.shape().dims().to_vec(),
802                got: values.shape().dims().to_vec(),
803            });
804        }
805
806        let self_size = self.shape().numel();
807        let mut output_data = self_data.clone();
808
809        // Put values at the given flat indices
810        for (i, &idx) in indices_data.iter().enumerate() {
811            let idx = if idx < 0 {
812                (self_size as i64 + idx) as usize
813            } else {
814                idx as usize
815            };
816
817            if idx >= self_size {
818                return Err(TorshError::IndexOutOfBounds {
819                    index: idx,
820                    size: self_size,
821                });
822            }
823
824            output_data[idx] = values_data[i];
825        }
826
827        Self::from_data(output_data, self.shape().dims().to_vec(), self.device)
828    }
829
830    /// Select indices along a dimension
831    pub fn index_select(&self, dim: i32, index: &Tensor<i64>) -> Result<Self> {
832        let ndim = self.ndim() as i32;
833        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
834
835        if dim >= self.ndim() {
836            return Err(TorshError::InvalidArgument(format!(
837                "Dimension {} out of range for tensor with {} dimensions",
838                dim,
839                self.ndim()
840            )));
841        }
842
843        // Index must be 1D
844        if index.ndim() != 1 {
845            return Err(TorshError::InvalidShape(
846                "index_select expects a 1D index tensor".to_string(),
847            ));
848        }
849
850        // Calculate output shape
851        let mut output_shape = self.shape().dims().to_vec();
852        output_shape[dim] = index.shape().dims()[0];
853
854        let output_size: usize = output_shape.iter().product();
855        let mut output_data = Vec::with_capacity(output_size);
856
857        let self_data = self.data()?;
858
859        let index_data = index.data()?;
860
861        // Compute strides
862        let self_strides = self.compute_strides();
863        let _output_strides = Self::compute_strides_for_shape(&output_shape);
864
865        // Select elements
866        for out_idx in 0..output_size {
867            // Convert flat index to multi-dimensional index
868            let mut indices = vec![0; self.ndim()];
869            let mut remaining = out_idx;
870            for i in (0..self.ndim()).rev() {
871                indices[i] = remaining % output_shape[i];
872                remaining /= output_shape[i];
873            }
874
875            // For the selected dimension, use the index from the index tensor
876            let select_idx = indices[dim];
877            let selected_value = index_data[select_idx] as usize;
878
879            if selected_value >= self.shape().dims()[dim] {
880                return Err(TorshError::IndexOutOfBounds {
881                    index: selected_value,
882                    size: self.shape().dims()[dim],
883                });
884            }
885
886            indices[dim] = selected_value;
887
888            // Compute flat index in source tensor
889            let src_flat_idx = indices
890                .iter()
891                .zip(&self_strides)
892                .map(|(idx, stride)| idx * stride)
893                .sum::<usize>();
894
895            output_data.push(self_data[src_flat_idx]);
896        }
897
898        Self::from_data(output_data, output_shape, self.device)
899    }
900
901    /// Compute strides for the tensor's shape
902    pub(crate) fn compute_strides(&self) -> Vec<usize> {
903        Self::compute_strides_for_shape(self.shape().dims())
904    }
905
906    /// Compute strides for a given shape
907    pub(crate) fn compute_strides_for_shape(shape: &[usize]) -> Vec<usize> {
908        let mut strides = vec![1; shape.len()];
909        for i in (0..shape.len() - 1).rev() {
910            strides[i] = strides[i + 1] * shape[i + 1];
911        }
912        strides
913    }
914}
915
916/// Helper function to compute strides from shape
917fn compute_strides_from_shape(shape: &[usize]) -> Vec<usize> {
918    let mut strides = vec![1; shape.len()];
919    for i in (0..shape.len() - 1).rev() {
920        strides[i] = strides[i + 1] * shape[i + 1];
921    }
922    strides
923}
924
925/// Convenience macros for indexing
926#[macro_export]
927macro_rules! idx {
928    // Single index: idx![5]
929    ($idx:expr) => {
930        vec![TensorIndex::Index($idx)]
931    };
932
933    // Multiple indices: idx![1, 2, 3]
934    ($($idx:expr),+ $(,)?) => {
935        vec![$(TensorIndex::Index($idx)),+]
936    };
937}
938
939#[macro_export]
940macro_rules! s {
941    // Full slice: s![..]
942    (..) => {
943        TensorIndex::All
944    };
945
946    // To end: s![..5]
947    (.. $stop:expr) => {
948        TensorIndex::range(None, Some($stop))
949    };
950
951    // Range (comma syntax): s![1, 5]
952    ($start:expr, $stop:expr) => {
953        TensorIndex::range(Some($start), Some($stop))
954    };
955
956    // Range with step (comma syntax): s![1, 5, 2]
957    ($start:expr, $stop:expr, $step:expr) => {
958        TensorIndex::range_step(Some($start), Some($stop), $step)
959    };
960
961    // Ellipsis: s![ellipsis]
962    (ellipsis) => {
963        TensorIndex::Ellipsis
964    };
965
966    // NewAxis: s![None]
967    (None) => {
968        TensorIndex::NewAxis
969    };
970}
971
972/// Advanced indexing macros
973#[macro_export]
974macro_rules! fancy_idx {
975    // List indexing: fancy_idx![0, 2, 1]
976    [$($idx:expr),+ $(,)?] => {
977        TensorIndex::List(vec![$($idx),+])
978    };
979}
980
981#[macro_export]
982macro_rules! mask_idx {
983    // Boolean mask indexing: mask_idx![mask_tensor]
984    [$mask:expr] => {
985        TensorIndex::Mask($mask)
986    };
987}
988
989/// Convenient indexing syntax
990impl<T: TensorElement> Tensor<T> {
991    /// Advanced indexing with list of indices (fancy indexing)
992    pub fn index_with_list(&self, dim: i32, indices: &[i64]) -> Result<Self> {
993        let ndim = self.ndim() as i32;
994        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
995
996        if dim >= self.ndim() {
997            return Err(TorshError::InvalidArgument(format!(
998                "Dimension {} out of range for tensor with {} dimensions",
999                dim,
1000                self.ndim()
1001            )));
1002        }
1003
1004        let mut index_spec = vec![TensorIndex::All; self.ndim()];
1005        index_spec[dim] = TensorIndex::List(indices.to_vec());
1006
1007        self.index(&index_spec)
1008    }
1009
1010    /// Boolean mask indexing for a specific dimension
1011    pub fn index_with_mask(&self, dim: i32, mask: &Tensor<bool>) -> Result<Self> {
1012        let ndim = self.ndim() as i32;
1013        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
1014
1015        if dim >= self.ndim() {
1016            return Err(TorshError::InvalidArgument(format!(
1017                "Dimension {} out of range for tensor with {} dimensions",
1018                dim,
1019                self.ndim()
1020            )));
1021        }
1022
1023        let mut index_spec = vec![TensorIndex::All; self.ndim()];
1024        index_spec[dim] = TensorIndex::Mask(mask.clone());
1025
1026        self.index(&index_spec)
1027    }
1028
1029    /// Global boolean mask indexing (flattens to 1D result)
1030    pub fn mask_select(&self, mask: &Tensor<bool>) -> Result<Self> {
1031        if self.shape() != mask.shape() {
1032            return Err(TorshError::ShapeMismatch {
1033                expected: self.shape().dims().to_vec(),
1034                got: mask.shape().dims().to_vec(),
1035            });
1036        }
1037
1038        let self_data = self.data()?;
1039
1040        let mask_data = mask.data()?;
1041
1042        // Collect all elements where mask is true
1043        let mut selected_data = Vec::new();
1044        for (i, &mask_val) in mask_data.iter().enumerate() {
1045            if mask_val {
1046                selected_data.push(self_data[i]);
1047            }
1048        }
1049
1050        // Return 1D tensor with selected elements
1051        Self::from_data(
1052            selected_data.clone(),
1053            vec![selected_data.len()],
1054            self.device,
1055        )
1056    }
1057
1058    /// Create boolean mask from condition
1059    pub fn where_condition<F>(&self, condition: F) -> Result<Tensor<bool>>
1060    where
1061        F: Fn(&T) -> bool,
1062        T: Clone,
1063    {
1064        let data = self.data()?;
1065
1066        let mask_data: Vec<bool> = data.iter().map(condition).collect();
1067
1068        Tensor::from_data(mask_data, self.shape().dims().to_vec(), self.device)
1069    }
1070
1071    /// Scatter values along an axis using indices (indexing version)
1072    pub fn scatter_indexed(&self, dim: i32, index: &Tensor<i64>, src: &Self) -> Result<Self> {
1073        let ndim = self.ndim() as i32;
1074        let dim = if dim < 0 { ndim + dim } else { dim } as usize;
1075
1076        if dim >= self.ndim() {
1077            return Err(TorshError::InvalidArgument(format!(
1078                "Dimension {} out of range for tensor with {} dimensions",
1079                dim,
1080                self.ndim()
1081            )));
1082        }
1083
1084        let self_shape_binding = self.shape();
1085        let self_shape = self_shape_binding.dims();
1086        let index_shape_binding = index.shape();
1087        let index_shape = index_shape_binding.dims();
1088        let src_shape_binding = src.shape();
1089        let src_shape = src_shape_binding.dims();
1090
1091        // Validate shapes
1092        if index_shape != src_shape {
1093            return Err(TorshError::ShapeMismatch {
1094                expected: index_shape.to_vec(),
1095                got: src_shape.to_vec(),
1096            });
1097        }
1098
1099        if index_shape.len() != self_shape.len() {
1100            return Err(TorshError::InvalidArgument(
1101                "Index tensor must have same number of dimensions as input tensor".to_string(),
1102            ));
1103        }
1104
1105        // Start with a copy of self
1106        let mut result_data = self.data()?.clone();
1107        let index_data = index.data()?;
1108        let src_data = src.data()?;
1109        let self_strides = self.compute_strides();
1110
1111        let index_size = index_shape.iter().product();
1112
1113        // Process each element in the index tensor
1114        for flat_idx in 0..index_size {
1115            // Convert flat index to multi-dimensional coordinates
1116            let mut coords = Vec::new();
1117            let mut temp_idx = flat_idx;
1118
1119            for &dim_size in index_shape.iter().rev() {
1120                coords.push(temp_idx % dim_size);
1121                temp_idx /= dim_size;
1122            }
1123            coords.reverse();
1124
1125            // Get the index value for the scatter dimension
1126            let scatter_idx = index_data[flat_idx];
1127            let dim_size = self_shape[dim] as i64;
1128            let scatter_idx = if scatter_idx < 0 {
1129                dim_size + scatter_idx
1130            } else {
1131                scatter_idx
1132            };
1133
1134            if scatter_idx < 0 || scatter_idx >= dim_size {
1135                return Err(TorshError::IndexOutOfBounds {
1136                    index: scatter_idx as usize,
1137                    size: dim_size as usize,
1138                });
1139            }
1140
1141            // Calculate destination index in result tensor
1142            coords[dim] = scatter_idx as usize;
1143            let mut dest_idx = 0;
1144            for (coord, &stride) in coords.iter().zip(self_strides.iter()) {
1145                dest_idx += coord * stride;
1146            }
1147
1148            result_data[dest_idx] = src_data[flat_idx];
1149        }
1150
1151        Self::from_data(result_data, self_shape.to_vec(), self.device)
1152    }
1153}
1154
1155#[cfg(test)]
1156mod tests {
1157    use super::*;
1158    use crate::creation::{tensor_2d, zeros};
1159
1160    #[test]
1161    fn test_index_macros() {
1162        // Test single index
1163        let indices = idx![5];
1164        assert_eq!(indices.len(), 1);
1165
1166        // Test multiple indices
1167        let indices = idx![1, 2, 3];
1168        assert_eq!(indices.len(), 3);
1169
1170        // Test slice macros
1171        let _all = s![..];
1172        let _range = s![1, 5];
1173        let _range_step = s![1, 10, 2];
1174        let _to = s![..7];
1175
1176        // Test advanced indexing macros
1177        let _fancy = fancy_idx![0, 2, 1];
1178        let _ellipsis = s![ellipsis];
1179        let _newaxis = s![None];
1180    }
1181
1182    #[test]
1183    fn test_get_set() {
1184        let tensor = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
1185
1186        // Test get
1187        assert_eq!(tensor.get(&[0, 0]).unwrap(), 1.0);
1188        assert_eq!(tensor.get(&[0, 1]).unwrap(), 2.0);
1189        assert_eq!(tensor.get(&[1, 2]).unwrap(), 6.0);
1190
1191        // Test set
1192        tensor.set(&[1, 1], 10.0).unwrap();
1193        assert_eq!(tensor.get(&[1, 1]).unwrap(), 10.0);
1194
1195        // Test out of bounds
1196        assert!(tensor.get(&[2, 0]).is_err());
1197        assert!(tensor.set(&[0, 3], 0.0).is_err());
1198    }
1199
1200    #[test]
1201    fn test_gather() {
1202        // Create a 3x3 tensor
1203        let tensor = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]).unwrap();
1204
1205        // Create indices for gathering along dim=1
1206        let indices = tensor_2d(&[&[0i64, 2, 1], &[1, 0, 2], &[2, 1, 0]]).unwrap();
1207
1208        let result = tensor.gather(1, &indices).unwrap();
1209
1210        // Expected: [[1, 3, 2], [5, 4, 6], [9, 8, 7]]
1211        assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1212        assert_eq!(result.get(&[0, 1]).unwrap(), 3.0);
1213        assert_eq!(result.get(&[0, 2]).unwrap(), 2.0);
1214        assert_eq!(result.get(&[1, 0]).unwrap(), 5.0);
1215        assert_eq!(result.get(&[1, 1]).unwrap(), 4.0);
1216        assert_eq!(result.get(&[1, 2]).unwrap(), 6.0);
1217        assert_eq!(result.get(&[2, 0]).unwrap(), 9.0);
1218        assert_eq!(result.get(&[2, 1]).unwrap(), 8.0);
1219        assert_eq!(result.get(&[2, 2]).unwrap(), 7.0);
1220    }
1221
1222    #[test]
1223    fn test_scatter() {
1224        // Create a 3x3 tensor of zeros
1225        let tensor = zeros::<f32>(&[3, 3]).unwrap();
1226
1227        // Create indices for scattering along dim=1
1228        let indices = tensor_2d(&[&[0i64, 2, 1], &[1, 0, 2], &[2, 1, 0]]).unwrap();
1229
1230        // Source values
1231        let src = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]).unwrap();
1232
1233        let result = tensor.scatter(1, &indices, &src).unwrap();
1234
1235        // Expected: [[1, 3, 2], [5, 4, 6], [9, 8, 7]]
1236        assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1237        assert_eq!(result.get(&[0, 1]).unwrap(), 3.0);
1238        assert_eq!(result.get(&[0, 2]).unwrap(), 2.0);
1239        assert_eq!(result.get(&[1, 0]).unwrap(), 5.0);
1240        assert_eq!(result.get(&[1, 1]).unwrap(), 4.0);
1241        assert_eq!(result.get(&[1, 2]).unwrap(), 6.0);
1242        assert_eq!(result.get(&[2, 0]).unwrap(), 9.0);
1243        assert_eq!(result.get(&[2, 1]).unwrap(), 8.0);
1244        assert_eq!(result.get(&[2, 2]).unwrap(), 7.0);
1245    }
1246
1247    #[test]
1248    fn test_index_select() {
1249        // Create a 3x4 tensor
1250        let tensor = tensor_2d(&[
1251            &[1.0, 2.0, 3.0, 4.0],
1252            &[5.0, 6.0, 7.0, 8.0],
1253            &[9.0, 10.0, 11.0, 12.0],
1254        ])
1255        .unwrap();
1256
1257        // Select rows 0 and 2
1258        let row_indices = crate::creation::tensor_1d(&[0i64, 2]).unwrap();
1259        let result = tensor.index_select(0, &row_indices).unwrap();
1260
1261        assert_eq!(result.shape().dims(), &[2, 4]);
1262        assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1263        assert_eq!(result.get(&[0, 3]).unwrap(), 4.0);
1264        assert_eq!(result.get(&[1, 0]).unwrap(), 9.0);
1265        assert_eq!(result.get(&[1, 3]).unwrap(), 12.0);
1266
1267        // Select columns 1 and 3
1268        let col_indices = crate::creation::tensor_1d(&[1i64, 3]).unwrap();
1269        let result = tensor.index_select(1, &col_indices).unwrap();
1270
1271        assert_eq!(result.shape().dims(), &[3, 2]);
1272        assert_eq!(result.get(&[0, 0]).unwrap(), 2.0);
1273        assert_eq!(result.get(&[0, 1]).unwrap(), 4.0);
1274        assert_eq!(result.get(&[2, 0]).unwrap(), 10.0);
1275        assert_eq!(result.get(&[2, 1]).unwrap(), 12.0);
1276    }
1277
1278    #[test]
1279    fn test_list_indexing() {
1280        // Test fancy indexing with list of indices
1281        let tensor = tensor_2d(&[
1282            &[1.0, 2.0, 3.0, 4.0],
1283            &[5.0, 6.0, 7.0, 8.0],
1284            &[9.0, 10.0, 11.0, 12.0],
1285        ])
1286        .unwrap();
1287
1288        // Select rows 0 and 2 using list indexing
1289        let indices = vec![TensorIndex::List(vec![0, 2]), TensorIndex::All];
1290        let result = tensor.index(&indices).unwrap();
1291
1292        assert_eq!(result.shape().dims(), &[2, 4]);
1293        assert_eq!(result.get(&[0, 0]).unwrap(), 1.0);
1294        assert_eq!(result.get(&[0, 3]).unwrap(), 4.0);
1295        assert_eq!(result.get(&[1, 0]).unwrap(), 9.0);
1296        assert_eq!(result.get(&[1, 3]).unwrap(), 12.0);
1297
1298        // Test index_with_list convenience method
1299        let result2 = tensor.index_with_list(0, &[0, 2]).unwrap();
1300        assert_eq!(result.shape(), result2.shape());
1301        assert_eq!(result.get(&[0, 0]).unwrap(), result2.get(&[0, 0]).unwrap());
1302    }
1303
1304    #[test]
1305    fn test_boolean_mask_indexing() {
1306        use crate::creation::tensor_1d;
1307
1308        // Create test tensor
1309        let tensor = tensor_1d(&[10.0, 20.0, 30.0, 40.0, 50.0]).unwrap();
1310
1311        // Create boolean mask
1312        let mask = Tensor::from_data(
1313            vec![true, false, true, false, true],
1314            vec![5],
1315            crate::DeviceType::Cpu,
1316        )
1317        .unwrap();
1318
1319        // Test mask_select (global mask)
1320        let result = tensor.mask_select(&mask).unwrap();
1321        assert_eq!(result.shape().dims(), &[3]);
1322        assert_eq!(result.get(&[0]).unwrap(), 10.0);
1323        assert_eq!(result.get(&[1]).unwrap(), 30.0);
1324        assert_eq!(result.get(&[2]).unwrap(), 50.0);
1325
1326        // Test dimensional mask indexing
1327        let result2 = tensor.index_with_mask(0, &mask).unwrap();
1328        assert_eq!(result2.shape().dims(), &[3]);
1329        assert_eq!(result2.get(&[0]).unwrap(), 10.0);
1330        assert_eq!(result2.get(&[1]).unwrap(), 30.0);
1331        assert_eq!(result2.get(&[2]).unwrap(), 50.0);
1332    }
1333
1334    #[test]
1335    fn test_where_condition() {
1336        use crate::creation::tensor_1d;
1337
1338        let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1339
1340        // Create mask for values > 3.0
1341        let mask = tensor.where_condition(|&x| x > 3.0).unwrap();
1342
1343        {
1344            let mask_data = mask.data().unwrap();
1345            assert!(!mask_data[0]); // 1.0 <= 3.0
1346            assert!(!mask_data[1]); // 2.0 <= 3.0
1347            assert!(!mask_data[2]); // 3.0 <= 3.0
1348            assert!(mask_data[3]); // 4.0 > 3.0
1349            assert!(mask_data[4]); // 5.0 > 3.0
1350        } // Explicitly drop the lock
1351
1352        // Use the mask to select elements
1353        let selected = tensor.mask_select(&mask).unwrap();
1354        assert_eq!(selected.shape().dims(), &[2]);
1355        assert_eq!(selected.get(&[0]).unwrap(), 4.0);
1356        assert_eq!(selected.get(&[1]).unwrap(), 5.0);
1357    }
1358
1359    #[test]
1360    fn test_newaxis_indexing() {
1361        use crate::creation::tensor_1d;
1362
1363        let tensor = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
1364
1365        // Add new axis at beginning
1366        let indices = vec![TensorIndex::NewAxis, TensorIndex::All];
1367        let result = tensor.index(&indices).unwrap();
1368        assert_eq!(result.shape().dims(), &[1, 3]);
1369
1370        // Add new axis at end
1371        let indices = vec![TensorIndex::All, TensorIndex::NewAxis];
1372        let result = tensor.index(&indices).unwrap();
1373        assert_eq!(result.shape().dims(), &[3, 1]);
1374
1375        // Add multiple new axes
1376        let indices = vec![
1377            TensorIndex::NewAxis,
1378            TensorIndex::All,
1379            TensorIndex::NewAxis,
1380            TensorIndex::NewAxis,
1381        ];
1382        let result = tensor.index(&indices).unwrap();
1383        assert_eq!(result.shape().dims(), &[1, 3, 1, 1]);
1384    }
1385
1386    #[test]
1387    fn test_ellipsis_indexing() {
1388        // Create 3D tensor
1389        let tensor = crate::creation::zeros::<f32>(&[2, 3, 4]).unwrap();
1390
1391        // Test ellipsis in middle
1392        let indices = vec![TensorIndex::Index(0), TensorIndex::Ellipsis];
1393        let result = tensor.index(&indices).unwrap();
1394        assert_eq!(result.shape().dims(), &[3, 4]);
1395
1396        // Test ellipsis at end
1397        let indices = vec![TensorIndex::Index(1), TensorIndex::Ellipsis];
1398        let result = tensor.index(&indices).unwrap();
1399        assert_eq!(result.shape().dims(), &[3, 4]);
1400    }
1401
1402    #[test]
1403    fn test_complex_indexing() {
1404        // Test combination of different indexing types
1405        let tensor = tensor_2d(&[
1406            &[1.0, 2.0, 3.0, 4.0],
1407            &[5.0, 6.0, 7.0, 8.0],
1408            &[9.0, 10.0, 11.0, 12.0],
1409            &[13.0, 14.0, 15.0, 16.0],
1410        ])
1411        .unwrap();
1412
1413        // Combine list indexing with range indexing
1414        let indices = vec![
1415            TensorIndex::List(vec![0, 2, 3]),
1416            TensorIndex::Range(Some(1), Some(4), None),
1417        ];
1418        let result = tensor.index(&indices).unwrap();
1419
1420        assert_eq!(result.shape().dims(), &[3, 3]);
1421        assert_eq!(result.get(&[0, 0]).unwrap(), 2.0); // tensor[0, 1]
1422        assert_eq!(result.get(&[1, 0]).unwrap(), 10.0); // tensor[2, 1]
1423        assert_eq!(result.get(&[2, 2]).unwrap(), 16.0); // tensor[3, 3]
1424    }
1425
1426    #[test]
1427    fn test_negative_indexing() {
1428        use crate::creation::tensor_1d;
1429
1430        let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1431
1432        // Test negative single index
1433        let indices = vec![TensorIndex::Index(-1)];
1434        let result = tensor.index(&indices).unwrap();
1435        assert_eq!(result.numel(), 1);
1436        assert_eq!(result.item().unwrap(), 5.0);
1437
1438        // Test negative range
1439        let indices = vec![TensorIndex::Range(Some(-3), Some(-1), None)];
1440        let result = tensor.index(&indices).unwrap();
1441        assert_eq!(result.shape().dims(), &[2]);
1442        assert_eq!(result.get(&[0]).unwrap(), 3.0);
1443        assert_eq!(result.get(&[1]).unwrap(), 4.0);
1444
1445        // Test negative list indexing
1446        let indices = vec![TensorIndex::List(vec![-1, -2, 0])];
1447        let result = tensor.index(&indices).unwrap();
1448        assert_eq!(result.shape().dims(), &[3]);
1449        assert_eq!(result.get(&[0]).unwrap(), 5.0); // -1 -> index 4
1450        assert_eq!(result.get(&[1]).unwrap(), 4.0); // -2 -> index 3
1451        assert_eq!(result.get(&[2]).unwrap(), 1.0); // 0 -> index 0
1452    }
1453}