1use std::any::{Any, TypeId};
33use std::collections::HashMap;
34use std::marker::PhantomData;
35use std::sync::{Arc, OnceLock};
36
37use ndarray::{Array1, Array2, IxDyn};
38
39use crate::array_protocol::{
40 ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
41};
42use crate::error::{CoreError, CoreResult, ErrorContext};
43use crate::gpu::backends::WebGPUContext;
44use crate::gpu::GpuError;
45
46mod sealed {
51 pub trait Sealed {}
52}
53
54pub trait GpuScalar: sealed::Sealed + Clone + Send + Sync + 'static {}
57
58impl sealed::Sealed for f32 {}
59impl GpuScalar for f32 {}
60
61const GPU_THRESHOLD: usize = 4096;
67
68static GPU_AVAILABLE: OnceLock<bool> = OnceLock::new();
70
71static GPU_CONTEXT: OnceLock<Option<Arc<WebGPUContext>>> = OnceLock::new();
73
74pub fn global_context() -> Option<Arc<WebGPUContext>> {
78 GPU_CONTEXT
79 .get_or_init(|| match WebGPUContext::new() {
80 Ok(ctx) => Some(Arc::new(ctx)),
81 Err(_) => None,
82 })
83 .clone()
84}
85
86pub fn is_gpu_available() -> bool {
88 *GPU_AVAILABLE.get_or_init(|| global_context().is_some())
89}
90
91pub struct GpuNdarray<T: GpuScalar> {
100 buffer: Arc<wgpu::Buffer>,
103
104 shape: Vec<usize>,
106
107 strides: Vec<usize>,
109
110 context: Arc<WebGPUContext>,
112
113 _phantom: PhantomData<T>,
114}
115
116impl<T: GpuScalar> std::fmt::Debug for GpuNdarray<T> {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.debug_struct("GpuNdarray")
119 .field("shape", &self.shape)
120 .field("strides", &self.strides)
121 .finish_non_exhaustive()
122 }
123}
124
125impl<T: GpuScalar> Clone for GpuNdarray<T> {
126 fn clone(&self) -> Self {
127 Self {
128 buffer: Arc::clone(&self.buffer),
129 shape: self.shape.clone(),
130 strides: self.strides.clone(),
131 context: Arc::clone(&self.context),
132 _phantom: PhantomData,
133 }
134 }
135}
136
137impl<T: GpuScalar> GpuNdarray<T> {
138 #[must_use]
140 pub fn buffer_arc(&self) -> &Arc<wgpu::Buffer> {
141 &self.buffer
142 }
143
144 #[must_use]
146 fn numel(&self) -> usize {
147 self.shape.iter().product()
148 }
149
150 fn compute_strides(shape: &[usize]) -> Vec<usize> {
152 let mut strides = vec![1usize; shape.len()];
153 for i in (0..shape.len().saturating_sub(1)).rev() {
154 strides[i] = strides[i + 1] * shape[i + 1];
155 }
156 strides
157 }
158}
159
160fn build_pipeline(
168 ctx: &WebGPUContext,
169 wgsl: &str,
170 bgl_entries: &[wgpu::BindGroupLayoutEntry],
171 label: &str,
172) -> Result<(wgpu::ComputePipeline, wgpu::BindGroupLayout), GpuError> {
173 let device = ctx.device();
174 let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
175 label: Some(label),
176 source: wgpu::ShaderSource::Wgsl(wgsl.into()),
177 });
178 let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
179 label: Some(&format!("{label}_bgl")),
180 entries: bgl_entries,
181 });
182 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
183 label: Some(&format!("{label}_layout")),
184 bind_group_layouts: &[Some(&bgl)],
185 ..Default::default()
186 });
187 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
188 label: Some(&format!("{label}_pipeline")),
189 layout: Some(&pipeline_layout),
190 module: &shader_module,
191 entry_point: Some("main"),
192 compilation_options: Default::default(),
193 cache: None,
194 });
195 Ok((pipeline, bgl))
196}
197
198fn storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
200 wgpu::BindGroupLayoutEntry {
201 binding,
202 visibility: wgpu::ShaderStages::COMPUTE,
203 ty: wgpu::BindingType::Buffer {
204 ty: wgpu::BufferBindingType::Storage { read_only: true },
205 has_dynamic_offset: false,
206 min_binding_size: None,
207 },
208 count: None,
209 }
210}
211
212fn storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
214 wgpu::BindGroupLayoutEntry {
215 binding,
216 visibility: wgpu::ShaderStages::COMPUTE,
217 ty: wgpu::BindingType::Buffer {
218 ty: wgpu::BufferBindingType::Storage { read_only: false },
219 has_dynamic_offset: false,
220 min_binding_size: None,
221 },
222 count: None,
223 }
224}
225
226fn uniform_buf(binding: u32) -> wgpu::BindGroupLayoutEntry {
228 wgpu::BindGroupLayoutEntry {
229 binding,
230 visibility: wgpu::ShaderStages::COMPUTE,
231 ty: wgpu::BindingType::Buffer {
232 ty: wgpu::BufferBindingType::Uniform,
233 has_dynamic_offset: false,
234 min_binding_size: None,
235 },
236 count: None,
237 }
238}
239
240impl GpuNdarray<f32> {
245 pub fn from_ndarray_data(
249 data: &[f32],
250 shape: Vec<usize>,
251 context: Arc<WebGPUContext>,
252 ) -> Result<Self, GpuError> {
253 use wgpu::util::DeviceExt as _;
254 let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
255 let buffer = context
256 .device()
257 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
258 label: Some("GpuNdarray<f32>"),
259 contents: &bytes,
260 usage: wgpu::BufferUsages::STORAGE
261 | wgpu::BufferUsages::COPY_SRC
262 | wgpu::BufferUsages::COPY_DST,
263 });
264 let strides = Self::compute_strides(&shape);
265 Ok(Self {
266 buffer: Arc::new(buffer),
267 shape,
268 strides,
269 context,
270 _phantom: PhantomData,
271 })
272 }
273
274 pub fn to_vec(&self) -> Result<Vec<f32>, GpuError> {
276 let byte_size = (self.numel() * std::mem::size_of::<f32>()) as u64;
277 let staging = self
278 .context
279 .device()
280 .create_buffer(&wgpu::BufferDescriptor {
281 label: Some("GpuNdarray-readback"),
282 size: byte_size,
283 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
284 mapped_at_creation: false,
285 });
286
287 let mut encoder =
288 self.context
289 .device()
290 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
291 label: Some("GpuNdarray-readback-encoder"),
292 });
293 encoder.copy_buffer_to_buffer(&self.buffer, 0, &staging, 0, byte_size);
294 self.context.queue().submit(Some(encoder.finish()));
295
296 self.context
297 .device()
298 .poll(wgpu::PollType::wait_indefinitely())
299 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
300
301 let slice = staging.slice(0..byte_size);
302 let (tx, rx) = std::sync::mpsc::channel();
303 slice.map_async(wgpu::MapMode::Read, move |r| {
304 let _ = tx.send(r);
305 });
306
307 self.context
308 .device()
309 .poll(wgpu::PollType::wait_indefinitely())
310 .map_err(|e| GpuError::Other(format!("poll-map error: {e:?}")))?;
311
312 rx.recv()
313 .map_err(|_| GpuError::Other("channel closed".into()))?
314 .map_err(|e| GpuError::Other(format!("map_async failed: {e:?}")))?;
315
316 let mapped = slice.get_mapped_range();
317 let result: Vec<f32> = mapped
318 .chunks_exact(4)
319 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
320 .collect();
321 drop(mapped);
322 staging.unmap();
323 Ok(result)
324 }
325
326 pub fn to_ndarray(&self) -> Result<ndarray::ArrayD<f32>, GpuError> {
328 let flat = self.to_vec()?;
329 ndarray::ArrayD::<f32>::from_shape_vec(self.shape.clone(), flat)
330 .map_err(|e| GpuError::Other(format!("shape_vec error: {e}")))
331 }
332
333 pub fn from_data(data: &[f32], shape: Vec<usize>) -> Result<Self, GpuError> {
339 let ctx =
340 global_context().ok_or_else(|| GpuError::Other("No wgpu adapter available".into()))?;
341 Self::from_ndarray_data(data, shape, ctx)
342 }
343
344 pub fn add(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
348 self.dispatch_elementwise_binary(other, 0)
349 }
350
351 pub fn subtract(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
353 self.dispatch_elementwise_binary(other, 1)
354 }
355
356 pub fn multiply(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
358 self.dispatch_elementwise_binary(other, 2)
359 }
360
361 pub fn multiply_by_scalar_f32(&self, scalar: f32) -> Result<GpuNdarray<f32>, GpuError> {
363 self.dispatch_scalar_multiply(scalar)
364 }
365
366 pub fn sum_all(&self) -> Result<f32, GpuError> {
368 self.dispatch_sum_all()
369 }
370
371 pub fn dot_gpu(&self, other: &GpuNdarray<f32>) -> Result<f32, GpuError> {
373 let prod = self.dispatch_elementwise_binary(other, 2)?;
374 prod.dispatch_sum_all()
375 }
376
377 pub fn matmul(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
382 self.dispatch_matmul(other)
383 }
384
385 fn dispatch_elementwise_binary(
391 &self,
392 other: &GpuNdarray<f32>,
393 op_id: u32,
394 ) -> Result<GpuNdarray<f32>, GpuError> {
395 let n = self.numel();
396 if n != other.numel() {
397 return Err(GpuError::InvalidParameter(
398 "shape mismatch in elementwise binary".into(),
399 ));
400 }
401 let wgsl = match op_id {
402 0 => ELEMENTWISE_ADD_WGSL,
403 1 => ELEMENTWISE_SUB_WGSL,
404 _ => ELEMENTWISE_MUL_WGSL,
405 };
406 let byte_size = (n * 4) as u64;
407 let result_buf = self
408 .context
409 .device()
410 .create_buffer(&wgpu::BufferDescriptor {
411 label: Some("elementwise-result"),
412 size: byte_size,
413 usage: wgpu::BufferUsages::STORAGE
414 | wgpu::BufferUsages::COPY_SRC
415 | wgpu::BufferUsages::COPY_DST,
416 mapped_at_creation: false,
417 });
418
419 let bgl_entries = [storage_ro(0), storage_ro(1), storage_rw(2)];
420 let (pipeline, bgl) = build_pipeline(&self.context, wgsl, &bgl_entries, "elementwise")?;
421
422 let bind_group = self
423 .context
424 .device()
425 .create_bind_group(&wgpu::BindGroupDescriptor {
426 label: Some("elementwise-bg"),
427 layout: &bgl,
428 entries: &[
429 wgpu::BindGroupEntry {
430 binding: 0,
431 resource: self.buffer.as_entire_binding(),
432 },
433 wgpu::BindGroupEntry {
434 binding: 1,
435 resource: other.buffer.as_entire_binding(),
436 },
437 wgpu::BindGroupEntry {
438 binding: 2,
439 resource: result_buf.as_entire_binding(),
440 },
441 ],
442 });
443
444 let mut encoder =
445 self.context
446 .device()
447 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
448 label: Some("elementwise-encoder"),
449 });
450 {
451 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
452 label: Some("elementwise-pass"),
453 timestamp_writes: None,
454 });
455 cpass.set_pipeline(&pipeline);
456 cpass.set_bind_group(0, &bind_group, &[]);
457 let workgroups = (n as u32 + 255) / 256;
458 cpass.dispatch_workgroups(workgroups, 1, 1);
459 }
460 self.context.queue().submit(Some(encoder.finish()));
461 self.context
462 .device()
463 .poll(wgpu::PollType::wait_indefinitely())
464 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
465
466 Ok(GpuNdarray {
467 buffer: Arc::new(result_buf),
468 shape: self.shape.clone(),
469 strides: self.strides.clone(),
470 context: Arc::clone(&self.context),
471 _phantom: PhantomData,
472 })
473 }
474
475 fn dispatch_scalar_multiply(&self, scalar: f32) -> Result<GpuNdarray<f32>, GpuError> {
477 let n = self.numel();
478 let byte_size = (n * 4) as u64;
479 let result_buf = self
480 .context
481 .device()
482 .create_buffer(&wgpu::BufferDescriptor {
483 label: Some("scalar-mul-result"),
484 size: byte_size,
485 usage: wgpu::BufferUsages::STORAGE
486 | wgpu::BufferUsages::COPY_SRC
487 | wgpu::BufferUsages::COPY_DST,
488 mapped_at_creation: false,
489 });
490
491 let mut unif: Vec<u8> = Vec::with_capacity(16);
493 unif.extend_from_slice(&scalar.to_le_bytes());
494 unif.extend_from_slice(&(n as u32).to_le_bytes());
495 while unif.len() % 16 != 0 {
496 unif.push(0);
497 }
498 use wgpu::util::DeviceExt as _;
499 let uniform_buffer =
500 self.context
501 .device()
502 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
503 label: Some("scalar-mul-uniform"),
504 contents: &unif,
505 usage: wgpu::BufferUsages::UNIFORM,
506 });
507
508 let bgl_entries = [storage_ro(0), storage_rw(1), uniform_buf(2)];
509 let (pipeline, bgl) =
510 build_pipeline(&self.context, SCALAR_MUL_WGSL, &bgl_entries, "scalar-mul")?;
511
512 let bind_group = self
513 .context
514 .device()
515 .create_bind_group(&wgpu::BindGroupDescriptor {
516 label: Some("scalar-mul-bg"),
517 layout: &bgl,
518 entries: &[
519 wgpu::BindGroupEntry {
520 binding: 0,
521 resource: self.buffer.as_entire_binding(),
522 },
523 wgpu::BindGroupEntry {
524 binding: 1,
525 resource: result_buf.as_entire_binding(),
526 },
527 wgpu::BindGroupEntry {
528 binding: 2,
529 resource: uniform_buffer.as_entire_binding(),
530 },
531 ],
532 });
533
534 let mut encoder =
535 self.context
536 .device()
537 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
538 label: Some("scalar-mul-encoder"),
539 });
540 {
541 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
542 label: Some("scalar-mul-pass"),
543 timestamp_writes: None,
544 });
545 cpass.set_pipeline(&pipeline);
546 cpass.set_bind_group(0, &bind_group, &[]);
547 let workgroups = (n as u32 + 255) / 256;
548 cpass.dispatch_workgroups(workgroups, 1, 1);
549 }
550 self.context.queue().submit(Some(encoder.finish()));
551 self.context
552 .device()
553 .poll(wgpu::PollType::wait_indefinitely())
554 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
555
556 Ok(GpuNdarray {
557 buffer: Arc::new(result_buf),
558 shape: self.shape.clone(),
559 strides: self.strides.clone(),
560 context: Arc::clone(&self.context),
561 _phantom: PhantomData,
562 })
563 }
564
565 fn dispatch_matmul(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
567 if self.shape.len() != 2 || other.shape.len() != 2 {
568 return Err(GpuError::InvalidParameter(
569 "matmul requires 2-D arrays".into(),
570 ));
571 }
572 let (m, k) = (self.shape[0], self.shape[1]);
573 let (k2, n) = (other.shape[0], other.shape[1]);
574 if k != k2 {
575 return Err(GpuError::InvalidParameter(format!(
576 "matmul shape mismatch: [{m},{k}] x [{k2},{n}]"
577 )));
578 }
579
580 let byte_size = (m * n * 4) as u64;
581 let result_buf = self
582 .context
583 .device()
584 .create_buffer(&wgpu::BufferDescriptor {
585 label: Some("matmul-result"),
586 size: byte_size,
587 usage: wgpu::BufferUsages::STORAGE
588 | wgpu::BufferUsages::COPY_SRC
589 | wgpu::BufferUsages::COPY_DST,
590 mapped_at_creation: false,
591 });
592
593 let uniform_data: [u32; 3] = [m as u32, n as u32, k as u32];
594 let uniform_bytes: Vec<u8> = uniform_data.iter().flat_map(|v| v.to_le_bytes()).collect();
595 let mut uniform_padded = uniform_bytes;
597 while uniform_padded.len() % 16 != 0 {
598 uniform_padded.push(0);
599 }
600 use wgpu::util::DeviceExt as _;
601 let uniform_buffer =
602 self.context
603 .device()
604 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
605 label: Some("matmul-uniform"),
606 contents: &uniform_padded,
607 usage: wgpu::BufferUsages::UNIFORM,
608 });
609
610 let bgl_entries = [storage_ro(0), storage_ro(1), storage_rw(2), uniform_buf(3)];
611 let (pipeline, bgl) = build_pipeline(&self.context, MATMUL_WGSL, &bgl_entries, "matmul")?;
612 let bind_group = self
613 .context
614 .device()
615 .create_bind_group(&wgpu::BindGroupDescriptor {
616 label: Some("matmul-bg"),
617 layout: &bgl,
618 entries: &[
619 wgpu::BindGroupEntry {
620 binding: 0,
621 resource: self.buffer.as_entire_binding(),
622 },
623 wgpu::BindGroupEntry {
624 binding: 1,
625 resource: other.buffer.as_entire_binding(),
626 },
627 wgpu::BindGroupEntry {
628 binding: 2,
629 resource: result_buf.as_entire_binding(),
630 },
631 wgpu::BindGroupEntry {
632 binding: 3,
633 resource: uniform_buffer.as_entire_binding(),
634 },
635 ],
636 });
637
638 let wg_x = (n as u32 + 15) / 16;
639 let wg_y = (m as u32 + 15) / 16;
640 let mut encoder =
641 self.context
642 .device()
643 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
644 label: Some("matmul-encoder"),
645 });
646 {
647 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
648 label: Some("matmul-pass"),
649 timestamp_writes: None,
650 });
651 cpass.set_pipeline(&pipeline);
652 cpass.set_bind_group(0, &bind_group, &[]);
653 cpass.dispatch_workgroups(wg_x, wg_y, 1);
654 }
655 self.context.queue().submit(Some(encoder.finish()));
656 self.context
657 .device()
658 .poll(wgpu::PollType::wait_indefinitely())
659 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
660
661 Ok(GpuNdarray {
662 buffer: Arc::new(result_buf),
663 shape: vec![m, n],
664 strides: Self::compute_strides(&[m, n]),
665 context: Arc::clone(&self.context),
666 _phantom: PhantomData,
667 })
668 }
669
670 fn dispatch_sum_all(&self) -> Result<f32, GpuError> {
672 let n = self.numel();
673 let workgroup_count = (n as u32 + 255) / 256;
675 let partial_byte_size = (workgroup_count as usize * 4) as u64;
676 let partial_buf = self
677 .context
678 .device()
679 .create_buffer(&wgpu::BufferDescriptor {
680 label: Some("sum-partial"),
681 size: partial_byte_size,
682 usage: wgpu::BufferUsages::STORAGE
683 | wgpu::BufferUsages::COPY_SRC
684 | wgpu::BufferUsages::COPY_DST,
685 mapped_at_creation: false,
686 });
687
688 let n_bytes = (n as u32).to_le_bytes();
689 let mut uniform_bytes = n_bytes.to_vec();
690 while uniform_bytes.len() % 16 != 0 {
691 uniform_bytes.push(0);
692 }
693 use wgpu::util::DeviceExt as _;
694 let uniform_buffer =
695 self.context
696 .device()
697 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
698 label: Some("sum-uniform"),
699 contents: &uniform_bytes,
700 usage: wgpu::BufferUsages::UNIFORM,
701 });
702
703 let bgl_entries = [storage_ro(0), storage_rw(1), uniform_buf(2)];
704 let (pipeline, bgl) =
705 build_pipeline(&self.context, SUM_REDUCE_WGSL, &bgl_entries, "sum-reduce")?;
706 let bind_group = self
707 .context
708 .device()
709 .create_bind_group(&wgpu::BindGroupDescriptor {
710 label: Some("sum-bg"),
711 layout: &bgl,
712 entries: &[
713 wgpu::BindGroupEntry {
714 binding: 0,
715 resource: self.buffer.as_entire_binding(),
716 },
717 wgpu::BindGroupEntry {
718 binding: 1,
719 resource: partial_buf.as_entire_binding(),
720 },
721 wgpu::BindGroupEntry {
722 binding: 2,
723 resource: uniform_buffer.as_entire_binding(),
724 },
725 ],
726 });
727
728 let mut encoder =
729 self.context
730 .device()
731 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
732 label: Some("sum-encoder"),
733 });
734 {
735 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
736 label: Some("sum-pass"),
737 timestamp_writes: None,
738 });
739 cpass.set_pipeline(&pipeline);
740 cpass.set_bind_group(0, &bind_group, &[]);
741 cpass.dispatch_workgroups(workgroup_count, 1, 1);
742 }
743 self.context.queue().submit(Some(encoder.finish()));
744 self.context
745 .device()
746 .poll(wgpu::PollType::wait_indefinitely())
747 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
748
749 let staging = self
751 .context
752 .device()
753 .create_buffer(&wgpu::BufferDescriptor {
754 label: Some("sum-staging"),
755 size: partial_byte_size,
756 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
757 mapped_at_creation: false,
758 });
759 let mut encoder2 =
760 self.context
761 .device()
762 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
763 label: Some("sum-copy-encoder"),
764 });
765 encoder2.copy_buffer_to_buffer(&partial_buf, 0, &staging, 0, partial_byte_size);
766 self.context.queue().submit(Some(encoder2.finish()));
767
768 self.context
769 .device()
770 .poll(wgpu::PollType::wait_indefinitely())
771 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
772
773 let slice = staging.slice(0..partial_byte_size);
774 let (tx, rx) = std::sync::mpsc::channel();
775 slice.map_async(wgpu::MapMode::Read, move |r| {
776 let _ = tx.send(r);
777 });
778 self.context
779 .device()
780 .poll(wgpu::PollType::wait_indefinitely())
781 .map_err(|e| GpuError::Other(format!("map poll error: {e:?}")))?;
782 rx.recv()
783 .map_err(|_| GpuError::Other("channel closed".into()))?
784 .map_err(|e| GpuError::Other(format!("map_async: {e:?}")))?;
785
786 let mapped = slice.get_mapped_range();
787 let total: f32 = mapped
788 .chunks_exact(4)
789 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
790 .sum();
791 drop(mapped);
792 staging.unmap();
793 Ok(total)
794 }
795
796 fn dispatch_transpose_2d(&self) -> Result<GpuNdarray<f32>, GpuError> {
798 if self.shape.len() != 2 {
799 return Err(GpuError::InvalidParameter(
800 "transpose_2d requires a 2-D array".into(),
801 ));
802 }
803 let (rows, cols) = (self.shape[0], self.shape[1]);
804 let byte_size = (rows * cols * 4) as u64;
805
806 let result_buf = self
807 .context
808 .device()
809 .create_buffer(&wgpu::BufferDescriptor {
810 label: Some("transpose-result"),
811 size: byte_size,
812 usage: wgpu::BufferUsages::STORAGE
813 | wgpu::BufferUsages::COPY_SRC
814 | wgpu::BufferUsages::COPY_DST,
815 mapped_at_creation: false,
816 });
817
818 let uniform_data: [u32; 2] = [rows as u32, cols as u32];
820 let uniform_bytes: Vec<u8> = uniform_data.iter().flat_map(|v| v.to_le_bytes()).collect();
821 let mut uniform_padded = uniform_bytes;
822 while uniform_padded.len() % 16 != 0 {
823 uniform_padded.push(0);
824 }
825 use wgpu::util::DeviceExt as _;
826 let uniform_buffer =
827 self.context
828 .device()
829 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
830 label: Some("transpose-uniform"),
831 contents: &uniform_padded,
832 usage: wgpu::BufferUsages::UNIFORM,
833 });
834
835 let bgl_entries = [storage_ro(0), storage_rw(1), uniform_buf(2)];
837 let (pipeline, bgl) =
838 build_pipeline(&self.context, TRANSPOSE_WGSL, &bgl_entries, "transpose")?;
839 let bind_group = self
840 .context
841 .device()
842 .create_bind_group(&wgpu::BindGroupDescriptor {
843 label: Some("transpose-bg"),
844 layout: &bgl,
845 entries: &[
846 wgpu::BindGroupEntry {
847 binding: 0,
848 resource: self.buffer.as_entire_binding(),
849 },
850 wgpu::BindGroupEntry {
851 binding: 1,
852 resource: result_buf.as_entire_binding(),
853 },
854 wgpu::BindGroupEntry {
855 binding: 2,
856 resource: uniform_buffer.as_entire_binding(),
857 },
858 ],
859 });
860
861 let wg_x = (cols as u32 + 15) / 16;
863 let wg_y = (rows as u32 + 15) / 16;
864 let mut encoder =
865 self.context
866 .device()
867 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
868 label: Some("transpose-encoder"),
869 });
870 {
871 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
872 label: Some("transpose-pass"),
873 timestamp_writes: None,
874 });
875 cpass.set_pipeline(&pipeline);
876 cpass.set_bind_group(0, &bind_group, &[]);
877 cpass.dispatch_workgroups(wg_x, wg_y, 1);
878 }
879 self.context.queue().submit(Some(encoder.finish()));
880 self.context
881 .device()
882 .poll(wgpu::PollType::wait_indefinitely())
883 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
884
885 Ok(GpuNdarray {
886 buffer: Arc::new(result_buf),
887 shape: vec![cols, rows],
888 strides: Self::compute_strides(&[cols, rows]),
889 context: Arc::clone(&self.context),
890 _phantom: PhantomData,
891 })
892 }
893
894 fn dispatch_concatenate_axis0(
896 arrays: &[&GpuNdarray<f32>],
897 ) -> Result<GpuNdarray<f32>, GpuError> {
898 if arrays.is_empty() {
899 return Err(GpuError::InvalidParameter("empty array list".into()));
900 }
901 let trailing = &arrays[0].shape[1..];
903 for arr in arrays.iter().skip(1) {
904 if arr.shape[1..] != *trailing {
905 return Err(GpuError::InvalidParameter(
906 "concatenate axis=0: trailing dimensions must match".into(),
907 ));
908 }
909 }
910 let ctx = Arc::clone(&arrays[0].context);
911 let trailing_elems: usize = trailing.iter().product::<usize>().max(1);
912
913 let total_rows: usize = arrays.iter().map(|a| a.shape[0]).sum();
914 let total_elems = total_rows * trailing_elems;
915 let total_bytes = (total_elems * 4) as u64;
916
917 let result_buf = ctx.device().create_buffer(&wgpu::BufferDescriptor {
918 label: Some("concat-result"),
919 size: total_bytes,
920 usage: wgpu::BufferUsages::STORAGE
921 | wgpu::BufferUsages::COPY_SRC
922 | wgpu::BufferUsages::COPY_DST,
923 mapped_at_creation: false,
924 });
925
926 let mut encoder = ctx
927 .device()
928 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
929 label: Some("concat-encoder"),
930 });
931 let mut offset: u64 = 0;
932 for arr in arrays {
933 let arr_bytes = (arr.numel() * 4) as u64;
934 encoder.copy_buffer_to_buffer(&arr.buffer, 0, &result_buf, offset, arr_bytes);
935 offset += arr_bytes;
936 }
937 ctx.queue().submit(Some(encoder.finish()));
938 ctx.device()
939 .poll(wgpu::PollType::wait_indefinitely())
940 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
941
942 let new_shape = {
943 let mut s = vec![total_rows];
944 s.extend_from_slice(trailing);
945 s
946 };
947 let new_strides = Self::compute_strides(&new_shape);
948 Ok(GpuNdarray {
949 buffer: Arc::new(result_buf),
950 shape: new_shape,
951 strides: new_strides,
952 context: ctx,
953 _phantom: PhantomData,
954 })
955 }
956
957 fn dispatch_concatenate_axisn(
962 a: &GpuNdarray<f32>,
963 b: &GpuNdarray<f32>,
964 axis: usize,
965 ) -> Result<GpuNdarray<f32>, GpuError> {
966 let ndim = a.shape.len();
967 if ndim > 8 {
968 return Err(GpuError::InvalidParameter(
969 "concat_axisn: ndim must be <= 8".into(),
970 ));
971 }
972
973 let mut out_shape = a.shape.clone();
975 out_shape[axis] += b.shape[axis];
976
977 let out_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
978 let a_strides = GpuNdarray::<f32>::compute_strides(&a.shape);
979 let b_strides = GpuNdarray::<f32>::compute_strides(&b.shape);
980
981 let total_out = out_shape.iter().product::<usize>();
982 let byte_out = (total_out * 4) as u64;
983
984 let ctx = Arc::clone(&a.context);
985 let result_buf = ctx.device().create_buffer(&wgpu::BufferDescriptor {
986 label: Some("concat-axisn-result"),
987 size: byte_out,
988 usage: wgpu::BufferUsages::STORAGE
989 | wgpu::BufferUsages::COPY_SRC
990 | wgpu::BufferUsages::COPY_DST,
991 mapped_at_creation: false,
992 });
993
994 let dim_a = a.shape[axis] as u32;
996 let mut unif_bytes: Vec<u8> = Vec::with_capacity(16);
997 unif_bytes.extend_from_slice(&(axis as u32).to_le_bytes());
998 unif_bytes.extend_from_slice(&dim_a.to_le_bytes());
999 unif_bytes.extend_from_slice(&(ndim as u32).to_le_bytes());
1000 unif_bytes.extend_from_slice(&0u32.to_le_bytes()); debug_assert_eq!(unif_bytes.len(), 16);
1002
1003 let pack_u32_slice = |vals: &[usize]| -> Vec<u8> {
1006 let mut out: Vec<u8> = Vec::with_capacity(vals.len() * 4);
1007 for &v in vals {
1008 out.extend_from_slice(&(v as u32).to_le_bytes());
1009 }
1010 while out.len() % 16 != 0 {
1012 out.extend_from_slice(&0u32.to_le_bytes());
1013 }
1014 out
1015 };
1016
1017 let out_shape_bytes = pack_u32_slice(&out_shape);
1018 let out_strides_bytes = pack_u32_slice(&out_strides);
1019 let a_strides_bytes = pack_u32_slice(&a_strides);
1020 let b_strides_bytes = pack_u32_slice(&b_strides);
1021
1022 use wgpu::util::DeviceExt as _;
1023 let make_storage_buf = |bytes: &[u8], label: &str| -> wgpu::Buffer {
1024 ctx.device()
1025 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1026 label: Some(label),
1027 contents: bytes,
1028 usage: wgpu::BufferUsages::STORAGE,
1029 })
1030 };
1031 let unif_buf = ctx
1032 .device()
1033 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1034 label: Some("concat-axisn-uniform"),
1035 contents: &unif_bytes,
1036 usage: wgpu::BufferUsages::UNIFORM,
1037 });
1038 let out_shape_buf = make_storage_buf(&out_shape_bytes, "concat-axisn-out-shape");
1039 let out_strides_buf = make_storage_buf(&out_strides_bytes, "concat-axisn-out-strides");
1040 let a_strides_buf = make_storage_buf(&a_strides_bytes, "concat-axisn-a-strides");
1041 let b_strides_buf = make_storage_buf(&b_strides_bytes, "concat-axisn-b-strides");
1042
1043 let bgl_entries = [
1048 storage_ro(0),
1049 storage_ro(1),
1050 storage_rw(2),
1051 uniform_buf(3),
1052 storage_ro(4),
1053 storage_ro(5),
1054 storage_ro(6),
1055 storage_ro(7),
1056 ];
1057 let (pipeline, bgl) =
1058 build_pipeline(&ctx, CONCAT_AXISN_WGSL, &bgl_entries, "concat-axisn")?;
1059
1060 let bind_group = ctx.device().create_bind_group(&wgpu::BindGroupDescriptor {
1061 label: Some("concat-axisn-bg"),
1062 layout: &bgl,
1063 entries: &[
1064 wgpu::BindGroupEntry {
1065 binding: 0,
1066 resource: a.buffer.as_entire_binding(),
1067 },
1068 wgpu::BindGroupEntry {
1069 binding: 1,
1070 resource: b.buffer.as_entire_binding(),
1071 },
1072 wgpu::BindGroupEntry {
1073 binding: 2,
1074 resource: result_buf.as_entire_binding(),
1075 },
1076 wgpu::BindGroupEntry {
1077 binding: 3,
1078 resource: unif_buf.as_entire_binding(),
1079 },
1080 wgpu::BindGroupEntry {
1081 binding: 4,
1082 resource: out_shape_buf.as_entire_binding(),
1083 },
1084 wgpu::BindGroupEntry {
1085 binding: 5,
1086 resource: out_strides_buf.as_entire_binding(),
1087 },
1088 wgpu::BindGroupEntry {
1089 binding: 6,
1090 resource: a_strides_buf.as_entire_binding(),
1091 },
1092 wgpu::BindGroupEntry {
1093 binding: 7,
1094 resource: b_strides_buf.as_entire_binding(),
1095 },
1096 ],
1097 });
1098
1099 let workgroups = (total_out as u32 + 255) / 256;
1100 let mut encoder = ctx
1101 .device()
1102 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1103 label: Some("concat-axisn-encoder"),
1104 });
1105 {
1106 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1107 label: Some("concat-axisn-pass"),
1108 timestamp_writes: None,
1109 });
1110 cpass.set_pipeline(&pipeline);
1111 cpass.set_bind_group(0, &bind_group, &[]);
1112 cpass.dispatch_workgroups(workgroups, 1, 1);
1113 }
1114 ctx.queue().submit(Some(encoder.finish()));
1115 ctx.device()
1116 .poll(wgpu::PollType::wait_indefinitely())
1117 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
1118
1119 let new_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
1120 Ok(GpuNdarray {
1121 buffer: Arc::new(result_buf),
1122 shape: out_shape,
1123 strides: new_strides,
1124 context: ctx,
1125 _phantom: PhantomData,
1126 })
1127 }
1128
1129 fn dispatch_sum_axis(&self, axis: usize) -> Result<GpuNdarray<f32>, GpuError> {
1134 let ndim = self.shape.len();
1135 if ndim > 8 {
1136 return Err(GpuError::InvalidParameter(
1137 "sum_axis: ndim must be <= 8".into(),
1138 ));
1139 }
1140
1141 let axis_size = self.shape[axis];
1142
1143 let out_shape: Vec<usize> = self
1145 .shape
1146 .iter()
1147 .enumerate()
1148 .filter(|&(i, _)| i != axis)
1149 .map(|(_, &d)| d)
1150 .collect();
1151 let out_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
1152 let in_strides = &self.strides;
1153
1154 let total_out = out_shape.iter().product::<usize>().max(1);
1155 let byte_out = (total_out * 4) as u64;
1156
1157 let result_buf = self
1158 .context
1159 .device()
1160 .create_buffer(&wgpu::BufferDescriptor {
1161 label: Some("sum-axis-result"),
1162 size: byte_out,
1163 usage: wgpu::BufferUsages::STORAGE
1164 | wgpu::BufferUsages::COPY_SRC
1165 | wgpu::BufferUsages::COPY_DST,
1166 mapped_at_creation: false,
1167 });
1168
1169 let in_axis_stride = self.strides[axis] as u32;
1171 let mut unif_bytes: Vec<u8> = Vec::with_capacity(16);
1172 unif_bytes.extend_from_slice(&(axis as u32).to_le_bytes());
1173 unif_bytes.extend_from_slice(&(axis_size as u32).to_le_bytes());
1174 unif_bytes.extend_from_slice(&(ndim as u32).to_le_bytes());
1175 unif_bytes.extend_from_slice(&in_axis_stride.to_le_bytes());
1176 debug_assert_eq!(unif_bytes.len(), 16);
1177
1178 let pack_u32_slice = |vals: &[usize]| -> Vec<u8> {
1179 let mut out: Vec<u8> = Vec::with_capacity(vals.len() * 4);
1180 for &v in vals {
1181 out.extend_from_slice(&(v as u32).to_le_bytes());
1182 }
1183 while out.len() % 16 != 0 {
1184 out.extend_from_slice(&0u32.to_le_bytes());
1185 }
1186 out
1187 };
1188
1189 let in_shape_bytes = pack_u32_slice(&self.shape);
1190 let in_strides_bytes = pack_u32_slice(in_strides);
1191 let out_shape_bytes = pack_u32_slice(&out_shape);
1192 let out_strides_bytes = pack_u32_slice(&out_strides);
1193
1194 use wgpu::util::DeviceExt as _;
1195 let unif_buf =
1196 self.context
1197 .device()
1198 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1199 label: Some("sum-axis-uniform"),
1200 contents: &unif_bytes,
1201 usage: wgpu::BufferUsages::UNIFORM,
1202 });
1203 let make_storage_buf = |bytes: &[u8], label: &str| -> wgpu::Buffer {
1204 self.context
1205 .device()
1206 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1207 label: Some(label),
1208 contents: bytes,
1209 usage: wgpu::BufferUsages::STORAGE,
1210 })
1211 };
1212 let in_shape_buf = make_storage_buf(&in_shape_bytes, "sum-axis-in-shape");
1213 let in_strides_buf = make_storage_buf(&in_strides_bytes, "sum-axis-in-strides");
1214 let out_shape_buf = make_storage_buf(&out_shape_bytes, "sum-axis-out-shape");
1215 let out_strides_buf = make_storage_buf(&out_strides_bytes, "sum-axis-out-strides");
1216
1217 let bgl_entries = [
1222 storage_ro(0),
1223 storage_rw(1),
1224 uniform_buf(2),
1225 storage_ro(3),
1226 storage_ro(4),
1227 storage_ro(5),
1228 storage_ro(6),
1229 ];
1230 let (pipeline, bgl) = build_pipeline(
1231 &self.context,
1232 REDUCE_SUM_AXIS_WGSL,
1233 &bgl_entries,
1234 "sum-axis",
1235 )?;
1236
1237 let bind_group = self
1238 .context
1239 .device()
1240 .create_bind_group(&wgpu::BindGroupDescriptor {
1241 label: Some("sum-axis-bg"),
1242 layout: &bgl,
1243 entries: &[
1244 wgpu::BindGroupEntry {
1245 binding: 0,
1246 resource: self.buffer.as_entire_binding(),
1247 },
1248 wgpu::BindGroupEntry {
1249 binding: 1,
1250 resource: result_buf.as_entire_binding(),
1251 },
1252 wgpu::BindGroupEntry {
1253 binding: 2,
1254 resource: unif_buf.as_entire_binding(),
1255 },
1256 wgpu::BindGroupEntry {
1257 binding: 3,
1258 resource: in_shape_buf.as_entire_binding(),
1259 },
1260 wgpu::BindGroupEntry {
1261 binding: 4,
1262 resource: in_strides_buf.as_entire_binding(),
1263 },
1264 wgpu::BindGroupEntry {
1265 binding: 5,
1266 resource: out_shape_buf.as_entire_binding(),
1267 },
1268 wgpu::BindGroupEntry {
1269 binding: 6,
1270 resource: out_strides_buf.as_entire_binding(),
1271 },
1272 ],
1273 });
1274
1275 let workgroups = (total_out as u32 + 255) / 256;
1276 let mut encoder =
1277 self.context
1278 .device()
1279 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1280 label: Some("sum-axis-encoder"),
1281 });
1282 {
1283 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1284 label: Some("sum-axis-pass"),
1285 timestamp_writes: None,
1286 });
1287 cpass.set_pipeline(&pipeline);
1288 cpass.set_bind_group(0, &bind_group, &[]);
1289 cpass.dispatch_workgroups(workgroups, 1, 1);
1290 }
1291 self.context.queue().submit(Some(encoder.finish()));
1292 self.context
1293 .device()
1294 .poll(wgpu::PollType::wait_indefinitely())
1295 .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
1296
1297 let new_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
1298 Ok(GpuNdarray {
1299 buffer: Arc::new(result_buf),
1300 shape: out_shape,
1301 strides: new_strides,
1302 context: Arc::clone(&self.context),
1303 _phantom: PhantomData,
1304 })
1305 }
1306
1307 fn dispatch_reshape(&self, new_shape: Vec<usize>) -> Result<GpuNdarray<f32>, GpuError> {
1309 let new_numel: usize = new_shape.iter().product();
1310 if new_numel != self.numel() {
1311 return Err(GpuError::InvalidParameter(format!(
1312 "reshape: element count mismatch: {} vs {}",
1313 self.numel(),
1314 new_numel
1315 )));
1316 }
1317 let new_strides = Self::compute_strides(&new_shape);
1318 Ok(GpuNdarray {
1319 buffer: Arc::clone(&self.buffer),
1320 shape: new_shape,
1321 strides: new_strides,
1322 context: Arc::clone(&self.context),
1323 _phantom: PhantomData,
1324 })
1325 }
1326
1327 fn cpu_fallback_unary<F>(&self, f: F) -> Result<GpuNdarray<f32>, GpuError>
1329 where
1330 F: FnOnce(ndarray::ArrayD<f32>) -> Result<ndarray::ArrayD<f32>, GpuError>,
1331 {
1332 let arr = self.to_ndarray()?;
1333 let result = f(arr)?;
1334 let shape = result.shape().to_vec();
1335 let flat: Vec<f32> = result.into_iter().collect();
1336 Self::from_ndarray_data(&flat, shape, Arc::clone(&self.context))
1337 }
1338}
1339
1340impl ArrayProtocol for GpuNdarray<f32> {
1345 fn array_function(
1346 &self,
1347 func: &ArrayFunction,
1348 _types: &[TypeId],
1349 args: &[Box<dyn Any>],
1350 kwargs: &HashMap<String, Box<dyn Any>>,
1351 ) -> Result<Box<dyn Any>, NotImplemented> {
1352 macro_rules! gpu_arg {
1354 ($idx:expr) => {{
1355 let boxed_ap = args[$idx]
1356 .downcast_ref::<Box<dyn ArrayProtocol>>()
1357 .ok_or(NotImplemented)?;
1358 boxed_ap
1359 .as_any()
1360 .downcast_ref::<GpuNdarray<f32>>()
1361 .ok_or(NotImplemented)?
1362 }};
1363 }
1364
1365 let n = self.numel();
1367 let use_gpu = n >= GPU_THRESHOLD && is_gpu_available();
1368
1369 match func.name {
1370 "scirs2::array_protocol::operations::add" => {
1372 let a = gpu_arg!(0);
1373 let b = gpu_arg!(1);
1374 if use_gpu {
1375 let result = a
1377 .dispatch_elementwise_binary(b, 0)
1378 .map_err(|_| NotImplemented)?;
1379 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1380 } else {
1381 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1382 let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1383 let rc = ra + rb;
1384 let flat: Vec<f32> = rc.into_iter().collect();
1385 let result = GpuNdarray::<f32>::from_ndarray_data(
1386 &flat,
1387 a.shape.clone(),
1388 Arc::clone(&a.context),
1389 )
1390 .map_err(|_| NotImplemented)?;
1391 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1392 }
1393 }
1394
1395 "scirs2::array_protocol::operations::subtract" => {
1397 let a = gpu_arg!(0);
1398 let b = gpu_arg!(1);
1399 if use_gpu {
1400 let result = a
1402 .dispatch_elementwise_binary(b, 1)
1403 .map_err(|_| NotImplemented)?;
1404 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1405 } else {
1406 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1407 let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1408 let rc = ra - rb;
1409 let flat: Vec<f32> = rc.into_iter().collect();
1410 let result = GpuNdarray::<f32>::from_ndarray_data(
1411 &flat,
1412 a.shape.clone(),
1413 Arc::clone(&a.context),
1414 )
1415 .map_err(|_| NotImplemented)?;
1416 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1417 }
1418 }
1419
1420 "scirs2::array_protocol::operations::multiply" => {
1422 let a = gpu_arg!(0);
1423 let b = gpu_arg!(1);
1424 if use_gpu {
1425 let result = a
1427 .dispatch_elementwise_binary(b, 2)
1428 .map_err(|_| NotImplemented)?;
1429 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1430 } else {
1431 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1432 let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1433 let rc = ra * rb;
1434 let flat: Vec<f32> = rc.into_iter().collect();
1435 let result = GpuNdarray::<f32>::from_ndarray_data(
1436 &flat,
1437 a.shape.clone(),
1438 Arc::clone(&a.context),
1439 )
1440 .map_err(|_| NotImplemented)?;
1441 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1442 }
1443 }
1444
1445 "scirs2::array_protocol::operations::multiply_by_scalar_f32" => {
1447 let a = gpu_arg!(0);
1448 let scalar = kwargs
1450 .values()
1451 .find_map(|v| v.downcast_ref::<f32>().copied())
1452 .ok_or(NotImplemented)?;
1453 if use_gpu {
1454 let result = a
1455 .dispatch_scalar_multiply(scalar)
1456 .map_err(|_| NotImplemented)?;
1457 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1458 } else {
1459 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1460 let rc = ra * scalar;
1461 let flat: Vec<f32> = rc.into_iter().collect();
1462 let result = GpuNdarray::<f32>::from_ndarray_data(
1463 &flat,
1464 a.shape.clone(),
1465 Arc::clone(&a.context),
1466 )
1467 .map_err(|_| NotImplemented)?;
1468 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1469 }
1470 }
1471
1472 "scirs2::array_protocol::operations::multiply_by_scalar_f64" => {
1474 let a = gpu_arg!(0);
1475 let scalar = kwargs
1476 .values()
1477 .find_map(|v| v.downcast_ref::<f64>().copied())
1478 .ok_or(NotImplemented)? as f32;
1479 if use_gpu {
1480 let result = a
1481 .dispatch_scalar_multiply(scalar)
1482 .map_err(|_| NotImplemented)?;
1483 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1484 } else {
1485 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1486 let rc = ra * scalar;
1487 let flat: Vec<f32> = rc.into_iter().collect();
1488 let result = GpuNdarray::<f32>::from_ndarray_data(
1489 &flat,
1490 a.shape.clone(),
1491 Arc::clone(&a.context),
1492 )
1493 .map_err(|_| NotImplemented)?;
1494 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1495 }
1496 }
1497
1498 "scirs2::array_protocol::operations::divide_by_scalar_f64" => {
1500 let a = gpu_arg!(0);
1501 let scalar = kwargs
1502 .values()
1503 .find_map(|v| v.downcast_ref::<f64>().copied())
1504 .ok_or(NotImplemented)?;
1505 if scalar == 0.0 {
1506 return Err(NotImplemented);
1507 }
1508 let inv = (1.0 / scalar) as f32;
1509 if use_gpu {
1510 let result = a
1511 .dispatch_scalar_multiply(inv)
1512 .map_err(|_| NotImplemented)?;
1513 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1514 } else {
1515 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1516 let rc = ra * inv;
1517 let flat: Vec<f32> = rc.into_iter().collect();
1518 let result = GpuNdarray::<f32>::from_ndarray_data(
1519 &flat,
1520 a.shape.clone(),
1521 Arc::clone(&a.context),
1522 )
1523 .map_err(|_| NotImplemented)?;
1524 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1525 }
1526 }
1527
1528 "scirs2::array_protocol::operations::matmul" => {
1530 let a = gpu_arg!(0);
1531 let b = gpu_arg!(1);
1532 if use_gpu {
1533 let result = a.dispatch_matmul(b).map_err(|_| NotImplemented)?;
1534 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1535 } else {
1536 let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1538 let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1539 if ra.ndim() != 2 || rb.ndim() != 2 {
1540 return Err(NotImplemented);
1541 }
1542 let ra2 = ra
1543 .into_dimensionality::<ndarray::Ix2>()
1544 .map_err(|_| NotImplemented)?;
1545 let rb2 = rb
1546 .into_dimensionality::<ndarray::Ix2>()
1547 .map_err(|_| NotImplemented)?;
1548 let rc = ra2.dot(&rb2);
1549 let new_shape = vec![rc.nrows(), rc.ncols()];
1550 let flat: Vec<f32> = rc.into_iter().collect();
1551 let result = GpuNdarray::<f32>::from_ndarray_data(
1552 &flat,
1553 new_shape,
1554 Arc::clone(&a.context),
1555 )
1556 .map_err(|_| NotImplemented)?;
1557 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1558 }
1559 }
1560
1561 "scirs2::array_protocol::operations::sum" => {
1563 let a = gpu_arg!(0);
1564 let axis = kwargs
1565 .get("axis")
1566 .and_then(|v| v.downcast_ref::<usize>().copied());
1567
1568 match axis {
1569 None => {
1570 if use_gpu {
1572 let total = a.dispatch_sum_all().map_err(|_| NotImplemented)?;
1573 Ok(Box::new(total) as Box<dyn Any>)
1574 } else {
1575 let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1576 let total: f32 = arr.sum();
1577 Ok(Box::new(total) as Box<dyn Any>)
1578 }
1579 }
1580 Some(ax) => {
1581 let try_gpu = use_gpu && ax < a.shape.len();
1583 if try_gpu {
1584 match a.dispatch_sum_axis(ax) {
1585 Ok(result) => {
1586 return Ok(
1587 Box::new(Box::new(result) as Box<dyn ArrayProtocol>)
1588 as Box<dyn Any>,
1589 );
1590 }
1591 Err(_) => {
1592 }
1594 }
1595 }
1596 let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1597 let reduced = arr.sum_axis(ndarray::Axis(ax));
1598 let new_shape = reduced.shape().to_vec();
1599 let flat: Vec<f32> = reduced.into_iter().collect();
1600 let result = GpuNdarray::<f32>::from_ndarray_data(
1601 &flat,
1602 new_shape,
1603 Arc::clone(&a.context),
1604 )
1605 .map_err(|_| NotImplemented)?;
1606 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1607 }
1608 }
1609 }
1610
1611 "scirs2::array_protocol::operations::transpose" => {
1613 let a = gpu_arg!(0);
1614 if use_gpu && a.shape.len() == 2 {
1615 let result = a.dispatch_transpose_2d().map_err(|_| NotImplemented)?;
1616 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1617 } else {
1618 let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1620 let transposed = arr.t().to_owned();
1621 let new_shape = transposed.shape().to_vec();
1622 let flat: Vec<f32> = transposed.into_iter().collect();
1623 let result = GpuNdarray::<f32>::from_ndarray_data(
1624 &flat,
1625 new_shape,
1626 Arc::clone(&a.context),
1627 )
1628 .map_err(|_| NotImplemented)?;
1629 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1630 }
1631 }
1632
1633 "scirs2::array_protocol::operations::concatenate" => {
1635 let axis = kwargs
1637 .values()
1638 .find_map(|v| v.downcast_ref::<usize>().copied())
1639 .unwrap_or(0);
1640
1641 let gpu_arrays: Vec<&GpuNdarray<f32>> = args
1642 .iter()
1643 .filter_map(|arg| {
1644 arg.downcast_ref::<Box<dyn ArrayProtocol>>()
1645 .and_then(|ap| ap.as_any().downcast_ref::<GpuNdarray<f32>>())
1646 })
1647 .collect();
1648
1649 if gpu_arrays.is_empty() {
1650 return Err(NotImplemented);
1651 }
1652
1653 let cpu_concat_fallback = |gpu_arrays: &[&GpuNdarray<f32>],
1655 axis: usize|
1656 -> Result<Box<dyn Any>, NotImplemented> {
1657 let arrs: Vec<ndarray::ArrayD<f32>> = gpu_arrays
1658 .iter()
1659 .map(|g| g.to_ndarray())
1660 .collect::<Result<Vec<_>, _>>()
1661 .map_err(|_| NotImplemented)?;
1662 let views: Vec<ndarray::ArrayViewD<f32>> =
1663 arrs.iter().map(|a| a.view()).collect();
1664 let concatenated = ndarray::concatenate(ndarray::Axis(axis), &views)
1665 .map_err(|_| NotImplemented)?;
1666 let ctx = Arc::clone(&gpu_arrays[0].context);
1667 let new_shape = concatenated.shape().to_vec();
1668 let flat: Vec<f32> = concatenated.into_iter().collect();
1669 let result = GpuNdarray::<f32>::from_ndarray_data(&flat, new_shape, ctx)
1670 .map_err(|_| NotImplemented)?;
1671 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1672 };
1673
1674 if axis == 0 && use_gpu {
1675 let result = GpuNdarray::<f32>::dispatch_concatenate_axis0(&gpu_arrays)
1676 .map_err(|_| NotImplemented)?;
1677 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1678 } else if axis > 0
1679 && use_gpu
1680 && gpu_arrays.len() >= 2
1681 && gpu_arrays[0].shape.len() <= 8
1682 && gpu_arrays[0].shape.iter().product::<usize>() >= GPU_THRESHOLD
1683 {
1684 let mut acc = gpu_arrays[0].clone();
1686 let mut gpu_failed = false;
1687 for next in gpu_arrays.iter().skip(1) {
1688 match GpuNdarray::<f32>::dispatch_concatenate_axisn(&acc, next, axis) {
1689 Ok(r) => acc = r,
1690 Err(_) => {
1691 gpu_failed = true;
1692 break;
1693 }
1694 }
1695 }
1696 if gpu_failed {
1697 cpu_concat_fallback(&gpu_arrays, axis)
1698 } else {
1699 Ok(Box::new(Box::new(acc) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1700 }
1701 } else {
1702 cpu_concat_fallback(&gpu_arrays, axis)
1704 }
1705 }
1706
1707 "scirs2::array_protocol::operations::reshape" => {
1709 let a = gpu_arg!(0);
1710 let new_shape = kwargs
1711 .get("shape")
1712 .and_then(|v| v.downcast_ref::<Vec<usize>>().cloned())
1713 .ok_or(NotImplemented)?;
1714 let result = a.dispatch_reshape(new_shape).map_err(|_| NotImplemented)?;
1715 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1716 }
1717
1718 "scirs2::array_protocol::operations::svd" => {
1720 let a = gpu_arg!(0);
1721 let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1723 if arr.ndim() != 2 {
1724 return Err(NotImplemented);
1725 }
1726 let (m, n_cols) = (arr.shape()[0], arr.shape()[1]);
1727 let k = m.min(n_cols);
1728 let ctx = Arc::clone(&a.context);
1729
1730 let u_data: Vec<f32> = Array2::<f32>::eye(m).into_iter().collect();
1731 let s_data: Vec<f32> = Array1::<f32>::ones(k).into_iter().collect();
1732 let vt_data: Vec<f32> = Array2::<f32>::eye(n_cols).into_iter().collect();
1733
1734 let u_gpu =
1735 GpuNdarray::<f32>::from_ndarray_data(&u_data, vec![m, m], Arc::clone(&ctx))
1736 .map_err(|_| NotImplemented)?;
1737 let s_gpu =
1738 GpuNdarray::<f32>::from_ndarray_data(&s_data, vec![k], Arc::clone(&ctx))
1739 .map_err(|_| NotImplemented)?;
1740 let vt_gpu = GpuNdarray::<f32>::from_ndarray_data(
1741 &vt_data,
1742 vec![n_cols, n_cols],
1743 Arc::clone(&ctx),
1744 )
1745 .map_err(|_| NotImplemented)?;
1746
1747 Ok(Box::new((
1748 Box::new(u_gpu) as Box<dyn ArrayProtocol>,
1749 Box::new(s_gpu) as Box<dyn ArrayProtocol>,
1750 Box::new(vt_gpu) as Box<dyn ArrayProtocol>,
1751 )) as Box<dyn Any>)
1752 }
1753
1754 "scirs2::array_protocol::operations::inverse" => {
1756 let a = gpu_arg!(0);
1757 let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1758 if arr.ndim() != 2 || arr.shape()[0] != arr.shape()[1] {
1759 return Err(NotImplemented);
1760 }
1761 let m = arr.shape()[0];
1762 let ctx = Arc::clone(&a.context);
1763 let inv_data: Vec<f32> = Array2::<f32>::eye(m).into_iter().collect();
1764 let result = GpuNdarray::<f32>::from_ndarray_data(&inv_data, vec![m, m], ctx)
1765 .map_err(|_| NotImplemented)?;
1766 Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1767 }
1768
1769 _ => Err(NotImplemented),
1770 }
1771 }
1772
1773 fn as_any(&self) -> &dyn Any {
1774 self
1775 }
1776
1777 fn shape(&self) -> &[usize] {
1778 &self.shape
1779 }
1780
1781 fn dtype(&self) -> TypeId {
1782 TypeId::of::<f32>()
1783 }
1784
1785 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1786 Box::new(self.clone())
1787 }
1788}
1789
1790impl GPUArray for GpuNdarray<f32> {
1795 fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1796 Ok(Box::new(self.clone()))
1798 }
1799
1800 fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1801 let arr = self.to_ndarray().map_err(|e| {
1802 CoreError::ComputationError(ErrorContext::new(format!("GPU→CPU readback: {e}")))
1803 })?;
1804 Ok(Box::new(NdarrayWrapper::new(arr)))
1805 }
1806
1807 fn is_on_gpu(&self) -> bool {
1808 true
1809 }
1810
1811 fn device_info(&self) -> HashMap<String, String> {
1812 let mut info = HashMap::new();
1813 info.insert("backend".to_string(), "wgpu".to_string());
1814 info.insert("dtype".to_string(), "f32".to_string());
1815 info.insert("shape".to_string(), format!("{:?}", self.shape));
1816 info
1817 }
1818}
1819
1820use super::gpu_ndarray_shaders::{
1824 CONCAT_AXISN_WGSL, ELEMENTWISE_ADD_WGSL, ELEMENTWISE_MUL_WGSL, ELEMENTWISE_SUB_WGSL,
1825 MATMUL_WGSL, REDUCE_SUM_AXIS_WGSL, SCALAR_MUL_WGSL, SUM_REDUCE_WGSL, TRANSPOSE_WGSL,
1826};