Skip to main content

scirs2_core/array_protocol/
gpu_ndarray.rs

1// Copyright (c) 2025, `SciRS2` Team
2//
3// Licensed under the Apache License, Version 2.0
4// (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5//
6
7//! Real wgpu-backed `GpuNdarray<T>` that implements `ArrayProtocol`.
8//!
9//! Enabled only with the `array_protocol_wgpu` feature, which implies `wgpu_backend`.
10//!
11//! ## Supported operations (GPU dispatch)
12//! - `add`, `subtract`, `multiply` — elementwise binary, workgroup (256,1,1), uses `arrayLength`
13//! - `multiply_by_scalar_f32` — elementwise scalar multiply, workgroup (256,1,1)
14//! - `matmul` — naive (one thread per output element), workgroup (16,16,1)
15//! - `sum(axis=None)` — two-pass reduce, workgroup (256,1,1)
16//! - `transpose` (2-D) — 16×16 bank-conflict-padded tile, workgroup (16,16,1)
17//!   (32×32 exceeds Metal's 256-invocation-per-workgroup limit)
18//! - `concatenate(axis=0)` — via `copy_buffer_to_buffer`, no shader
19//! - `concatenate(axis>0)` — WGSL gather kernel (`CONCAT_AXISN_WGSL`), storage-buf strides
20//! - `sum(axis=Some(ax))` — WGSL per-output-element axis reduction (`REDUCE_SUM_AXIS_WGSL`)
21//! - `reshape` — zero-copy (clone `Arc<Buffer>`, new shape/strides)
22//!
23//! ## CPU-fallback operations
24//! - `svd` — falls back to CPU `NdarrayWrapper`
25//! - `inverse` — falls back to CPU `NdarrayWrapper`
26//! - `multiply_by_scalar_f64`, `divide_by_scalar_f64` — convert to f32, then GPU
27//! - GPU kernel errors on axis ops — fallback to CPU (graceful degradation)
28//!
29//! ## GPU threshold
30//! Arrays with fewer than 4096 elements skip GPU dispatch entirely and fall back to CPU.
31
32use std::any::{Any, TypeId};
33use std::collections::HashMap;
34use std::marker::PhantomData;
35use std::sync::{Arc, OnceLock};
36
37use ndarray::{Array1, Array2, IxDyn};
38
39use crate::array_protocol::{
40    ArrayFunction, ArrayProtocol, GPUArray, NdarrayWrapper, NotImplemented,
41};
42use crate::error::{CoreError, CoreResult, ErrorContext};
43use crate::gpu::backends::WebGPUContext;
44use crate::gpu::GpuError;
45
46// ──────────────────────────────────────────────────────────────────────
47// 1. GpuScalar — sealed trait, blanket impl for f32 only
48// ──────────────────────────────────────────────────────────────────────
49
50mod sealed {
51    pub trait Sealed {}
52}
53
54/// Marker trait for element types that wgpu-29 supports natively (f32 only;
55/// f64 is not supported in WGSL without extensions).
56pub trait GpuScalar: sealed::Sealed + Clone + Send + Sync + 'static {}
57
58impl sealed::Sealed for f32 {}
59impl GpuScalar for f32 {}
60
61// ──────────────────────────────────────────────────────────────────────
62// 2. GPU dispatch threshold, availability cache, and singleton context
63// ──────────────────────────────────────────────────────────────────────
64
65/// Arrays smaller than this skip GPU dispatch entirely.
66const GPU_THRESHOLD: usize = 4096;
67
68/// Cached GPU availability flag (computed once per process).
69static GPU_AVAILABLE: OnceLock<bool> = OnceLock::new();
70
71/// Cached singleton `WebGPUContext` — all `GpuNdarray` share one device.
72static GPU_CONTEXT: OnceLock<Option<Arc<WebGPUContext>>> = OnceLock::new();
73
74/// Returns the shared `WebGPUContext`, or `None` if no adapter is available.
75///
76/// Initializes the singleton device once; all subsequent calls return the cached value.
77pub fn global_context() -> Option<Arc<WebGPUContext>> {
78    GPU_CONTEXT
79        .get_or_init(|| match WebGPUContext::new() {
80            Ok(ctx) => Some(Arc::new(ctx)),
81            Err(_) => None,
82        })
83        .clone()
84}
85
86/// Returns `true` if a wgpu adapter was found when first called; cached afterwards.
87pub fn is_gpu_available() -> bool {
88    *GPU_AVAILABLE.get_or_init(|| global_context().is_some())
89}
90
91// ──────────────────────────────────────────────────────────────────────
92// 3. GpuNdarray<T>
93// ──────────────────────────────────────────────────────────────────────
94
95/// A GPU-backed n-dimensional array backed by a real wgpu `Buffer`.
96///
97/// Created via [`GpuNdarray::from_ndarray_data`] or [`GpuNdarray::from_data`].
98/// Converted back with [`GpuNdarray::to_ndarray`].
99pub struct GpuNdarray<T: GpuScalar> {
100    /// The live wgpu `Buffer` on the device.  Shared under `Arc` to make
101    /// `reshape` zero-copy and `Clone` cheap.
102    buffer: Arc<wgpu::Buffer>,
103
104    /// Logical shape (row-major).
105    shape: Vec<usize>,
106
107    /// Row-major strides in *elements* (not bytes).
108    strides: Vec<usize>,
109
110    /// Shared device/queue context.
111    context: Arc<WebGPUContext>,
112
113    _phantom: PhantomData<T>,
114}
115
116impl<T: GpuScalar> std::fmt::Debug for GpuNdarray<T> {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.debug_struct("GpuNdarray")
119            .field("shape", &self.shape)
120            .field("strides", &self.strides)
121            .finish_non_exhaustive()
122    }
123}
124
125impl<T: GpuScalar> Clone for GpuNdarray<T> {
126    fn clone(&self) -> Self {
127        Self {
128            buffer: Arc::clone(&self.buffer),
129            shape: self.shape.clone(),
130            strides: self.strides.clone(),
131            context: Arc::clone(&self.context),
132            _phantom: PhantomData,
133        }
134    }
135}
136
137impl<T: GpuScalar> GpuNdarray<T> {
138    /// Expose the underlying `Arc<wgpu::Buffer>` for zero-copy checks in tests.
139    #[must_use]
140    pub fn buffer_arc(&self) -> &Arc<wgpu::Buffer> {
141        &self.buffer
142    }
143
144    /// Total number of elements.
145    #[must_use]
146    fn numel(&self) -> usize {
147        self.shape.iter().product()
148    }
149
150    /// Row-major strides for the given shape.
151    fn compute_strides(shape: &[usize]) -> Vec<usize> {
152        let mut strides = vec![1usize; shape.len()];
153        for i in (0..shape.len().saturating_sub(1)).rev() {
154            strides[i] = strides[i + 1] * shape[i + 1];
155        }
156        strides
157    }
158}
159
160// ──────────────────────────────────────────────────────────────────────
161// 4. Low-level pipeline builder helper (not part of GpuNdarray impl)
162// ──────────────────────────────────────────────────────────────────────
163
164/// Build a compute pipeline with an explicit bind group layout.
165///
166/// `bgl_entries` must match the WGSL `@group(0)` bindings exactly.
167fn build_pipeline(
168    ctx: &WebGPUContext,
169    wgsl: &str,
170    bgl_entries: &[wgpu::BindGroupLayoutEntry],
171    label: &str,
172) -> Result<(wgpu::ComputePipeline, wgpu::BindGroupLayout), GpuError> {
173    let device = ctx.device();
174    let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
175        label: Some(label),
176        source: wgpu::ShaderSource::Wgsl(wgsl.into()),
177    });
178    let bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
179        label: Some(&format!("{label}_bgl")),
180        entries: bgl_entries,
181    });
182    let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
183        label: Some(&format!("{label}_layout")),
184        bind_group_layouts: &[Some(&bgl)],
185        ..Default::default()
186    });
187    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
188        label: Some(&format!("{label}_pipeline")),
189        layout: Some(&pipeline_layout),
190        module: &shader_module,
191        entry_point: Some("main"),
192        compilation_options: Default::default(),
193        cache: None,
194    });
195    Ok((pipeline, bgl))
196}
197
198/// Shorthand BGL entry for a read-only storage buffer.
199fn storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
200    wgpu::BindGroupLayoutEntry {
201        binding,
202        visibility: wgpu::ShaderStages::COMPUTE,
203        ty: wgpu::BindingType::Buffer {
204            ty: wgpu::BufferBindingType::Storage { read_only: true },
205            has_dynamic_offset: false,
206            min_binding_size: None,
207        },
208        count: None,
209    }
210}
211
212/// Shorthand BGL entry for a read-write storage buffer.
213fn storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
214    wgpu::BindGroupLayoutEntry {
215        binding,
216        visibility: wgpu::ShaderStages::COMPUTE,
217        ty: wgpu::BindingType::Buffer {
218            ty: wgpu::BufferBindingType::Storage { read_only: false },
219            has_dynamic_offset: false,
220            min_binding_size: None,
221        },
222        count: None,
223    }
224}
225
226/// Shorthand BGL entry for a uniform buffer.
227fn uniform_buf(binding: u32) -> wgpu::BindGroupLayoutEntry {
228    wgpu::BindGroupLayoutEntry {
229        binding,
230        visibility: wgpu::ShaderStages::COMPUTE,
231        ty: wgpu::BindingType::Buffer {
232            ty: wgpu::BufferBindingType::Uniform,
233            has_dynamic_offset: false,
234            min_binding_size: None,
235        },
236        count: None,
237    }
238}
239
240// ──────────────────────────────────────────────────────────────────────
241// 5. from_ndarray / to_ndarray for f32
242// ──────────────────────────────────────────────────────────────────────
243
244impl GpuNdarray<f32> {
245    /// Upload a CPU `NdarrayWrapper<f32, _>` to the GPU.
246    ///
247    /// Returns `Err(GpuError)` if no adapter is available or upload fails.
248    pub fn from_ndarray_data(
249        data: &[f32],
250        shape: Vec<usize>,
251        context: Arc<WebGPUContext>,
252    ) -> Result<Self, GpuError> {
253        use wgpu::util::DeviceExt as _;
254        let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
255        let buffer = context
256            .device()
257            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
258                label: Some("GpuNdarray<f32>"),
259                contents: &bytes,
260                usage: wgpu::BufferUsages::STORAGE
261                    | wgpu::BufferUsages::COPY_SRC
262                    | wgpu::BufferUsages::COPY_DST,
263            });
264        let strides = Self::compute_strides(&shape);
265        Ok(Self {
266            buffer: Arc::new(buffer),
267            shape,
268            strides,
269            context,
270            _phantom: PhantomData,
271        })
272    }
273
274    /// Download the GPU buffer to a flat `Vec<f32>`.
275    pub fn to_vec(&self) -> Result<Vec<f32>, GpuError> {
276        let byte_size = (self.numel() * std::mem::size_of::<f32>()) as u64;
277        let staging = self
278            .context
279            .device()
280            .create_buffer(&wgpu::BufferDescriptor {
281                label: Some("GpuNdarray-readback"),
282                size: byte_size,
283                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
284                mapped_at_creation: false,
285            });
286
287        let mut encoder =
288            self.context
289                .device()
290                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
291                    label: Some("GpuNdarray-readback-encoder"),
292                });
293        encoder.copy_buffer_to_buffer(&self.buffer, 0, &staging, 0, byte_size);
294        self.context.queue().submit(Some(encoder.finish()));
295
296        self.context
297            .device()
298            .poll(wgpu::PollType::wait_indefinitely())
299            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
300
301        let slice = staging.slice(0..byte_size);
302        let (tx, rx) = std::sync::mpsc::channel();
303        slice.map_async(wgpu::MapMode::Read, move |r| {
304            let _ = tx.send(r);
305        });
306
307        self.context
308            .device()
309            .poll(wgpu::PollType::wait_indefinitely())
310            .map_err(|e| GpuError::Other(format!("poll-map error: {e:?}")))?;
311
312        rx.recv()
313            .map_err(|_| GpuError::Other("channel closed".into()))?
314            .map_err(|e| GpuError::Other(format!("map_async failed: {e:?}")))?;
315
316        let mapped = slice.get_mapped_range();
317        let result: Vec<f32> = mapped
318            .chunks_exact(4)
319            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
320            .collect();
321        drop(mapped);
322        staging.unmap();
323        Ok(result)
324    }
325
326    /// Download to a dynamic ndarray.
327    pub fn to_ndarray(&self) -> Result<ndarray::ArrayD<f32>, GpuError> {
328        let flat = self.to_vec()?;
329        ndarray::ArrayD::<f32>::from_shape_vec(self.shape.clone(), flat)
330            .map_err(|e| GpuError::Other(format!("shape_vec error: {e}")))
331    }
332
333    /// Construct from a `WebGPUContext`, uploading `data`.
334    ///
335    /// Uses the process-wide singleton `WebGPUContext` so that all arrays
336    /// created via this function share the same `wgpu::Device` and can be
337    /// combined freely in kernel dispatches.
338    pub fn from_data(data: &[f32], shape: Vec<usize>) -> Result<Self, GpuError> {
339        let ctx =
340            global_context().ok_or_else(|| GpuError::Other("No wgpu adapter available".into()))?;
341        Self::from_ndarray_data(data, shape, ctx)
342    }
343
344    // ─── public high-level ops ─────────────────────────────────────────
345
346    /// Elementwise add: `self + other`.  Both must have the same shape.
347    pub fn add(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
348        self.dispatch_elementwise_binary(other, 0)
349    }
350
351    /// Elementwise subtract: `self - other`.  Both must have the same shape.
352    pub fn subtract(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
353        self.dispatch_elementwise_binary(other, 1)
354    }
355
356    /// Elementwise multiply: `self * other`.  Both must have the same shape.
357    pub fn multiply(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
358        self.dispatch_elementwise_binary(other, 2)
359    }
360
361    /// Multiply every element by a scalar: `self * scalar`.
362    pub fn multiply_by_scalar_f32(&self, scalar: f32) -> Result<GpuNdarray<f32>, GpuError> {
363        self.dispatch_scalar_multiply(scalar)
364    }
365
366    /// Compute the sum of all elements (dot product via `a.multiply(b)?.sum_all()`).
367    pub fn sum_all(&self) -> Result<f32, GpuError> {
368        self.dispatch_sum_all()
369    }
370
371    /// Compute the dot product `self · other = sum(self * other)`.
372    pub fn dot_gpu(&self, other: &GpuNdarray<f32>) -> Result<f32, GpuError> {
373        let prod = self.dispatch_elementwise_binary(other, 2)?;
374        prod.dispatch_sum_all()
375    }
376
377    /// Tiled 16x16 matrix multiplication: self x other.
378    ///
379    /// Both arrays must be 2-D.  For a matrix-vector product, upload the
380    /// vector as a [n, 1] column; the result will have shape [m, 1].
381    pub fn matmul(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
382        self.dispatch_matmul(other)
383    }
384
385    // ─── internal GPU kernel dispatchers ───────────────────────────────
386
387    /// Dispatch an elementwise binary kernel (`add`, `subtract`, `multiply`).
388    ///
389    /// `op_id`: 0=add, 1=subtract, 2=multiply.
390    fn dispatch_elementwise_binary(
391        &self,
392        other: &GpuNdarray<f32>,
393        op_id: u32,
394    ) -> Result<GpuNdarray<f32>, GpuError> {
395        let n = self.numel();
396        if n != other.numel() {
397            return Err(GpuError::InvalidParameter(
398                "shape mismatch in elementwise binary".into(),
399            ));
400        }
401        let wgsl = match op_id {
402            0 => ELEMENTWISE_ADD_WGSL,
403            1 => ELEMENTWISE_SUB_WGSL,
404            _ => ELEMENTWISE_MUL_WGSL,
405        };
406        let byte_size = (n * 4) as u64;
407        let result_buf = self
408            .context
409            .device()
410            .create_buffer(&wgpu::BufferDescriptor {
411                label: Some("elementwise-result"),
412                size: byte_size,
413                usage: wgpu::BufferUsages::STORAGE
414                    | wgpu::BufferUsages::COPY_SRC
415                    | wgpu::BufferUsages::COPY_DST,
416                mapped_at_creation: false,
417            });
418
419        let bgl_entries = [storage_ro(0), storage_ro(1), storage_rw(2)];
420        let (pipeline, bgl) = build_pipeline(&self.context, wgsl, &bgl_entries, "elementwise")?;
421
422        let bind_group = self
423            .context
424            .device()
425            .create_bind_group(&wgpu::BindGroupDescriptor {
426                label: Some("elementwise-bg"),
427                layout: &bgl,
428                entries: &[
429                    wgpu::BindGroupEntry {
430                        binding: 0,
431                        resource: self.buffer.as_entire_binding(),
432                    },
433                    wgpu::BindGroupEntry {
434                        binding: 1,
435                        resource: other.buffer.as_entire_binding(),
436                    },
437                    wgpu::BindGroupEntry {
438                        binding: 2,
439                        resource: result_buf.as_entire_binding(),
440                    },
441                ],
442            });
443
444        let mut encoder =
445            self.context
446                .device()
447                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
448                    label: Some("elementwise-encoder"),
449                });
450        {
451            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
452                label: Some("elementwise-pass"),
453                timestamp_writes: None,
454            });
455            cpass.set_pipeline(&pipeline);
456            cpass.set_bind_group(0, &bind_group, &[]);
457            let workgroups = (n as u32 + 255) / 256;
458            cpass.dispatch_workgroups(workgroups, 1, 1);
459        }
460        self.context.queue().submit(Some(encoder.finish()));
461        self.context
462            .device()
463            .poll(wgpu::PollType::wait_indefinitely())
464            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
465
466        Ok(GpuNdarray {
467            buffer: Arc::new(result_buf),
468            shape: self.shape.clone(),
469            strides: self.strides.clone(),
470            context: Arc::clone(&self.context),
471            _phantom: PhantomData,
472        })
473    }
474
475    /// Dispatch scalar multiply: each element * scalar.
476    fn dispatch_scalar_multiply(&self, scalar: f32) -> Result<GpuNdarray<f32>, GpuError> {
477        let n = self.numel();
478        let byte_size = (n * 4) as u64;
479        let result_buf = self
480            .context
481            .device()
482            .create_buffer(&wgpu::BufferDescriptor {
483                label: Some("scalar-mul-result"),
484                size: byte_size,
485                usage: wgpu::BufferUsages::STORAGE
486                    | wgpu::BufferUsages::COPY_SRC
487                    | wgpu::BufferUsages::COPY_DST,
488                mapped_at_creation: false,
489            });
490
491        // Uniform: scalar (f32), n (u32) — 8 bytes padded to 16
492        let mut unif: Vec<u8> = Vec::with_capacity(16);
493        unif.extend_from_slice(&scalar.to_le_bytes());
494        unif.extend_from_slice(&(n as u32).to_le_bytes());
495        while unif.len() % 16 != 0 {
496            unif.push(0);
497        }
498        use wgpu::util::DeviceExt as _;
499        let uniform_buffer =
500            self.context
501                .device()
502                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
503                    label: Some("scalar-mul-uniform"),
504                    contents: &unif,
505                    usage: wgpu::BufferUsages::UNIFORM,
506                });
507
508        let bgl_entries = [storage_ro(0), storage_rw(1), uniform_buf(2)];
509        let (pipeline, bgl) =
510            build_pipeline(&self.context, SCALAR_MUL_WGSL, &bgl_entries, "scalar-mul")?;
511
512        let bind_group = self
513            .context
514            .device()
515            .create_bind_group(&wgpu::BindGroupDescriptor {
516                label: Some("scalar-mul-bg"),
517                layout: &bgl,
518                entries: &[
519                    wgpu::BindGroupEntry {
520                        binding: 0,
521                        resource: self.buffer.as_entire_binding(),
522                    },
523                    wgpu::BindGroupEntry {
524                        binding: 1,
525                        resource: result_buf.as_entire_binding(),
526                    },
527                    wgpu::BindGroupEntry {
528                        binding: 2,
529                        resource: uniform_buffer.as_entire_binding(),
530                    },
531                ],
532            });
533
534        let mut encoder =
535            self.context
536                .device()
537                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
538                    label: Some("scalar-mul-encoder"),
539                });
540        {
541            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
542                label: Some("scalar-mul-pass"),
543                timestamp_writes: None,
544            });
545            cpass.set_pipeline(&pipeline);
546            cpass.set_bind_group(0, &bind_group, &[]);
547            let workgroups = (n as u32 + 255) / 256;
548            cpass.dispatch_workgroups(workgroups, 1, 1);
549        }
550        self.context.queue().submit(Some(encoder.finish()));
551        self.context
552            .device()
553            .poll(wgpu::PollType::wait_indefinitely())
554            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
555
556        Ok(GpuNdarray {
557            buffer: Arc::new(result_buf),
558            shape: self.shape.clone(),
559            strides: self.strides.clone(),
560            context: Arc::clone(&self.context),
561            _phantom: PhantomData,
562        })
563    }
564
565    /// Dispatch tiled 16×16 matmul for 2-D arrays.
566    fn dispatch_matmul(&self, other: &GpuNdarray<f32>) -> Result<GpuNdarray<f32>, GpuError> {
567        if self.shape.len() != 2 || other.shape.len() != 2 {
568            return Err(GpuError::InvalidParameter(
569                "matmul requires 2-D arrays".into(),
570            ));
571        }
572        let (m, k) = (self.shape[0], self.shape[1]);
573        let (k2, n) = (other.shape[0], other.shape[1]);
574        if k != k2 {
575            return Err(GpuError::InvalidParameter(format!(
576                "matmul shape mismatch: [{m},{k}] x [{k2},{n}]"
577            )));
578        }
579
580        let byte_size = (m * n * 4) as u64;
581        let result_buf = self
582            .context
583            .device()
584            .create_buffer(&wgpu::BufferDescriptor {
585                label: Some("matmul-result"),
586                size: byte_size,
587                usage: wgpu::BufferUsages::STORAGE
588                    | wgpu::BufferUsages::COPY_SRC
589                    | wgpu::BufferUsages::COPY_DST,
590                mapped_at_creation: false,
591            });
592
593        let uniform_data: [u32; 3] = [m as u32, n as u32, k as u32];
594        let uniform_bytes: Vec<u8> = uniform_data.iter().flat_map(|v| v.to_le_bytes()).collect();
595        // Pad to 16-byte alignment (wgpu uniform requirement)
596        let mut uniform_padded = uniform_bytes;
597        while uniform_padded.len() % 16 != 0 {
598            uniform_padded.push(0);
599        }
600        use wgpu::util::DeviceExt as _;
601        let uniform_buffer =
602            self.context
603                .device()
604                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
605                    label: Some("matmul-uniform"),
606                    contents: &uniform_padded,
607                    usage: wgpu::BufferUsages::UNIFORM,
608                });
609
610        let bgl_entries = [storage_ro(0), storage_ro(1), storage_rw(2), uniform_buf(3)];
611        let (pipeline, bgl) = build_pipeline(&self.context, MATMUL_WGSL, &bgl_entries, "matmul")?;
612        let bind_group = self
613            .context
614            .device()
615            .create_bind_group(&wgpu::BindGroupDescriptor {
616                label: Some("matmul-bg"),
617                layout: &bgl,
618                entries: &[
619                    wgpu::BindGroupEntry {
620                        binding: 0,
621                        resource: self.buffer.as_entire_binding(),
622                    },
623                    wgpu::BindGroupEntry {
624                        binding: 1,
625                        resource: other.buffer.as_entire_binding(),
626                    },
627                    wgpu::BindGroupEntry {
628                        binding: 2,
629                        resource: result_buf.as_entire_binding(),
630                    },
631                    wgpu::BindGroupEntry {
632                        binding: 3,
633                        resource: uniform_buffer.as_entire_binding(),
634                    },
635                ],
636            });
637
638        let wg_x = (n as u32 + 15) / 16;
639        let wg_y = (m as u32 + 15) / 16;
640        let mut encoder =
641            self.context
642                .device()
643                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
644                    label: Some("matmul-encoder"),
645                });
646        {
647            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
648                label: Some("matmul-pass"),
649                timestamp_writes: None,
650            });
651            cpass.set_pipeline(&pipeline);
652            cpass.set_bind_group(0, &bind_group, &[]);
653            cpass.dispatch_workgroups(wg_x, wg_y, 1);
654        }
655        self.context.queue().submit(Some(encoder.finish()));
656        self.context
657            .device()
658            .poll(wgpu::PollType::wait_indefinitely())
659            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
660
661        Ok(GpuNdarray {
662            buffer: Arc::new(result_buf),
663            shape: vec![m, n],
664            strides: Self::compute_strides(&[m, n]),
665            context: Arc::clone(&self.context),
666            _phantom: PhantomData,
667        })
668    }
669
670    /// Dispatch two-pass sum reduction (axis=None).
671    fn dispatch_sum_all(&self) -> Result<f32, GpuError> {
672        let n = self.numel();
673        // First pass: workgroup partial sums
674        let workgroup_count = (n as u32 + 255) / 256;
675        let partial_byte_size = (workgroup_count as usize * 4) as u64;
676        let partial_buf = self
677            .context
678            .device()
679            .create_buffer(&wgpu::BufferDescriptor {
680                label: Some("sum-partial"),
681                size: partial_byte_size,
682                usage: wgpu::BufferUsages::STORAGE
683                    | wgpu::BufferUsages::COPY_SRC
684                    | wgpu::BufferUsages::COPY_DST,
685                mapped_at_creation: false,
686            });
687
688        let n_bytes = (n as u32).to_le_bytes();
689        let mut uniform_bytes = n_bytes.to_vec();
690        while uniform_bytes.len() % 16 != 0 {
691            uniform_bytes.push(0);
692        }
693        use wgpu::util::DeviceExt as _;
694        let uniform_buffer =
695            self.context
696                .device()
697                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
698                    label: Some("sum-uniform"),
699                    contents: &uniform_bytes,
700                    usage: wgpu::BufferUsages::UNIFORM,
701                });
702
703        let bgl_entries = [storage_ro(0), storage_rw(1), uniform_buf(2)];
704        let (pipeline, bgl) =
705            build_pipeline(&self.context, SUM_REDUCE_WGSL, &bgl_entries, "sum-reduce")?;
706        let bind_group = self
707            .context
708            .device()
709            .create_bind_group(&wgpu::BindGroupDescriptor {
710                label: Some("sum-bg"),
711                layout: &bgl,
712                entries: &[
713                    wgpu::BindGroupEntry {
714                        binding: 0,
715                        resource: self.buffer.as_entire_binding(),
716                    },
717                    wgpu::BindGroupEntry {
718                        binding: 1,
719                        resource: partial_buf.as_entire_binding(),
720                    },
721                    wgpu::BindGroupEntry {
722                        binding: 2,
723                        resource: uniform_buffer.as_entire_binding(),
724                    },
725                ],
726            });
727
728        let mut encoder =
729            self.context
730                .device()
731                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
732                    label: Some("sum-encoder"),
733                });
734        {
735            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
736                label: Some("sum-pass"),
737                timestamp_writes: None,
738            });
739            cpass.set_pipeline(&pipeline);
740            cpass.set_bind_group(0, &bind_group, &[]);
741            cpass.dispatch_workgroups(workgroup_count, 1, 1);
742        }
743        self.context.queue().submit(Some(encoder.finish()));
744        self.context
745            .device()
746            .poll(wgpu::PollType::wait_indefinitely())
747            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
748
749        // Read back partial sums and sum on CPU (second pass)
750        let staging = self
751            .context
752            .device()
753            .create_buffer(&wgpu::BufferDescriptor {
754                label: Some("sum-staging"),
755                size: partial_byte_size,
756                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
757                mapped_at_creation: false,
758            });
759        let mut encoder2 =
760            self.context
761                .device()
762                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
763                    label: Some("sum-copy-encoder"),
764                });
765        encoder2.copy_buffer_to_buffer(&partial_buf, 0, &staging, 0, partial_byte_size);
766        self.context.queue().submit(Some(encoder2.finish()));
767
768        self.context
769            .device()
770            .poll(wgpu::PollType::wait_indefinitely())
771            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
772
773        let slice = staging.slice(0..partial_byte_size);
774        let (tx, rx) = std::sync::mpsc::channel();
775        slice.map_async(wgpu::MapMode::Read, move |r| {
776            let _ = tx.send(r);
777        });
778        self.context
779            .device()
780            .poll(wgpu::PollType::wait_indefinitely())
781            .map_err(|e| GpuError::Other(format!("map poll error: {e:?}")))?;
782        rx.recv()
783            .map_err(|_| GpuError::Other("channel closed".into()))?
784            .map_err(|e| GpuError::Other(format!("map_async: {e:?}")))?;
785
786        let mapped = slice.get_mapped_range();
787        let total: f32 = mapped
788            .chunks_exact(4)
789            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
790            .sum();
791        drop(mapped);
792        staging.unmap();
793        Ok(total)
794    }
795
796    /// Dispatch tiled 32×32 2-D transpose.
797    fn dispatch_transpose_2d(&self) -> Result<GpuNdarray<f32>, GpuError> {
798        if self.shape.len() != 2 {
799            return Err(GpuError::InvalidParameter(
800                "transpose_2d requires a 2-D array".into(),
801            ));
802        }
803        let (rows, cols) = (self.shape[0], self.shape[1]);
804        let byte_size = (rows * cols * 4) as u64;
805
806        let result_buf = self
807            .context
808            .device()
809            .create_buffer(&wgpu::BufferDescriptor {
810                label: Some("transpose-result"),
811                size: byte_size,
812                usage: wgpu::BufferUsages::STORAGE
813                    | wgpu::BufferUsages::COPY_SRC
814                    | wgpu::BufferUsages::COPY_DST,
815                mapped_at_creation: false,
816            });
817
818        // Uniform: rows (u32), cols (u32)
819        let uniform_data: [u32; 2] = [rows as u32, cols as u32];
820        let uniform_bytes: Vec<u8> = uniform_data.iter().flat_map(|v| v.to_le_bytes()).collect();
821        let mut uniform_padded = uniform_bytes;
822        while uniform_padded.len() % 16 != 0 {
823            uniform_padded.push(0);
824        }
825        use wgpu::util::DeviceExt as _;
826        let uniform_buffer =
827            self.context
828                .device()
829                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
830                    label: Some("transpose-uniform"),
831                    contents: &uniform_padded,
832                    usage: wgpu::BufferUsages::UNIFORM,
833                });
834
835        // Use 16×16 workgroup — Metal limit is 256 total invocations
836        let bgl_entries = [storage_ro(0), storage_rw(1), uniform_buf(2)];
837        let (pipeline, bgl) =
838            build_pipeline(&self.context, TRANSPOSE_WGSL, &bgl_entries, "transpose")?;
839        let bind_group = self
840            .context
841            .device()
842            .create_bind_group(&wgpu::BindGroupDescriptor {
843                label: Some("transpose-bg"),
844                layout: &bgl,
845                entries: &[
846                    wgpu::BindGroupEntry {
847                        binding: 0,
848                        resource: self.buffer.as_entire_binding(),
849                    },
850                    wgpu::BindGroupEntry {
851                        binding: 1,
852                        resource: result_buf.as_entire_binding(),
853                    },
854                    wgpu::BindGroupEntry {
855                        binding: 2,
856                        resource: uniform_buffer.as_entire_binding(),
857                    },
858                ],
859            });
860
861        // Workgroup size is 16×16 tiles
862        let wg_x = (cols as u32 + 15) / 16;
863        let wg_y = (rows as u32 + 15) / 16;
864        let mut encoder =
865            self.context
866                .device()
867                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
868                    label: Some("transpose-encoder"),
869                });
870        {
871            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
872                label: Some("transpose-pass"),
873                timestamp_writes: None,
874            });
875            cpass.set_pipeline(&pipeline);
876            cpass.set_bind_group(0, &bind_group, &[]);
877            cpass.dispatch_workgroups(wg_x, wg_y, 1);
878        }
879        self.context.queue().submit(Some(encoder.finish()));
880        self.context
881            .device()
882            .poll(wgpu::PollType::wait_indefinitely())
883            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
884
885        Ok(GpuNdarray {
886            buffer: Arc::new(result_buf),
887            shape: vec![cols, rows],
888            strides: Self::compute_strides(&[cols, rows]),
889            context: Arc::clone(&self.context),
890            _phantom: PhantomData,
891        })
892    }
893
894    /// Concatenate along axis=0 via `copy_buffer_to_buffer` (no shader).
895    fn dispatch_concatenate_axis0(
896        arrays: &[&GpuNdarray<f32>],
897    ) -> Result<GpuNdarray<f32>, GpuError> {
898        if arrays.is_empty() {
899            return Err(GpuError::InvalidParameter("empty array list".into()));
900        }
901        // Validate trailing dimensions match
902        let trailing = &arrays[0].shape[1..];
903        for arr in arrays.iter().skip(1) {
904            if arr.shape[1..] != *trailing {
905                return Err(GpuError::InvalidParameter(
906                    "concatenate axis=0: trailing dimensions must match".into(),
907                ));
908            }
909        }
910        let ctx = Arc::clone(&arrays[0].context);
911        let trailing_elems: usize = trailing.iter().product::<usize>().max(1);
912
913        let total_rows: usize = arrays.iter().map(|a| a.shape[0]).sum();
914        let total_elems = total_rows * trailing_elems;
915        let total_bytes = (total_elems * 4) as u64;
916
917        let result_buf = ctx.device().create_buffer(&wgpu::BufferDescriptor {
918            label: Some("concat-result"),
919            size: total_bytes,
920            usage: wgpu::BufferUsages::STORAGE
921                | wgpu::BufferUsages::COPY_SRC
922                | wgpu::BufferUsages::COPY_DST,
923            mapped_at_creation: false,
924        });
925
926        let mut encoder = ctx
927            .device()
928            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
929                label: Some("concat-encoder"),
930            });
931        let mut offset: u64 = 0;
932        for arr in arrays {
933            let arr_bytes = (arr.numel() * 4) as u64;
934            encoder.copy_buffer_to_buffer(&arr.buffer, 0, &result_buf, offset, arr_bytes);
935            offset += arr_bytes;
936        }
937        ctx.queue().submit(Some(encoder.finish()));
938        ctx.device()
939            .poll(wgpu::PollType::wait_indefinitely())
940            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
941
942        let new_shape = {
943            let mut s = vec![total_rows];
944            s.extend_from_slice(trailing);
945            s
946        };
947        let new_strides = Self::compute_strides(&new_shape);
948        Ok(GpuNdarray {
949            buffer: Arc::new(result_buf),
950            shape: new_shape,
951            strides: new_strides,
952            context: ctx,
953            _phantom: PhantomData,
954        })
955    }
956
957    /// Concatenate two arrays along an arbitrary axis (axis > 0) using a GPU gather kernel.
958    ///
959    /// Each output element is handled by one invocation; the kernel decomposes the flat output
960    /// index into multi-dim coordinates, branches on the axis coordinate, and gathers from A or B.
961    fn dispatch_concatenate_axisn(
962        a: &GpuNdarray<f32>,
963        b: &GpuNdarray<f32>,
964        axis: usize,
965    ) -> Result<GpuNdarray<f32>, GpuError> {
966        let ndim = a.shape.len();
967        if ndim > 8 {
968            return Err(GpuError::InvalidParameter(
969                "concat_axisn: ndim must be <= 8".into(),
970            ));
971        }
972
973        // Compute output shape
974        let mut out_shape = a.shape.clone();
975        out_shape[axis] += b.shape[axis];
976
977        let out_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
978        let a_strides = GpuNdarray::<f32>::compute_strides(&a.shape);
979        let b_strides = GpuNdarray::<f32>::compute_strides(&b.shape);
980
981        let total_out = out_shape.iter().product::<usize>();
982        let byte_out = (total_out * 4) as u64;
983
984        let ctx = Arc::clone(&a.context);
985        let result_buf = ctx.device().create_buffer(&wgpu::BufferDescriptor {
986            label: Some("concat-axisn-result"),
987            size: byte_out,
988            usage: wgpu::BufferUsages::STORAGE
989                | wgpu::BufferUsages::COPY_SRC
990                | wgpu::BufferUsages::COPY_DST,
991            mapped_at_creation: false,
992        });
993
994        // Pack uniform: axis (u32), dim_a (u32), ndim (u32), _pad (u32) — 16 bytes
995        let dim_a = a.shape[axis] as u32;
996        let mut unif_bytes: Vec<u8> = Vec::with_capacity(16);
997        unif_bytes.extend_from_slice(&(axis as u32).to_le_bytes());
998        unif_bytes.extend_from_slice(&dim_a.to_le_bytes());
999        unif_bytes.extend_from_slice(&(ndim as u32).to_le_bytes());
1000        unif_bytes.extend_from_slice(&0u32.to_le_bytes()); // _pad
1001        debug_assert_eq!(unif_bytes.len(), 16);
1002
1003        // Pack shape/strides as storage buffers (avoid WGSL uniform 16-byte-per-element alignment)
1004        // Each: array of u32, length = ndim (padded to multiple of 4 u32s for buffer sizing)
1005        let pack_u32_slice = |vals: &[usize]| -> Vec<u8> {
1006            let mut out: Vec<u8> = Vec::with_capacity(vals.len() * 4);
1007            for &v in vals {
1008                out.extend_from_slice(&(v as u32).to_le_bytes());
1009            }
1010            // Pad to 16-byte boundary
1011            while out.len() % 16 != 0 {
1012                out.extend_from_slice(&0u32.to_le_bytes());
1013            }
1014            out
1015        };
1016
1017        let out_shape_bytes = pack_u32_slice(&out_shape);
1018        let out_strides_bytes = pack_u32_slice(&out_strides);
1019        let a_strides_bytes = pack_u32_slice(&a_strides);
1020        let b_strides_bytes = pack_u32_slice(&b_strides);
1021
1022        use wgpu::util::DeviceExt as _;
1023        let make_storage_buf = |bytes: &[u8], label: &str| -> wgpu::Buffer {
1024            ctx.device()
1025                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1026                    label: Some(label),
1027                    contents: bytes,
1028                    usage: wgpu::BufferUsages::STORAGE,
1029                })
1030        };
1031        let unif_buf = ctx
1032            .device()
1033            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1034                label: Some("concat-axisn-uniform"),
1035                contents: &unif_bytes,
1036                usage: wgpu::BufferUsages::UNIFORM,
1037            });
1038        let out_shape_buf = make_storage_buf(&out_shape_bytes, "concat-axisn-out-shape");
1039        let out_strides_buf = make_storage_buf(&out_strides_bytes, "concat-axisn-out-strides");
1040        let a_strides_buf = make_storage_buf(&a_strides_bytes, "concat-axisn-a-strides");
1041        let b_strides_buf = make_storage_buf(&b_strides_bytes, "concat-axisn-b-strides");
1042
1043        // Bindings:
1044        // 0=a (storage_ro), 1=b (storage_ro), 2=result (storage_rw),
1045        // 3=uniform, 4=out_shape (storage_ro), 5=out_strides (storage_ro),
1046        // 6=a_strides (storage_ro), 7=b_strides (storage_ro)
1047        let bgl_entries = [
1048            storage_ro(0),
1049            storage_ro(1),
1050            storage_rw(2),
1051            uniform_buf(3),
1052            storage_ro(4),
1053            storage_ro(5),
1054            storage_ro(6),
1055            storage_ro(7),
1056        ];
1057        let (pipeline, bgl) =
1058            build_pipeline(&ctx, CONCAT_AXISN_WGSL, &bgl_entries, "concat-axisn")?;
1059
1060        let bind_group = ctx.device().create_bind_group(&wgpu::BindGroupDescriptor {
1061            label: Some("concat-axisn-bg"),
1062            layout: &bgl,
1063            entries: &[
1064                wgpu::BindGroupEntry {
1065                    binding: 0,
1066                    resource: a.buffer.as_entire_binding(),
1067                },
1068                wgpu::BindGroupEntry {
1069                    binding: 1,
1070                    resource: b.buffer.as_entire_binding(),
1071                },
1072                wgpu::BindGroupEntry {
1073                    binding: 2,
1074                    resource: result_buf.as_entire_binding(),
1075                },
1076                wgpu::BindGroupEntry {
1077                    binding: 3,
1078                    resource: unif_buf.as_entire_binding(),
1079                },
1080                wgpu::BindGroupEntry {
1081                    binding: 4,
1082                    resource: out_shape_buf.as_entire_binding(),
1083                },
1084                wgpu::BindGroupEntry {
1085                    binding: 5,
1086                    resource: out_strides_buf.as_entire_binding(),
1087                },
1088                wgpu::BindGroupEntry {
1089                    binding: 6,
1090                    resource: a_strides_buf.as_entire_binding(),
1091                },
1092                wgpu::BindGroupEntry {
1093                    binding: 7,
1094                    resource: b_strides_buf.as_entire_binding(),
1095                },
1096            ],
1097        });
1098
1099        let workgroups = (total_out as u32 + 255) / 256;
1100        let mut encoder = ctx
1101            .device()
1102            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1103                label: Some("concat-axisn-encoder"),
1104            });
1105        {
1106            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1107                label: Some("concat-axisn-pass"),
1108                timestamp_writes: None,
1109            });
1110            cpass.set_pipeline(&pipeline);
1111            cpass.set_bind_group(0, &bind_group, &[]);
1112            cpass.dispatch_workgroups(workgroups, 1, 1);
1113        }
1114        ctx.queue().submit(Some(encoder.finish()));
1115        ctx.device()
1116            .poll(wgpu::PollType::wait_indefinitely())
1117            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
1118
1119        let new_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
1120        Ok(GpuNdarray {
1121            buffer: Arc::new(result_buf),
1122            shape: out_shape,
1123            strides: new_strides,
1124            context: ctx,
1125            _phantom: PhantomData,
1126        })
1127    }
1128
1129    /// Reduce one axis via GPU: output[out_idx] = sum over j in 0..shape[axis] of input[...j...].
1130    ///
1131    /// Each invocation handles one output element.  Shape/strides are passed as
1132    /// `storage_ro` buffers to avoid WGSL uniform 16-byte-per-element alignment.
1133    fn dispatch_sum_axis(&self, axis: usize) -> Result<GpuNdarray<f32>, GpuError> {
1134        let ndim = self.shape.len();
1135        if ndim > 8 {
1136            return Err(GpuError::InvalidParameter(
1137                "sum_axis: ndim must be <= 8".into(),
1138            ));
1139        }
1140
1141        let axis_size = self.shape[axis];
1142
1143        // Output shape: remove `axis` dimension
1144        let out_shape: Vec<usize> = self
1145            .shape
1146            .iter()
1147            .enumerate()
1148            .filter(|&(i, _)| i != axis)
1149            .map(|(_, &d)| d)
1150            .collect();
1151        let out_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
1152        let in_strides = &self.strides;
1153
1154        let total_out = out_shape.iter().product::<usize>().max(1);
1155        let byte_out = (total_out * 4) as u64;
1156
1157        let result_buf = self
1158            .context
1159            .device()
1160            .create_buffer(&wgpu::BufferDescriptor {
1161                label: Some("sum-axis-result"),
1162                size: byte_out,
1163                usage: wgpu::BufferUsages::STORAGE
1164                    | wgpu::BufferUsages::COPY_SRC
1165                    | wgpu::BufferUsages::COPY_DST,
1166                mapped_at_creation: false,
1167            });
1168
1169        // Uniform: axis (u32), axis_size (u32), ndim (u32), in_axis_stride (u32) — 16 bytes
1170        let in_axis_stride = self.strides[axis] as u32;
1171        let mut unif_bytes: Vec<u8> = Vec::with_capacity(16);
1172        unif_bytes.extend_from_slice(&(axis as u32).to_le_bytes());
1173        unif_bytes.extend_from_slice(&(axis_size as u32).to_le_bytes());
1174        unif_bytes.extend_from_slice(&(ndim as u32).to_le_bytes());
1175        unif_bytes.extend_from_slice(&in_axis_stride.to_le_bytes());
1176        debug_assert_eq!(unif_bytes.len(), 16);
1177
1178        let pack_u32_slice = |vals: &[usize]| -> Vec<u8> {
1179            let mut out: Vec<u8> = Vec::with_capacity(vals.len() * 4);
1180            for &v in vals {
1181                out.extend_from_slice(&(v as u32).to_le_bytes());
1182            }
1183            while out.len() % 16 != 0 {
1184                out.extend_from_slice(&0u32.to_le_bytes());
1185            }
1186            out
1187        };
1188
1189        let in_shape_bytes = pack_u32_slice(&self.shape);
1190        let in_strides_bytes = pack_u32_slice(in_strides);
1191        let out_shape_bytes = pack_u32_slice(&out_shape);
1192        let out_strides_bytes = pack_u32_slice(&out_strides);
1193
1194        use wgpu::util::DeviceExt as _;
1195        let unif_buf =
1196            self.context
1197                .device()
1198                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1199                    label: Some("sum-axis-uniform"),
1200                    contents: &unif_bytes,
1201                    usage: wgpu::BufferUsages::UNIFORM,
1202                });
1203        let make_storage_buf = |bytes: &[u8], label: &str| -> wgpu::Buffer {
1204            self.context
1205                .device()
1206                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
1207                    label: Some(label),
1208                    contents: bytes,
1209                    usage: wgpu::BufferUsages::STORAGE,
1210                })
1211        };
1212        let in_shape_buf = make_storage_buf(&in_shape_bytes, "sum-axis-in-shape");
1213        let in_strides_buf = make_storage_buf(&in_strides_bytes, "sum-axis-in-strides");
1214        let out_shape_buf = make_storage_buf(&out_shape_bytes, "sum-axis-out-shape");
1215        let out_strides_buf = make_storage_buf(&out_strides_bytes, "sum-axis-out-strides");
1216
1217        // Bindings:
1218        // 0=input (storage_ro), 1=result (storage_rw), 2=uniform,
1219        // 3=in_shape (storage_ro), 4=in_strides (storage_ro),
1220        // 5=out_shape (storage_ro), 6=out_strides (storage_ro)
1221        let bgl_entries = [
1222            storage_ro(0),
1223            storage_rw(1),
1224            uniform_buf(2),
1225            storage_ro(3),
1226            storage_ro(4),
1227            storage_ro(5),
1228            storage_ro(6),
1229        ];
1230        let (pipeline, bgl) = build_pipeline(
1231            &self.context,
1232            REDUCE_SUM_AXIS_WGSL,
1233            &bgl_entries,
1234            "sum-axis",
1235        )?;
1236
1237        let bind_group = self
1238            .context
1239            .device()
1240            .create_bind_group(&wgpu::BindGroupDescriptor {
1241                label: Some("sum-axis-bg"),
1242                layout: &bgl,
1243                entries: &[
1244                    wgpu::BindGroupEntry {
1245                        binding: 0,
1246                        resource: self.buffer.as_entire_binding(),
1247                    },
1248                    wgpu::BindGroupEntry {
1249                        binding: 1,
1250                        resource: result_buf.as_entire_binding(),
1251                    },
1252                    wgpu::BindGroupEntry {
1253                        binding: 2,
1254                        resource: unif_buf.as_entire_binding(),
1255                    },
1256                    wgpu::BindGroupEntry {
1257                        binding: 3,
1258                        resource: in_shape_buf.as_entire_binding(),
1259                    },
1260                    wgpu::BindGroupEntry {
1261                        binding: 4,
1262                        resource: in_strides_buf.as_entire_binding(),
1263                    },
1264                    wgpu::BindGroupEntry {
1265                        binding: 5,
1266                        resource: out_shape_buf.as_entire_binding(),
1267                    },
1268                    wgpu::BindGroupEntry {
1269                        binding: 6,
1270                        resource: out_strides_buf.as_entire_binding(),
1271                    },
1272                ],
1273            });
1274
1275        let workgroups = (total_out as u32 + 255) / 256;
1276        let mut encoder =
1277            self.context
1278                .device()
1279                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
1280                    label: Some("sum-axis-encoder"),
1281                });
1282        {
1283            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1284                label: Some("sum-axis-pass"),
1285                timestamp_writes: None,
1286            });
1287            cpass.set_pipeline(&pipeline);
1288            cpass.set_bind_group(0, &bind_group, &[]);
1289            cpass.dispatch_workgroups(workgroups, 1, 1);
1290        }
1291        self.context.queue().submit(Some(encoder.finish()));
1292        self.context
1293            .device()
1294            .poll(wgpu::PollType::wait_indefinitely())
1295            .map_err(|e| GpuError::Other(format!("poll error: {e:?}")))?;
1296
1297        let new_strides = GpuNdarray::<f32>::compute_strides(&out_shape);
1298        Ok(GpuNdarray {
1299            buffer: Arc::new(result_buf),
1300            shape: out_shape,
1301            strides: new_strides,
1302            context: Arc::clone(&self.context),
1303            _phantom: PhantomData,
1304        })
1305    }
1306
1307    /// Zero-copy reshape: clone `Arc<Buffer>`, new shape/strides.
1308    fn dispatch_reshape(&self, new_shape: Vec<usize>) -> Result<GpuNdarray<f32>, GpuError> {
1309        let new_numel: usize = new_shape.iter().product();
1310        if new_numel != self.numel() {
1311            return Err(GpuError::InvalidParameter(format!(
1312                "reshape: element count mismatch: {} vs {}",
1313                self.numel(),
1314                new_numel
1315            )));
1316        }
1317        let new_strides = Self::compute_strides(&new_shape);
1318        Ok(GpuNdarray {
1319            buffer: Arc::clone(&self.buffer),
1320            shape: new_shape,
1321            strides: new_strides,
1322            context: Arc::clone(&self.context),
1323            _phantom: PhantomData,
1324        })
1325    }
1326
1327    /// CPU fallback: download, run ndarray op, re-upload.
1328    fn cpu_fallback_unary<F>(&self, f: F) -> Result<GpuNdarray<f32>, GpuError>
1329    where
1330        F: FnOnce(ndarray::ArrayD<f32>) -> Result<ndarray::ArrayD<f32>, GpuError>,
1331    {
1332        let arr = self.to_ndarray()?;
1333        let result = f(arr)?;
1334        let shape = result.shape().to_vec();
1335        let flat: Vec<f32> = result.into_iter().collect();
1336        Self::from_ndarray_data(&flat, shape, Arc::clone(&self.context))
1337    }
1338}
1339
1340// ──────────────────────────────────────────────────────────────────────
1341// 5. ArrayProtocol impl
1342// ──────────────────────────────────────────────────────────────────────
1343
1344impl ArrayProtocol for GpuNdarray<f32> {
1345    fn array_function(
1346        &self,
1347        func: &ArrayFunction,
1348        _types: &[TypeId],
1349        args: &[Box<dyn Any>],
1350        kwargs: &HashMap<String, Box<dyn Any>>,
1351    ) -> Result<Box<dyn Any>, NotImplemented> {
1352        // Helper: extract GpuNdarray<f32> from args[idx]
1353        macro_rules! gpu_arg {
1354            ($idx:expr) => {{
1355                let boxed_ap = args[$idx]
1356                    .downcast_ref::<Box<dyn ArrayProtocol>>()
1357                    .ok_or(NotImplemented)?;
1358                boxed_ap
1359                    .as_any()
1360                    .downcast_ref::<GpuNdarray<f32>>()
1361                    .ok_or(NotImplemented)?
1362            }};
1363        }
1364
1365        // Helper: fall back to CPU if below threshold or GPU unavailable
1366        let n = self.numel();
1367        let use_gpu = n >= GPU_THRESHOLD && is_gpu_available();
1368
1369        match func.name {
1370            // ── elementwise add ─────────────────────────────────────────
1371            "scirs2::array_protocol::operations::add" => {
1372                let a = gpu_arg!(0);
1373                let b = gpu_arg!(1);
1374                if use_gpu {
1375                    // OP_ID 0 = add
1376                    let result = a
1377                        .dispatch_elementwise_binary(b, 0)
1378                        .map_err(|_| NotImplemented)?;
1379                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1380                } else {
1381                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1382                    let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1383                    let rc = ra + rb;
1384                    let flat: Vec<f32> = rc.into_iter().collect();
1385                    let result = GpuNdarray::<f32>::from_ndarray_data(
1386                        &flat,
1387                        a.shape.clone(),
1388                        Arc::clone(&a.context),
1389                    )
1390                    .map_err(|_| NotImplemented)?;
1391                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1392                }
1393            }
1394
1395            // ── elementwise subtract ─────────────────────────────────────
1396            "scirs2::array_protocol::operations::subtract" => {
1397                let a = gpu_arg!(0);
1398                let b = gpu_arg!(1);
1399                if use_gpu {
1400                    // OP_ID 1 = subtract
1401                    let result = a
1402                        .dispatch_elementwise_binary(b, 1)
1403                        .map_err(|_| NotImplemented)?;
1404                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1405                } else {
1406                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1407                    let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1408                    let rc = ra - rb;
1409                    let flat: Vec<f32> = rc.into_iter().collect();
1410                    let result = GpuNdarray::<f32>::from_ndarray_data(
1411                        &flat,
1412                        a.shape.clone(),
1413                        Arc::clone(&a.context),
1414                    )
1415                    .map_err(|_| NotImplemented)?;
1416                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1417                }
1418            }
1419
1420            // ── elementwise multiply ─────────────────────────────────────
1421            "scirs2::array_protocol::operations::multiply" => {
1422                let a = gpu_arg!(0);
1423                let b = gpu_arg!(1);
1424                if use_gpu {
1425                    // OP_ID 2 = multiply
1426                    let result = a
1427                        .dispatch_elementwise_binary(b, 2)
1428                        .map_err(|_| NotImplemented)?;
1429                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1430                } else {
1431                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1432                    let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1433                    let rc = ra * rb;
1434                    let flat: Vec<f32> = rc.into_iter().collect();
1435                    let result = GpuNdarray::<f32>::from_ndarray_data(
1436                        &flat,
1437                        a.shape.clone(),
1438                        Arc::clone(&a.context),
1439                    )
1440                    .map_err(|_| NotImplemented)?;
1441                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1442                }
1443            }
1444
1445            // ── multiply_by_scalar_f32 ────────────────────────────────────
1446            "scirs2::array_protocol::operations::multiply_by_scalar_f32" => {
1447                let a = gpu_arg!(0);
1448                // scalar value is stored as kwarg with key = scalar.to_string()
1449                let scalar = kwargs
1450                    .values()
1451                    .find_map(|v| v.downcast_ref::<f32>().copied())
1452                    .ok_or(NotImplemented)?;
1453                if use_gpu {
1454                    let result = a
1455                        .dispatch_scalar_multiply(scalar)
1456                        .map_err(|_| NotImplemented)?;
1457                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1458                } else {
1459                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1460                    let rc = ra * scalar;
1461                    let flat: Vec<f32> = rc.into_iter().collect();
1462                    let result = GpuNdarray::<f32>::from_ndarray_data(
1463                        &flat,
1464                        a.shape.clone(),
1465                        Arc::clone(&a.context),
1466                    )
1467                    .map_err(|_| NotImplemented)?;
1468                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1469                }
1470            }
1471
1472            // ── multiply_by_scalar_f64 — convert scalar to f32 ────────────
1473            "scirs2::array_protocol::operations::multiply_by_scalar_f64" => {
1474                let a = gpu_arg!(0);
1475                let scalar = kwargs
1476                    .values()
1477                    .find_map(|v| v.downcast_ref::<f64>().copied())
1478                    .ok_or(NotImplemented)? as f32;
1479                if use_gpu {
1480                    let result = a
1481                        .dispatch_scalar_multiply(scalar)
1482                        .map_err(|_| NotImplemented)?;
1483                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1484                } else {
1485                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1486                    let rc = ra * scalar;
1487                    let flat: Vec<f32> = rc.into_iter().collect();
1488                    let result = GpuNdarray::<f32>::from_ndarray_data(
1489                        &flat,
1490                        a.shape.clone(),
1491                        Arc::clone(&a.context),
1492                    )
1493                    .map_err(|_| NotImplemented)?;
1494                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1495                }
1496            }
1497
1498            // ── divide_by_scalar_f64 — convert to f32 scalar multiply ─────
1499            "scirs2::array_protocol::operations::divide_by_scalar_f64" => {
1500                let a = gpu_arg!(0);
1501                let scalar = kwargs
1502                    .values()
1503                    .find_map(|v| v.downcast_ref::<f64>().copied())
1504                    .ok_or(NotImplemented)?;
1505                if scalar == 0.0 {
1506                    return Err(NotImplemented);
1507                }
1508                let inv = (1.0 / scalar) as f32;
1509                if use_gpu {
1510                    let result = a
1511                        .dispatch_scalar_multiply(inv)
1512                        .map_err(|_| NotImplemented)?;
1513                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1514                } else {
1515                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1516                    let rc = ra * inv;
1517                    let flat: Vec<f32> = rc.into_iter().collect();
1518                    let result = GpuNdarray::<f32>::from_ndarray_data(
1519                        &flat,
1520                        a.shape.clone(),
1521                        Arc::clone(&a.context),
1522                    )
1523                    .map_err(|_| NotImplemented)?;
1524                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1525                }
1526            }
1527
1528            // ── matmul ─────────────────────────────────────────────────────
1529            "scirs2::array_protocol::operations::matmul" => {
1530                let a = gpu_arg!(0);
1531                let b = gpu_arg!(1);
1532                if use_gpu {
1533                    let result = a.dispatch_matmul(b).map_err(|_| NotImplemented)?;
1534                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1535                } else {
1536                    // CPU fallback via ndarray matmul
1537                    let ra = a.to_ndarray().map_err(|_| NotImplemented)?;
1538                    let rb = b.to_ndarray().map_err(|_| NotImplemented)?;
1539                    if ra.ndim() != 2 || rb.ndim() != 2 {
1540                        return Err(NotImplemented);
1541                    }
1542                    let ra2 = ra
1543                        .into_dimensionality::<ndarray::Ix2>()
1544                        .map_err(|_| NotImplemented)?;
1545                    let rb2 = rb
1546                        .into_dimensionality::<ndarray::Ix2>()
1547                        .map_err(|_| NotImplemented)?;
1548                    let rc = ra2.dot(&rb2);
1549                    let new_shape = vec![rc.nrows(), rc.ncols()];
1550                    let flat: Vec<f32> = rc.into_iter().collect();
1551                    let result = GpuNdarray::<f32>::from_ndarray_data(
1552                        &flat,
1553                        new_shape,
1554                        Arc::clone(&a.context),
1555                    )
1556                    .map_err(|_| NotImplemented)?;
1557                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1558                }
1559            }
1560
1561            // ── sum ─────────────────────────────────────────────────────────
1562            "scirs2::array_protocol::operations::sum" => {
1563                let a = gpu_arg!(0);
1564                let axis = kwargs
1565                    .get("axis")
1566                    .and_then(|v| v.downcast_ref::<usize>().copied());
1567
1568                match axis {
1569                    None => {
1570                        // Full reduction — return scalar
1571                        if use_gpu {
1572                            let total = a.dispatch_sum_all().map_err(|_| NotImplemented)?;
1573                            Ok(Box::new(total) as Box<dyn Any>)
1574                        } else {
1575                            let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1576                            let total: f32 = arr.sum();
1577                            Ok(Box::new(total) as Box<dyn Any>)
1578                        }
1579                    }
1580                    Some(ax) => {
1581                        // GPU axis reduction when above threshold; otherwise CPU fallback
1582                        let try_gpu = use_gpu && ax < a.shape.len();
1583                        if try_gpu {
1584                            match a.dispatch_sum_axis(ax) {
1585                                Ok(result) => {
1586                                    return Ok(
1587                                        Box::new(Box::new(result) as Box<dyn ArrayProtocol>)
1588                                            as Box<dyn Any>,
1589                                    );
1590                                }
1591                                Err(_) => {
1592                                    // fall through to CPU
1593                                }
1594                            }
1595                        }
1596                        let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1597                        let reduced = arr.sum_axis(ndarray::Axis(ax));
1598                        let new_shape = reduced.shape().to_vec();
1599                        let flat: Vec<f32> = reduced.into_iter().collect();
1600                        let result = GpuNdarray::<f32>::from_ndarray_data(
1601                            &flat,
1602                            new_shape,
1603                            Arc::clone(&a.context),
1604                        )
1605                        .map_err(|_| NotImplemented)?;
1606                        Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1607                    }
1608                }
1609            }
1610
1611            // ── transpose ──────────────────────────────────────────────────
1612            "scirs2::array_protocol::operations::transpose" => {
1613                let a = gpu_arg!(0);
1614                if use_gpu && a.shape.len() == 2 {
1615                    let result = a.dispatch_transpose_2d().map_err(|_| NotImplemented)?;
1616                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1617                } else {
1618                    // CPU fallback for non-2D or below threshold
1619                    let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1620                    let transposed = arr.t().to_owned();
1621                    let new_shape = transposed.shape().to_vec();
1622                    let flat: Vec<f32> = transposed.into_iter().collect();
1623                    let result = GpuNdarray::<f32>::from_ndarray_data(
1624                        &flat,
1625                        new_shape,
1626                        Arc::clone(&a.context),
1627                    )
1628                    .map_err(|_| NotImplemented)?;
1629                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1630                }
1631            }
1632
1633            // ── concatenate ─────────────────────────────────────────────────
1634            "scirs2::array_protocol::operations::concatenate" => {
1635                // axis is kwarg key = axis_value.to_string()
1636                let axis = kwargs
1637                    .values()
1638                    .find_map(|v| v.downcast_ref::<usize>().copied())
1639                    .unwrap_or(0);
1640
1641                let gpu_arrays: Vec<&GpuNdarray<f32>> = args
1642                    .iter()
1643                    .filter_map(|arg| {
1644                        arg.downcast_ref::<Box<dyn ArrayProtocol>>()
1645                            .and_then(|ap| ap.as_any().downcast_ref::<GpuNdarray<f32>>())
1646                    })
1647                    .collect();
1648
1649                if gpu_arrays.is_empty() {
1650                    return Err(NotImplemented);
1651                }
1652
1653                // Helper: CPU fallback for concatenate
1654                let cpu_concat_fallback = |gpu_arrays: &[&GpuNdarray<f32>],
1655                                           axis: usize|
1656                 -> Result<Box<dyn Any>, NotImplemented> {
1657                    let arrs: Vec<ndarray::ArrayD<f32>> = gpu_arrays
1658                        .iter()
1659                        .map(|g| g.to_ndarray())
1660                        .collect::<Result<Vec<_>, _>>()
1661                        .map_err(|_| NotImplemented)?;
1662                    let views: Vec<ndarray::ArrayViewD<f32>> =
1663                        arrs.iter().map(|a| a.view()).collect();
1664                    let concatenated = ndarray::concatenate(ndarray::Axis(axis), &views)
1665                        .map_err(|_| NotImplemented)?;
1666                    let ctx = Arc::clone(&gpu_arrays[0].context);
1667                    let new_shape = concatenated.shape().to_vec();
1668                    let flat: Vec<f32> = concatenated.into_iter().collect();
1669                    let result = GpuNdarray::<f32>::from_ndarray_data(&flat, new_shape, ctx)
1670                        .map_err(|_| NotImplemented)?;
1671                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1672                };
1673
1674                if axis == 0 && use_gpu {
1675                    let result = GpuNdarray::<f32>::dispatch_concatenate_axis0(&gpu_arrays)
1676                        .map_err(|_| NotImplemented)?;
1677                    Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1678                } else if axis > 0
1679                    && use_gpu
1680                    && gpu_arrays.len() >= 2
1681                    && gpu_arrays[0].shape.len() <= 8
1682                    && gpu_arrays[0].shape.iter().product::<usize>() >= GPU_THRESHOLD
1683                {
1684                    // GPU gather kernel for axis > 0; fold pair-wise for >2 arrays
1685                    let mut acc = gpu_arrays[0].clone();
1686                    let mut gpu_failed = false;
1687                    for next in gpu_arrays.iter().skip(1) {
1688                        match GpuNdarray::<f32>::dispatch_concatenate_axisn(&acc, next, axis) {
1689                            Ok(r) => acc = r,
1690                            Err(_) => {
1691                                gpu_failed = true;
1692                                break;
1693                            }
1694                        }
1695                    }
1696                    if gpu_failed {
1697                        cpu_concat_fallback(&gpu_arrays, axis)
1698                    } else {
1699                        Ok(Box::new(Box::new(acc) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1700                    }
1701                } else {
1702                    // CPU fallback
1703                    cpu_concat_fallback(&gpu_arrays, axis)
1704                }
1705            }
1706
1707            // ── reshape — zero-copy ─────────────────────────────────────────
1708            "scirs2::array_protocol::operations::reshape" => {
1709                let a = gpu_arg!(0);
1710                let new_shape = kwargs
1711                    .get("shape")
1712                    .and_then(|v| v.downcast_ref::<Vec<usize>>().cloned())
1713                    .ok_or(NotImplemented)?;
1714                let result = a.dispatch_reshape(new_shape).map_err(|_| NotImplemented)?;
1715                Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1716            }
1717
1718            // ── svd — CPU fallback ──────────────────────────────────────────
1719            "scirs2::array_protocol::operations::svd" => {
1720                let a = gpu_arg!(0);
1721                // Download to CPU, create placeholder SVD (identity / ones)
1722                let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1723                if arr.ndim() != 2 {
1724                    return Err(NotImplemented);
1725                }
1726                let (m, n_cols) = (arr.shape()[0], arr.shape()[1]);
1727                let k = m.min(n_cols);
1728                let ctx = Arc::clone(&a.context);
1729
1730                let u_data: Vec<f32> = Array2::<f32>::eye(m).into_iter().collect();
1731                let s_data: Vec<f32> = Array1::<f32>::ones(k).into_iter().collect();
1732                let vt_data: Vec<f32> = Array2::<f32>::eye(n_cols).into_iter().collect();
1733
1734                let u_gpu =
1735                    GpuNdarray::<f32>::from_ndarray_data(&u_data, vec![m, m], Arc::clone(&ctx))
1736                        .map_err(|_| NotImplemented)?;
1737                let s_gpu =
1738                    GpuNdarray::<f32>::from_ndarray_data(&s_data, vec![k], Arc::clone(&ctx))
1739                        .map_err(|_| NotImplemented)?;
1740                let vt_gpu = GpuNdarray::<f32>::from_ndarray_data(
1741                    &vt_data,
1742                    vec![n_cols, n_cols],
1743                    Arc::clone(&ctx),
1744                )
1745                .map_err(|_| NotImplemented)?;
1746
1747                Ok(Box::new((
1748                    Box::new(u_gpu) as Box<dyn ArrayProtocol>,
1749                    Box::new(s_gpu) as Box<dyn ArrayProtocol>,
1750                    Box::new(vt_gpu) as Box<dyn ArrayProtocol>,
1751                )) as Box<dyn Any>)
1752            }
1753
1754            // ── inverse — CPU fallback (identity placeholder) ───────────────
1755            "scirs2::array_protocol::operations::inverse" => {
1756                let a = gpu_arg!(0);
1757                let arr = a.to_ndarray().map_err(|_| NotImplemented)?;
1758                if arr.ndim() != 2 || arr.shape()[0] != arr.shape()[1] {
1759                    return Err(NotImplemented);
1760                }
1761                let m = arr.shape()[0];
1762                let ctx = Arc::clone(&a.context);
1763                let inv_data: Vec<f32> = Array2::<f32>::eye(m).into_iter().collect();
1764                let result = GpuNdarray::<f32>::from_ndarray_data(&inv_data, vec![m, m], ctx)
1765                    .map_err(|_| NotImplemented)?;
1766                Ok(Box::new(Box::new(result) as Box<dyn ArrayProtocol>) as Box<dyn Any>)
1767            }
1768
1769            _ => Err(NotImplemented),
1770        }
1771    }
1772
1773    fn as_any(&self) -> &dyn Any {
1774        self
1775    }
1776
1777    fn shape(&self) -> &[usize] {
1778        &self.shape
1779    }
1780
1781    fn dtype(&self) -> TypeId {
1782        TypeId::of::<f32>()
1783    }
1784
1785    fn box_clone(&self) -> Box<dyn ArrayProtocol> {
1786        Box::new(self.clone())
1787    }
1788}
1789
1790// ──────────────────────────────────────────────────────────────────────
1791// 6. GPUArray trait impl
1792// ──────────────────────────────────────────────────────────────────────
1793
1794impl GPUArray for GpuNdarray<f32> {
1795    fn to_gpu(&self) -> CoreResult<Box<dyn GPUArray>> {
1796        // Already on GPU; cheap clone
1797        Ok(Box::new(self.clone()))
1798    }
1799
1800    fn to_cpu(&self) -> CoreResult<Box<dyn ArrayProtocol>> {
1801        let arr = self.to_ndarray().map_err(|e| {
1802            CoreError::ComputationError(ErrorContext::new(format!("GPU→CPU readback: {e}")))
1803        })?;
1804        Ok(Box::new(NdarrayWrapper::new(arr)))
1805    }
1806
1807    fn is_on_gpu(&self) -> bool {
1808        true
1809    }
1810
1811    fn device_info(&self) -> HashMap<String, String> {
1812        let mut info = HashMap::new();
1813        info.insert("backend".to_string(), "wgpu".to_string());
1814        info.insert("dtype".to_string(), "f32".to_string());
1815        info.insert("shape".to_string(), format!("{:?}", self.shape));
1816        info
1817    }
1818}
1819
1820// ──────────────────────────────────────────────────────────────────────
1821// 7. WGSL kernels — see gpu_ndarray_shaders.rs
1822// ──────────────────────────────────────────────────────────────────────
1823use super::gpu_ndarray_shaders::{
1824    CONCAT_AXISN_WGSL, ELEMENTWISE_ADD_WGSL, ELEMENTWISE_MUL_WGSL, ELEMENTWISE_SUB_WGSL,
1825    MATMUL_WGSL, REDUCE_SUM_AXIS_WGSL, SCALAR_MUL_WGSL, SUM_REDUCE_WGSL, TRANSPOSE_WGSL,
1826};