1use std::any::{Any, TypeId};
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16use crate::array_protocol::{ArrayFunction, ArrayProtocol, GPUArray, NotImplemented};
17use crate::error::{CoreError, CoreResult, ErrorContext};
18use ::ndarray::{Array, Dimension};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum GPUBackend {
23 CUDA,
25
26 ROCm,
28
29 Metal,
31
32 WebGPU,
34
35 OpenCL,
37}
38
39impl Default for GPUBackend {
40 fn default() -> Self {
41 Self::CUDA
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct GPUConfig {
48 pub backend: GPUBackend,
50
51 pub device_id: usize,
53
54 pub async_ops: bool,
56
57 pub mixed_precision: bool,
59
60 pub memory_fraction: f32,
62}
63
64impl Default for GPUConfig {
65 fn default() -> Self {
66 Self {
67 backend: GPUBackend::default(),
68 device_id: 0,
69 async_ops: true,
70 mixed_precision: false,
71 memory_fraction: 0.9,
72 }
73 }
74}
75
76pub struct GPUNdarray<T, D: Dimension>
78where
79 T: Clone + Send + Sync + 'static + num_traits::Zero,
80 T: std::ops::Div<f64, Output = T>,
81 D: Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
82{
83 host_data: Array<T, D>,
85
86 config: GPUConfig,
88
89 on_gpu: bool,
91
92 id: String,
94}
95
96impl<T, D> Debug for GPUNdarray<T, D>
97where
98 T: Debug + Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
99 D: Dimension + Debug + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
100{
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("GPUNdarray")
103 .field("config", &self.config)
104 .field("on_gpu", &self.on_gpu)
105 .field("id", &self.id)
106 .field("shape", &self.host_data.shape())
107 .finish()
108 }
109}
110
111impl<T, D> GPUNdarray<T, D>
112where
113 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
114 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
115{
116 #[must_use]
118 pub fn new(hostdata: Array<T, D>, config: GPUConfig) -> Self {
119 let uuid = uuid::Uuid::new_v4();
120 let id = format!("uuid{uuid}");
121 let mut array = Self {
122 host_data: hostdata,
123 config,
124 on_gpu: false,
125 id,
126 };
127
128 array.on_gpu = true;
131
132 array
133 }
134
135 #[must_use]
137 pub fn shape(&self) -> &[usize] {
138 self.host_data.shape()
139 }
140
141 #[must_use]
143 pub const fn host_data(&self) -> &Array<T, D> {
144 &self.host_data
145 }
146
147 pub fn host_data_mut(&mut self) -> &mut Array<T, D> {
149 &mut self.host_data
151 }
152
153 #[must_use]
155 pub const fn config(&self) -> &GPUConfig {
156 &self.config
157 }
158
159 pub fn execute_kernel<F, R>(&self, kernel: F) -> CoreResult<R>
164 where
165 F: FnOnce(&Array<T, D>) -> CoreResult<R>,
166 {
167 kernel(&self.host_data)
170 }
171
172 pub fn sync_to_host(&mut self) -> CoreResult<()> {
177 Ok(())
180 }
181
182 pub fn sync_to_gpu(&mut self) -> CoreResult<()> {
187 self.on_gpu = true;
190 Ok(())
191 }
192}
193
194impl<T, D> ArrayProtocol for GPUNdarray<T, D>
195where
196 T: Clone + Send + Sync + 'static + num_traits::Zero,
197 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
198 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
199{
200 fn array_function(
201 &self,
202 func: &ArrayFunction,
203 _types: &[TypeId],
204 args: &[Box<dyn Any>],
205 kwargs: &HashMap<String, Box<dyn Any>>,
206 ) -> Result<Box<dyn Any>, NotImplemented> {
207 match func.name {
208 "scirs2::array_protocol::operations::sum" => {
209 let axis = kwargs.get("axis").and_then(|a| a.downcast_ref::<usize>());
212
213 if let Some(&_ax) = axis {
214 let sum = self.host_data.sum();
218 Ok(Box::new(sum))
219 } else {
220 let sum = self.host_data.sum();
222 Ok(Box::new(sum))
223 }
224 }
225 "scirs2::array_protocol::operations::mean" => {
226 let sum = self.host_data.sum();
228 let count = self.host_data.len();
229 #[allow(clippy::cast_precision_loss)]
230 let mean = sum / count as f64;
231
232 Ok(Box::new(mean))
233 }
234 "scirs2::array_protocol::operations::add" => {
235 if args.len() < 2 {
237 return Err(NotImplemented);
238 }
239
240 if let Some(other) = args[1].downcast_ref::<Self>() {
242 if self.shape() != other.shape() {
244 return Err(NotImplemented);
245 }
246
247 let Ok(result) = kernels::add(self, other) else {
249 return Err(NotImplemented);
250 };
251
252 return Ok(Box::new(result));
253 }
254
255 Err(NotImplemented)
258 }
259 "scirs2::array_protocol::operations::multiply" => {
260 if args.len() < 2 {
262 return Err(NotImplemented);
263 }
264
265 if let Some(other) = args[1].downcast_ref::<Self>() {
267 if self.shape() != other.shape() {
269 return Err(NotImplemented);
270 }
271
272 let Ok(result) = kernels::multiply(self, other) else {
274 return Err(NotImplemented);
275 };
276
277 return Ok(Box::new(result));
278 }
279
280 Err(NotImplemented)
283 }
284 "scirs2::array_protocol::operations::matmul" => {
285 if args.len() < 2 {
287 return Err(NotImplemented);
288 }
289
290 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
294 return Err(NotImplemented);
295 }
296
297 if let Some(other) = args[1].downcast_ref::<Self>() {
299 if TypeId::of::<T>() == TypeId::of::<f64>()
302 && TypeId::of::<D>() == TypeId::of::<crate::ndarray::Ix2>()
303 {
304 let self_f64 = unsafe {
305 &*std::ptr::from_ref(self)
306 .cast::<GPUNdarray<f64, crate::ndarray::Ix2>>()
307 };
308 let other_f64 = unsafe {
309 &*std::ptr::from_ref(other)
310 .cast::<GPUNdarray<f64, crate::ndarray::Ix2>>()
311 };
312
313 match kernels::matmul(self_f64, other_f64) {
314 Ok(result) => {
315 return Ok(Box::new(result));
319 }
320 Err(_) => return Err(NotImplemented),
321 }
322 }
323 let result = Self::new(self.host_data.clone(), self.config.clone());
326 return Ok(Box::new(result));
327 }
328
329 Err(NotImplemented)
330 }
331 "scirs2::array_protocol::operations::transpose" => {
332 if TypeId::of::<D>() != TypeId::of::<crate::ndarray::Ix2>() {
335 return Err(NotImplemented);
336 }
337
338 let transposed = self.host_data.t().to_owned();
341 let result = Self::new(transposed, self.config.clone());
342
343 Ok(Box::new(result))
344 }
345 "scirs2::array_protocol::operations::reshape" => {
346 if let Some(shape) = kwargs
348 .get("shape")
349 .and_then(|s| s.downcast_ref::<Vec<usize>>())
350 {
351 match self.host_data.clone().into_shape_with_order(shape.clone()) {
352 Ok(reshaped) => {
353 let result = GPUNdarray::new(reshaped, self.config.clone());
354 return Ok(Box::new(result));
355 }
356 Err(_) => return Err(NotImplemented),
357 }
358 }
359
360 Err(NotImplemented)
361 }
362 _ => Err(NotImplemented),
363 }
364 }
365
366 fn as_any(&self) -> &dyn Any {
367 self
368 }
369
370 fn shape(&self) -> &[usize] {
371 self.host_data.shape()
372 }
373
374 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
375 Box::new(self.clone())
376 }
377}
378
379impl<T, D> GPUArray for GPUNdarray<T, D>
380where
381 T: Clone + Send + Sync + 'static + num_traits::Zero,
382 T: std::ops::Div<f64, Output = T> + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
383 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
384{
385 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
388 Ok(Box::new(self.clone()))
390 }
391
392 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
395 let array = super::NdarrayWrapper::new(self.host_data.clone());
397
398 Ok(Box::new(array) as Box<dyn ArrayProtocol>)
399 }
400
401 fn is_on_gpu(&self) -> bool {
402 self.on_gpu
403 }
404
405 fn device_info(&self) -> HashMap<String, String> {
406 let mut info = HashMap::new();
407 info.insert("backend".to_string(), format!("{:?}", self.config.backend));
408 info.insert("device_id".to_string(), self.config.device_id.to_string());
409 info.insert("on_gpu".to_string(), self.on_gpu.to_string());
410 info.insert("id".to_string(), self.id.clone());
411 info
412 }
413}
414
415impl<T, D> Clone for GPUNdarray<T, D>
416where
417 T: Clone + Send + Sync + 'static + num_traits::Zero,
418 T: std::ops::Div<f64, Output = T>,
419 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
420{
421 fn clone(&self) -> Self {
422 Self {
423 host_data: self.host_data.clone(),
424 config: self.config.clone(),
425 on_gpu: self.on_gpu,
426 id: self.id.clone(),
427 }
428 }
429}
430
431pub struct GPUArrayBuilder {
433 config: GPUConfig,
434}
435
436impl Default for GPUArrayBuilder {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442impl GPUArrayBuilder {
443 #[must_use]
445 pub fn new() -> Self {
446 Self {
447 config: GPUConfig::default(),
448 }
449 }
450
451 #[must_use]
453 pub const fn backend(mut self, backend: GPUBackend) -> Self {
454 self.config.backend = backend;
455 self
456 }
457
458 #[must_use]
460 pub const fn device_id(mut self, device_id: usize) -> Self {
461 self.config.device_id = device_id;
462 self
463 }
464
465 #[must_use]
467 pub const fn async_ops(mut self, asyncops: bool) -> Self {
468 self.config.async_ops = asyncops;
469 self
470 }
471
472 #[must_use]
474 pub const fn mixed_precision(mut self, mixedprecision: bool) -> Self {
475 self.config.mixed_precision = mixedprecision;
476 self
477 }
478
479 #[must_use]
481 pub const fn memory_fraction(mut self, memoryfraction: f32) -> Self {
482 self.config.memory_fraction = memoryfraction;
483 self
484 }
485
486 #[must_use]
488 pub fn build<T, D>(self, hostdata: Array<T, D>) -> GPUNdarray<T, D>
489 where
490 T: Clone + Send + Sync + 'static + num_traits::Zero + std::ops::Div<f64, Output = T>,
491 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
492 {
493 GPUNdarray::new(hostdata, self.config)
494 }
495}
496
497pub mod kernels {
499 use super::*;
500 use ::ndarray::{Array, Dimension};
501
502 pub fn add<T, D>(a: &GPUNdarray<T, D>, b: &GPUNdarray<T, D>) -> CoreResult<GPUNdarray<T, D>>
507 where
508 T: Clone
509 + std::ops::Add<Output = T>
510 + Send
511 + Sync
512 + 'static
513 + num_traits::Zero
514 + std::ops::Div<f64, Output = T>,
515 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
516 {
517 if a.shape() != b.shape() {
522 return Err(CoreError::ShapeError(ErrorContext::new(format!(
523 "Shape mismatch: {:?} vs {:?}",
524 a.shape(),
525 b.shape()
526 ))));
527 }
528
529 let result_data = a.host_data().clone() + b.host_data().clone();
531
532 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
534 }
535
536 pub fn multiply<T, D>(
541 a: &GPUNdarray<T, D>,
542 b: &GPUNdarray<T, D>,
543 ) -> CoreResult<GPUNdarray<T, D>>
544 where
545 T: Clone
546 + std::ops::Mul<Output = T>
547 + Send
548 + Sync
549 + 'static
550 + num_traits::Zero
551 + std::ops::Div<f64, Output = T>,
552 D: Dimension + Clone + Send + Sync + 'static + crate::ndarray::RemoveAxis,
553 {
554 if a.shape() != b.shape() {
559 return Err(CoreError::ShapeError(ErrorContext::new(format!(
560 "Shape mismatch: {:?} vs {:?}",
561 a.shape(),
562 b.shape()
563 ))));
564 }
565
566 let result_data = a.host_data().clone() * b.host_data().clone();
568
569 Ok(GPUNdarray::<T, D>::new(result_data, a.config.clone()))
571 }
572
573 pub fn matmul<T>(
578 a: &GPUNdarray<T, crate::ndarray::Ix2>,
579 b: &GPUNdarray<T, crate::ndarray::Ix2>,
580 ) -> CoreResult<GPUNdarray<T, crate::ndarray::Ix2>>
581 where
582 T: Clone
583 + std::ops::Mul<Output = T>
584 + std::ops::Add<Output = T>
585 + Default
586 + Send
587 + Sync
588 + 'static
589 + num_traits::Zero
590 + std::ops::Div<f64, Output = T>,
591 {
592 let ashape = a.shape();
597 let bshape = b.shape();
598
599 if ashape.len() != 2 || bshape.len() != 2 || ashape[1] != bshape[0] {
600 return Err(CoreError::ShapeError(ErrorContext::new(format!(
601 "Incompatible shapes for matmul: {ashape:?} vs {bshape:?}"
602 ))));
603 }
604
605 let m = ashape[0];
608 let p = bshape[1];
609
610 let result_data = Array::default((m, p));
612
613 Ok(GPUNdarray::<T, crate::ndarray::Ix2>::new(
615 result_data,
616 a.config.clone(),
617 ))
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624 use ::ndarray::{arr2, Array2};
625
626 #[test]
627 fn test_gpu_ndarray_creation() {
628 let array = Array2::<f64>::ones((10, 5));
629 let config = GPUConfig::default();
630
631 let gpu_array = GPUNdarray::new(array.clone(), config);
632
633 assert_eq!(gpu_array.shape(), &[10, 5]);
635 assert!(gpu_array.is_on_gpu());
636
637 let info = gpu_array.device_info();
639 assert_eq!(info.get("backend").expect("Operation failed"), "CUDA");
640 assert_eq!(info.get("device_id").expect("Operation failed"), "0");
641 assert_eq!(info.get("on_gpu").expect("Operation failed"), "true");
642 }
643
644 #[test]
645 fn test_gpu_array_builder() {
646 let array = Array2::<f64>::ones((10, 5));
647
648 let gpu_array = GPUArrayBuilder::new()
649 .backend(GPUBackend::CUDA)
650 .device_id(1)
651 .async_ops(true)
652 .mixed_precision(true)
653 .memory_fraction(0.8)
654 .build(array.clone());
655
656 assert_eq!(gpu_array.config.backend, GPUBackend::CUDA);
658 assert_eq!(gpu_array.config.device_id, 1);
659 assert!(gpu_array.config.async_ops);
660 assert!(gpu_array.config.mixed_precision);
661 assert_eq!(gpu_array.config.memory_fraction, 0.8);
662 }
663
664 #[test]
665 fn test_gpu_array_kernels() {
666 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
667 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
668
669 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
670 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
671
672 let result = kernels::add(&gpu_a, &gpu_b).expect("Operation failed");
674 let expected = a + b;
675 assert_eq!(result.host_data(), &expected);
676
677 let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
679 let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
680
681 let gpu_a = GPUNdarray::new(a.clone(), GPUConfig::default());
682 let gpu_b = GPUNdarray::new(b.clone(), GPUConfig::default());
683
684 let result = kernels::multiply(&gpu_a, &gpu_b).expect("Operation failed");
685 let expected = a * b;
686 assert_eq!(result.host_data(), &expected);
687 }
688}