tensor_rs/tensor_trait/
index_slicing.rs

1use std::marker::Sized;
2
3pub trait IndexSlicing where Self: Sized {
4
5    /// Concatenates the given sequence of seq tensors
6    /// in the given dimension.
7    /// The input tensor should all have the same size except
8    /// on the given dimension.
9    /// The output tensor will have all the same size as the input
10    /// except the given dimension, which will be the sum of
11    /// the inputs on the given dimension.
12    /// Apply cat on [tensor(5, 3, 2), tensor(5, 7, 2), ]
13    /// will get a tensor(5, 10, 2).
14    fn cat(&self, tensors: &[Self], dim: usize) -> Self;
15    
16    /// Splits a tensor into a specific number of chunks.
17    fn chunk(&self, chunks: usize, dim: usize) -> Vec<Self>;
18    
19    /// Pick elements on the given dimension by the index,
20    /// and gather them in the output.
21    /// A restriction is that self.size() and index.size()
22    /// should be the same on other dimensions.
23    fn gather(&self, dim: usize, index: &Self) -> Self;
24    /// The opposite of gather.
25    /// Self will be replaced with value along dim by index.
26    fn spread(&self, dim: usize, index: &Self, value: &Self) -> Self;
27
28    /// Select on dim and collect those subtensor by index.
29    fn index_select(&self, dim: usize, index: &Self) -> Self;
30
31    /// Inverse of index_select, remove those subtensor by index along dim.
32    fn index_exclude(&self, dim: usize, index: &Self) -> Self;
33    // fn masked_select();
34    //pub fn narrow() {}
35    //pub fn nonzero() {}
36
37    /// Just change the index boundary.
38    fn reshape(&self, new_shape: &[usize]) -> Self;
39
40    /// Inverse of cat(), split tensor along dim dimension,
41    /// the length of each section on dim is specified by sections.
42    fn split(&self, sections: &[usize], dim: usize) -> Vec<Self>;
43
44    /// Remove dimension with length of 1.
45    fn squeeze(&self, dim: Option<usize>) -> Self;
46
47    /// Stack tensor with the same size along a new dimension
48    /// specified by dim.
49    /// The difference from cat is that cat don't create new dimension.
50    fn stack(&self, tensors: &[Self], dim: usize) -> Self;
51
52    /// Transpose
53    fn t(&self) -> Self;
54
55    /// Returns a new tensor with the elements of input at the given indices. 
56    /// The input tensor is treated as if it were viewed as a 1-D tensor.
57    /// The result takes the same shape as the indices.
58    fn take(&self, index: &[usize]) -> Self;
59    //pub fn transpose() {}
60    //pub fn unbind() {}
61
62    /// 
63    fn permute(&self, dims: &[usize]) -> Self;
64
65    /// Add size 1 dimension at dim.
66    fn unsqueeze(&self, dim: usize) -> Self;
67    //pub fn condition() {} // this is pytorch where
68
69    /// Self is the bool condition, at each position of self,
70    /// select from x if self at the position is positive or zero,
71    /// Otherwise , use value from y if self at the position is negative.
72    /// The restriction is that, self, x, and y all have the same size.
73    fn conditional_select(&self, x: &Self, y: &Self) -> Self;
74    /// Repeat the tensor along all dimensions,
75    /// the number of repeat is specified in sizes.
76    /// Thus the restriction is that self.size().len() is
77    /// equal to sizes.len().
78    fn repeat(&self, sizes: &[usize]) -> Self;
79}