template_compiler/gen/
template.rs

1use std::collections::HashSet;
2
3use wasm_encoder::{
4    DataCountSection, DataSection, Function, Instruction, MemArg, TypeSection, ValType, ComponentTypeSection, PrimitiveValType, ComponentValType, BlockType,
5};
6
7use crate::{parse::Node, FileData};
8
9const REALLOC_FUNC_INDEX: u32 = 0;
10const MEMORY_INDEX: u32 = 0;
11
12const MAX_FLAT_PARAMS: u32 = 16;
13
14pub struct TemplateGenerator<'source> {
15    params: Params<'source>,
16    file_data: &'source FileData<'source>,
17}
18
19pub struct Params<'source> {
20    text_params: Vec<&'source str>,
21    cond_params: Vec<&'source str>,
22}
23
24impl<'source> Params<'source> {
25    pub fn new(contents: &'source Vec<Node<'source>>) -> Self {
26        let mut text_params = HashSet::new();
27        let mut cond_params = HashSet::new();
28        for node in contents {
29            Self::collect_params(node, &mut text_params, &mut cond_params);
30        }
31        let mut text_params: Vec<&str> = text_params.into_iter().collect();
32        let mut cond_params: Vec<&str> = cond_params.into_iter().collect();
33        text_params.sort();
34        cond_params.sort();
35        Params {
36            text_params,
37            cond_params,
38        }
39    }
40
41    fn collect_params(
42        node: &'source Node<'source>,
43        text_params: &mut HashSet<&'source str>,
44        cond_params: &mut HashSet<&'source str>,
45    ) {
46        match node {
47            Node::Text { .. } => {}
48            Node::Parameter { name } => {
49                text_params.insert(name.value);
50            }
51            Node::Conditional {
52                if_kwd: _,
53                cond_ident,
54                contents,
55                endif_kwd: _,
56            } => {
57                cond_params.insert(cond_ident.value);
58                for node in contents {
59                    Self::collect_params(node, text_params, cond_params);
60                }
61            }
62        }
63    }
64
65    pub fn stack_len(&self) -> u32 {
66        self.text_stack_len() + (self.cond_params.len() as u32)
67    }
68
69    fn text_stack_len(&self) -> u32 {
70        2 * (self.text_params.len() as u32)
71    }
72
73    fn text_mem_len(&self) -> u32 {
74        8 * (self.text_params.len() as u32)
75    }
76
77    pub fn must_spill(&self) -> bool {
78        self.stack_len() > MAX_FLAT_PARAMS
79    }
80
81    // The number of text parameters
82    pub fn text_params_len(&self) -> usize {
83        self.text_params.len()
84    }
85
86    // The index in the parameters of a given text parameter name
87    pub fn text_param_index(&self, param: &str) -> usize {
88        self.text_params.binary_search(&param).unwrap()
89    }
90
91    // The index in the parameters of a given condition parameter name
92    pub fn cond_param_index(&self, param: &str) -> usize {
93        self.cond_params.binary_search(&param).unwrap()
94    }
95
96    pub fn record_type(&self) -> ComponentTypeSection {
97        let mut types = ComponentTypeSection::new();
98        let converted_names: Vec<String> = self.text_params.iter().map(|param: &&str| snake_to_kebab(param)).collect();
99        let text_fields = converted_names.iter().map(|param| {
100            (
101                param.as_str(),
102                ComponentValType::Primitive(PrimitiveValType::String),
103            )
104        });
105        let converted_names: Vec<String> = self.cond_params.iter().map(|param: &&str| snake_to_kebab(param)).collect();
106        let cond_fields = converted_names.iter().map(|param| {
107            (
108                param.as_str(),
109                ComponentValType::Primitive(PrimitiveValType::Bool),
110            )
111        });
112        let fields: Vec<(&str, ComponentValType)> = text_fields.chain(cond_fields).collect();
113        types.defined_type().record(fields);
114        types
115    }
116
117    fn gen_push_text_offset(&self, func: &mut Function, text_index: u32) {
118        self.gen_push_text_field(func, text_index, 0)
119    }
120
121    fn gen_push_text_len(&self, func: &mut Function, text_index: u32) {
122        self.gen_push_text_field(func, text_index, 1)
123    }
124    
125    fn gen_push_text_field(&self, func: &mut Function, text_index: u32, field: u32) {
126        if self.must_spill() {
127            // push params offset
128            func.instruction(&Instruction::LocalGet(0));
129            // push param index shift
130            let shift = (text_index * 8) + (field * 4);
131            let shift = shift.try_into().unwrap();
132            func.instruction(&Instruction::I32Const(shift));
133            // compute the final param index
134            func.instruction(&Instruction::I32Add);
135            // load the param string offset
136            func.instruction(&Instruction::I32Load(MemArg {
137                offset: 0,
138                align: 4,
139                memory_index: 0,
140            }));
141        } else {
142            let local_index = 2 * text_index + field;
143            func.instruction(&Instruction::LocalGet(local_index));
144        }
145    }
146
147    fn gen_push_cond(&self, func: &mut Function, cond_index: u32) {
148        if self.must_spill() {
149            // push params offset
150            func.instruction(&Instruction::LocalGet(0));
151            // push param index shift
152            let shift = self.text_mem_len() + cond_index;
153            let shift = shift.try_into().unwrap();
154            func.instruction(&Instruction::I32Const(shift));
155            // compute the final param index
156            func.instruction(&Instruction::I32Add);
157            // load the param string offset
158            func.instruction(&Instruction::I32Load8U(MemArg {
159                offset: 0,
160                align: 1,
161                memory_index: 0,
162            }));
163        } else {
164            let local_index = self.text_stack_len() + cond_index;
165            func.instruction(&Instruction::LocalGet(local_index));
166        }
167    }
168}
169
170impl<'source> TemplateGenerator<'source> {
171    pub fn new(params: Params<'source>, file_data: &'source FileData<'source>) -> Self {
172        Self { params, file_data }
173    }
174
175    pub fn params(&self) -> &Params<'source> {
176        &self.params
177    }
178
179    fn arguments_len(&self) -> u32 {
180        if self.params.must_spill() {
181            1
182        } else {
183            self.params.stack_len()
184        }
185    }
186
187    fn result_len_local(&self) -> u32 {
188        self.arguments_len() + 0
189    }
190
191    fn result_addr_local(&self) -> u32 {
192        self.arguments_len() + 1
193    }
194
195    fn return_area_local(&self) -> u32 {
196        self.arguments_len() + 2
197    }
198
199    fn result_cursor_local(&self) -> u32 {
200        self.arguments_len() + 3
201    }
202
203    fn locals_len(&self) -> u32 {
204        4
205    }
206
207    pub fn gen_core_type(&self, types: &mut TypeSection) {
208        let params = vec![ValType::I32; self.arguments_len() as usize];
209        let results = vec![ValType::I32];
210        types.function(params, results);
211    }
212
213    pub fn gen_data(&self) -> (DataCountSection, DataSection) {
214        let mut count = 0;
215        let mut data = DataSection::new();
216
217        for node in self.file_data.contents.iter() {
218            Self::collect_data(node, &mut count, &mut data);
219        }
220
221        let count = DataCountSection { count };
222        (count, data)
223    }
224
225    fn collect_data(node: &Node<'source>, count: &mut u32, data: &mut DataSection) {
226        match node {
227            Node::Text { index: _, text } => {
228                data.passive(text.value.bytes());
229                *count += 1;
230            }
231            Node::Parameter { name: _ } => {}
232            Node::Conditional {
233                if_kwd: _,
234                cond_ident: _,
235                contents,
236                endif_kwd: _,
237            } => {
238                for node in contents {
239                    Self::collect_data(node, count, data);
240                }
241            }
242        }
243    }
244
245    pub fn gen_core_function(&self) -> Function {
246        // Local variables
247        let locals = vec![(self.locals_len(), ValType::I32)];
248        let mut func = Function::new(locals);
249
250        self.gen_calculate_len(&mut func);
251        self.gen_allocate_results(&mut func);
252        self.gen_init_cursor(&mut func);
253        self.gen_write_template(&mut func);
254
255        func.instruction(&Instruction::LocalGet(self.return_area_local()));
256        func.instruction(&Instruction::End);
257        func
258    }
259
260    fn gen_calculate_len(&self, func: &mut Function) {
261        self.gen_calculate_sequence_len(func, self.file_data.contents.as_slice());
262        // Store the calculated length
263        func.instruction(&Instruction::LocalSet(self.result_len_local()));
264    }
265
266    fn gen_calculate_sequence_len(&self, func: &mut Function, sequence: &[Node<'source>]) {
267        let mut base_length = 0;
268        let mut param_counts = vec![0; self.params.text_params_len()];
269        let mut prior_exists = false;
270        for node in sequence.iter() {
271            match node {
272                Node::Text { index: _, text } => {
273                    base_length += text.value.len() as i32;
274                }
275                Node::Parameter { name } => {
276                    let index = self.params.text_param_index(&name.value);
277                    param_counts[index] += 1;
278                }
279                Node::Conditional {
280                    if_kwd: _,
281                    cond_ident,
282                    contents,
283                    endif_kwd: _,
284                } => {
285                    let cond_index = self.params.cond_param_index(cond_ident.value) as u32;
286
287                    self.params.gen_push_cond(func, cond_index);
288                    func.instruction(&Instruction::If(BlockType::Result(ValType::I32)));
289                    self.gen_calculate_sequence_len(func, &contents);
290                    func.instruction(&Instruction::Else);
291                    func.instruction(&Instruction::I32Const(0));
292                    func.instruction(&Instruction::End);
293
294                    if prior_exists {
295                        func.instruction(&Instruction::I32Add);
296                    }
297
298                    prior_exists = true;
299                }
300            }
301        }
302
303        // push the base length
304        func.instruction(&Instruction::I32Const(base_length));
305
306        if prior_exists {
307            func.instruction(&Instruction::I32Add);
308        }
309
310        // accumulate the dynamic part of the length
311        for (index, count) in param_counts.iter().enumerate() {
312            if *count > 0 {
313                // load the length of the parameter
314                self.params.gen_push_text_len(func, index as u32);
315                // push the count of parameter occurrences
316                func.instruction(&Instruction::I32Const(*count));
317                // multiple the length by the occurrences
318                func.instruction(&Instruction::I32Mul);
319                // add this length addition to the total length
320                func.instruction(&Instruction::I32Add);
321            }
322        }
323    }
324
325    fn gen_allocate_results(&self, func: &mut Function) {
326        // allocate result string
327        func.instruction(&Instruction::I32Const(0));
328        func.instruction(&Instruction::I32Const(0));
329        func.instruction(&Instruction::I32Const(1));
330        func.instruction(&Instruction::LocalGet(self.result_len_local()));
331        func.instruction(&Instruction::Call(REALLOC_FUNC_INDEX));
332        // store allocated address
333        func.instruction(&Instruction::LocalSet(self.result_addr_local()));
334
335        // allocate return area
336        func.instruction(&Instruction::I32Const(0));
337        func.instruction(&Instruction::I32Const(0));
338        func.instruction(&Instruction::I32Const(4));
339        func.instruction(&Instruction::I32Const(8));
340        func.instruction(&Instruction::Call(REALLOC_FUNC_INDEX));
341        // store allocated address
342        func.instruction(&Instruction::LocalSet(self.return_area_local()));
343
344        // populate return area
345        // store result addr
346        let mem_arg = MemArg {
347            offset: 0,
348            align: 2,
349            memory_index: MEMORY_INDEX,
350        };
351        func.instruction(&Instruction::LocalGet(self.return_area_local()));
352        func.instruction(&Instruction::LocalGet(self.result_addr_local()));
353        func.instruction(&Instruction::I32Store(mem_arg));
354        // store result len
355        func.instruction(&Instruction::LocalGet(self.return_area_local()));
356        func.instruction(&Instruction::I32Const(4));
357        func.instruction(&Instruction::I32Add);
358        func.instruction(&Instruction::LocalGet(self.result_len_local()));
359        func.instruction(&Instruction::I32Store(mem_arg));
360    }
361
362    fn gen_init_cursor(&self, func: &mut Function) {
363        // set cursor to result string address
364        func.instruction(&Instruction::LocalGet(self.result_addr_local()));
365        func.instruction(&Instruction::LocalSet(self.result_cursor_local()));
366    }
367
368    fn gen_write_template(&self, func: &mut Function) {
369        self.gen_write_sequence_template(func, &self.file_data.contents);
370    }
371
372    fn gen_write_sequence_template(&self, func: &mut Function, sequence: &[Node<'source>]) {
373        for node in sequence {
374            // note both branches end by pushing the cursor shift
375            match node {
376                Node::Text { index, text } => {
377                    self.gen_write_segment(func, *index as u32, text.value.len() as i32);
378                }
379                Node::Parameter { name } => {
380                    let index = self.params.text_param_index(&name.value);
381                    self.gen_write_param(func, index as u32);
382                }
383                Node::Conditional {
384                    if_kwd: _,
385                    cond_ident,
386                    contents,
387                    endif_kwd: _,
388                } => {
389                    let cond_index = self.params.cond_param_index(cond_ident.value) as u32;
390
391                    self.params.gen_push_cond(func, cond_index);
392                    func.instruction(&Instruction::If(BlockType::Empty));
393                    self.gen_write_sequence_template(func, contents);
394                    func.instruction(&Instruction::Else);
395                    func.instruction(&Instruction::End);
396                }
397            }
398            
399            if matches!(node, Node::Text { .. }) || matches!(node, Node::Parameter { .. }) {
400                // push cursor and add to shift
401                func.instruction(&Instruction::LocalGet(self.result_cursor_local()));
402                func.instruction(&Instruction::I32Add);
403                func.instruction(&Instruction::LocalSet(self.result_cursor_local()));
404            }
405        }
406    }
407
408    fn gen_write_segment(&self, func: &mut Function, data_index: u32, length: i32) {
409        // push destination
410        func.instruction(&Instruction::LocalGet(self.result_cursor_local()));
411        // push source
412        func.instruction(&Instruction::I32Const(0));
413        // push length
414        func.instruction(&Instruction::I32Const(length));
415        // copy data segment into output
416        func.instruction(&Instruction::MemoryInit { mem: 0, data_index });
417
418        // push length
419        func.instruction(&Instruction::I32Const(length));
420    }
421
422    fn gen_write_param(&self, func: &mut Function, param_index: u32) {
423        // push destination
424        func.instruction(&Instruction::LocalGet(self.result_cursor_local()));
425        // push source
426        self.params.gen_push_text_offset(func, param_index);
427        // push length
428        self.params.gen_push_text_len(func, param_index);
429        // copy the argument data
430        func.instruction(&Instruction::MemoryCopy {
431            src_mem: MEMORY_INDEX,
432            dst_mem: MEMORY_INDEX,
433        });
434
435        // push length
436        self.params.gen_push_text_len(func, param_index);
437    }
438}
439
440fn snake_to_kebab(ident: &str) -> String {
441    ident.replace("_", "-")
442}