1use std::any::{Any, TypeId};
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16use crate::array_protocol::{ArrayFunction, ArrayProtocol, GPUArray, NotImplemented};
17use crate::error::{CoreError, CoreResult, ErrorContext};
18use ::ndarray::{Array, Dimension};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum GPUBackend {
23 CUDA,
25
26 ROCm,
28
29 Metal,
31
32 WebGPU,
34
35 OpenCL,
37}
38
39impl Default for GPUBackend {
40 fn default() -> Self {
41 Self::CUDA
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct GPUConfig {
48 pub backend: GPUBackend,
50
51 pub device_id: usize,
53
54 pub async_ops: bool,
56
57 pub mixed_precision: bool,
59
60 pub memory_fraction: f32,
62}
63
64impl Default for GPUConfig {
65 fn default() -> Self {
66 Self {
67 backend: GPUBackend::default(),
68 device_id: 0,
69 async_ops: true,
70 mixed_precision: false,
71 memory_fraction: 0.9,
72 }
73 }
74}
75
76pub struct GPUNdarray<T, D: Dimension>
78where
79 T: Clone + Send + Sync + 'static + num_traits::Zero,
80 T: std::ops::Div<f64, Output = T>,
81 D: Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
82{
83 host_data: Array<T, D>,
85
86 config: GPUConfig,
88
89 on_gpu: bool,
91
92 id: String,
94}
95
96impl<T, D> Debug for GPUNdarray<T, D>
97where
98 T: Debug + Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
99 D: Dimension + Debug + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
100{
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("GPUNdarray")
103 .field("config", &self.config)
104 .field("on_gpu", &self.on_gpu)
105 .field("id", &self.id)
106 .field("shape", &self.host_data.shape())
107 .finish()
108 }
109}
110
111impl<T, D> GPUNdarray<T, D>
112where
113 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
114 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
115{
116 #[must_use]
118 pub fn new(hostdata: Array<T, D>, config: GPUConfig) -> Self {
119 let uuid = uuid::Uuid::new_v4();
120 let id = format!("uuid{uuid}");
121 let mut array = Self {
122 host_data: hostdata,
123 config,
124 on_gpu: false,
125 id,
126 };
127
128 array.on_gpu = true;
131
132 array
133 }
134
135 #[must_use]
137 pub fn shape(&self) -> &[usize] {
138 self.host_data.shape()
139 }
140
141 #[must_use]
143 pub const fn host_data(&self) -> &Array<T, D> {
144 &self.host_data
145 }
146
147 pub fn host_data_mut(&mut self) -> &mut Array<T, D> {
149 &mut self.host_data
151 }
152
153 #[must_use]
155 pub const fn config(&self) -> &GPUConfig {
156 &self.config
157 }
158
159 pub fn execute_kernel<F, R>(&self, kernel: F) -> CoreResult<R>
164 where
165 F: FnOnce(&Array<T, D>) -> CoreResult<R>,
166 {
167 kernel(&self.host_data)
170 }
171
172 pub fn sync_to_host(&mut self) -> CoreResult<()> {
177 Ok(())
180 }
181
182 pub fn sync_to_gpu(&mut self) -> CoreResult<()> {
187 self.on_gpu = true;
190 Ok(())
191 }
192}
193
194impl<T, D> ArrayProtocol for GPUNdarray<T, D>
195where
196 T: Clone + Send + Sync + 'static + num_traits::Zero,
197 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
198 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
199{
200 fn array_function(
201 &self,
202 func: &ArrayFunction,
203 _types: &[TypeId],
204 args: &[Box<dyn Any>],
205 kwargs: &HashMap<String, Box<dyn Any>>,
206 ) -> Result<Box<dyn Any>, NotImplemented> {
207 match func.name {
208 "scirs2::array_protocol::operations::sum" => {
209 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
212
213 if let Some(&_ax) = axis {
214 let sum = self.host_data.sum();
218 Ok(Box::new(sum))
219 } else {
220 let sum = self.host_data.sum();
222 Ok(Box::new(sum))
223 }
224 }
225 "scirs2::array_protocol::operations::mean" => {
226 let sum = self.host_data.sum();
228 let count = self.host_data.len();
229 #[allow(clippy::cast_precision_loss)]
230 let mean = sum / count as f64;
231
232 Ok(Box::new(mean))
233 }
234 "scirs2::array_protocol::operations::add" => {
235 if args.len() < 2 {
237 return Err(NotImplemented);
238 }
239
240 if let Some(other) = args[1].downcast_ref::<Self>() {
242 if self.shape() != other.shape() {
244 return Err(NotImplemented);
245 }
246
247 let Ok(result) = kernels::add(self, other) else {
249 return Err(NotImplemented);
250 };
251
252 return Ok(Box::new(result));
253 }
254
255 Err(NotImplemented)
258 }
259 "scirs2::array_protocol::operations::multiply" => {
260 if args.len() < 2 {
262 return Err(NotImplemented);
263 }
264
265 if let Some(other) = args[1].downcast_ref::<Self>() {
267 if self.shape() != other.shape() {
269 return Err(NotImplemented);
270 }
271
272 let Ok(result) = kernels::multiply(self, other) else {
274 return Err(NotImplemented);
275 };
276
277 return Ok(Box::new(result));
278 }
279
280 Err(NotImplemented)
283 }
284 "scirs2::array_protocol::operations::matmul" => {
285 if args.len() < 2 {
287 return Err(NotImplemented);
288 }
289
290 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
294 return Err(NotImplemented);
295 }
296
297 if let Some(other) = args[1].downcast_ref::<Self>() {
299 if TypeId::of::<T>() == TypeId::of::<f64>()
302 && TypeId::of::<D>() == TypeId::of::<crate::ndarray::Ix2>()
303 {
304 let self_f64 = unsafe {
305 &*std::ptr::from_ref(self)
306 .cast::<GPUNdarray<f64, crate::ndarray::Ix2>>()
307 };
308 let other_f64 = unsafe {
309 &*std::ptr::from_ref(other)
310 .cast::<GPUNdarray<f64, crate::ndarray::Ix2>>()
311 };
312
313 match kernels::matmul(self_f64, other_f64) {
314 Ok(result) => {
315 return Ok(Box::new(result));
319 }
320 Err(_) => return Err(NotImplemented),
321 }
322 }
323 return Err(NotImplemented);
329 }
330
331 Err(NotImplemented)
332 }
333 "scirs2::array_protocol::operations::transpose" => {
334 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
337 return Err(NotImplemented);
338 }
339
340 let transposed = self.host_data.t().to_owned();
343 let result = Self::new(transposed, self.config.clone());
344
345 Ok(Box::new(result))
346 }
347 "scirs2::array_protocol::operations::reshape" => {
348 if let Some(shape) = kwargs
350 .get("shape")
351 .and_then(|s| s.downcast_ref::<Vec<usize>>())
352 {
353 match self.host_data.clone().into_shape_with_order(shape.clone()) {
354 Ok(reshaped) => {
355 let result = GPUNdarray::new(reshaped, self.config.clone());
356 return Ok(Box::new(result));
357 }
358 Err(_) => return Err(NotImplemented),
359 }
360 }
361
362 Err(NotImplemented)
363 }
364 _ => Err(NotImplemented),
365 }
366 }
367
368 fn as_any(&self) -> &dyn Any {
369 self
370 }
371
372 fn shape(&self) -> &[usize] {
373 self.host_data.shape()
374 }
375
376 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
377 Box::new(self.clone())
378 }
379}
380
381impl<T, D> GPUArray for GPUNdarray<T, D>
382where
383 T: Clone + Send + Sync + 'static + num_traits::Zero,
384 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
385 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
386{
387 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
390 Ok(Box::new(self.clone()))
392 }
393
394 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
397 let array = super::NdarrayWrapper::new(self.host_data.clone());
399
400 Ok(Box::new(array) as Box<dyn ArrayProtocol>)
401 }
402
403 fn is_on_gpu(&self) -> bool {
404 self.on_gpu
405 }
406
407 fn device_info(&self) -> HashMap<String, String> {
408 let mut info = HashMap::new();
409 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
410 info.insert("device_id".to_string(), self.config.device_id.to_string());
411 info.insert("on_gpu".to_string(), self.on_gpu.to_string());
412 info.insert("id".to_string(), self.id.clone());
413 info
414 }
415}
416
417impl<T, D> Clone for GPUNdarray<T, D>
418where
419 T: Clone + Send + Sync + 'static + num_traits::Zero,
420 T: std::ops::Div<f64, Output = T>,
421 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
422{
423 fn clone(&self) -> Self {
424 Self {
425 host_data: self.host_data.clone(),
426 config: self.config.clone(),
427 on_gpu: self.on_gpu,
428 id: self.id.clone(),
429 }
430 }
431}
432
433pub struct GPUArrayBuilder {
435 config: GPUConfig,
436}
437
438impl Default for GPUArrayBuilder {
439 fn default() -> Self {
440 Self::new()
441 }
442}
443
444impl GPUArrayBuilder {
445 #[must_use]
447 pub fn new() -> Self {
448 Self {
449 config: GPUConfig::default(),
450 }
451 }
452
453 #[must_use]
455 pub const fn backend(mut self, backend: GPUBackend) -> Self {
456 self.config.backend = backend;
457 self
458 }
459
460 #[must_use]
462 pub const fn device_id(mut self, device_id: usize) -> Self {
463 self.config.device_id = device_id;
464 self
465 }
466
467 #[must_use]
469 pub const fn async_ops(mut self, asyncops: bool) -> Self {
470 self.config.async_ops = asyncops;
471 self
472 }
473
474 #[must_use]
476 pub const fn mixed_precision(mut self, mixedprecision: bool) -> Self {
477 self.config.mixed_precision = mixedprecision;
478 self
479 }
480
481 #[must_use]
483 pub const fn memory_fraction(mut self, memoryfraction: f32) -> Self {
484 self.config.memory_fraction = memoryfraction;
485 self
486 }
487
488 #[must_use]
490 pub fn build<T, D>(self, hostdata: Array<T, D>) -> GPUNdarray<T, D>
491 where
492 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
493 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
494 {
495 GPUNdarray::new(hostdata, self.config)
496 }
497}
498
499pub mod kernels {
501 use super::*;
502 use ::ndarray::{Array, Dimension};
503
504 pub fn add<T, D>(a: &GPUNdarray<T, D>, b: &GPUNdarray<T, D>) -> CoreResult<GPUNdarray<T, D>>
509 where
510 T: Clone
511 + std::ops::Add<Output = T>
512 + Send
513 + Sync
514 + 'static
515 + num_traits::Zero
516 + std::ops::Div<f64, Output = T>,
517 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
518 {
519 if a.shape() != b.shape() {
524 return Err(CoreError::ShapeError(ErrorContext::new(format!(
525 "Shape mismatch: {:?} vs {:?}",
526 a.shape(),
527 b.shape()
528 ))));
529 }
530
531 let result_data = a.host_data().clone() + b.host_data().clone();
533
534 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
536 }
537
538 pub fn multiply<T, D>(
543 a: &GPUNdarray<T, D>,
544 b: &GPUNdarray<T, D>,
545 ) -> CoreResult<GPUNdarray<T, D>>
546 where
547 T: Clone
548 + std::ops::Mul<Output = T>
549 + Send
550 + Sync
551 + 'static
552 + num_traits::Zero
553 + std::ops::Div<f64, Output = T>,
554 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
555 {
556 if a.shape() != b.shape() {
561 return Err(CoreError::ShapeError(ErrorContext::new(format!(
562 "Shape mismatch: {:?} vs {:?}",
563 a.shape(),
564 b.shape()
565 ))));
566 }
567
568 let result_data = a.host_data().clone() * b.host_data().clone();
570
571 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
573 }
574
575 pub fn matmul<T>(
580 a: &GPUNdarray<T, crate::ndarray::Ix2>,
581 b: &GPUNdarray<T, crate::ndarray::Ix2>,
582 ) -> CoreResult<GPUNdarray<T, crate::ndarray::Ix2>>
583 where
584 T: Clone
585 + std::ops::Mul<Output = T>
586 + std::ops::Add<Output = T>
587 + Default
588 + Send
589 + Sync
590 + 'static
591 + num_traits::Zero
592 + std::ops::Div<f64, Output = T>,
593 {
594 let ashape = a.shape();
599 let bshape = b.shape();
600
601 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
602 return Err(CoreError::ShapeError(ErrorContext::new(format!(
603 "Incompatible shapes for matmul: {ashape:?} vs {bshape:?}"
604 ))));
605 }
606
607 let m = ashape[0];
610 let p = bshape[1];
611
612 let result_data = Array::default((m, p));
614
615 Ok(GPUNdarray::<T, crate::ndarray::Ix2>::new(
617 result_data,
618 a.config.clone(),
619 ))
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use ::ndarray::{arr2, Array2};
627
628 #[test]
629 fn test_gpu_ndarray_creation() {
630 let array = Array2::<f64>::ones((10, 5));
631 let config = GPUConfig::default();
632
633 let gpu_array = GPUNdarray::new(array.clone(), config);
634
635 assert_eq!(gpu_array.shape(), &[10, 5]);
637 assert!(gpu_array.is_on_gpu());
638
639 let info = gpu_array.device_info();
641 assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
642 assert_eq!(info.get("device_id").expect("Operation failed"), "0");
643 assert_eq!(info.get("on_gpu").expect("Operation failed"), "true");
644 }
645
646 #[test]
647 fn test_gpu_array_builder() {
648 let array = Array2::<f64>::ones((10, 5));
649
650 let gpu_array = GPUArrayBuilder::new()
651 .backend(GPUBackend::CUDA)
652 .device_id(1)
653 .async_ops(true)
654 .mixed_precision(true)
655 .memory_fraction(0.8)
656 .build(array.clone());
657
658 assert_eq!(gpu_array.config.backend, GPUBackend::CUDA);
660 assert_eq!(gpu_array.config.device_id, 1);
661 assert!(gpu_array.config.async_ops);
662 assert!(gpu_array.config.mixed_precision);
663 assert_eq!(gpu_array.config.memory_fraction, 0.8);
664 }
665
666 #[test]
667 fn test_gpu_array_kernels() {
668 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
669 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
670
671 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
672 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
673
674 let result = kernels::add(&gpu_a, &gpu_b).expect("Operation failed");
676 let expected = a + b;
677 assert_eq!(result.host_data(), &expected);
678
679 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
681 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
682
683 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
684 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
685
686 let result = kernels::multiply(&gpu_a, &gpu_b).expect("Operation failed");
687 let expected = a * b;
688 assert_eq!(result.host_data(), &expected);
689 }
690}