1use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
13#[cfg(feature = "gpu")]
14use crate::gpu::{GpuBackend, GpuBuffer, GpuContext, GpuDataType};
15use ::ndarray::{Array, ArrayBase, Dimension, IxDyn, RawData};
16use std::any::TypeId;
17use std::collections::HashMap;
18use std::hash::{Hash, Hasher};
19use std::marker::PhantomData;
20use std::sync::{Arc, Mutex};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum DeviceType {
25 Cpu,
27 Gpu(GpuBackend),
29 Tpu,
31}
32
33impl DeviceType {
34 pub fn is_available(&self) -> bool {
36 match self {
37 DeviceType::Cpu => true,
38 DeviceType::Gpu(backend) => backend.is_available(),
39 DeviceType::Tpu => false, }
41 }
42
43 pub fn name(&self) -> String {
45 match self {
46 DeviceType::Cpu => "CPU".to_string(),
47 DeviceType::Gpu(backend) => format!("GPU ({backend})"),
48 DeviceType::Tpu => "TPU".to_string(),
49 }
50 }
51}
52
53impl std::fmt::Display for DeviceType {
54 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
55 match self {
56 DeviceType::Cpu => write!(f, "CPU"),
57 DeviceType::Gpu(backend) => write!(f, "GPU ({backend})"),
58 DeviceType::Tpu => write!(f, "TPU"),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum TransferDirection {
66 HostToDevice,
68 DeviceToHost,
70 DeviceToDevice,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum TransferMode {
77 Synchronous,
79 Asynchronous,
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum MemoryLayout {
86 RowMajor,
88 ColumnMajor,
90 Strided,
92}
93
94#[derive(Debug, Clone)]
96pub struct TransferOptions {
97 pub mode: TransferMode,
99 pub layout: MemoryLayout,
101 pub use_pinned_memory: bool,
103 pub enable_streaming: bool,
105 pub stream_id: Option<usize>,
107}
108
109impl Default for TransferOptions {
110 fn default() -> Self {
111 Self {
112 mode: TransferMode::Synchronous,
113 layout: MemoryLayout::RowMajor,
114 use_pinned_memory: true,
115 enable_streaming: true,
116 stream_id: None,
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct TransferOptionsBuilder {
124 options: TransferOptions,
125}
126
127impl TransferOptionsBuilder {
128 pub fn new() -> Self {
130 Self {
131 options: TransferOptions::default(),
132 }
133 }
134
135 pub const fn mode(mut self, mode: TransferMode) -> Self {
137 self.options.mode = mode;
138 self
139 }
140
141 pub const fn layout(mut self, layout: MemoryLayout) -> Self {
143 self.options.layout = layout;
144 self
145 }
146
147 pub const fn memory(mut self, use_pinnedmemory: bool) -> Self {
149 self.options.use_pinned_memory = use_pinnedmemory;
150 self
151 }
152
153 pub const fn streaming(mut self, enablestreaming: bool) -> Self {
155 self.options.enable_streaming = enablestreaming;
156 self
157 }
158
159 pub const fn with_stream_id(mut self, streamid: Option<usize>) -> Self {
161 self.options.stream_id = streamid;
162 self
163 }
164
165 pub fn build(self) -> TransferOptions {
167 self.options
168 }
169}
170
171impl Default for TransferOptionsBuilder {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[derive(Debug, Clone, PartialEq, Eq)]
179struct CacheKey {
180 data_id: usize,
182 device: DeviceType,
184 type_id: TypeId,
186 size: usize,
188}
189
190impl Hash for CacheKey {
191 fn hash<H: Hasher>(&self, state: &mut H) {
192 self.data_id.hash(state);
193 self.device.hash(state);
194 std::any::TypeId::of::<i32>().hash(state);
195 self.size.hash(state);
196 }
197}
198
199#[derive(Debug)]
201pub struct TransferEvent {
202 #[allow(dead_code)]
204 device: DeviceType,
205 #[allow(dead_code)]
207 handle: Arc<Mutex<Box<dyn std::any::Any + Send + Sync>>>,
208 completed: Arc<std::sync::atomic::AtomicBool>,
210}
211
212impl TransferEvent {
213 #[allow(dead_code)]
215 fn device(devicetype: DeviceType, handle: Box<dyn std::any::Any + Send + Sync>) -> Self {
216 Self {
217 device: devicetype,
218 handle: Arc::new(Mutex::new(handle)),
219 completed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
220 }
221 }
222
223 pub fn wait(&self) {
225 self.completed
228 .store(true, std::sync::atomic::Ordering::SeqCst);
229 }
230
231 pub fn is_complete(&self) -> bool {
233 self.completed.load(std::sync::atomic::Ordering::SeqCst)
234 }
235}
236
237struct CacheEntry<T: GpuDataType> {
239 buffer: DeviceBuffer<T>,
241 size: usize,
243 last_access: std::time::Instant,
245 #[allow(dead_code)]
247 dirty: bool,
248}
249
250pub struct DeviceMemoryManager {
252 gpu_context: Option<GpuContext>,
254 cache: Mutex<HashMap<CacheKey, Box<dyn std::any::Any + Send + Sync>>>,
256 max_cache_size: usize,
258 current_cache_size: std::sync::atomic::AtomicUsize,
260 enable_caching: bool,
262}
263
264impl std::fmt::Debug for DeviceMemoryManager {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 f.debug_struct("DeviceMemoryManager")
267 .field("gpu_context", &"<gpu_context>")
268 .field("cache", &"<cache>")
269 .field("max_cache_size", &self.max_cache_size)
270 .field(
271 "current_cache_size",
272 &self
273 .current_cache_size
274 .load(std::sync::atomic::Ordering::Relaxed),
275 )
276 .field("enable_caching", &self.enable_caching)
277 .finish()
278 }
279}
280
281impl DeviceMemoryManager {
282 pub fn new(max_cachesize: usize) -> Result<Self, CoreError> {
284 let gpu_context = match GpuBackend::preferred() {
286 backend if backend.is_available() => GpuContext::new(backend).ok(),
287 _ => None,
288 };
289
290 Ok(Self {
291 gpu_context,
292 cache: Mutex::new(HashMap::new()),
293 max_cache_size: max_cachesize,
294 current_cache_size: std::sync::atomic::AtomicUsize::new(0),
295 enable_caching: true,
296 })
297 }
298
299 pub fn is_device_available(&self, device: DeviceType) -> bool {
301 match device {
302 DeviceType::Cpu => true,
303 DeviceType::Gpu(_) => self.gpu_context.is_some(),
304 DeviceType::Tpu => false, }
306 }
307
308 pub fn available_devices(&self) -> Vec<DeviceType> {
310 let mut devices = vec![DeviceType::Cpu];
311
312 if let Some(ref context) = self.gpu_context {
313 devices.push(DeviceType::Gpu(context.backend()));
314 }
315
316 devices
317 }
318
319 pub fn transfer_to_device<T, S, D>(
321 &self,
322 array: &ArrayBase<S, D>,
323 device: DeviceType,
324 options: Option<TransferOptions>,
325 ) -> CoreResult<DeviceArray<T, D>>
326 where
327 T: GpuDataType,
328 S: RawData<Elem = T> + crate::ndarray::Data,
329 D: Dimension,
330 {
331 let options = options.unwrap_or_default();
332
333 if !self.is_device_available(device) {
335 return Err(CoreError::DeviceError(
336 ErrorContext::new(format!("Device {device} is not available"))
337 .with_location(ErrorLocation::new(file!(), line!())),
338 ));
339 }
340
341 if device == DeviceType::Cpu {
343 return Ok(DeviceArray::new_cpu(array.to_owned()));
344 }
345
346 if let DeviceType::Gpu(backend) = device {
348 if let Some(ref context) = self.gpu_context {
349 if context.backend() != backend {
350 return Err(CoreError::DeviceError(
351 ErrorContext::new(format!(
352 "GPU backend mismatch: requested {}, available {}",
353 backend,
354 context.backend()
355 ))
356 .with_location(ErrorLocation::new(file!(), line!())),
357 ));
358 }
359
360 let flat_data = array.as_slice().ok_or_else(|| {
362 CoreError::DeviceError(
363 ErrorContext::new("Array is not contiguous".to_string())
364 .with_location(ErrorLocation::new(file!(), line!())),
365 )
366 })?;
367
368 let data_id = flat_data.as_ptr() as usize;
370 let key = CacheKey {
371 data_id,
372 device,
373 type_id: TypeId::of::<T>(),
374 size: flat_data.len(),
375 };
376
377 let buffer = if self.enable_caching {
378 let mut cache = self.cache.lock().expect("Operation failed");
379 if let Some(entry) = cache.get_mut(&key) {
380 if let Some(entry) = entry.downcast_mut::<CacheEntry<T>>() {
382 entry.last_access = std::time::Instant::now();
384 entry.buffer.clone()
385 } else {
386 return Err(CoreError::DeviceError(
388 ErrorContext::new("Cache entry type mismatch".to_string())
389 .with_location(ErrorLocation::new(file!(), line!())),
390 ));
391 }
392 } else {
393 let gpubuffer = context.create_buffer_from_slice(flat_data);
395 let buffer = DeviceBuffer::new_gpu(gpubuffer);
396
397 let entry = CacheEntry {
399 buffer: buffer.clone(),
400 size: flat_data.len(),
401 last_access: std::time::Instant::now(),
402 dirty: false,
403 };
404
405 let buffersize = std::mem::size_of_val(flat_data);
406 self.current_cache_size
407 .fetch_add(buffersize, std::sync::atomic::Ordering::SeqCst);
408
409 self.evict_cache_entries_if_needed();
411
412 cache.insert(key, Box::new(entry));
413 buffer
414 }
415 } else {
416 let gpubuffer = context.create_buffer_from_slice(flat_data);
418 DeviceBuffer::new_gpu(gpubuffer)
419 };
420
421 return Ok(DeviceArray {
422 buffer,
423 shape: array.raw_dim(),
424 device: DeviceType::Gpu(crate::gpu::GpuBackend::preferred()),
425 phantom: PhantomData,
426 });
427 }
428 }
429
430 Err(CoreError::DeviceError(
431 ErrorContext::new(format!("{device}"))
432 .with_location(ErrorLocation::new(file!(), line!())),
433 ))
434 }
435
436 pub fn transfer_to_host<T, D>(
438 &self,
439 devicearray: &DeviceArray<T, D>,
440 options: Option<TransferOptions>,
441 ) -> CoreResult<Array<T, D>>
442 where
443 T: GpuDataType,
444 D: Dimension,
445 {
446 let options = options.unwrap_or_default();
447
448 if devicearray.device == DeviceType::Cpu {
450 if let Some(cpuarray) = devicearray.buffer.get_cpuarray() {
451 let reshaped = cpuarray
452 .clone()
453 .to_shape(devicearray.shape.clone())
454 .map_err(|e| CoreError::ShapeError(ErrorContext::new(e.to_string())))?
455 .to_owned();
456 return Ok(reshaped);
457 }
458 }
459
460 if let DeviceType::Gpu(_) = devicearray.device {
462 if let Some(gpubuffer) = devicearray.buffer.get_gpubuffer() {
463 let size = devicearray.size();
464 let mut data = vec![unsafe { std::mem::zeroed() }; size];
465
466 let _ = gpubuffer.copy_to_host(&mut data);
468
469 return Array::from_shape_vec(devicearray.shape.clone(), data).map_err(|e| {
471 CoreError::DeviceError(
472 ErrorContext::new(format!("{e}"))
473 .with_location(ErrorLocation::new(file!(), line!())),
474 )
475 });
476 }
477 }
478
479 Err(CoreError::DeviceError(
480 ErrorContext::new(format!(
481 "Unsupported device type for transfer to host: {}",
482 devicearray.device
483 ))
484 .with_location(ErrorLocation::new(file!(), line!())),
485 ))
486 }
487
488 pub fn transfer_between_devices<T, D>(
490 &self,
491 devicearray: &DeviceArray<T, D>,
492 target_device: DeviceType,
493 options: Option<TransferOptions>,
494 ) -> CoreResult<DeviceArray<T, D>>
495 where
496 T: GpuDataType,
497 D: Dimension,
498 {
499 let options = options.unwrap_or_default();
500
501 if devicearray.device == target_device {
503 return Ok(devicearray.clone());
504 }
505
506 if target_device == DeviceType::Cpu {
508 let hostarray = self.transfer_to_host(devicearray, Some(options))?;
509 return Ok(DeviceArray::new_cpu(hostarray));
510 }
511
512 if devicearray.device == DeviceType::Cpu {
514 if let Some(cpuarray) = devicearray.buffer.get_cpuarray() {
515 let cpu_clone = cpuarray.clone();
517 let reshaped = cpu_clone
518 .to_shape(devicearray.shape.clone())
519 .map_err(|e| CoreError::ShapeError(ErrorContext::new(e.to_string())))?;
520 return self.transfer_to_device(&reshaped.to_owned(), target_device, Some(options));
521 }
522 }
523
524 let hostarray = self.transfer_to_host(devicearray, Some(options.clone()))?;
530 self.transfer_to_device(&hostarray, target_device, Some(options))
531 }
532
533 fn evict_cache_entries_if_needed(&self) {
535 let current_size = self
536 .current_cache_size
537 .load(std::sync::atomic::Ordering::SeqCst);
538 if current_size <= self.max_cache_size {
539 return;
540 }
541
542 let mut cache = self.cache.lock().expect("Operation failed");
543
544 let mut key_times: Vec<_> = cache
546 .iter()
547 .map(|(key, value)| {
548 let access_time = match value.downcast_ref::<CacheEntry<f32>>() {
549 Some(entry) => entry.last_access,
550 None => match value.downcast_ref::<CacheEntry<f64>>() {
551 Some(entry) => entry.last_access,
552 None => match value.downcast_ref::<CacheEntry<i32>>() {
553 Some(entry) => entry.last_access,
554 None => match value.downcast_ref::<CacheEntry<u32>>() {
555 Some(entry) => entry.last_access,
556 None => std::time::Instant::now(), },
558 },
559 },
560 };
561 (key.clone(), access_time)
562 })
563 .collect();
564
565 key_times.sort_by(|a, b| a.1.cmp(&b.1));
567
568 let mut removed_size = 0;
570 let target_size = current_size - self.max_cache_size / 2; for key_ in key_times {
573 let entry = cache.remove(&key_.0).expect("Operation failed");
574
575 let entry_size = match entry.downcast_ref::<CacheEntry<f32>>() {
577 Some(entry) => entry.size * std::mem::size_of::<f32>(),
578 None => match entry.downcast_ref::<CacheEntry<f64>>() {
579 Some(entry) => entry.size * std::mem::size_of::<f64>(),
580 None => match entry.downcast_ref::<CacheEntry<i32>>() {
581 Some(entry) => entry.size * std::mem::size_of::<i32>(),
582 None => match entry.downcast_ref::<CacheEntry<u32>>() {
583 Some(entry) => entry.size * std::mem::size_of::<u32>(),
584 None => 0, },
586 },
587 },
588 };
589
590 removed_size += entry_size;
591
592 if removed_size >= target_size {
593 break;
594 }
595 }
596
597 self.current_cache_size
599 .fetch_sub(removed_size, std::sync::atomic::Ordering::SeqCst);
600 }
601
602 pub fn clear_cache(&self) {
604 let mut cache = self.cache.lock().expect("Operation failed");
605 cache.clear();
606 self.current_cache_size
607 .store(0, std::sync::atomic::Ordering::SeqCst);
608 }
609
610 pub fn execute_kernel<T, D>(
612 &self,
613 devicearray: &DeviceArray<T, D>,
614 kernel_name: &str,
615 params: HashMap<String, KernelParam>,
616 ) -> CoreResult<()>
617 where
618 T: GpuDataType,
619 D: Dimension,
620 {
621 if let DeviceType::Gpu(_) = devicearray.device {
623 if let Some(ref context) = self.gpu_context {
624 let kernel = context
626 .get_kernel(kernel_name)
627 .map_err(|e| CoreError::ComputationError(ErrorContext::new(e.to_string())))?;
628
629 if let Some(gpubuffer) = devicearray.buffer.get_gpubuffer() {
631 kernel.set_buffer("input", gpubuffer);
632 }
633
634 for (name, param) in params {
636 match param {
637 KernelParam::Buffer(buffer) => {
638 if let Some(gpubuffer) = buffer.get_gpubuffer() {
639 kernel.set_buffer(&name, gpubuffer);
640 }
641 }
642 KernelParam::U32(value) => kernel.set_u32(&name, value),
643 KernelParam::I32(value) => kernel.set_i32(&name, value),
644 KernelParam::F32(value) => kernel.set_f32(&name, value),
645 KernelParam::F64(value) => kernel.set_f64(&name, value),
646 }
647 }
648
649 let total_elements = devicearray.size();
651 let work_group_size = 256; let num_groups = total_elements.div_ceil(work_group_size);
653
654 kernel.dispatch([num_groups as u32, 1, 1]);
656
657 return Ok(());
658 }
659 }
660
661 Err(CoreError::DeviceError(
662 ErrorContext::new(format!(
663 "Unsupported device type for kernel execution: {}",
664 devicearray.device
665 ))
666 .with_location(ErrorLocation::new(file!(), line!())),
667 ))
668 }
669}
670
671#[derive(Debug, Clone)]
673pub enum KernelParam {
674 Buffer(DeviceBuffer<f32>), U32(u32),
678 I32(i32),
680 F32(f32),
682 F64(f64),
684}
685
686#[derive(Clone)]
688enum BufferLocation<T: GpuDataType> {
689 Cpu(Arc<Array<T, IxDyn>>),
691 Gpu(Arc<GpuBuffer<T>>),
693}
694
695impl<T> std::fmt::Debug for BufferLocation<T>
696where
697 T: GpuDataType + std::fmt::Debug,
698{
699 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
700 match self {
701 BufferLocation::Cpu(_) => write!(f, "Cpu(Array)"),
702 BufferLocation::Gpu(_) => write!(f, "Gpu(GpuBuffer)"),
703 }
704 }
705}
706
707#[derive(Debug, Clone)]
709pub struct DeviceBuffer<T: GpuDataType> {
710 location: BufferLocation<T>,
712}
713
714impl<T: GpuDataType> DeviceBuffer<T> {
715 fn new_cpu<D: Dimension>(array: Array<T, D>) -> Self {
717 Self {
718 location: BufferLocation::Cpu(Arc::new(array.into_dyn())),
719 }
720 }
721
722 fn new_gpu(buffer: GpuBuffer<T>) -> Self {
724 Self {
725 location: BufferLocation::Gpu(Arc::new(buffer)),
726 }
727 }
728
729 fn get_cpuarray(&self) -> Option<&Array<T, IxDyn>> {
731 match self.location {
732 BufferLocation::Cpu(ref array) => Some(array),
733 _ => None,
734 }
735 }
736
737 fn get_gpubuffer(&self) -> Option<&GpuBuffer<T>> {
739 match self.location {
740 BufferLocation::Gpu(ref buffer) => Some(buffer),
741 _ => None,
742 }
743 }
744
745 fn size(&self) -> usize {
747 match self.location {
748 BufferLocation::Cpu(ref array) => array.len(),
749 BufferLocation::Gpu(ref buffer) => buffer.len(),
750 }
751 }
752}
753
754#[derive(Debug, Clone)]
756pub struct DeviceArray<T: GpuDataType, D: Dimension> {
757 buffer: DeviceBuffer<T>,
759 shape: D,
761 device: DeviceType,
763 phantom: PhantomData<T>,
765}
766
767impl<T: GpuDataType, D: Dimension> DeviceArray<T, D> {
768 fn new_cpu<S: RawData<Elem = T> + crate::ndarray::Data>(array: ArrayBase<S, D>) -> Self {
770 Self {
771 buffer: DeviceBuffer::new_cpu(array.to_owned()),
772 shape: array.raw_dim(),
773 device: DeviceType::Cpu,
774 phantom: PhantomData,
775 }
776 }
777
778 pub fn device(&self) -> DeviceType {
780 self.device
781 }
782
783 pub const fn shape(&self) -> &D {
785 &self.shape
786 }
787
788 pub fn size(&self) -> usize {
790 self.buffer.size()
791 }
792
793 pub fn ndim(&self) -> usize {
795 self.shape.ndim()
796 }
797
798 pub fn is_on_cpu(&self) -> bool {
800 self.device == DeviceType::Cpu
801 }
802
803 pub fn is_on_gpu(&self) -> bool {
805 matches!(self.device, DeviceType::Gpu(_))
806 }
807
808 pub fn as_cpuarray(&self) -> Option<&Array<T, IxDyn>> {
810 self.buffer.get_cpuarray()
811 }
812
813 pub fn as_gpubuffer(&self) -> Option<&GpuBuffer<T>> {
815 self.buffer.get_gpubuffer()
816 }
817}
818
819pub struct DeviceStream {
821 #[allow(dead_code)]
823 device: DeviceType,
824 #[allow(dead_code)]
826 handle: Arc<Mutex<Box<dyn std::any::Any + Send + Sync>>>,
827}
828
829impl DeviceStream {
830 pub fn new(device: DeviceType) -> CoreResult<Self> {
832 Ok(Self {
835 device,
836 handle: Arc::new(Mutex::new(Box::new(()))),
837 })
838 }
839
840 pub fn synchronize(&self) {
842 }
844}
845
846pub struct DeviceMemoryPool {
848 device: DeviceType,
850 freebuffers: Mutex<HashMap<usize, Vec<Box<dyn std::any::Any + Send + Sync>>>>,
852 max_poolsize: usize,
854 current_poolsize: std::sync::atomic::AtomicUsize,
856}
857
858impl std::fmt::Debug for DeviceMemoryPool {
859 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
860 f.debug_struct("DeviceMemoryPool")
861 .field("device", &self.device)
862 .field("freebuffers", &"<freebuffers>")
863 .field("max_poolsize", &self.max_poolsize)
864 .field(
865 "current_poolsize",
866 &self
867 .current_poolsize
868 .load(std::sync::atomic::Ordering::Relaxed),
869 )
870 .finish()
871 }
872}
873
874impl DeviceMemoryPool {
875 pub fn new(device: DeviceType, max_poolsize: usize) -> Self {
877 Self {
878 device,
879 freebuffers: Mutex::new(HashMap::new()),
880 max_poolsize,
881 current_poolsize: std::sync::atomic::AtomicUsize::new(0),
882 }
883 }
884
885 pub fn allocate<T: GpuDataType + num_traits::Zero>(
887 &self,
888 size: usize,
889 ) -> CoreResult<DeviceBuffer<T>> {
890 let mut freebuffers = self.freebuffers.lock().expect("Operation failed");
892 if let Some(buffers) = freebuffers.get_mut(&size) {
893 if let Some(buffer) = buffers.pop() {
894 if let Ok(buffer) = buffer.downcast::<DeviceBuffer<T>>() {
896 return Ok(*buffer);
897 }
898 }
899 }
900
901 match self.device {
903 DeviceType::Cpu => {
904 let array = Array::<T, crate::ndarray::IxDyn>::zeros(IxDyn(&[size]));
906 Ok(DeviceBuffer::new_cpu(array))
907 }
908 DeviceType::Gpu(_) => {
909 Err(CoreError::ImplementationError(
911 ErrorContext::new("GPU memory allocation not implemented".to_string())
912 .with_location(ErrorLocation::new(file!(), line!())),
913 ))
914 }
915 DeviceType::Tpu => {
916 Err(CoreError::DeviceError(
918 ErrorContext::new("TPU not supported".to_string())
919 .with_location(ErrorLocation::new(file!(), line!())),
920 ))
921 }
922 }
923 }
924
925 pub fn free<T: GpuDataType>(&self, buffer: DeviceBuffer<T>) {
927 let size = buffer.size();
928 let buffersize = size * std::mem::size_of::<T>();
929
930 let current_size = self
932 .current_poolsize
933 .load(std::sync::atomic::Ordering::SeqCst);
934 if current_size + buffersize > self.max_poolsize {
935 return;
937 }
938
939 let mut freebuffers = self.freebuffers.lock().expect("Operation failed");
941 freebuffers.entry(size).or_default().push(Box::new(buffer));
942
943 self.current_poolsize
945 .fetch_add(buffersize, std::sync::atomic::Ordering::SeqCst);
946 }
947
948 pub fn clear(&self) {
950 let mut freebuffers = self.freebuffers.lock().expect("Operation failed");
951 freebuffers.clear();
952 self.current_poolsize
953 .store(0, std::sync::atomic::Ordering::SeqCst);
954 }
955}
956
957impl<T: GpuDataType, D: Dimension> DeviceArray<T, D> {
959 pub fn map<F>(&self, f: F, manager: &DeviceMemoryManager) -> CoreResult<DeviceArray<T, D>>
961 where
962 F: Fn(T) -> T + Send + Sync,
963 D: Clone,
964 {
965 if self.is_on_cpu() {
967 if let Some(cpuarray) = self.as_cpuarray() {
968 let mapped = cpuarray.map(|&x| f(x));
969 return Ok(DeviceArray {
970 buffer: DeviceBuffer::new_cpu(mapped),
971 shape: self.shape.clone(),
972 device: DeviceType::Cpu,
973 phantom: PhantomData,
974 });
975 }
976 }
977
978 let hostarray = manager.transfer_to_host(self, None)?;
981 let mapped = hostarray.map(|&x| f(x));
982 manager.transfer_to_device(&mapped, self.device, None)
983 }
984
985 pub fn reduce<F>(&self, f: F, manager: &DeviceMemoryManager) -> CoreResult<T>
987 where
988 F: Fn(T, T) -> T + Send + Sync,
989 T: Copy,
990 {
991 if self.is_on_cpu() {
993 if let Some(cpuarray) = self.as_cpuarray() {
994 if cpuarray.is_empty() {
995 return Err(CoreError::ValueError(
996 ErrorContext::new("Cannot reduce empty array".to_string())
997 .with_location(ErrorLocation::new(file!(), line!())),
998 ));
999 }
1000
1001 let first = cpuarray[0];
1002 let result = cpuarray.iter().skip(1).fold(first, |acc, &x| f(acc, x));
1003 return Ok(result);
1004 }
1005 }
1006
1007 let hostarray = manager.transfer_to_host(self, None)?;
1010 if hostarray.is_empty() {
1011 return Err(CoreError::ValueError(
1012 ErrorContext::new("Cannot reduce empty array".to_string())
1013 .with_location(ErrorLocation::new(file!(), line!())),
1014 ));
1015 }
1016
1017 let first = *hostarray.iter().next().expect("Operation failed");
1018 let result = hostarray.iter().skip(1).fold(first, |acc, &x| f(acc, x));
1019 Ok(result)
1020 }
1021}
1022
1023#[derive(Debug)]
1025pub struct CrossDeviceManager {
1026 memory_managers: HashMap<DeviceType, DeviceMemoryManager>,
1028 memory_pools: HashMap<DeviceType, DeviceMemoryPool>,
1030 active_transfers: Mutex<Vec<TransferEvent>>,
1032 #[allow(dead_code)]
1034 enable_caching: bool,
1035 #[allow(dead_code)]
1037 max_cache_size: usize,
1038}
1039
1040impl CrossDeviceManager {
1041 pub fn new(max_cachesize: usize) -> CoreResult<Self> {
1043 let mut memory_managers = HashMap::new();
1044 let mut memory_pools = HashMap::new();
1045
1046 let cpu_manager = DeviceMemoryManager::new(max_cachesize)?;
1048 memory_managers.insert(DeviceType::Cpu, cpu_manager);
1049 memory_pools.insert(
1050 DeviceType::Cpu,
1051 DeviceMemoryPool::new(DeviceType::Cpu, max_cachesize),
1052 );
1053
1054 let gpu_backend = GpuBackend::preferred();
1056 if gpu_backend.is_available() {
1057 let gpu_device = DeviceType::Gpu(gpu_backend);
1058 let gpu_manager = DeviceMemoryManager::new(max_cachesize)?;
1059 memory_managers.insert(gpu_device, gpu_manager);
1060 memory_pools.insert(gpu_device, DeviceMemoryPool::new(gpu_device, max_cachesize));
1061 }
1062
1063 Ok(Self {
1064 memory_managers,
1065 memory_pools,
1066 active_transfers: Mutex::new(Vec::new()),
1067 enable_caching: true,
1068 max_cache_size: max_cachesize,
1069 })
1070 }
1071
1072 pub fn available_devices(&self) -> Vec<DeviceType> {
1074 self.memory_managers.keys().cloned().collect()
1075 }
1076
1077 pub fn is_device_available(&self, device: DeviceType) -> bool {
1079 self.memory_managers.contains_key(&device)
1080 }
1081
1082 pub fn to_device<T, S, D>(
1084 &self,
1085 array: &ArrayBase<S, D>,
1086 device: DeviceType,
1087 options: Option<TransferOptions>,
1088 ) -> CoreResult<DeviceArray<T, D>>
1089 where
1090 T: GpuDataType,
1091 S: RawData<Elem = T> + crate::ndarray::Data,
1092 D: Dimension,
1093 {
1094 if !self.is_device_available(device) {
1096 return Err(CoreError::DeviceError(
1097 ErrorContext::new(format!("Device {device} is not available"))
1098 .with_location(ErrorLocation::new(file!(), line!())),
1099 ));
1100 }
1101
1102 let manager = self.memory_managers.get(&device).expect("Operation failed");
1104 manager.transfer_to_device(array, device, options)
1105 }
1106
1107 pub fn to_host<T, D>(
1109 &self,
1110 devicearray: &DeviceArray<T, D>,
1111 options: Option<TransferOptions>,
1112 ) -> CoreResult<Array<T, D>>
1113 where
1114 T: GpuDataType,
1115 D: Dimension,
1116 {
1117 let manager = self
1119 .memory_managers
1120 .get(&devicearray.device)
1121 .ok_or_else(|| {
1122 CoreError::DeviceError(
1123 ErrorContext::new(format!("Device {} is not available", devicearray.device))
1124 .with_location(ErrorLocation::new(file!(), line!())),
1125 )
1126 })?;
1127
1128 manager.transfer_to_host(devicearray, options)
1129 }
1130
1131 pub fn transfer<T, D>(
1133 &self,
1134 devicearray: &DeviceArray<T, D>,
1135 target_device: DeviceType,
1136 options: Option<TransferOptions>,
1137 ) -> CoreResult<DeviceArray<T, D>>
1138 where
1139 T: GpuDataType,
1140 D: Dimension,
1141 {
1142 if !self.is_device_available(target_device) {
1144 return Err(CoreError::DeviceError(
1145 ErrorContext::new(format!("Device {target_device} is not available"))
1146 .with_location(ErrorLocation::new(file!(), line!())),
1147 ));
1148 }
1149
1150 let manager = self
1152 .memory_managers
1153 .get(&devicearray.device)
1154 .ok_or_else(|| {
1155 CoreError::DeviceError(
1156 ErrorContext::new(format!("Device {} is not available", devicearray.device))
1157 .with_location(ErrorLocation::new(file!(), line!())),
1158 )
1159 })?;
1160
1161 manager.transfer_between_devices(devicearray, target_device, options)
1162 }
1163
1164 pub fn execute_kernel<T, D>(
1166 &self,
1167 devicearray: &DeviceArray<T, D>,
1168 kernel_name: &str,
1169 params: HashMap<String, KernelParam>,
1170 ) -> CoreResult<()>
1171 where
1172 T: GpuDataType,
1173 D: Dimension,
1174 {
1175 let manager = self
1177 .memory_managers
1178 .get(&devicearray.device)
1179 .ok_or_else(|| {
1180 CoreError::DeviceError(
1181 ErrorContext::new(format!("Device {} is not available", devicearray.device))
1182 .with_location(ErrorLocation::new(file!(), line!())),
1183 )
1184 })?;
1185
1186 manager.execute_kernel(devicearray, kernel_name, params)
1187 }
1188
1189 pub fn allocate<T: GpuDataType + num_traits::Zero>(
1191 &self,
1192 size: usize,
1193 device: DeviceType,
1194 ) -> CoreResult<DeviceBuffer<T>> {
1195 if !self.is_device_available(device) {
1197 return Err(CoreError::DeviceError(
1198 ErrorContext::new(format!("Device {device} is not available"))
1199 .with_location(ErrorLocation::new(file!(), line!())),
1200 ));
1201 }
1202
1203 let pool = self.memory_pools.get(&device).expect("Operation failed");
1205 pool.allocate(size)
1206 }
1207
1208 pub fn free<T: GpuDataType>(&self, buffer: DeviceBuffer<T>, device: DeviceType) {
1210 if !self.is_device_available(device) {
1212 return;
1213 }
1214
1215 let pool = self.memory_pools.get(&device).expect("Operation failed");
1217 pool.free(buffer);
1218 }
1219
1220 pub fn clear(&self) {
1222 for manager in self.memory_managers.values() {
1224 manager.clear_cache();
1225 }
1226
1227 for pool in self.memory_pools.values() {
1229 pool.clear();
1230 }
1231
1232 let mut active_transfers = self.active_transfers.lock().expect("Operation failed");
1234 active_transfers.clear();
1235 }
1236
1237 pub fn synchronize(&self) {
1239 let mut active_transfers = self.active_transfers.lock().expect("Operation failed");
1240 for event in active_transfers.drain(..) {
1241 event.wait();
1242 }
1243 }
1244}
1245
1246#[allow(dead_code)]
1248pub fn create_cross_device_manager() -> CoreResult<CrossDeviceManager> {
1249 CrossDeviceManager::new(1024 * 1024 * 1024) }
1251
1252pub trait ToDevice<T, D>
1254where
1255 T: GpuDataType,
1256 D: Dimension,
1257{
1258 fn to_device(
1260 &self,
1261 device: DeviceType,
1262 manager: &CrossDeviceManager,
1263 ) -> CoreResult<DeviceArray<T, D>>;
1264}
1265
1266impl<T, S, D> ToDevice<T, D> for ArrayBase<S, D>
1267where
1268 T: GpuDataType,
1269 S: RawData<Elem = T> + crate::ndarray::Data,
1270 D: Dimension,
1271{
1272 fn to_device(
1273 &self,
1274 device: DeviceType,
1275 manager: &CrossDeviceManager,
1276 ) -> CoreResult<DeviceArray<T, D>> {
1277 manager.to_device(self, device, None)
1278 }
1279}
1280
1281pub trait ToHost<T, D>
1283where
1284 T: GpuDataType,
1285 D: Dimension,
1286{
1287 fn to_host(&self, manager: &CrossDeviceManager) -> CoreResult<Array<T, D>>;
1289}
1290
1291impl<T, D> ToHost<T, D> for DeviceArray<T, D>
1292where
1293 T: GpuDataType,
1294 D: Dimension,
1295{
1296 fn to_host(&self, manager: &CrossDeviceManager) -> CoreResult<Array<T, D>> {
1297 manager.to_host(self, None)
1298 }
1299}
1300
1301#[allow(dead_code)]
1305pub fn create_cpuarray<T, S, D>(array: &ArrayBase<S, D>) -> DeviceArray<T, D>
1306where
1307 T: GpuDataType,
1308 S: RawData<Elem = T> + crate::ndarray::Data,
1309 D: Dimension,
1310{
1311 DeviceArray::new_cpu(array.to_owned())
1312}
1313
1314#[allow(dead_code)]
1316pub fn create_gpuarray<T, S, D>(
1317 array: &ArrayBase<S, D>,
1318 manager: &CrossDeviceManager,
1319) -> CoreResult<DeviceArray<T, D>>
1320where
1321 T: GpuDataType,
1322 S: RawData<Elem = T> + crate::ndarray::Data,
1323 D: Dimension,
1324{
1325 for device in manager.available_devices() {
1327 if let DeviceType::Gpu(_) = device {
1328 return manager.to_device(array, device, None);
1329 }
1330 }
1331
1332 Err(CoreError::DeviceError(
1333 ErrorContext::new("No GPU device available".to_string())
1334 .with_location(ErrorLocation::new(file!(), line!())),
1335 ))
1336}
1337
1338#[allow(dead_code)]
1340pub fn to_best_device<T, S, D>(
1341 array: &ArrayBase<S, D>,
1342 manager: &CrossDeviceManager,
1343) -> CoreResult<DeviceArray<T, D>>
1344where
1345 T: GpuDataType,
1346 S: RawData<Elem = T> + crate::ndarray::Data,
1347 D: Dimension,
1348{
1349 for device in manager.available_devices() {
1351 if let DeviceType::Gpu(_) = device {
1352 return manager.to_device(array, device, None);
1353 }
1354 }
1355
1356 Ok(create_cpuarray(array))
1358}