1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::array_protocol::{ArrayFunction, ArrayProtocol, GPUArray, NotImplemented};
23use crate::error::{CoreError, CoreResult, ErrorContext};
24use ::ndarray::{Array, Dimension};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum GPUBackend {
29 CUDA,
31
32 ROCm,
34
35 Metal,
37
38 WebGPU,
40
41 OpenCL,
43}
44
45impl Default for GPUBackend {
46 fn default() -> Self {
47 Self::CUDA
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct GPUConfig {
54 pub backend: GPUBackend,
56
57 pub device_id: usize,
59
60 pub async_ops: bool,
62
63 pub mixed_precision: bool,
65
66 pub memory_fraction: f32,
68}
69
70impl Default for GPUConfig {
71 fn default() -> Self {
72 Self {
73 backend: GPUBackend::default(),
74 device_id: 0,
75 async_ops: true,
76 mixed_precision: false,
77 memory_fraction: 0.9,
78 }
79 }
80}
81
82pub struct GPUNdarray<T, D: Dimension>
84where
85 T: Clone + Send + Sync + 'static + num_traits::Zero,
86 T: std::ops::Div<f64, Output = T>,
87 D: Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
88{
89 host_data: Array<T, D>,
91
92 config: GPUConfig,
94
95 on_gpu: bool,
97
98 id: String,
100}
101
102impl<T, D> Debug for GPUNdarray<T, D>
103where
104 T: Debug + Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
105 D: Dimension + Debug + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
106{
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("GPUNdarray")
109 .field("config", &self.config)
110 .field("on_gpu", &self.on_gpu)
111 .field("id", &self.id)
112 .field("shape", &self.host_data.shape())
113 .finish()
114 }
115}
116
117impl<T, D> GPUNdarray<T, D>
118where
119 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
120 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
121{
122 #[must_use]
124 pub fn new(hostdata: Array<T, D>, config: GPUConfig) -> Self {
125 let uuid = uuid::Uuid::new_v4();
126 let id = format!("uuid{uuid}");
127 let mut array = Self {
128 host_data: hostdata,
129 config,
130 on_gpu: false,
131 id,
132 };
133
134 array.on_gpu = true;
137
138 array
139 }
140
141 #[must_use]
143 pub fn shape(&self) -> &[usize] {
144 self.host_data.shape()
145 }
146
147 #[must_use]
149 pub const fn host_data(&self) -> &Array<T, D> {
150 &self.host_data
151 }
152
153 pub fn host_data_mut(&mut self) -> &mut Array<T, D> {
155 &mut self.host_data
157 }
158
159 #[must_use]
161 pub const fn config(&self) -> &GPUConfig {
162 &self.config
163 }
164
165 pub fn execute_kernel<F, R>(&self, kernel: F) -> CoreResult<R>
170 where
171 F: FnOnce(&Array<T, D>) -> CoreResult<R>,
172 {
173 kernel(&self.host_data)
176 }
177
178 pub fn sync_to_host(&mut self) -> CoreResult<()> {
183 Ok(())
186 }
187
188 pub fn sync_to_gpu(&mut self) -> CoreResult<()> {
193 self.on_gpu = true;
196 Ok(())
197 }
198}
199
200impl<T, D> ArrayProtocol for GPUNdarray<T, D>
201where
202 T: Clone + Send + Sync + 'static + num_traits::Zero,
203 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
204 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
205{
206 fn array_function(
207 &self,
208 func: &ArrayFunction,
209 _types: &[TypeId],
210 args: &[Box<dyn Any>],
211 kwargs: &HashMap<String, Box<dyn Any>>,
212 ) -> Result<Box<dyn Any>, NotImplemented> {
213 match func.name {
214 "scirs2::array_protocol::operations::sum" => {
215 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
218
219 if let Some(&_ax) = axis {
220 let sum = self.host_data.sum();
224 Ok(Box::new(sum))
225 } else {
226 let sum = self.host_data.sum();
228 Ok(Box::new(sum))
229 }
230 }
231 "scirs2::array_protocol::operations::mean" => {
232 let sum = self.host_data.sum();
234 let count = self.host_data.len();
235 #[allow(clippy::cast_precision_loss)]
236 let mean = sum / count as f64;
237
238 Ok(Box::new(mean))
239 }
240 "scirs2::array_protocol::operations::add" => {
241 if args.len() < 2 {
243 return Err(NotImplemented);
244 }
245
246 if let Some(other) = args[1].downcast_ref::<Self>() {
248 if self.shape() != other.shape() {
250 return Err(NotImplemented);
251 }
252
253 let Ok(result) = kernels::add(self, other) else {
255 return Err(NotImplemented);
256 };
257
258 return Ok(Box::new(result));
259 }
260
261 Err(NotImplemented)
264 }
265 "scirs2::array_protocol::operations::multiply" => {
266 if args.len() < 2 {
268 return Err(NotImplemented);
269 }
270
271 if let Some(other) = args[1].downcast_ref::<Self>() {
273 if self.shape() != other.shape() {
275 return Err(NotImplemented);
276 }
277
278 let Ok(result) = kernels::multiply(self, other) else {
280 return Err(NotImplemented);
281 };
282
283 return Ok(Box::new(result));
284 }
285
286 Err(NotImplemented)
289 }
290 "scirs2::array_protocol::operations::matmul" => {
291 if args.len() < 2 {
293 return Err(NotImplemented);
294 }
295
296 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
300 return Err(NotImplemented);
301 }
302
303 if let Some(other) = args[1].downcast_ref::<Self>() {
305 if TypeId::of::<T>() == TypeId::of::<f64>()
308 && TypeId::of::<D>() == TypeId::of::<crate::ndarray::Ix2>()
309 {
310 let self_f64 = unsafe {
311 &*std::ptr::from_ref(self)
312 .cast::<GPUNdarray<f64, crate::ndarray::Ix2>>()
313 };
314 let other_f64 = unsafe {
315 &*std::ptr::from_ref(other)
316 .cast::<GPUNdarray<f64, crate::ndarray::Ix2>>()
317 };
318
319 match kernels::matmul(self_f64, other_f64) {
320 Ok(result) => {
321 return Ok(Box::new(result));
325 }
326 Err(_) => return Err(NotImplemented),
327 }
328 }
329 let result = Self::new(self.host_data.clone(), self.config.clone());
332 return Ok(Box::new(result));
333 }
334
335 Err(NotImplemented)
336 }
337 "scirs2::array_protocol::operations::transpose" => {
338 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
341 return Err(NotImplemented);
342 }
343
344 let transposed = self.host_data.t().to_owned();
347 let result = Self::new(transposed, self.config.clone());
348
349 Ok(Box::new(result))
350 }
351 "scirs2::array_protocol::operations::reshape" => {
352 if let Some(shape) = kwargs
354 .get("shape")
355 .and_then(|s| s.downcast_ref::<Vec<usize>>())
356 {
357 match self.host_data.clone().into_shape_with_order(shape.clone()) {
358 Ok(reshaped) => {
359 let result = GPUNdarray::new(reshaped, self.config.clone());
360 return Ok(Box::new(result));
361 }
362 Err(_) => return Err(NotImplemented),
363 }
364 }
365
366 Err(NotImplemented)
367 }
368 _ => Err(NotImplemented),
369 }
370 }
371
372 fn as_any(&self) -> &dyn Any {
373 self
374 }
375
376 fn shape(&self) -> &[usize] {
377 self.host_data.shape()
378 }
379
380 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
381 Box::new(self.clone())
382 }
383}
384
385impl<T, D> GPUArray for GPUNdarray<T, D>
386where
387 T: Clone + Send + Sync + 'static + num_traits::Zero,
388 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
389 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
390{
391 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
394 Ok(Box::new(self.clone()))
396 }
397
398 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
401 let array = super::NdarrayWrapper::new(self.host_data.clone());
403
404 Ok(Box::new(array) as Box<dyn ArrayProtocol>)
405 }
406
407 fn is_on_gpu(&self) -> bool {
408 self.on_gpu
409 }
410
411 fn device_info(&self) -> HashMap<String, String> {
412 let mut info = HashMap::new();
413 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
414 info.insert("device_id".to_string(), self.config.device_id.to_string());
415 info.insert("on_gpu".to_string(), self.on_gpu.to_string());
416 info.insert("id".to_string(), self.id.clone());
417 info
418 }
419}
420
421impl<T, D> Clone for GPUNdarray<T, D>
422where
423 T: Clone + Send + Sync + 'static + num_traits::Zero,
424 T: std::ops::Div<f64, Output = T>,
425 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
426{
427 fn clone(&self) -> Self {
428 Self {
429 host_data: self.host_data.clone(),
430 config: self.config.clone(),
431 on_gpu: self.on_gpu,
432 id: self.id.clone(),
433 }
434 }
435}
436
437pub struct GPUArrayBuilder {
439 config: GPUConfig,
440}
441
442impl Default for GPUArrayBuilder {
443 fn default() -> Self {
444 Self::new()
445 }
446}
447
448impl GPUArrayBuilder {
449 #[must_use]
451 pub fn new() -> Self {
452 Self {
453 config: GPUConfig::default(),
454 }
455 }
456
457 #[must_use]
459 pub const fn backend(mut self, backend: GPUBackend) -> Self {
460 self.config.backend = backend;
461 self
462 }
463
464 #[must_use]
466 pub const fn device_id(mut self, device_id: usize) -> Self {
467 self.config.device_id = device_id;
468 self
469 }
470
471 #[must_use]
473 pub const fn async_ops(mut self, asyncops: bool) -> Self {
474 self.config.async_ops = asyncops;
475 self
476 }
477
478 #[must_use]
480 pub const fn mixed_precision(mut self, mixedprecision: bool) -> Self {
481 self.config.mixed_precision = mixedprecision;
482 self
483 }
484
485 #[must_use]
487 pub const fn memory_fraction(mut self, memoryfraction: f32) -> Self {
488 self.config.memory_fraction = memoryfraction;
489 self
490 }
491
492 #[must_use]
494 pub fn build<T, D>(self, hostdata: Array<T, D>) -> GPUNdarray<T, D>
495 where
496 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
497 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
498 {
499 GPUNdarray::new(hostdata, self.config)
500 }
501}
502
503pub mod kernels {
505 use super::*;
506 use ::ndarray::{Array, Dimension};
507
508 pub fn add<T, D>(a: &GPUNdarray<T, D>, b: &GPUNdarray<T, D>) -> CoreResult<GPUNdarray<T, D>>
513 where
514 T: Clone
515 + std::ops::Add<Output = T>
516 + Send
517 + Sync
518 + 'static
519 + num_traits::Zero
520 + std::ops::Div<f64, Output = T>,
521 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
522 {
523 if a.shape() != b.shape() {
528 return Err(CoreError::ShapeError(ErrorContext::new(format!(
529 "Shape mismatch: {:?} vs {:?}",
530 a.shape(),
531 b.shape()
532 ))));
533 }
534
535 let result_data = a.host_data().clone() + b.host_data().clone();
537
538 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
540 }
541
542 pub fn multiply<T, D>(
547 a: &GPUNdarray<T, D>,
548 b: &GPUNdarray<T, D>,
549 ) -> CoreResult<GPUNdarray<T, D>>
550 where
551 T: Clone
552 + std::ops::Mul<Output = T>
553 + Send
554 + Sync
555 + 'static
556 + num_traits::Zero
557 + std::ops::Div<f64, Output = T>,
558 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
559 {
560 if a.shape() != b.shape() {
565 return Err(CoreError::ShapeError(ErrorContext::new(format!(
566 "Shape mismatch: {:?} vs {:?}",
567 a.shape(),
568 b.shape()
569 ))));
570 }
571
572 let result_data = a.host_data().clone() * b.host_data().clone();
574
575 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
577 }
578
579 pub fn matmul<T>(
584 a: &GPUNdarray<T, crate::ndarray::Ix2>,
585 b: &GPUNdarray<T, crate::ndarray::Ix2>,
586 ) -> CoreResult<GPUNdarray<T, crate::ndarray::Ix2>>
587 where
588 T: Clone
589 + std::ops::Mul<Output = T>
590 + std::ops::Add<Output = T>
591 + Default
592 + Send
593 + Sync
594 + 'static
595 + num_traits::Zero
596 + std::ops::Div<f64, Output = T>,
597 {
598 let ashape = a.shape();
603 let bshape = b.shape();
604
605 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
606 return Err(CoreError::ShapeError(ErrorContext::new(format!(
607 "Incompatible shapes for matmul: {ashape:?} vs {bshape:?}"
608 ))));
609 }
610
611 let m = ashape[0];
614 let p = bshape[1];
615
616 let result_data = Array::default((m, p));
618
619 Ok(GPUNdarray::<T, crate::ndarray::Ix2>::new(
621 result_data,
622 a.config.clone(),
623 ))
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630 use ::ndarray::{arr2, Array2};
631
632 #[test]
633 fn test_gpu_ndarray_creation() {
634 let array = Array2::<f64>::ones((10, 5));
635 let config = GPUConfig::default();
636
637 let gpu_array = GPUNdarray::new(array.clone(), config);
638
639 assert_eq!(gpu_array.shape(), &[10, 5]);
641 assert!(gpu_array.is_on_gpu());
642
643 let info = gpu_array.device_info();
645 assert_eq!(info.get("backend").unwrap(), "CUDA");
646 assert_eq!(info.get("device_id").unwrap(), "0");
647 assert_eq!(info.get("on_gpu").unwrap(), "true");
648 }
649
650 #[test]
651 fn test_gpu_array_builder() {
652 let array = Array2::<f64>::ones((10, 5));
653
654 let gpu_array = GPUArrayBuilder::new()
655 .backend(GPUBackend::CUDA)
656 .device_id(1)
657 .async_ops(true)
658 .mixed_precision(true)
659 .memory_fraction(0.8)
660 .build(array.clone());
661
662 assert_eq!(gpu_array.config.backend, GPUBackend::CUDA);
664 assert_eq!(gpu_array.config.device_id, 1);
665 assert!(gpu_array.config.async_ops);
666 assert!(gpu_array.config.mixed_precision);
667 assert_eq!(gpu_array.config.memory_fraction, 0.8);
668 }
669
670 #[test]
671 fn test_gpu_array_kernels() {
672 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
673 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
674
675 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
676 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
677
678 let result = kernels::add(&gpu_a, &gpu_b).unwrap();
680 let expected = a + b;
681 assert_eq!(result.host_data(), &expected);
682
683 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
685 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
686
687 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
688 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
689
690 let result = kernels::multiply(&gpu_a, &gpu_b).unwrap();
691 let expected = a * b;
692 assert_eq!(result.host_data(), &expected);
693 }
694}