1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3use wgpu;
4use wgpu::util::DeviceExt;
5
6pub struct MemoryPool {
7 pool: Mutex<HashMap<u64, Vec<wgpu::Buffer>>>,
8}
9
10impl MemoryPool {
11 fn new() -> Self {
12 Self {
13 pool: Mutex::new(HashMap::new()),
14 }
15 }
16
17 pub fn get_buffer(&self, device: &wgpu::Device, size: u64) -> wgpu::Buffer {
18 let mut pool = self.pool.lock().unwrap();
19 if let Some(buffers) = pool.get_mut(&size) {
20 if let Some(buf) = buffers.pop() {
21 return buf;
22 }
23 }
24
25 device.create_buffer(&wgpu::BufferDescriptor {
26 label: Some("Pooled Buffer"),
27 size,
28 usage: wgpu::BufferUsages::STORAGE
29 | wgpu::BufferUsages::COPY_SRC
30 | wgpu::BufferUsages::COPY_DST
31 | wgpu::BufferUsages::UNIFORM,
32 mapped_at_creation: false,
33 })
34 }
35
36 pub fn return_buffer(&self, buffer: wgpu::Buffer, size: u64) {
37 let mut pool = self.pool.lock().unwrap();
38 pool.entry(size).or_insert_with(Vec::new).push(buffer);
39 }
40
41 pub fn clear(&self) {
42 let mut pool = self.pool.lock().unwrap();
43 pool.clear();
44 }
45}
46
47static MEMORY_POOL: OnceLock<MemoryPool> = OnceLock::new();
48
49pub fn get_memory_pool() -> &'static MemoryPool {
50 MEMORY_POOL.get_or_init(MemoryPool::new)
51}
52
53pub fn clear_memory_pool() {
54 get_memory_pool().clear();
55}
56
57pub struct WgpuContext {
58 pub device: wgpu::Device,
59 pub queue: wgpu::Queue,
60 pub matmul_pipeline: wgpu::ComputePipeline,
61 pub matmul_bind_group_layout: wgpu::BindGroupLayout,
62 pub elementwise_pipeline: wgpu::ComputePipeline,
63 pub elementwise_bind_group_layout: wgpu::BindGroupLayout,
64 pub elementwise_vec4_pipeline: wgpu::ComputePipeline,
65 pub adam_pipeline: wgpu::ComputePipeline,
66 pub adam_bind_group_layout: wgpu::BindGroupLayout,
67 pub reduce_pipeline: wgpu::ComputePipeline,
68 pub reduce_general_pipeline: wgpu::ComputePipeline,
69 pub reduce_all_pipeline: wgpu::ComputePipeline,
70 pub reduce_bind_group_layout: wgpu::BindGroupLayout,
71 pub contiguous_pipeline: wgpu::ComputePipeline,
72 pub contiguous_bind_group_layout: wgpu::BindGroupLayout,
73 pub current_encoder: Mutex<Option<wgpu::CommandEncoder>>,
74}
75
76static CONTEXT: OnceLock<WgpuContext> = OnceLock::new();
77
78pub fn get_context() -> Option<&'static WgpuContext> {
79 Some(CONTEXT.get_or_init(|| {
80 let instance = wgpu::Instance::default();
81 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
82 power_preference: wgpu::PowerPreference::HighPerformance,
83 compatible_surface: None,
84 force_fallback_adapter: false,
85 }))
86 .expect("No suitable WGPU adapter found");
87
88 let (device, queue) = pollster::block_on(adapter.request_device(
89 &wgpu::DeviceDescriptor {
90 label: Some("RusTorch Device"),
91 required_features: wgpu::Features::empty(),
92 required_limits: wgpu::Limits::downlevel_defaults(),
93 },
94 None,
95 ))
96 .expect("Failed to create device");
97
98 let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
100 label: Some("MatMul Shader"),
101 source: wgpu::ShaderSource::Wgsl(include_str!("matmul.wgsl").into()),
102 });
103
104 let matmul_bind_group_layout =
105 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
106 label: Some("MatMul Bind Group Layout"),
107 entries: &[
108 wgpu::BindGroupLayoutEntry {
109 binding: 0,
110 visibility: wgpu::ShaderStages::COMPUTE,
111 ty: wgpu::BindingType::Buffer {
112 ty: wgpu::BufferBindingType::Storage { read_only: true },
113 has_dynamic_offset: false,
114 min_binding_size: None,
115 },
116 count: None,
117 },
118 wgpu::BindGroupLayoutEntry {
119 binding: 1,
120 visibility: wgpu::ShaderStages::COMPUTE,
121 ty: wgpu::BindingType::Buffer {
122 ty: wgpu::BufferBindingType::Storage { read_only: true },
123 has_dynamic_offset: false,
124 min_binding_size: None,
125 },
126 count: None,
127 },
128 wgpu::BindGroupLayoutEntry {
129 binding: 2,
130 visibility: wgpu::ShaderStages::COMPUTE,
131 ty: wgpu::BindingType::Buffer {
132 ty: wgpu::BufferBindingType::Storage { read_only: false },
133 has_dynamic_offset: false,
134 min_binding_size: None,
135 },
136 count: None,
137 },
138 wgpu::BindGroupLayoutEntry {
139 binding: 3,
140 visibility: wgpu::ShaderStages::COMPUTE,
141 ty: wgpu::BindingType::Buffer {
142 ty: wgpu::BufferBindingType::Uniform,
143 has_dynamic_offset: false,
144 min_binding_size: None,
145 },
146 count: None,
147 },
148 wgpu::BindGroupLayoutEntry {
149 binding: 4,
150 visibility: wgpu::ShaderStages::COMPUTE,
151 ty: wgpu::BindingType::Buffer {
152 ty: wgpu::BufferBindingType::Storage { read_only: true },
153 has_dynamic_offset: false,
154 min_binding_size: None,
155 },
156 count: None,
157 },
158 ],
159 });
160
161 let matmul_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
162 label: Some("MatMul Pipeline"),
163 layout: Some(
164 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
165 label: Some("MatMul Pipeline Layout"),
166 bind_group_layouts: &[&matmul_bind_group_layout],
167 push_constant_ranges: &[],
168 }),
169 ),
170 module: &matmul_shader,
171 entry_point: "main",
172 });
173
174 let elem_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
176 label: Some("Elementwise Shader"),
177 source: wgpu::ShaderSource::Wgsl(include_str!("elementwise.wgsl").into()),
178 });
179
180 let elem_bind_group_layout =
181 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
182 label: Some("Elementwise Bind Group Layout"),
183 entries: &[
184 wgpu::BindGroupLayoutEntry {
185 binding: 0,
186 visibility: wgpu::ShaderStages::COMPUTE,
187 ty: wgpu::BindingType::Buffer {
188 ty: wgpu::BufferBindingType::Storage { read_only: true },
189 has_dynamic_offset: false,
190 min_binding_size: None,
191 },
192 count: None,
193 },
194 wgpu::BindGroupLayoutEntry {
195 binding: 1,
196 visibility: wgpu::ShaderStages::COMPUTE,
197 ty: wgpu::BindingType::Buffer {
198 ty: wgpu::BufferBindingType::Storage { read_only: true },
199 has_dynamic_offset: false,
200 min_binding_size: None,
201 },
202 count: None,
203 },
204 wgpu::BindGroupLayoutEntry {
205 binding: 2,
206 visibility: wgpu::ShaderStages::COMPUTE,
207 ty: wgpu::BindingType::Buffer {
208 ty: wgpu::BufferBindingType::Storage { read_only: false },
209 has_dynamic_offset: false,
210 min_binding_size: None,
211 },
212 count: None,
213 },
214 wgpu::BindGroupLayoutEntry {
215 binding: 3,
216 visibility: wgpu::ShaderStages::COMPUTE,
217 ty: wgpu::BindingType::Buffer {
218 ty: wgpu::BufferBindingType::Uniform,
219 has_dynamic_offset: false,
220 min_binding_size: None,
221 },
222 count: None,
223 },
224 ],
225 });
226
227 let elementwise_pipeline =
228 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
229 label: Some("Elementwise Pipeline"),
230 layout: Some(
231 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
232 label: Some("Elementwise Pipeline Layout"),
233 bind_group_layouts: &[&elem_bind_group_layout],
234 push_constant_ranges: &[],
235 }),
236 ),
237 module: &elem_shader,
238 entry_point: "main",
239 });
240
241 let elem_vec4_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
243 label: Some("Elementwise Vec4 Shader"),
244 source: wgpu::ShaderSource::Wgsl(include_str!("elementwise_vec4.wgsl").into()),
245 });
246
247 let elementwise_vec4_pipeline =
248 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
249 label: Some("Elementwise Vec4 Pipeline"),
250 layout: Some(
251 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
252 label: Some("Elementwise Vec4 Pipeline Layout"),
253 bind_group_layouts: &[&elem_bind_group_layout], push_constant_ranges: &[],
255 }),
256 ),
257 module: &elem_vec4_shader,
258 entry_point: "main",
259 });
260
261 let adam_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
263 label: Some("Adam Shader"),
264 source: wgpu::ShaderSource::Wgsl(include_str!("adam.wgsl").into()),
265 });
266
267 let adam_bind_group_layout =
268 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
269 label: Some("Adam Bind Group Layout"),
270 entries: &[
271 wgpu::BindGroupLayoutEntry {
272 binding: 0,
273 visibility: wgpu::ShaderStages::COMPUTE,
274 ty: wgpu::BindingType::Buffer {
275 ty: wgpu::BufferBindingType::Storage { read_only: false },
276 has_dynamic_offset: false,
277 min_binding_size: None,
278 },
279 count: None,
280 },
281 wgpu::BindGroupLayoutEntry {
282 binding: 1,
283 visibility: wgpu::ShaderStages::COMPUTE,
284 ty: wgpu::BindingType::Buffer {
285 ty: wgpu::BufferBindingType::Storage { read_only: true },
286 has_dynamic_offset: false,
287 min_binding_size: None,
288 },
289 count: None,
290 },
291 wgpu::BindGroupLayoutEntry {
292 binding: 2,
293 visibility: wgpu::ShaderStages::COMPUTE,
294 ty: wgpu::BindingType::Buffer {
295 ty: wgpu::BufferBindingType::Storage { read_only: false },
296 has_dynamic_offset: false,
297 min_binding_size: None,
298 },
299 count: None,
300 },
301 wgpu::BindGroupLayoutEntry {
302 binding: 3,
303 visibility: wgpu::ShaderStages::COMPUTE,
304 ty: wgpu::BindingType::Buffer {
305 ty: wgpu::BufferBindingType::Storage { read_only: false },
306 has_dynamic_offset: false,
307 min_binding_size: None,
308 },
309 count: None,
310 },
311 wgpu::BindGroupLayoutEntry {
312 binding: 4,
313 visibility: wgpu::ShaderStages::COMPUTE,
314 ty: wgpu::BindingType::Buffer {
315 ty: wgpu::BufferBindingType::Uniform,
316 has_dynamic_offset: false,
317 min_binding_size: None,
318 },
319 count: None,
320 },
321 ],
322 });
323
324 let adam_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
325 label: Some("Adam Pipeline"),
326 layout: Some(
327 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
328 label: Some("Adam Pipeline Layout"),
329 bind_group_layouts: &[&adam_bind_group_layout],
330 push_constant_ranges: &[],
331 }),
332 ),
333 module: &adam_shader,
334 entry_point: "main",
335 });
336
337 let reduce_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
339 label: Some("Reduce Shader"),
340 source: wgpu::ShaderSource::Wgsl(include_str!("reduce.wgsl").into()),
341 });
342
343 let reduce_bind_group_layout =
344 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
345 label: Some("Reduce Bind Group Layout"),
346 entries: &[
347 wgpu::BindGroupLayoutEntry {
348 binding: 0,
349 visibility: wgpu::ShaderStages::COMPUTE,
350 ty: wgpu::BindingType::Buffer {
351 ty: wgpu::BufferBindingType::Storage { read_only: true },
352 has_dynamic_offset: false,
353 min_binding_size: None,
354 },
355 count: None,
356 },
357 wgpu::BindGroupLayoutEntry {
358 binding: 1,
359 visibility: wgpu::ShaderStages::COMPUTE,
360 ty: wgpu::BindingType::Buffer {
361 ty: wgpu::BufferBindingType::Storage { read_only: false },
362 has_dynamic_offset: false,
363 min_binding_size: None,
364 },
365 count: None,
366 },
367 wgpu::BindGroupLayoutEntry {
368 binding: 2,
369 visibility: wgpu::ShaderStages::COMPUTE,
370 ty: wgpu::BindingType::Buffer {
371 ty: wgpu::BufferBindingType::Uniform,
372 has_dynamic_offset: false,
373 min_binding_size: None,
374 },
375 count: None,
376 },
377 ],
378 });
379
380 let reduce_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
381 label: Some("Reduce Pipeline"),
382 layout: Some(
383 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
384 label: Some("Reduce Pipeline Layout"),
385 bind_group_layouts: &[&reduce_bind_group_layout],
386 push_constant_ranges: &[],
387 }),
388 ),
389 module: &reduce_shader,
390 entry_point: "reduce_sum_dim0",
391 });
392
393 let reduce_general_pipeline =
394 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
395 label: Some("Reduce General Pipeline"),
396 layout: Some(
397 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
398 label: Some("Reduce Pipeline Layout"),
399 bind_group_layouts: &[&reduce_bind_group_layout],
400 push_constant_ranges: &[],
401 }),
402 ),
403 module: &reduce_shader,
404 entry_point: "reduce_sum_general",
405 });
406
407 let reduce_all_pipeline =
408 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
409 label: Some("Reduce All Pipeline"),
410 layout: Some(
411 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
412 label: Some("Reduce Pipeline Layout"),
413 bind_group_layouts: &[&reduce_bind_group_layout],
414 push_constant_ranges: &[],
415 }),
416 ),
417 module: &reduce_shader,
418 entry_point: "reduce_sum_all",
419 });
420
421 let contiguous_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
423 label: Some("Contiguous Shader"),
424 source: wgpu::ShaderSource::Wgsl(include_str!("contiguous.wgsl").into()),
425 });
426
427 let contiguous_bind_group_layout =
428 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
429 label: Some("Contiguous Bind Group Layout"),
430 entries: &[
431 wgpu::BindGroupLayoutEntry {
432 binding: 0,
433 visibility: wgpu::ShaderStages::COMPUTE,
434 ty: wgpu::BindingType::Buffer {
435 ty: wgpu::BufferBindingType::Storage { read_only: true },
436 has_dynamic_offset: false,
437 min_binding_size: None,
438 },
439 count: None,
440 },
441 wgpu::BindGroupLayoutEntry {
442 binding: 1,
443 visibility: wgpu::ShaderStages::COMPUTE,
444 ty: wgpu::BindingType::Buffer {
445 ty: wgpu::BufferBindingType::Storage { read_only: false },
446 has_dynamic_offset: false,
447 min_binding_size: None,
448 },
449 count: None,
450 },
451 wgpu::BindGroupLayoutEntry {
452 binding: 2,
453 visibility: wgpu::ShaderStages::COMPUTE,
454 ty: wgpu::BindingType::Buffer {
455 ty: wgpu::BufferBindingType::Uniform,
456 has_dynamic_offset: false,
457 min_binding_size: None,
458 },
459 count: None,
460 },
461 wgpu::BindGroupLayoutEntry {
462 binding: 3,
463 visibility: wgpu::ShaderStages::COMPUTE,
464 ty: wgpu::BindingType::Buffer {
465 ty: wgpu::BufferBindingType::Storage { read_only: true },
466 has_dynamic_offset: false,
467 min_binding_size: None,
468 },
469 count: None,
470 },
471 ],
472 });
473
474 let contiguous_pipeline =
475 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
476 label: Some("Contiguous Pipeline"),
477 layout: Some(
478 &device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
479 label: Some("Contiguous Pipeline Layout"),
480 bind_group_layouts: &[&contiguous_bind_group_layout],
481 push_constant_ranges: &[],
482 }),
483 ),
484 module: &contiguous_shader,
485 entry_point: "main",
486 });
487
488 WgpuContext {
489 device,
490 queue,
491 matmul_pipeline,
492 matmul_bind_group_layout,
493 elementwise_pipeline,
494 elementwise_bind_group_layout: elem_bind_group_layout,
495 elementwise_vec4_pipeline,
496 adam_pipeline,
497 adam_bind_group_layout,
498 reduce_pipeline,
499 reduce_general_pipeline,
500 reduce_all_pipeline,
501 reduce_bind_group_layout,
502 contiguous_pipeline,
503 contiguous_bind_group_layout,
504 current_encoder: Mutex::new(None),
505 }
506 }))
507}
508
509pub fn flush_queue() {
510 let ctx = get_context().expect("WGPU not initialized");
511 let mut lock = ctx.current_encoder.lock().unwrap();
512 if let Some(encoder) = lock.take() {
513 ctx.queue.submit(Some(encoder.finish()));
514 }
515}
516
517pub fn record_commands<F>(f: F)
518where
519 F: FnOnce(&mut wgpu::CommandEncoder),
520{
521 let ctx = get_context().expect("WGPU not initialized");
522 let mut lock = ctx.current_encoder.lock().unwrap();
523
524 if lock.is_none() {
525 *lock = Some(
526 ctx.device
527 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
528 label: Some("Buffered Encoder"),
529 }),
530 );
531 }
532
533 let encoder = lock.as_mut().unwrap();
534 f(encoder);
535}
536
537#[derive(Debug, Clone, Copy, PartialEq, Eq)]
538pub enum ElementwiseOp {
539 Add = 0,
540 Sub = 1,
541 Mul = 2,
542 Div = 3,
543 ReLU = 4,
544 Sigmoid = 5,
545 Tanh = 6,
546 ReLUBackward = 10,
547 SigmoidBackward = 11,
548 TanhBackward = 12,
549 SGDStep = 20,
550 ExpandRepeat = 30,
551}
552
553pub use crate::backend::Activation;
554
555pub fn elementwise_wgpu_buffer(
556 input1: &wgpu::Buffer,
557 input1_shape: &[usize],
558 input1_strides: &[usize],
559 input2: Option<(&wgpu::Buffer, &[usize], &[usize])>,
560 output_shape: &[usize],
561 op: ElementwiseOp,
562 alpha: Option<f32>,
563) -> wgpu::Buffer {
564 let ctx = get_context().expect("WGPU not initialized");
565 let device = &ctx.device;
566 let output_size: usize = output_shape.iter().product();
569 let output_byte_size = output_size * std::mem::size_of::<f32>();
570 let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
571
572 fn is_contiguous_or_scalar(shape: &[usize], strides: &[usize]) -> bool {
573 if shape.iter().product::<usize>() == 1 {
574 return true;
575 }
576 let mut st = 1;
577 for i in (0..shape.len()).rev() {
578 if strides[i] != st {
579 return false;
580 }
581 st *= shape[i];
582 }
583 return true;
584 }
585
586 let can_use_vec4 = output_size % 4 == 0
587 && is_contiguous_or_scalar(input1_shape, input1_strides)
588 && input2.map_or(true, |(_, s2, st2)| is_contiguous_or_scalar(s2, st2));
589
590 if can_use_vec4 {
591 #[repr(C)]
592 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
593 struct ParamsVec4 {
594 numel_vec4: u32,
595 op: u32,
596 alpha: f32,
597 stride_mode_1: u32,
598 stride_mode_2: u32,
599 }
600
601 let stride_mode_1 = if input1_shape.iter().product::<usize>() == 1 {
602 1
603 } else {
604 0
605 };
606 let stride_mode_2 = if let Some((_, s2, _)) = input2 {
607 if s2.iter().product::<usize>() == 1 {
608 1
609 } else {
610 0
611 }
612 } else {
613 0
614 };
615
616 let params = ParamsVec4 {
617 numel_vec4: (output_size / 4) as u32,
618 op: op as u32,
619 alpha: alpha.unwrap_or(0.0),
620 stride_mode_1,
621 stride_mode_2,
622 };
623
624 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
625 label: Some("Elementwise Vec4 Params"),
626 contents: bytemuck::bytes_of(¶ms),
627 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
628 });
629
630 let dummy = device.create_buffer(&wgpu::BufferDescriptor {
632 label: Some("Dummy"),
633 size: 16,
634 usage: wgpu::BufferUsages::STORAGE,
635 mapped_at_creation: false,
636 });
637
638 let input2_buf = input2.map(|(b, _, _)| b).unwrap_or(&dummy);
639
640 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
641 label: Some("Elementwise Vec4 Bind Group"),
642 layout: &ctx.elementwise_bind_group_layout, entries: &[
644 wgpu::BindGroupEntry {
645 binding: 0,
646 resource: input1.as_entire_binding(),
647 },
648 wgpu::BindGroupEntry {
649 binding: 1,
650 resource: input2_buf.as_entire_binding(),
651 },
652 wgpu::BindGroupEntry {
653 binding: 2,
654 resource: output_buffer.as_entire_binding(),
655 },
656 wgpu::BindGroupEntry {
657 binding: 3,
658 resource: params_buffer.as_entire_binding(),
659 },
660 ],
661 });
662
663 record_commands(|encoder| {
664 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
665 label: Some("Elementwise Vec4 Pass"),
666 timestamp_writes: None,
667 });
668 cpass.set_pipeline(&ctx.elementwise_vec4_pipeline);
669 cpass.set_bind_group(0, &bind_group, &[]);
670 let workgroups = (params.numel_vec4 + 255) / 256;
671 cpass.dispatch_workgroups(workgroups, 1, 1);
672 });
673 } else {
674 #[repr(C)]
675 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
676 struct Params {
677 numel: u32,
678 op: u32,
679 alpha: f32,
680 ndim: u32,
681 shape: [u32; 4],
682 strides_out: [u32; 4],
683 strides_1: [u32; 4],
684 strides_2: [u32; 4],
685 }
686
687 let mut shape_arr = [1u32; 4];
688 let stride_out_arr = [0u32; 4]; let mut stride_1_arr = [0u32; 4];
690 let mut stride_2_arr = [0u32; 4];
691
692 let ndim = output_shape.len();
699 for i in 0..ndim {
700 shape_arr[4 - ndim + i] = output_shape[i] as u32;
701 }
702
703 fn fill_strides(
704 shape: &[usize],
705 strides: &[usize],
706 target_shape: &[usize],
707 out_arr: &mut [u32; 4],
708 ) {
709 let ndim = target_shape.len();
710 let s_ndim = shape.len();
711 let offset = ndim - s_ndim;
712
713 for i in 0..ndim {
714 if i >= offset {
715 let s_idx = i - offset;
716 if shape[s_idx] == 1 && target_shape[i] > 1 {
717 out_arr[4 - ndim + i] = 0; } else {
719 out_arr[4 - ndim + i] = strides[s_idx] as u32;
720 }
721 } else {
722 out_arr[4 - ndim + i] = 0; }
724 }
725 }
726
727 fill_strides(
728 input1_shape,
729 input1_strides,
730 output_shape,
731 &mut stride_1_arr,
732 );
733 if let Some((_, s2, st2)) = input2 {
734 fill_strides(s2, st2, output_shape, &mut stride_2_arr);
735 }
736
737 let p = Params {
738 numel: output_size as u32,
739 op: op as u32,
740 alpha: alpha.unwrap_or(0.0),
741 ndim: ndim as u32,
742 shape: shape_arr,
743 strides_out: stride_out_arr,
744 strides_1: stride_1_arr,
745 strides_2: stride_2_arr,
746 };
747
748 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
749 label: Some("Elementwise Params"),
750 contents: bytemuck::bytes_of(&p),
751 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
752 });
753
754 let dummy = device.create_buffer(&wgpu::BufferDescriptor {
755 label: Some("Dummy"),
756 size: 16,
757 usage: wgpu::BufferUsages::STORAGE,
758 mapped_at_creation: false,
759 });
760 let input2_buf = input2.map(|(b, _, _)| b).unwrap_or(&dummy);
761
762 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
763 label: Some("Elementwise Bind Group"),
764 layout: &ctx.elementwise_bind_group_layout,
765 entries: &[
766 wgpu::BindGroupEntry {
767 binding: 0,
768 resource: input1.as_entire_binding(),
769 },
770 wgpu::BindGroupEntry {
771 binding: 1,
772 resource: input2_buf.as_entire_binding(),
773 },
774 wgpu::BindGroupEntry {
775 binding: 2,
776 resource: output_buffer.as_entire_binding(),
777 },
778 wgpu::BindGroupEntry {
779 binding: 3,
780 resource: params_buffer.as_entire_binding(),
781 },
782 ],
783 });
784
785 record_commands(|encoder| {
786 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
787 label: Some("Elementwise Pass"),
788 timestamp_writes: None,
789 });
790 cpass.set_pipeline(&ctx.elementwise_pipeline);
791 cpass.set_bind_group(0, &bind_group, &[]);
792 let workgroups = (output_size as u32 + 255) / 256;
793 cpass.dispatch_workgroups(workgroups, 1, 1);
794 });
795 }
796
797 output_buffer
798}
799
800pub fn matmul_wgpu_buffer(
801 lhs: &wgpu::Buffer,
802 lhs_shape: &[usize],
803 rhs: &wgpu::Buffer,
804 rhs_shape: &[usize],
805 activation: Activation,
806) -> wgpu::Buffer {
807 matmul_fused_wgpu_buffer(lhs, lhs_shape, rhs, rhs_shape, None, activation)
808}
809
810pub fn matmul_fused_wgpu_buffer(
811 lhs: &wgpu::Buffer,
812 lhs_shape: &[usize],
813 rhs: &wgpu::Buffer,
814 rhs_shape: &[usize],
815 bias: Option<(&wgpu::Buffer, &[usize])>,
816 activation: Activation,
817) -> wgpu::Buffer {
818 let ctx = get_context().expect("WGPU not initialized");
819 let device = &ctx.device;
820
821 let m = lhs_shape[0] as u32;
822 let k = lhs_shape[1] as u32;
823 let n = rhs_shape[1] as u32;
824
825 let output_size = (m * n) as usize;
826 let output_byte_size = output_size * std::mem::size_of::<f32>();
827 let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
828
829 #[repr(C)]
830 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
831 struct MatmulParams {
832 m: u32,
833 k: u32,
834 n: u32,
835 activation_packed: u32, }
837
838 let act_val = activation as u32;
839 let has_bias_val = if bias.is_some() { 1u32 } else { 0u32 };
840
841 let params = MatmulParams {
842 m,
843 k,
844 n,
845 activation_packed: act_val | (has_bias_val << 8),
846 };
847
848 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
849 label: Some("MatMul Params"),
850 contents: bytemuck::bytes_of(¶ms),
851 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
852 });
853
854 let dummy = device.create_buffer(&wgpu::BufferDescriptor {
856 label: Some("Dummy Bias"),
857 size: 16,
858 usage: wgpu::BufferUsages::STORAGE,
859 mapped_at_creation: false,
860 });
861
862 let bias_buf = bias.map(|(b, _)| b).unwrap_or(&dummy);
863
864 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
865 label: Some("MatMul Bind Group"),
866 layout: &ctx.matmul_bind_group_layout,
867 entries: &[
868 wgpu::BindGroupEntry {
869 binding: 0,
870 resource: lhs.as_entire_binding(),
871 },
872 wgpu::BindGroupEntry {
873 binding: 1,
874 resource: rhs.as_entire_binding(),
875 },
876 wgpu::BindGroupEntry {
877 binding: 2,
878 resource: output_buffer.as_entire_binding(),
879 },
880 wgpu::BindGroupEntry {
881 binding: 3,
882 resource: params_buffer.as_entire_binding(),
883 },
884 wgpu::BindGroupEntry {
885 binding: 4,
886 resource: bias_buf.as_entire_binding(),
887 },
888 ],
889 });
890
891 record_commands(|encoder| {
892 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
893 label: Some("MatMul Pass"),
894 timestamp_writes: None,
895 });
896 cpass.set_pipeline(&ctx.matmul_pipeline);
897 cpass.set_bind_group(0, &bind_group, &[]);
898
899 let wg_x = (n + 15) / 16;
901 let wg_y = (m + 15) / 16;
902 cpass.dispatch_workgroups(wg_x, wg_y, 1);
903 });
904
905 output_buffer
906}
907
908pub fn adam_step_wgpu(
909 param: &wgpu::Buffer,
910 grad: &wgpu::Buffer,
911 m: &wgpu::Buffer,
912 v: &wgpu::Buffer,
913 numel: usize,
914 lr: f32,
915 beta1: f32,
916 beta2: f32,
917 epsilon: f32,
918 step: u32,
919) {
920 let ctx = get_context().expect("WGPU not initialized");
921 let device = &ctx.device;
922
923 #[repr(C)]
924 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
925 struct AdamParams {
926 lr: f32,
927 beta1: f32,
928 beta2: f32,
929 epsilon: f32,
930 step: u32,
931 numel: u32,
932 _pad1: u32,
933 _pad2: u32,
934 }
935
936 let params = AdamParams {
937 lr,
938 beta1,
939 beta2,
940 epsilon,
941 step,
942 numel: numel as u32,
943 _pad1: 0,
944 _pad2: 0,
945 };
946
947 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
948 label: Some("Adam Params"),
949 contents: bytemuck::bytes_of(¶ms),
950 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
951 });
952
953 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
954 label: Some("Adam Bind Group"),
955 layout: &ctx.adam_bind_group_layout,
956 entries: &[
957 wgpu::BindGroupEntry {
958 binding: 0,
959 resource: param.as_entire_binding(),
960 },
961 wgpu::BindGroupEntry {
962 binding: 1,
963 resource: grad.as_entire_binding(),
964 },
965 wgpu::BindGroupEntry {
966 binding: 2,
967 resource: m.as_entire_binding(),
968 },
969 wgpu::BindGroupEntry {
970 binding: 3,
971 resource: v.as_entire_binding(),
972 },
973 wgpu::BindGroupEntry {
974 binding: 4,
975 resource: params_buffer.as_entire_binding(),
976 },
977 ],
978 });
979
980 record_commands(|encoder| {
981 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
982 label: Some("Adam Pass"),
983 timestamp_writes: None,
984 });
985 cpass.set_pipeline(&ctx.adam_pipeline);
986 cpass.set_bind_group(0, &bind_group, &[]);
987 let workgroups = (numel as u32 + 255) / 256;
988 cpass.dispatch_workgroups(workgroups, 1, 1);
989 });
990}
991
992pub fn matmul_wgpu(lhs: &[f32], lhs_shape: &[usize], rhs: &[f32], rhs_shape: &[usize]) -> Vec<f32> {
993 let ctx = get_context().expect("WGPU not initialized");
994 let device = &ctx.device;
995
996 let lhs_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
997 label: Some("LHS"),
998 contents: bytemuck::cast_slice(lhs),
999 usage: wgpu::BufferUsages::STORAGE,
1000 });
1001 let rhs_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1002 label: Some("RHS"),
1003 contents: bytemuck::cast_slice(rhs),
1004 usage: wgpu::BufferUsages::STORAGE,
1005 });
1006
1007 let out_buf = matmul_wgpu_buffer(&lhs_buf, lhs_shape, &rhs_buf, rhs_shape, Activation::None);
1008
1009 let size = (lhs_shape[0] * rhs_shape[1] * 4) as u64;
1011 let staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
1012 label: Some("Staging Buffer"),
1013 size,
1014 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1015 mapped_at_creation: false,
1016 });
1017
1018 record_commands(|encoder| {
1019 encoder.copy_buffer_to_buffer(&out_buf, 0, &staging_buf, 0, size);
1020 });
1021
1022 flush_queue();
1023
1024 let slice = staging_buf.slice(..);
1025 let (tx, rx) = std::sync::mpsc::channel();
1026 slice.map_async(wgpu::MapMode::Read, move |v| tx.send(v).unwrap());
1027 device.poll(wgpu::Maintain::Wait);
1028 rx.recv().unwrap().unwrap();
1029
1030 let data = slice.get_mapped_range();
1031 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
1032 drop(data);
1033 staging_buf.unmap();
1034
1035 get_memory_pool().return_buffer(out_buf, size);
1036
1037 result
1038}
1039
1040pub fn reduce_sum_dim0_wgpu(input: &wgpu::Buffer, input_shape: &[usize]) -> wgpu::Buffer {
1041 let ctx = get_context().expect("WGPU not initialized");
1042 let device = &ctx.device;
1043
1044 let batch_size = input_shape[0] as u32;
1045 let input_size = input_shape[1] as u32;
1046 let stride_batch = input_size;
1047 let total_input_size: usize = input_shape.iter().product();
1048
1049 let output_size = input_size as usize;
1050 let output_byte_size = output_size * std::mem::size_of::<f32>();
1051 let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1052
1053 #[repr(C)]
1054 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1055 struct ReduceParams {
1056 input_size: u32,
1057 reduce_dim_size: u32,
1058 reduce_dim_stride: u32,
1059 outer_size: u32,
1060 }
1061
1062 let params = ReduceParams {
1063 input_size: total_input_size as u32,
1064 reduce_dim_size: batch_size,
1065 reduce_dim_stride: stride_batch,
1066 outer_size: input_size,
1067 };
1068
1069 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1070 label: Some("Reduce Params"),
1071 contents: bytemuck::bytes_of(¶ms),
1072 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1073 });
1074
1075 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1076 label: Some("Reduce Bind Group"),
1077 layout: &ctx.reduce_bind_group_layout,
1078 entries: &[
1079 wgpu::BindGroupEntry {
1080 binding: 0,
1081 resource: input.as_entire_binding(),
1082 },
1083 wgpu::BindGroupEntry {
1084 binding: 1,
1085 resource: output_buffer.as_entire_binding(),
1086 },
1087 wgpu::BindGroupEntry {
1088 binding: 2,
1089 resource: params_buffer.as_entire_binding(),
1090 },
1091 ],
1092 });
1093
1094 record_commands(|encoder| {
1095 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1096 label: Some("Reduce Pass"),
1097 timestamp_writes: None,
1098 });
1099 cpass.set_pipeline(&ctx.reduce_pipeline);
1100 cpass.set_bind_group(0, &bind_group, &[]);
1101 let workgroups = (input_size + 255) / 256;
1102 cpass.dispatch_workgroups(workgroups, 1, 1);
1103 });
1104
1105 output_buffer
1106}
1107
1108pub fn reduce_sum_all_wgpu(input: &wgpu::Buffer, input_size: usize) -> wgpu::Buffer {
1109 let ctx = get_context().expect("WGPU not initialized");
1110 let device = &ctx.device;
1111
1112 let output_byte_size = std::mem::size_of::<f32>();
1113 let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1114
1115 #[repr(C)]
1116 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1117 struct ReduceParams {
1118 input_size: u32,
1119 reduce_dim_size: u32,
1120 reduce_dim_stride: u32,
1121 outer_size: u32,
1122 }
1123
1124 let params = ReduceParams {
1125 input_size: input_size as u32,
1126 reduce_dim_size: 0,
1127 reduce_dim_stride: 0,
1128 outer_size: 0,
1129 };
1130
1131 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1132 label: Some("Reduce Params"),
1133 contents: bytemuck::bytes_of(¶ms),
1134 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1135 });
1136
1137 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1138 label: Some("Reduce Bind Group"),
1139 layout: &ctx.reduce_bind_group_layout,
1140 entries: &[
1141 wgpu::BindGroupEntry {
1142 binding: 0,
1143 resource: input.as_entire_binding(),
1144 },
1145 wgpu::BindGroupEntry {
1146 binding: 1,
1147 resource: output_buffer.as_entire_binding(),
1148 },
1149 wgpu::BindGroupEntry {
1150 binding: 2,
1151 resource: params_buffer.as_entire_binding(),
1152 },
1153 ],
1154 });
1155
1156 record_commands(|encoder| {
1157 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1158 label: Some("Reduce All Pass"),
1159 timestamp_writes: None,
1160 });
1161 cpass.set_pipeline(&ctx.reduce_all_pipeline);
1162 cpass.set_bind_group(0, &bind_group, &[]);
1163 cpass.dispatch_workgroups(1, 1, 1);
1164 });
1165
1166 output_buffer
1167}
1168
1169pub fn reduce_sum_dim_wgpu(
1170 input: &wgpu::Buffer,
1171 input_shape: &[usize],
1172 dim: usize,
1173) -> wgpu::Buffer {
1174 let ctx = get_context().expect("WGPU not initialized");
1175 let device = &ctx.device;
1176
1177 let ndim = input_shape.len();
1178 if dim >= ndim {
1179 panic!("Invalid reduction dimension");
1180 }
1181
1182 if ndim == 1 || dim == 0 && ndim == 2 {
1183 return reduce_sum_dim0_wgpu(input, input_shape);
1184 }
1185
1186 let total_input_size: usize = input_shape.iter().product();
1187 let reduce_dim_size = input_shape[dim];
1188
1189 let outer_size: usize = if dim < ndim - 1 {
1190 input_shape[dim + 1..].iter().product()
1191 } else {
1192 1
1193 };
1194
1195 let output_size = total_input_size / reduce_dim_size;
1196 let output_byte_size = output_size * std::mem::size_of::<f32>();
1197 let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1198
1199 #[repr(C)]
1200 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1201 struct ReduceParams {
1202 input_size: u32,
1203 reduce_dim_size: u32,
1204 reduce_dim_stride: u32,
1205 outer_size: u32,
1206 }
1207
1208 let params = ReduceParams {
1209 input_size: total_input_size as u32,
1210 reduce_dim_size: reduce_dim_size as u32,
1211 reduce_dim_stride: outer_size as u32,
1212 outer_size: outer_size as u32,
1213 };
1214
1215 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1216 label: Some("Reduce Params"),
1217 contents: bytemuck::bytes_of(¶ms),
1218 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1219 });
1220
1221 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1222 label: Some("Reduce Bind Group"),
1223 layout: &ctx.reduce_bind_group_layout,
1224 entries: &[
1225 wgpu::BindGroupEntry {
1226 binding: 0,
1227 resource: input.as_entire_binding(),
1228 },
1229 wgpu::BindGroupEntry {
1230 binding: 1,
1231 resource: output_buffer.as_entire_binding(),
1232 },
1233 wgpu::BindGroupEntry {
1234 binding: 2,
1235 resource: params_buffer.as_entire_binding(),
1236 },
1237 ],
1238 });
1239
1240 record_commands(|encoder| {
1241 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1242 label: Some("Reduce General Pass"),
1243 timestamp_writes: None,
1244 });
1245 cpass.set_pipeline(&ctx.reduce_general_pipeline);
1246 cpass.set_bind_group(0, &bind_group, &[]);
1247 let workgroups = ((output_size as u32 + 255) / 256).max(1);
1248 cpass.dispatch_workgroups(workgroups, 1, 1);
1249 });
1250
1251 output_buffer
1252}
1253
1254pub fn contiguous_wgpu(input: &wgpu::Buffer, shape: &[usize], strides: &[usize]) -> wgpu::Buffer {
1255 let ctx = get_context().expect("WGPU not initialized");
1256 let device = &ctx.device;
1257
1258 let size: usize = shape.iter().product();
1259 let output_byte_size = size * std::mem::size_of::<f32>();
1260 let output_buffer = get_memory_pool().get_buffer(device, output_byte_size as u64);
1261
1262 #[repr(C)]
1263 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1264 struct ContiguousParams {
1265 size: u32,
1266 ndim: u32,
1267 pad0: u32,
1268 pad1: u32,
1269 }
1270
1271 let params = ContiguousParams {
1272 size: size as u32,
1273 ndim: shape.len() as u32,
1274 pad0: 0,
1275 pad1: 0,
1276 };
1277
1278 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1279 label: Some("Contiguous Params"),
1280 contents: bytemuck::bytes_of(¶ms),
1281 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
1282 });
1283
1284 #[repr(C)]
1285 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
1286 struct ShapeStride {
1287 shape0: u32,
1288 shape1: u32,
1289 shape2: u32,
1290 shape3: u32,
1291 shape4: u32,
1292 shape5: u32,
1293 shape6: u32,
1294 shape7: u32,
1295 stride0: u32,
1296 stride1: u32,
1297 stride2: u32,
1298 stride3: u32,
1299 stride4: u32,
1300 stride5: u32,
1301 stride6: u32,
1302 stride7: u32,
1303 }
1304
1305 let shape_stride = ShapeStride {
1306 shape0: shape.get(0).copied().unwrap_or(1) as u32,
1307 shape1: shape.get(1).copied().unwrap_or(1) as u32,
1308 shape2: shape.get(2).copied().unwrap_or(1) as u32,
1309 shape3: shape.get(3).copied().unwrap_or(1) as u32,
1310 shape4: shape.get(4).copied().unwrap_or(1) as u32,
1311 shape5: shape.get(5).copied().unwrap_or(1) as u32,
1312 shape6: shape.get(6).copied().unwrap_or(1) as u32,
1313 shape7: shape.get(7).copied().unwrap_or(1) as u32,
1314 stride0: strides.get(0).copied().unwrap_or(0) as u32,
1315 stride1: strides.get(1).copied().unwrap_or(0) as u32,
1316 stride2: strides.get(2).copied().unwrap_or(0) as u32,
1317 stride3: strides.get(3).copied().unwrap_or(0) as u32,
1318 stride4: strides.get(4).copied().unwrap_or(0) as u32,
1319 stride5: strides.get(5).copied().unwrap_or(0) as u32,
1320 stride6: strides.get(6).copied().unwrap_or(0) as u32,
1321 stride7: strides.get(7).copied().unwrap_or(0) as u32,
1322 };
1323
1324 let shape_strides_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1325 label: Some("Shape Strides"),
1326 contents: bytemuck::bytes_of(&shape_stride),
1327 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
1328 });
1329
1330 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1331 label: Some("Contiguous Bind Group"),
1332 layout: &ctx.contiguous_bind_group_layout,
1333 entries: &[
1334 wgpu::BindGroupEntry {
1335 binding: 0,
1336 resource: input.as_entire_binding(),
1337 },
1338 wgpu::BindGroupEntry {
1339 binding: 1,
1340 resource: output_buffer.as_entire_binding(),
1341 },
1342 wgpu::BindGroupEntry {
1343 binding: 2,
1344 resource: params_buffer.as_entire_binding(),
1345 },
1346 wgpu::BindGroupEntry {
1347 binding: 3,
1348 resource: shape_strides_buffer.as_entire_binding(),
1349 },
1350 ],
1351 });
1352
1353 record_commands(|encoder| {
1354 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1355 label: Some("Contiguous Pass"),
1356 timestamp_writes: None,
1357 });
1358 cpass.set_pipeline(&ctx.contiguous_pipeline);
1359 cpass.set_bind_group(0, &bind_group, &[]);
1360 let workgroups = ((size as u32 + 255) / 256).max(1);
1361 cpass.dispatch_workgroups(workgroups, 1, 1);
1362 });
1363
1364 output_buffer
1365}