tenflowers_core/tensor/
device.rs1use super::core::{Tensor, TensorStorage};
8use crate::{Device, Result};
9
10impl<T: Clone> Tensor<T> {
12 pub fn to(&self, device: Device) -> Result<Self>
14 where
15 T: Default + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
16 {
17 if self.device() == &device {
18 return Ok(self.clone());
19 }
20
21 match (&self.storage, &device) {
22 (TensorStorage::Cpu(_array), Device::Cpu) => Ok(self.clone()),
23 #[cfg(feature = "gpu")]
24 (TensorStorage::Cpu(array), Device::Gpu(id)) => {
25 let gpu_buffer = crate::gpu::buffer::GpuBuffer::from_cpu_array(array, *id)?;
26 Ok(Self {
27 storage: TensorStorage::Gpu(gpu_buffer),
28 shape: self.shape().clone(),
29 device,
30 requires_grad: self.requires_grad(),
31 grad: None,
32 })
33 }
34 #[cfg(feature = "gpu")]
35 (TensorStorage::Gpu(buffer), Device::Cpu) => {
36 let array = buffer.to_cpu_array()?;
37 Ok(Self {
38 storage: TensorStorage::Cpu(array),
39 shape: self.shape().clone(),
40 device,
41 requires_grad: self.requires_grad(),
42 grad: None,
43 })
44 }
45 #[allow(unreachable_patterns)]
46 _ => unreachable!(),
47 }
48 }
49
50 pub fn to_device(&self, target_device: Device) -> Result<Self>
52 where
53 T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
54 {
55 if self.device() == &target_device {
56 return Ok(self.clone());
57 }
58
59 self.transfer_to_device(target_device)
60 }
61
62 pub fn to_cpu(&self) -> Result<Self>
64 where
65 T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
66 {
67 self.to_device(Device::Cpu)
68 }
69
70 #[cfg(feature = "gpu")]
72 pub fn gpu_context_info(&self) -> Option<crate::device::context::GpuContextInfo> {
73 match &self.storage {
74 TensorStorage::Gpu(buffer) => Some(crate::device::context::GpuContextInfo {
75 device: buffer.device.clone(),
76 queue: buffer.queue.clone(),
77 }),
78 _ => None,
79 }
80 }
81
82 #[cfg(feature = "gpu")]
84 pub fn to_gpu(&self, gpu_id: usize) -> Result<Self>
85 where
86 T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
87 {
88 self.to_device(Device::Gpu(gpu_id))
89 }
90
91 fn transfer_to_device(&self, target_device: Device) -> Result<Self>
93 where
94 T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
95 {
96 use crate::device::context::DEVICE_MANAGER;
97
98 let _src_ctx = DEVICE_MANAGER.get_context(self.device())?;
99 let _dst_ctx = DEVICE_MANAGER.get_context(&target_device)?;
100
101 match (&self.storage, &target_device) {
102 #[cfg(feature = "gpu")]
104 (TensorStorage::Cpu(cpu_array), Device::Gpu(_)) => {
105 #[cfg(feature = "gpu")]
106 {
107 let slice = cpu_array.as_slice().ok_or_else(|| {
108 crate::TensorError::invalid_argument(
109 "Cannot convert CPU array to slice".to_string(),
110 )
111 })?;
112
113 let gpu_buffer =
114 crate::gpu::buffer::GpuBuffer::from_slice(slice, &target_device)?;
115
116 Ok(Self {
117 storage: TensorStorage::Gpu(gpu_buffer),
118 shape: self.shape().clone(),
119 device: target_device,
120 requires_grad: self.requires_grad(),
121 grad: None, })
123 }
124 #[cfg(not(feature = "gpu"))]
125 {
126 Err(crate::TensorError::device_error_simple(
127 "GPU support not compiled",
128 ))
129 }
130 }
131
132 #[cfg(feature = "gpu")]
134 (TensorStorage::Gpu(gpu_buffer), Device::Cpu) => {
135 let cpu_data = gpu_buffer.to_cpu()?;
136 let array = scirs2_core::ndarray::ArrayD::from_shape_vec(
137 scirs2_core::ndarray::IxDyn(self.shape().dims()),
138 cpu_data,
139 )
140 .map_err(|e| crate::TensorError::invalid_shape_simple(e.to_string()))?;
141
142 Ok(Self {
143 storage: TensorStorage::Cpu(array),
144 shape: self.shape().clone(),
145 device: target_device,
146 requires_grad: self.requires_grad(),
147 grad: None, })
149 }
150
151 #[cfg(feature = "gpu")]
153 (TensorStorage::Gpu(src_buffer), Device::Gpu(_)) => {
154 let dst_buffer = src_buffer.transfer_to_device(&target_device)?;
155
156 Ok(Self {
157 storage: TensorStorage::Gpu(dst_buffer),
158 shape: self.shape().clone(),
159 device: target_device,
160 requires_grad: self.requires_grad(),
161 grad: None, })
163 }
164
165 (TensorStorage::Cpu(_), Device::Cpu) => Ok(self.clone()),
167
168 #[cfg(feature = "rocm")]
170 (TensorStorage::Cpu(_), Device::Rocm(_)) => {
171 eprintln!("Warning: ROCm tensor transfer using CPU fallback - native implementation pending");
174 Ok(self.clone()) }
176 #[cfg(feature = "rocm")]
177 (TensorStorage::Gpu(_), Device::Rocm(_)) => {
178 eprintln!("Warning: GPU to ROCm transfer using CPU fallback - native implementation pending");
181 self.to_cpu() }
183 }
184 }
185
186 pub fn copy_from_device(&mut self, src: &Self) -> Result<()>
188 where
189 T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
190 {
191 if self.shape() != src.shape() {
192 return Err(crate::TensorError::ShapeMismatch {
193 operation: "copy_from_device".to_string(),
194 expected: self.shape().to_string(),
195 got: src.shape().to_string(),
196 context: None,
197 });
198 }
199
200 let transferred = src.transfer_to_device(*self.device())?;
201 self.storage = transferred.storage;
202
203 Ok(())
204 }
205
206 pub fn can_transfer_to(&self, target_device: Device) -> bool
208 where
209 T: bytemuck::Pod,
210 {
211 match (self.device(), &target_device) {
212 (Device::Cpu, Device::Cpu) => true,
213 #[cfg(feature = "gpu")]
214 (Device::Cpu, Device::Gpu(_)) => true,
215 #[cfg(feature = "gpu")]
216 (Device::Gpu(_), Device::Cpu) => true,
217 #[cfg(feature = "gpu")]
218 (Device::Gpu(_), Device::Gpu(_)) => true,
219 #[cfg(feature = "rocm")]
220 (Device::Rocm(_), _) => true,
221 #[cfg(feature = "rocm")]
222 (_, Device::Rocm(_)) => true,
223 }
224 }
225}