tch_plus/tensor/
convert.rs

1//! Implement conversion traits for tensors
2use super::Tensor;
3use crate::{kind::Element, TchError};
4use half::{bf16, f16};
5use std::convert::{TryFrom, TryInto};
6
7impl<T: Element + Copy> TryFrom<&Tensor> for Vec<T> {
8    type Error = TchError;
9    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
10        let size = tensor.size();
11        if size.len() != 1 {
12            Err(TchError::Convert(format!(
13                "Attempting to convert a Tensor with {} dimensions to flat vector",
14                size.len()
15            )))?;
16        }
17        let numel = size[0] as usize;
18        let mut vec = vec![T::ZERO; numel];
19        tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?;
20        Ok(vec)
21    }
22}
23
24impl<T: Element + Copy> TryFrom<&Tensor> for Vec<Vec<T>> {
25    type Error = TchError;
26    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
27        let (s1, s2) = tensor.size2()?;
28        let s1 = s1 as usize;
29        let s2 = s2 as usize;
30        let num_elem = s1 * s2;
31        // TODO: Try to remove this intermediary copy.
32        let mut all_elems = vec![T::ZERO; num_elem];
33        tensor.f_to_kind(T::KIND)?.f_copy_data(&mut all_elems, num_elem)?;
34        let out = (0..s1).map(|i1| (0..s2).map(|i2| all_elems[i1 * s2 + i2]).collect()).collect();
35        Ok(out)
36    }
37}
38
39impl<T: Element + Copy> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
40    type Error = TchError;
41    fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
42        let (s1, s2, s3) = tensor.size3()?;
43        let s1 = s1 as usize;
44        let s2 = s2 as usize;
45        let s3 = s3 as usize;
46        let num_elem = s1 * s2 * s3;
47        // TODO: Try to remove this intermediary copy.
48        let mut all_elems = vec![T::ZERO; num_elem];
49        tensor.f_to_kind(T::KIND)?.f_copy_data(&mut all_elems, num_elem)?;
50        let out = (0..s1)
51            .map(|i1| {
52                (0..s2)
53                    .map(|i2| (0..s3).map(|i3| all_elems[i1 * s2 * s3 + i2 * s3 + i3]).collect())
54                    .collect()
55            })
56            .collect();
57        Ok(out)
58    }
59}
60
61impl<T: Element + Copy> TryFrom<Tensor> for Vec<T> {
62    type Error = TchError;
63    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
64        Vec::<T>::try_from(&tensor)
65    }
66}
67
68impl<T: Element + Copy> TryFrom<Tensor> for Vec<Vec<T>> {
69    type Error = TchError;
70    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
71        Vec::<Vec<T>>::try_from(&tensor)
72    }
73}
74
75impl<T: Element + Copy> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
76    type Error = TchError;
77    fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
78        Vec::<Vec<Vec<T>>>::try_from(&tensor)
79    }
80}
81
82macro_rules! from_tensor {
83    ($typ:ident) => {
84        impl TryFrom<&Tensor> for $typ {
85            type Error = TchError;
86
87            fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
88                let numel = tensor.numel();
89                if numel != 1 {
90                    return Err(TchError::Convert(format!(
91                        "expected exactly one element, got {}",
92                        numel
93                    )));
94                }
95                let mut vec = [$typ::ZERO; 1];
96                tensor
97                    .f_to_device(crate::Device::Cpu)?
98                    .f_to_kind($typ::KIND)?
99                    .f_copy_data(&mut vec, numel)?;
100                Ok(vec[0])
101            }
102        }
103
104        impl TryFrom<Tensor> for $typ {
105            type Error = TchError;
106
107            fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
108                $typ::try_from(&tensor)
109            }
110        }
111    };
112}
113
114from_tensor!(f64);
115from_tensor!(f32);
116from_tensor!(f16);
117from_tensor!(i64);
118from_tensor!(i32);
119from_tensor!(i16);
120from_tensor!(i8);
121from_tensor!(u8);
122from_tensor!(bool);
123from_tensor!(bf16);
124
125impl<T: Element + Copy> TryInto<ndarray::ArrayD<T>> for &Tensor {
126    type Error = TchError;
127
128    fn try_into(self) -> Result<ndarray::ArrayD<T>, Self::Error> {
129        let num_elem = self.numel();
130        let mut vec = vec![T::ZERO; num_elem];
131        self.f_to_kind(T::KIND)?.f_copy_data(&mut vec, num_elem)?;
132        let shape: Vec<usize> = self.size().iter().map(|s| *s as usize).collect();
133        Ok(ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shape), vec)?)
134    }
135}
136
137impl<T, D> TryFrom<&ndarray::ArrayBase<T, D>> for Tensor
138where
139    T: ndarray::Data,
140    T::Elem: Element,
141    D: ndarray::Dimension,
142{
143    type Error = TchError;
144
145    fn try_from(value: &ndarray::ArrayBase<T, D>) -> Result<Self, Self::Error> {
146        let slice = value
147            .as_slice()
148            .ok_or_else(|| TchError::Convert("cannot convert to slice".to_string()))?;
149        let tn = Self::f_from_slice(slice)?;
150        let shape: Vec<i64> = value.shape().iter().map(|s| *s as i64).collect();
151        tn.f_reshape(shape)
152    }
153}
154
155impl<T, D> TryFrom<ndarray::ArrayBase<T, D>> for Tensor
156where
157    T: ndarray::Data,
158    T::Elem: Element,
159    D: ndarray::Dimension,
160{
161    type Error = TchError;
162
163    fn try_from(value: ndarray::ArrayBase<T, D>) -> Result<Self, Self::Error> {
164        Self::try_from(&value)
165    }
166}
167
168impl<T: Element> TryFrom<&Vec<T>> for Tensor {
169    type Error = TchError;
170
171    fn try_from(value: &Vec<T>) -> Result<Self, Self::Error> {
172        Self::f_from_slice(value.as_slice())
173    }
174}
175
176impl<T: Element> TryFrom<Vec<T>> for Tensor {
177    type Error = TchError;
178
179    fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
180        Self::try_from(&value)
181    }
182}