trustformers_core/tensor/
conversions.rs1use super::{DType, Tensor};
7use crate::errors::{Result, TrustformersError};
8use scirs2_core::{Complex32, Complex64};
9
10impl Tensor {
11 pub fn to_dtype(&self, dtype: DType) -> Result<Tensor> {
21 match (self, dtype) {
22 (Tensor::F32(a), DType::F64) => {
23 let result = a.mapv(|x| x as f64);
24 Ok(Tensor::F64(result))
25 },
26 (Tensor::F32(a), DType::I64) => {
27 let result = a.mapv(|x| x as i64);
28 Ok(Tensor::I64(result))
29 },
30 (Tensor::F32(a), DType::C32) => {
31 let result = a.mapv(|x| Complex32::new(x, 0.0));
32 Ok(Tensor::C32(result))
33 },
34 (Tensor::F32(a), DType::C64) => {
35 let result = a.mapv(|x| Complex64::new(x as f64, 0.0));
36 Ok(Tensor::C64(result))
37 },
38 (Tensor::F64(a), DType::F32) => {
39 let result = a.mapv(|x| x as f32);
40 Ok(Tensor::F32(result))
41 },
42 (Tensor::F64(a), DType::I64) => {
43 let result = a.mapv(|x| x as i64);
44 Ok(Tensor::I64(result))
45 },
46 (Tensor::F64(a), DType::C32) => {
47 let result = a.mapv(|x| Complex32::new(x as f32, 0.0));
48 Ok(Tensor::C32(result))
49 },
50 (Tensor::F64(a), DType::C64) => {
51 let result = a.mapv(|x| Complex64::new(x, 0.0));
52 Ok(Tensor::C64(result))
53 },
54 (Tensor::I64(a), DType::F32) => {
55 let result = a.mapv(|x| x as f32);
56 Ok(Tensor::F32(result))
57 },
58 (Tensor::I64(a), DType::F64) => {
59 let result = a.mapv(|x| x as f64);
60 Ok(Tensor::F64(result))
61 },
62 (Tensor::C32(a), DType::F32) => {
63 let result = a.mapv(|x| x.re);
64 Ok(Tensor::F32(result))
65 },
66 (Tensor::C32(a), DType::F64) => {
67 let result = a.mapv(|x| x.re as f64);
68 Ok(Tensor::F64(result))
69 },
70 (Tensor::C64(a), DType::F32) => {
71 let result = a.mapv(|x| x.re as f32);
72 Ok(Tensor::F32(result))
73 },
74 (Tensor::C64(a), DType::F64) => {
75 let result = a.mapv(|x| x.re);
76 Ok(Tensor::F64(result))
77 },
78 (tensor, target_dtype) if tensor.dtype() == target_dtype => Ok(tensor.clone()),
79 #[cfg(all(target_os = "macos", feature = "metal"))]
80 (Tensor::Metal(_), _) => {
81 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
83 cpu_tensor.to_dtype(dtype)
84 },
85 _ => Err(TrustformersError::tensor_op_error(
86 &format!(
87 "Conversion from {:?} to {:?} not supported",
88 self.dtype(),
89 dtype
90 ),
91 "Tensor::to_dtype",
92 )),
93 }
94 }
95
96 pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
102 match self {
103 Tensor::F32(a) => Ok(a.iter().cloned().collect()),
104 Tensor::F64(a) => Ok(a.iter().map(|&x| x as f32).collect()),
105 Tensor::I64(a) => Ok(a.iter().map(|&x| x as f32).collect()),
106 Tensor::C32(a) => Ok(a.iter().map(|x| x.re).collect()),
107 Tensor::C64(a) => Ok(a.iter().map(|x| x.re as f32).collect()),
108 #[cfg(all(target_os = "macos", feature = "metal"))]
109 Tensor::Metal(_) => {
110 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
112 cpu_tensor.to_vec_f32()
113 },
114 _ => Err(TrustformersError::tensor_op_error(
115 "Cannot convert this tensor type to Vec<f32>",
116 "Tensor::to_vec_f32",
117 )),
118 }
119 }
120
121 pub fn to_vec_u8(&self) -> Result<Vec<u8>> {
127 match self {
128 Tensor::F32(a) => Ok(a.iter().map(|&x| x as u8).collect()),
129 Tensor::F64(a) => Ok(a.iter().map(|&x| x as u8).collect()),
130 Tensor::I64(a) => Ok(a.iter().map(|&x| x as u8).collect()),
131 #[cfg(all(target_os = "macos", feature = "metal"))]
132 Tensor::Metal(_) => {
133 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
135 cpu_tensor.to_vec_u8()
136 },
137 _ => Err(TrustformersError::tensor_op_error(
138 "Cannot convert this tensor type to Vec<u8>",
139 "Tensor::to_vec_u8",
140 )),
141 }
142 }
143
144 pub fn to_f32(&self) -> Result<Tensor> {
150 self.to_dtype(DType::F32)
151 }
152
153 pub fn to_i64(&self) -> Result<Tensor> {
159 self.to_dtype(DType::I64)
160 }
161}