1use std::marker::PhantomData;
2
3use crate::{
4 env::Environment,
5 expr::{ErasedExpr, Expr},
6 fun::{ErasedFunHandle, FunDef, FunHandle},
7 input::{
8 FragmentShaderInputs, GeometryShaderInputs, Inputs, TessCtrlShaderInputs, TessEvalShaderInputs,
9 VertexShaderInputs,
10 },
11 output::{
12 FragmentShaderOutputs, GeometryShaderOutputs, Outputs, TessCtrlShaderOutputs,
13 TessEvalShaderOutputs, VertexShaderOutputs,
14 },
15 scope::ScopedHandle,
16 shader::ShaderDecl,
17 types::ToType,
18};
19
20#[derive(Debug)]
22pub struct Stage<S, I, O, E>
23where
24 S: ?Sized,
25{
26 pub(crate) builder: ModBuilder<S, I, O, E>,
27}
28
29pub trait ShaderModule<I, O> {
36 type Inputs;
37 type Outputs;
38
39 fn new_shader_module<E>(
48 f: impl FnOnce(
49 ModBuilder<Self, I, O, E>,
50 Self::Inputs,
51 Self::Outputs,
52 E::Env,
53 ) -> Stage<Self, I, O, E>,
54 ) -> Stage<Self, I, O, E>
55 where
56 E: Environment;
57}
58
59#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
61pub struct VS;
62
63impl<I, O> ShaderModule<I, O> for VS
64where
65 I: Inputs,
66 O: Outputs,
67{
68 type Inputs = VertexShaderInputs<I::In>;
69 type Outputs = VertexShaderOutputs<O::Out>;
70
71 fn new_shader_module<E>(
72 f: impl FnOnce(
73 ModBuilder<Self, I, O, E>,
74 Self::Inputs,
75 Self::Outputs,
76 E::Env,
77 ) -> Stage<Self, I, O, E>,
78 ) -> Stage<Self, I, O, E>
79 where
80 E: Environment,
81 {
82 f(
83 ModBuilder::new(),
84 VertexShaderInputs::new(I::input()),
85 VertexShaderOutputs::new(O::output()),
86 E::env(),
87 )
88 }
89}
90
91#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
93pub struct TCS;
94
95impl<I, O> ShaderModule<I, O> for TCS
96where
97 I: Inputs,
98 O: Outputs,
99{
100 type Inputs = TessCtrlShaderInputs<I::In>;
101 type Outputs = TessCtrlShaderOutputs<O::Out>;
102
103 fn new_shader_module<E>(
104 f: impl FnOnce(
105 ModBuilder<Self, I, O, E>,
106 Self::Inputs,
107 Self::Outputs,
108 E::Env,
109 ) -> Stage<Self, I, O, E>,
110 ) -> Stage<Self, I, O, E>
111 where
112 E: Environment,
113 {
114 f(
115 ModBuilder::new(),
116 TessCtrlShaderInputs::new(I::input()),
117 TessCtrlShaderOutputs::new(O::output()),
118 E::env(),
119 )
120 }
121}
122
123#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
125pub struct TES;
126
127impl<I, O> ShaderModule<I, O> for TES
128where
129 I: Inputs,
130 O: Outputs,
131{
132 type Inputs = TessEvalShaderInputs<I::In>;
133 type Outputs = TessEvalShaderOutputs<O::Out>;
134
135 fn new_shader_module<E>(
136 f: impl FnOnce(
137 ModBuilder<Self, I, O, E>,
138 Self::Inputs,
139 Self::Outputs,
140 E::Env,
141 ) -> Stage<Self, I, O, E>,
142 ) -> Stage<Self, I, O, E>
143 where
144 E: Environment,
145 {
146 f(
147 ModBuilder::new(),
148 TessEvalShaderInputs::new(I::input()),
149 TessEvalShaderOutputs::new(O::output()),
150 E::env(),
151 )
152 }
153}
154
155#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
157pub struct GS;
158
159impl<I, O> ShaderModule<I, O> for GS
160where
161 I: Inputs,
162 O: Outputs,
163{
164 type Inputs = GeometryShaderInputs<I::In>;
165 type Outputs = GeometryShaderOutputs<O::Out>;
166
167 fn new_shader_module<E>(
168 f: impl FnOnce(
169 ModBuilder<Self, I, O, E>,
170 Self::Inputs,
171 Self::Outputs,
172 E::Env,
173 ) -> Stage<Self, I, O, E>,
174 ) -> Stage<Self, I, O, E>
175 where
176 E: Environment,
177 {
178 f(
179 ModBuilder::new(),
180 GeometryShaderInputs::new(I::input()),
181 GeometryShaderOutputs::new(O::output()),
182 E::env(),
183 )
184 }
185}
186
187#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
189pub struct FS;
190
191impl<I, O> ShaderModule<I, O> for FS
192where
193 I: Inputs,
194 O: Outputs,
195{
196 type Inputs = FragmentShaderInputs<I::In>;
197 type Outputs = FragmentShaderOutputs<O::Out>;
198
199 fn new_shader_module<E>(
200 f: impl FnOnce(
201 ModBuilder<Self, I, O, E>,
202 Self::Inputs,
203 Self::Outputs,
204 E::Env,
205 ) -> Stage<Self, I, O, E>,
206 ) -> Stage<Self, I, O, E>
207 where
208 E: Environment,
209 {
210 f(
211 ModBuilder::new(),
212 FragmentShaderInputs::new(I::input()),
213 FragmentShaderOutputs::new(O::output()),
214 E::env(),
215 )
216 }
217}
218
219#[derive(Debug)]
226pub struct ModBuilder<S, I, O, E>
227where
228 S: ?Sized,
229{
230 pub(crate) decls: Vec<ShaderDecl>,
231 next_fun_handle: u16,
232 next_global_handle: u16,
233 _phantom: PhantomData<(*const S, I, O, E)>,
234}
235
236impl<S, I, O, E> ModBuilder<S, I, O, E>
237where
238 S: ?Sized,
239 I: Inputs,
240 O: Outputs,
241 E: Environment,
242{
243 fn new() -> Self {
244 Self {
245 decls: Vec::new(),
246 next_fun_handle: 0,
247 next_global_handle: 0,
248 _phantom: PhantomData,
249 }
250 }
251 pub fn fun<R, A>(&mut self, fundef: FunDef<R, A>) -> FunHandle<R, A> {
260 let handle = self.next_fun_handle;
261 self.next_fun_handle += 1;
262
263 self.decls.push(ShaderDecl::FunDef(handle, fundef.erased));
264
265 FunHandle::new(ErasedFunHandle::UserDefined(handle as _))
266 }
267
268 pub fn constant<T>(&mut self, expr: Expr<T>) -> Expr<T>
277 where
278 T: ToType,
279 {
280 let handle = self.next_global_handle;
281 self.next_global_handle += 1;
282
283 self
284 .decls
285 .push(ShaderDecl::Const(handle, T::ty(), expr.erased));
286
287 Expr::new(ErasedExpr::Var(ScopedHandle::global(handle)))
288 }
289}
290
291impl<S, I, O, E> ModBuilder<S, I, O, E>
292where
293 S: ShaderModule<I, O>,
294 I: Inputs,
295 O: Outputs,
296 E: Environment,
297{
298 pub fn new_stage(
300 f: impl FnOnce(ModBuilder<S, I, O, E>, S::Inputs, S::Outputs, E::Env) -> Stage<S, I, O, E>,
301 ) -> Stage<S, I, O, E> {
302 S::new_shader_module(f)
303 }
304
305 pub fn main_fun(mut self, fundef: FunDef<(), ()>) -> Stage<S, I, O, E> {
318 self.decls.push(ShaderDecl::Main(fundef.erased));
319 Stage { builder: self }
320 }
321}