tch_plus/tensor/
convert.rs1use super::Tensor;
3use crate::{kind::Element, TchError};
4use half::{bf16, f16};
5use std::convert::{TryFrom, TryInto};
6
7impl<T: Element + Copy> TryFrom<&Tensor> for Vec<T> {
8 type Error = TchError;
9 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
10 let size = tensor.size();
11 if size.len() != 1 {
12 Err(TchError::Convert(format!(
13 "Attempting to convert a Tensor with {} dimensions to flat vector",
14 size.len()
15 )))?;
16 }
17 let numel = size[0] as usize;
18 let mut vec = vec![T::ZERO; numel];
19 tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?;
20 Ok(vec)
21 }
22}
23
24impl<T: Element + Copy> TryFrom<&Tensor> for Vec<Vec<T>> {
25 type Error = TchError;
26 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
27 let (s1, s2) = tensor.size2()?;
28 let s1 = s1 as usize;
29 let s2 = s2 as usize;
30 let num_elem = s1 * s2;
31 let mut all_elems = vec![T::ZERO; num_elem];
33 tensor.f_to_kind(T::KIND)?.f_copy_data(&mut all_elems, num_elem)?;
34 let out = (0..s1).map(|i1| (0..s2).map(|i2| all_elems[i1 * s2 + i2]).collect()).collect();
35 Ok(out)
36 }
37}
38
39impl<T: Element + Copy> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> {
40 type Error = TchError;
41 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
42 let (s1, s2, s3) = tensor.size3()?;
43 let s1 = s1 as usize;
44 let s2 = s2 as usize;
45 let s3 = s3 as usize;
46 let num_elem = s1 * s2 * s3;
47 let mut all_elems = vec![T::ZERO; num_elem];
49 tensor.f_to_kind(T::KIND)?.f_copy_data(&mut all_elems, num_elem)?;
50 let out = (0..s1)
51 .map(|i1| {
52 (0..s2)
53 .map(|i2| (0..s3).map(|i3| all_elems[i1 * s2 * s3 + i2 * s3 + i3]).collect())
54 .collect()
55 })
56 .collect();
57 Ok(out)
58 }
59}
60
61impl<T: Element + Copy> TryFrom<Tensor> for Vec<T> {
62 type Error = TchError;
63 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
64 Vec::<T>::try_from(&tensor)
65 }
66}
67
68impl<T: Element + Copy> TryFrom<Tensor> for Vec<Vec<T>> {
69 type Error = TchError;
70 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
71 Vec::<Vec<T>>::try_from(&tensor)
72 }
73}
74
75impl<T: Element + Copy> TryFrom<Tensor> for Vec<Vec<Vec<T>>> {
76 type Error = TchError;
77 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
78 Vec::<Vec<Vec<T>>>::try_from(&tensor)
79 }
80}
81
82macro_rules! from_tensor {
83 ($typ:ident) => {
84 impl TryFrom<&Tensor> for $typ {
85 type Error = TchError;
86
87 fn try_from(tensor: &Tensor) -> Result<Self, Self::Error> {
88 let numel = tensor.numel();
89 if numel != 1 {
90 return Err(TchError::Convert(format!(
91 "expected exactly one element, got {}",
92 numel
93 )));
94 }
95 let mut vec = [$typ::ZERO; 1];
96 tensor
97 .f_to_device(crate::Device::Cpu)?
98 .f_to_kind($typ::KIND)?
99 .f_copy_data(&mut vec, numel)?;
100 Ok(vec[0])
101 }
102 }
103
104 impl TryFrom<Tensor> for $typ {
105 type Error = TchError;
106
107 fn try_from(tensor: Tensor) -> Result<Self, Self::Error> {
108 $typ::try_from(&tensor)
109 }
110 }
111 };
112}
113
114from_tensor!(f64);
115from_tensor!(f32);
116from_tensor!(f16);
117from_tensor!(i64);
118from_tensor!(i32);
119from_tensor!(i16);
120from_tensor!(i8);
121from_tensor!(u8);
122from_tensor!(bool);
123from_tensor!(bf16);
124
125impl<T: Element + Copy> TryInto<ndarray::ArrayD<T>> for &Tensor {
126 type Error = TchError;
127
128 fn try_into(self) -> Result<ndarray::ArrayD<T>, Self::Error> {
129 let num_elem = self.numel();
130 let mut vec = vec![T::ZERO; num_elem];
131 self.f_to_kind(T::KIND)?.f_copy_data(&mut vec, num_elem)?;
132 let shape: Vec<usize> = self.size().iter().map(|s| *s as usize).collect();
133 Ok(ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shape), vec)?)
134 }
135}
136
137impl<T, D> TryFrom<&ndarray::ArrayBase<T, D>> for Tensor
138where
139 T: ndarray::Data,
140 T::Elem: Element,
141 D: ndarray::Dimension,
142{
143 type Error = TchError;
144
145 fn try_from(value: &ndarray::ArrayBase<T, D>) -> Result<Self, Self::Error> {
146 let slice = value
147 .as_slice()
148 .ok_or_else(|| TchError::Convert("cannot convert to slice".to_string()))?;
149 let tn = Self::f_from_slice(slice)?;
150 let shape: Vec<i64> = value.shape().iter().map(|s| *s as i64).collect();
151 tn.f_reshape(shape)
152 }
153}
154
155impl<T, D> TryFrom<ndarray::ArrayBase<T, D>> for Tensor
156where
157 T: ndarray::Data,
158 T::Elem: Element,
159 D: ndarray::Dimension,
160{
161 type Error = TchError;
162
163 fn try_from(value: ndarray::ArrayBase<T, D>) -> Result<Self, Self::Error> {
164 Self::try_from(&value)
165 }
166}
167
168impl<T: Element> TryFrom<&Vec<T>> for Tensor {
169 type Error = TchError;
170
171 fn try_from(value: &Vec<T>) -> Result<Self, Self::Error> {
172 Self::f_from_slice(value.as_slice())
173 }
174}
175
176impl<T: Element> TryFrom<Vec<T>> for Tensor {
177 type Error = TchError;
178
179 fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
180 Self::try_from(&value)
181 }
182}