Skip to main content

trustformers_core/tensor/
conversions.rs

1//! Tensor conversion functions.
2//!
3//! This module contains functions for converting between different tensor types
4//! and data formats.
5
6use super::{DType, Tensor};
7use crate::errors::{Result, TrustformersError};
8use scirs2_core::{Complex32, Complex64};
9
10impl Tensor {
11    /// Convert tensor to a different data type.
12    ///
13    /// # Arguments
14    ///
15    /// * `dtype` - Target data type
16    ///
17    /// # Returns
18    ///
19    /// A tensor with the new data type.
20    pub fn to_dtype(&self, dtype: DType) -> Result<Tensor> {
21        match (self, dtype) {
22            (Tensor::F32(a), DType::F64) => {
23                let result = a.mapv(|x| x as f64);
24                Ok(Tensor::F64(result))
25            },
26            (Tensor::F32(a), DType::I64) => {
27                let result = a.mapv(|x| x as i64);
28                Ok(Tensor::I64(result))
29            },
30            (Tensor::F32(a), DType::C32) => {
31                let result = a.mapv(|x| Complex32::new(x, 0.0));
32                Ok(Tensor::C32(result))
33            },
34            (Tensor::F32(a), DType::C64) => {
35                let result = a.mapv(|x| Complex64::new(x as f64, 0.0));
36                Ok(Tensor::C64(result))
37            },
38            (Tensor::F64(a), DType::F32) => {
39                let result = a.mapv(|x| x as f32);
40                Ok(Tensor::F32(result))
41            },
42            (Tensor::F64(a), DType::I64) => {
43                let result = a.mapv(|x| x as i64);
44                Ok(Tensor::I64(result))
45            },
46            (Tensor::F64(a), DType::C32) => {
47                let result = a.mapv(|x| Complex32::new(x as f32, 0.0));
48                Ok(Tensor::C32(result))
49            },
50            (Tensor::F64(a), DType::C64) => {
51                let result = a.mapv(|x| Complex64::new(x, 0.0));
52                Ok(Tensor::C64(result))
53            },
54            (Tensor::I64(a), DType::F32) => {
55                let result = a.mapv(|x| x as f32);
56                Ok(Tensor::F32(result))
57            },
58            (Tensor::I64(a), DType::F64) => {
59                let result = a.mapv(|x| x as f64);
60                Ok(Tensor::F64(result))
61            },
62            (Tensor::C32(a), DType::F32) => {
63                let result = a.mapv(|x| x.re);
64                Ok(Tensor::F32(result))
65            },
66            (Tensor::C32(a), DType::F64) => {
67                let result = a.mapv(|x| x.re as f64);
68                Ok(Tensor::F64(result))
69            },
70            (Tensor::C64(a), DType::F32) => {
71                let result = a.mapv(|x| x.re as f32);
72                Ok(Tensor::F32(result))
73            },
74            (Tensor::C64(a), DType::F64) => {
75                let result = a.mapv(|x| x.re);
76                Ok(Tensor::F64(result))
77            },
78            (tensor, target_dtype) if tensor.dtype() == target_dtype => Ok(tensor.clone()),
79            #[cfg(all(target_os = "macos", feature = "metal"))]
80            (Tensor::Metal(_), _) => {
81                // Convert Metal tensor to CPU first, then apply dtype conversion
82                let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
83                cpu_tensor.to_dtype(dtype)
84            },
85            _ => Err(TrustformersError::tensor_op_error(
86                &format!(
87                    "Conversion from {:?} to {:?} not supported",
88                    self.dtype(),
89                    dtype
90                ),
91                "Tensor::to_dtype",
92            )),
93        }
94    }
95
96    /// Convert tensor to vector of f32 values.
97    ///
98    /// # Returns
99    ///
100    /// A vector of f32 values.
101    pub fn to_vec_f32(&self) -> Result<Vec<f32>> {
102        match self {
103            Tensor::F32(a) => Ok(a.iter().cloned().collect()),
104            Tensor::F64(a) => Ok(a.iter().map(|&x| x as f32).collect()),
105            Tensor::I64(a) => Ok(a.iter().map(|&x| x as f32).collect()),
106            Tensor::C32(a) => Ok(a.iter().map(|x| x.re).collect()),
107            Tensor::C64(a) => Ok(a.iter().map(|x| x.re as f32).collect()),
108            #[cfg(all(target_os = "macos", feature = "metal"))]
109            Tensor::Metal(_) => {
110                // Convert Metal tensor to CPU first, then get vec
111                let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
112                cpu_tensor.to_vec_f32()
113            },
114            _ => Err(TrustformersError::tensor_op_error(
115                "Cannot convert this tensor type to Vec<f32>",
116                "Tensor::to_vec_f32",
117            )),
118        }
119    }
120
121    /// Convert tensor to vector of u8 values.
122    ///
123    /// # Returns
124    ///
125    /// A vector of u8 values.
126    pub fn to_vec_u8(&self) -> Result<Vec<u8>> {
127        match self {
128            Tensor::F32(a) => Ok(a.iter().map(|&x| x as u8).collect()),
129            Tensor::F64(a) => Ok(a.iter().map(|&x| x as u8).collect()),
130            Tensor::I64(a) => Ok(a.iter().map(|&x| x as u8).collect()),
131            #[cfg(all(target_os = "macos", feature = "metal"))]
132            Tensor::Metal(_) => {
133                // Convert Metal tensor to CPU first, then get vec
134                let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
135                cpu_tensor.to_vec_u8()
136            },
137            _ => Err(TrustformersError::tensor_op_error(
138                "Cannot convert this tensor type to Vec<u8>",
139                "Tensor::to_vec_u8",
140            )),
141        }
142    }
143
144    /// Convert tensor to F32 dtype (convenience method).
145    ///
146    /// # Returns
147    ///
148    /// A tensor with F32 dtype.
149    pub fn to_f32(&self) -> Result<Tensor> {
150        self.to_dtype(DType::F32)
151    }
152
153    /// Convert tensor to I64 dtype (convenience method).
154    ///
155    /// # Returns
156    ///
157    /// A tensor with I64 dtype.
158    pub fn to_i64(&self) -> Result<Tensor> {
159        self.to_dtype(DType::I64)
160    }
161}