tch_tensor_like/
lib.rs

1#[cfg(feature = "derive")]
2pub use tch_tensor_like_derive::TensorLike;
3
4use std::{
5    collections::{BTreeMap, HashMap, LinkedList, VecDeque},
6    hash::Hash,
7};
8use tch::{Device, Kind, TchError, Tensor};
9
10pub trait TensorLike
11where
12    Self: Sized,
13{
14    fn f_to_device(&self, device: Device) -> Result<Self, TchError>;
15    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError>;
16    fn shallow_clone(&self) -> Self;
17
18    fn to_device(&self, device: Device) -> Self {
19        self.f_to_device(device).unwrap()
20    }
21
22    fn to_kind(&self, kind: Kind) -> Self {
23        self.f_to_kind(kind).unwrap()
24    }
25}
26
27// primitives
28
29macro_rules! impl_for_primitive {
30    ($name:ty) => {
31        impl TensorLike for $name {
32            fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
33                Ok(*self)
34            }
35
36            fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
37                Ok(*self)
38            }
39
40            fn shallow_clone(&self) -> Self {
41                *self
42            }
43        }
44    };
45}
46
47impl_for_primitive!(bool);
48impl_for_primitive!(f32);
49impl_for_primitive!(f64);
50impl_for_primitive!(usize);
51impl_for_primitive!(u8);
52impl_for_primitive!(u16);
53impl_for_primitive!(u32);
54impl_for_primitive!(u64);
55impl_for_primitive!(u128);
56impl_for_primitive!(isize);
57impl_for_primitive!(i8);
58impl_for_primitive!(i16);
59impl_for_primitive!(i32);
60impl_for_primitive!(i64);
61impl_for_primitive!(i128);
62
63// reference
64
65impl<T> TensorLike for &T {
66    fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
67        Ok(*self)
68    }
69
70    fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
71        Ok(*self)
72    }
73
74    fn shallow_clone(&self) -> Self {
75        *self
76    }
77}
78
79// pointer
80
81impl<T> TensorLike for *const T {
82    fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
83        Ok(*self)
84    }
85
86    fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
87        Ok(*self)
88    }
89
90    fn shallow_clone(&self) -> Self {
91        *self
92    }
93}
94
95impl<T> TensorLike for *mut T {
96    fn f_to_device(&self, _device: Device) -> Result<Self, TchError> {
97        Ok(*self)
98    }
99
100    fn f_to_kind(&self, _kind: Kind) -> Result<Self, TchError> {
101        Ok(*self)
102    }
103
104    fn shallow_clone(&self) -> Self {
105        *self
106    }
107}
108
109// tuples
110
111impl<T1> TensorLike for (T1,)
112where
113    T1: TensorLike,
114{
115    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
116        Ok((self.0.f_to_device(device)?,))
117    }
118
119    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
120        Ok((self.0.f_to_kind(kind)?,))
121    }
122
123    fn shallow_clone(&self) -> Self {
124        (self.0.shallow_clone(),)
125    }
126}
127
128impl<T1, T2> TensorLike for (T1, T2)
129where
130    T1: TensorLike,
131    T2: TensorLike,
132{
133    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
134        Ok((self.0.f_to_device(device)?, self.1.f_to_device(device)?))
135    }
136
137    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
138        Ok((self.0.f_to_kind(kind)?, self.1.f_to_kind(kind)?))
139    }
140
141    fn shallow_clone(&self) -> Self {
142        (self.0.shallow_clone(), self.1.shallow_clone())
143    }
144}
145
146impl<T1, T2, T3> TensorLike for (T1, T2, T3)
147where
148    T1: TensorLike,
149    T2: TensorLike,
150    T3: TensorLike,
151{
152    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
153        Ok((
154            self.0.f_to_device(device)?,
155            self.1.f_to_device(device)?,
156            self.2.f_to_device(device)?,
157        ))
158    }
159
160    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
161        Ok((
162            self.0.f_to_kind(kind)?,
163            self.1.f_to_kind(kind)?,
164            self.2.f_to_kind(kind)?,
165        ))
166    }
167
168    fn shallow_clone(&self) -> Self {
169        (
170            self.0.shallow_clone(),
171            self.1.shallow_clone(),
172            self.2.shallow_clone(),
173        )
174    }
175}
176
177impl<T1, T2, T3, T4> TensorLike for (T1, T2, T3, T4)
178where
179    T1: TensorLike,
180    T2: TensorLike,
181    T3: TensorLike,
182    T4: TensorLike,
183{
184    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
185        Ok((
186            self.0.f_to_device(device)?,
187            self.1.f_to_device(device)?,
188            self.2.f_to_device(device)?,
189            self.3.f_to_device(device)?,
190        ))
191    }
192
193    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
194        Ok((
195            self.0.f_to_kind(kind)?,
196            self.1.f_to_kind(kind)?,
197            self.2.f_to_kind(kind)?,
198            self.3.f_to_kind(kind)?,
199        ))
200    }
201
202    fn shallow_clone(&self) -> Self {
203        (
204            self.0.shallow_clone(),
205            self.1.shallow_clone(),
206            self.2.shallow_clone(),
207            self.3.shallow_clone(),
208        )
209    }
210}
211
212impl<T1, T2, T3, T4, T5> TensorLike for (T1, T2, T3, T4, T5)
213where
214    T1: TensorLike,
215    T2: TensorLike,
216    T3: TensorLike,
217    T4: TensorLike,
218    T5: TensorLike,
219{
220    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
221        Ok((
222            self.0.f_to_device(device)?,
223            self.1.f_to_device(device)?,
224            self.2.f_to_device(device)?,
225            self.3.f_to_device(device)?,
226            self.4.f_to_device(device)?,
227        ))
228    }
229
230    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
231        Ok((
232            self.0.f_to_kind(kind)?,
233            self.1.f_to_kind(kind)?,
234            self.2.f_to_kind(kind)?,
235            self.3.f_to_kind(kind)?,
236            self.4.f_to_kind(kind)?,
237        ))
238    }
239
240    fn shallow_clone(&self) -> Self {
241        (
242            self.0.shallow_clone(),
243            self.1.shallow_clone(),
244            self.2.shallow_clone(),
245            self.3.shallow_clone(),
246            self.4.shallow_clone(),
247        )
248    }
249}
250
251// tensor
252
253impl TensorLike for Tensor {
254    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
255        self.f_to_device(device)
256    }
257
258    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
259        self.f_to_kind(kind)
260    }
261
262    fn shallow_clone(&self) -> Self {
263        self.shallow_clone()
264    }
265}
266
267// collections
268
269impl<T> TensorLike for Vec<T>
270where
271    T: TensorLike,
272{
273    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
274        self.iter()
275            .map(|tensor| tensor.f_to_device(device))
276            .collect()
277    }
278
279    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
280        self.iter().map(|tensor| tensor.f_to_kind(kind)).collect()
281    }
282
283    fn shallow_clone(&self) -> Self {
284        self.iter().map(|tensor| tensor.shallow_clone()).collect()
285    }
286}
287
288impl<T> TensorLike for LinkedList<T>
289where
290    T: TensorLike,
291{
292    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
293        self.iter()
294            .map(|tensor| tensor.f_to_device(device))
295            .collect()
296    }
297
298    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
299        self.iter().map(|tensor| tensor.f_to_kind(kind)).collect()
300    }
301
302    fn shallow_clone(&self) -> Self {
303        self.iter().map(|tensor| tensor.shallow_clone()).collect()
304    }
305}
306
307impl<T> TensorLike for VecDeque<T>
308where
309    T: TensorLike,
310{
311    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
312        self.iter()
313            .map(|tensor| tensor.f_to_device(device))
314            .collect()
315    }
316
317    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
318        self.iter().map(|tensor| tensor.f_to_kind(kind)).collect()
319    }
320
321    fn shallow_clone(&self) -> Self {
322        self.iter().map(|tensor| tensor.shallow_clone()).collect()
323    }
324}
325
326impl<K, T> TensorLike for HashMap<K, T>
327where
328    K: Eq + Hash + Clone,
329    T: TensorLike,
330{
331    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
332        self.iter()
333            .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_device(device)?)))
334            .collect()
335    }
336
337    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
338        self.iter()
339            .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_kind(kind)?)))
340            .collect()
341    }
342
343    fn shallow_clone(&self) -> Self {
344        self.iter()
345            .map(|(key, tensor)| (key.clone(), tensor.shallow_clone()))
346            .collect()
347    }
348}
349
350impl<K, T> TensorLike for BTreeMap<K, T>
351where
352    K: Ord + Clone,
353    T: TensorLike,
354{
355    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
356        self.iter()
357            .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_device(device)?)))
358            .collect()
359    }
360
361    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
362        self.iter()
363            .map(|(key, tensor)| Ok((key.clone(), tensor.f_to_kind(kind)?)))
364            .collect()
365    }
366
367    fn shallow_clone(&self) -> Self {
368        self.iter()
369            .map(|(key, tensor)| (key.clone(), tensor.shallow_clone()))
370            .collect()
371    }
372}
373
374// option
375
376impl<T> TensorLike for Option<T>
377where
378    T: TensorLike,
379{
380    fn f_to_device(&self, device: Device) -> Result<Self, TchError> {
381        self.as_ref()
382            .map(|tensor| tensor.f_to_device(device))
383            .transpose()
384    }
385
386    fn f_to_kind(&self, kind: Kind) -> Result<Self, TchError> {
387        self.as_ref()
388            .map(|tensor| tensor.f_to_kind(kind))
389            .transpose()
390    }
391
392    fn shallow_clone(&self) -> Self {
393        self.as_ref().map(|tensor| tensor.shallow_clone())
394    }
395}