1use std::borrow::Cow;
2use std::fs;
3
4use wgpu::{
5 BindGroupLayout, BlendState, ColorTargetState, ColorWrites, DepthStencilState, Device, Face,
6 FragmentState, FrontFace, MultisampleState, PipelineLayoutDescriptor, PolygonMode,
7 PrimitiveState, PrimitiveTopology, RenderPipeline, RenderPipelineDescriptor, ShaderModule,
8 ShaderModuleDescriptor, ShaderSource, SurfaceConfiguration, TextureFormat, VertexBufferLayout,
9 VertexState,
10};
11
12pub struct RenderPipelineCreator<'a> {
14 device: &'a Device,
15 format: &'a TextureFormat,
16
17 shader: ShaderModule,
18 vertex_main: &'a str,
19 fragment_main: &'a str,
20
21 vertex_buffers: Vec<VertexBufferLayout<'a>>,
22 bind_groups: Vec<&'a BindGroupLayout>,
23
24 depth_stencil: Option<DepthStencilState>,
25
26 label: &'a str,
27
28 blend_state: BlendState,
29}
30
31impl<'a> RenderPipelineCreator<'a> {
32 pub fn from_shader_file(
34 path: &'a str,
35 device: &'a Device,
36 config: &'a SurfaceConfiguration,
37 ) -> RenderPipelineCreator<'a> {
38 Self::from_shader_code(
39 &fs::read_to_string(path)
40 .unwrap_or_else(|_| panic!("Could not find Shader-File at {}", path)),
41 device,
42 config,
43 )
44 }
45
46 pub fn from_shader_code(
48 shader_code: &str,
49 device: &'a Device,
50 config: &'a SurfaceConfiguration,
51 ) -> RenderPipelineCreator<'a> {
52 let shader = device.create_shader_module(ShaderModuleDescriptor {
53 label: Some("Render Pipeline Shader"),
54 source: ShaderSource::Wgsl(Cow::from(shader_code)),
55 });
56
57 RenderPipelineCreator {
58 device,
59 format: &config.format,
60 shader,
61 vertex_main: "vs_main",
62 fragment_main: "fs_main",
63
64 vertex_buffers: vec![],
65 bind_groups: vec![],
66
67 depth_stencil: None,
68
69 label: "Render Pipeline",
70 blend_state: BlendState::REPLACE,
71 }
72 }
73
74 pub fn add_vertex_buffer(mut self, layout: VertexBufferLayout<'a>) -> Self {
76 self.vertex_buffers.push(layout);
77
78 self
79 }
80
81 pub fn add_bind_group(mut self, layout: &'a BindGroupLayout) -> Self {
83 self.bind_groups.push(layout);
84
85 self
86 }
87
88 pub fn fragment_main(mut self, fn_name: &'a str) -> Self {
90 self.fragment_main = fn_name;
91 self
92 }
93
94 pub fn vertex_main(mut self, fn_name: &'a str) -> Self {
96 self.vertex_main = fn_name;
97 self
98 }
99
100 pub fn depth_stencil(mut self, depth_stencil: DepthStencilState) -> Self {
102 self.depth_stencil = Some(depth_stencil);
103 self
104 }
105
106 pub fn blend_state(mut self, blend_state: BlendState) -> Self {
108 self.blend_state = blend_state;
109 self
110 }
111
112 pub fn build(&self) -> RenderPipeline {
114 let render_pipeline_layout =
115 self.device
116 .create_pipeline_layout(&PipelineLayoutDescriptor {
117 label: Some(&(self.label.to_owned() + " Layout")),
118 bind_group_layouts: &self.bind_groups[..],
119 push_constant_ranges: &[],
120 });
121
122 self.device
123 .create_render_pipeline(&RenderPipelineDescriptor {
124 label: Some(self.label),
125 layout: Some(&render_pipeline_layout),
126 vertex: VertexState {
127 module: &self.shader,
128 entry_point: self.vertex_main,
129 buffers: &self.vertex_buffers[..],
130 },
131 fragment: Some(FragmentState {
132 module: &self.shader,
133 entry_point: self.fragment_main,
134 targets: &[Some(ColorTargetState {
135 format: self.format.to_owned(),
136 blend: Some(self.blend_state),
137 write_mask: ColorWrites::ALL,
138 })],
139 }),
140 primitive: PrimitiveState {
141 topology: PrimitiveTopology::TriangleList,
142 strip_index_format: None,
143 front_face: FrontFace::Ccw,
144 cull_mode: Some(Face::Back),
145 polygon_mode: PolygonMode::Fill,
146 unclipped_depth: false,
147 conservative: false,
148 },
149 depth_stencil: self.depth_stencil.to_owned(),
150 multisample: MultisampleState {
151 count: 1,
152 mask: !0,
153 alpha_to_coverage_enabled: false,
154 },
155 multiview: None,
156 })
157 }
158}