1#[cfg(feature = "derive")]
2pub use tch_tensor_like_derive::TensorLike;
3
4use std::{
5 collections::{BTreeMap, HashMap, LinkedList, VecDeque},
6 hash::Hash,
7};
8use tch::{Device, Kind, TchError, Tensor};
9
10pub trait TensorLike
11where
12 Self: Sized,
13{
14 fn f_to_device(&self, device: Device) -> Result<Self, TchError>;
15 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError>;
16 fn shallow_clone(&self) -> Self;
17
18 fn to_device(&self, device: Device) -> Self {
19 self.f_to_device(device).unwrap()
20 }
21
22 fn to_kind(&self, kind: Kind) -> Self {
23 self.f_to_kind(kind).unwrap()
24 }
25}
26
27macro_rules! impl_for_primitive {
30 ($name:ty) => {
31 impl TensorLike for $name {
32 fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
33 Ok(*self)
34 }
35
36 fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
37 Ok(*self)
38 }
39
40 fn shallow_clone(&self) -> Self {
41 *self
42 }
43 }
44 };
45}
46
47impl_for_primitive!(bool);
48impl_for_primitive!(f32);
49impl_for_primitive!(f64);
50impl_for_primitive!(usize);
51impl_for_primitive!(u8);
52impl_for_primitive!(u16);
53impl_for_primitive!(u32);
54impl_for_primitive!(u64);
55impl_for_primitive!(u128);
56impl_for_primitive!(isize);
57impl_for_primitive!(i8);
58impl_for_primitive!(i16);
59impl_for_primitive!(i32);
60impl_for_primitive!(i64);
61impl_for_primitive!(i128);
62
63impl<T> TensorLike for &T {
66 fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
67 Ok(*self)
68 }
69
70 fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
71 Ok(*self)
72 }
73
74 fn shallow_clone(&self) -> Self {
75 *self
76 }
77}
78
79impl<T> TensorLike for *const T {
82 fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
83 Ok(*self)
84 }
85
86 fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
87 Ok(*self)
88 }
89
90 fn shallow_clone(&self) -> Self {
91 *self
92 }
93}
94
95impl<T> TensorLike for *mut T {
96 fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
97 Ok(*self)
98 }
99
100 fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
101 Ok(*self)
102 }
103
104 fn shallow_clone(&self) -> Self {
105 *self
106 }
107}
108
109impl<T1> TensorLike for (T1,)
112where
113 T1: TensorLike,
114{
115 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
116 Ok((self.0.f_to_device(device)?,))
117 }
118
119 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
120 Ok((self.0.f_to_kind(kind)?,))
121 }
122
123 fn shallow_clone(&self) -> Self {
124 (self.0.shallow_clone(),)
125 }
126}
127
128impl<T1, T2> TensorLike for (T1, T2)
129where
130 T1: TensorLike,
131 T2: TensorLike,
132{
133 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
134 Ok((self.0.f_to_device(device)?, self.1.f_to_device(device)?))
135 }
136
137 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
138 Ok((self.0.f_to_kind(kind)?, self.1.f_to_kind(kind)?))
139 }
140
141 fn shallow_clone(&self) -> Self {
142 (self.0.shallow_clone(), self.1.shallow_clone())
143 }
144}
145
146impl<T1, T2, T3> TensorLike for (T1, T2, T3)
147where
148 T1: TensorLike,
149 T2: TensorLike,
150 T3: TensorLike,
151{
152 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
153 Ok((
154 self.0.f_to_device(device)?,
155 self.1.f_to_device(device)?,
156 self.2.f_to_device(device)?,
157 ))
158 }
159
160 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
161 Ok((
162 self.0.f_to_kind(kind)?,
163 self.1.f_to_kind(kind)?,
164 self.2.f_to_kind(kind)?,
165 ))
166 }
167
168 fn shallow_clone(&self) -> Self {
169 (
170 self.0.shallow_clone(),
171 self.1.shallow_clone(),
172 self.2.shallow_clone(),
173 )
174 }
175}
176
177impl<T1, T2, T3, T4> TensorLike for (T1, T2, T3, T4)
178where
179 T1: TensorLike,
180 T2: TensorLike,
181 T3: TensorLike,
182 T4: TensorLike,
183{
184 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
185 Ok((
186 self.0.f_to_device(device)?,
187 self.1.f_to_device(device)?,
188 self.2.f_to_device(device)?,
189 self.3.f_to_device(device)?,
190 ))
191 }
192
193 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
194 Ok((
195 self.0.f_to_kind(kind)?,
196 self.1.f_to_kind(kind)?,
197 self.2.f_to_kind(kind)?,
198 self.3.f_to_kind(kind)?,
199 ))
200 }
201
202 fn shallow_clone(&self) -> Self {
203 (
204 self.0.shallow_clone(),
205 self.1.shallow_clone(),
206 self.2.shallow_clone(),
207 self.3.shallow_clone(),
208 )
209 }
210}
211
212impl<T1, T2, T3, T4, T5> TensorLike for (T1, T2, T3, T4, T5)
213where
214 T1: TensorLike,
215 T2: TensorLike,
216 T3: TensorLike,
217 T4: TensorLike,
218 T5: TensorLike,
219{
220 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
221 Ok((
222 self.0.f_to_device(device)?,
223 self.1.f_to_device(device)?,
224 self.2.f_to_device(device)?,
225 self.3.f_to_device(device)?,
226 self.4.f_to_device(device)?,
227 ))
228 }
229
230 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
231 Ok((
232 self.0.f_to_kind(kind)?,
233 self.1.f_to_kind(kind)?,
234 self.2.f_to_kind(kind)?,
235 self.3.f_to_kind(kind)?,
236 self.4.f_to_kind(kind)?,
237 ))
238 }
239
240 fn shallow_clone(&self) -> Self {
241 (
242 self.0.shallow_clone(),
243 self.1.shallow_clone(),
244 self.2.shallow_clone(),
245 self.3.shallow_clone(),
246 self.4.shallow_clone(),
247 )
248 }
249}
250
251impl TensorLike for Tensor {
254 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
255 self.f_to_device(device)
256 }
257
258 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
259 self.f_to_kind(kind)
260 }
261
262 fn shallow_clone(&self) -> Self {
263 self.shallow_clone()
264 }
265}
266
267impl<T> TensorLike for Vec<T>
270where
271 T: TensorLike,
272{
273 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
274 self.iter()
275 .map(|tensor| tensor.f_to_device(device))
276 .collect()
277 }
278
279 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
280 self.iter().map(|tensor| tensor.f_to_kind(kind)).collect()
281 }
282
283 fn shallow_clone(&self) -> Self {
284 self.iter().map(|tensor| tensor.shallow_clone()).collect()
285 }
286}
287
288impl<T> TensorLike for LinkedList<T>
289where
290 T: TensorLike,
291{
292 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
293 self.iter()
294 .map(|tensor| tensor.f_to_device(device))
295 .collect()
296 }
297
298 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
299 self.iter().map(|tensor| tensor.f_to_kind(kind)).collect()
300 }
301
302 fn shallow_clone(&self) -> Self {
303 self.iter().map(|tensor| tensor.shallow_clone()).collect()
304 }
305}
306
307impl<T> TensorLike for VecDeque<T>
308where
309 T: TensorLike,
310{
311 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
312 self.iter()
313 .map(|tensor| tensor.f_to_device(device))
314 .collect()
315 }
316
317 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
318 self.iter().map(|tensor| tensor.f_to_kind(kind)).collect()
319 }
320
321 fn shallow_clone(&self) -> Self {
322 self.iter().map(|tensor| tensor.shallow_clone()).collect()
323 }
324}
325
326impl<K, T> TensorLike for HashMap<K, T>
327where
328 K: Eq + Hash + Clone,
329 T: TensorLike,
330{
331 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
332 self.iter()
333 .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_device(device)?)))
334 .collect()
335 }
336
337 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
338 self.iter()
339 .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_kind(kind)?)))
340 .collect()
341 }
342
343 fn shallow_clone(&self) -> Self {
344 self.iter()
345 .map(|(key, tensor)| (key.clone(), tensor.shallow_clone()))
346 .collect()
347 }
348}
349
350impl<K, T> TensorLike for BTreeMap<K, T>
351where
352 K: Ord + Clone,
353 T: TensorLike,
354{
355 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
356 self.iter()
357 .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_device(device)?)))
358 .collect()
359 }
360
361 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
362 self.iter()
363 .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_kind(kind)?)))
364 .collect()
365 }
366
367 fn shallow_clone(&self) -> Self {
368 self.iter()
369 .map(|(key, tensor)| (key.clone(), tensor.shallow_clone()))
370 .collect()
371 }
372}
373
374impl<T> TensorLike for Option<T>
377where
378 T: TensorLike,
379{
380 fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
381 self.as_ref()
382 .map(|tensor| tensor.f_to_device(device))
383 .transpose()
384 }
385
386 fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
387 self.as_ref()
388 .map(|tensor| tensor.f_to_kind(kind))
389 .transpose()
390 }
391
392 fn shallow_clone(&self) -> Self {
393 self.as_ref().map(|tensor| tensor.shallow_clone())
394 }
395}