Skip to main content

sbpf_assembler/
ast.rs

1use {
2    crate::{
3        CompileError, SbpfArch,
4        astnode::{ASTNode, ROData},
5        dynsym::{DynamicSymbolMap, RelDynMap, RelocationType},
6        parser::ParseResult,
7        section::{CodeSection, DataSection},
8    },
9    either::Either,
10    sbpf_common::{
11        inst_param::{Number, Register},
12        instruction::Instruction,
13        opcode::Opcode,
14    },
15    std::collections::HashMap,
16    syscall_map::murmur3_32,
17};
18
19#[derive(Default, Debug)]
20pub struct AST {
21    pub nodes: Vec<ASTNode>,
22    pub rodata_nodes: Vec<ASTNode>,
23
24    text_size: u64,
25    rodata_size: u64,
26}
27
28impl AST {
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    //
34    pub fn set_text_size(&mut self, text_size: u64) {
35        self.text_size = text_size;
36    }
37
38    //
39    pub fn set_rodata_size(&mut self, rodata_size: u64) {
40        self.rodata_size = rodata_size;
41    }
42
43    //
44    pub fn get_instruction_at_offset(&mut self, offset: u64) -> Option<&mut Instruction> {
45        self.nodes
46            .iter_mut()
47            .find(|node| match node {
48                ASTNode::Instruction {
49                    instruction: _,
50                    offset: inst_offset,
51                    ..
52                } => offset == *inst_offset,
53                _ => false,
54            })
55            .map(|node| match node {
56                ASTNode::Instruction { instruction, .. } => instruction,
57                _ => panic!("Expected Instruction node"),
58            })
59    }
60
61    //
62    pub fn get_rodata_at_offset(&self, offset: u64) -> Option<&ROData> {
63        self.rodata_nodes
64            .iter()
65            .find(|node| match node {
66                ASTNode::ROData {
67                    rodata: _,
68                    offset: rodata_offset,
69                    ..
70                } => offset == *rodata_offset,
71                _ => false,
72            })
73            .map(|node| match node {
74                ASTNode::ROData { rodata, .. } => rodata,
75                _ => panic!("Expected ROData node"),
76            })
77    }
78
79    /// Resolve numeric label references (like "2f" or "1b")
80    fn resolve_numeric_label(
81        label_ref: &str,
82        current_idx: usize,
83        numeric_labels: &[(String, u64, usize)],
84    ) -> Option<u64> {
85        if let Some(direction) = label_ref.chars().last()
86            && (direction == 'f' || direction == 'b')
87        {
88            let label_num = &label_ref[..label_ref.len() - 1];
89
90            if direction == 'f' {
91                // search forward from current position
92                for (name, offset, node_idx) in numeric_labels {
93                    if name == label_num && *node_idx > current_idx {
94                        return Some(*offset);
95                    }
96                }
97            } else {
98                // search backward from current position
99                for (name, offset, node_idx) in numeric_labels.iter().rev() {
100                    if name == label_num && *node_idx < current_idx {
101                        return Some(*offset);
102                    }
103                }
104            }
105        }
106        None
107    }
108
109    pub fn build_program(&mut self, arch: SbpfArch) -> Result<ParseResult, Vec<CompileError>> {
110        let mut label_offset_map: HashMap<String, u64> = HashMap::new();
111        let mut numeric_labels: Vec<(String, u64, usize)> = Vec::new();
112
113        // iterate through text labels and rodata labels and find the pair
114        // of each label and offset
115        for (idx, node) in self.nodes.iter().enumerate() {
116            if let ASTNode::Label { label, offset } = node {
117                label_offset_map.insert(label.name.clone(), *offset);
118                // Also track numeric labels separately for forward/backward resolution
119                numeric_labels.push((label.name.clone(), *offset, idx));
120            }
121        }
122
123        for node in &self.rodata_nodes {
124            if let ASTNode::ROData { rodata, offset } = node {
125                label_offset_map.insert(rodata.name.clone(), *offset + self.text_size);
126            }
127        }
128
129        // 1. resolve labels in the intruction nodes for lddw and jump
130        // 2. find relocation information
131
132        let mut relocations = RelDynMap::new();
133        let mut dynamic_symbols = DynamicSymbolMap::new();
134
135        // Resolve both static and dynamic syscalls.
136        for node in self.nodes.iter_mut() {
137            if let ASTNode::Instruction {
138                instruction: inst,
139                offset,
140            } = node
141                && inst.is_syscall()
142                && let Some(Either::Left(syscall_name)) = &inst.imm
143            {
144                let syscall_name = syscall_name.clone();
145                if arch.is_v3() {
146                    // Static syscall: src = 0, imm = hash
147                    inst.src = Some(Register { n: 0 });
148                    inst.imm = Some(Either::Right(Number::Int(murmur3_32(&syscall_name) as i64)));
149                } else {
150                    // Dynamic syscall: src = 1, imm = -1
151                    inst.src = Some(Register { n: 1 });
152                    inst.imm = Some(Either::Right(Number::Int(-1)));
153
154                    // Add relocation for dynamic syscall
155                    relocations.add_rel_dyn(
156                        *offset,
157                        RelocationType::RSbfSyscall,
158                        syscall_name.clone(),
159                    );
160                    dynamic_symbols.add_call_target(syscall_name.clone(), *offset);
161                }
162            }
163        }
164
165        let program_is_static = !self.nodes.iter().any(|node| {
166            matches!(node, ASTNode::Instruction { instruction: inst, .. }
167                if inst.needs_relocation())
168        });
169
170        let mut errors = Vec::new();
171
172        for (idx, node) in self.nodes.iter_mut().enumerate() {
173            if let ASTNode::Instruction {
174                instruction: inst,
175                offset,
176                ..
177            } = node
178            {
179                // For jump/call instructions, replace label with relative offsets
180                if inst.is_jump()
181                    && let Some(Either::Left(label)) = &inst.off
182                {
183                    let target_offset = if let Some(offset) = label_offset_map.get(label) {
184                        Some(*offset)
185                    } else {
186                        // Handle numeric label references
187                        Self::resolve_numeric_label(label, idx, &numeric_labels)
188                    };
189
190                    if let Some(target_offset) = target_offset {
191                        let rel_offset = (target_offset as i64 - *offset as i64) / 8 - 1;
192                        inst.off = Some(Either::Right(rel_offset as i16));
193                    } else {
194                        errors.push(CompileError::UndefinedLabel {
195                            label: label.clone(),
196                            span: inst.span.clone(),
197                            custom_label: None,
198                        });
199                    }
200                } else if inst.opcode == Opcode::Call
201                    && let Some(Either::Left(label)) = &inst.imm
202                    && let Some(target_offset) = label_offset_map.get(label)
203                {
204                    let rel_offset = (*target_offset as i64 - *offset as i64) / 8 - 1;
205                    inst.src = Some(Register { n: 1 });
206                    inst.imm = Some(Either::Right(Number::Int(rel_offset)));
207                }
208
209                if inst.opcode == Opcode::Lddw
210                    && let Some(Either::Left(name)) = &inst.imm
211                {
212                    let label = name.clone();
213                    // Add relocation for lddw (only for v0)
214                    if !arch.is_v3() {
215                        relocations.add_rel_dyn(
216                            *offset,
217                            RelocationType::RSbf64Relative,
218                            label.clone(),
219                        );
220                    }
221
222                    if let Some(target_offset) = label_offset_map.get(&label) {
223                        let abs_offset = if arch.is_v3() {
224                            (*target_offset - self.text_size) as i64
225                        } else {
226                            let ph_count = if program_is_static { 1 } else { 3 };
227                            let ph_offset = 64 + (ph_count as u64 * 56) as i64;
228                            *target_offset as i64 + ph_offset
229                        };
230                        // Replace label with immediate value
231                        inst.imm = Some(Either::Right(Number::Addr(abs_offset)));
232                    } else {
233                        errors.push(CompileError::UndefinedLabel {
234                            label: name.clone(),
235                            span: inst.span.clone(),
236                            custom_label: None,
237                        });
238                    }
239                }
240            }
241        }
242
243        // Set entry point offset if a GlobalDecl was specified
244        let entry_label = self.nodes.iter().find_map(|node| {
245            if let ASTNode::GlobalDecl { global_decl } = node {
246                Some(global_decl.entry_label.clone())
247            } else {
248                None
249            }
250        });
251        if let Some(entry_label) = entry_label
252            && let Some(offset) = label_offset_map.get(&entry_label)
253        {
254            dynamic_symbols.add_entry_point(entry_label, *offset);
255        }
256
257        if !errors.is_empty() {
258            Err(errors)
259        } else {
260            Ok(ParseResult {
261                code_section: CodeSection::new(std::mem::take(&mut self.nodes), self.text_size),
262                data_section: DataSection::new(
263                    std::mem::take(&mut self.rodata_nodes),
264                    self.rodata_size,
265                ),
266                dynamic_symbols,
267                relocation_data: relocations,
268                prog_is_static: program_is_static,
269                arch,
270                debug_sections: Vec::default(),
271            })
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use {super::*, crate::parser::Token};
279
280    #[test]
281    fn test_ast_new() {
282        let ast = AST::new();
283        assert!(ast.nodes.is_empty());
284        assert!(ast.rodata_nodes.is_empty());
285        assert_eq!(ast.text_size, 0);
286        assert_eq!(ast.rodata_size, 0);
287    }
288
289    #[test]
290    fn test_ast_set_sizes() {
291        let mut ast = AST::new();
292        ast.set_text_size(100);
293        ast.set_rodata_size(50);
294        assert_eq!(ast.text_size, 100);
295        assert_eq!(ast.rodata_size, 50);
296    }
297
298    #[test]
299    fn test_get_instruction_at_offset() {
300        let mut ast = AST::new();
301        let inst = Instruction {
302            opcode: Opcode::Exit,
303            dst: None,
304            src: None,
305            off: None,
306            imm: None,
307            span: 0..4,
308        };
309        ast.nodes.push(ASTNode::Instruction {
310            instruction: inst,
311            offset: 0,
312        });
313
314        let found = ast.get_instruction_at_offset(0);
315        assert!(found.is_some());
316        assert_eq!(found.unwrap().opcode, Opcode::Exit);
317
318        let not_found = ast.get_instruction_at_offset(8);
319        assert!(not_found.is_none());
320    }
321
322    #[test]
323    fn test_get_rodata_at_offset() {
324        let mut ast = AST::new();
325        let rodata = ROData {
326            name: "data".to_string(),
327            args: vec![
328                Token::Directive("ascii".to_string(), 0..5),
329                Token::StringLiteral("test".to_string(), 6..12),
330            ],
331            span: 0..12,
332        };
333        ast.rodata_nodes.push(ASTNode::ROData {
334            rodata: rodata.clone(),
335            offset: 0,
336        });
337
338        let found = ast.get_rodata_at_offset(0);
339        assert!(found.is_some());
340        assert_eq!(found.unwrap().name, "data");
341    }
342
343    #[test]
344    fn test_resolve_numeric_label_forward() {
345        let numeric_labels = vec![("1".to_string(), 16, 2), ("2".to_string(), 32, 4)];
346
347        let result = AST::resolve_numeric_label("1f", 0, &numeric_labels);
348        assert_eq!(result, Some(16));
349
350        let result = AST::resolve_numeric_label("2f", 3, &numeric_labels);
351        assert_eq!(result, Some(32));
352    }
353
354    #[test]
355    fn test_resolve_numeric_label_backward() {
356        let numeric_labels = vec![("1".to_string(), 16, 2), ("2".to_string(), 32, 4)];
357
358        let result = AST::resolve_numeric_label("1b", 3, &numeric_labels);
359        assert_eq!(result, Some(16));
360
361        let result = AST::resolve_numeric_label("2b", 5, &numeric_labels);
362        assert_eq!(result, Some(32));
363    }
364
365    #[test]
366    fn test_build_program_simple() {
367        let mut ast = AST::new();
368        let inst = Instruction {
369            opcode: Opcode::Exit,
370            dst: None,
371            src: None,
372            off: None,
373            imm: None,
374            span: 0..4,
375        };
376        ast.nodes.push(ASTNode::Instruction {
377            instruction: inst,
378            offset: 0,
379        });
380        ast.set_text_size(8);
381        ast.set_rodata_size(0);
382
383        let result = ast.build_program(SbpfArch::V0);
384        assert!(result.is_ok());
385        let parse_result = result.unwrap();
386        assert!(parse_result.prog_is_static);
387    }
388
389    #[test]
390    fn test_build_program_undefined_label_error() {
391        let mut ast = AST::new();
392
393        // Jump to undefined label
394        let inst = Instruction {
395            opcode: Opcode::Ja,
396            dst: None,
397            src: None,
398            off: Some(Either::Left("undefined_label".to_string())),
399            imm: None,
400            span: 0..10,
401        };
402        ast.nodes.push(ASTNode::Instruction {
403            instruction: inst,
404            offset: 0,
405        });
406        ast.set_text_size(8);
407
408        let result = ast.build_program(SbpfArch::V0);
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_build_program_static_syscalls_no_relocation() {
414        let mut ast = AST::new();
415
416        let syscall_inst = Instruction {
417            opcode: Opcode::Call,
418            dst: None,
419            src: None,
420            off: None,
421            imm: Some(Either::Left("sol_log_".to_string())),
422            span: 0..8,
423        };
424        ast.nodes.push(ASTNode::Instruction {
425            instruction: syscall_inst,
426            offset: 0,
427        });
428
429        let exit_inst = Instruction {
430            opcode: Opcode::Exit,
431            dst: None,
432            src: None,
433            off: None,
434            imm: None,
435            span: 8..16,
436        };
437        ast.nodes.push(ASTNode::Instruction {
438            instruction: exit_inst,
439            offset: 8,
440        });
441
442        ast.set_text_size(16);
443        ast.set_rodata_size(0);
444
445        let result = ast.build_program(SbpfArch::V3);
446        assert!(result.is_ok());
447        let parse_result = result.unwrap();
448
449        assert!(parse_result.prog_is_static);
450        assert!(parse_result.relocation_data.get_rel_dyns().is_empty());
451    }
452
453    #[test]
454    fn test_build_program_dynamic_syscalls_with_relocation() {
455        let mut ast = AST::new();
456
457        let syscall_inst = Instruction {
458            opcode: Opcode::Call,
459            dst: None,
460            src: None,
461            off: None,
462            imm: Some(Either::Left("sol_log_".to_string())),
463            span: 0..8,
464        };
465        ast.nodes.push(ASTNode::Instruction {
466            instruction: syscall_inst,
467            offset: 0,
468        });
469
470        let exit_inst = Instruction {
471            opcode: Opcode::Exit,
472            dst: None,
473            src: None,
474            off: None,
475            imm: None,
476            span: 8..16,
477        };
478        ast.nodes.push(ASTNode::Instruction {
479            instruction: exit_inst,
480            offset: 8,
481        });
482
483        ast.set_text_size(16);
484        ast.set_rodata_size(0);
485
486        let result = ast.build_program(SbpfArch::V0);
487        assert!(result.is_ok());
488        let parse_result = result.unwrap();
489
490        assert!(!parse_result.prog_is_static);
491        assert!(!parse_result.relocation_data.get_rel_dyns().is_empty());
492    }
493}