tch_plus/tensor/
mod.rs

1//! A Torch tensor.
2use crate::{Device, Kind, TchError};
3
4mod convert;
5pub mod display;
6pub mod index;
7mod iter;
8mod npy;
9mod ops;
10mod safetensors;
11
12pub use super::wrappers::tensor::{
13    autocast, no_grad, no_grad_guard, with_grad, NoGradGuard, Reduction, Tensor,
14};
15pub use index::{IndexOp, NewAxis, TensorIndexer};
16
17pub trait Shape {
18    fn to_shape(&self) -> Box<[i64]>;
19}
20
21macro_rules! impl_shape {
22    ($v:expr) => {
23        impl Shape for [i64; $v] {
24            fn to_shape(&self) -> Box<[i64]> {
25                Box::new(*self)
26            }
27        }
28    };
29}
30
31impl_shape!(0);
32impl_shape!(1);
33impl_shape!(2);
34impl_shape!(3);
35impl_shape!(4);
36impl_shape!(5);
37impl_shape!(6);
38
39impl Shape for () {
40    fn to_shape(&self) -> Box<[i64]> {
41        Box::new([])
42    }
43}
44
45impl Shape for &[i64] {
46    fn to_shape(&self) -> Box<[i64]> {
47        (*self).into()
48    }
49}
50
51impl Shape for i64 {
52    fn to_shape(&self) -> Box<[i64]> {
53        Box::new([*self])
54    }
55}
56
57impl Shape for usize {
58    fn to_shape(&self) -> Box<[i64]> {
59        Box::new([*self as i64])
60    }
61}
62
63impl Shape for i32 {
64    fn to_shape(&self) -> Box<[i64]> {
65        Box::new([i64::from(*self)])
66    }
67}
68
69impl Shape for (i64,) {
70    fn to_shape(&self) -> Box<[i64]> {
71        Box::new([self.0])
72    }
73}
74
75impl Shape for (i64, i64) {
76    fn to_shape(&self) -> Box<[i64]> {
77        Box::new([self.0, self.1])
78    }
79}
80
81impl Shape for (i64, i64, i64) {
82    fn to_shape(&self) -> Box<[i64]> {
83        Box::new([self.0, self.1, self.2])
84    }
85}
86
87impl Shape for (i64, i64, i64, i64) {
88    fn to_shape(&self) -> Box<[i64]> {
89        Box::new([self.0, self.1, self.2, self.3])
90    }
91}
92
93impl Tensor {
94    pub fn f_view<T: Shape>(&self, s: T) -> Result<Tensor, TchError> {
95        self.f_view_(&s.to_shape())
96    }
97
98    pub fn view<T: Shape>(&self, s: T) -> Tensor {
99        self.view_(&s.to_shape())
100    }
101
102    pub fn f_zero_pad1d(&self, left: i64, right: i64) -> Result<Tensor, TchError> {
103        if self.dim() != 3 {
104            return Err(TchError::Shape(format!(
105                "expected a 3 dimension tensor, got {:?}",
106                self.size()
107            )));
108        }
109        self.f_constant_pad_nd([left, right])
110    }
111
112    pub fn zero_pad1d(&self, left: i64, right: i64) -> Tensor {
113        self.f_zero_pad1d(left, right).unwrap()
114    }
115
116    pub fn f_zero_pad2d(
117        &self,
118        left: i64,
119        right: i64,
120        top: i64,
121        bottom: i64,
122    ) -> Result<Tensor, TchError> {
123        if self.dim() != 4 {
124            return Err(TchError::Shape(format!(
125                "expected a 4 dimension tensor, got {:?}",
126                self.size()
127            )));
128        }
129        self.f_constant_pad_nd([left, right, top, bottom])
130    }
131
132    pub fn zero_pad2d(&self, left: i64, right: i64, top: i64, bottom: i64) -> Tensor {
133        self.f_zero_pad2d(left, right, top, bottom).unwrap()
134    }
135}
136
137impl<T: crate::kind::Element> From<&[T]> for Tensor {
138    fn from(v: &[T]) -> Tensor {
139        Tensor::from_slice(v)
140    }
141}
142
143impl<T: crate::kind::Element> From<T> for Tensor {
144    fn from(v: T) -> Tensor {
145        Tensor::from_slice(&[v]).view(())
146    }
147}
148impl Tensor {
149    /// Casts a tensor to a specified kind.
150    pub fn to_kind(&self, kind: Kind) -> Tensor {
151        self.totype(kind)
152    }
153
154    pub fn f_to_kind(&self, kind: Kind) -> Result<Tensor, TchError> {
155        self.f_totype(kind)
156    }
157
158    pub fn nll_loss(&self, targets: &Tensor) -> Tensor {
159        self.g_nll_loss::<Tensor>(targets, None, Reduction::Mean, -100)
160    }
161}
162
163impl Tensor {
164    /// Computes the cross-entropy loss based on some logits and targets.
165    pub fn cross_entropy_for_logits(&self, targets: &Tensor) -> Tensor {
166        self.log_softmax(-1, Kind::Float).nll_loss(targets)
167    }
168
169    /// Returns the average accuracy for some given logits assuming that
170    /// targets represent ground-truth.
171    pub fn accuracy_for_logits(&self, targets: &Tensor) -> Tensor {
172        self.argmax(-1, false).eq_tensor(targets).to_kind(Kind::Float).mean(Kind::Float)
173    }
174
175    pub fn random_batch(&self, batch_size: i64) -> Tensor {
176        let len: i64 = self.size()[0];
177        let index = Tensor::randint(len, [batch_size], (Kind::Int64, self.device()));
178        self.index_select(0, &index)
179    }
180
181    pub fn random_batch2(t1: &Tensor, t2: &Tensor, batch_size: i64) -> (Tensor, Tensor) {
182        let len1: i64 = t1.size()[0];
183        let len2: i64 = t2.size()[0];
184        if len1 != len2 {
185            panic!("random_batch2: shape mismatch {:?} {:?}", t1.size(), t2.size())
186        }
187        let device1 = t1.device();
188        let device2 = t2.device();
189        if device1 != device2 {
190            panic!("random_batch2: device mismatch {device1:?} {device2:?}")
191        }
192        let index = Tensor::randint(len1, [batch_size], (Kind::Int64, device1));
193        let batch1 = t1.index_select(0, &index);
194        let batch2 = t2.index_select(0, &index);
195        (batch1, batch2)
196    }
197
198    /// Moves a tensor to a specified device.
199    pub fn to_device(&self, device: Device) -> Tensor {
200        self.to(device)
201    }
202
203    pub fn f_to_device(&self, device: Device) -> Result<Tensor, TchError> {
204        self.f_to(device)
205    }
206
207    pub fn avg_pool2d_default(&self, ksize: i64) -> Tensor {
208        self.avg_pool2d([ksize, ksize], [ksize, ksize], [0, 0], false, true, 1)
209    }
210
211    pub fn max_pool2d_default(&self, ksize: i64) -> Tensor {
212        self.max_pool2d([ksize, ksize], [ksize, ksize], [0, 0], [1, 1], false)
213    }
214
215    /// Flattens a tensor.
216    ///
217    /// This returns a flattened version of the given tensor. The first dimension
218    /// is preserved as it is assumed to be the mini-batch dimension.
219    pub fn flat_view(&self) -> Tensor {
220        self.view((self.size()[0], -1))
221    }
222
223    /// Converts a tensor to a one-hot encoded version.
224    ///
225    /// If the input has a size [N1, N2, ..., Nk], the returned tensor has a size
226    /// [N1, ..., Nk, labels]. The returned tensor uses float values.
227    /// Elements of the input vector are expected to be between 0 and labels-1.
228    pub fn onehot(&self, labels: i64) -> Tensor {
229        Tensor::zeros([self.size(), vec![labels]].concat(), (Kind::Float, self.device()))
230            .scatter_value_(-1, &self.unsqueeze(-1).to_kind(Kind::Int64), 1.0)
231    }
232
233    /// Copies a tensor to a newly allocated tensor using the same shape and device.
234    pub fn copy(&self) -> Tensor {
235        let mut result = self.zeros_like();
236        result.copy_(self);
237        result
238    }
239
240    /// Copies the data from a two dimensional slice in a tensor object.
241    pub fn from_slice2<T, U>(v: &[U]) -> Tensor
242    where
243        T: crate::kind::Element,
244        U: AsRef<[T]>,
245    {
246        let inner: Vec<Tensor> = v.iter().map(|v| Tensor::from_slice(v.as_ref())).collect();
247        Tensor::stack(&inner, 0)
248    }
249
250    pub fn to_mkldnn(&self) -> Tensor {
251        self.g_to_mkldnn(self.kind())
252    }
253}