Skip to main content

rustorch_core/backend/
wgpu.rs

1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3use wgpu;
4use wgpu::util::DeviceExt;
5
6pub struct MemoryPool {
7    pool: Mutex<HashMap<u64, Vec<wgpu::Buffer>>>,
8}
9
10impl MemoryPool {
11    fn new() -> Self {
12        Self {
13            pool: Mutex::new(HashMap::new()),
14        }
15    }
16
17    pub fn get_buffer(&self, device: &wgpu::Device, size: u64) -> wgpu::Buffer {
18        let mut pool = self.pool.lock().unwrap();
19        if let Some(buffers) = pool.get_mut(&size) {
20            if let Some(buf) = buffers.pop() {
21                return buf;
22            }
23        }
24
25        device.create_buffer(&wgpu::BufferDescriptor {
26            label: Some("Pooled Buffer"),
27            size,
28            usage: wgpu::BufferUsages::STORAGE
29                | wgpu::BufferUsages::COPY_SRC
30                | wgpu::BufferUsages::COPY_DST
31                | wgpu::BufferUsages::UNIFORM,
32            mapped_at_creation: false,
33        })
34    }
35
36    pub fn return_buffer(&self, buffer: wgpu::Buffer, size: u64) {
37        let mut pool = self.pool.lock().unwrap();
38        pool.entry(size).or_insert_with(Vec::new).push(buffer);
39    }
40
41    pub fn clear(&self) {
42        let mut pool = self.pool.lock().unwrap();
43        pool.clear();
44    }
45}
46
47static MEMORY_POOL: OnceLock<MemoryPool> = OnceLock::new();
48
49pub fn get_memory_pool() -> &'static MemoryPool {
50    MEMORY_POOL.get_or_init(MemoryPool::new)
51}
52
53pub fn clear_memory_pool() {
54    get_memory_pool().clear();
55}
56
57pub struct WgpuContext {
58    pub device: wgpu::Device,
59    pub queue: wgpu::Queue,
60    pub matmul_pipeline: wgpu::ComputePipeline,
61    pub matmul_bind_group_layout: wgpu::BindGroupLayout,
62    pub elementwise_pipeline: wgpu::ComputePipeline,
63    pub elementwise_bind_group_layout: wgpu::BindGroupLayout,
64    pub elementwise_vec4_pipeline: wgpu::ComputePipeline,
65    pub adam_pipeline: wgpu::ComputePipeline,
66    pub adam_bind_group_layout: wgpu::BindGroupLayout,
67    pub reduce_pipeline: wgpu::ComputePipeline,
68    pub reduce_general_pipeline: wgpu::ComputePipeline,
69    pub reduce_all_pipeline: wgpu::ComputePipeline,
70    pub reduce_bind_group_layout: wgpu::BindGroupLayout,
71    pub contiguous_pipeline: wgpu::ComputePipeline,
72    pub contiguous_bind_group_layout: wgpu::BindGroupLayout,
73    pub current_encoder: Mutex<Option<wgpu::CommandEncoder>>,
74}
75
76static CONTEXT: OnceLock<WgpuContext> = OnceLock::new();
77
78pub fn get_context() -> Option<&'static WgpuContext> {
79    Some(CONTEXT.get_or_init(|| {
80        let instance = wgpu::Instance::default();
81        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
82            power_preference: wgpu::PowerPreference::HighPerformance,
83            compatible_surface: None,
84            force_fallback_adapter: false,
85        }))
86        .expect("No suitable WGPU adapter found");
87
88        let (device, queue) = pollster::block_on(adapter.request_device(
89            &wgpu::DeviceDescriptor {
90                label: Some("RusTorch Device"),
91                required_features: wgpu::Features::empty(),
92                required_limits: wgpu::Limits::downlevel_defaults(),
93            },
94            None,
95        ))
96        .expect("Failed to create device");
97
98        // --- MatMul Pipeline ---
99        let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
100            label: Some("MatMul Shader"),
101            source: wgpu::ShaderSource::Wgsl(include_str!("matmul.wgsl").into()),
102        });
103
104        let matmul_bind_group_layout =
105            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
106                label: Some("MatMul Bind Group Layout"),
107                entries: &[
108                    wgpu::BindGroupLayoutEntry {
109                        binding: 0,
110                        visibility: wgpu::ShaderStages::COMPUTE,
111                        ty: wgpu::BindingType::Buffer {
112                            ty: wgpu::BufferBindingType::Storage { read_only: true },
113                            has_dynamic_offset: false,
114                            min_binding_size: None,
115                        },
116                        count: None,
117                    },
118                    wgpu::BindGroupLayoutEntry {
119                        binding: 1,
120                        visibility: wgpu::ShaderStages::COMPUTE,
121                        ty: wgpu::BindingType::Buffer {
122                            ty: wgpu::BufferBindingType::Storage { read_only: true },
123                            has_dynamic_offset: false,
124                            min_binding_size: None,
125                        },
126                        count: None,
127                    },
128                    wgpu::BindGroupLayoutEntry {
129                        binding: 2,
130                        visibility: wgpu::ShaderStages::COMPUTE,
131                        ty: wgpu::BindingType::Buffer {
132                            ty: wgpu::BufferBindingType::Storage { read_only: false },
133                            has_dynamic_offset: false,
134                            min_binding_size: None,
135                        },
136                        count: None,
137                    },
138                    wgpu::BindGroupLayoutEntry {
139                        binding: 3,
140                        visibility: wgpu::ShaderStages::COMPUTE,
141                        ty: wgpu::BindingType::Buffer {
142                            ty: wgpu::BufferBindingType::Uniform,
143                            has_dynamic_offset: false,
144                            min_binding_size: None,
145                        },
146                        count: None,
147                    },
148                    wgpu::BindGroupLayoutEntry {
149                        binding: 4,
150                        visibility: wgpu::ShaderStages::COMPUTE,
151                        ty: wgpu::BindingType::Buffer {
152                            ty: wgpu::BufferBindingType::Storage { read_only: true },
153                            has_dynamic_offset: false,
154                            min_binding_size: None,
155                        },
156                        count: None,
157                    },
158                ],
159            });
160
161        let matmul_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
162            label: Some("MatMul Pipeline"),
163            layout: Some(
164                &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
165                    label: Some("MatMul Pipeline Layout"),
166                    bind_group_layouts: &[&matmul_bind_group_layout],
167                    push_constant_ranges: &[],
168                }),
169            ),
170            module: &matmul_shader,
171            entry_point: "main",
172        });
173
174        // --- Elementwise Pipeline ---
175        let elem_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
176            label: Some("Elementwise Shader"),
177            source: wgpu::ShaderSource::Wgsl(include_str!("elementwise.wgsl").into()),
178        });
179
180        let elem_bind_group_layout =
181            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
182                label: Some("Elementwise Bind Group Layout"),
183                entries: &[
184                    wgpu::BindGroupLayoutEntry {
185                        binding: 0,
186                        visibility: wgpu::ShaderStages::COMPUTE,
187                        ty: wgpu::BindingType::Buffer {
188                            ty: wgpu::BufferBindingType::Storage { read_only: true },
189                            has_dynamic_offset: false,
190                            min_binding_size: None,
191                        },
192                        count: None,
193                    },
194                    wgpu::BindGroupLayoutEntry {
195                        binding: 1,
196                        visibility: wgpu::ShaderStages::COMPUTE,
197                        ty: wgpu::BindingType::Buffer {
198                            ty: wgpu::BufferBindingType::Storage { read_only: true },
199                            has_dynamic_offset: false,
200                            min_binding_size: None,
201                        },
202                        count: None,
203                    },
204                    wgpu::BindGroupLayoutEntry {
205                        binding: 2,
206                        visibility: wgpu::ShaderStages::COMPUTE,
207                        ty: wgpu::BindingType::Buffer {
208                            ty: wgpu::BufferBindingType::Storage { read_only: false },
209                            has_dynamic_offset: false,
210                            min_binding_size: None,
211                        },
212                        count: None,
213                    },
214                    wgpu::BindGroupLayoutEntry {
215                        binding: 3,
216                        visibility: wgpu::ShaderStages::COMPUTE,
217                        ty: wgpu::BindingType::Buffer {
218                            ty: wgpu::BufferBindingType::Uniform,
219                            has_dynamic_offset: false,
220                            min_binding_size: None,
221                        },
222                        count: None,
223                    },
224                ],
225            });
226
227        let elementwise_pipeline =
228            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
229                label: Some("Elementwise Pipeline"),
230                layout: Some(
231                    &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
232                        label: Some("Elementwise Pipeline Layout"),
233                        bind_group_layouts: &[&elem_bind_group_layout],
234                        push_constant_ranges: &[],
235                    }),
236                ),
237                module: &elem_shader,
238                entry_point: "main",
239            });
240
241        // --- Elementwise Vec4 Pipeline ---
242        let elem_vec4_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
243            label: Some("Elementwise Vec4 Shader"),
244            source: wgpu::ShaderSource::Wgsl(include_str!("elementwise_vec4.wgsl").into()),
245        });
246
247        let elementwise_vec4_pipeline =
248            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
249                label: Some("Elementwise Vec4 Pipeline"),
250                layout: Some(
251                    &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
252                        label: Some("Elementwise Vec4 Pipeline Layout"),
253                        bind_group_layouts: &[&elem_bind_group_layout], // Reusing same layout
254                        push_constant_ranges: &[],
255                    }),
256                ),
257                module: &elem_vec4_shader,
258                entry_point: "main",
259            });
260
261        // --- Adam Pipeline ---
262        let adam_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
263            label: Some("Adam Shader"),
264            source: wgpu::ShaderSource::Wgsl(include_str!("adam.wgsl").into()),
265        });
266
267        let adam_bind_group_layout =
268            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
269                label: Some("Adam Bind Group Layout"),
270                entries: &[
271                    wgpu::BindGroupLayoutEntry {
272                        binding: 0,
273                        visibility: wgpu::ShaderStages::COMPUTE,
274                        ty: wgpu::BindingType::Buffer {
275                            ty: wgpu::BufferBindingType::Storage { read_only: false },
276                            has_dynamic_offset: false,
277                            min_binding_size: None,
278                        },
279                        count: None,
280                    },
281                    wgpu::BindGroupLayoutEntry {
282                        binding: 1,
283                        visibility: wgpu::ShaderStages::COMPUTE,
284                        ty: wgpu::BindingType::Buffer {
285                            ty: wgpu::BufferBindingType::Storage { read_only: true },
286                            has_dynamic_offset: false,
287                            min_binding_size: None,
288                        },
289                        count: None,
290                    },
291                    wgpu::BindGroupLayoutEntry {
292                        binding: 2,
293                        visibility: wgpu::ShaderStages::COMPUTE,
294                        ty: wgpu::BindingType::Buffer {
295                            ty: wgpu::BufferBindingType::Storage { read_only: false },
296                            has_dynamic_offset: false,
297                            min_binding_size: None,
298                        },
299                        count: None,
300                    },
301                    wgpu::BindGroupLayoutEntry {
302                        binding: 3,
303                        visibility: wgpu::ShaderStages::COMPUTE,
304                        ty: wgpu::BindingType::Buffer {
305                            ty: wgpu::BufferBindingType::Storage { read_only: false },
306                            has_dynamic_offset: false,
307                            min_binding_size: None,
308                        },
309                        count: None,
310                    },
311                    wgpu::BindGroupLayoutEntry {
312                        binding: 4,
313                        visibility: wgpu::ShaderStages::COMPUTE,
314                        ty: wgpu::BindingType::Buffer {
315                            ty: wgpu::BufferBindingType::Uniform,
316                            has_dynamic_offset: false,
317                            min_binding_size: None,
318                        },
319                        count: None,
320                    },
321                ],
322            });
323
324        let adam_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
325            label: Some("Adam Pipeline"),
326            layout: Some(
327                &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
328                    label: Some("Adam Pipeline Layout"),
329                    bind_group_layouts: &[&adam_bind_group_layout],
330                    push_constant_ranges: &[],
331                }),
332            ),
333            module: &adam_shader,
334            entry_point: "main",
335        });
336
337        // --- Reduce Pipeline ---
338        let reduce_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
339            label: Some("Reduce Shader"),
340            source: wgpu::ShaderSource::Wgsl(include_str!("reduce.wgsl").into()),
341        });
342
343        let reduce_bind_group_layout =
344            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
345                label: Some("Reduce Bind Group Layout"),
346                entries: &[
347                    wgpu::BindGroupLayoutEntry {
348                        binding: 0,
349                        visibility: wgpu::ShaderStages::COMPUTE,
350                        ty: wgpu::BindingType::Buffer {
351                            ty: wgpu::BufferBindingType::Storage { read_only: true },
352                            has_dynamic_offset: false,
353                            min_binding_size: None,
354                        },
355                        count: None,
356                    },
357                    wgpu::BindGroupLayoutEntry {
358                        binding: 1,
359                        visibility: wgpu::ShaderStages::COMPUTE,
360                        ty: wgpu::BindingType::Buffer {
361                            ty: wgpu::BufferBindingType::Storage { read_only: false },
362                            has_dynamic_offset: false,
363                            min_binding_size: None,
364                        },
365                        count: None,
366                    },
367                    wgpu::BindGroupLayoutEntry {
368                        binding: 2,
369                        visibility: wgpu::ShaderStages::COMPUTE,
370                        ty: wgpu::BindingType::Buffer {
371                            ty: wgpu::BufferBindingType::Uniform,
372                            has_dynamic_offset: false,
373                            min_binding_size: None,
374                        },
375                        count: None,
376                    },
377                ],
378            });
379
380        let reduce_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
381            label: Some("Reduce Pipeline"),
382            layout: Some(
383                &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
384                    label: Some("Reduce Pipeline Layout"),
385                    bind_group_layouts: &[&reduce_bind_group_layout],
386                    push_constant_ranges: &[],
387                }),
388            ),
389            module: &reduce_shader,
390            entry_point: "reduce_sum_dim0",
391        });
392
393        let reduce_general_pipeline =
394            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
395                label: Some("Reduce General Pipeline"),
396                layout: Some(
397                    &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
398                        label: Some("Reduce Pipeline Layout"),
399                        bind_group_layouts: &[&reduce_bind_group_layout],
400                        push_constant_ranges: &[],
401                    }),
402                ),
403                module: &reduce_shader,
404                entry_point: "reduce_sum_general",
405            });
406
407        let reduce_all_pipeline =
408            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
409                label: Some("Reduce All Pipeline"),
410                layout: Some(
411                    &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
412                        label: Some("Reduce Pipeline Layout"),
413                        bind_group_layouts: &[&reduce_bind_group_layout],
414                        push_constant_ranges: &[],
415                    }),
416                ),
417                module: &reduce_shader,
418                entry_point: "reduce_sum_all",
419            });
420
421        // --- Contiguous Pipeline ---
422        let contiguous_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
423            label: Some("Contiguous Shader"),
424            source: wgpu::ShaderSource::Wgsl(include_str!("contiguous.wgsl").into()),
425        });
426
427        let contiguous_bind_group_layout =
428            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
429                label: Some("Contiguous Bind Group Layout"),
430                entries: &[
431                    wgpu::BindGroupLayoutEntry {
432                        binding: 0,
433                        visibility: wgpu::ShaderStages::COMPUTE,
434                        ty: wgpu::BindingType::Buffer {
435                            ty: wgpu::BufferBindingType::Storage { read_only: true },
436                            has_dynamic_offset: false,
437                            min_binding_size: None,
438                        },
439                        count: None,
440                    },
441                    wgpu::BindGroupLayoutEntry {
442                        binding: 1,
443                        visibility: wgpu::ShaderStages::COMPUTE,
444                        ty: wgpu::BindingType::Buffer {
445                            ty: wgpu::BufferBindingType::Storage { read_only: false },
446                            has_dynamic_offset: false,
447                            min_binding_size: None,
448                        },
449                        count: None,
450                    },
451                    wgpu::BindGroupLayoutEntry {
452                        binding: 2,
453                        visibility: wgpu::ShaderStages::COMPUTE,
454                        ty: wgpu::BindingType::Buffer {
455                            ty: wgpu::BufferBindingType::Uniform,
456                            has_dynamic_offset: false,
457                            min_binding_size: None,
458                        },
459                        count: None,
460                    },
461                    wgpu::BindGroupLayoutEntry {
462                        binding: 3,
463                        visibility: wgpu::ShaderStages::COMPUTE,
464                        ty: wgpu::BindingType::Buffer {
465                            ty: wgpu::BufferBindingType::Storage { read_only: true },
466                            has_dynamic_offset: false,
467                            min_binding_size: None,
468                        },
469                        count: None,
470                    },
471                ],
472            });
473
474        let contiguous_pipeline =
475            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
476                label: Some("Contiguous Pipeline"),
477                layout: Some(
478                    &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
479                        label: Some("Contiguous Pipeline Layout"),
480                        bind_group_layouts: &[&contiguous_bind_group_layout],
481                        push_constant_ranges: &[],
482                    }),
483                ),
484                module: &contiguous_shader,
485                entry_point: "main",
486            });
487
488        WgpuContext {
489            device,
490            queue,
491            matmul_pipeline,
492            matmul_bind_group_layout,
493            elementwise_pipeline,
494            elementwise_bind_group_layout: elem_bind_group_layout,
495            elementwise_vec4_pipeline,
496            adam_pipeline,
497            adam_bind_group_layout,
498            reduce_pipeline,
499            reduce_general_pipeline,
500            reduce_all_pipeline,
501            reduce_bind_group_layout,
502            contiguous_pipeline,
503            contiguous_bind_group_layout,
504            current_encoder: Mutex::new(None),
505        }
506    }))
507}
508
509pub fn flush_queue() {
510    let ctx = get_context().expect("WGPU not initialized");
511    let mut lock = ctx.current_encoder.lock().unwrap();
512    if let Some(encoder) = lock.take() {
513        ctx.queue.submit(Some(encoder.finish()));
514    }
515}
516
517pub fn record_commands<F>(f: F)
518where
519    F: FnOnce(&mut wgpu::CommandEncoder),
520{
521    let ctx = get_context().expect("WGPU not initialized");
522    let mut lock = ctx.current_encoder.lock().unwrap();
523
524    if lock.is_none() {
525        *lock = Some(
526            ctx.device
527                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
528                    label: Some("Buffered Encoder"),
529                }),
530        );
531    }
532
533    let encoder = lock.as_mut().unwrap();
534    f(encoder);
535}
536
537#[derive(Debug, Clone, Copy, PartialEq, Eq)]
538pub enum ElementwiseOp {
539    Add = 0,
540    Sub = 1,
541    Mul = 2,
542    Div = 3,
543    ReLU = 4,
544    Sigmoid = 5,
545    Tanh = 6,
546    ReLUBackward = 10,
547    SigmoidBackward = 11,
548    TanhBackward = 12,
549    SGDStep = 20,
550    ExpandRepeat = 30,
551}
552
553pub use crate::backend::Activation;
554
555pub fn elementwise_wgpu_buffer(
556    input1: &wgpu::Buffer,
557    input1_shape: &[usize],
558    input1_strides: &[usize],
559    input2: Option<(&wgpu::Buffer, &[usize], &[usize])>,
560    output_shape: &[usize],
561    op: ElementwiseOp,
562    alpha: Option<f32>,
563) -> wgpu::Buffer {
564    let ctx = get_context().expect("WGPU not initialized");
565    let device = &ctx.device;
566    // queue used for buffer creation (internal)
567
568    let output_size: usize = output_shape.iter().product();
569    let output_byte_size = output_size * std::mem::size_of::<f32>();
570    let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
571
572    fn is_contiguous_or_scalar(shape: &[usize], strides: &[usize]) -> bool {
573        if shape.iter().product::<usize>() == 1 {
574            return true;
575        }
576        let mut st = 1;
577        for i in (0..shape.len()).rev() {
578            if strides[i] != st {
579                return false;
580            }
581            st *= shape[i];
582        }
583        return true;
584    }
585
586    let can_use_vec4 = output_size % 4 == 0
587        && is_contiguous_or_scalar(input1_shape, input1_strides)
588        && input2.map_or(true, |(_, s2, st2)| is_contiguous_or_scalar(s2, st2));
589
590    if can_use_vec4 {
591        #[repr(C)]
592        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
593        struct ParamsVec4 {
594            numel_vec4: u32,
595            op: u32,
596            alpha: f32,
597            stride_mode_1: u32,
598            stride_mode_2: u32,
599        }
600
601        let stride_mode_1 = if input1_shape.iter().product::<usize>() == 1 {
602            1
603        } else {
604            0
605        };
606        let stride_mode_2 = if let Some((_, s2, _)) = input2 {
607            if s2.iter().product::<usize>() == 1 {
608                1
609            } else {
610                0
611            }
612        } else {
613            0
614        };
615
616        let params = ParamsVec4 {
617            numel_vec4: (output_size / 4) as u32,
618            op: op as u32,
619            alpha: alpha.unwrap_or(0.0),
620            stride_mode_1,
621            stride_mode_2,
622        };
623
624        let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
625            label: Some("Elementwise Vec4 Params"),
626            contents: bytemuck::bytes_of(&params),
627            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
628        });
629
630        // Input 2 (optional)
631        let dummy = device.create_buffer(&wgpu::BufferDescriptor {
632            label: Some("Dummy"),
633            size: 16,
634            usage: wgpu::BufferUsages::STORAGE,
635            mapped_at_creation: false,
636        });
637
638        let input2_buf = input2.map(|(b, _, _)| b).unwrap_or(&dummy);
639
640        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
641            label: Some("Elementwise Vec4 Bind Group"),
642            layout: &ctx.elementwise_bind_group_layout, // Compatible layout
643            entries: &[
644                wgpu::BindGroupEntry {
645                    binding: 0,
646                    resource: input1.as_entire_binding(),
647                },
648                wgpu::BindGroupEntry {
649                    binding: 1,
650                    resource: input2_buf.as_entire_binding(),
651                },
652                wgpu::BindGroupEntry {
653                    binding: 2,
654                    resource: output_buffer.as_entire_binding(),
655                },
656                wgpu::BindGroupEntry {
657                    binding: 3,
658                    resource: params_buffer.as_entire_binding(),
659                },
660            ],
661        });
662
663        record_commands(|encoder| {
664            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
665                label: Some("Elementwise Vec4 Pass"),
666                timestamp_writes: None,
667            });
668            cpass.set_pipeline(&ctx.elementwise_vec4_pipeline);
669            cpass.set_bind_group(0, &bind_group, &[]);
670            let workgroups = (params.numel_vec4 + 255) / 256;
671            cpass.dispatch_workgroups(workgroups, 1, 1);
672        });
673    } else {
674        #[repr(C)]
675        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
676        struct Params {
677            numel: u32,
678            op: u32,
679            alpha: f32,
680            ndim: u32,
681            shape: [u32; 4],
682            strides_out: [u32; 4],
683            strides_1: [u32; 4],
684            strides_2: [u32; 4],
685        }
686
687        let mut shape_arr = [1u32; 4];
688        let stride_out_arr = [0u32; 4]; // Unused in shader logic but present in struct
689        let mut stride_1_arr = [0u32; 4];
690        let mut stride_2_arr = [0u32; 4];
691
692        // Populate arrays (broadcasting logic)
693        // Note: Shader iterates over output_shape (target shape)
694        // We need to map target_index -> target_coord -> input_coord -> input_index
695        // Using strides directly allows: index -> dot(coord, stride)
696        // But shader does: index -> coord -> dot(coord, stride)
697
698        let ndim = output_shape.len();
699        for i in 0..ndim {
700            shape_arr[4 - ndim + i] = output_shape[i] as u32;
701        }
702
703        fn fill_strides(
704            shape: &[usize],
705            strides: &[usize],
706            target_shape: &[usize],
707            out_arr: &mut [u32; 4],
708        ) {
709            let ndim = target_shape.len();
710            let s_ndim = shape.len();
711            let offset = ndim - s_ndim;
712
713            for i in 0..ndim {
714                if i >= offset {
715                    let s_idx = i - offset;
716                    if shape[s_idx] == 1 && target_shape[i] > 1 {
717                        out_arr[4 - ndim + i] = 0; // Broadcast
718                    } else {
719                        out_arr[4 - ndim + i] = strides[s_idx] as u32;
720                    }
721                } else {
722                    out_arr[4 - ndim + i] = 0; // Broadcast new dim
723                }
724            }
725        }
726
727        fill_strides(
728            input1_shape,
729            input1_strides,
730            output_shape,
731            &mut stride_1_arr,
732        );
733        if let Some((_, s2, st2)) = input2 {
734            fill_strides(s2, st2, output_shape, &mut stride_2_arr);
735        }
736
737        let p = Params {
738            numel: output_size as u32,
739            op: op as u32,
740            alpha: alpha.unwrap_or(0.0),
741            ndim: ndim as u32,
742            shape: shape_arr,
743            strides_out: stride_out_arr,
744            strides_1: stride_1_arr,
745            strides_2: stride_2_arr,
746        };
747
748        let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
749            label: Some("Elementwise Params"),
750            contents: bytemuck::bytes_of(&p),
751            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
752        });
753
754        let dummy = device.create_buffer(&wgpu::BufferDescriptor {
755            label: Some("Dummy"),
756            size: 16,
757            usage: wgpu::BufferUsages::STORAGE,
758            mapped_at_creation: false,
759        });
760        let input2_buf = input2.map(|(b, _, _)| b).unwrap_or(&dummy);
761
762        let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
763            label: Some("Elementwise Bind Group"),
764            layout: &ctx.elementwise_bind_group_layout,
765            entries: &[
766                wgpu::BindGroupEntry {
767                    binding: 0,
768                    resource: input1.as_entire_binding(),
769                },
770                wgpu::BindGroupEntry {
771                    binding: 1,
772                    resource: input2_buf.as_entire_binding(),
773                },
774                wgpu::BindGroupEntry {
775                    binding: 2,
776                    resource: output_buffer.as_entire_binding(),
777                },
778                wgpu::BindGroupEntry {
779                    binding: 3,
780                    resource: params_buffer.as_entire_binding(),
781                },
782            ],
783        });
784
785        record_commands(|encoder| {
786            let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
787                label: Some("Elementwise Pass"),
788                timestamp_writes: None,
789            });
790            cpass.set_pipeline(&ctx.elementwise_pipeline);
791            cpass.set_bind_group(0, &bind_group, &[]);
792            let workgroups = (output_size as u32 + 255) / 256;
793            cpass.dispatch_workgroups(workgroups, 1, 1);
794        });
795    }
796
797    output_buffer
798}
799
800pub fn matmul_wgpu_buffer(
801    lhs: &wgpu::Buffer,
802    lhs_shape: &[usize],
803    rhs: &wgpu::Buffer,
804    rhs_shape: &[usize],
805    activation: Activation,
806) -> wgpu::Buffer {
807    matmul_fused_wgpu_buffer(lhs, lhs_shape, rhs, rhs_shape, None, activation)
808}
809
810pub fn matmul_fused_wgpu_buffer(
811    lhs: &wgpu::Buffer,
812    lhs_shape: &[usize],
813    rhs: &wgpu::Buffer,
814    rhs_shape: &[usize],
815    bias: Option<(&wgpu::Buffer, &[usize])>,
816    activation: Activation,
817) -> wgpu::Buffer {
818    let ctx = get_context().expect("WGPU not initialized");
819    let device = &ctx.device;
820
821    let m = lhs_shape[0] as u32;
822    let k = lhs_shape[1] as u32;
823    let n = rhs_shape[1] as u32;
824
825    let output_size = (m * n) as usize;
826    let output_byte_size = output_size * std::mem::size_of::<f32>();
827    let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
828
829    #[repr(C)]
830    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
831    struct MatmulParams {
832        m: u32,
833        k: u32,
834        n: u32,
835        activation_packed: u32, // activation | (has_bias << 8)
836    }
837
838    let act_val = activation as u32;
839    let has_bias_val = if bias.is_some() { 1u32 } else { 0u32 };
840
841    let params = MatmulParams {
842        m,
843        k,
844        n,
845        activation_packed: act_val | (has_bias_val << 8),
846    };
847
848    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
849        label: Some("MatMul Params"),
850        contents: bytemuck::bytes_of(&params),
851        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
852    });
853
854    // Handle Bias (Optional)
855    let dummy = device.create_buffer(&wgpu::BufferDescriptor {
856        label: Some("Dummy Bias"),
857        size: 16,
858        usage: wgpu::BufferUsages::STORAGE,
859        mapped_at_creation: false,
860    });
861
862    let bias_buf = bias.map(|(b, _)| b).unwrap_or(&dummy);
863
864    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
865        label: Some("MatMul Bind Group"),
866        layout: &ctx.matmul_bind_group_layout,
867        entries: &[
868            wgpu::BindGroupEntry {
869                binding: 0,
870                resource: lhs.as_entire_binding(),
871            },
872            wgpu::BindGroupEntry {
873                binding: 1,
874                resource: rhs.as_entire_binding(),
875            },
876            wgpu::BindGroupEntry {
877                binding: 2,
878                resource: output_buffer.as_entire_binding(),
879            },
880            wgpu::BindGroupEntry {
881                binding: 3,
882                resource: params_buffer.as_entire_binding(),
883            },
884            wgpu::BindGroupEntry {
885                binding: 4,
886                resource: bias_buf.as_entire_binding(),
887            },
888        ],
889    });
890
891    record_commands(|encoder| {
892        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
893            label: Some("MatMul Pass"),
894            timestamp_writes: None,
895        });
896        cpass.set_pipeline(&ctx.matmul_pipeline);
897        cpass.set_bind_group(0, &bind_group, &[]);
898
899        // Workgroup size (16, 16)
900        let wg_x = (n + 15) / 16;
901        let wg_y = (m + 15) / 16;
902        cpass.dispatch_workgroups(wg_x, wg_y, 1);
903    });
904
905    output_buffer
906}
907
908pub fn adam_step_wgpu(
909    param: &wgpu::Buffer,
910    grad: &wgpu::Buffer,
911    m: &wgpu::Buffer,
912    v: &wgpu::Buffer,
913    numel: usize,
914    lr: f32,
915    beta1: f32,
916    beta2: f32,
917    epsilon: f32,
918    step: u32,
919) {
920    let ctx = get_context().expect("WGPU not initialized");
921    let device = &ctx.device;
922
923    #[repr(C)]
924    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
925    struct AdamParams {
926        lr: f32,
927        beta1: f32,
928        beta2: f32,
929        epsilon: f32,
930        step: u32,
931        numel: u32,
932        _pad1: u32,
933        _pad2: u32,
934    }
935
936    let params = AdamParams {
937        lr,
938        beta1,
939        beta2,
940        epsilon,
941        step,
942        numel: numel as u32,
943        _pad1: 0,
944        _pad2: 0,
945    };
946
947    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
948        label: Some("Adam Params"),
949        contents: bytemuck::bytes_of(&params),
950        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
951    });
952
953    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
954        label: Some("Adam Bind Group"),
955        layout: &ctx.adam_bind_group_layout,
956        entries: &[
957            wgpu::BindGroupEntry {
958                binding: 0,
959                resource: param.as_entire_binding(),
960            },
961            wgpu::BindGroupEntry {
962                binding: 1,
963                resource: grad.as_entire_binding(),
964            },
965            wgpu::BindGroupEntry {
966                binding: 2,
967                resource: m.as_entire_binding(),
968            },
969            wgpu::BindGroupEntry {
970                binding: 3,
971                resource: v.as_entire_binding(),
972            },
973            wgpu::BindGroupEntry {
974                binding: 4,
975                resource: params_buffer.as_entire_binding(),
976            },
977        ],
978    });
979
980    record_commands(|encoder| {
981        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
982            label: Some("Adam Pass"),
983            timestamp_writes: None,
984        });
985        cpass.set_pipeline(&ctx.adam_pipeline);
986        cpass.set_bind_group(0, &bind_group, &[]);
987        let workgroups = (numel as u32 + 255) / 256;
988        cpass.dispatch_workgroups(workgroups, 1, 1);
989    });
990}
991
992pub fn matmul_wgpu(lhs: &[f32], lhs_shape: &[usize], rhs: &[f32], rhs_shape: &[usize]) -> Vec<f32> {
993    let ctx = get_context().expect("WGPU not initialized");
994    let device = &ctx.device;
995
996    let lhs_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
997        label: Some("LHS"),
998        contents: bytemuck::cast_slice(lhs),
999        usage: wgpu::BufferUsages::STORAGE,
1000    });
1001    let rhs_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1002        label: Some("RHS"),
1003        contents: bytemuck::cast_slice(rhs),
1004        usage: wgpu::BufferUsages::STORAGE,
1005    });
1006
1007    let out_buf = matmul_wgpu_buffer(&lhs_buf, lhs_shape, &rhs_buf, rhs_shape, Activation::None);
1008
1009    // Create Staging Buffer
1010    let size = (lhs_shape[0] * rhs_shape[1] * 4) as u64;
1011    let staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
1012        label: Some("Staging Buffer"),
1013        size,
1014        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1015        mapped_at_creation: false,
1016    });
1017
1018    record_commands(|encoder| {
1019        encoder.copy_buffer_to_buffer(&out_buf, 0, &staging_buf, 0, size);
1020    });
1021
1022    flush_queue();
1023
1024    let slice = staging_buf.slice(..);
1025    let (tx, rx) = std::sync::mpsc::channel();
1026    slice.map_async(wgpu::MapMode::Read, move |v| tx.send(v).unwrap());
1027    device.poll(wgpu::Maintain::Wait);
1028    rx.recv().unwrap().unwrap();
1029
1030    let data = slice.get_mapped_range();
1031    let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
1032    drop(data);
1033    staging_buf.unmap();
1034
1035    get_memory_pool().return_buffer(out_buf, size);
1036
1037    result
1038}
1039
1040pub fn reduce_sum_dim0_wgpu(input: &wgpu::Buffer, input_shape: &[usize]) -> wgpu::Buffer {
1041    let ctx = get_context().expect("WGPU not initialized");
1042    let device = &ctx.device;
1043
1044    let batch_size = input_shape[0] as u32;
1045    let input_size = input_shape[1] as u32;
1046    let stride_batch = input_size;
1047    let total_input_size: usize = input_shape.iter().product();
1048
1049    let output_size = input_size as usize;
1050    let output_byte_size = output_size * std::mem::size_of::<f32>();
1051    let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1052
1053    #[repr(C)]
1054    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1055    struct ReduceParams {
1056        input_size: u32,
1057        reduce_dim_size: u32,
1058        reduce_dim_stride: u32,
1059        outer_size: u32,
1060    }
1061
1062    let params = ReduceParams {
1063        input_size: total_input_size as u32,
1064        reduce_dim_size: batch_size,
1065        reduce_dim_stride: stride_batch,
1066        outer_size: input_size,
1067    };
1068
1069    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1070        label: Some("Reduce Params"),
1071        contents: bytemuck::bytes_of(&params),
1072        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1073    });
1074
1075    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1076        label: Some("Reduce Bind Group"),
1077        layout: &ctx.reduce_bind_group_layout,
1078        entries: &[
1079            wgpu::BindGroupEntry {
1080                binding: 0,
1081                resource: input.as_entire_binding(),
1082            },
1083            wgpu::BindGroupEntry {
1084                binding: 1,
1085                resource: output_buffer.as_entire_binding(),
1086            },
1087            wgpu::BindGroupEntry {
1088                binding: 2,
1089                resource: params_buffer.as_entire_binding(),
1090            },
1091        ],
1092    });
1093
1094    record_commands(|encoder| {
1095        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1096            label: Some("Reduce Pass"),
1097            timestamp_writes: None,
1098        });
1099        cpass.set_pipeline(&ctx.reduce_pipeline);
1100        cpass.set_bind_group(0, &bind_group, &[]);
1101        let workgroups = (input_size + 255) / 256;
1102        cpass.dispatch_workgroups(workgroups, 1, 1);
1103    });
1104
1105    output_buffer
1106}
1107
1108pub fn reduce_sum_all_wgpu(input: &wgpu::Buffer, input_size: usize) -> wgpu::Buffer {
1109    let ctx = get_context().expect("WGPU not initialized");
1110    let device = &ctx.device;
1111
1112    let output_byte_size = std::mem::size_of::<f32>();
1113    let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1114
1115    #[repr(C)]
1116    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1117    struct ReduceParams {
1118        input_size: u32,
1119        reduce_dim_size: u32,
1120        reduce_dim_stride: u32,
1121        outer_size: u32,
1122    }
1123
1124    let params = ReduceParams {
1125        input_size: input_size as u32,
1126        reduce_dim_size: 0,
1127        reduce_dim_stride: 0,
1128        outer_size: 0,
1129    };
1130
1131    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1132        label: Some("Reduce Params"),
1133        contents: bytemuck::bytes_of(&params),
1134        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1135    });
1136
1137    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1138        label: Some("Reduce Bind Group"),
1139        layout: &ctx.reduce_bind_group_layout,
1140        entries: &[
1141            wgpu::BindGroupEntry {
1142                binding: 0,
1143                resource: input.as_entire_binding(),
1144            },
1145            wgpu::BindGroupEntry {
1146                binding: 1,
1147                resource: output_buffer.as_entire_binding(),
1148            },
1149            wgpu::BindGroupEntry {
1150                binding: 2,
1151                resource: params_buffer.as_entire_binding(),
1152            },
1153        ],
1154    });
1155
1156    record_commands(|encoder| {
1157        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1158            label: Some("Reduce All Pass"),
1159            timestamp_writes: None,
1160        });
1161        cpass.set_pipeline(&ctx.reduce_all_pipeline);
1162        cpass.set_bind_group(0, &bind_group, &[]);
1163        cpass.dispatch_workgroups(1, 1, 1);
1164    });
1165
1166    output_buffer
1167}
1168
1169pub fn reduce_sum_dim_wgpu(
1170    input: &wgpu::Buffer,
1171    input_shape: &[usize],
1172    dim: usize,
1173) -> wgpu::Buffer {
1174    let ctx = get_context().expect("WGPU not initialized");
1175    let device = &ctx.device;
1176
1177    let ndim = input_shape.len();
1178    if dim >= ndim {
1179        panic!("Invalid reduction dimension");
1180    }
1181
1182    if ndim == 1 || dim == 0 && ndim == 2 {
1183        return reduce_sum_dim0_wgpu(input, input_shape);
1184    }
1185
1186    let total_input_size: usize = input_shape.iter().product();
1187    let reduce_dim_size = input_shape[dim];
1188
1189    let outer_size: usize = if dim < ndim - 1 {
1190        input_shape[dim + 1..].iter().product()
1191    } else {
1192        1
1193    };
1194
1195    let output_size = total_input_size / reduce_dim_size;
1196    let output_byte_size = output_size * std::mem::size_of::<f32>();
1197    let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1198
1199    #[repr(C)]
1200    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1201    struct ReduceParams {
1202        input_size: u32,
1203        reduce_dim_size: u32,
1204        reduce_dim_stride: u32,
1205        outer_size: u32,
1206    }
1207
1208    let params = ReduceParams {
1209        input_size: total_input_size as u32,
1210        reduce_dim_size: reduce_dim_size as u32,
1211        reduce_dim_stride: outer_size as u32,
1212        outer_size: outer_size as u32,
1213    };
1214
1215    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1216        label: Some("Reduce Params"),
1217        contents: bytemuck::bytes_of(&params),
1218        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1219    });
1220
1221    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1222        label: Some("Reduce Bind Group"),
1223        layout: &ctx.reduce_bind_group_layout,
1224        entries: &[
1225            wgpu::BindGroupEntry {
1226                binding: 0,
1227                resource: input.as_entire_binding(),
1228            },
1229            wgpu::BindGroupEntry {
1230                binding: 1,
1231                resource: output_buffer.as_entire_binding(),
1232            },
1233            wgpu::BindGroupEntry {
1234                binding: 2,
1235                resource: params_buffer.as_entire_binding(),
1236            },
1237        ],
1238    });
1239
1240    record_commands(|encoder| {
1241        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1242            label: Some("Reduce General Pass"),
1243            timestamp_writes: None,
1244        });
1245        cpass.set_pipeline(&ctx.reduce_general_pipeline);
1246        cpass.set_bind_group(0, &bind_group, &[]);
1247        let workgroups = ((output_size as u32 + 255) / 256).max(1);
1248        cpass.dispatch_workgroups(workgroups, 1, 1);
1249    });
1250
1251    output_buffer
1252}
1253
1254pub fn contiguous_wgpu(input: &wgpu::Buffer, shape: &[usize], strides: &[usize]) -> wgpu::Buffer {
1255    let ctx = get_context().expect("WGPU not initialized");
1256    let device = &ctx.device;
1257
1258    let size: usize = shape.iter().product();
1259    let output_byte_size = size * std::mem::size_of::<f32>();
1260    let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1261
1262    #[repr(C)]
1263    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1264    struct ContiguousParams {
1265        size: u32,
1266        ndim: u32,
1267        pad0: u32,
1268        pad1: u32,
1269    }
1270
1271    let params = ContiguousParams {
1272        size: size as u32,
1273        ndim: shape.len() as u32,
1274        pad0: 0,
1275        pad1: 0,
1276    };
1277
1278    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1279        label: Some("Contiguous Params"),
1280        contents: bytemuck::bytes_of(&params),
1281        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1282    });
1283
1284    #[repr(C)]
1285    #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1286    struct ShapeStride {
1287        shape0: u32,
1288        shape1: u32,
1289        shape2: u32,
1290        shape3: u32,
1291        shape4: u32,
1292        shape5: u32,
1293        shape6: u32,
1294        shape7: u32,
1295        stride0: u32,
1296        stride1: u32,
1297        stride2: u32,
1298        stride3: u32,
1299        stride4: u32,
1300        stride5: u32,
1301        stride6: u32,
1302        stride7: u32,
1303    }
1304
1305    let shape_stride = ShapeStride {
1306        shape0: shape.get(0).copied().unwrap_or(1) as u32,
1307        shape1: shape.get(1).copied().unwrap_or(1) as u32,
1308        shape2: shape.get(2).copied().unwrap_or(1) as u32,
1309        shape3: shape.get(3).copied().unwrap_or(1) as u32,
1310        shape4: shape.get(4).copied().unwrap_or(1) as u32,
1311        shape5: shape.get(5).copied().unwrap_or(1) as u32,
1312        shape6: shape.get(6).copied().unwrap_or(1) as u32,
1313        shape7: shape.get(7).copied().unwrap_or(1) as u32,
1314        stride0: strides.get(0).copied().unwrap_or(0) as u32,
1315        stride1: strides.get(1).copied().unwrap_or(0) as u32,
1316        stride2: strides.get(2).copied().unwrap_or(0) as u32,
1317        stride3: strides.get(3).copied().unwrap_or(0) as u32,
1318        stride4: strides.get(4).copied().unwrap_or(0) as u32,
1319        stride5: strides.get(5).copied().unwrap_or(0) as u32,
1320        stride6: strides.get(6).copied().unwrap_or(0) as u32,
1321        stride7: strides.get(7).copied().unwrap_or(0) as u32,
1322    };
1323
1324    let shape_strides_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1325        label: Some("Shape Strides"),
1326        contents: bytemuck::bytes_of(&shape_stride),
1327        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1328    });
1329
1330    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1331        label: Some("Contiguous Bind Group"),
1332        layout: &ctx.contiguous_bind_group_layout,
1333        entries: &[
1334            wgpu::BindGroupEntry {
1335                binding: 0,
1336                resource: input.as_entire_binding(),
1337            },
1338            wgpu::BindGroupEntry {
1339                binding: 1,
1340                resource: output_buffer.as_entire_binding(),
1341            },
1342            wgpu::BindGroupEntry {
1343                binding: 2,
1344                resource: params_buffer.as_entire_binding(),
1345            },
1346            wgpu::BindGroupEntry {
1347                binding: 3,
1348                resource: shape_strides_buffer.as_entire_binding(),
1349            },
1350        ],
1351    });
1352
1353    record_commands(|encoder| {
1354        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1355            label: Some("Contiguous Pass"),
1356            timestamp_writes: None,
1357        });
1358        cpass.set_pipeline(&ctx.contiguous_pipeline);
1359        cpass.set_bind_group(0, &bind_group, &[]);
1360        let workgroups = ((size as u32 + 255) / 256).max(1);
1361        cpass.dispatch_workgroups(workgroups, 1, 1);
1362    });
1363
1364    output_buffer
1365}