tensor_rs/tensor_trait/index_slicing.rs
1use std::marker::Sized;
2
3pub trait IndexSlicing where Self: Sized {
4
5 /// Concatenates the given sequence of seq tensors
6 /// in the given dimension.
7 /// The input tensor should all have the same size except
8 /// on the given dimension.
9 /// The output tensor will have all the same size as the input
10 /// except the given dimension, which will be the sum of
11 /// the inputs on the given dimension.
12 /// Apply cat on [tensor(5, 3, 2), tensor(5, 7, 2), ]
13 /// will get a tensor(5, 10, 2).
14 fn cat(&self, tensors: &[Self], dim: usize) -> Self;
15
16 /// Splits a tensor into a specific number of chunks.
17 fn chunk(&self, chunks: usize, dim: usize) -> Vec<Self>;
18
19 /// Pick elements on the given dimension by the index,
20 /// and gather them in the output.
21 /// A restriction is that self.size() and index.size()
22 /// should be the same on other dimensions.
23 fn gather(&self, dim: usize, index: &Self) -> Self;
24 /// The opposite of gather.
25 /// Self will be replaced with value along dim by index.
26 fn spread(&self, dim: usize, index: &Self, value: &Self) -> Self;
27
28 /// Select on dim and collect those subtensor by index.
29 fn index_select(&self, dim: usize, index: &Self) -> Self;
30
31 /// Inverse of index_select, remove those subtensor by index along dim.
32 fn index_exclude(&self, dim: usize, index: &Self) -> Self;
33 // fn masked_select();
34 //pub fn narrow() {}
35 //pub fn nonzero() {}
36
37 /// Just change the index boundary.
38 fn reshape(&self, new_shape: &[usize]) -> Self;
39
40 /// Inverse of cat(), split tensor along dim dimension,
41 /// the length of each section on dim is specified by sections.
42 fn split(&self, sections: &[usize], dim: usize) -> Vec<Self>;
43
44 /// Remove dimension with length of 1.
45 fn squeeze(&self, dim: Option<usize>) -> Self;
46
47 /// Stack tensor with the same size along a new dimension
48 /// specified by dim.
49 /// The difference from cat is that cat don't create new dimension.
50 fn stack(&self, tensors: &[Self], dim: usize) -> Self;
51
52 /// Transpose
53 fn t(&self) -> Self;
54
55 /// Returns a new tensor with the elements of input at the given indices.
56 /// The input tensor is treated as if it were viewed as a 1-D tensor.
57 /// The result takes the same shape as the indices.
58 fn take(&self, index: &[usize]) -> Self;
59 //pub fn transpose() {}
60 //pub fn unbind() {}
61
62 ///
63 fn permute(&self, dims: &[usize]) -> Self;
64
65 /// Add size 1 dimension at dim.
66 fn unsqueeze(&self, dim: usize) -> Self;
67 //pub fn condition() {} // this is pytorch where
68
69 /// Self is the bool condition, at each position of self,
70 /// select from x if self at the position is positive or zero,
71 /// Otherwise , use value from y if self at the position is negative.
72 /// The restriction is that, self, x, and y all have the same size.
73 fn conditional_select(&self, x: &Self, y: &Self) -> Self;
74 /// Repeat the tensor along all dimensions,
75 /// the number of repeat is specified in sizes.
76 /// Thus the restriction is that self.size().len() is
77 /// equal to sizes.len().
78 fn repeat(&self, sizes: &[usize]) -> Self;
79}