1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use std::marker::Sized;

pub trait IndexSlicing where Self: Sized {

    /// Concatenates the given sequence of seq tensors
    /// in the given dimension.
    /// The input tensor should all have the same size except
    /// on the given dimension.
    /// The output tensor will have all the same size as the input
    /// except the given dimension, which will be the sum of
    /// the inputs on the given dimension.
    /// Apply cat on [tensor(5, 3, 2), tensor(5, 7, 2), ]
    /// will get a tensor(5, 10, 2).
    fn cat(&self, tensors: &[Self], dim: usize) -> Self;
    
    /// Splits a tensor into a specific number of chunks.
    fn chunk(&self, chunks: usize, dim: usize) -> Vec<Self>;
    
    /// Pick elements on the given dimension by the index given
    /// in index, and gather them in the output.
    /// A restriction is that self size and index size
    /// should be the same.
    fn gather(&self, dim: usize, index: &Self) -> Self;

    /// Select on dim and collect those subtensor by index.
    fn index_select(&self, dim: usize, index: &Self) -> Self;

    /// Inverse of index_select, remove those subtensor by index along dim.
    fn index_exclude(&self, dim: usize, index: &Self) -> Self;
    // fn masked_select();
    //pub fn narrow() {}
    //pub fn nonzero() {}

    /// Just change the index boundary.
    fn reshape(&self, new_shape: &[usize]) -> Self;

    /// Inverse of cat(), split tensor along dim dimension,
    /// the length of each section on dim is specified by sections.
    fn split(&self, sections: &[usize], dim: usize) -> Vec<Self>;

    /// Remove dimension with length of 1.
    fn squeeze(&self, dim: Option<usize>) -> Self;

    /// Stack tensor with the same size along a new dimension
    /// specified by dim.
    /// The difference from cat is that cat don't create new dimension.
    fn stack(&self, tensors: &[Self], dim: usize) -> Self;

    /// Transpose
    fn t(&self) -> Self;

    /// Returns a new tensor with the elements of input at the given indices. 
    /// The input tensor is treated as if it were viewed as a 1-D tensor.
    /// The result takes the same shape as the indices.
    fn take(&self, index: &[usize]) -> Self;
    //pub fn transpose() {}
    //pub fn unbind() {}

    /// 
    fn permute(&self, dims: &[usize]) -> Self;

    /// Add size 1 dimension at dim.
    fn unsqueeze(&self, dim: usize) -> Self;
    //pub fn condition() {} // this is pytorch where

    /// Self is the bool condition, at each position of self,
    /// select from x if self at the position is positive or zero,
    /// Otherwise , use value from y if self at the position is negative.
    /// The restriction is that, self, x, and y all have the same size.
    fn conditional_select(&self, x: &Self, y: &Self) -> Self;

    fn repeat(&self, sizes: &[usize]) -> Self;
}