1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::array_protocol::{ArrayFunction, ArrayProtocol, GPUArray, NotImplemented};
23use crate::error::{CoreError, CoreResult, ErrorContext};
24use ndarray::{Array, Dimension};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum GPUBackend {
29 CUDA,
31
32 ROCm,
34
35 Metal,
37
38 WebGPU,
40
41 OpenCL,
43}
44
45impl Default for GPUBackend {
46 fn default() -> Self {
47 Self::CUDA
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct GPUConfig {
54 pub backend: GPUBackend,
56
57 pub device_id: usize,
59
60 pub async_ops: bool,
62
63 pub mixed_precision: bool,
65
66 pub memory_fraction: f32,
68}
69
70impl Default for GPUConfig {
71 fn default() -> Self {
72 Self {
73 backend: GPUBackend::default(),
74 device_id: 0,
75 async_ops: true,
76 mixed_precision: false,
77 memory_fraction: 0.9,
78 }
79 }
80}
81
82pub struct GPUNdarray<T, D: Dimension>
84where
85 T: Clone + Send + Sync + 'static + num_traits::Zero,
86 T: std::ops::Div<f64, Output = T>,
87 D: Clone + Send + Sync + 'static + ndarray::RemoveAxis,
88{
89 host_data: Array<T, D>,
91
92 config: GPUConfig,
94
95 on_gpu: bool,
97
98 id: String,
100}
101
102impl<T, D> Debug for GPUNdarray<T, D>
103where
104 T: Debug + Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
105 D: Dimension + Debug + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
106{
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("GPUNdarray")
109 .field("config", &self.config)
110 .field("on_gpu", &self.on_gpu)
111 .field("id", &self.id)
112 .field("shape", &self.host_data.shape())
113 .finish()
114 }
115}
116
117impl<T, D> GPUNdarray<T, D>
118where
119 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
120 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
121{
122 #[must_use]
124 pub fn new(hostdata: Array<T, D>, config: GPUConfig) -> Self {
125 let uuid = uuid::Uuid::new_v4();
126 let id = format!("uuid{uuid}");
127 let mut array = Self {
128 host_data: hostdata,
129 config,
130 on_gpu: false,
131 id,
132 };
133
134 array.on_gpu = true;
137
138 array
139 }
140
141 #[must_use]
143 pub fn shape(&self) -> &[usize] {
144 self.host_data.shape()
145 }
146
147 #[must_use]
149 pub const fn host_data(&self) -> &Array<T, D> {
150 &self.host_data
151 }
152
153 pub fn host_data_mut(&mut self) -> &mut Array<T, D> {
155 &mut self.host_data
157 }
158
159 #[must_use]
161 pub const fn config(&self) -> &GPUConfig {
162 &self.config
163 }
164
165 pub fn execute_kernel<F, R>(&self, kernel: F) -> CoreResult<R>
170 where
171 F: FnOnce(&Array<T, D>) -> CoreResult<R>,
172 {
173 kernel(&self.host_data)
176 }
177
178 pub fn sync_to_host(&mut self) -> CoreResult<()> {
183 Ok(())
186 }
187
188 pub fn sync_to_gpu(&mut self) -> CoreResult<()> {
193 self.on_gpu = true;
196 Ok(())
197 }
198}
199
200impl<T, D> ArrayProtocol for GPUNdarray<T, D>
201where
202 T: Clone + Send + Sync + 'static + num_traits::Zero,
203 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
204 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
205{
206 fn array_function(
207 &self,
208 func: &ArrayFunction,
209 _types: &[TypeId],
210 args: &[Box<dyn Any>],
211 kwargs: &HashMap<String, Box<dyn Any>>,
212 ) -> Result<Box<dyn Any>, NotImplemented> {
213 match func.name {
214 "scirs2::array_protocol::operations::sum" => {
215 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
218
219 if let Some(&_ax) = axis {
220 let sum = self.host_data.sum();
224 Ok(Box::new(sum))
225 } else {
226 let sum = self.host_data.sum();
228 Ok(Box::new(sum))
229 }
230 }
231 "scirs2::array_protocol::operations::mean" => {
232 let sum = self.host_data.sum();
234 let count = self.host_data.len();
235 #[allow(clippy::cast_precision_loss)]
236 let mean = sum / count as f64;
237
238 Ok(Box::new(mean))
239 }
240 "scirs2::array_protocol::operations::add" => {
241 if args.len() < 2 {
243 return Err(NotImplemented);
244 }
245
246 if let Some(other) = args[1].downcast_ref::<Self>() {
248 if self.shape() != other.shape() {
250 return Err(NotImplemented);
251 }
252
253 let Ok(result) = kernels::add(self, other) else {
255 return Err(NotImplemented);
256 };
257
258 return Ok(Box::new(result));
259 }
260
261 Err(NotImplemented)
264 }
265 "scirs2::array_protocol::operations::multiply" => {
266 if args.len() < 2 {
268 return Err(NotImplemented);
269 }
270
271 if let Some(other) = args[1].downcast_ref::<Self>() {
273 if self.shape() != other.shape() {
275 return Err(NotImplemented);
276 }
277
278 let Ok(result) = kernels::multiply(self, other) else {
280 return Err(NotImplemented);
281 };
282
283 return Ok(Box::new(result));
284 }
285
286 Err(NotImplemented)
289 }
290 "scirs2::array_protocol::operations::matmul" => {
291 if args.len() < 2 {
293 return Err(NotImplemented);
294 }
295
296 if TypeId::of::<D>() != TypeId::of::<ndarray::Ix2>() {
300 return Err(NotImplemented);
301 }
302
303 if let Some(other) = args[1].downcast_ref::<Self>() {
305 if TypeId::of::<T>() == TypeId::of::<f64>()
308 && TypeId::of::<D>() == TypeId::of::<ndarray::Ix2>()
309 {
310 let self_f64 = unsafe {
311 &*std::ptr::from_ref(self).cast::<GPUNdarray<f64, ndarray::Ix2>>()
312 };
313 let other_f64 = unsafe {
314 &*std::ptr::from_ref(other).cast::<GPUNdarray<f64, ndarray::Ix2>>()
315 };
316
317 match kernels::matmul(self_f64, other_f64) {
318 Ok(result) => {
319 return Ok(Box::new(result));
323 }
324 Err(_) => return Err(NotImplemented),
325 }
326 }
327 let result = Self::new(self.host_data.clone(), self.config.clone());
330 return Ok(Box::new(result));
331 }
332
333 Err(NotImplemented)
334 }
335 "scirs2::array_protocol::operations::transpose" => {
336 if TypeId::of::<D>() != TypeId::of::<ndarray::Ix2>() {
339 return Err(NotImplemented);
340 }
341
342 let transposed = self.host_data.t().to_owned();
345 let result = Self::new(transposed, self.config.clone());
346
347 Ok(Box::new(result))
348 }
349 "scirs2::array_protocol::operations::reshape" => {
350 if let Some(shape) = kwargs
352 .get("shape")
353 .and_then(|s| s.downcast_ref::<Vec<usize>>())
354 {
355 match self.host_data.clone().into_shape_with_order(shape.clone()) {
356 Ok(reshaped) => {
357 let result = GPUNdarray::new(reshaped, self.config.clone());
358 return Ok(Box::new(result));
359 }
360 Err(_) => return Err(NotImplemented),
361 }
362 }
363
364 Err(NotImplemented)
365 }
366 _ => Err(NotImplemented),
367 }
368 }
369
370 fn as_any(&self) -> &dyn Any {
371 self
372 }
373
374 fn shape(&self) -> &[usize] {
375 self.host_data.shape()
376 }
377
378 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
379 Box::new(self.clone())
380 }
381}
382
383impl<T, D> GPUArray for GPUNdarray<T, D>
384where
385 T: Clone + Send + Sync + 'static + num_traits::Zero,
386 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
387 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
388{
389 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
392 Ok(Box::new(self.clone()))
394 }
395
396 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
399 let array = super::NdarrayWrapper::new(self.host_data.clone());
401
402 Ok(Box::new(array) as Box<dyn ArrayProtocol>)
403 }
404
405 fn is_on_gpu(&self) -> bool {
406 self.on_gpu
407 }
408
409 fn device_info(&self) -> HashMap<String, String> {
410 let mut info = HashMap::new();
411 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
412 info.insert("device_id".to_string(), self.config.device_id.to_string());
413 info.insert("on_gpu".to_string(), self.on_gpu.to_string());
414 info.insert("id".to_string(), self.id.clone());
415 info
416 }
417}
418
419impl<T, D> Clone for GPUNdarray<T, D>
420where
421 T: Clone + Send + Sync + 'static + num_traits::Zero,
422 T: std::ops::Div<f64, Output = T>,
423 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
424{
425 fn clone(&self) -> Self {
426 Self {
427 host_data: self.host_data.clone(),
428 config: self.config.clone(),
429 on_gpu: self.on_gpu,
430 id: self.id.clone(),
431 }
432 }
433}
434
435pub struct GPUArrayBuilder {
437 config: GPUConfig,
438}
439
440impl Default for GPUArrayBuilder {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446impl GPUArrayBuilder {
447 #[must_use]
449 pub fn new() -> Self {
450 Self {
451 config: GPUConfig::default(),
452 }
453 }
454
455 #[must_use]
457 pub const fn backend(mut self, backend: GPUBackend) -> Self {
458 self.config.backend = backend;
459 self
460 }
461
462 #[must_use]
464 pub const fn device_id(mut self, device_id: usize) -> Self {
465 self.config.device_id = device_id;
466 self
467 }
468
469 #[must_use]
471 pub const fn async_ops(mut self, asyncops: bool) -> Self {
472 self.config.async_ops = asyncops;
473 self
474 }
475
476 #[must_use]
478 pub const fn mixed_precision(mut self, mixedprecision: bool) -> Self {
479 self.config.mixed_precision = mixedprecision;
480 self
481 }
482
483 #[must_use]
485 pub const fn memory_fraction(mut self, memoryfraction: f32) -> Self {
486 self.config.memory_fraction = memoryfraction;
487 self
488 }
489
490 #[must_use]
492 pub fn build<T, D>(self, hostdata: Array<T, D>) -> GPUNdarray<T, D>
493 where
494 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
495 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
496 {
497 GPUNdarray::new(hostdata, self.config)
498 }
499}
500
501pub mod kernels {
503 use super::*;
504 use ndarray::{Array, Dimension};
505
506 pub fn add<T, D>(a: &GPUNdarray<T, D>, b: &GPUNdarray<T, D>) -> CoreResult<GPUNdarray<T, D>>
511 where
512 T: Clone
513 + std::ops::Add<Output = T>
514 + Send
515 + Sync
516 + 'static
517 + num_traits::Zero
518 + std::ops::Div<f64, Output = T>,
519 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
520 {
521 if a.shape() != b.shape() {
526 return Err(CoreError::ShapeError(ErrorContext::new(format!(
527 "Shape mismatch: {:?} vs {:?}",
528 a.shape(),
529 b.shape()
530 ))));
531 }
532
533 let result_data = a.host_data().clone() + b.host_data().clone();
535
536 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
538 }
539
540 pub fn multiply<T, D>(
545 a: &GPUNdarray<T, D>,
546 b: &GPUNdarray<T, D>,
547 ) -> CoreResult<GPUNdarray<T, D>>
548 where
549 T: Clone
550 + std::ops::Mul<Output = T>
551 + Send
552 + Sync
553 + 'static
554 + num_traits::Zero
555 + std::ops::Div<f64, Output = T>,
556 D: Dimension + Clone + Send + Sync + 'static + ndarray::RemoveAxis,
557 {
558 if a.shape() != b.shape() {
563 return Err(CoreError::ShapeError(ErrorContext::new(format!(
564 "Shape mismatch: {:?} vs {:?}",
565 a.shape(),
566 b.shape()
567 ))));
568 }
569
570 let result_data = a.host_data().clone() * b.host_data().clone();
572
573 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
575 }
576
577 pub fn matmul<T>(
582 a: &GPUNdarray<T, ndarray::Ix2>,
583 b: &GPUNdarray<T, ndarray::Ix2>,
584 ) -> CoreResult<GPUNdarray<T, ndarray::Ix2>>
585 where
586 T: Clone
587 + std::ops::Mul<Output = T>
588 + std::ops::Add<Output = T>
589 + Default
590 + Send
591 + Sync
592 + 'static
593 + num_traits::Zero
594 + std::ops::Div<f64, Output = T>,
595 {
596 let ashape = a.shape();
601 let bshape = b.shape();
602
603 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
604 return Err(CoreError::ShapeError(ErrorContext::new(format!(
605 "Incompatible shapes for matmul: {ashape:?} vs {bshape:?}"
606 ))));
607 }
608
609 let m = ashape[0];
612 let p = bshape[1];
613
614 let result_data = Array::default((m, p));
616
617 Ok(GPUNdarray::<T, ndarray::Ix2>::new(
619 result_data,
620 a.config.clone(),
621 ))
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use ndarray::{arr2, Array2};
629
630 #[test]
631 fn test_gpu_ndarray_creation() {
632 let array = Array2::<f64>::ones((10, 5));
633 let config = GPUConfig::default();
634
635 let gpu_array = GPUNdarray::new(array.clone(), config);
636
637 assert_eq!(gpu_array.shape(), &[10, 5]);
639 assert!(gpu_array.is_on_gpu());
640
641 let info = gpu_array.device_info();
643 assert_eq!(info.get("backend").unwrap(), "CUDA");
644 assert_eq!(info.get("device_id").unwrap(), "0");
645 assert_eq!(info.get("on_gpu").unwrap(), "true");
646 }
647
648 #[test]
649 fn test_gpu_array_builder() {
650 let array = Array2::<f64>::ones((10, 5));
651
652 let gpu_array = GPUArrayBuilder::new()
653 .backend(GPUBackend::CUDA)
654 .device_id(1)
655 .async_ops(true)
656 .mixed_precision(true)
657 .memory_fraction(0.8)
658 .build(array.clone());
659
660 assert_eq!(gpu_array.config.backend, GPUBackend::CUDA);
662 assert_eq!(gpu_array.config.device_id, 1);
663 assert!(gpu_array.config.async_ops);
664 assert!(gpu_array.config.mixed_precision);
665 assert_eq!(gpu_array.config.memory_fraction, 0.8);
666 }
667
668 #[test]
669 fn test_gpu_array_kernels() {
670 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
671 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
672
673 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
674 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
675
676 let result = kernels::add(&gpu_a, &gpu_b).unwrap();
678 let expected = a + b;
679 assert_eq!(result.host_data(), &expected);
680
681 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
683 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
684
685 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
686 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
687
688 let result = kernels::multiply(&gpu_a, &gpu_b).unwrap();
689 let expected = a * b;
690 assert_eq!(result.host_data(), &expected);
691 }
692}