Skip to main content

tenflowers_core/tensor/
device.rs

1//! Device Management and Transfer Operations
2//!
3//! This module handles tensor device placement, transfers between devices
4//! (CPU/GPU), and device-specific operations. It provides efficient
5//! device-to-device data transfer capabilities.
6
7use super::core::{Tensor, TensorStorage};
8use crate::{Device, Result};
9
10// Impl block for methods that need Clone
11impl<T: Clone> Tensor<T> {
12    /// Transfer tensor to specified device
13    pub fn to(&self, device: Device) -> Result<Self>
14    where
15        T: Default + bytemuck::Pod + bytemuck::Zeroable + Send + Sync + 'static,
16    {
17        if self.device() == &device {
18            return Ok(self.clone());
19        }
20
21        match (&self.storage, &device) {
22            (TensorStorage::Cpu(_array), Device::Cpu) => Ok(self.clone()),
23            #[cfg(feature = "gpu")]
24            (TensorStorage::Cpu(array), Device::Gpu(id)) => {
25                let gpu_buffer = crate::gpu::buffer::GpuBuffer::from_cpu_array(array, *id)?;
26                Ok(Self {
27                    storage: TensorStorage::Gpu(gpu_buffer),
28                    shape: self.shape().clone(),
29                    device,
30                    requires_grad: self.requires_grad(),
31                    grad: None,
32                })
33            }
34            #[cfg(feature = "gpu")]
35            (TensorStorage::Gpu(buffer), Device::Cpu) => {
36                let array = buffer.to_cpu_array()?;
37                Ok(Self {
38                    storage: TensorStorage::Cpu(array),
39                    shape: self.shape().clone(),
40                    device,
41                    requires_grad: self.requires_grad(),
42                    grad: None,
43                })
44            }
45            #[allow(unreachable_patterns)]
46            _ => unreachable!(),
47        }
48    }
49
50    /// Transfer tensor to a different device
51    pub fn to_device(&self, target_device: Device) -> Result<Self>
52    where
53        T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
54    {
55        if self.device() == &target_device {
56            return Ok(self.clone());
57        }
58
59        self.transfer_to_device(target_device)
60    }
61
62    /// Move tensor to CPU
63    pub fn to_cpu(&self) -> Result<Self>
64    where
65        T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
66    {
67        self.to_device(Device::Cpu)
68    }
69
70    /// Get GPU context information from this tensor (if it's on GPU)
71    #[cfg(feature = "gpu")]
72    pub fn gpu_context_info(&self) -> Option<crate::device::context::GpuContextInfo> {
73        match &self.storage {
74            TensorStorage::Gpu(buffer) => Some(crate::device::context::GpuContextInfo {
75                device: buffer.device.clone(),
76                queue: buffer.queue.clone(),
77            }),
78            _ => None,
79        }
80    }
81
82    /// Move tensor to GPU with specified ID
83    #[cfg(feature = "gpu")]
84    pub fn to_gpu(&self, gpu_id: usize) -> Result<Self>
85    where
86        T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
87    {
88        self.to_device(Device::Gpu(gpu_id))
89    }
90
91    /// Internal device transfer implementation
92    fn transfer_to_device(&self, target_device: Device) -> Result<Self>
93    where
94        T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
95    {
96        use crate::device::context::DEVICE_MANAGER;
97
98        let _src_ctx = DEVICE_MANAGER.get_context(self.device())?;
99        let _dst_ctx = DEVICE_MANAGER.get_context(&target_device)?;
100
101        match (&self.storage, &target_device) {
102            // CPU to GPU transfer
103            #[cfg(feature = "gpu")]
104            (TensorStorage::Cpu(cpu_array), Device::Gpu(_)) => {
105                #[cfg(feature = "gpu")]
106                {
107                    let slice = cpu_array.as_slice().ok_or_else(|| {
108                        crate::TensorError::invalid_argument(
109                            "Cannot convert CPU array to slice".to_string(),
110                        )
111                    })?;
112
113                    let gpu_buffer =
114                        crate::gpu::buffer::GpuBuffer::from_slice(slice, &target_device)?;
115
116                    Ok(Self {
117                        storage: TensorStorage::Gpu(gpu_buffer),
118                        shape: self.shape().clone(),
119                        device: target_device,
120                        requires_grad: self.requires_grad(),
121                        grad: None, // Gradients are reset on device transfer
122                    })
123                }
124                #[cfg(not(feature = "gpu"))]
125                {
126                    Err(crate::TensorError::device_error_simple(
127                        "GPU support not compiled",
128                    ))
129                }
130            }
131
132            // GPU to CPU transfer
133            #[cfg(feature = "gpu")]
134            (TensorStorage::Gpu(gpu_buffer), Device::Cpu) => {
135                let cpu_data = gpu_buffer.to_cpu()?;
136                let array = scirs2_core::ndarray::ArrayD::from_shape_vec(
137                    scirs2_core::ndarray::IxDyn(self.shape().dims()),
138                    cpu_data,
139                )
140                .map_err(|e| crate::TensorError::invalid_shape_simple(e.to_string()))?;
141
142                Ok(Self {
143                    storage: TensorStorage::Cpu(array),
144                    shape: self.shape().clone(),
145                    device: target_device,
146                    requires_grad: self.requires_grad(),
147                    grad: None, // Gradients are reset on device transfer
148                })
149            }
150
151            // GPU to GPU transfer (device-to-device)
152            #[cfg(feature = "gpu")]
153            (TensorStorage::Gpu(src_buffer), Device::Gpu(_)) => {
154                let dst_buffer = src_buffer.transfer_to_device(&target_device)?;
155
156                Ok(Self {
157                    storage: TensorStorage::Gpu(dst_buffer),
158                    shape: self.shape().clone(),
159                    device: target_device,
160                    requires_grad: self.requires_grad(),
161                    grad: None, // Gradients are reset on device transfer
162                })
163            }
164
165            // CPU to CPU (should not happen due to early return)
166            (TensorStorage::Cpu(_), Device::Cpu) => Ok(self.clone()),
167
168            // ROCm patterns - Use CPU fallback for now
169            #[cfg(feature = "rocm")]
170            (TensorStorage::Cpu(_), Device::Rocm(_)) => {
171                // ROCm tensor transfer: Fallback to CPU for now
172                // Future: Implement native ROCm memory transfer kernels
173                eprintln!("Warning: ROCm tensor transfer using CPU fallback - native implementation pending");
174                Ok(self.clone()) // Keep on CPU until ROCm support is complete
175            }
176            #[cfg(feature = "rocm")]
177            (TensorStorage::Gpu(_), Device::Rocm(_)) => {
178                // GPU to ROCm transfer: Go through CPU as intermediate step
179                // Future: Implement direct GPU<->ROCm memory transfer
180                eprintln!("Warning: GPU to ROCm transfer using CPU fallback - native implementation pending");
181                self.to_cpu() // Transfer to CPU for now
182            }
183        }
184    }
185
186    /// Copy tensor data from another device (for collective operations)
187    pub fn copy_from_device(&mut self, src: &Self) -> Result<()>
188    where
189        T: Clone + Default + Send + Sync + 'static + bytemuck::Pod,
190    {
191        if self.shape() != src.shape() {
192            return Err(crate::TensorError::ShapeMismatch {
193                operation: "copy_from_device".to_string(),
194                expected: self.shape().to_string(),
195                got: src.shape().to_string(),
196                context: None,
197            });
198        }
199
200        let transferred = src.transfer_to_device(*self.device())?;
201        self.storage = transferred.storage;
202
203        Ok(())
204    }
205
206    /// Check if tensor can be transferred to target device
207    pub fn can_transfer_to(&self, target_device: Device) -> bool
208    where
209        T: bytemuck::Pod,
210    {
211        match (self.device(), &target_device) {
212            (Device::Cpu, Device::Cpu) => true,
213            #[cfg(feature = "gpu")]
214            (Device::Cpu, Device::Gpu(_)) => true,
215            #[cfg(feature = "gpu")]
216            (Device::Gpu(_), Device::Cpu) => true,
217            #[cfg(feature = "gpu")]
218            (Device::Gpu(_), Device::Gpu(_)) => true,
219            #[cfg(feature = "rocm")]
220            (Device::Rocm(_), _) => true,
221            #[cfg(feature = "rocm")]
222            (_, Device::Rocm(_)) => true,
223        }
224    }
225}