sbpf_assembler/
ast.rs

1use {
2    crate::{
3        CompileError,
4        astnode::{ASTNode, ROData},
5        dynsym::{DynamicSymbolMap, RelDynMap, RelocationType, get_relocation_info},
6        parser::ParseResult,
7        section::{CodeSection, DataSection},
8    },
9    either::Either,
10    sbpf_common::{inst_param::Number, instruction::Instruction, opcode::Opcode},
11    std::collections::HashMap,
12};
13
14#[derive(Default)]
15pub struct AST {
16    pub nodes: Vec<ASTNode>,
17    pub rodata_nodes: Vec<ASTNode>,
18
19    pub entry_label: Option<String>,
20    text_size: u64,
21    rodata_size: u64,
22}
23
24impl AST {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    //
30    pub fn set_text_size(&mut self, text_size: u64) {
31        self.text_size = text_size;
32    }
33
34    //
35    pub fn set_rodata_size(&mut self, rodata_size: u64) {
36        self.rodata_size = rodata_size;
37    }
38
39    //
40    pub fn get_instruction_at_offset(&mut self, offset: u64) -> Option<&mut Instruction> {
41        self.nodes
42            .iter_mut()
43            .find(|node| match node {
44                ASTNode::Instruction {
45                    instruction: _,
46                    offset: inst_offset,
47                    ..
48                } => offset == *inst_offset,
49                _ => false,
50            })
51            .map(|node| match node {
52                ASTNode::Instruction { instruction, .. } => instruction,
53                _ => panic!("Expected Instruction node"),
54            })
55    }
56
57    //
58    pub fn get_rodata_at_offset(&self, offset: u64) -> Option<&ROData> {
59        self.rodata_nodes
60            .iter()
61            .find(|node| match node {
62                ASTNode::ROData {
63                    rodata: _,
64                    offset: rodata_offset,
65                    ..
66                } => offset == *rodata_offset,
67                _ => false,
68            })
69            .map(|node| match node {
70                ASTNode::ROData { rodata, .. } => rodata,
71                _ => panic!("Expected ROData node"),
72            })
73    }
74
75    /// Resolve numeric label references (like "2f" or "1b")
76    fn resolve_numeric_label(
77        label_ref: &str,
78        current_idx: usize,
79        numeric_labels: &[(String, u64, usize)],
80    ) -> Option<u64> {
81        if let Some(direction) = label_ref.chars().last()
82            && (direction == 'f' || direction == 'b')
83        {
84            let label_num = &label_ref[..label_ref.len() - 1];
85
86            if direction == 'f' {
87                // search forward from current position
88                for (name, offset, node_idx) in numeric_labels {
89                    if name == label_num && *node_idx > current_idx {
90                        return Some(*offset);
91                    }
92                }
93            } else {
94                // search backward from current position
95                for (name, offset, node_idx) in numeric_labels.iter().rev() {
96                    if name == label_num && *node_idx < current_idx {
97                        return Some(*offset);
98                    }
99                }
100            }
101        }
102        None
103    }
104
105    //
106    pub fn build_program(&mut self) -> Result<ParseResult, Vec<CompileError>> {
107        let mut label_offset_map: HashMap<String, u64> = HashMap::new();
108        let mut numeric_labels: Vec<(String, u64, usize)> = Vec::new();
109
110        // iterate through text labels and rodata labels and find the pair
111        // of each label and offset
112        for (idx, node) in self.nodes.iter().enumerate() {
113            if let ASTNode::Label { label, offset } = node {
114                label_offset_map.insert(label.name.clone(), *offset);
115                // Also track numeric labels separately for forward/backward resolution
116                numeric_labels.push((label.name.clone(), *offset, idx));
117            }
118        }
119
120        for node in &self.rodata_nodes {
121            if let ASTNode::ROData { rodata, offset } = node {
122                label_offset_map.insert(rodata.name.clone(), *offset + self.text_size);
123            }
124        }
125
126        // 1. resolve labels in the intruction nodes for lddw and jump
127        // 2. find relocation information
128
129        let program_is_static = !self.nodes.iter().any(|node| matches!(node, ASTNode::Instruction { instruction: inst, .. } if inst.needs_relocation()));
130        let mut relocations = RelDynMap::new();
131        let mut dynamic_symbols = DynamicSymbolMap::new();
132
133        let mut errors = Vec::new();
134
135        for (idx, node) in self.nodes.iter_mut().enumerate() {
136            if let ASTNode::Instruction {
137                instruction: inst,
138                offset,
139                ..
140            } = node
141            {
142                // For jump/call instructions, replace label with relative offsets
143                if inst.is_jump()
144                    && let Some(Either::Left(label)) = &inst.off
145                {
146                    let target_offset = if let Some(offset) = label_offset_map.get(label) {
147                        Some(*offset)
148                    } else {
149                        // Handle numeric label references
150                        Self::resolve_numeric_label(label, idx, &numeric_labels)
151                    };
152
153                    if let Some(target_offset) = target_offset {
154                        let rel_offset = (target_offset as i64 - *offset as i64) / 8 - 1;
155                        inst.off = Some(Either::Right(rel_offset as i16));
156                    } else {
157                        errors.push(CompileError::UndefinedLabel {
158                            label: label.clone(),
159                            span: inst.span.clone(),
160                            custom_label: None,
161                        });
162                    }
163                } else if inst.opcode == Opcode::Call
164                    && let Some(Either::Left(label)) = &inst.imm
165                    && let Some(target_offset) = label_offset_map.get(label)
166                {
167                    let rel_offset = (*target_offset as i64 - *offset as i64) / 8 - 1;
168                    inst.imm = Some(Either::Right(Number::Int(rel_offset)));
169                }
170
171                if inst.needs_relocation() {
172                    let (reloc_type, label) = get_relocation_info(inst);
173                    relocations.add_rel_dyn(*offset, reloc_type, label.clone());
174                    if reloc_type == RelocationType::RSbfSyscall {
175                        dynamic_symbols.add_call_target(label.clone(), *offset);
176                    }
177                }
178                if inst.opcode == Opcode::Lddw
179                    && let Some(Either::Left(name)) = &inst.imm
180                {
181                    let label = name.clone();
182                    if let Some(target_offset) = label_offset_map.get(&label) {
183                        let ph_count = if program_is_static { 1 } else { 3 };
184                        let ph_offset = 64 + (ph_count as u64 * 56) as i64;
185                        let abs_offset = *target_offset as i64 + ph_offset;
186                        // Replace label with immediate value
187                        inst.imm = Some(Either::Right(Number::Addr(abs_offset)));
188                    } else {
189                        errors.push(CompileError::UndefinedLabel {
190                            label: name.clone(),
191                            span: inst.span.clone(),
192                            custom_label: None,
193                        });
194                    }
195                }
196            }
197        }
198
199        // Set entry point offset if an entry label was specified
200        if let Some(entry_label) = &self.entry_label
201            && let Some(offset) = label_offset_map.get(entry_label)
202        {
203            dynamic_symbols.add_entry_point(entry_label.clone(), *offset);
204        }
205
206        if !errors.is_empty() {
207            Err(errors)
208        } else {
209            Ok(ParseResult {
210                code_section: CodeSection::new(std::mem::take(&mut self.nodes), self.text_size),
211                data_section: DataSection::new(
212                    std::mem::take(&mut self.rodata_nodes),
213                    self.rodata_size,
214                ),
215                dynamic_symbols,
216                relocation_data: relocations,
217                prog_is_static: program_is_static,
218            })
219        }
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use {super::*, crate::parser::Token};
226
227    #[test]
228    fn test_ast_new() {
229        let ast = AST::new();
230        assert!(ast.nodes.is_empty());
231        assert!(ast.rodata_nodes.is_empty());
232        assert!(ast.entry_label.is_none());
233        assert_eq!(ast.text_size, 0);
234        assert_eq!(ast.rodata_size, 0);
235    }
236
237    #[test]
238    fn test_ast_set_sizes() {
239        let mut ast = AST::new();
240        ast.set_text_size(100);
241        ast.set_rodata_size(50);
242        assert_eq!(ast.text_size, 100);
243        assert_eq!(ast.rodata_size, 50);
244    }
245
246    #[test]
247    fn test_get_instruction_at_offset() {
248        let mut ast = AST::new();
249        let inst = Instruction {
250            opcode: Opcode::Exit,
251            dst: None,
252            src: None,
253            off: None,
254            imm: None,
255            span: 0..4,
256        };
257        ast.nodes.push(ASTNode::Instruction {
258            instruction: inst,
259            offset: 0,
260        });
261
262        let found = ast.get_instruction_at_offset(0);
263        assert!(found.is_some());
264        assert_eq!(found.unwrap().opcode, Opcode::Exit);
265
266        let not_found = ast.get_instruction_at_offset(8);
267        assert!(not_found.is_none());
268    }
269
270    #[test]
271    fn test_get_rodata_at_offset() {
272        let mut ast = AST::new();
273        let rodata = ROData {
274            name: "data".to_string(),
275            args: vec![
276                Token::Directive("ascii".to_string(), 0..5),
277                Token::StringLiteral("test".to_string(), 6..12),
278            ],
279            span: 0..12,
280        };
281        ast.rodata_nodes.push(ASTNode::ROData {
282            rodata: rodata.clone(),
283            offset: 0,
284        });
285
286        let found = ast.get_rodata_at_offset(0);
287        assert!(found.is_some());
288        assert_eq!(found.unwrap().name, "data");
289    }
290
291    #[test]
292    fn test_resolve_numeric_label_forward() {
293        let numeric_labels = vec![("1".to_string(), 16, 2), ("2".to_string(), 32, 4)];
294
295        let result = AST::resolve_numeric_label("1f", 0, &numeric_labels);
296        assert_eq!(result, Some(16));
297
298        let result = AST::resolve_numeric_label("2f", 3, &numeric_labels);
299        assert_eq!(result, Some(32));
300    }
301
302    #[test]
303    fn test_resolve_numeric_label_backward() {
304        let numeric_labels = vec![("1".to_string(), 16, 2), ("2".to_string(), 32, 4)];
305
306        let result = AST::resolve_numeric_label("1b", 3, &numeric_labels);
307        assert_eq!(result, Some(16));
308
309        let result = AST::resolve_numeric_label("2b", 5, &numeric_labels);
310        assert_eq!(result, Some(32));
311    }
312
313    #[test]
314    fn test_build_program_simple() {
315        let mut ast = AST::new();
316        let inst = Instruction {
317            opcode: Opcode::Exit,
318            dst: None,
319            src: None,
320            off: None,
321            imm: None,
322            span: 0..4,
323        };
324        ast.nodes.push(ASTNode::Instruction {
325            instruction: inst,
326            offset: 0,
327        });
328        ast.set_text_size(8);
329        ast.set_rodata_size(0);
330
331        let result = ast.build_program();
332        assert!(result.is_ok());
333        let parse_result = result.unwrap();
334        assert!(parse_result.prog_is_static);
335    }
336
337    #[test]
338    fn test_build_program_undefined_label_error() {
339        let mut ast = AST::new();
340
341        // Jump to undefined label
342        let inst = Instruction {
343            opcode: Opcode::Ja,
344            dst: None,
345            src: None,
346            off: Some(Either::Left("undefined_label".to_string())),
347            imm: None,
348            span: 0..10,
349        };
350        ast.nodes.push(ASTNode::Instruction {
351            instruction: inst,
352            offset: 0,
353        });
354        ast.set_text_size(8);
355
356        let result = ast.build_program();
357        assert!(result.is_err());
358    }
359}