Skip to main content

torsh_core/
webgpu.rs

1// Copyright (c) 2025 ToRSh Contributors
2//
3// WebGPU Compute Shader Integration
4//
5// This module provides abstractions for WebGPU compute shaders, enabling
6// high-performance tensor operations in web browsers and native applications.
7//
8// # Key Features
9//
10// - **WGSL Shader Compilation**: Compile and manage WebGPU Shading Language shaders
11// - **Compute Pipeline Management**: Efficient pipeline creation and caching
12// - **Buffer Management**: Optimized GPU buffer allocation and transfer
13// - **Workgroup Optimization**: Automatic workgroup size calculation
14// - **Cross-Platform**: Works in browsers (via wasm) and native applications
15//
16// # Design Principles
17//
18// 1. **Zero-Copy Transfer**: Minimize data transfer between CPU and GPU
19// 2. **Pipeline Caching**: Reuse compiled pipelines for performance
20// 3. **Memory Efficiency**: Efficient buffer pooling and management
21// 4. **Shader Composition**: Modular shader building blocks
22//
23// # Examples
24//
25// ```rust
26// use torsh_core::webgpu::{WGSLShader, ComputePipeline, GPUBuffer};
27//
28// // Define a WGSL compute shader
29// let shader = WGSLShader::new(r#"
30//     @compute @workgroup_size(64)
31//     fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
32//         // Compute shader code
33//     }
34// "#);
35//
36// // Create a compute pipeline
37// let pipeline = ComputePipeline::new(shader);
38//
39// // Allocate GPU buffers
40// let input_buffer = GPUBuffer::storage(data.len());
41// let output_buffer = GPUBuffer::storage(data.len());
42// ```
43
44use core::fmt;
45
46/// WebGPU compute shader in WGSL (WebGPU Shading Language)
47///
48/// This struct represents a compute shader written in WGSL that can be
49/// compiled and executed on the GPU.
50#[derive(Debug, Clone)]
51pub struct WGSLShader {
52    /// Shader source code in WGSL
53    source: String,
54    /// Entry point function name
55    entry_point: String,
56    /// Workgroup size (x, y, z)
57    workgroup_size: (u32, u32, u32),
58}
59
60impl WGSLShader {
61    /// Create a new WGSL shader
62    pub fn new(source: impl Into<String>) -> Self {
63        Self {
64            source: source.into(),
65            entry_point: "main".to_string(),
66            workgroup_size: (64, 1, 1), // Default workgroup size
67        }
68    }
69
70    /// Create a shader with custom entry point
71    pub fn with_entry_point(source: impl Into<String>, entry_point: impl Into<String>) -> Self {
72        Self {
73            source: source.into(),
74            entry_point: entry_point.into(),
75            workgroup_size: (64, 1, 1),
76        }
77    }
78
79    /// Set workgroup size
80    pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
81        self.workgroup_size = (x, y, z);
82        self
83    }
84
85    /// Get shader source
86    pub fn source(&self) -> &str {
87        &self.source
88    }
89
90    /// Get entry point
91    pub fn entry_point(&self) -> &str {
92        &self.entry_point
93    }
94
95    /// Get workgroup size
96    pub fn workgroup_size(&self) -> (u32, u32, u32) {
97        self.workgroup_size
98    }
99
100    /// Validate shader syntax (basic validation)
101    pub fn validate(&self) -> Result<(), ShaderError> {
102        if self.source.is_empty() {
103            return Err(ShaderError::EmptyShader);
104        }
105        if !self.source.contains("@compute") {
106            return Err(ShaderError::MissingComputeAttribute);
107        }
108        Ok(())
109    }
110}
111
112/// Shader compilation and validation errors
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub enum ShaderError {
115    /// Shader source is empty
116    EmptyShader,
117    /// Missing @compute attribute
118    MissingComputeAttribute,
119    /// Syntax error in shader
120    SyntaxError(String),
121    /// Compilation failed
122    CompilationFailed(String),
123}
124
125impl fmt::Display for ShaderError {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        match self {
128            ShaderError::EmptyShader => write!(f, "Shader source is empty"),
129            ShaderError::MissingComputeAttribute => write!(f, "Missing @compute attribute"),
130            ShaderError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
131            ShaderError::CompilationFailed(msg) => write!(f, "Compilation failed: {}", msg),
132        }
133    }
134}
135
136/// GPU buffer types for WebGPU
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
138pub enum BufferUsage {
139    /// Storage buffer (read/write)
140    Storage,
141    /// Uniform buffer (read-only, small data)
142    Uniform,
143    /// Staging buffer (CPU-GPU transfer)
144    Staging,
145    /// Vertex buffer
146    Vertex,
147    /// Index buffer
148    Index,
149}
150
151/// GPU buffer descriptor
152///
153/// Describes a GPU buffer for allocation and management.
154#[derive(Debug, Clone)]
155pub struct GPUBuffer {
156    /// Buffer size in bytes
157    size: usize,
158    /// Buffer usage
159    usage: BufferUsage,
160    /// Whether buffer is mapped for CPU access
161    mapped: bool,
162    /// Buffer label for debugging
163    label: Option<String>,
164}
165
166impl GPUBuffer {
167    /// Create a storage buffer
168    pub fn storage(size: usize) -> Self {
169        Self {
170            size,
171            usage: BufferUsage::Storage,
172            mapped: false,
173            label: None,
174        }
175    }
176
177    /// Create a uniform buffer
178    pub fn uniform(size: usize) -> Self {
179        Self {
180            size,
181            usage: BufferUsage::Uniform,
182            mapped: false,
183            label: None,
184        }
185    }
186
187    /// Create a staging buffer
188    pub fn staging(size: usize) -> Self {
189        Self {
190            size,
191            usage: BufferUsage::Staging,
192            mapped: true,
193            label: None,
194        }
195    }
196
197    /// Set buffer label
198    pub fn with_label(mut self, label: impl Into<String>) -> Self {
199        self.label = Some(label.into());
200        self
201    }
202
203    /// Get buffer size
204    pub fn size(&self) -> usize {
205        self.size
206    }
207
208    /// Get buffer usage
209    pub fn usage(&self) -> BufferUsage {
210        self.usage
211    }
212
213    /// Check if buffer is mapped
214    pub fn is_mapped(&self) -> bool {
215        self.mapped
216    }
217
218    /// Get buffer label
219    pub fn label(&self) -> Option<&str> {
220        self.label.as_deref()
221    }
222}
223
224/// Compute pipeline for executing shaders
225///
226/// Represents a compiled compute pipeline that can be executed on the GPU.
227#[derive(Debug, Clone)]
228pub struct ComputePipeline {
229    /// Associated shader
230    shader: WGSLShader,
231    /// Bind group layouts
232    bind_groups: Vec<BindGroupLayout>,
233    /// Pipeline label
234    label: Option<String>,
235}
236
237impl ComputePipeline {
238    /// Create a new compute pipeline
239    pub fn new(shader: WGSLShader) -> Self {
240        Self {
241            shader,
242            bind_groups: Vec::new(),
243            label: None,
244        }
245    }
246
247    /// Add a bind group layout
248    pub fn with_bind_group(mut self, layout: BindGroupLayout) -> Self {
249        self.bind_groups.push(layout);
250        self
251    }
252
253    /// Set pipeline label
254    pub fn with_label(mut self, label: impl Into<String>) -> Self {
255        self.label = Some(label.into());
256        self
257    }
258
259    /// Get shader
260    pub fn shader(&self) -> &WGSLShader {
261        &self.shader
262    }
263
264    /// Get bind groups
265    pub fn bind_groups(&self) -> &[BindGroupLayout] {
266        &self.bind_groups
267    }
268
269    /// Get pipeline label
270    pub fn label(&self) -> Option<&str> {
271        self.label.as_deref()
272    }
273
274    /// Calculate optimal dispatch size for a given data size
275    pub fn optimal_dispatch_size(&self, data_size: usize) -> (u32, u32, u32) {
276        let (wg_x, wg_y, wg_z) = self.shader.workgroup_size();
277        let workgroup_count = (data_size as u32 + wg_x - 1) / wg_x;
278        (workgroup_count, wg_y, wg_z)
279    }
280}
281
282/// Bind group layout entry
283///
284/// Describes a binding in a bind group (buffer, texture, sampler, etc.)
285#[derive(Debug, Clone)]
286pub struct BindGroupEntry {
287    /// Binding index
288    binding: u32,
289    /// Resource type
290    resource_type: ResourceType,
291    /// Shader visibility
292    visibility: ShaderStage,
293}
294
295impl BindGroupEntry {
296    /// Create a new bind group entry
297    pub fn new(binding: u32, resource_type: ResourceType, visibility: ShaderStage) -> Self {
298        Self {
299            binding,
300            resource_type,
301            visibility,
302        }
303    }
304
305    /// Get binding index
306    pub fn binding(&self) -> u32 {
307        self.binding
308    }
309
310    /// Get resource type
311    pub fn resource_type(&self) -> ResourceType {
312        self.resource_type
313    }
314
315    /// Get visibility
316    pub fn visibility(&self) -> ShaderStage {
317        self.visibility
318    }
319}
320
321/// Bind group layout
322///
323/// Describes the layout of bindings in a bind group.
324#[derive(Debug, Clone)]
325pub struct BindGroupLayout {
326    /// Entries in this bind group
327    entries: Vec<BindGroupEntry>,
328    /// Layout label
329    label: Option<String>,
330}
331
332impl BindGroupLayout {
333    /// Create a new bind group layout
334    pub fn new() -> Self {
335        Self {
336            entries: Vec::new(),
337            label: None,
338        }
339    }
340
341    /// Add an entry
342    pub fn with_entry(mut self, entry: BindGroupEntry) -> Self {
343        self.entries.push(entry);
344        self
345    }
346
347    /// Set label
348    pub fn with_label(mut self, label: impl Into<String>) -> Self {
349        self.label = Some(label.into());
350        self
351    }
352
353    /// Get entries
354    pub fn entries(&self) -> &[BindGroupEntry] {
355        &self.entries
356    }
357
358    /// Get label
359    pub fn label(&self) -> Option<&str> {
360        self.label.as_deref()
361    }
362}
363
364impl Default for BindGroupLayout {
365    fn default() -> Self {
366        Self::new()
367    }
368}
369
370/// Resource types for bind group entries
371#[derive(Debug, Clone, Copy, PartialEq, Eq)]
372pub enum ResourceType {
373    /// Storage buffer (read/write)
374    StorageBuffer,
375    /// Uniform buffer (read-only)
376    UniformBuffer,
377    /// Read-only storage buffer
378    ReadOnlyStorageBuffer,
379    /// Texture (2D/3D)
380    Texture,
381    /// Sampler
382    Sampler,
383}
384
385/// Shader stage visibility
386#[derive(Debug, Clone, Copy, PartialEq, Eq)]
387pub enum ShaderStage {
388    /// Vertex shader
389    Vertex,
390    /// Fragment shader
391    Fragment,
392    /// Compute shader
393    Compute,
394    /// All stages
395    All,
396}
397
398/// Workgroup size optimizer
399///
400/// Automatically calculates optimal workgroup sizes based on data dimensions
401/// and GPU capabilities.
402#[derive(Debug, Clone, Copy)]
403pub struct WorkgroupOptimizer {
404    /// Maximum workgroup size (device-dependent)
405    max_workgroup_size: u32,
406    /// Preferred workgroup size for 1D operations
407    preferred_1d: u32,
408    /// Preferred workgroup size for 2D operations
409    preferred_2d: (u32, u32),
410}
411
412impl WorkgroupOptimizer {
413    /// Create a new workgroup optimizer
414    pub fn new(max_workgroup_size: u32) -> Self {
415        Self {
416            max_workgroup_size,
417            preferred_1d: 256,
418            preferred_2d: (16, 16),
419        }
420    }
421
422    /// Calculate optimal workgroup size for 1D data
423    pub fn optimize_1d(&self, data_size: usize) -> u32 {
424        let size = data_size as u32;
425        if size <= 64 {
426            64
427        } else if size <= self.preferred_1d {
428            self.preferred_1d
429        } else {
430            self.max_workgroup_size.min(512)
431        }
432    }
433
434    /// Calculate optimal workgroup size for 2D data
435    pub fn optimize_2d(&self, width: usize, height: usize) -> (u32, u32) {
436        let (w, h) = (width as u32, height as u32);
437        if w <= 16 && h <= 16 {
438            (8, 8)
439        } else if w <= 32 && h <= 32 {
440            (16, 16)
441        } else {
442            self.preferred_2d
443        }
444    }
445
446    /// Calculate optimal workgroup size for 3D data
447    pub fn optimize_3d(&self, width: usize, height: usize, depth: usize) -> (u32, u32, u32) {
448        let (w, h, d) = (width as u32, height as u32, depth as u32);
449        if w <= 8 && h <= 8 && d <= 8 {
450            (4, 4, 4)
451        } else {
452            (8, 8, 4)
453        }
454    }
455
456    /// Get maximum workgroup size
457    pub fn max_workgroup_size(&self) -> u32 {
458        self.max_workgroup_size
459    }
460}
461
462impl Default for WorkgroupOptimizer {
463    fn default() -> Self {
464        Self::new(256) // Common default
465    }
466}
467
468/// Pipeline cache for reusing compiled pipelines
469///
470/// Caches compiled compute pipelines to avoid recompilation.
471#[derive(Debug, Clone)]
472pub struct PipelineCache {
473    /// Cached pipelines (shader source hash -> pipeline)
474    cache: Vec<(u64, ComputePipeline)>,
475    /// Maximum cache size
476    max_size: usize,
477}
478
479impl PipelineCache {
480    /// Create a new pipeline cache
481    pub fn new(max_size: usize) -> Self {
482        Self {
483            cache: Vec::new(),
484            max_size,
485        }
486    }
487
488    /// Get or create a pipeline
489    pub fn get_or_create<F>(&mut self, shader: &WGSLShader, create_fn: F) -> &ComputePipeline
490    where
491        F: FnOnce(&WGSLShader) -> ComputePipeline,
492    {
493        let hash = self.hash_shader(shader);
494
495        // Check if already cached
496        if let Some(index) = self.cache.iter().position(|(h, _)| *h == hash) {
497            return &self.cache[index].1;
498        }
499
500        // Create new pipeline
501        let pipeline = create_fn(shader);
502        self.cache.push((hash, pipeline));
503
504        // Evict oldest if cache is full
505        if self.cache.len() > self.max_size {
506            self.cache.remove(0);
507        }
508
509        &self
510            .cache
511            .last()
512            .expect("cache should have at least one entry after push")
513            .1
514    }
515
516    /// Simple hash function for shader source
517    fn hash_shader(&self, shader: &WGSLShader) -> u64 {
518        // Simple hash based on source length and first few characters
519        let src = shader.source();
520        let mut hash = src.len() as u64;
521        for (i, byte) in src.bytes().take(16).enumerate() {
522            hash = hash
523                .wrapping_mul(31)
524                .wrapping_add(byte as u64 * (i as u64 + 1));
525        }
526        hash
527    }
528
529    /// Clear the cache
530    pub fn clear(&mut self) {
531        self.cache.clear();
532    }
533
534    /// Get cache size
535    pub fn size(&self) -> usize {
536        self.cache.len()
537    }
538
539    /// Get maximum cache size
540    pub fn max_size(&self) -> usize {
541        self.max_size
542    }
543}
544
545/// Common WGSL shader templates
546pub mod shaders {
547    use super::*;
548
549    /// Element-wise addition shader
550    pub fn elementwise_add() -> WGSLShader {
551        WGSLShader::new(
552            r#"
553@group(0) @binding(0) var<storage, read> input_a: array<f32>;
554@group(0) @binding(1) var<storage, read> input_b: array<f32>;
555@group(0) @binding(2) var<storage, read_write> output: array<f32>;
556
557@compute @workgroup_size(256)
558fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
559    let index = global_id.x;
560    if (index < arrayLength(&input_a)) {
561        output[index] = input_a[index] + input_b[index];
562    }
563}
564"#,
565        )
566    }
567
568    /// Element-wise multiplication shader
569    pub fn elementwise_mul() -> WGSLShader {
570        WGSLShader::new(
571            r#"
572@group(0) @binding(0) var<storage, read> input_a: array<f32>;
573@group(0) @binding(1) var<storage, read> input_b: array<f32>;
574@group(0) @binding(2) var<storage, read_write> output: array<f32>;
575
576@compute @workgroup_size(256)
577fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
578    let index = global_id.x;
579    if (index < arrayLength(&input_a)) {
580        output[index] = input_a[index] * input_b[index];
581    }
582}
583"#,
584        )
585    }
586
587    /// Matrix multiplication shader (simple version)
588    pub fn matrix_mul() -> WGSLShader {
589        WGSLShader::new(
590            r#"
591struct Dimensions {
592    m: u32,
593    n: u32,
594    k: u32,
595}
596
597@group(0) @binding(0) var<storage, read> matrix_a: array<f32>;
598@group(0) @binding(1) var<storage, read> matrix_b: array<f32>;
599@group(0) @binding(2) var<storage, read_write> output: array<f32>;
600@group(0) @binding(3) var<uniform> dims: Dimensions;
601
602@compute @workgroup_size(16, 16)
603fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
604    let row = global_id.y;
605    let col = global_id.x;
606
607    if (row >= dims.m || col >= dims.n) {
608        return;
609    }
610
611    var sum = 0.0;
612    for (var i = 0u; i < dims.k; i++) {
613        let a_index = row * dims.k + i;
614        let b_index = i * dims.n + col;
615        sum += matrix_a[a_index] * matrix_b[b_index];
616    }
617
618    let out_index = row * dims.n + col;
619    output[out_index] = sum;
620}
621"#,
622        )
623        .with_workgroup_size(16, 16, 1)
624    }
625
626    /// Reduction (sum) shader
627    pub fn reduce_sum() -> WGSLShader {
628        WGSLShader::new(
629            r#"
630@group(0) @binding(0) var<storage, read> input: array<f32>;
631@group(0) @binding(1) var<storage, read_write> output: array<f32>;
632
633var<workgroup> shared_data: array<f32, 256>;
634
635@compute @workgroup_size(256)
636fn main(
637    @builtin(global_invocation_id) global_id: vec3<u32>,
638    @builtin(local_invocation_id) local_id: vec3<u32>,
639) {
640    let tid = local_id.x;
641    let index = global_id.x;
642
643    // Load data into shared memory
644    if (index < arrayLength(&input)) {
645        shared_data[tid] = input[index];
646    } else {
647        shared_data[tid] = 0.0;
648    }
649
650    workgroupBarrier();
651
652    // Reduce in shared memory
653    for (var s = 128u; s > 0u; s >>= 1u) {
654        if (tid < s) {
655            shared_data[tid] += shared_data[tid + s];
656        }
657        workgroupBarrier();
658    }
659
660    // Write result
661    if (tid == 0u) {
662        output[global_id.x / 256u] = shared_data[0];
663    }
664}
665"#,
666        )
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[test]
675    fn test_wgsl_shader_creation() {
676        let shader = WGSLShader::new("@compute @workgroup_size(64) fn main() {}");
677        assert!(shader.source().contains("@compute"));
678        assert_eq!(shader.entry_point(), "main");
679        assert_eq!(shader.workgroup_size(), (64, 1, 1));
680    }
681
682    #[test]
683    fn test_shader_with_entry_point() {
684        let shader = WGSLShader::with_entry_point("code", "custom_entry");
685        assert_eq!(shader.entry_point(), "custom_entry");
686    }
687
688    #[test]
689    fn test_shader_workgroup_size() {
690        let shader = WGSLShader::new("code").with_workgroup_size(16, 16, 1);
691        assert_eq!(shader.workgroup_size(), (16, 16, 1));
692    }
693
694    #[test]
695    fn test_shader_validation() {
696        let shader = WGSLShader::new("@compute fn main() {}");
697        assert!(shader.validate().is_ok());
698
699        let empty_shader = WGSLShader::new("");
700        assert_eq!(empty_shader.validate(), Err(ShaderError::EmptyShader));
701
702        let invalid_shader = WGSLShader::new("fn main() {}");
703        assert_eq!(
704            invalid_shader.validate(),
705            Err(ShaderError::MissingComputeAttribute)
706        );
707    }
708
709    #[test]
710    fn test_gpu_buffer_creation() {
711        let storage = GPUBuffer::storage(1024);
712        assert_eq!(storage.size(), 1024);
713        assert_eq!(storage.usage(), BufferUsage::Storage);
714        assert!(!storage.is_mapped());
715
716        let uniform = GPUBuffer::uniform(256);
717        assert_eq!(uniform.usage(), BufferUsage::Uniform);
718
719        let staging = GPUBuffer::staging(512);
720        assert_eq!(staging.usage(), BufferUsage::Staging);
721        assert!(staging.is_mapped());
722    }
723
724    #[test]
725    fn test_buffer_with_label() {
726        let buffer = GPUBuffer::storage(1024).with_label("test_buffer");
727        assert_eq!(buffer.label(), Some("test_buffer"));
728    }
729
730    #[test]
731    fn test_compute_pipeline() {
732        let shader = WGSLShader::new("@compute fn main() {}");
733        let pipeline = ComputePipeline::new(shader.clone());
734        assert_eq!(pipeline.shader().source(), shader.source());
735        assert_eq!(pipeline.bind_groups().len(), 0);
736    }
737
738    #[test]
739    fn test_pipeline_with_label() {
740        let shader = WGSLShader::new("@compute fn main() {}");
741        let pipeline = ComputePipeline::new(shader).with_label("test_pipeline");
742        assert_eq!(pipeline.label(), Some("test_pipeline"));
743    }
744
745    #[test]
746    fn test_optimal_dispatch_size() {
747        let shader = WGSLShader::new("code").with_workgroup_size(64, 1, 1);
748        let pipeline = ComputePipeline::new(shader);
749
750        let (x, _, _) = pipeline.optimal_dispatch_size(1000);
751        assert_eq!(x, 16); // ceil(1000 / 64) = 16
752    }
753
754    #[test]
755    fn test_bind_group_entry() {
756        let entry = BindGroupEntry::new(0, ResourceType::StorageBuffer, ShaderStage::Compute);
757        assert_eq!(entry.binding(), 0);
758        assert_eq!(entry.resource_type(), ResourceType::StorageBuffer);
759        assert_eq!(entry.visibility(), ShaderStage::Compute);
760    }
761
762    #[test]
763    fn test_bind_group_layout() {
764        let entry = BindGroupEntry::new(0, ResourceType::StorageBuffer, ShaderStage::Compute);
765        let layout = BindGroupLayout::new().with_entry(entry);
766        assert_eq!(layout.entries().len(), 1);
767    }
768
769    #[test]
770    fn test_workgroup_optimizer() {
771        let optimizer = WorkgroupOptimizer::new(512);
772        assert_eq!(optimizer.optimize_1d(50), 64); // size <= 64 -> 64
773        assert_eq!(optimizer.optimize_1d(100), 256); // size <= 256 -> 256
774        assert_eq!(optimizer.optimize_1d(500), 512); // size > 256 -> min(max, 512)
775
776        let (w, h) = optimizer.optimize_2d(10, 10);
777        assert_eq!((w, h), (8, 8));
778
779        let (w, h, d) = optimizer.optimize_3d(10, 10, 10);
780        assert_eq!((w, h, d), (8, 8, 4));
781    }
782
783    #[test]
784    fn test_default_workgroup_optimizer() {
785        let optimizer = WorkgroupOptimizer::default();
786        assert_eq!(optimizer.max_workgroup_size(), 256);
787    }
788
789    #[test]
790    fn test_pipeline_cache() {
791        let mut cache = PipelineCache::new(10);
792        assert_eq!(cache.size(), 0);
793
794        let shader = WGSLShader::new("@compute fn main() {}");
795        let _pipeline = cache.get_or_create(&shader, |s| ComputePipeline::new(s.clone()));
796        assert_eq!(cache.size(), 1);
797
798        cache.clear();
799        assert_eq!(cache.size(), 0);
800    }
801
802    #[test]
803    fn test_shader_templates() {
804        let add_shader = shaders::elementwise_add();
805        assert!(add_shader.source().contains("input_a"));
806        assert!(add_shader.validate().is_ok());
807
808        let mul_shader = shaders::elementwise_mul();
809        assert!(mul_shader.source().contains("input_b"));
810        assert!(mul_shader.validate().is_ok());
811
812        let matmul_shader = shaders::matrix_mul();
813        assert!(matmul_shader.source().contains("matrix_a"));
814        assert_eq!(matmul_shader.workgroup_size(), (16, 16, 1));
815        assert!(matmul_shader.validate().is_ok());
816
817        let reduce_shader = shaders::reduce_sum();
818        assert!(reduce_shader.source().contains("shared_data"));
819        assert!(reduce_shader.validate().is_ok());
820    }
821
822    #[test]
823    fn test_resource_types() {
824        let _storage = ResourceType::StorageBuffer;
825        let _uniform = ResourceType::UniformBuffer;
826        let _readonly = ResourceType::ReadOnlyStorageBuffer;
827        let _texture = ResourceType::Texture;
828        let _sampler = ResourceType::Sampler;
829    }
830
831    #[test]
832    fn test_shader_stages() {
833        let _vertex = ShaderStage::Vertex;
834        let _fragment = ShaderStage::Fragment;
835        let _compute = ShaderStage::Compute;
836        let _all = ShaderStage::All;
837    }
838
839    #[test]
840    fn test_buffer_usage_types() {
841        let _storage = BufferUsage::Storage;
842        let _uniform = BufferUsage::Uniform;
843        let _staging = BufferUsage::Staging;
844        let _vertex = BufferUsage::Vertex;
845        let _index = BufferUsage::Index;
846    }
847
848    #[test]
849    fn test_shader_error_display() {
850        let err = ShaderError::EmptyShader;
851        assert_eq!(format!("{}", err), "Shader source is empty");
852
853        let err = ShaderError::MissingComputeAttribute;
854        assert_eq!(format!("{}", err), "Missing @compute attribute");
855
856        let err = ShaderError::SyntaxError("test".to_string());
857        assert_eq!(format!("{}", err), "Syntax error: test");
858
859        let err = ShaderError::CompilationFailed("test".to_string());
860        assert_eq!(format!("{}", err), "Compilation failed: test");
861    }
862
863    #[test]
864    fn test_pipeline_with_bind_group() {
865        let shader = WGSLShader::new("@compute fn main() {}");
866        let layout = BindGroupLayout::new();
867        let pipeline = ComputePipeline::new(shader).with_bind_group(layout);
868        assert_eq!(pipeline.bind_groups().len(), 1);
869    }
870
871    #[test]
872    fn test_bind_group_with_label() {
873        let layout = BindGroupLayout::new().with_label("test_layout");
874        assert_eq!(layout.label(), Some("test_layout"));
875    }
876
877    #[test]
878    fn test_default_bind_group_layout() {
879        let layout = BindGroupLayout::default();
880        assert_eq!(layout.entries().len(), 0);
881    }
882
883    #[test]
884    fn test_cache_max_size() {
885        let cache = PipelineCache::new(5);
886        assert_eq!(cache.max_size(), 5);
887    }
888}