1use std::collections::HashSet;
2
3use wasm_encoder::{
4 DataCountSection, DataSection, Function, Instruction, MemArg, TypeSection, ValType, ComponentTypeSection, PrimitiveValType, ComponentValType, BlockType,
5};
6
7use crate::{parse::Node, FileData};
8
9const REALLOC_FUNC_INDEX: u32 = 0;
10const MEMORY_INDEX: u32 = 0;
11
12const MAX_FLAT_PARAMS: u32 = 16;
13
14pub struct TemplateGenerator<'source> {
15 params: Params<'source>,
16 file_data: &'source FileData<'source>,
17}
18
19pub struct Params<'source> {
20 text_params: Vec<&'source str>,
21 cond_params: Vec<&'source str>,
22}
23
24impl<'source> Params<'source> {
25 pub fn new(contents: &'source Vec<Node<'source>>) -> Self {
26 let mut text_params = HashSet::new();
27 let mut cond_params = HashSet::new();
28 for node in contents {
29 Self::collect_params(node, &mut text_params, &mut cond_params);
30 }
31 let mut text_params: Vec<&str> = text_params.into_iter().collect();
32 let mut cond_params: Vec<&str> = cond_params.into_iter().collect();
33 text_params.sort();
34 cond_params.sort();
35 Params {
36 text_params,
37 cond_params,
38 }
39 }
40
41 fn collect_params(
42 node: &'source Node<'source>,
43 text_params: &mut HashSet<&'source str>,
44 cond_params: &mut HashSet<&'source str>,
45 ) {
46 match node {
47 Node::Text { .. } => {}
48 Node::Parameter { name } => {
49 text_params.insert(name.value);
50 }
51 Node::Conditional {
52 if_kwd: _,
53 cond_ident,
54 contents,
55 endif_kwd: _,
56 } => {
57 cond_params.insert(cond_ident.value);
58 for node in contents {
59 Self::collect_params(node, text_params, cond_params);
60 }
61 }
62 }
63 }
64
65 pub fn stack_len(&self) -> u32 {
66 self.text_stack_len() + (self.cond_params.len() as u32)
67 }
68
69 fn text_stack_len(&self) -> u32 {
70 2 * (self.text_params.len() as u32)
71 }
72
73 fn text_mem_len(&self) -> u32 {
74 8 * (self.text_params.len() as u32)
75 }
76
77 pub fn must_spill(&self) -> bool {
78 self.stack_len() > MAX_FLAT_PARAMS
79 }
80
81 pub fn text_params_len(&self) -> usize {
83 self.text_params.len()
84 }
85
86 pub fn text_param_index(&self, param: &str) -> usize {
88 self.text_params.binary_search(¶m).unwrap()
89 }
90
91 pub fn cond_param_index(&self, param: &str) -> usize {
93 self.cond_params.binary_search(¶m).unwrap()
94 }
95
96 pub fn record_type(&self) -> ComponentTypeSection {
97 let mut types = ComponentTypeSection::new();
98 let converted_names: Vec<String> = self.text_params.iter().map(|param: &&str| snake_to_kebab(param)).collect();
99 let text_fields = converted_names.iter().map(|param| {
100 (
101 param.as_str(),
102 ComponentValType::Primitive(PrimitiveValType::String),
103 )
104 });
105 let converted_names: Vec<String> = self.cond_params.iter().map(|param: &&str| snake_to_kebab(param)).collect();
106 let cond_fields = converted_names.iter().map(|param| {
107 (
108 param.as_str(),
109 ComponentValType::Primitive(PrimitiveValType::Bool),
110 )
111 });
112 let fields: Vec<(&str, ComponentValType)> = text_fields.chain(cond_fields).collect();
113 types.defined_type().record(fields);
114 types
115 }
116
117 fn gen_push_text_offset(&self, func: &mut Function, text_index: u32) {
118 self.gen_push_text_field(func, text_index, 0)
119 }
120
121 fn gen_push_text_len(&self, func: &mut Function, text_index: u32) {
122 self.gen_push_text_field(func, text_index, 1)
123 }
124
125 fn gen_push_text_field(&self, func: &mut Function, text_index: u32, field: u32) {
126 if self.must_spill() {
127 func.instruction(&Instruction::LocalGet(0));
129 let shift = (text_index * 8) + (field * 4);
131 let shift = shift.try_into().unwrap();
132 func.instruction(&Instruction::I32Const(shift));
133 func.instruction(&Instruction::I32Add);
135 func.instruction(&Instruction::I32Load(MemArg {
137 offset: 0,
138 align: 4,
139 memory_index: 0,
140 }));
141 } else {
142 let local_index = 2 * text_index + field;
143 func.instruction(&Instruction::LocalGet(local_index));
144 }
145 }
146
147 fn gen_push_cond(&self, func: &mut Function, cond_index: u32) {
148 if self.must_spill() {
149 func.instruction(&Instruction::LocalGet(0));
151 let shift = self.text_mem_len() + cond_index;
153 let shift = shift.try_into().unwrap();
154 func.instruction(&Instruction::I32Const(shift));
155 func.instruction(&Instruction::I32Add);
157 func.instruction(&Instruction::I32Load8U(MemArg {
159 offset: 0,
160 align: 1,
161 memory_index: 0,
162 }));
163 } else {
164 let local_index = self.text_stack_len() + cond_index;
165 func.instruction(&Instruction::LocalGet(local_index));
166 }
167 }
168}
169
170impl<'source> TemplateGenerator<'source> {
171 pub fn new(params: Params<'source>, file_data: &'source FileData<'source>) -> Self {
172 Self { params, file_data }
173 }
174
175 pub fn params(&self) -> &Params<'source> {
176 &self.params
177 }
178
179 fn arguments_len(&self) -> u32 {
180 if self.params.must_spill() {
181 1
182 } else {
183 self.params.stack_len()
184 }
185 }
186
187 fn result_len_local(&self) -> u32 {
188 self.arguments_len() + 0
189 }
190
191 fn result_addr_local(&self) -> u32 {
192 self.arguments_len() + 1
193 }
194
195 fn return_area_local(&self) -> u32 {
196 self.arguments_len() + 2
197 }
198
199 fn result_cursor_local(&self) -> u32 {
200 self.arguments_len() + 3
201 }
202
203 fn locals_len(&self) -> u32 {
204 4
205 }
206
207 pub fn gen_core_type(&self, types: &mut TypeSection) {
208 let params = vec![ValType::I32; self.arguments_len() as usize];
209 let results = vec![ValType::I32];
210 types.function(params, results);
211 }
212
213 pub fn gen_data(&self) -> (DataCountSection, DataSection) {
214 let mut count = 0;
215 let mut data = DataSection::new();
216
217 for node in self.file_data.contents.iter() {
218 Self::collect_data(node, &mut count, &mut data);
219 }
220
221 let count = DataCountSection { count };
222 (count, data)
223 }
224
225 fn collect_data(node: &Node<'source>, count: &mut u32, data: &mut DataSection) {
226 match node {
227 Node::Text { index: _, text } => {
228 data.passive(text.value.bytes());
229 *count += 1;
230 }
231 Node::Parameter { name: _ } => {}
232 Node::Conditional {
233 if_kwd: _,
234 cond_ident: _,
235 contents,
236 endif_kwd: _,
237 } => {
238 for node in contents {
239 Self::collect_data(node, count, data);
240 }
241 }
242 }
243 }
244
245 pub fn gen_core_function(&self) -> Function {
246 let locals = vec![(self.locals_len(), ValType::I32)];
248 let mut func = Function::new(locals);
249
250 self.gen_calculate_len(&mut func);
251 self.gen_allocate_results(&mut func);
252 self.gen_init_cursor(&mut func);
253 self.gen_write_template(&mut func);
254
255 func.instruction(&Instruction::LocalGet(self.return_area_local()));
256 func.instruction(&Instruction::End);
257 func
258 }
259
260 fn gen_calculate_len(&self, func: &mut Function) {
261 self.gen_calculate_sequence_len(func, self.file_data.contents.as_slice());
262 func.instruction(&Instruction::LocalSet(self.result_len_local()));
264 }
265
266 fn gen_calculate_sequence_len(&self, func: &mut Function, sequence: &[Node<'source>]) {
267 let mut base_length = 0;
268 let mut param_counts = vec![0; self.params.text_params_len()];
269 let mut prior_exists = false;
270 for node in sequence.iter() {
271 match node {
272 Node::Text { index: _, text } => {
273 base_length += text.value.len() as i32;
274 }
275 Node::Parameter { name } => {
276 let index = self.params.text_param_index(&name.value);
277 param_counts[index] += 1;
278 }
279 Node::Conditional {
280 if_kwd: _,
281 cond_ident,
282 contents,
283 endif_kwd: _,
284 } => {
285 let cond_index = self.params.cond_param_index(cond_ident.value) as u32;
286
287 self.params.gen_push_cond(func, cond_index);
288 func.instruction(&Instruction::If(BlockType::Result(ValType::I32)));
289 self.gen_calculate_sequence_len(func, &contents);
290 func.instruction(&Instruction::Else);
291 func.instruction(&Instruction::I32Const(0));
292 func.instruction(&Instruction::End);
293
294 if prior_exists {
295 func.instruction(&Instruction::I32Add);
296 }
297
298 prior_exists = true;
299 }
300 }
301 }
302
303 func.instruction(&Instruction::I32Const(base_length));
305
306 if prior_exists {
307 func.instruction(&Instruction::I32Add);
308 }
309
310 for (index, count) in param_counts.iter().enumerate() {
312 if *count > 0 {
313 self.params.gen_push_text_len(func, index as u32);
315 func.instruction(&Instruction::I32Const(*count));
317 func.instruction(&Instruction::I32Mul);
319 func.instruction(&Instruction::I32Add);
321 }
322 }
323 }
324
325 fn gen_allocate_results(&self, func: &mut Function) {
326 func.instruction(&Instruction::I32Const(0));
328 func.instruction(&Instruction::I32Const(0));
329 func.instruction(&Instruction::I32Const(1));
330 func.instruction(&Instruction::LocalGet(self.result_len_local()));
331 func.instruction(&Instruction::Call(REALLOC_FUNC_INDEX));
332 func.instruction(&Instruction::LocalSet(self.result_addr_local()));
334
335 func.instruction(&Instruction::I32Const(0));
337 func.instruction(&Instruction::I32Const(0));
338 func.instruction(&Instruction::I32Const(4));
339 func.instruction(&Instruction::I32Const(8));
340 func.instruction(&Instruction::Call(REALLOC_FUNC_INDEX));
341 func.instruction(&Instruction::LocalSet(self.return_area_local()));
343
344 let mem_arg = MemArg {
347 offset: 0,
348 align: 2,
349 memory_index: MEMORY_INDEX,
350 };
351 func.instruction(&Instruction::LocalGet(self.return_area_local()));
352 func.instruction(&Instruction::LocalGet(self.result_addr_local()));
353 func.instruction(&Instruction::I32Store(mem_arg));
354 func.instruction(&Instruction::LocalGet(self.return_area_local()));
356 func.instruction(&Instruction::I32Const(4));
357 func.instruction(&Instruction::I32Add);
358 func.instruction(&Instruction::LocalGet(self.result_len_local()));
359 func.instruction(&Instruction::I32Store(mem_arg));
360 }
361
362 fn gen_init_cursor(&self, func: &mut Function) {
363 func.instruction(&Instruction::LocalGet(self.result_addr_local()));
365 func.instruction(&Instruction::LocalSet(self.result_cursor_local()));
366 }
367
368 fn gen_write_template(&self, func: &mut Function) {
369 self.gen_write_sequence_template(func, &self.file_data.contents);
370 }
371
372 fn gen_write_sequence_template(&self, func: &mut Function, sequence: &[Node<'source>]) {
373 for node in sequence {
374 match node {
376 Node::Text { index, text } => {
377 self.gen_write_segment(func, *index as u32, text.value.len() as i32);
378 }
379 Node::Parameter { name } => {
380 let index = self.params.text_param_index(&name.value);
381 self.gen_write_param(func, index as u32);
382 }
383 Node::Conditional {
384 if_kwd: _,
385 cond_ident,
386 contents,
387 endif_kwd: _,
388 } => {
389 let cond_index = self.params.cond_param_index(cond_ident.value) as u32;
390
391 self.params.gen_push_cond(func, cond_index);
392 func.instruction(&Instruction::If(BlockType::Empty));
393 self.gen_write_sequence_template(func, contents);
394 func.instruction(&Instruction::Else);
395 func.instruction(&Instruction::End);
396 }
397 }
398
399 if matches!(node, Node::Text { .. }) || matches!(node, Node::Parameter { .. }) {
400 func.instruction(&Instruction::LocalGet(self.result_cursor_local()));
402 func.instruction(&Instruction::I32Add);
403 func.instruction(&Instruction::LocalSet(self.result_cursor_local()));
404 }
405 }
406 }
407
408 fn gen_write_segment(&self, func: &mut Function, data_index: u32, length: i32) {
409 func.instruction(&Instruction::LocalGet(self.result_cursor_local()));
411 func.instruction(&Instruction::I32Const(0));
413 func.instruction(&Instruction::I32Const(length));
415 func.instruction(&Instruction::MemoryInit { mem: 0, data_index });
417
418 func.instruction(&Instruction::I32Const(length));
420 }
421
422 fn gen_write_param(&self, func: &mut Function, param_index: u32) {
423 func.instruction(&Instruction::LocalGet(self.result_cursor_local()));
425 self.params.gen_push_text_offset(func, param_index);
427 self.params.gen_push_text_len(func, param_index);
429 func.instruction(&Instruction::MemoryCopy {
431 src_mem: MEMORY_INDEX,
432 dst_mem: MEMORY_INDEX,
433 });
434
435 self.params.gen_push_text_len(func, param_index);
437 }
438}
439
440fn snake_to_kebab(ident: &str) -> String {
441 ident.replace("_", "-")
442}