scirs2_core/memory_efficient/
cross_device.rs

1//! Cross-device memory management for efficient data transfer between CPU and GPU
2//!
3//! This module provides utilities for managing memory across different devices
4//! (CPU, GPU, TPU) with efficient data transfer and synchronization. It includes:
5//!
6//! - Cross-device memory transfer with automatic format conversion
7//! - Memory pools for efficient allocation
8//! - Smart caching for frequently accessed data
9//! - Efficient pinned memory for faster CPU-GPU transfers
10//! - Asynchronous data transfer with event-based synchronization
11
12use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
13#[cfg(feature = "gpu")]
14use crate::gpu::{GpuBackend, GpuBuffer, GpuContext, GpuDataType};
15use ::ndarray::{Array, ArrayBase, Dimension, IxDyn, RawData};
16use std::any::TypeId;
17use std::collections::HashMap;
18use std::hash::{Hash, Hasher};
19use std::marker::PhantomData;
20use std::sync::{Arc, Mutex};
21
22/// Device types supported by the cross-device memory management
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum DeviceType {
25    /// CPU (host) memory
26    Cpu,
27    /// Discrete GPU memory
28    Gpu(GpuBackend),
29    /// TPU (Tensor Processing Unit) memory
30    Tpu,
31}
32
33impl DeviceType {
34    /// Check if the device is available on the current system
35    pub fn is_available(&self) -> bool {
36        match self {
37            DeviceType::Cpu => true,
38            DeviceType::Gpu(backend) => backend.is_available(),
39            DeviceType::Tpu => false, // TPU support not yet implemented
40        }
41    }
42
43    /// Get the name of the device
44    pub fn name(&self) -> String {
45        match self {
46            DeviceType::Cpu => "CPU".to_string(),
47            DeviceType::Gpu(backend) => format!("GPU ({backend})"),
48            DeviceType::Tpu => "TPU".to_string(),
49        }
50    }
51}
52
53impl std::fmt::Display for DeviceType {
54    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
55        match self {
56            DeviceType::Cpu => write!(f, "CPU"),
57            DeviceType::Gpu(backend) => write!(f, "GPU ({backend})"),
58            DeviceType::Tpu => write!(f, "TPU"),
59        }
60    }
61}
62
63/// Memory transfer direction between devices
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum TransferDirection {
66    /// Host to device transfer (e.g., CPU to GPU)
67    HostToDevice,
68    /// Device to host transfer (e.g., GPU to CPU)
69    DeviceToHost,
70    /// Device to device transfer (e.g., GPU to TPU)
71    DeviceToDevice,
72}
73
74/// Transfer mode for cross-device operations
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum TransferMode {
77    /// Synchronous transfer (blocks until complete)
78    Synchronous,
79    /// Asynchronous transfer (returns immediately, track with events)
80    Asynchronous,
81}
82
83/// Memory layout for device buffers
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum MemoryLayout {
86    /// Row-major layout (C-style)
87    RowMajor,
88    /// Column-major layout (Fortran-style)
89    ColumnMajor,
90    /// Strided layout (custom strides)
91    Strided,
92}
93
94/// Options for cross-device memory transfers
95#[derive(Debug, Clone)]
96pub struct TransferOptions {
97    /// Transfer mode (synchronous or asynchronous)
98    pub mode: TransferMode,
99    /// Memory layout for the transfer
100    pub layout: MemoryLayout,
101    /// Whether to use pinned memory for the transfer
102    pub use_pinned_memory: bool,
103    /// Whether to enable streaming transfers for large buffers
104    pub enable_streaming: bool,
105    /// Stream ID for asynchronous transfers
106    pub stream_id: Option<usize>,
107}
108
109impl Default for TransferOptions {
110    fn default() -> Self {
111        Self {
112            mode: TransferMode::Synchronous,
113            layout: MemoryLayout::RowMajor,
114            use_pinned_memory: true,
115            enable_streaming: true,
116            stream_id: None,
117        }
118    }
119}
120
121/// Builder for transfer options
122#[derive(Debug, Clone)]
123pub struct TransferOptionsBuilder {
124    options: TransferOptions,
125}
126
127impl TransferOptionsBuilder {
128    /// Create a new transfer options builder with default values
129    pub fn new() -> Self {
130        Self {
131            options: TransferOptions::default(),
132        }
133    }
134
135    /// Set the transfer mode
136    pub const fn mode(mut self, mode: TransferMode) -> Self {
137        self.options.mode = mode;
138        self
139    }
140
141    /// Set the memory layout
142    pub const fn layout(mut self, layout: MemoryLayout) -> Self {
143        self.options.layout = layout;
144        self
145    }
146
147    /// Set whether to use pinned memory
148    pub const fn memory(mut self, use_pinnedmemory: bool) -> Self {
149        self.options.use_pinned_memory = use_pinnedmemory;
150        self
151    }
152
153    /// Set whether to enable streaming transfers
154    pub const fn streaming(mut self, enablestreaming: bool) -> Self {
155        self.options.enable_streaming = enablestreaming;
156        self
157    }
158
159    /// Set the stream ID for asynchronous transfers
160    pub const fn with_stream_id(mut self, streamid: Option<usize>) -> Self {
161        self.options.stream_id = streamid;
162        self
163    }
164
165    /// Build the transfer options
166    pub fn build(self) -> TransferOptions {
167        self.options
168    }
169}
170
171impl Default for TransferOptionsBuilder {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177/// Cache key for the device memory cache
178#[derive(Debug, Clone, PartialEq, Eq)]
179struct CacheKey {
180    /// Data identifier (usually the memory address of the host array)
181    data_id: usize,
182    /// Device type
183    device: DeviceType,
184    /// Element type ID
185    type_id: TypeId,
186    /// Size in elements
187    size: usize,
188}
189
190impl Hash for CacheKey {
191    fn hash<H: Hasher>(&self, state: &mut H) {
192        self.data_id.hash(state);
193        self.device.hash(state);
194        std::any::TypeId::of::<i32>().hash(state);
195        self.size.hash(state);
196    }
197}
198
199/// Event for tracking asynchronous operations
200#[derive(Debug)]
201pub struct TransferEvent {
202    /// Device associated with the event
203    #[allow(dead_code)]
204    device: DeviceType,
205    /// Internal event handle (implementation-specific)
206    #[allow(dead_code)]
207    handle: Arc<Mutex<Box<dyn std::any::Any + Send + Sync>>>,
208    /// Whether the event has been completed
209    completed: Arc<std::sync::atomic::AtomicBool>,
210}
211
212impl TransferEvent {
213    /// Create a new transfer event
214    #[allow(dead_code)]
215    fn device(devicetype: DeviceType, handle: Box<dyn std::any::Any + Send + Sync>) -> Self {
216        Self {
217            device: devicetype,
218            handle: Arc::new(Mutex::new(handle)),
219            completed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
220        }
221    }
222
223    /// Wait for the event to complete
224    pub fn wait(&self) {
225        // In a real implementation, this would block until the event is complete
226        // For now, just set the completed flag for demonstration
227        self.completed
228            .store(true, std::sync::atomic::Ordering::SeqCst);
229    }
230
231    /// Check if the event has completed
232    pub fn is_complete(&self) -> bool {
233        self.completed.load(std::sync::atomic::Ordering::SeqCst)
234    }
235}
236
237/// Cache entry for the device memory cache
238struct CacheEntry<T: GpuDataType> {
239    /// Buffer on the device
240    buffer: DeviceBuffer<T>,
241    /// Size in elements
242    size: usize,
243    /// Last access time
244    last_access: std::time::Instant,
245    /// Whether the buffer is dirty (modified on device)
246    #[allow(dead_code)]
247    dirty: bool,
248}
249
250/// Device memory manager for cross-device operations
251pub struct DeviceMemoryManager {
252    /// GPU context for accessing GPU functionality
253    gpu_context: Option<GpuContext>,
254    /// Cache of device buffers
255    cache: Mutex<HashMap<CacheKey, Box<dyn std::any::Any + Send + Sync>>>,
256    /// Maximum cache size in bytes
257    max_cache_size: usize,
258    /// Current cache size in bytes
259    current_cache_size: std::sync::atomic::AtomicUsize,
260    /// Whether the caching is enabled
261    enable_caching: bool,
262}
263
264impl std::fmt::Debug for DeviceMemoryManager {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        f.debug_struct("DeviceMemoryManager")
267            .field("gpu_context", &"<gpu_context>")
268            .field("cache", &"<cache>")
269            .field("max_cache_size", &self.max_cache_size)
270            .field(
271                "current_cache_size",
272                &self
273                    .current_cache_size
274                    .load(std::sync::atomic::Ordering::Relaxed),
275            )
276            .field("enable_caching", &self.enable_caching)
277            .finish()
278    }
279}
280
281impl DeviceMemoryManager {
282    /// Create a new device memory manager
283    pub fn new(max_cachesize: usize) -> Result<Self, CoreError> {
284        // Try to create a GPU context if a GPU is available
285        let gpu_context = match GpuBackend::preferred() {
286            backend if backend.is_available() => GpuContext::new(backend).ok(),
287            _ => None,
288        };
289
290        Ok(Self {
291            gpu_context,
292            cache: Mutex::new(HashMap::new()),
293            max_cache_size: max_cachesize,
294            current_cache_size: std::sync::atomic::AtomicUsize::new(0),
295            enable_caching: true,
296        })
297    }
298
299    /// Check if a device type is available
300    pub fn is_device_available(&self, device: DeviceType) -> bool {
301        match device {
302            DeviceType::Cpu => true,
303            DeviceType::Gpu(_) => self.gpu_context.is_some(),
304            DeviceType::Tpu => false, // TPU not yet supported
305        }
306    }
307
308    /// Get a list of available devices
309    pub fn available_devices(&self) -> Vec<DeviceType> {
310        let mut devices = vec![DeviceType::Cpu];
311
312        if let Some(ref context) = self.gpu_context {
313            devices.push(DeviceType::Gpu(context.backend()));
314        }
315
316        devices
317    }
318
319    /// Transfer data from host to device
320    pub fn transfer_to_device<T, S, D>(
321        &self,
322        array: &ArrayBase<S, D>,
323        device: DeviceType,
324        options: Option<TransferOptions>,
325    ) -> CoreResult<DeviceArray<T, D>>
326    where
327        T: GpuDataType,
328        S: RawData<Elem = T> + crate::ndarray::Data,
329        D: Dimension,
330    {
331        let options = options.unwrap_or_default();
332
333        // Check if the device is available
334        if !self.is_device_available(device) {
335            return Err(CoreError::DeviceError(
336                ErrorContext::new(format!("Device {device} is not available"))
337                    .with_location(ErrorLocation::new(file!(), line!())),
338            ));
339        }
340
341        // For CPU, just create a view of the array
342        if device == DeviceType::Cpu {
343            return Ok(DeviceArray::new_cpu(array.to_owned()));
344        }
345
346        // For GPU, create a GPU buffer
347        if let DeviceType::Gpu(backend) = device {
348            if let Some(ref context) = self.gpu_context {
349                if context.backend() != backend {
350                    return Err(CoreError::DeviceError(
351                        ErrorContext::new(format!(
352                            "GPU backend mismatch: requested {}, available {}",
353                            backend,
354                            context.backend()
355                        ))
356                        .with_location(ErrorLocation::new(file!(), line!())),
357                    ));
358                }
359
360                // Create a flat view of the array data
361                let flat_data = array.as_slice().ok_or_else(|| {
362                    CoreError::DeviceError(
363                        ErrorContext::new("Array is not contiguous".to_string())
364                            .with_location(ErrorLocation::new(file!(), line!())),
365                    )
366                })?;
367
368                // Check if we have a cached buffer for this array
369                let data_id = flat_data.as_ptr() as usize;
370                let key = CacheKey {
371                    data_id,
372                    device,
373                    type_id: TypeId::of::<T>(),
374                    size: flat_data.len(),
375                };
376
377                let buffer = if self.enable_caching {
378                    let mut cache = self.cache.lock().expect("Operation failed");
379                    if let Some(entry) = cache.get_mut(&key) {
380                        // We found a cached entry, cast it to the correct type
381                        if let Some(entry) = entry.downcast_mut::<CacheEntry<T>>() {
382                            // Update the last access time
383                            entry.last_access = std::time::Instant::now();
384                            entry.buffer.clone()
385                        } else {
386                            // This should never happen if our caching logic is correct
387                            return Err(CoreError::DeviceError(
388                                ErrorContext::new("Cache entry type mismatch".to_string())
389                                    .with_location(ErrorLocation::new(file!(), line!())),
390                            ));
391                        }
392                    } else {
393                        // No cached entry, create a new buffer
394                        let gpubuffer = context.create_buffer_from_slice(flat_data);
395                        let buffer = DeviceBuffer::new_gpu(gpubuffer);
396
397                        // Add to cache
398                        let entry = CacheEntry {
399                            buffer: buffer.clone(),
400                            size: flat_data.len(),
401                            last_access: std::time::Instant::now(),
402                            dirty: false,
403                        };
404
405                        let buffersize = std::mem::size_of_val(flat_data);
406                        self.current_cache_size
407                            .fetch_add(buffersize, std::sync::atomic::Ordering::SeqCst);
408
409                        // If we're over the cache size limit, evict old entries
410                        self.evict_cache_entries_if_needed();
411
412                        cache.insert(key, Box::new(entry));
413                        buffer
414                    }
415                } else {
416                    // Caching is disabled, just create a new buffer
417                    let gpubuffer = context.create_buffer_from_slice(flat_data);
418                    DeviceBuffer::new_gpu(gpubuffer)
419                };
420
421                return Ok(DeviceArray {
422                    buffer,
423                    shape: array.raw_dim(),
424                    device: DeviceType::Gpu(crate::gpu::GpuBackend::preferred()),
425                    phantom: PhantomData,
426                });
427            }
428        }
429
430        Err(CoreError::DeviceError(
431            ErrorContext::new(format!("{device}"))
432                .with_location(ErrorLocation::new(file!(), line!())),
433        ))
434    }
435
436    /// Transfer data from device to host
437    pub fn transfer_to_host<T, D>(
438        &self,
439        devicearray: &DeviceArray<T, D>,
440        options: Option<TransferOptions>,
441    ) -> CoreResult<Array<T, D>>
442    where
443        T: GpuDataType,
444        D: Dimension,
445    {
446        let options = options.unwrap_or_default();
447
448        // For CPU arrays, just clone the data
449        if devicearray.device == DeviceType::Cpu {
450            if let Some(cpuarray) = devicearray.buffer.get_cpuarray() {
451                let reshaped = cpuarray
452                    .clone()
453                    .to_shape(devicearray.shape.clone())
454                    .map_err(|e| CoreError::ShapeError(ErrorContext::new(e.to_string())))?
455                    .to_owned();
456                return Ok(reshaped);
457            }
458        }
459
460        // For GPU arrays, copy the data back to the host
461        if let DeviceType::Gpu(_) = devicearray.device {
462            if let Some(gpubuffer) = devicearray.buffer.get_gpubuffer() {
463                let size = devicearray.size();
464                let mut data = vec![unsafe { std::mem::zeroed() }; size];
465
466                // Copy data from GPU to host
467                let _ = gpubuffer.copy_to_host(&mut data);
468
469                // Reshape the data to match the original array shape
470                return Array::from_shape_vec(devicearray.shape.clone(), data).map_err(|e| {
471                    CoreError::DeviceError(
472                        ErrorContext::new(format!("{e}"))
473                            .with_location(ErrorLocation::new(file!(), line!())),
474                    )
475                });
476            }
477        }
478
479        Err(CoreError::DeviceError(
480            ErrorContext::new(format!(
481                "Unsupported device type for transfer to host: {}",
482                devicearray.device
483            ))
484            .with_location(ErrorLocation::new(file!(), line!())),
485        ))
486    }
487
488    /// Transfer data between devices
489    pub fn transfer_between_devices<T, D>(
490        &self,
491        devicearray: &DeviceArray<T, D>,
492        target_device: DeviceType,
493        options: Option<TransferOptions>,
494    ) -> CoreResult<DeviceArray<T, D>>
495    where
496        T: GpuDataType,
497        D: Dimension,
498    {
499        let options = options.unwrap_or_default();
500
501        // If the source and target devices are the same, just clone the array
502        if devicearray.device == target_device {
503            return Ok(devicearray.clone());
504        }
505
506        // For transfers to CPU, use transfer_to_host
507        if target_device == DeviceType::Cpu {
508            let hostarray = self.transfer_to_host(devicearray, Some(options))?;
509            return Ok(DeviceArray::new_cpu(hostarray));
510        }
511
512        // For transfers from CPU to another device, use transfer_to_device
513        if devicearray.device == DeviceType::Cpu {
514            if let Some(cpuarray) = devicearray.buffer.get_cpuarray() {
515                // Reshape the CPU array to match the expected dimension type
516                let cpu_clone = cpuarray.clone();
517                let reshaped = cpu_clone
518                    .to_shape(devicearray.shape.clone())
519                    .map_err(|e| CoreError::ShapeError(ErrorContext::new(e.to_string())))?;
520                return self.transfer_to_device(&reshaped.to_owned(), target_device, Some(options));
521            }
522        }
523
524        // For transfers between GPUs (or future TPU support)
525        // In a real implementation, we would use peer-to-peer transfers if available,
526        // or copy through host memory if not
527
528        // For now, we'll transfer through host memory
529        let hostarray = self.transfer_to_host(devicearray, Some(options.clone()))?;
530        self.transfer_to_device(&hostarray, target_device, Some(options))
531    }
532
533    /// Evict cache entries if the total size exceeds the limit
534    fn evict_cache_entries_if_needed(&self) {
535        let current_size = self
536            .current_cache_size
537            .load(std::sync::atomic::Ordering::SeqCst);
538        if current_size <= self.max_cache_size {
539            return;
540        }
541
542        let mut cache = self.cache.lock().expect("Operation failed");
543
544        // Collect keys with their access times to avoid borrow conflicts
545        let mut key_times: Vec<_> = cache
546            .iter()
547            .map(|(key, value)| {
548                let access_time = match value.downcast_ref::<CacheEntry<f32>>() {
549                    Some(entry) => entry.last_access,
550                    None => match value.downcast_ref::<CacheEntry<f64>>() {
551                        Some(entry) => entry.last_access,
552                        None => match value.downcast_ref::<CacheEntry<i32>>() {
553                            Some(entry) => entry.last_access,
554                            None => match value.downcast_ref::<CacheEntry<u32>>() {
555                                Some(entry) => entry.last_access,
556                                None => std::time::Instant::now(), // Fallback, shouldn't happen
557                            },
558                        },
559                    },
560                };
561                (key.clone(), access_time)
562            })
563            .collect();
564
565        // Sort by access time (oldest first)
566        key_times.sort_by(|a, b| a.1.cmp(&b.1));
567
568        // Remove entries until we're under the limit
569        let mut removed_size = 0;
570        let target_size = current_size - self.max_cache_size / 2; // Remove enough to get below half the limit
571
572        for key_ in key_times {
573            let entry = cache.remove(&key_.0).expect("Operation failed");
574
575            // Calculate the size of the entry based on its type
576            let entry_size = match entry.downcast_ref::<CacheEntry<f32>>() {
577                Some(entry) => entry.size * std::mem::size_of::<f32>(),
578                None => match entry.downcast_ref::<CacheEntry<f64>>() {
579                    Some(entry) => entry.size * std::mem::size_of::<f64>(),
580                    None => match entry.downcast_ref::<CacheEntry<i32>>() {
581                        Some(entry) => entry.size * std::mem::size_of::<i32>(),
582                        None => match entry.downcast_ref::<CacheEntry<u32>>() {
583                            Some(entry) => entry.size * std::mem::size_of::<u32>(),
584                            None => 0, // Fallback, shouldn't happen
585                        },
586                    },
587                },
588            };
589
590            removed_size += entry_size;
591
592            if removed_size >= target_size {
593                break;
594            }
595        }
596
597        // Update the current cache size
598        self.current_cache_size
599            .fetch_sub(removed_size, std::sync::atomic::Ordering::SeqCst);
600    }
601
602    /// Clear the cache
603    pub fn clear_cache(&self) {
604        let mut cache = self.cache.lock().expect("Operation failed");
605        cache.clear();
606        self.current_cache_size
607            .store(0, std::sync::atomic::Ordering::SeqCst);
608    }
609
610    /// Execute a kernel on a device array
611    pub fn execute_kernel<T, D>(
612        &self,
613        devicearray: &DeviceArray<T, D>,
614        kernel_name: &str,
615        params: HashMap<String, KernelParam>,
616    ) -> CoreResult<()>
617    where
618        T: GpuDataType,
619        D: Dimension,
620    {
621        // Only GPU devices support kernel execution
622        if let DeviceType::Gpu(_) = devicearray.device {
623            if let Some(ref context) = self.gpu_context {
624                // Get the kernel
625                let kernel = context
626                    .get_kernel(kernel_name)
627                    .map_err(|e| CoreError::ComputationError(ErrorContext::new(e.to_string())))?;
628
629                // Set the input buffer parameter
630                if let Some(gpubuffer) = devicearray.buffer.get_gpubuffer() {
631                    kernel.set_buffer("input", gpubuffer);
632                }
633
634                // Set other parameters
635                for (name, param) in params {
636                    match param {
637                        KernelParam::Buffer(buffer) => {
638                            if let Some(gpubuffer) = buffer.get_gpubuffer() {
639                                kernel.set_buffer(&name, gpubuffer);
640                            }
641                        }
642                        KernelParam::U32(value) => kernel.set_u32(&name, value),
643                        KernelParam::I32(value) => kernel.set_i32(&name, value),
644                        KernelParam::F32(value) => kernel.set_f32(&name, value),
645                        KernelParam::F64(value) => kernel.set_f64(&name, value),
646                    }
647                }
648
649                // Compute dispatch dimensions
650                let total_elements = devicearray.size();
651                let work_group_size = 256; // A common CUDA/OpenCL work group size
652                let num_groups = total_elements.div_ceil(work_group_size);
653
654                // Dispatch the kernel
655                kernel.dispatch([num_groups as u32, 1, 1]);
656
657                return Ok(());
658            }
659        }
660
661        Err(CoreError::DeviceError(
662            ErrorContext::new(format!(
663                "Unsupported device type for kernel execution: {}",
664                devicearray.device
665            ))
666            .with_location(ErrorLocation::new(file!(), line!())),
667        ))
668    }
669}
670
671/// Kernel parameter for GPU execution
672#[derive(Debug, Clone)]
673pub enum KernelParam {
674    /// Buffer parameter
675    Buffer(DeviceBuffer<f32>), // Note: In a real implementation, this would be generic
676    /// U32 parameter
677    U32(u32),
678    /// I32 parameter
679    I32(i32),
680    /// F32 parameter
681    F32(f32),
682    /// F64 parameter
683    F64(f64),
684}
685
686/// Buffer location (CPU or GPU)
687#[derive(Clone)]
688enum BufferLocation<T: GpuDataType> {
689    /// CPU buffer
690    Cpu(Arc<Array<T, IxDyn>>),
691    /// GPU buffer
692    Gpu(Arc<GpuBuffer<T>>),
693}
694
695impl<T> std::fmt::Debug for BufferLocation<T>
696where
697    T: GpuDataType + std::fmt::Debug,
698{
699    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
700        match self {
701            BufferLocation::Cpu(_) => write!(f, "Cpu(Array)"),
702            BufferLocation::Gpu(_) => write!(f, "Gpu(GpuBuffer)"),
703        }
704    }
705}
706
707/// Buffer for cross-device operations
708#[derive(Debug, Clone)]
709pub struct DeviceBuffer<T: GpuDataType> {
710    /// Buffer data (CPU or GPU)
711    location: BufferLocation<T>,
712}
713
714impl<T: GpuDataType> DeviceBuffer<T> {
715    /// Create a new CPU buffer
716    fn new_cpu<D: Dimension>(array: Array<T, D>) -> Self {
717        Self {
718            location: BufferLocation::Cpu(Arc::new(array.into_dyn())),
719        }
720    }
721
722    /// Create a new GPU buffer
723    fn new_gpu(buffer: GpuBuffer<T>) -> Self {
724        Self {
725            location: BufferLocation::Gpu(Arc::new(buffer)),
726        }
727    }
728
729    /// Get the CPU array if available
730    fn get_cpuarray(&self) -> Option<&Array<T, IxDyn>> {
731        match self.location {
732            BufferLocation::Cpu(ref array) => Some(array),
733            _ => None,
734        }
735    }
736
737    /// Get the GPU buffer if available
738    fn get_gpubuffer(&self) -> Option<&GpuBuffer<T>> {
739        match self.location {
740            BufferLocation::Gpu(ref buffer) => Some(buffer),
741            _ => None,
742        }
743    }
744
745    /// Get the size of the buffer in elements
746    fn size(&self) -> usize {
747        match self.location {
748            BufferLocation::Cpu(ref array) => array.len(),
749            BufferLocation::Gpu(ref buffer) => buffer.len(),
750        }
751    }
752}
753
754/// Array residing on a specific device (CPU, GPU, TPU)
755#[derive(Debug, Clone)]
756pub struct DeviceArray<T: GpuDataType, D: Dimension> {
757    /// Buffer containing the array data
758    buffer: DeviceBuffer<T>,
759    /// Shape of the array
760    shape: D,
761    /// Device where the array resides
762    device: DeviceType,
763    /// Phantom data for the element type
764    phantom: PhantomData<T>,
765}
766
767impl<T: GpuDataType, D: Dimension> DeviceArray<T, D> {
768    /// Create a new CPU array
769    fn new_cpu<S: RawData<Elem = T> + crate::ndarray::Data>(array: ArrayBase<S, D>) -> Self {
770        Self {
771            buffer: DeviceBuffer::new_cpu(array.to_owned()),
772            shape: array.raw_dim(),
773            device: DeviceType::Cpu,
774            phantom: PhantomData,
775        }
776    }
777
778    /// Get the device where the array resides
779    pub fn device(&self) -> DeviceType {
780        self.device
781    }
782
783    /// Get the shape of the array
784    pub const fn shape(&self) -> &D {
785        &self.shape
786    }
787
788    /// Get the size of the array in elements
789    pub fn size(&self) -> usize {
790        self.buffer.size()
791    }
792
793    /// Get the number of dimensions
794    pub fn ndim(&self) -> usize {
795        self.shape.ndim()
796    }
797
798    /// Check if the array is on the CPU
799    pub fn is_on_cpu(&self) -> bool {
800        self.device == DeviceType::Cpu
801    }
802
803    /// Check if the array is on a GPU
804    pub fn is_on_gpu(&self) -> bool {
805        matches!(self.device, DeviceType::Gpu(_))
806    }
807
808    /// Get a reference to the underlying CPU array if available
809    pub fn as_cpuarray(&self) -> Option<&Array<T, IxDyn>> {
810        self.buffer.get_cpuarray()
811    }
812
813    /// Get a reference to the underlying GPU buffer if available
814    pub fn as_gpubuffer(&self) -> Option<&GpuBuffer<T>> {
815        self.buffer.get_gpubuffer()
816    }
817}
818
819/// Stream for asynchronous operations
820pub struct DeviceStream {
821    /// Device associated with the stream
822    #[allow(dead_code)]
823    device: DeviceType,
824    /// Internal stream handle (implementation-specific)
825    #[allow(dead_code)]
826    handle: Arc<Mutex<Box<dyn std::any::Any + Send + Sync>>>,
827}
828
829impl DeviceStream {
830    /// Create a new device stream
831    pub fn new(device: DeviceType) -> CoreResult<Self> {
832        // In a real implementation, we would create a stream for the _device
833        // For now, just create a dummy stream
834        Ok(Self {
835            device,
836            handle: Arc::new(Mutex::new(Box::new(()))),
837        })
838    }
839
840    /// Synchronize the stream
841    pub fn synchronize(&self) {
842        // In a real implementation, this would wait for all operations to complete
843    }
844}
845
846/// Memory pool for efficient allocation on a device
847pub struct DeviceMemoryPool {
848    /// Device associated with the pool
849    device: DeviceType,
850    /// List of free buffers by size
851    freebuffers: Mutex<HashMap<usize, Vec<Box<dyn std::any::Any + Send + Sync>>>>,
852    /// Maximum pool size in bytes
853    max_poolsize: usize,
854    /// Current pool size in bytes
855    current_poolsize: std::sync::atomic::AtomicUsize,
856}
857
858impl std::fmt::Debug for DeviceMemoryPool {
859    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860        f.debug_struct("DeviceMemoryPool")
861            .field("device", &self.device)
862            .field("freebuffers", &"<freebuffers>")
863            .field("max_poolsize", &self.max_poolsize)
864            .field(
865                "current_poolsize",
866                &self
867                    .current_poolsize
868                    .load(std::sync::atomic::Ordering::Relaxed),
869            )
870            .finish()
871    }
872}
873
874impl DeviceMemoryPool {
875    /// Create a new device memory pool
876    pub fn new(device: DeviceType, max_poolsize: usize) -> Self {
877        Self {
878            device,
879            freebuffers: Mutex::new(HashMap::new()),
880            max_poolsize,
881            current_poolsize: std::sync::atomic::AtomicUsize::new(0),
882        }
883    }
884
885    /// Allocate a buffer of the given size
886    pub fn allocate<T: GpuDataType + num_traits::Zero>(
887        &self,
888        size: usize,
889    ) -> CoreResult<DeviceBuffer<T>> {
890        // Check if we have a free buffer of the right size
891        let mut freebuffers = self.freebuffers.lock().expect("Operation failed");
892        if let Some(buffers) = freebuffers.get_mut(&size) {
893            if let Some(buffer) = buffers.pop() {
894                // We found a free buffer, cast it to the correct type
895                if let Ok(buffer) = buffer.downcast::<DeviceBuffer<T>>() {
896                    return Ok(*buffer);
897                }
898            }
899        }
900
901        // No free buffer, allocate a new one
902        match self.device {
903            DeviceType::Cpu => {
904                // Allocate CPU memory
905                let array = Array::<T, crate::ndarray::IxDyn>::zeros(IxDyn(&[size]));
906                Ok(DeviceBuffer::new_cpu(array))
907            }
908            DeviceType::Gpu(_) => {
909                // Allocate GPU memory
910                Err(CoreError::ImplementationError(
911                    ErrorContext::new("GPU memory allocation not implemented".to_string())
912                        .with_location(ErrorLocation::new(file!(), line!())),
913                ))
914            }
915            DeviceType::Tpu => {
916                // TPU not yet supported
917                Err(CoreError::DeviceError(
918                    ErrorContext::new("TPU not supported".to_string())
919                        .with_location(ErrorLocation::new(file!(), line!())),
920                ))
921            }
922        }
923    }
924
925    /// Free a buffer (return it to the pool)
926    pub fn free<T: GpuDataType>(&self, buffer: DeviceBuffer<T>) {
927        let size = buffer.size();
928        let buffersize = size * std::mem::size_of::<T>();
929
930        // Check if adding this buffer would exceed the pool size
931        let current_size = self
932            .current_poolsize
933            .load(std::sync::atomic::Ordering::SeqCst);
934        if current_size + buffersize > self.max_poolsize {
935            // Pool is full, just let the buffer be dropped
936            return;
937        }
938
939        // Add the buffer to the pool
940        let mut freebuffers = self.freebuffers.lock().expect("Operation failed");
941        freebuffers.entry(size).or_default().push(Box::new(buffer));
942
943        // Update the pool size
944        self.current_poolsize
945            .fetch_add(buffersize, std::sync::atomic::Ordering::SeqCst);
946    }
947
948    /// Clear the pool
949    pub fn clear(&self) {
950        let mut freebuffers = self.freebuffers.lock().expect("Operation failed");
951        freebuffers.clear();
952        self.current_poolsize
953            .store(0, std::sync::atomic::Ordering::SeqCst);
954    }
955}
956
957/// Cross-device array operations
958impl<T: GpuDataType, D: Dimension> DeviceArray<T, D> {
959    /// Map the array elements using a function
960    pub fn map<F>(&self, f: F, manager: &DeviceMemoryManager) -> CoreResult<DeviceArray<T, D>>
961    where
962        F: Fn(T) -> T + Send + Sync,
963        D: Clone,
964    {
965        // For CPU arrays, use ndarray's map function
966        if self.is_on_cpu() {
967            if let Some(cpuarray) = self.as_cpuarray() {
968                let mapped = cpuarray.map(|&x| f(x));
969                return Ok(DeviceArray {
970                    buffer: DeviceBuffer::new_cpu(mapped),
971                    shape: self.shape.clone(),
972                    device: DeviceType::Cpu,
973                    phantom: PhantomData,
974                });
975            }
976        }
977
978        // For GPU arrays, transfer to host, map, and transfer back
979        // In a real implementation, we would use a GPU kernel
980        let hostarray = manager.transfer_to_host(self, None)?;
981        let mapped = hostarray.map(|&x| f(x));
982        manager.transfer_to_device(&mapped, self.device, None)
983    }
984
985    /// Reduce the array using a binary operation
986    pub fn reduce<F>(&self, f: F, manager: &DeviceMemoryManager) -> CoreResult<T>
987    where
988        F: Fn(T, T) -> T + Send + Sync,
989        T: Copy,
990    {
991        // For CPU arrays, use ndarray's fold function
992        if self.is_on_cpu() {
993            if let Some(cpuarray) = self.as_cpuarray() {
994                if cpuarray.is_empty() {
995                    return Err(CoreError::ValueError(
996                        ErrorContext::new("Cannot reduce empty array".to_string())
997                            .with_location(ErrorLocation::new(file!(), line!())),
998                    ));
999                }
1000
1001                let first = cpuarray[0];
1002                let result = cpuarray.iter().skip(1).fold(first, |acc, &x| f(acc, x));
1003                return Ok(result);
1004            }
1005        }
1006
1007        // For GPU arrays, transfer to host and reduce
1008        // In a real implementation, we would use a GPU reduction kernel
1009        let hostarray = manager.transfer_to_host(self, None)?;
1010        if hostarray.is_empty() {
1011            return Err(CoreError::ValueError(
1012                ErrorContext::new("Cannot reduce empty array".to_string())
1013                    .with_location(ErrorLocation::new(file!(), line!())),
1014            ));
1015        }
1016
1017        let first = *hostarray.iter().next().expect("Operation failed");
1018        let result = hostarray.iter().skip(1).fold(first, |acc, &x| f(acc, x));
1019        Ok(result)
1020    }
1021}
1022
1023/// Cross-device manager for handling data transfers and operations
1024#[derive(Debug)]
1025pub struct CrossDeviceManager {
1026    /// Memory managers for each device
1027    memory_managers: HashMap<DeviceType, DeviceMemoryManager>,
1028    /// Memory pools for each device
1029    memory_pools: HashMap<DeviceType, DeviceMemoryPool>,
1030    /// Active data transfers
1031    active_transfers: Mutex<Vec<TransferEvent>>,
1032    /// Enable caching
1033    #[allow(dead_code)]
1034    enable_caching: bool,
1035    /// Maximum cache size in bytes
1036    #[allow(dead_code)]
1037    max_cache_size: usize,
1038}
1039
1040impl CrossDeviceManager {
1041    /// Create a new cross-device manager
1042    pub fn new(max_cachesize: usize) -> CoreResult<Self> {
1043        let mut memory_managers = HashMap::new();
1044        let mut memory_pools = HashMap::new();
1045
1046        // Create CPU memory manager and pool
1047        let cpu_manager = DeviceMemoryManager::new(max_cachesize)?;
1048        memory_managers.insert(DeviceType::Cpu, cpu_manager);
1049        memory_pools.insert(
1050            DeviceType::Cpu,
1051            DeviceMemoryPool::new(DeviceType::Cpu, max_cachesize),
1052        );
1053
1054        // Try to create GPU memory manager and pool
1055        let gpu_backend = GpuBackend::preferred();
1056        if gpu_backend.is_available() {
1057            let gpu_device = DeviceType::Gpu(gpu_backend);
1058            let gpu_manager = DeviceMemoryManager::new(max_cachesize)?;
1059            memory_managers.insert(gpu_device, gpu_manager);
1060            memory_pools.insert(gpu_device, DeviceMemoryPool::new(gpu_device, max_cachesize));
1061        }
1062
1063        Ok(Self {
1064            memory_managers,
1065            memory_pools,
1066            active_transfers: Mutex::new(Vec::new()),
1067            enable_caching: true,
1068            max_cache_size: max_cachesize,
1069        })
1070    }
1071
1072    /// Get a list of available devices
1073    pub fn available_devices(&self) -> Vec<DeviceType> {
1074        self.memory_managers.keys().cloned().collect()
1075    }
1076
1077    /// Check if a device is available
1078    pub fn is_device_available(&self, device: DeviceType) -> bool {
1079        self.memory_managers.contains_key(&device)
1080    }
1081
1082    /// Transfer data to a device
1083    pub fn to_device<T, S, D>(
1084        &self,
1085        array: &ArrayBase<S, D>,
1086        device: DeviceType,
1087        options: Option<TransferOptions>,
1088    ) -> CoreResult<DeviceArray<T, D>>
1089    where
1090        T: GpuDataType,
1091        S: RawData<Elem = T> + crate::ndarray::Data,
1092        D: Dimension,
1093    {
1094        // Check if the device is available
1095        if !self.is_device_available(device) {
1096            return Err(CoreError::DeviceError(
1097                ErrorContext::new(format!("Device {device} is not available"))
1098                    .with_location(ErrorLocation::new(file!(), line!())),
1099            ));
1100        }
1101
1102        // Get the memory manager for the device
1103        let manager = self.memory_managers.get(&device).expect("Operation failed");
1104        manager.transfer_to_device(array, device, options)
1105    }
1106
1107    /// Transfer data from a device to the host
1108    pub fn to_host<T, D>(
1109        &self,
1110        devicearray: &DeviceArray<T, D>,
1111        options: Option<TransferOptions>,
1112    ) -> CoreResult<Array<T, D>>
1113    where
1114        T: GpuDataType,
1115        D: Dimension,
1116    {
1117        // Get the memory manager for the device
1118        let manager = self
1119            .memory_managers
1120            .get(&devicearray.device)
1121            .ok_or_else(|| {
1122                CoreError::DeviceError(
1123                    ErrorContext::new(format!("Device {} is not available", devicearray.device))
1124                        .with_location(ErrorLocation::new(file!(), line!())),
1125                )
1126            })?;
1127
1128        manager.transfer_to_host(devicearray, options)
1129    }
1130
1131    /// Transfer data between devices
1132    pub fn transfer<T, D>(
1133        &self,
1134        devicearray: &DeviceArray<T, D>,
1135        target_device: DeviceType,
1136        options: Option<TransferOptions>,
1137    ) -> CoreResult<DeviceArray<T, D>>
1138    where
1139        T: GpuDataType,
1140        D: Dimension,
1141    {
1142        // Check if the target _device is available
1143        if !self.is_device_available(target_device) {
1144            return Err(CoreError::DeviceError(
1145                ErrorContext::new(format!("Device {target_device} is not available"))
1146                    .with_location(ErrorLocation::new(file!(), line!())),
1147            ));
1148        }
1149
1150        // Get the memory manager for the source _device
1151        let manager = self
1152            .memory_managers
1153            .get(&devicearray.device)
1154            .ok_or_else(|| {
1155                CoreError::DeviceError(
1156                    ErrorContext::new(format!("Device {} is not available", devicearray.device))
1157                        .with_location(ErrorLocation::new(file!(), line!())),
1158                )
1159            })?;
1160
1161        manager.transfer_between_devices(devicearray, target_device, options)
1162    }
1163
1164    /// Execute a kernel on a device array
1165    pub fn execute_kernel<T, D>(
1166        &self,
1167        devicearray: &DeviceArray<T, D>,
1168        kernel_name: &str,
1169        params: HashMap<String, KernelParam>,
1170    ) -> CoreResult<()>
1171    where
1172        T: GpuDataType,
1173        D: Dimension,
1174    {
1175        // Get the memory manager for the device
1176        let manager = self
1177            .memory_managers
1178            .get(&devicearray.device)
1179            .ok_or_else(|| {
1180                CoreError::DeviceError(
1181                    ErrorContext::new(format!("Device {} is not available", devicearray.device))
1182                        .with_location(ErrorLocation::new(file!(), line!())),
1183                )
1184            })?;
1185
1186        manager.execute_kernel(devicearray, kernel_name, params)
1187    }
1188
1189    /// Allocate memory on a device
1190    pub fn allocate<T: GpuDataType + num_traits::Zero>(
1191        &self,
1192        size: usize,
1193        device: DeviceType,
1194    ) -> CoreResult<DeviceBuffer<T>> {
1195        // Check if the device is available
1196        if !self.is_device_available(device) {
1197            return Err(CoreError::DeviceError(
1198                ErrorContext::new(format!("Device {device} is not available"))
1199                    .with_location(ErrorLocation::new(file!(), line!())),
1200            ));
1201        }
1202
1203        // Get the memory pool for the device
1204        let pool = self.memory_pools.get(&device).expect("Operation failed");
1205        pool.allocate(size)
1206    }
1207
1208    /// Free memory on a device
1209    pub fn free<T: GpuDataType>(&self, buffer: DeviceBuffer<T>, device: DeviceType) {
1210        // Check if the device is available
1211        if !self.is_device_available(device) {
1212            return;
1213        }
1214
1215        // Get the memory pool for the device
1216        let pool = self.memory_pools.get(&device).expect("Operation failed");
1217        pool.free(buffer);
1218    }
1219
1220    /// Clear all caches and pools
1221    pub fn clear(&self) {
1222        // Clear all memory managers
1223        for manager in self.memory_managers.values() {
1224            manager.clear_cache();
1225        }
1226
1227        // Clear all memory pools
1228        for pool in self.memory_pools.values() {
1229            pool.clear();
1230        }
1231
1232        // Clear active transfers
1233        let mut active_transfers = self.active_transfers.lock().expect("Operation failed");
1234        active_transfers.clear();
1235    }
1236
1237    /// Wait for all active transfers to complete
1238    pub fn synchronize(&self) {
1239        let mut active_transfers = self.active_transfers.lock().expect("Operation failed");
1240        for event in active_transfers.drain(..) {
1241            event.wait();
1242        }
1243    }
1244}
1245
1246/// Create a cross-device manager with default settings
1247#[allow(dead_code)]
1248pub fn create_cross_device_manager() -> CoreResult<CrossDeviceManager> {
1249    CrossDeviceManager::new(1024 * 1024 * 1024) // 1 GB cache by default
1250}
1251
1252/// Extension trait for arrays to simplify device transfers
1253pub trait ToDevice<T, D>
1254where
1255    T: GpuDataType,
1256    D: Dimension,
1257{
1258    /// Transfer the array to a device
1259    fn to_device(
1260        &self,
1261        device: DeviceType,
1262        manager: &CrossDeviceManager,
1263    ) -> CoreResult<DeviceArray<T, D>>;
1264}
1265
1266impl<T, S, D> ToDevice<T, D> for ArrayBase<S, D>
1267where
1268    T: GpuDataType,
1269    S: RawData<Elem = T> + crate::ndarray::Data,
1270    D: Dimension,
1271{
1272    fn to_device(
1273        &self,
1274        device: DeviceType,
1275        manager: &CrossDeviceManager,
1276    ) -> CoreResult<DeviceArray<T, D>> {
1277        manager.to_device(self, device, None)
1278    }
1279}
1280
1281/// Extension trait for device arrays to simplify host transfers
1282pub trait ToHost<T, D>
1283where
1284    T: GpuDataType,
1285    D: Dimension,
1286{
1287    /// Transfer the device array to the host
1288    fn to_host(&self, manager: &CrossDeviceManager) -> CoreResult<Array<T, D>>;
1289}
1290
1291impl<T, D> ToHost<T, D> for DeviceArray<T, D>
1292where
1293    T: GpuDataType,
1294    D: Dimension,
1295{
1296    fn to_host(&self, manager: &CrossDeviceManager) -> CoreResult<Array<T, D>> {
1297        manager.to_host(self, None)
1298    }
1299}
1300
1301// Convenience functions
1302
1303/// Create a device array on the CPU
1304#[allow(dead_code)]
1305pub fn create_cpuarray<T, S, D>(array: &ArrayBase<S, D>) -> DeviceArray<T, D>
1306where
1307    T: GpuDataType,
1308    S: RawData<Elem = T> + crate::ndarray::Data,
1309    D: Dimension,
1310{
1311    DeviceArray::new_cpu(array.to_owned())
1312}
1313
1314/// Create a device array on the GPU
1315#[allow(dead_code)]
1316pub fn create_gpuarray<T, S, D>(
1317    array: &ArrayBase<S, D>,
1318    manager: &CrossDeviceManager,
1319) -> CoreResult<DeviceArray<T, D>>
1320where
1321    T: GpuDataType,
1322    S: RawData<Elem = T> + crate::ndarray::Data,
1323    D: Dimension,
1324{
1325    // Find the first available GPU
1326    for device in manager.available_devices() {
1327        if let DeviceType::Gpu(_) = device {
1328            return manager.to_device(array, device, None);
1329        }
1330    }
1331
1332    Err(CoreError::DeviceError(
1333        ErrorContext::new("No GPU device available".to_string())
1334            .with_location(ErrorLocation::new(file!(), line!())),
1335    ))
1336}
1337
1338/// Transfer an array to the best available device
1339#[allow(dead_code)]
1340pub fn to_best_device<T, S, D>(
1341    array: &ArrayBase<S, D>,
1342    manager: &CrossDeviceManager,
1343) -> CoreResult<DeviceArray<T, D>>
1344where
1345    T: GpuDataType,
1346    S: RawData<Elem = T> + crate::ndarray::Data,
1347    D: Dimension,
1348{
1349    // Try to find a GPU first
1350    for device in manager.available_devices() {
1351        if let DeviceType::Gpu(_) = device {
1352            return manager.to_device(array, device, None);
1353        }
1354    }
1355
1356    // Fall back to CPU
1357    Ok(create_cpuarray(array))
1358}