wgpu_algorithms/sort/
pipeline.rs

1use crate::{common, context::Context};
2
3pub struct SortPipeline {
4    pub bind_group_layout: wgpu::BindGroupLayout,
5    pub reduce_pipeline: wgpu::ComputePipeline,
6    pub scatter_pipeline: wgpu::ComputePipeline,
7    pub vt: u32,
8    pub block_size: u32,
9}
10
11impl SortPipeline {
12    pub fn new(ctx: &Context) -> Self {
13        let bind_group_layout =
14            ctx.device
15                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
16                    label: Some("Fused Sort Layout"),
17                    entries: &[
18                        common::buffers::bind_entry(0, true, false),  // Input
19                        common::buffers::bind_entry(1, false, false), // Hist
20                        common::buffers::bind_entry(2, false, false), // Output
21                        common::buffers::bind_entry(3, false, true),  // Uniforms
22                    ],
23                });
24
25        let limits = ctx.device.limits();
26        let max_shared_mem = limits.max_compute_workgroup_storage_size;
27
28        let (vt, block_size) = if max_shared_mem >= 32768 {
29            (8, 256) // M3 / Desktop
30        } else {
31            (4, 128) // Mobile
32        };
33
34        let config = common::shader::ShaderConfig { vt, block_size };
35
36        common::shader::create_compute_pipeline(
37            &ctx.device,
38            &bind_group_layout,
39            include_str!("sort.wgsl"),
40            &format!("Reduce VT{}", vt),
41            "main_reduce",
42            Some(&config),
43        );
44
45        common::shader::create_compute_pipeline(
46            &ctx.device,
47            &bind_group_layout,
48            include_str!("sort.wgsl"),
49            &format!("Scatter VT{}", vt),
50            "main_scatter",
51            Some(&config),
52        );
53
54        let raw_shader = include_str!("sort.wgsl");
55        let final_source = raw_shader
56            .replace("{{VT}}", &vt.to_string())
57            .replace("{{BLOCK_SIZE}}", &block_size.to_string());
58
59        let shader_module = ctx
60            .device
61            .create_shader_module(wgpu::ShaderModuleDescriptor {
62                label: Some("Fused Sort Shader"),
63                source: wgpu::ShaderSource::Wgsl(final_source.into()),
64            });
65
66        let pipeline_layout = ctx
67            .device
68            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
69                label: Some("Fused Pipeline Layout"),
70                bind_group_layouts: &[&bind_group_layout],
71                immediate_size: 0,
72            });
73
74        let reduce_pipeline =
75            ctx.device
76                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
77                    label: Some("Reduce Pipeline"),
78                    layout: Some(&pipeline_layout),
79                    module: &shader_module,
80                    entry_point: Some("main_reduce"),
81                    compilation_options: Default::default(),
82                    cache: None,
83                });
84
85        let scatter_pipeline =
86            ctx.device
87                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
88                    label: Some("Scatter Pipeline"),
89                    layout: Some(&pipeline_layout),
90                    module: &shader_module,
91                    entry_point: Some("main_scatter"),
92                    compilation_options: Default::default(),
93                    cache: None,
94                });
95
96        Self {
97            bind_group_layout,
98            reduce_pipeline,
99            scatter_pipeline,
100            vt,
101            block_size,
102        }
103    }
104}