1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::array_protocol::{ArrayFunction, ArrayProtocol, GPUArray, NotImplemented};
23use crate::error::{CoreError, CoreResult, ErrorContext};
24use ndarray::{Array, Dimension};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum GPUBackend {
29 CUDA,
31
32 ROCm,
34
35 Metal,
37
38 WebGPU,
40
41 OpenCL,
43}
44
45impl Default for GPUBackend {
46 fn default() -> Self {
47 Self::CUDA
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct GPUConfig {
54 pub backend: GPUBackend,
56
57 pub device_id: usize,
59
60 pub async_ops: bool,
62
63 pub mixed_precision: bool,
65
66 pub memory_fraction: f32,
68}
69
70impl Default for GPUConfig {
71 fn default() -> Self {
72 Self {
73 backend: GPUBackend::default(),
74 device_id: 0,
75 async_ops: true,
76 mixed_precision: false,
77 memory_fraction: 0.9,
78 }
79 }
80}
81
82pub struct GPUNdarray<T, D: Dimension>
84where
85 T: Clone + Send + Sync + 'static + num_traits::Zero,
86 T: std::ops::Div<f64, Output = T>,
87 D: Clone + Send + Sync + 'static + ndarray::RemoveAxis,
88{
89 host_data: Array<T, D>,
91
92 config: GPUConfig,
94
95 on_gpu: bool,
97
98 id: String,
100}
101
102impl<T, D> Debug for GPUNdarray<T, D>
103where
104 T: Debug + Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
105 D: Dimension + Debug + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
106{
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("GPUNdarray")
109 .field("config", &self.config)
110 .field("on_gpu", &self.on_gpu)
111 .field("id", &self.id)
112 .field("shape", &self.host_data.shape())
113 .finish()
114 }
115}
116
117impl<T, D> GPUNdarray<T, D>
118where
119 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
120 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
121{
122 #[must_use]
124 pub fn new(host_data: Array<T, D>, config: GPUConfig) -> Self {
125 let id = format!("gpu_array_{}", uuid::Uuid::new_v4());
126 let mut array = Self {
127 host_data,
128 config,
129 on_gpu: false,
130 id,
131 };
132
133 array.on_gpu = true;
136
137 array
138 }
139
140 #[must_use]
142 pub fn shape(&self) -> &[usize] {
143 self.host_data.shape()
144 }
145
146 #[must_use]
148 pub const fn host_data(&self) -> &Array<T, D> {
149 &self.host_data
150 }
151
152 pub fn host_data_mut(&mut self) -> &mut Array<T, D> {
154 &mut self.host_data
156 }
157
158 #[must_use]
160 pub const fn config(&self) -> &GPUConfig {
161 &self.config
162 }
163
164 pub fn execute_kernel<F, R>(&self, kernel: F) -> CoreResult<R>
169 where
170 F: FnOnce(&Array<T, D>) -> CoreResult<R>,
171 {
172 kernel(&self.host_data)
175 }
176
177 pub fn sync_to_host(&mut self) -> CoreResult<()> {
182 Ok(())
185 }
186
187 pub fn sync_to_gpu(&mut self) -> CoreResult<()> {
192 self.on_gpu = true;
195 Ok(())
196 }
197}
198
199impl<T, D> ArrayProtocol for GPUNdarray<T, D>
200where
201 T: Clone + Send + Sync + 'static + num_traits::Zero,
202 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
203 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
204{
205 fn array_function(
206 &self,
207 func: &ArrayFunction,
208 _types: &[TypeId],
209 args: &[Box<dyn Any>],
210 kwargs: &HashMap<String, Box<dyn Any>>,
211 ) -> Result<Box<dyn Any>, NotImplemented> {
212 match func.name {
213 "scirs2::array_protocol::operations::sum" => {
214 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
217
218 if let Some(&_ax) = axis {
219 let sum = self.host_data.sum();
223 Ok(Box::new(sum))
224 } else {
225 let sum = self.host_data.sum();
227 Ok(Box::new(sum))
228 }
229 }
230 "scirs2::array_protocol::operations::mean" => {
231 let sum = self.host_data.sum();
233 let count = self.host_data.len();
234 #[allow(clippy::cast_precision_loss)]
235 let mean = sum / count as f64;
236
237 Ok(Box::new(mean))
238 }
239 "scirs2::array_protocol::operations::add" => {
240 if args.len() < 2 {
242 return Err(NotImplemented);
243 }
244
245 if let Some(other) = args[1].downcast_ref::<Self>() {
247 if self.shape() != other.shape() {
249 return Err(NotImplemented);
250 }
251
252 let Ok(result) = kernels::add(self, other) else {
254 return Err(NotImplemented);
255 };
256
257 return Ok(Box::new(result));
258 }
259
260 Err(NotImplemented)
263 }
264 "scirs2::array_protocol::operations::multiply" => {
265 if args.len() < 2 {
267 return Err(NotImplemented);
268 }
269
270 if let Some(other) = args[1].downcast_ref::<Self>() {
272 if self.shape() != other.shape() {
274 return Err(NotImplemented);
275 }
276
277 let Ok(result) = kernels::multiply(self, other) else {
279 return Err(NotImplemented);
280 };
281
282 return Ok(Box::new(result));
283 }
284
285 Err(NotImplemented)
288 }
289 "scirs2::array_protocol::operations::matmul" => {
290 if args.len() < 2 {
292 return Err(NotImplemented);
293 }
294
295 if TypeId::of::<D>() != TypeId::of::<ndarray::Ix2>() {
299 return Err(NotImplemented);
300 }
301
302 if let Some(other) = args[1].downcast_ref::<Self>() {
304 if TypeId::of::<T>() == TypeId::of::<f64>()
307 && TypeId::of::<D>() == TypeId::of::<ndarray::Ix2>()
308 {
309 let self_f64 = unsafe {
310 &*std::ptr::from_ref(self).cast::<GPUNdarray<f64, ndarray::Ix2>>()
311 };
312 let other_f64 = unsafe {
313 &*std::ptr::from_ref(other).cast::<GPUNdarray<f64, ndarray::Ix2>>()
314 };
315
316 match kernels::matmul(self_f64, other_f64) {
317 Ok(result) => {
318 return Ok(Box::new(result));
322 }
323 Err(_) => return Err(NotImplemented),
324 }
325 }
326 let result = Self::new(self.host_data.clone(), self.config.clone());
329 return Ok(Box::new(result));
330 }
331
332 Err(NotImplemented)
333 }
334 "scirs2::array_protocol::operations::transpose" => {
335 if TypeId::of::<D>() != TypeId::of::<ndarray::Ix2>() {
338 return Err(NotImplemented);
339 }
340
341 let transposed = self.host_data.t().to_owned();
344 let result = Self::new(transposed, self.config.clone());
345
346 Ok(Box::new(result))
347 }
348 "scirs2::array_protocol::operations::reshape" => {
349 if let Some(shape) = kwargs
351 .get("shape")
352 .and_then(|s| s.downcast_ref::<Vec<usize>>())
353 {
354 match self.host_data.clone().into_shape_with_order(shape.clone()) {
355 Ok(reshaped) => {
356 let result = GPUNdarray::new(reshaped, self.config.clone());
357 return Ok(Box::new(result));
358 }
359 Err(_) => return Err(NotImplemented),
360 }
361 }
362
363 Err(NotImplemented)
364 }
365 _ => Err(NotImplemented),
366 }
367 }
368
369 fn as_any(&self) -> &dyn Any {
370 self
371 }
372
373 fn shape(&self) -> &[usize] {
374 self.host_data.shape()
375 }
376
377 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
378 Box::new(self.clone())
379 }
380}
381
382impl<T, D> GPUArray for GPUNdarray<T, D>
383where
384 T: Clone + Send + Sync + 'static + num_traits::Zero,
385 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
386 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
387{
388 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
391 Ok(Box::new(self.clone()))
393 }
394
395 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
398 let array = super::NdarrayWrapper::new(self.host_data.clone());
400
401 Ok(Box::new(array) as Box<dyn ArrayProtocol>)
402 }
403
404 fn is_on_gpu(&self) -> bool {
405 self.on_gpu
406 }
407
408 fn device_info(&self) -> HashMap<String, String> {
409 let mut info = HashMap::new();
410 info.insert(
411 "backend".to_string(),
412 format!("{backend:?}", backend = self.config.backend),
413 );
414 info.insert("device_id".to_string(), self.config.device_id.to_string());
415 info.insert("on_gpu".to_string(), self.on_gpu.to_string());
416 info.insert("id".to_string(), self.id.clone());
417 info
418 }
419}
420
421impl<T, D> Clone for GPUNdarray<T, D>
422where
423 T: Clone + Send + Sync + 'static + num_traits::Zero,
424 T: std::ops::Div<f64, Output = T>,
425 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
426{
427 fn clone(&self) -> Self {
428 Self {
429 host_data: self.host_data.clone(),
430 config: self.config.clone(),
431 on_gpu: self.on_gpu,
432 id: self.id.clone(),
433 }
434 }
435}
436
437pub struct GPUArrayBuilder {
439 config: GPUConfig,
440}
441
442impl Default for GPUArrayBuilder {
443 fn default() -> Self {
444 Self::new()
445 }
446}
447
448impl GPUArrayBuilder {
449 #[must_use]
451 pub fn new() -> Self {
452 Self {
453 config: GPUConfig::default(),
454 }
455 }
456
457 #[must_use]
459 pub const fn backend(mut self, backend: GPUBackend) -> Self {
460 self.config.backend = backend;
461 self
462 }
463
464 #[must_use]
466 pub const fn device_id(mut self, device_id: usize) -> Self {
467 self.config.device_id = device_id;
468 self
469 }
470
471 #[must_use]
473 pub const fn async_ops(mut self, async_ops: bool) -> Self {
474 self.config.async_ops = async_ops;
475 self
476 }
477
478 #[must_use]
480 pub const fn mixed_precision(mut self, mixed_precision: bool) -> Self {
481 self.config.mixed_precision = mixed_precision;
482 self
483 }
484
485 #[must_use]
487 pub const fn memory_fraction(mut self, memory_fraction: f32) -> Self {
488 self.config.memory_fraction = memory_fraction;
489 self
490 }
491
492 #[must_use]
494 pub fn build<T, D>(self, host_data: Array<T, D>) -> GPUNdarray<T, D>
495 where
496 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
497 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
498 {
499 GPUNdarray::new(host_data, self.config)
500 }
501}
502
503pub mod kernels {
505 use super::*;
506 use ndarray::{Array, Dimension};
507
508 pub fn add<T, D>(a: &GPUNdarray<T, D>, b: &GPUNdarray<T, D>) -> CoreResult<GPUNdarray<T, D>>
513 where
514 T: Clone
515 + std::ops::Add<Output = T>
516 + Send
517 + Sync
518 + 'static
519 + num_traits::Zero
520 + std::ops::Div<f64, Output = T>,
521 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
522 {
523 if a.shape() != b.shape() {
528 return Err(CoreError::ShapeError(ErrorContext::new(format!(
529 "Shape mismatch: {:?} vs {:?}",
530 a.shape(),
531 b.shape()
532 ))));
533 }
534
535 let result_data = a.host_data().clone() + b.host_data().clone();
537
538 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
540 }
541
542 pub fn multiply<T, D>(
547 a: &GPUNdarray<T, D>,
548 b: &GPUNdarray<T, D>,
549 ) -> CoreResult<GPUNdarray<T, D>>
550 where
551 T: Clone
552 + std::ops::Mul<Output = T>
553 + Send
554 + Sync
555 + 'static
556 + num_traits::Zero
557 + std::ops::Div<f64, Output = T>,
558 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
559 {
560 if a.shape() != b.shape() {
565 return Err(CoreError::ShapeError(ErrorContext::new(format!(
566 "Shape mismatch: {:?} vs {:?}",
567 a.shape(),
568 b.shape()
569 ))));
570 }
571
572 let result_data = a.host_data().clone() * b.host_data().clone();
574
575 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
577 }
578
579 pub fn matmul<T>(
584 a: &GPUNdarray<T, ndarray::Ix2>,
585 b: &GPUNdarray<T, ndarray::Ix2>,
586 ) -> CoreResult<GPUNdarray<T, ndarray::Ix2>>
587 where
588 T: Clone
589 + std::ops::Mul<Output = T>
590 + std::ops::Add<Output = T>
591 + Default
592 + Send
593 + Sync
594 + 'static
595 + num_traits::Zero
596 + std::ops::Div<f64, Output = T>,
597 {
598 let a_shape = a.shape();
603 let b_shape = b.shape();
604
605 if a_shape.len() != 2 || b_shape.len() != 2 || a_shape[1] != b_shape[0] {
606 return Err(CoreError::ShapeError(ErrorContext::new(format!(
607 "Incompatible shapes for matmul: {:?} vs {:?}",
608 a_shape, b_shape
609 ))));
610 }
611
612 let m = a_shape[0];
615 let p = b_shape[1];
616
617 let result_data = Array::default((m, p));
619
620 Ok(GPUNdarray::<T, ndarray::Ix2>::new(
622 result_data,
623 a.config.clone(),
624 ))
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631 use ndarray::{arr2, Array2};
632
633 #[test]
634 fn test_gpu_ndarray_creation() {
635 let array = Array2::<f64>::ones((10, 5));
636 let config = GPUConfig::default();
637
638 let gpu_array = GPUNdarray::new(array.clone(), config);
639
640 assert_eq!(gpu_array.shape(), &[10, 5]);
642 assert!(gpu_array.is_on_gpu());
643
644 let info = gpu_array.device_info();
646 assert_eq!(info.get("backend").unwrap(), "CUDA");
647 assert_eq!(info.get("device_id").unwrap(), "0");
648 assert_eq!(info.get("on_gpu").unwrap(), "true");
649 }
650
651 #[test]
652 fn test_gpu_array_builder() {
653 let array = Array2::<f64>::ones((10, 5));
654
655 let gpu_array = GPUArrayBuilder::new()
656 .backend(GPUBackend::CUDA)
657 .device_id(1)
658 .async_ops(true)
659 .mixed_precision(true)
660 .memory_fraction(0.8)
661 .build(array.clone());
662
663 assert_eq!(gpu_array.config.backend, GPUBackend::CUDA);
665 assert_eq!(gpu_array.config.device_id, 1);
666 assert!(gpu_array.config.async_ops);
667 assert!(gpu_array.config.mixed_precision);
668 assert_eq!(gpu_array.config.memory_fraction, 0.8);
669 }
670
671 #[test]
672 fn test_gpu_array_kernels() {
673 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
674 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
675
676 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
677 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
678
679 let result = kernels::add(&gpu_a, &gpu_b).unwrap();
681 let expected = a + b;
682 assert_eq!(result.host_data(), &expected);
683
684 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
686 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
687
688 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
689 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
690
691 let result = kernels::multiply(&gpu_a, &gpu_b).unwrap();
692 let expected = a * b;
693 assert_eq!(result.host_data(), &expected);
694 }
695}