1use crate::{Device, Kind, TchError};
3
4mod convert;
5pub mod display;
6pub mod index;
7mod iter;
8mod npy;
9mod ops;
10mod safetensors;
11
12pub use super::wrappers::tensor::{
13 autocast, no_grad, no_grad_guard, with_grad, NoGradGuard, Reduction, Tensor,
14};
15pub use index::{IndexOp, NewAxis, TensorIndexer};
16
17pub trait Shape {
18 fn to_shape(&self) -> Box<[i64]>;
19}
20
21macro_rules! impl_shape {
22 ($v:expr) => {
23 impl Shape for [i64; $v] {
24 fn to_shape(&self) -> Box<[i64]> {
25 Box::new(*self)
26 }
27 }
28 };
29}
30
31impl_shape!(0);
32impl_shape!(1);
33impl_shape!(2);
34impl_shape!(3);
35impl_shape!(4);
36impl_shape!(5);
37impl_shape!(6);
38
39impl Shape for () {
40 fn to_shape(&self) -> Box<[i64]> {
41 Box::new([])
42 }
43}
44
45impl Shape for &[i64] {
46 fn to_shape(&self) -> Box<[i64]> {
47 (*self).into()
48 }
49}
50
51impl Shape for i64 {
52 fn to_shape(&self) -> Box<[i64]> {
53 Box::new([*self])
54 }
55}
56
57impl Shape for usize {
58 fn to_shape(&self) -> Box<[i64]> {
59 Box::new([*self as i64])
60 }
61}
62
63impl Shape for i32 {
64 fn to_shape(&self) -> Box<[i64]> {
65 Box::new([i64::from(*self)])
66 }
67}
68
69impl Shape for (i64,) {
70 fn to_shape(&self) -> Box<[i64]> {
71 Box::new([self.0])
72 }
73}
74
75impl Shape for (i64, i64) {
76 fn to_shape(&self) -> Box<[i64]> {
77 Box::new([self.0, self.1])
78 }
79}
80
81impl Shape for (i64, i64, i64) {
82 fn to_shape(&self) -> Box<[i64]> {
83 Box::new([self.0, self.1, self.2])
84 }
85}
86
87impl Shape for (i64, i64, i64, i64) {
88 fn to_shape(&self) -> Box<[i64]> {
89 Box::new([self.0, self.1, self.2, self.3])
90 }
91}
92
93impl Tensor {
94 pub fn f_view<T: Shape>(&self, s: T) -> Result<Tensor, TchError> {
95 self.f_view_(&s.to_shape())
96 }
97
98 pub fn view<T: Shape>(&self, s: T) -> Tensor {
99 self.view_(&s.to_shape())
100 }
101
102 pub fn f_zero_pad1d(&self, left: i64, right: i64) -> Result<Tensor, TchError> {
103 if self.dim() != 3 {
104 return Err(TchError::Shape(format!(
105 "expected a 3 dimension tensor, got {:?}",
106 self.size()
107 )));
108 }
109 self.f_constant_pad_nd([left, right])
110 }
111
112 pub fn zero_pad1d(&self, left: i64, right: i64) -> Tensor {
113 self.f_zero_pad1d(left, right).unwrap()
114 }
115
116 pub fn f_zero_pad2d(
117 &self,
118 left: i64,
119 right: i64,
120 top: i64,
121 bottom: i64,
122 ) -> Result<Tensor, TchError> {
123 if self.dim() != 4 {
124 return Err(TchError::Shape(format!(
125 "expected a 4 dimension tensor, got {:?}",
126 self.size()
127 )));
128 }
129 self.f_constant_pad_nd([left, right, top, bottom])
130 }
131
132 pub fn zero_pad2d(&self, left: i64, right: i64, top: i64, bottom: i64) -> Tensor {
133 self.f_zero_pad2d(left, right, top, bottom).unwrap()
134 }
135}
136
137impl<T: crate::kind::Element> From<&[T]> for Tensor {
138 fn from(v: &[T]) -> Tensor {
139 Tensor::from_slice(v)
140 }
141}
142
143impl<T: crate::kind::Element> From<T> for Tensor {
144 fn from(v: T) -> Tensor {
145 Tensor::from_slice(&[v]).view(())
146 }
147}
148impl Tensor {
149 pub fn to_kind(&self, kind: Kind) -> Tensor {
151 self.totype(kind)
152 }
153
154 pub fn f_to_kind(&self, kind: Kind) -> Result<Tensor, TchError> {
155 self.f_totype(kind)
156 }
157
158 pub fn nll_loss(&self, targets: &Tensor) -> Tensor {
159 self.g_nll_loss::<Tensor>(targets, None, Reduction::Mean, -100)
160 }
161}
162
163impl Tensor {
164 pub fn cross_entropy_for_logits(&self, targets: &Tensor) -> Tensor {
166 self.log_softmax(-1, Kind::Float).nll_loss(targets)
167 }
168
169 pub fn accuracy_for_logits(&self, targets: &Tensor) -> Tensor {
172 self.argmax(-1, false).eq_tensor(targets).to_kind(Kind::Float).mean(Kind::Float)
173 }
174
175 pub fn random_batch(&self, batch_size: i64) -> Tensor {
176 let len: i64 = self.size()[0];
177 let index = Tensor::randint(len, [batch_size], (Kind::Int64, self.device()));
178 self.index_select(0, &index)
179 }
180
181 pub fn random_batch2(t1: &Tensor, t2: &Tensor, batch_size: i64) -> (Tensor, Tensor) {
182 let len1: i64 = t1.size()[0];
183 let len2: i64 = t2.size()[0];
184 if len1 != len2 {
185 panic!("random_batch2: shape mismatch {:?} {:?}", t1.size(), t2.size())
186 }
187 let device1 = t1.device();
188 let device2 = t2.device();
189 if device1 != device2 {
190 panic!("random_batch2: device mismatch {device1:?} {device2:?}")
191 }
192 let index = Tensor::randint(len1, [batch_size], (Kind::Int64, device1));
193 let batch1 = t1.index_select(0, &index);
194 let batch2 = t2.index_select(0, &index);
195 (batch1, batch2)
196 }
197
198 pub fn to_device(&self, device: Device) -> Tensor {
200 self.to(device)
201 }
202
203 pub fn f_to_device(&self, device: Device) -> Result<Tensor, TchError> {
204 self.f_to(device)
205 }
206
207 pub fn avg_pool2d_default(&self, ksize: i64) -> Tensor {
208 self.avg_pool2d([ksize, ksize], [ksize, ksize], [0, 0], false, true, 1)
209 }
210
211 pub fn max_pool2d_default(&self, ksize: i64) -> Tensor {
212 self.max_pool2d([ksize, ksize], [ksize, ksize], [0, 0], [1, 1], false)
213 }
214
215 pub fn flat_view(&self) -> Tensor {
220 self.view((self.size()[0], -1))
221 }
222
223 pub fn onehot(&self, labels: i64) -> Tensor {
229 Tensor::zeros([self.size(), vec![labels]].concat(), (Kind::Float, self.device()))
230 .scatter_value_(-1, &self.unsqueeze(-1).to_kind(Kind::Int64), 1.0)
231 }
232
233 pub fn copy(&self) -> Tensor {
235 let mut result = self.zeros_like();
236 result.copy_(self);
237 result
238 }
239
240 pub fn from_slice2<T, U>(v: &[U]) -> Tensor
242 where
243 T: crate::kind::Element,
244 U: AsRef<[T]>,
245 {
246 let inner: Vec<Tensor> = v.iter().map(|v| Tensor::from_slice(v.as_ref())).collect();
247 Tensor::stack(&inner, 0)
248 }
249
250 pub fn to_mkldnn(&self) -> Tensor {
251 self.g_to_mkldnn(self.kind())
252 }
253}