1use core::fmt;
45
46#[derive(Debug, Clone)]
51pub struct WGSLShader {
52 source: String,
54 entry_point: String,
56 workgroup_size: (u32, u32, u32),
58}
59
60impl WGSLShader {
61 pub fn new(source: impl Into<String>) -> Self {
63 Self {
64 source: source.into(),
65 entry_point: "main".to_string(),
66 workgroup_size: (64, 1, 1), }
68 }
69
70 pub fn with_entry_point(source: impl Into<String>, entry_point: impl Into<String>) -> Self {
72 Self {
73 source: source.into(),
74 entry_point: entry_point.into(),
75 workgroup_size: (64, 1, 1),
76 }
77 }
78
79 pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
81 self.workgroup_size = (x, y, z);
82 self
83 }
84
85 pub fn source(&self) -> &str {
87 &self.source
88 }
89
90 pub fn entry_point(&self) -> &str {
92 &self.entry_point
93 }
94
95 pub fn workgroup_size(&self) -> (u32, u32, u32) {
97 self.workgroup_size
98 }
99
100 pub fn validate(&self) -> Result<(), ShaderError> {
102 if self.source.is_empty() {
103 return Err(ShaderError::EmptyShader);
104 }
105 if !self.source.contains("@compute") {
106 return Err(ShaderError::MissingComputeAttribute);
107 }
108 Ok(())
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum ShaderError {
115 EmptyShader,
117 MissingComputeAttribute,
119 SyntaxError(String),
121 CompilationFailed(String),
123}
124
125impl fmt::Display for ShaderError {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 match self {
128 ShaderError::EmptyShader => write!(f, "Shader source is empty"),
129 ShaderError::MissingComputeAttribute => write!(f, "Missing @compute attribute"),
130 ShaderError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
131 ShaderError::CompilationFailed(msg) => write!(f, "Compilation failed: {}", msg),
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum BufferUsage {
139 Storage,
141 Uniform,
143 Staging,
145 Vertex,
147 Index,
149}
150
151#[derive(Debug, Clone)]
155pub struct GPUBuffer {
156 size: usize,
158 usage: BufferUsage,
160 mapped: bool,
162 label: Option<String>,
164}
165
166impl GPUBuffer {
167 pub fn storage(size: usize) -> Self {
169 Self {
170 size,
171 usage: BufferUsage::Storage,
172 mapped: false,
173 label: None,
174 }
175 }
176
177 pub fn uniform(size: usize) -> Self {
179 Self {
180 size,
181 usage: BufferUsage::Uniform,
182 mapped: false,
183 label: None,
184 }
185 }
186
187 pub fn staging(size: usize) -> Self {
189 Self {
190 size,
191 usage: BufferUsage::Staging,
192 mapped: true,
193 label: None,
194 }
195 }
196
197 pub fn with_label(mut self, label: impl Into<String>) -> Self {
199 self.label = Some(label.into());
200 self
201 }
202
203 pub fn size(&self) -> usize {
205 self.size
206 }
207
208 pub fn usage(&self) -> BufferUsage {
210 self.usage
211 }
212
213 pub fn is_mapped(&self) -> bool {
215 self.mapped
216 }
217
218 pub fn label(&self) -> Option<&str> {
220 self.label.as_deref()
221 }
222}
223
224#[derive(Debug, Clone)]
228pub struct ComputePipeline {
229 shader: WGSLShader,
231 bind_groups: Vec<BindGroupLayout>,
233 label: Option<String>,
235}
236
237impl ComputePipeline {
238 pub fn new(shader: WGSLShader) -> Self {
240 Self {
241 shader,
242 bind_groups: Vec::new(),
243 label: None,
244 }
245 }
246
247 pub fn with_bind_group(mut self, layout: BindGroupLayout) -> Self {
249 self.bind_groups.push(layout);
250 self
251 }
252
253 pub fn with_label(mut self, label: impl Into<String>) -> Self {
255 self.label = Some(label.into());
256 self
257 }
258
259 pub fn shader(&self) -> &WGSLShader {
261 &self.shader
262 }
263
264 pub fn bind_groups(&self) -> &[BindGroupLayout] {
266 &self.bind_groups
267 }
268
269 pub fn label(&self) -> Option<&str> {
271 self.label.as_deref()
272 }
273
274 pub fn optimal_dispatch_size(&self, data_size: usize) -> (u32, u32, u32) {
276 let (wg_x, wg_y, wg_z) = self.shader.workgroup_size();
277 let workgroup_count = (data_size as u32 + wg_x - 1) / wg_x;
278 (workgroup_count, wg_y, wg_z)
279 }
280}
281
282#[derive(Debug, Clone)]
286pub struct BindGroupEntry {
287 binding: u32,
289 resource_type: ResourceType,
291 visibility: ShaderStage,
293}
294
295impl BindGroupEntry {
296 pub fn new(binding: u32, resource_type: ResourceType, visibility: ShaderStage) -> Self {
298 Self {
299 binding,
300 resource_type,
301 visibility,
302 }
303 }
304
305 pub fn binding(&self) -> u32 {
307 self.binding
308 }
309
310 pub fn resource_type(&self) -> ResourceType {
312 self.resource_type
313 }
314
315 pub fn visibility(&self) -> ShaderStage {
317 self.visibility
318 }
319}
320
321#[derive(Debug, Clone)]
325pub struct BindGroupLayout {
326 entries: Vec<BindGroupEntry>,
328 label: Option<String>,
330}
331
332impl BindGroupLayout {
333 pub fn new() -> Self {
335 Self {
336 entries: Vec::new(),
337 label: None,
338 }
339 }
340
341 pub fn with_entry(mut self, entry: BindGroupEntry) -> Self {
343 self.entries.push(entry);
344 self
345 }
346
347 pub fn with_label(mut self, label: impl Into<String>) -> Self {
349 self.label = Some(label.into());
350 self
351 }
352
353 pub fn entries(&self) -> &[BindGroupEntry] {
355 &self.entries
356 }
357
358 pub fn label(&self) -> Option<&str> {
360 self.label.as_deref()
361 }
362}
363
364impl Default for BindGroupLayout {
365 fn default() -> Self {
366 Self::new()
367 }
368}
369
370#[derive(Debug, Clone, Copy, PartialEq, Eq)]
372pub enum ResourceType {
373 StorageBuffer,
375 UniformBuffer,
377 ReadOnlyStorageBuffer,
379 Texture,
381 Sampler,
383}
384
385#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum ShaderStage {
388 Vertex,
390 Fragment,
392 Compute,
394 All,
396}
397
398#[derive(Debug, Clone, Copy)]
403pub struct WorkgroupOptimizer {
404 max_workgroup_size: u32,
406 preferred_1d: u32,
408 preferred_2d: (u32, u32),
410}
411
412impl WorkgroupOptimizer {
413 pub fn new(max_workgroup_size: u32) -> Self {
415 Self {
416 max_workgroup_size,
417 preferred_1d: 256,
418 preferred_2d: (16, 16),
419 }
420 }
421
422 pub fn optimize_1d(&self, data_size: usize) -> u32 {
424 let size = data_size as u32;
425 if size <= 64 {
426 64
427 } else if size <= self.preferred_1d {
428 self.preferred_1d
429 } else {
430 self.max_workgroup_size.min(512)
431 }
432 }
433
434 pub fn optimize_2d(&self, width: usize, height: usize) -> (u32, u32) {
436 let (w, h) = (width as u32, height as u32);
437 if w <= 16 && h <= 16 {
438 (8, 8)
439 } else if w <= 32 && h <= 32 {
440 (16, 16)
441 } else {
442 self.preferred_2d
443 }
444 }
445
446 pub fn optimize_3d(&self, width: usize, height: usize, depth: usize) -> (u32, u32, u32) {
448 let (w, h, d) = (width as u32, height as u32, depth as u32);
449 if w <= 8 && h <= 8 && d <= 8 {
450 (4, 4, 4)
451 } else {
452 (8, 8, 4)
453 }
454 }
455
456 pub fn max_workgroup_size(&self) -> u32 {
458 self.max_workgroup_size
459 }
460}
461
462impl Default for WorkgroupOptimizer {
463 fn default() -> Self {
464 Self::new(256) }
466}
467
468#[derive(Debug, Clone)]
472pub struct PipelineCache {
473 cache: Vec<(u64, ComputePipeline)>,
475 max_size: usize,
477}
478
479impl PipelineCache {
480 pub fn new(max_size: usize) -> Self {
482 Self {
483 cache: Vec::new(),
484 max_size,
485 }
486 }
487
488 pub fn get_or_create<F>(&mut self, shader: &WGSLShader, create_fn: F) -> &ComputePipeline
490 where
491 F: FnOnce(&WGSLShader) -> ComputePipeline,
492 {
493 let hash = self.hash_shader(shader);
494
495 if let Some(index) = self.cache.iter().position(|(h, _)| *h == hash) {
497 return &self.cache[index].1;
498 }
499
500 let pipeline = create_fn(shader);
502 self.cache.push((hash, pipeline));
503
504 if self.cache.len() > self.max_size {
506 self.cache.remove(0);
507 }
508
509 &self
510 .cache
511 .last()
512 .expect("cache should have at least one entry after push")
513 .1
514 }
515
516 fn hash_shader(&self, shader: &WGSLShader) -> u64 {
518 let src = shader.source();
520 let mut hash = src.len() as u64;
521 for (i, byte) in src.bytes().take(16).enumerate() {
522 hash = hash
523 .wrapping_mul(31)
524 .wrapping_add(byte as u64 * (i as u64 + 1));
525 }
526 hash
527 }
528
529 pub fn clear(&mut self) {
531 self.cache.clear();
532 }
533
534 pub fn size(&self) -> usize {
536 self.cache.len()
537 }
538
539 pub fn max_size(&self) -> usize {
541 self.max_size
542 }
543}
544
545pub mod shaders {
547 use super::*;
548
549 pub fn elementwise_add() -> WGSLShader {
551 WGSLShader::new(
552 r#"
553@group(0) @binding(0) var<storage, read> input_a: array<f32>;
554@group(0) @binding(1) var<storage, read> input_b: array<f32>;
555@group(0) @binding(2) var<storage, read_write> output: array<f32>;
556
557@compute @workgroup_size(256)
558fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
559 let index = global_id.x;
560 if (index < arrayLength(&input_a)) {
561 output[index] = input_a[index] + input_b[index];
562 }
563}
564"#,
565 )
566 }
567
568 pub fn elementwise_mul() -> WGSLShader {
570 WGSLShader::new(
571 r#"
572@group(0) @binding(0) var<storage, read> input_a: array<f32>;
573@group(0) @binding(1) var<storage, read> input_b: array<f32>;
574@group(0) @binding(2) var<storage, read_write> output: array<f32>;
575
576@compute @workgroup_size(256)
577fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
578 let index = global_id.x;
579 if (index < arrayLength(&input_a)) {
580 output[index] = input_a[index] * input_b[index];
581 }
582}
583"#,
584 )
585 }
586
587 pub fn matrix_mul() -> WGSLShader {
589 WGSLShader::new(
590 r#"
591struct Dimensions {
592 m: u32,
593 n: u32,
594 k: u32,
595}
596
597@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
598@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
599@group(0) @binding(2) var<storage, read_write> output: array<f32>;
600@group(0) @binding(3) var<uniform> dims: Dimensions;
601
602@compute @workgroup_size(16, 16)
603fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
604 let row = global_id.y;
605 let col = global_id.x;
606
607 if (row >= dims.m || col >= dims.n) {
608 return;
609 }
610
611 var sum = 0.0;
612 for (var i = 0u; i < dims.k; i++) {
613 let a_index = row * dims.k + i;
614 let b_index = i * dims.n + col;
615 sum += matrix_a[a_index] * matrix_b[b_index];
616 }
617
618 let out_index = row * dims.n + col;
619 output[out_index] = sum;
620}
621"#,
622 )
623 .with_workgroup_size(16, 16, 1)
624 }
625
626 pub fn reduce_sum() -> WGSLShader {
628 WGSLShader::new(
629 r#"
630@group(0) @binding(0) var<storage, read> input: array<f32>;
631@group(0) @binding(1) var<storage, read_write> output: array<f32>;
632
633var<workgroup> shared_data: array<f32, 256>;
634
635@compute @workgroup_size(256)
636fn main(
637 @builtin(global_invocation_id) global_id: vec3<u32>,
638 @builtin(local_invocation_id) local_id: vec3<u32>,
639) {
640 let tid = local_id.x;
641 let index = global_id.x;
642
643 // Load data into shared memory
644 if (index < arrayLength(&input)) {
645 shared_data[tid] = input[index];
646 } else {
647 shared_data[tid] = 0.0;
648 }
649
650 workgroupBarrier();
651
652 // Reduce in shared memory
653 for (var s = 128u; s > 0u; s >>= 1u) {
654 if (tid < s) {
655 shared_data[tid] += shared_data[tid + s];
656 }
657 workgroupBarrier();
658 }
659
660 // Write result
661 if (tid == 0u) {
662 output[global_id.x / 256u] = shared_data[0];
663 }
664}
665"#,
666 )
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn test_wgsl_shader_creation() {
676 let shader = WGSLShader::new("@compute @workgroup_size(64) fn main() {}");
677 assert!(shader.source().contains("@compute"));
678 assert_eq!(shader.entry_point(), "main");
679 assert_eq!(shader.workgroup_size(), (64, 1, 1));
680 }
681
682 #[test]
683 fn test_shader_with_entry_point() {
684 let shader = WGSLShader::with_entry_point("code", "custom_entry");
685 assert_eq!(shader.entry_point(), "custom_entry");
686 }
687
688 #[test]
689 fn test_shader_workgroup_size() {
690 let shader = WGSLShader::new("code").with_workgroup_size(16, 16, 1);
691 assert_eq!(shader.workgroup_size(), (16, 16, 1));
692 }
693
694 #[test]
695 fn test_shader_validation() {
696 let shader = WGSLShader::new("@compute fn main() {}");
697 assert!(shader.validate().is_ok());
698
699 let empty_shader = WGSLShader::new("");
700 assert_eq!(empty_shader.validate(), Err(ShaderError::EmptyShader));
701
702 let invalid_shader = WGSLShader::new("fn main() {}");
703 assert_eq!(
704 invalid_shader.validate(),
705 Err(ShaderError::MissingComputeAttribute)
706 );
707 }
708
709 #[test]
710 fn test_gpu_buffer_creation() {
711 let storage = GPUBuffer::storage(1024);
712 assert_eq!(storage.size(), 1024);
713 assert_eq!(storage.usage(), BufferUsage::Storage);
714 assert!(!storage.is_mapped());
715
716 let uniform = GPUBuffer::uniform(256);
717 assert_eq!(uniform.usage(), BufferUsage::Uniform);
718
719 let staging = GPUBuffer::staging(512);
720 assert_eq!(staging.usage(), BufferUsage::Staging);
721 assert!(staging.is_mapped());
722 }
723
724 #[test]
725 fn test_buffer_with_label() {
726 let buffer = GPUBuffer::storage(1024).with_label("test_buffer");
727 assert_eq!(buffer.label(), Some("test_buffer"));
728 }
729
730 #[test]
731 fn test_compute_pipeline() {
732 let shader = WGSLShader::new("@compute fn main() {}");
733 let pipeline = ComputePipeline::new(shader.clone());
734 assert_eq!(pipeline.shader().source(), shader.source());
735 assert_eq!(pipeline.bind_groups().len(), 0);
736 }
737
738 #[test]
739 fn test_pipeline_with_label() {
740 let shader = WGSLShader::new("@compute fn main() {}");
741 let pipeline = ComputePipeline::new(shader).with_label("test_pipeline");
742 assert_eq!(pipeline.label(), Some("test_pipeline"));
743 }
744
745 #[test]
746 fn test_optimal_dispatch_size() {
747 let shader = WGSLShader::new("code").with_workgroup_size(64, 1, 1);
748 let pipeline = ComputePipeline::new(shader);
749
750 let (x, _, _) = pipeline.optimal_dispatch_size(1000);
751 assert_eq!(x, 16); }
753
754 #[test]
755 fn test_bind_group_entry() {
756 let entry = BindGroupEntry::new(0, ResourceType::StorageBuffer, ShaderStage::Compute);
757 assert_eq!(entry.binding(), 0);
758 assert_eq!(entry.resource_type(), ResourceType::StorageBuffer);
759 assert_eq!(entry.visibility(), ShaderStage::Compute);
760 }
761
762 #[test]
763 fn test_bind_group_layout() {
764 let entry = BindGroupEntry::new(0, ResourceType::StorageBuffer, ShaderStage::Compute);
765 let layout = BindGroupLayout::new().with_entry(entry);
766 assert_eq!(layout.entries().len(), 1);
767 }
768
769 #[test]
770 fn test_workgroup_optimizer() {
771 let optimizer = WorkgroupOptimizer::new(512);
772 assert_eq!(optimizer.optimize_1d(50), 64); assert_eq!(optimizer.optimize_1d(100), 256); assert_eq!(optimizer.optimize_1d(500), 512); let (w, h) = optimizer.optimize_2d(10, 10);
777 assert_eq!((w, h), (8, 8));
778
779 let (w, h, d) = optimizer.optimize_3d(10, 10, 10);
780 assert_eq!((w, h, d), (8, 8, 4));
781 }
782
783 #[test]
784 fn test_default_workgroup_optimizer() {
785 let optimizer = WorkgroupOptimizer::default();
786 assert_eq!(optimizer.max_workgroup_size(), 256);
787 }
788
789 #[test]
790 fn test_pipeline_cache() {
791 let mut cache = PipelineCache::new(10);
792 assert_eq!(cache.size(), 0);
793
794 let shader = WGSLShader::new("@compute fn main() {}");
795 let _pipeline = cache.get_or_create(&shader, |s| ComputePipeline::new(s.clone()));
796 assert_eq!(cache.size(), 1);
797
798 cache.clear();
799 assert_eq!(cache.size(), 0);
800 }
801
802 #[test]
803 fn test_shader_templates() {
804 let add_shader = shaders::elementwise_add();
805 assert!(add_shader.source().contains("input_a"));
806 assert!(add_shader.validate().is_ok());
807
808 let mul_shader = shaders::elementwise_mul();
809 assert!(mul_shader.source().contains("input_b"));
810 assert!(mul_shader.validate().is_ok());
811
812 let matmul_shader = shaders::matrix_mul();
813 assert!(matmul_shader.source().contains("matrix_a"));
814 assert_eq!(matmul_shader.workgroup_size(), (16, 16, 1));
815 assert!(matmul_shader.validate().is_ok());
816
817 let reduce_shader = shaders::reduce_sum();
818 assert!(reduce_shader.source().contains("shared_data"));
819 assert!(reduce_shader.validate().is_ok());
820 }
821
822 #[test]
823 fn test_resource_types() {
824 let _storage = ResourceType::StorageBuffer;
825 let _uniform = ResourceType::UniformBuffer;
826 let _readonly = ResourceType::ReadOnlyStorageBuffer;
827 let _texture = ResourceType::Texture;
828 let _sampler = ResourceType::Sampler;
829 }
830
831 #[test]
832 fn test_shader_stages() {
833 let _vertex = ShaderStage::Vertex;
834 let _fragment = ShaderStage::Fragment;
835 let _compute = ShaderStage::Compute;
836 let _all = ShaderStage::All;
837 }
838
839 #[test]
840 fn test_buffer_usage_types() {
841 let _storage = BufferUsage::Storage;
842 let _uniform = BufferUsage::Uniform;
843 let _staging = BufferUsage::Staging;
844 let _vertex = BufferUsage::Vertex;
845 let _index = BufferUsage::Index;
846 }
847
848 #[test]
849 fn test_shader_error_display() {
850 let err = ShaderError::EmptyShader;
851 assert_eq!(format!("{}", err), "Shader source is empty");
852
853 let err = ShaderError::MissingComputeAttribute;
854 assert_eq!(format!("{}", err), "Missing @compute attribute");
855
856 let err = ShaderError::SyntaxError("test".to_string());
857 assert_eq!(format!("{}", err), "Syntax error: test");
858
859 let err = ShaderError::CompilationFailed("test".to_string());
860 assert_eq!(format!("{}", err), "Compilation failed: test");
861 }
862
863 #[test]
864 fn test_pipeline_with_bind_group() {
865 let shader = WGSLShader::new("@compute fn main() {}");
866 let layout = BindGroupLayout::new();
867 let pipeline = ComputePipeline::new(shader).with_bind_group(layout);
868 assert_eq!(pipeline.bind_groups().len(), 1);
869 }
870
871 #[test]
872 fn test_bind_group_with_label() {
873 let layout = BindGroupLayout::new().with_label("test_layout");
874 assert_eq!(layout.label(), Some("test_layout"));
875 }
876
877 #[test]
878 fn test_default_bind_group_layout() {
879 let layout = BindGroupLayout::default();
880 assert_eq!(layout.entries().len(), 0);
881 }
882
883 #[test]
884 fn test_cache_max_size() {
885 let cache = PipelineCache::new(5);
886 assert_eq!(cache.max_size(), 5);
887 }
888}