Skip to main content

tl_gpu/
ops.rs

1// GpuOps — compute shader dispatch engine
2
3use std::sync::{Arc, OnceLock};
4use wgpu;
5use wgpu::util::DeviceExt;
6
7use crate::device::GpuDevice;
8use crate::shaders;
9use crate::tensor::GpuTensor;
10
11/// GPU operations dispatcher with cached pipelines.
12pub struct GpuOps {
13    device: Arc<GpuDevice>,
14    binary_pipeline: OnceLock<wgpu::ComputePipeline>,
15    scalar_pipeline: OnceLock<wgpu::ComputePipeline>,
16    reduce_pipeline: OnceLock<wgpu::ComputePipeline>,
17    matmul_pipeline: OnceLock<wgpu::ComputePipeline>,
18    transpose_pipeline: OnceLock<wgpu::ComputePipeline>,
19}
20
21impl GpuOps {
22    pub fn new(device: Arc<GpuDevice>) -> Self {
23        GpuOps {
24            device,
25            binary_pipeline: OnceLock::new(),
26            scalar_pipeline: OnceLock::new(),
27            reduce_pipeline: OnceLock::new(),
28            matmul_pipeline: OnceLock::new(),
29            transpose_pipeline: OnceLock::new(),
30        }
31    }
32
33    // ── Pipeline builders ──
34
35    fn get_binary_pipeline(&self) -> &wgpu::ComputePipeline {
36        self.binary_pipeline
37            .get_or_init(|| self.create_pipeline(shaders::ELEMENTWISE_BINARY, "main"))
38    }
39
40    fn get_scalar_pipeline(&self) -> &wgpu::ComputePipeline {
41        self.scalar_pipeline
42            .get_or_init(|| self.create_pipeline(shaders::SCALAR_MUL, "main"))
43    }
44
45    fn get_reduce_pipeline(&self) -> &wgpu::ComputePipeline {
46        self.reduce_pipeline
47            .get_or_init(|| self.create_pipeline(shaders::REDUCE_SUM, "main"))
48    }
49
50    fn get_matmul_pipeline(&self) -> &wgpu::ComputePipeline {
51        self.matmul_pipeline
52            .get_or_init(|| self.create_pipeline(shaders::MATMUL, "main"))
53    }
54
55    fn get_transpose_pipeline(&self) -> &wgpu::ComputePipeline {
56        self.transpose_pipeline
57            .get_or_init(|| self.create_pipeline(shaders::TRANSPOSE, "main"))
58    }
59
60    fn create_pipeline(&self, shader_src: &str, entry: &str) -> wgpu::ComputePipeline {
61        let module = self
62            .device
63            .device
64            .create_shader_module(wgpu::ShaderModuleDescriptor {
65                label: Some("compute_shader"),
66                source: wgpu::ShaderSource::Wgsl(shader_src.into()),
67            });
68        self.device
69            .device
70            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
71                label: Some("compute_pipeline"),
72                layout: None, // auto-layout
73                module: &module,
74                entry_point: Some(entry),
75                compilation_options: Default::default(),
76                cache: None,
77            })
78    }
79
80    // ── Elementwise binary operations ──
81
82    fn binary_op(&self, a: &GpuTensor, b: &GpuTensor, op: u32) -> Result<GpuTensor, String> {
83        if a.numel != b.numel {
84            return Err(format!("Shape mismatch: {:?} vs {:?}", a.shape, b.shape));
85        }
86
87        let pipeline = self.get_binary_pipeline();
88        let dev = &self.device.device;
89
90        let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
91            label: Some("binary_result"),
92            size: a.byte_size(),
93            usage: wgpu::BufferUsages::STORAGE
94                | wgpu::BufferUsages::COPY_SRC
95                | wgpu::BufferUsages::COPY_DST,
96            mapped_at_creation: false,
97        });
98
99        #[repr(C)]
100        #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
101        struct Params {
102            len: u32,
103            op: u32,
104        }
105        let params = Params {
106            len: a.numel as u32,
107            op,
108        };
109
110        let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
111            label: Some("params"),
112            contents: bytemuck::bytes_of(&params),
113            usage: wgpu::BufferUsages::UNIFORM,
114        });
115
116        let bind_group_layout = pipeline.get_bind_group_layout(0);
117        let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
118            label: Some("binary_bg"),
119            layout: &bind_group_layout,
120            entries: &[
121                wgpu::BindGroupEntry {
122                    binding: 0,
123                    resource: a.buffer.as_entire_binding(),
124                },
125                wgpu::BindGroupEntry {
126                    binding: 1,
127                    resource: b.buffer.as_entire_binding(),
128                },
129                wgpu::BindGroupEntry {
130                    binding: 2,
131                    resource: result_buf.as_entire_binding(),
132                },
133                wgpu::BindGroupEntry {
134                    binding: 3,
135                    resource: param_buf.as_entire_binding(),
136                },
137            ],
138        });
139
140        let workgroups = (a.numel as u32 + 255) / 256;
141        let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
142            label: Some("binary_op"),
143        });
144        {
145            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
146                label: Some("binary"),
147                timestamp_writes: None,
148            });
149            pass.set_pipeline(pipeline);
150            pass.set_bind_group(0, &bind_group, &[]);
151            pass.dispatch_workgroups(workgroups, 1, 1);
152        }
153        self.device.queue.submit(std::iter::once(encoder.finish()));
154
155        Ok(GpuTensor {
156            buffer: result_buf,
157            shape: a.shape.clone(),
158            dtype: a.dtype,
159            numel: a.numel,
160            device: self.device.clone(),
161        })
162    }
163
164    pub fn add(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
165        self.binary_op(a, b, 0)
166    }
167
168    pub fn sub(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
169        self.binary_op(a, b, 1)
170    }
171
172    pub fn mul(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
173        self.binary_op(a, b, 2)
174    }
175
176    pub fn div(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
177        self.binary_op(a, b, 3)
178    }
179
180    // ── Scalar multiplication ──
181
182    pub fn scale(&self, a: &GpuTensor, scalar: f32) -> GpuTensor {
183        let pipeline = self.get_scalar_pipeline();
184        let dev = &self.device.device;
185
186        let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
187            label: Some("scale_result"),
188            size: a.byte_size(),
189            usage: wgpu::BufferUsages::STORAGE
190                | wgpu::BufferUsages::COPY_SRC
191                | wgpu::BufferUsages::COPY_DST,
192            mapped_at_creation: false,
193        });
194
195        #[repr(C)]
196        #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
197        struct Params {
198            len: u32,
199            scalar: f32,
200        }
201        let params = Params {
202            len: a.numel as u32,
203            scalar,
204        };
205
206        let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
207            label: Some("scale_params"),
208            contents: bytemuck::bytes_of(&params),
209            usage: wgpu::BufferUsages::UNIFORM,
210        });
211
212        let bind_group_layout = pipeline.get_bind_group_layout(0);
213        let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
214            label: Some("scale_bg"),
215            layout: &bind_group_layout,
216            entries: &[
217                wgpu::BindGroupEntry {
218                    binding: 0,
219                    resource: a.buffer.as_entire_binding(),
220                },
221                wgpu::BindGroupEntry {
222                    binding: 1,
223                    resource: result_buf.as_entire_binding(),
224                },
225                wgpu::BindGroupEntry {
226                    binding: 2,
227                    resource: param_buf.as_entire_binding(),
228                },
229            ],
230        });
231
232        let workgroups = (a.numel as u32 + 255) / 256;
233        let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
234            label: Some("scale"),
235        });
236        {
237            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
238                label: Some("scale"),
239                timestamp_writes: None,
240            });
241            pass.set_pipeline(pipeline);
242            pass.set_bind_group(0, &bind_group, &[]);
243            pass.dispatch_workgroups(workgroups, 1, 1);
244        }
245        self.device.queue.submit(std::iter::once(encoder.finish()));
246
247        GpuTensor {
248            buffer: result_buf,
249            shape: a.shape.clone(),
250            dtype: a.dtype,
251            numel: a.numel,
252            device: self.device.clone(),
253        }
254    }
255
256    // ── Reduction operations ──
257
258    pub fn sum(&self, a: &GpuTensor) -> Result<f32, String> {
259        let pipeline = self.get_reduce_pipeline();
260        let dev = &self.device.device;
261
262        let num_workgroups = (a.numel as u32 + 255) / 256;
263
264        let partial_buf = dev.create_buffer(&wgpu::BufferDescriptor {
265            label: Some("reduce_partial"),
266            size: (num_workgroups as usize * std::mem::size_of::<f32>()) as u64,
267            usage: wgpu::BufferUsages::STORAGE
268                | wgpu::BufferUsages::COPY_SRC
269                | wgpu::BufferUsages::COPY_DST,
270            mapped_at_creation: false,
271        });
272
273        #[repr(C)]
274        #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
275        struct Params {
276            len: u32,
277        }
278        let params = Params {
279            len: a.numel as u32,
280        };
281
282        let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
283            label: Some("reduce_params"),
284            contents: bytemuck::bytes_of(&params),
285            usage: wgpu::BufferUsages::UNIFORM,
286        });
287
288        let bind_group_layout = pipeline.get_bind_group_layout(0);
289        let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
290            label: Some("reduce_bg"),
291            layout: &bind_group_layout,
292            entries: &[
293                wgpu::BindGroupEntry {
294                    binding: 0,
295                    resource: a.buffer.as_entire_binding(),
296                },
297                wgpu::BindGroupEntry {
298                    binding: 1,
299                    resource: partial_buf.as_entire_binding(),
300                },
301                wgpu::BindGroupEntry {
302                    binding: 2,
303                    resource: param_buf.as_entire_binding(),
304                },
305            ],
306        });
307
308        let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
309            label: Some("reduce"),
310        });
311        {
312            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
313                label: Some("reduce"),
314                timestamp_writes: None,
315            });
316            pass.set_pipeline(pipeline);
317            pass.set_bind_group(0, &bind_group, &[]);
318            pass.dispatch_workgroups(num_workgroups, 1, 1);
319        }
320        self.device.queue.submit(std::iter::once(encoder.finish()));
321
322        // Read partial sums back and finish on CPU
323        let partial_tensor = GpuTensor {
324            buffer: partial_buf,
325            shape: vec![num_workgroups as usize],
326            dtype: crate::tensor::DType::F32,
327            numel: num_workgroups as usize,
328            device: self.device.clone(),
329        };
330        let partials = partial_tensor.read_f32()?;
331        Ok(partials.iter().sum())
332    }
333
334    pub fn mean(&self, a: &GpuTensor) -> Result<f32, String> {
335        let s = self.sum(a)?;
336        Ok(s / a.numel as f32)
337    }
338
339    // ── Matrix multiply ──
340
341    pub fn matmul(&self, a: &GpuTensor, b: &GpuTensor) -> Result<GpuTensor, String> {
342        if a.shape.len() != 2 || b.shape.len() != 2 {
343            return Err("matmul requires 2D tensors".to_string());
344        }
345        let m = a.shape[0] as u32;
346        let k = a.shape[1] as u32;
347        let k2 = b.shape[0] as u32;
348        let n = b.shape[1] as u32;
349        if k != k2 {
350            return Err(format!("matmul dimension mismatch: [{m},{k}] x [{k2},{n}]"));
351        }
352
353        let pipeline = self.get_matmul_pipeline();
354        let dev = &self.device.device;
355
356        let result_numel = (m * n) as usize;
357        let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
358            label: Some("matmul_result"),
359            size: (result_numel * std::mem::size_of::<f32>()) as u64,
360            usage: wgpu::BufferUsages::STORAGE
361                | wgpu::BufferUsages::COPY_SRC
362                | wgpu::BufferUsages::COPY_DST,
363            mapped_at_creation: false,
364        });
365
366        #[repr(C)]
367        #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
368        struct Params {
369            m: u32,
370            k: u32,
371            n: u32,
372            _pad: u32,
373        }
374        let params = Params { m, k, n, _pad: 0 };
375
376        let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
377            label: Some("matmul_params"),
378            contents: bytemuck::bytes_of(&params),
379            usage: wgpu::BufferUsages::UNIFORM,
380        });
381
382        let bind_group_layout = pipeline.get_bind_group_layout(0);
383        let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
384            label: Some("matmul_bg"),
385            layout: &bind_group_layout,
386            entries: &[
387                wgpu::BindGroupEntry {
388                    binding: 0,
389                    resource: a.buffer.as_entire_binding(),
390                },
391                wgpu::BindGroupEntry {
392                    binding: 1,
393                    resource: b.buffer.as_entire_binding(),
394                },
395                wgpu::BindGroupEntry {
396                    binding: 2,
397                    resource: result_buf.as_entire_binding(),
398                },
399                wgpu::BindGroupEntry {
400                    binding: 3,
401                    resource: param_buf.as_entire_binding(),
402                },
403            ],
404        });
405
406        let wg_x = (n + 15) / 16;
407        let wg_y = (m + 15) / 16;
408        let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
409            label: Some("matmul"),
410        });
411        {
412            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
413                label: Some("matmul"),
414                timestamp_writes: None,
415            });
416            pass.set_pipeline(pipeline);
417            pass.set_bind_group(0, &bind_group, &[]);
418            pass.dispatch_workgroups(wg_x, wg_y, 1);
419        }
420        self.device.queue.submit(std::iter::once(encoder.finish()));
421
422        Ok(GpuTensor {
423            buffer: result_buf,
424            shape: vec![m as usize, n as usize],
425            dtype: a.dtype,
426            numel: result_numel,
427            device: self.device.clone(),
428        })
429    }
430
431    // ── Transpose ──
432
433    pub fn transpose(&self, a: &GpuTensor) -> Result<GpuTensor, String> {
434        if a.shape.len() != 2 {
435            return Err("transpose requires a 2D tensor".to_string());
436        }
437        let rows = a.shape[0] as u32;
438        let cols = a.shape[1] as u32;
439
440        let pipeline = self.get_transpose_pipeline();
441        let dev = &self.device.device;
442
443        let result_buf = dev.create_buffer(&wgpu::BufferDescriptor {
444            label: Some("transpose_result"),
445            size: a.byte_size(),
446            usage: wgpu::BufferUsages::STORAGE
447                | wgpu::BufferUsages::COPY_SRC
448                | wgpu::BufferUsages::COPY_DST,
449            mapped_at_creation: false,
450        });
451
452        #[repr(C)]
453        #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
454        struct Params {
455            rows: u32,
456            cols: u32,
457        }
458        let params = Params { rows, cols };
459
460        let param_buf = dev.create_buffer_init(&wgpu::util::BufferInitDescriptor {
461            label: Some("transpose_params"),
462            contents: bytemuck::bytes_of(&params),
463            usage: wgpu::BufferUsages::UNIFORM,
464        });
465
466        let bind_group_layout = pipeline.get_bind_group_layout(0);
467        let bind_group = dev.create_bind_group(&wgpu::BindGroupDescriptor {
468            label: Some("transpose_bg"),
469            layout: &bind_group_layout,
470            entries: &[
471                wgpu::BindGroupEntry {
472                    binding: 0,
473                    resource: a.buffer.as_entire_binding(),
474                },
475                wgpu::BindGroupEntry {
476                    binding: 1,
477                    resource: result_buf.as_entire_binding(),
478                },
479                wgpu::BindGroupEntry {
480                    binding: 2,
481                    resource: param_buf.as_entire_binding(),
482                },
483            ],
484        });
485
486        let wg_x = (cols + 15) / 16;
487        let wg_y = (rows + 15) / 16;
488        let mut encoder = dev.create_command_encoder(&wgpu::CommandEncoderDescriptor {
489            label: Some("transpose"),
490        });
491        {
492            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
493                label: Some("transpose"),
494                timestamp_writes: None,
495            });
496            pass.set_pipeline(pipeline);
497            pass.set_bind_group(0, &bind_group, &[]);
498            pass.dispatch_workgroups(wg_x, wg_y, 1);
499        }
500        self.device.queue.submit(std::iter::once(encoder.finish()));
501
502        Ok(GpuTensor {
503            buffer: result_buf,
504            shape: vec![cols as usize, rows as usize],
505            dtype: a.dtype,
506            numel: a.numel,
507            device: self.device.clone(),
508        })
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::tensor::GpuTensor;
516
517    #[test]
518    fn test_gpu_add() {
519        let Some(device) = GpuDevice::get() else {
520            return;
521        };
522        let ops = GpuOps::new(device.clone());
523
524        let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![4], device.clone());
525        let b = GpuTensor::from_f32(&[10.0, 20.0, 30.0, 40.0], vec![4], device.clone());
526
527        let c = ops.add(&a, &b).unwrap();
528        let result = c.read_f32().unwrap();
529        assert_eq!(result, vec![11.0, 22.0, 33.0, 44.0]);
530    }
531
532    #[test]
533    fn test_gpu_sub() {
534        let Some(device) = GpuDevice::get() else {
535            return;
536        };
537        let ops = GpuOps::new(device.clone());
538
539        let a = GpuTensor::from_f32(&[10.0, 20.0, 30.0], vec![3], device.clone());
540        let b = GpuTensor::from_f32(&[1.0, 2.0, 3.0], vec![3], device.clone());
541
542        let c = ops.sub(&a, &b).unwrap();
543        let result = c.read_f32().unwrap();
544        assert_eq!(result, vec![9.0, 18.0, 27.0]);
545    }
546
547    #[test]
548    fn test_gpu_mul() {
549        let Some(device) = GpuDevice::get() else {
550            return;
551        };
552        let ops = GpuOps::new(device.clone());
553
554        let a = GpuTensor::from_f32(&[2.0, 3.0, 4.0], vec![3], device.clone());
555        let b = GpuTensor::from_f32(&[5.0, 6.0, 7.0], vec![3], device.clone());
556
557        let c = ops.mul(&a, &b).unwrap();
558        let result = c.read_f32().unwrap();
559        assert_eq!(result, vec![10.0, 18.0, 28.0]);
560    }
561
562    #[test]
563    fn test_gpu_div() {
564        let Some(device) = GpuDevice::get() else {
565            return;
566        };
567        let ops = GpuOps::new(device.clone());
568
569        let a = GpuTensor::from_f32(&[10.0, 20.0, 30.0], vec![3], device.clone());
570        let b = GpuTensor::from_f32(&[2.0, 5.0, 10.0], vec![3], device.clone());
571
572        let c = ops.div(&a, &b).unwrap();
573        let result = c.read_f32().unwrap();
574        assert_eq!(result, vec![5.0, 4.0, 3.0]);
575    }
576
577    #[test]
578    fn test_gpu_matmul() {
579        let Some(device) = GpuDevice::get() else {
580            return;
581        };
582        let ops = GpuOps::new(device.clone());
583
584        // [2,2] x [2,2]
585        let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![2, 2], device.clone());
586        let b = GpuTensor::from_f32(&[5.0, 6.0, 7.0, 8.0], vec![2, 2], device.clone());
587
588        let c = ops.matmul(&a, &b).unwrap();
589        let result = c.read_f32().unwrap();
590        // [1*5+2*7, 1*6+2*8, 3*5+4*7, 3*6+4*8] = [19, 22, 43, 50]
591        assert_eq!(result, vec![19.0, 22.0, 43.0, 50.0]);
592        assert_eq!(c.shape, vec![2, 2]);
593    }
594
595    #[test]
596    fn test_gpu_sum() {
597        let Some(device) = GpuDevice::get() else {
598            return;
599        };
600        let ops = GpuOps::new(device.clone());
601
602        let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![4], device.clone());
603        let s = ops.sum(&a).unwrap();
604        assert!((s - 10.0).abs() < 1e-5);
605    }
606
607    #[test]
608    fn test_gpu_transpose() {
609        let Some(device) = GpuDevice::get() else {
610            return;
611        };
612        let ops = GpuOps::new(device.clone());
613
614        // [[1,2,3],[4,5,6]] -> [[1,4],[2,5],[3,6]]
615        let a = GpuTensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3], device.clone());
616        let t = ops.transpose(&a).unwrap();
617        let result = t.read_f32().unwrap();
618        assert_eq!(result, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
619        assert_eq!(t.shape, vec![3, 2]);
620    }
621}