1use std::sync::{Arc, OnceLock};
4use wgpu;
5use wgpu::util::DeviceExt;
6
7use crate::device::GpuDevice;
8use crate::shaders;
9use crate::tensor::GpuTensor;
10
11pub struct GpuOps {
13 device: Arc<GpuDevice>,
14 binary_pipeline: OnceLock<wgpu::ComputePipeline>,
15 scalar_pipeline: OnceLock<wgpu::ComputePipeline>,
16 reduce_pipeline: OnceLock<wgpu::ComputePipeline>,
17 matmul_pipeline: OnceLock<wgpu::ComputePipeline>,
18 transpose_pipeline: OnceLock<wgpu::ComputePipeline>,
19}
20
21impl GpuOps {
22 pub fn new(device: Arc<GpuDevice>) -> Self {
23 GpuOps {
24 device,
25 binary_pipeline: OnceLock::new(),
26 scalar_pipeline: OnceLock::new(),
27 reduce_pipeline: OnceLock::new(),
28 matmul_pipeline: OnceLock::new(),
29 transpose_pipeline: OnceLock::new(),
30 }
31 }
32
33 fn get_binary_pipeline(&self) -> &wgpu::ComputePipeline {
36 self.binary_pipeline
37 .get_or_init(|| self.create_pipeline(shaders::ELEMENTWISE_BINARY, "main"))
38 }
39
40 fn get_scalar_pipeline(&self) -> &wgpu::ComputePipeline {
41 self.scalar_pipeline
42 .get_or_init(|| self.create_pipeline(shaders::SCALAR_MUL, "main"))
43 }
44
45 fn get_reduce_pipeline(&self) -> &wgpu::ComputePipeline {
46 self.reduce_pipeline
47 .get_or_init(|| self.create_pipeline(shaders::REDUCE_SUM, "main"))
48 }
49
50 fn get_matmul_pipeline(&self) -> &wgpu::ComputePipeline {
51 self.matmul_pipeline
52 .get_or_init(|| self.create_pipeline(shaders::MATMUL, "main"))
53 }
54
55 fn get_transpose_pipeline(&self) -> &wgpu::ComputePipeline {
56 self.transpose_pipeline
57 .get_or_init(|| self.create_pipeline(shaders::TRANSPOSE, "main"))
58 }
59
60 fn create_pipeline(&self, shader_src: &str, entry: &str) -> wgpu::ComputePipeline {
61 let module = self
62 .device
63 .device
64 .create_shader_module(wgpu::ShaderModuleDescriptor {
65 label: Some("compute_shader"),
66 source: wgpu::ShaderSource::Wgsl(shader_src.into()),
67 });
68 self.device
69 .device
70 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
71 label: Some("compute_pipeline"),
72 layout: None, module: &module,
74 entry_point: Some(entry),
75 compilation_options: Default::default(),
76 cache: None,
77 })
78 }
79
80 fn binary_op(&self, a: &GpuTensor, b: &GpuTensor, op: u32) -> Result<GpuTensor, String> {
83 if a.numel != b.numel {
84 return Err(format!("Shape mismatch: {:?} vs {:?}", a.shape, b.shape));
85 }
86
87 let pipeline = self.get_binary_pipeline();
88 let dev = &self.device.device;
89
90 let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
91 label: Some("binary_result"),
92 size: a.byte_size(),
93 usage: wgpu::BufferUsages::STORAGE
94 | wgpu::BufferUsages::COPY_SRC
95 | wgpu::BufferUsages::COPY_DST,
96 mapped_at_creation: false,
97 });
98
99 #[repr(C)]
100 #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
101 struct Params {
102 len: u32,
103 op: u32,
104 }
105 let params = Params {
106 len: a.numel as u32,
107 op,
108 };
109
110 let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
111 label: Some("params"),
112 contents: bytemuck::bytes_of(¶ms),
113 usage: wgpu::BufferUsages::UNIFORM,
114 });
115
116 let bind_group_layout = pipeline.get_bind_group_layout(0);
117 let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
118 label: Some("binary_bg"),
119 layout: &bind_group_layout,
120 entries: &[
121 wgpu::BindGroupEntry {
122 binding: 0,
123 resource: a.buffer.as_entire_binding(),
124 },
125 wgpu::BindGroupEntry {
126 binding: 1,
127 resource: b.buffer.as_entire_binding(),
128 },
129 wgpu::BindGroupEntry {
130 binding: 2,
131 resource: result_buf.as_entire_binding(),
132 },
133 wgpu::BindGroupEntry {
134 binding: 3,
135 resource: param_buf.as_entire_binding(),
136 },
137 ],
138 });
139
140 let workgroups = (a.numel as u32 + 255) / 256;
141 let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
142 label: Some("binary_op"),
143 });
144 {
145 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
146 label: Some("binary"),
147 timestamp_writes: None,
148 });
149 pass.set_pipeline(pipeline);
150 pass.set_bind_group(0, &bind_group, &[]);
151 pass.dispatch_workgroups(workgroups, 1, 1);
152 }
153 self.device.queue.submit(std::iter::once(encoder.finish()));
154
155 Ok(GpuTensor {
156 buffer: result_buf,
157 shape: a.shape.clone(),
158 dtype: a.dtype,
159 numel: a.numel,
160 device: self.device.clone(),
161 })
162 }
163
164 pub fn add(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
165 self.binary_op(a, b, 0)
166 }
167
168 pub fn sub(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
169 self.binary_op(a, b, 1)
170 }
171
172 pub fn mul(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
173 self.binary_op(a, b, 2)
174 }
175
176 pub fn div(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
177 self.binary_op(a, b, 3)
178 }
179
180 pub fn scale(&self, a: &GpuTensor, scalar: f32) -> GpuTensor {
183 let pipeline = self.get_scalar_pipeline();
184 let dev = &self.device.device;
185
186 let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
187 label: Some("scale_result"),
188 size: a.byte_size(),
189 usage: wgpu::BufferUsages::STORAGE
190 | wgpu::BufferUsages::COPY_SRC
191 | wgpu::BufferUsages::COPY_DST,
192 mapped_at_creation: false,
193 });
194
195 #[repr(C)]
196 #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
197 struct Params {
198 len: u32,
199 scalar: f32,
200 }
201 let params = Params {
202 len: a.numel as u32,
203 scalar,
204 };
205
206 let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
207 label: Some("scale_params"),
208 contents: bytemuck::bytes_of(¶ms),
209 usage: wgpu::BufferUsages::UNIFORM,
210 });
211
212 let bind_group_layout = pipeline.get_bind_group_layout(0);
213 let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
214 label: Some("scale_bg"),
215 layout: &bind_group_layout,
216 entries: &[
217 wgpu::BindGroupEntry {
218 binding: 0,
219 resource: a.buffer.as_entire_binding(),
220 },
221 wgpu::BindGroupEntry {
222 binding: 1,
223 resource: result_buf.as_entire_binding(),
224 },
225 wgpu::BindGroupEntry {
226 binding: 2,
227 resource: param_buf.as_entire_binding(),
228 },
229 ],
230 });
231
232 let workgroups = (a.numel as u32 + 255) / 256;
233 let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
234 label: Some("scale"),
235 });
236 {
237 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
238 label: Some("scale"),
239 timestamp_writes: None,
240 });
241 pass.set_pipeline(pipeline);
242 pass.set_bind_group(0, &bind_group, &[]);
243 pass.dispatch_workgroups(workgroups, 1, 1);
244 }
245 self.device.queue.submit(std::iter::once(encoder.finish()));
246
247 GpuTensor {
248 buffer: result_buf,
249 shape: a.shape.clone(),
250 dtype: a.dtype,
251 numel: a.numel,
252 device: self.device.clone(),
253 }
254 }
255
256 pub fn sum(&self, a: &GpuTensor) -> Result<f32, String> {
259 let pipeline = self.get_reduce_pipeline();
260 let dev = &self.device.device;
261
262 let num_workgroups = (a.numel as u32 + 255) / 256;
263
264 let partial_buf = dev.create_buffer(&wgpu::BufferDescriptor {
265 label: Some("reduce_partial"),
266 size: (num_workgroups as usize * std::mem::size_of::<f32>()) as u64,
267 usage: wgpu::BufferUsages::STORAGE
268 | wgpu::BufferUsages::COPY_SRC
269 | wgpu::BufferUsages::COPY_DST,
270 mapped_at_creation: false,
271 });
272
273 #[repr(C)]
274 #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
275 struct Params {
276 len: u32,
277 }
278 let params = Params {
279 len: a.numel as u32,
280 };
281
282 let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
283 label: Some("reduce_params"),
284 contents: bytemuck::bytes_of(¶ms),
285 usage: wgpu::BufferUsages::UNIFORM,
286 });
287
288 let bind_group_layout = pipeline.get_bind_group_layout(0);
289 let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
290 label: Some("reduce_bg"),
291 layout: &bind_group_layout,
292 entries: &[
293 wgpu::BindGroupEntry {
294 binding: 0,
295 resource: a.buffer.as_entire_binding(),
296 },
297 wgpu::BindGroupEntry {
298 binding: 1,
299 resource: partial_buf.as_entire_binding(),
300 },
301 wgpu::BindGroupEntry {
302 binding: 2,
303 resource: param_buf.as_entire_binding(),
304 },
305 ],
306 });
307
308 let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
309 label: Some("reduce"),
310 });
311 {
312 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
313 label: Some("reduce"),
314 timestamp_writes: None,
315 });
316 pass.set_pipeline(pipeline);
317 pass.set_bind_group(0, &bind_group, &[]);
318 pass.dispatch_workgroups(num_workgroups, 1, 1);
319 }
320 self.device.queue.submit(std::iter::once(encoder.finish()));
321
322 let partial_tensor = GpuTensor {
324 buffer: partial_buf,
325 shape: vec![num_workgroups as usize],
326 dtype: crate::tensor::DType::F32,
327 numel: num_workgroups as usize,
328 device: self.device.clone(),
329 };
330 let partials = partial_tensor.read_f32()?;
331 Ok(partials.iter().sum())
332 }
333
334 pub fn mean(&self, a: &GpuTensor) -> Result<f32, String> {
335 let s = self.sum(a)?;
336 Ok(s / a.numel as f32)
337 }
338
339 pub fn matmul(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
342 if a.shape.len() != 2 || b.shape.len() != 2 {
343 return Err("matmul requires 2D tensors".to_string());
344 }
345 let m = a.shape[0] as u32;
346 let k = a.shape[1] as u32;
347 let k2 = b.shape[0] as u32;
348 let n = b.shape[1] as u32;
349 if k != k2 {
350 return Err(format!("matmul dimension mismatch: [{m},{k}] x [{k2},{n}]"));
351 }
352
353 let pipeline = self.get_matmul_pipeline();
354 let dev = &self.device.device;
355
356 let result_numel = (m * n) as usize;
357 let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
358 label: Some("matmul_result"),
359 size: (result_numel * std::mem::size_of::<f32>()) as u64,
360 usage: wgpu::BufferUsages::STORAGE
361 | wgpu::BufferUsages::COPY_SRC
362 | wgpu::BufferUsages::COPY_DST,
363 mapped_at_creation: false,
364 });
365
366 #[repr(C)]
367 #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
368 struct Params {
369 m: u32,
370 k: u32,
371 n: u32,
372 _pad: u32,
373 }
374 let params = Params { m, k, n, _pad: 0 };
375
376 let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
377 label: Some("matmul_params"),
378 contents: bytemuck::bytes_of(¶ms),
379 usage: wgpu::BufferUsages::UNIFORM,
380 });
381
382 let bind_group_layout = pipeline.get_bind_group_layout(0);
383 let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
384 label: Some("matmul_bg"),
385 layout: &bind_group_layout,
386 entries: &[
387 wgpu::BindGroupEntry {
388 binding: 0,
389 resource: a.buffer.as_entire_binding(),
390 },
391 wgpu::BindGroupEntry {
392 binding: 1,
393 resource: b.buffer.as_entire_binding(),
394 },
395 wgpu::BindGroupEntry {
396 binding: 2,
397 resource: result_buf.as_entire_binding(),
398 },
399 wgpu::BindGroupEntry {
400 binding: 3,
401 resource: param_buf.as_entire_binding(),
402 },
403 ],
404 });
405
406 let wg_x = (n + 15) / 16;
407 let wg_y = (m + 15) / 16;
408 let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
409 label: Some("matmul"),
410 });
411 {
412 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
413 label: Some("matmul"),
414 timestamp_writes: None,
415 });
416 pass.set_pipeline(pipeline);
417 pass.set_bind_group(0, &bind_group, &[]);
418 pass.dispatch_workgroups(wg_x, wg_y, 1);
419 }
420 self.device.queue.submit(std::iter::once(encoder.finish()));
421
422 Ok(GpuTensor {
423 buffer: result_buf,
424 shape: vec![m as usize, n as usize],
425 dtype: a.dtype,
426 numel: result_numel,
427 device: self.device.clone(),
428 })
429 }
430
431 pub fn transpose(&self, a: &GpuTensor) -> Result<GpuTensor, String> {
434 if a.shape.len() != 2 {
435 return Err("transpose requires a 2D tensor".to_string());
436 }
437 let rows = a.shape[0] as u32;
438 let cols = a.shape[1] as u32;
439
440 let pipeline = self.get_transpose_pipeline();
441 let dev = &self.device.device;
442
443 let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
444 label: Some("transpose_result"),
445 size: a.byte_size(),
446 usage: wgpu::BufferUsages::STORAGE
447 | wgpu::BufferUsages::COPY_SRC
448 | wgpu::BufferUsages::COPY_DST,
449 mapped_at_creation: false,
450 });
451
452 #[repr(C)]
453 #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
454 struct Params {
455 rows: u32,
456 cols: u32,
457 }
458 let params = Params { rows, cols };
459
460 let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
461 label: Some("transpose_params"),
462 contents: bytemuck::bytes_of(¶ms),
463 usage: wgpu::BufferUsages::UNIFORM,
464 });
465
466 let bind_group_layout = pipeline.get_bind_group_layout(0);
467 let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
468 label: Some("transpose_bg"),
469 layout: &bind_group_layout,
470 entries: &[
471 wgpu::BindGroupEntry {
472 binding: 0,
473 resource: a.buffer.as_entire_binding(),
474 },
475 wgpu::BindGroupEntry {
476 binding: 1,
477 resource: result_buf.as_entire_binding(),
478 },
479 wgpu::BindGroupEntry {
480 binding: 2,
481 resource: param_buf.as_entire_binding(),
482 },
483 ],
484 });
485
486 let wg_x = (cols + 15) / 16;
487 let wg_y = (rows + 15) / 16;
488 let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
489 label: Some("transpose"),
490 });
491 {
492 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
493 label: Some("transpose"),
494 timestamp_writes: None,
495 });
496 pass.set_pipeline(pipeline);
497 pass.set_bind_group(0, &bind_group, &[]);
498 pass.dispatch_workgroups(wg_x, wg_y, 1);
499 }
500 self.device.queue.submit(std::iter::once(encoder.finish()));
501
502 Ok(GpuTensor {
503 buffer: result_buf,
504 shape: vec![cols as usize, rows as usize],
505 dtype: a.dtype,
506 numel: a.numel,
507 device: self.device.clone(),
508 })
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::tensor::GpuTensor;
516
517 #[test]
518 fn test_gpu_add() {
519 let Some(device) = GpuDevice::get() else {
520 return;
521 };
522 let ops = GpuOps::new(device.clone());
523
524 let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![4], device.clone());
525 let b = GpuTensor::from_f32(&[10.0, 20.0, 30.0, 40.0], vec![4], device.clone());
526
527 let c = ops.add(&a, &b).unwrap();
528 let result = c.read_f32().unwrap();
529 assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0]);
530 }
531
532 #[test]
533 fn test_gpu_sub() {
534 let Some(device) = GpuDevice::get() else {
535 return;
536 };
537 let ops = GpuOps::new(device.clone());
538
539 let a = GpuTensor::from_f32(&[10.0, 20.0, 30.0], vec![3], device.clone());
540 let b = GpuTensor::from_f32(&[1.0, 2.0, 3.0], vec![3], device.clone());
541
542 let c = ops.sub(&a, &b).unwrap();
543 let result = c.read_f32().unwrap();
544 assert_eq!(result, vec![9.0, 18.0, 27.0]);
545 }
546
547 #[test]
548 fn test_gpu_mul() {
549 let Some(device) = GpuDevice::get() else {
550 return;
551 };
552 let ops = GpuOps::new(device.clone());
553
554 let a = GpuTensor::from_f32(&[2.0, 3.0, 4.0], vec![3], device.clone());
555 let b = GpuTensor::from_f32(&[5.0, 6.0, 7.0], vec![3], device.clone());
556
557 let c = ops.mul(&a, &b).unwrap();
558 let result = c.read_f32().unwrap();
559 assert_eq!(result, vec![10.0, 18.0, 28.0]);
560 }
561
562 #[test]
563 fn test_gpu_div() {
564 let Some(device) = GpuDevice::get() else {
565 return;
566 };
567 let ops = GpuOps::new(device.clone());
568
569 let a = GpuTensor::from_f32(&[10.0, 20.0, 30.0], vec![3], device.clone());
570 let b = GpuTensor::from_f32(&[2.0, 5.0, 10.0], vec![3], device.clone());
571
572 let c = ops.div(&a, &b).unwrap();
573 let result = c.read_f32().unwrap();
574 assert_eq!(result, vec![5.0, 4.0, 3.0]);
575 }
576
577 #[test]
578 fn test_gpu_matmul() {
579 let Some(device) = GpuDevice::get() else {
580 return;
581 };
582 let ops = GpuOps::new(device.clone());
583
584 let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![2, 2], device.clone());
586 let b = GpuTensor::from_f32(&[5.0, 6.0, 7.0, 8.0], vec![2, 2], device.clone());
587
588 let c = ops.matmul(&a, &b).unwrap();
589 let result = c.read_f32().unwrap();
590 assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
592 assert_eq!(c.shape, vec![2, 2]);
593 }
594
595 #[test]
596 fn test_gpu_sum() {
597 let Some(device) = GpuDevice::get() else {
598 return;
599 };
600 let ops = GpuOps::new(device.clone());
601
602 let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![4], device.clone());
603 let s = ops.sum(&a).unwrap();
604 assert!((s - 10.0).abs() < 1e-5);
605 }
606
607 #[test]
608 fn test_gpu_transpose() {
609 let Some(device) = GpuDevice::get() else {
610 return;
611 };
612 let ops = GpuOps::new(device.clone());
613
614 let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3], device.clone());
616 let t = ops.transpose(&a).unwrap();
617 let result = t.read_f32().unwrap();
618 assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
619 assert_eq!(t.shape, vec![3, 2]);
620 }
621}