srs2dge_core/shader/
builder.rs

1use super::{layout::AutoLayout, module::ShaderModule, Shader};
2use crate::{
3    buffer::{
4        index::{DefaultIndex, Index},
5        vertex::{DefaultVertex, Vertex},
6    },
7    label,
8    target::Target,
9};
10use std::marker::PhantomData;
11use wgpu::{
12    BlendState, ColorTargetState, ColorWrites, FragmentState, FrontFace, MultisampleState,
13    PipelineLayoutDescriptor, PolygonMode, PrimitiveState, PrimitiveTopology,
14    RenderPipelineDescriptor, TextureFormat, VertexState,
15};
16
17//
18
19pub struct ShaderBuilder<
20    's,
21    V = DefaultVertex,
22    I = DefaultIndex,
23    const VS: bool = false,
24    const FS: bool = false,
25    const FMT: bool = false,
26> {
27    pub(crate) vert: Option<(&'s ShaderModule<'s>, &'s str)>,
28    pub(crate) frag: Option<(&'s ShaderModule<'s>, &'s str)>,
29    format: Option<TextureFormat>,
30    layout: Option<PipelineLayoutDescriptor<'s>>,
31    topology: PrimitiveTopology,
32    label: Option<&'s str>,
33
34    _p: PhantomData<(V, I)>,
35}
36
37//
38
39impl<'s, V, I, const VS: bool, const FS: bool, const FMT: bool> Default
40    for ShaderBuilder<'s, V, I, VS, FS, FMT>
41{
42    fn default() -> Self {
43        Self {
44            vert: None,
45            frag: None,
46            format: None,
47            layout: None,
48            topology: PrimitiveTopology::TriangleStrip,
49            label: label!(),
50
51            _p: PhantomData::default(),
52        }
53    }
54}
55
56impl<'s, V, I, const VS: bool, const FS: bool, const FMT: bool>
57    ShaderBuilder<'s, V, I, VS, FS, FMT>
58{
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    fn pass<Vn, In, const VSN: bool, const FSN: bool, const FMTN: bool>(
64        self,
65    ) -> ShaderBuilder<'s, Vn, In, VSN, FSN, FMTN> {
66        ShaderBuilder {
67            vert: self.vert,
68            frag: self.frag,
69            format: self.format,
70            layout: self.layout,
71            topology: self.topology,
72            label: self.label,
73
74            _p: PhantomData::default(),
75        }
76    }
77
78    pub fn with_vertex<'n: 's>(
79        self,
80        module: &'n ShaderModule,
81        entry: &'n str,
82    ) -> ShaderBuilder<'s, V, I, true, FS, FMT> {
83        ShaderBuilder {
84            vert: Some((module, entry)),
85            ..self.pass()
86        }
87    }
88
89    pub fn with_fragment<'n: 's>(
90        self,
91        module: &'n ShaderModule,
92        entry: &'n str,
93    ) -> ShaderBuilder<'s, V, I, VS, true, FMT> {
94        ShaderBuilder {
95            frag: Some((module, entry)),
96            ..self.pass()
97        }
98    }
99
100    pub fn with_format(self, format: TextureFormat) -> ShaderBuilder<'s, V, I, VS, FS, true> {
101        ShaderBuilder {
102            format: Some(format),
103            ..self.pass()
104        }
105    }
106
107    pub fn with_vertex_format<Vn>(self) -> ShaderBuilder<'s, Vn, I, VS, FS, true> {
108        ShaderBuilder { ..self.pass() }
109    }
110
111    pub fn with_index_format<In>(self) -> ShaderBuilder<'s, V, In, VS, FS, true> {
112        ShaderBuilder { ..self.pass() }
113    }
114
115    pub fn with_topology(mut self, topology: PrimitiveTopology) -> Self {
116        self.topology = topology;
117        self
118    }
119
120    pub fn with_label<'n: 's>(mut self, label: Option<&'n str>) -> Self {
121        self.label = label;
122        self
123    }
124
125    pub fn with_baked_layout<'l: 's>(
126        self,
127        layout: PipelineLayoutDescriptor<'l>,
128    ) -> ShaderBuilder<'s, V, I, VS, FS, FMT> {
129        ShaderBuilder {
130            layout: Some(layout),
131            ..self.pass()
132        }
133    }
134}
135
136impl<'s, V, I> ShaderBuilder<'s, V, I, true, true, true>
137where
138    V: Vertex,
139    I: Index,
140{
141    pub fn build(self, target: &Target) -> Shader<V, I> {
142        let (vert_mod, vert_entry) = self.vert.unwrap();
143        let (frag_mod, frag_entry) = self.frag.unwrap();
144        let format = self.format.unwrap();
145
146        let layout = match self.layout {
147            Some(l) => target.device.create_pipeline_layout(&l),
148            None => {
149                let a = AutoLayout::new(target, (vert_mod, vert_entry), (frag_mod, frag_entry));
150                let a = a.get();
151                target.device.create_pipeline_layout(&a.get())
152            }
153        };
154
155        let strip_index_format = if let PrimitiveTopology::LineStrip
156        | PrimitiveTopology::TriangleStrip = self.topology
157        {
158            Some(I::FORMAT)
159        } else {
160            None
161        };
162
163        let pipeline = target
164            .device
165            .create_render_pipeline(&RenderPipelineDescriptor {
166                label: self.label,
167                layout: Some(&layout),
168                vertex: VertexState {
169                    module: &vert_mod.inner,
170                    entry_point: vert_entry,
171                    buffers: V::LAYOUT,
172                },
173                primitive: PrimitiveState {
174                    topology: self.topology,
175                    strip_index_format,
176                    front_face: FrontFace::Ccw,
177                    cull_mode: None,
178                    unclipped_depth: false,
179                    polygon_mode: PolygonMode::Fill,
180                    conservative: false,
181                },
182                depth_stencil: None,
183                multisample: MultisampleState::default(),
184                fragment: Some(FragmentState {
185                    module: &frag_mod.inner,
186                    entry_point: frag_entry,
187                    targets: &[ColorTargetState {
188                        format,
189                        blend: Some(BlendState::ALPHA_BLENDING),
190                        write_mask: ColorWrites::ALL,
191                    }],
192                }),
193                multiview: None,
194            });
195
196        Shader {
197            pipeline,
198            format,
199
200            _p: PhantomData::default(),
201        }
202    }
203}