1use super::{layout::AutoLayout, module::ShaderModule, Shader};
2use crate::{
3 buffer::{
4 index::{DefaultIndex, Index},
5 vertex::{DefaultVertex, Vertex},
6 },
7 label,
8 target::Target,
9};
10use std::marker::PhantomData;
11use wgpu::{
12 BlendState, ColorTargetState, ColorWrites, FragmentState, FrontFace, MultisampleState,
13 PipelineLayoutDescriptor, PolygonMode, PrimitiveState, PrimitiveTopology,
14 RenderPipelineDescriptor, TextureFormat, VertexState,
15};
16
17pub struct ShaderBuilder<
20 's,
21 V = DefaultVertex,
22 I = DefaultIndex,
23 const VS: bool = false,
24 const FS: bool = false,
25 const FMT: bool = false,
26> {
27 pub(crate) vert: Option<(&'s ShaderModule<'s>, &'s str)>,
28 pub(crate) frag: Option<(&'s ShaderModule<'s>, &'s str)>,
29 format: Option<TextureFormat>,
30 layout: Option<PipelineLayoutDescriptor<'s>>,
31 topology: PrimitiveTopology,
32 label: Option<&'s str>,
33
34 _p: PhantomData<(V, I)>,
35}
36
37impl<'s, V, I, const VS: bool, const FS: bool, const FMT: bool> Default
40 for ShaderBuilder<'s, V, I, VS, FS, FMT>
41{
42 fn default() -> Self {
43 Self {
44 vert: None,
45 frag: None,
46 format: None,
47 layout: None,
48 topology: PrimitiveTopology::TriangleStrip,
49 label: label!(),
50
51 _p: PhantomData::default(),
52 }
53 }
54}
55
56impl<'s, V, I, const VS: bool, const FS: bool, const FMT: bool>
57 ShaderBuilder<'s, V, I, VS, FS, FMT>
58{
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 fn pass<Vn, In, const VSN: bool, const FSN: bool, const FMTN: bool>(
64 self,
65 ) -> ShaderBuilder<'s, Vn, In, VSN, FSN, FMTN> {
66 ShaderBuilder {
67 vert: self.vert,
68 frag: self.frag,
69 format: self.format,
70 layout: self.layout,
71 topology: self.topology,
72 label: self.label,
73
74 _p: PhantomData::default(),
75 }
76 }
77
78 pub fn with_vertex<'n: 's>(
79 self,
80 module: &'n ShaderModule,
81 entry: &'n str,
82 ) -> ShaderBuilder<'s, V, I, true, FS, FMT> {
83 ShaderBuilder {
84 vert: Some((module, entry)),
85 ..self.pass()
86 }
87 }
88
89 pub fn with_fragment<'n: 's>(
90 self,
91 module: &'n ShaderModule,
92 entry: &'n str,
93 ) -> ShaderBuilder<'s, V, I, VS, true, FMT> {
94 ShaderBuilder {
95 frag: Some((module, entry)),
96 ..self.pass()
97 }
98 }
99
100 pub fn with_format(self, format: TextureFormat) -> ShaderBuilder<'s, V, I, VS, FS, true> {
101 ShaderBuilder {
102 format: Some(format),
103 ..self.pass()
104 }
105 }
106
107 pub fn with_vertex_format<Vn>(self) -> ShaderBuilder<'s, Vn, I, VS, FS, true> {
108 ShaderBuilder { ..self.pass() }
109 }
110
111 pub fn with_index_format<In>(self) -> ShaderBuilder<'s, V, In, VS, FS, true> {
112 ShaderBuilder { ..self.pass() }
113 }
114
115 pub fn with_topology(mut self, topology: PrimitiveTopology) -> Self {
116 self.topology = topology;
117 self
118 }
119
120 pub fn with_label<'n: 's>(mut self, label: Option<&'n str>) -> Self {
121 self.label = label;
122 self
123 }
124
125 pub fn with_baked_layout<'l: 's>(
126 self,
127 layout: PipelineLayoutDescriptor<'l>,
128 ) -> ShaderBuilder<'s, V, I, VS, FS, FMT> {
129 ShaderBuilder {
130 layout: Some(layout),
131 ..self.pass()
132 }
133 }
134}
135
136impl<'s, V, I> ShaderBuilder<'s, V, I, true, true, true>
137where
138 V: Vertex,
139 I: Index,
140{
141 pub fn build(self, target: &Target) -> Shader<V, I> {
142 let (vert_mod, vert_entry) = self.vert.unwrap();
143 let (frag_mod, frag_entry) = self.frag.unwrap();
144 let format = self.format.unwrap();
145
146 let layout = match self.layout {
147 Some(l) => target.device.create_pipeline_layout(&l),
148 None => {
149 let a = AutoLayout::new(target, (vert_mod, vert_entry), (frag_mod, frag_entry));
150 let a = a.get();
151 target.device.create_pipeline_layout(&a.get())
152 }
153 };
154
155 let strip_index_format = if let PrimitiveTopology::LineStrip
156 | PrimitiveTopology::TriangleStrip = self.topology
157 {
158 Some(I::FORMAT)
159 } else {
160 None
161 };
162
163 let pipeline = target
164 .device
165 .create_render_pipeline(&RenderPipelineDescriptor {
166 label: self.label,
167 layout: Some(&layout),
168 vertex: VertexState {
169 module: &vert_mod.inner,
170 entry_point: vert_entry,
171 buffers: V::LAYOUT,
172 },
173 primitive: PrimitiveState {
174 topology: self.topology,
175 strip_index_format,
176 front_face: FrontFace::Ccw,
177 cull_mode: None,
178 unclipped_depth: false,
179 polygon_mode: PolygonMode::Fill,
180 conservative: false,
181 },
182 depth_stencil: None,
183 multisample: MultisampleState::default(),
184 fragment: Some(FragmentState {
185 module: &frag_mod.inner,
186 entry_point: frag_entry,
187 targets: &[ColorTargetState {
188 format,
189 blend: Some(BlendState::ALPHA_BLENDING),
190 write_mask: ColorWrites::ALL,
191 }],
192 }),
193 multiview: None,
194 });
195
196 Shader {
197 pipeline,
198 format,
199
200 _p: PhantomData::default(),
201 }
202 }
203}