1use crate::{Result, TchError, Tensor};
56use std::ops::{
57 Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
58};
59
60#[derive(Debug, PartialEq, Eq)]
61pub struct NewAxis;
62
63#[derive(Debug, PartialEq)]
64pub enum TensorIndexer {
65 Select(i64),
66 Narrow(Bound<i64>, Bound<i64>),
67 IndexSelect(Tensor),
68 InsertNewAxis,
69}
70
71impl From<NewAxis> for TensorIndexer {
72 fn from(_index: NewAxis) -> Self {
73 TensorIndexer::InsertNewAxis
74 }
75}
76
77impl From<i64> for TensorIndexer {
78 fn from(index: i64) -> Self {
79 TensorIndexer::Select(index)
80 }
81}
82
83impl From<&[i64]> for TensorIndexer {
84 fn from(index: &[i64]) -> Self {
85 let tensor = index.into();
86 TensorIndexer::IndexSelect(tensor)
87 }
88}
89
90impl From<Vec<i64>> for TensorIndexer {
91 fn from(index: Vec<i64>) -> Self {
92 let tensor = Tensor::from_slice(&index);
93 TensorIndexer::IndexSelect(tensor)
94 }
95}
96
97impl From<&Tensor> for TensorIndexer {
98 fn from(tensor: &Tensor) -> Self {
99 TensorIndexer::IndexSelect(tensor.shallow_clone())
100 }
101}
102
103macro_rules! impl_from_range {
104 ($range_type:ty) => {
105 impl From<$range_type> for TensorIndexer {
106 fn from(range: $range_type) -> Self {
107 use std::ops::Bound::*;
108
109 let start = match range.start_bound() {
110 Included(idx) => Included(*idx),
111 Excluded(idx) => Excluded(*idx),
112 Unbounded => Unbounded,
113 };
114
115 let end = match range.end_bound() {
116 Included(idx) => Included(*idx),
117 Excluded(idx) => Excluded(*idx),
118 Unbounded => Unbounded,
119 };
120
121 TensorIndexer::Narrow(start, end)
122 }
123 }
124 };
125}
126
127impl_from_range!(Range<i64>);
128impl_from_range!(RangeFrom<i64>);
129impl_from_range!(RangeFull);
130impl_from_range!(RangeInclusive<i64>);
131impl_from_range!(RangeTo<i64>);
132impl_from_range!(RangeToInclusive<i64>);
133
134pub trait IndexOp<T> {
135 fn i(&self, index: T) -> Tensor;
136 fn f_i(&self, index: T) -> Result<Tensor>;
137}
138
139impl<A> IndexOp<A> for Tensor
140where
141 A: Into<TensorIndexer>,
142{
143 fn i(&self, index: A) -> Tensor {
144 self.f_i(index).unwrap()
145 }
146
147 fn f_i(&self, index: A) -> Result<Tensor> {
148 self.f_indexer(&[index.into()])
149 }
150}
151
152impl<A> IndexOp<(A,)> for Tensor
153where
154 A: Into<TensorIndexer>,
155{
156 fn i(&self, index: (A,)) -> Tensor {
157 self.f_i(index).unwrap()
158 }
159
160 fn f_i(&self, index: (A,)) -> Result<Tensor> {
161 let idx_a = index.0.into();
162 self.f_indexer(&[idx_a])
163 }
164}
165
166impl<A, B> IndexOp<(A, B)> for Tensor
167where
168 A: Into<TensorIndexer>,
169 B: Into<TensorIndexer>,
170{
171 fn i(&self, index: (A, B)) -> Tensor {
172 self.f_i(index).unwrap()
173 }
174
175 fn f_i(&self, index: (A, B)) -> Result<Tensor> {
176 let idx_a = index.0.into();
177 let idx_b = index.1.into();
178 self.f_indexer(&[idx_a, idx_b])
179 }
180}
181
182impl<A, B, C> IndexOp<(A, B, C)> for Tensor
183where
184 A: Into<TensorIndexer>,
185 B: Into<TensorIndexer>,
186 C: Into<TensorIndexer>,
187{
188 fn i(&self, index: (A, B, C)) -> Tensor {
189 self.f_i(index).unwrap()
190 }
191
192 fn f_i(&self, index: (A, B, C)) -> Result<Tensor> {
193 let idx_a = index.0.into();
194 let idx_b = index.1.into();
195 let idx_c = index.2.into();
196 self.f_indexer(&[idx_a, idx_b, idx_c])
197 }
198}
199
200impl<A, B, C, D> IndexOp<(A, B, C, D)> for Tensor
201where
202 A: Into<TensorIndexer>,
203 B: Into<TensorIndexer>,
204 C: Into<TensorIndexer>,
205 D: Into<TensorIndexer>,
206{
207 fn i(&self, index: (A, B, C, D)) -> Tensor {
208 self.f_i(index).unwrap()
209 }
210
211 fn f_i(&self, index: (A, B, C, D)) -> Result<Tensor> {
212 let idx_a = index.0.into();
213 let idx_b = index.1.into();
214 let idx_c = index.2.into();
215 let idx_d = index.3.into();
216 self.f_indexer(&[idx_a, idx_b, idx_c, idx_d])
217 }
218}
219
220impl<A, B, C, D, E> IndexOp<(A, B, C, D, E)> for Tensor
221where
222 A: Into<TensorIndexer>,
223 B: Into<TensorIndexer>,
224 C: Into<TensorIndexer>,
225 D: Into<TensorIndexer>,
226 E: Into<TensorIndexer>,
227{
228 fn i(&self, index: (A, B, C, D, E)) -> Tensor {
229 self.f_i(index).unwrap()
230 }
231
232 fn f_i(&self, index: (A, B, C, D, E)) -> Result<Tensor> {
233 let idx_a = index.0.into();
234 let idx_b = index.1.into();
235 let idx_c = index.2.into();
236 let idx_d = index.3.into();
237 let idx_e = index.4.into();
238 self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e])
239 }
240}
241
242impl<A, B, C, D, E, F> IndexOp<(A, B, C, D, E, F)> for Tensor
243where
244 A: Into<TensorIndexer>,
245 B: Into<TensorIndexer>,
246 C: Into<TensorIndexer>,
247 D: Into<TensorIndexer>,
248 E: Into<TensorIndexer>,
249 F: Into<TensorIndexer>,
250{
251 fn i(&self, index: (A, B, C, D, E, F)) -> Tensor {
252 self.f_i(index).unwrap()
253 }
254
255 fn f_i(&self, index: (A, B, C, D, E, F)) -> Result<Tensor> {
256 let idx_a = index.0.into();
257 let idx_b = index.1.into();
258 let idx_c = index.2.into();
259 let idx_d = index.3.into();
260 let idx_e = index.4.into();
261 let idx_f = index.5.into();
262 self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f])
263 }
264}
265
266impl<A, B, C, D, E, F, G> IndexOp<(A, B, C, D, E, F, G)> for Tensor
267where
268 A: Into<TensorIndexer>,
269 B: Into<TensorIndexer>,
270 C: Into<TensorIndexer>,
271 D: Into<TensorIndexer>,
272 E: Into<TensorIndexer>,
273 F: Into<TensorIndexer>,
274 G: Into<TensorIndexer>,
275{
276 fn i(&self, index: (A, B, C, D, E, F, G)) -> Tensor {
277 self.f_i(index).unwrap()
278 }
279
280 fn f_i(&self, index: (A, B, C, D, E, F, G)) -> Result<Tensor> {
281 let idx_a = index.0.into();
282 let idx_b = index.1.into();
283 let idx_c = index.2.into();
284 let idx_d = index.3.into();
285 let idx_e = index.4.into();
286 let idx_f = index.5.into();
287 let idx_g = index.6.into();
288 self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g])
289 }
290}
291
292impl Tensor {
293 fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result<Tensor> {
294 use std::ops::Bound::*;
295 use TensorIndexer::*;
296
297 let n_newaxis = index_spec.iter().filter(|spec| *spec == &InsertNewAxis).count();
299
300 if index_spec.len() > self.size().len() + n_newaxis {
301 return Err(TchError::Shape(format!(
302 "too many indices for tensor of dimension {}",
303 self.size().len()
304 )));
305 }
306
307 for spec in index_spec.iter() {
309 use super::Kind::*;
310 if let IndexSelect(tensor) = spec {
311 if tensor.size().len() != 1 {
312 return Err(TchError::Shape(
313 "Multi-dimensional tensor is not supported for indexing".to_string(),
314 ));
315 }
316 match tensor.f_kind()? {
317 Int64 => {}
318 Int16 => {}
319 Int8 => {}
320 Int => {}
321 _ => {
322 return Err(TchError::Kind(format!("the kind of tensors used as indices must be one of {Int64:?}, {Int16:?}, {Int8:?}, {Int:?}")));
323 }
324 }
325 }
326 }
327
328 let mut curr_tensor = self.shallow_clone();
330 let mut curr_idx: i64 = 0;
331
332 for spec in index_spec.iter() {
333 let (next_tensor, next_idx) = match spec {
334 InsertNewAxis => (curr_tensor.unsqueeze(curr_idx), curr_idx + 1),
335 Select(index) => (
336 curr_tensor.select(curr_idx, *index),
337 curr_idx, ),
339 Narrow(start, end) => {
340 if let Some((start, length)) = match (start, end) {
341 (Unbounded, Unbounded) => None,
342 (Included(start), Unbounded) => {
343 let dim_len = curr_tensor.size()[curr_idx as usize];
344 Some((*start, dim_len - *start))
345 }
346 (Excluded(start), Unbounded) => {
347 let dim_len = curr_tensor.size()[curr_idx as usize];
348 Some((*start + 1, dim_len - *start - 1))
349 }
350 (Unbounded, Included(end)) => Some((0, *end + 1)),
351 (Unbounded, Excluded(end)) => Some((0, *end)),
352 (Included(start), Included(end)) => Some((*start, *end - *start + 1)),
353 (Included(start), Excluded(end)) => Some((*start, *end - *start)),
354 (Excluded(start), Included(end)) => Some((*start + 1, *end - *start)),
355 (Excluded(start), Excluded(end)) => Some((*start + 1, *end - *start - 1)),
356 } {
357 (curr_tensor.f_narrow(curr_idx, start, length.max(0))?, curr_idx + 1)
358 } else {
359 (curr_tensor, curr_idx + 1)
360 }
361 }
362 IndexSelect(index_tensor) => {
363 let index_tensor = index_tensor.to_device(curr_tensor.device());
364 (curr_tensor.index_select(curr_idx, &index_tensor), curr_idx + 1)
365 }
366 };
367 curr_tensor = next_tensor;
368 curr_idx = next_idx;
369 }
370 Ok(curr_tensor)
371 }
372}