wgpu_algorithms/sort/
pipeline.rs1use crate::{common, context::Context};
2
3pub struct SortPipeline {
4 pub bind_group_layout: wgpu::BindGroupLayout,
5 pub reduce_pipeline: wgpu::ComputePipeline,
6 pub scatter_pipeline: wgpu::ComputePipeline,
7 pub vt: u32,
8 pub block_size: u32,
9}
10
11impl SortPipeline {
12 pub fn new(ctx: &Context) -> Self {
13 let bind_group_layout =
14 ctx.device
15 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
16 label: Some("Fused Sort Layout"),
17 entries: &[
18 common::buffers::bind_entry(0, true, false), common::buffers::bind_entry(1, false, false), common::buffers::bind_entry(2, false, false), common::buffers::bind_entry(3, false, true), ],
23 });
24
25 let limits = ctx.device.limits();
26 let max_shared_mem = limits.max_compute_workgroup_storage_size;
27
28 let (vt, block_size) = if max_shared_mem >= 32768 {
29 (8, 256) } else {
31 (4, 128) };
33
34 let config = common::shader::ShaderConfig { vt, block_size };
35
36 common::shader::create_compute_pipeline(
37 &ctx.device,
38 &bind_group_layout,
39 include_str!("sort.wgsl"),
40 &format!("Reduce VT{}", vt),
41 "main_reduce",
42 Some(&config),
43 );
44
45 common::shader::create_compute_pipeline(
46 &ctx.device,
47 &bind_group_layout,
48 include_str!("sort.wgsl"),
49 &format!("Scatter VT{}", vt),
50 "main_scatter",
51 Some(&config),
52 );
53
54 let raw_shader = include_str!("sort.wgsl");
55 let final_source = raw_shader
56 .replace("{{VT}}", &vt.to_string())
57 .replace("{{BLOCK_SIZE}}", &block_size.to_string());
58
59 let shader_module = ctx
60 .device
61 .create_shader_module(wgpu::ShaderModuleDescriptor {
62 label: Some("Fused Sort Shader"),
63 source: wgpu::ShaderSource::Wgsl(final_source.into()),
64 });
65
66 let pipeline_layout = ctx
67 .device
68 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
69 label: Some("Fused Pipeline Layout"),
70 bind_group_layouts: &[&bind_group_layout],
71 immediate_size: 0,
72 });
73
74 let reduce_pipeline =
75 ctx.device
76 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
77 label: Some("Reduce Pipeline"),
78 layout: Some(&pipeline_layout),
79 module: &shader_module,
80 entry_point: Some("main_reduce"),
81 compilation_options: Default::default(),
82 cache: None,
83 });
84
85 let scatter_pipeline =
86 ctx.device
87 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
88 label: Some("Scatter Pipeline"),
89 layout: Some(&pipeline_layout),
90 module: &shader_module,
91 entry_point: Some("main_scatter"),
92 compilation_options: Default::default(),
93 cache: None,
94 });
95
96 Self {
97 bind_group_layout,
98 reduce_pipeline,
99 scatter_pipeline,
100 vt,
101 block_size,
102 }
103 }
104}