tch_plus/tensor/
index.rs

1//! Indexing operations
2//!
3//! This module defines the `i` indexing operation. This can be used in various
4//! scenarios.
5//!
6//! Using an integer index returns the slice obtained by selecting elements with
7//! the specified index. Negative values can be used for the index, and `..` can
8//! be used to get all the indexes from a given dimension.
9//!
10//! ```ignore
11//! use crate::tch::{IndexOp, Tensor};
12//! let tensor = Tensor::from_slice(&[1, 2, 3, 4, 5, 6]).view((2, 3));
13//! let t = tensor.i(1);
14//! let t = tensor.i((.., -2));
15//! ```
16//!
17//! Indexes like `1..`, `..1`, or `1..2`, can be used to narrow a dimension.
18//!
19//! ```ignore
20//! use crate::tch::{IndexOp, Tensor};
21//! let tensor = Tensor::from_slice(&[1, 2, 3, 4, 5, 6]).view((2, 3));
22//! let t = tensor.i((.., 1..));
23//! assert_eq!(t.size(), [2, 2]);
24//! assert_eq!(Vec::<i64>::from(t.contiguous().view(-1)), [2, 3, 5, 6]);
25//! let t = tensor.i((..1, ..));
26//! assert_eq!(t.size(), [1, 3]);
27//! assert_eq!(Vec::<i64>::from(t.contiguous().view(-1)), [1, 2, 3]);
28//! let t = tensor.i((.., 1..2));
29//! assert_eq!(t.size(), [2, 1]);
30//! assert_eq!(Vec::<i64>::from(t.contiguous().view(-1)), [2, 5]);
31//! let t = tensor.i((.., 1..=2));
32//! assert_eq!(t.size(), [2, 2]);
33//! assert_eq!(Vec::<i64>::from(t.contiguous().view(-1)), [2, 3, 5, 6]);
34//! ```
35//!
36//! The `NewAxis` index can be used to insert a dimension.
37//!
38//! ```ignore
39//! use crate::tch::{IndexOp, NewAxis, Tensor};
40//! let tensor = Tensor::from_slice(&[1, 2, 3, 4, 5, 6]).view((2, 3));
41//! let t = tensor.i((NewAxis,));
42//! assert_eq!(t.size(), &[1, 2, 3]);
43//! let t = tensor.i((.., .., NewAxis));
44//! assert_eq!(t.size(), &[2, 3, 1]);
45//! ```
46//!
47//! Unlike NumPy, the `i` operation does not support advanced indexing.
48//! The result can be different from NumPy with same set of arguments.
49//! For example, `tensor.i(..1, vec![0, 3], vec![2, 1, 3])` does narrowing
50//! on first dimension, and index selection on second and third dimensions.
51//! The analogous NumPy indexing `array[:1, [0, 3], [2, 1, 3]]` throws
52//! shape mismatch error due to advanced indexing rule. Another distinction
53//! is that `i` guarantees the input and result tensor shares the same
54//! underlying storage, while NumPy may copy the tensor in certain scenarios.
55use crate::{Result, TchError, Tensor};
56use std::ops::{
57    Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
58};
59
60#[derive(Debug, PartialEq, Eq)]
61pub struct NewAxis;
62
63#[derive(Debug, PartialEq)]
64pub enum TensorIndexer {
65    Select(i64),
66    Narrow(Bound<i64>, Bound<i64>),
67    IndexSelect(Tensor),
68    InsertNewAxis,
69}
70
71impl From<NewAxis> for TensorIndexer {
72    fn from(_index: NewAxis) -> Self {
73        TensorIndexer::InsertNewAxis
74    }
75}
76
77impl From<i64> for TensorIndexer {
78    fn from(index: i64) -> Self {
79        TensorIndexer::Select(index)
80    }
81}
82
83impl From<&[i64]> for TensorIndexer {
84    fn from(index: &[i64]) -> Self {
85        let tensor = index.into();
86        TensorIndexer::IndexSelect(tensor)
87    }
88}
89
90impl From<Vec<i64>> for TensorIndexer {
91    fn from(index: Vec<i64>) -> Self {
92        let tensor = Tensor::from_slice(&index);
93        TensorIndexer::IndexSelect(tensor)
94    }
95}
96
97impl From<&Tensor> for TensorIndexer {
98    fn from(tensor: &Tensor) -> Self {
99        TensorIndexer::IndexSelect(tensor.shallow_clone())
100    }
101}
102
103macro_rules! impl_from_range {
104    ($range_type:ty) => {
105        impl From<$range_type> for TensorIndexer {
106            fn from(range: $range_type) -> Self {
107                use std::ops::Bound::*;
108
109                let start = match range.start_bound() {
110                    Included(idx) => Included(*idx),
111                    Excluded(idx) => Excluded(*idx),
112                    Unbounded => Unbounded,
113                };
114
115                let end = match range.end_bound() {
116                    Included(idx) => Included(*idx),
117                    Excluded(idx) => Excluded(*idx),
118                    Unbounded => Unbounded,
119                };
120
121                TensorIndexer::Narrow(start, end)
122            }
123        }
124    };
125}
126
127impl_from_range!(Range<i64>);
128impl_from_range!(RangeFrom<i64>);
129impl_from_range!(RangeFull);
130impl_from_range!(RangeInclusive<i64>);
131impl_from_range!(RangeTo<i64>);
132impl_from_range!(RangeToInclusive<i64>);
133
134pub trait IndexOp<T> {
135    fn i(&self, index: T) -> Tensor;
136    fn f_i(&self, index: T) -> Result<Tensor>;
137}
138
139impl<A> IndexOp<A> for Tensor
140where
141    A: Into<TensorIndexer>,
142{
143    fn i(&self, index: A) -> Tensor {
144        self.f_i(index).unwrap()
145    }
146
147    fn f_i(&self, index: A) -> Result<Tensor> {
148        self.f_indexer(&[index.into()])
149    }
150}
151
152impl<A> IndexOp<(A,)> for Tensor
153where
154    A: Into<TensorIndexer>,
155{
156    fn i(&self, index: (A,)) -> Tensor {
157        self.f_i(index).unwrap()
158    }
159
160    fn f_i(&self, index: (A,)) -> Result<Tensor> {
161        let idx_a = index.0.into();
162        self.f_indexer(&[idx_a])
163    }
164}
165
166impl<A, B> IndexOp<(A, B)> for Tensor
167where
168    A: Into<TensorIndexer>,
169    B: Into<TensorIndexer>,
170{
171    fn i(&self, index: (A, B)) -> Tensor {
172        self.f_i(index).unwrap()
173    }
174
175    fn f_i(&self, index: (A, B)) -> Result<Tensor> {
176        let idx_a = index.0.into();
177        let idx_b = index.1.into();
178        self.f_indexer(&[idx_a, idx_b])
179    }
180}
181
182impl<A, B, C> IndexOp<(A, B, C)> for Tensor
183where
184    A: Into<TensorIndexer>,
185    B: Into<TensorIndexer>,
186    C: Into<TensorIndexer>,
187{
188    fn i(&self, index: (A, B, C)) -> Tensor {
189        self.f_i(index).unwrap()
190    }
191
192    fn f_i(&self, index: (A, B, C)) -> Result<Tensor> {
193        let idx_a = index.0.into();
194        let idx_b = index.1.into();
195        let idx_c = index.2.into();
196        self.f_indexer(&[idx_a, idx_b, idx_c])
197    }
198}
199
200impl<A, B, C, D> IndexOp<(A, B, C, D)> for Tensor
201where
202    A: Into<TensorIndexer>,
203    B: Into<TensorIndexer>,
204    C: Into<TensorIndexer>,
205    D: Into<TensorIndexer>,
206{
207    fn i(&self, index: (A, B, C, D)) -> Tensor {
208        self.f_i(index).unwrap()
209    }
210
211    fn f_i(&self, index: (A, B, C, D)) -> Result<Tensor> {
212        let idx_a = index.0.into();
213        let idx_b = index.1.into();
214        let idx_c = index.2.into();
215        let idx_d = index.3.into();
216        self.f_indexer(&[idx_a, idx_b, idx_c, idx_d])
217    }
218}
219
220impl<A, B, C, D, E> IndexOp<(A, B, C, D, E)> for Tensor
221where
222    A: Into<TensorIndexer>,
223    B: Into<TensorIndexer>,
224    C: Into<TensorIndexer>,
225    D: Into<TensorIndexer>,
226    E: Into<TensorIndexer>,
227{
228    fn i(&self, index: (A, B, C, D, E)) -> Tensor {
229        self.f_i(index).unwrap()
230    }
231
232    fn f_i(&self, index: (A, B, C, D, E)) -> Result<Tensor> {
233        let idx_a = index.0.into();
234        let idx_b = index.1.into();
235        let idx_c = index.2.into();
236        let idx_d = index.3.into();
237        let idx_e = index.4.into();
238        self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e])
239    }
240}
241
242impl<A, B, C, D, E, F> IndexOp<(A, B, C, D, E, F)> for Tensor
243where
244    A: Into<TensorIndexer>,
245    B: Into<TensorIndexer>,
246    C: Into<TensorIndexer>,
247    D: Into<TensorIndexer>,
248    E: Into<TensorIndexer>,
249    F: Into<TensorIndexer>,
250{
251    fn i(&self, index: (A, B, C, D, E, F)) -> Tensor {
252        self.f_i(index).unwrap()
253    }
254
255    fn f_i(&self, index: (A, B, C, D, E, F)) -> Result<Tensor> {
256        let idx_a = index.0.into();
257        let idx_b = index.1.into();
258        let idx_c = index.2.into();
259        let idx_d = index.3.into();
260        let idx_e = index.4.into();
261        let idx_f = index.5.into();
262        self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f])
263    }
264}
265
266impl<A, B, C, D, E, F, G> IndexOp<(A, B, C, D, E, F, G)> for Tensor
267where
268    A: Into<TensorIndexer>,
269    B: Into<TensorIndexer>,
270    C: Into<TensorIndexer>,
271    D: Into<TensorIndexer>,
272    E: Into<TensorIndexer>,
273    F: Into<TensorIndexer>,
274    G: Into<TensorIndexer>,
275{
276    fn i(&self, index: (A, B, C, D, E, F, G)) -> Tensor {
277        self.f_i(index).unwrap()
278    }
279
280    fn f_i(&self, index: (A, B, C, D, E, F, G)) -> Result<Tensor> {
281        let idx_a = index.0.into();
282        let idx_b = index.1.into();
283        let idx_c = index.2.into();
284        let idx_d = index.3.into();
285        let idx_e = index.4.into();
286        let idx_f = index.5.into();
287        let idx_g = index.6.into();
288        self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g])
289    }
290}
291
292impl Tensor {
293    fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result<Tensor> {
294        use std::ops::Bound::*;
295        use TensorIndexer::*;
296
297        // Make sure n. non-newaxis does not exceed n. of dimensions
298        let n_newaxis = index_spec.iter().filter(|spec| *spec == &InsertNewAxis).count();
299
300        if index_spec.len() > self.size().len() + n_newaxis {
301            return Err(TchError::Shape(format!(
302                "too many indices for tensor of dimension {}",
303                self.size().len()
304            )));
305        }
306
307        // Make sure tensors conform the format
308        for spec in index_spec.iter() {
309            use super::Kind::*;
310            if let IndexSelect(tensor) = spec {
311                if tensor.size().len() != 1 {
312                    return Err(TchError::Shape(
313                        "Multi-dimensional tensor is not supported for indexing".to_string(),
314                    ));
315                }
316                match tensor.f_kind()? {
317                    Int64 => {}
318                    Int16 => {}
319                    Int8 => {}
320                    Int => {}
321                    _ => {
322                        return Err(TchError::Kind(format!("the kind of tensors used as indices must be one of {Int64:?}, {Int16:?}, {Int8:?}, {Int:?}")));
323                    }
324                }
325            }
326        }
327
328        // Apply indexing from left to right
329        let mut curr_tensor = self.shallow_clone();
330        let mut curr_idx: i64 = 0;
331
332        for spec in index_spec.iter() {
333            let (next_tensor, next_idx) = match spec {
334                InsertNewAxis => (curr_tensor.unsqueeze(curr_idx), curr_idx + 1),
335                Select(index) => (
336                    curr_tensor.select(curr_idx, *index),
337                    curr_idx, // not advanced because select() squeezes dimension
338                ),
339                Narrow(start, end) => {
340                    if let Some((start, length)) = match (start, end) {
341                        (Unbounded, Unbounded) => None,
342                        (Included(start), Unbounded) => {
343                            let dim_len = curr_tensor.size()[curr_idx as usize];
344                            Some((*start, dim_len - *start))
345                        }
346                        (Excluded(start), Unbounded) => {
347                            let dim_len = curr_tensor.size()[curr_idx as usize];
348                            Some((*start + 1, dim_len - *start - 1))
349                        }
350                        (Unbounded, Included(end)) => Some((0, *end + 1)),
351                        (Unbounded, Excluded(end)) => Some((0, *end)),
352                        (Included(start), Included(end)) => Some((*start, *end - *start + 1)),
353                        (Included(start), Excluded(end)) => Some((*start, *end - *start)),
354                        (Excluded(start), Included(end)) => Some((*start + 1, *end - *start)),
355                        (Excluded(start), Excluded(end)) => Some((*start + 1, *end - *start - 1)),
356                    } {
357                        (curr_tensor.f_narrow(curr_idx, start, length.max(0))?, curr_idx + 1)
358                    } else {
359                        (curr_tensor, curr_idx + 1)
360                    }
361                }
362                IndexSelect(index_tensor) => {
363                    let index_tensor = index_tensor.to_device(curr_tensor.device());
364                    (curr_tensor.index_select(curr_idx, &index_tensor), curr_idx + 1)
365                }
366            };
367            curr_tensor = next_tensor;
368            curr_idx = next_idx;
369        }
370        Ok(curr_tensor)
371    }
372}