1use {
2 crate::{
3 CompileError,
4 astnode::{ASTNode, ROData},
5 dynsym::{DynamicSymbolMap, RelDynMap, RelocationType, get_relocation_info},
6 parser::ParseResult,
7 section::{CodeSection, DataSection},
8 },
9 either::Either,
10 sbpf_common::{inst_param::Number, instruction::Instruction, opcode::Opcode},
11 std::collections::HashMap,
12};
13
14#[derive(Default)]
15pub struct AST {
16 pub nodes: Vec<ASTNode>,
17 pub rodata_nodes: Vec<ASTNode>,
18
19 pub entry_label: Option<String>,
20 text_size: u64,
21 rodata_size: u64,
22}
23
24impl AST {
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn set_text_size(&mut self, text_size: u64) {
31 self.text_size = text_size;
32 }
33
34 pub fn set_rodata_size(&mut self, rodata_size: u64) {
36 self.rodata_size = rodata_size;
37 }
38
39 pub fn get_instruction_at_offset(&mut self, offset: u64) -> Option<&mut Instruction> {
41 self.nodes
42 .iter_mut()
43 .find(|node| match node {
44 ASTNode::Instruction {
45 instruction: _,
46 offset: inst_offset,
47 ..
48 } => offset == *inst_offset,
49 _ => false,
50 })
51 .map(|node| match node {
52 ASTNode::Instruction { instruction, .. } => instruction,
53 _ => panic!("Expected Instruction node"),
54 })
55 }
56
57 pub fn get_rodata_at_offset(&self, offset: u64) -> Option<&ROData> {
59 self.rodata_nodes
60 .iter()
61 .find(|node| match node {
62 ASTNode::ROData {
63 rodata: _,
64 offset: rodata_offset,
65 ..
66 } => offset == *rodata_offset,
67 _ => false,
68 })
69 .map(|node| match node {
70 ASTNode::ROData { rodata, .. } => rodata,
71 _ => panic!("Expected ROData node"),
72 })
73 }
74
75 fn resolve_numeric_label(
77 label_ref: &str,
78 current_idx: usize,
79 numeric_labels: &[(String, u64, usize)],
80 ) -> Option<u64> {
81 if let Some(direction) = label_ref.chars().last()
82 && (direction == 'f' || direction == 'b')
83 {
84 let label_num = &label_ref[..label_ref.len() - 1];
85
86 if direction == 'f' {
87 for (name, offset, node_idx) in numeric_labels {
89 if name == label_num && *node_idx > current_idx {
90 return Some(*offset);
91 }
92 }
93 } else {
94 for (name, offset, node_idx) in numeric_labels.iter().rev() {
96 if name == label_num && *node_idx < current_idx {
97 return Some(*offset);
98 }
99 }
100 }
101 }
102 None
103 }
104
105 pub fn build_program(&mut self) -> Result<ParseResult, Vec<CompileError>> {
107 let mut label_offset_map: HashMap<String, u64> = HashMap::new();
108 let mut numeric_labels: Vec<(String, u64, usize)> = Vec::new();
109
110 for (idx, node) in self.nodes.iter().enumerate() {
113 if let ASTNode::Label { label, offset } = node {
114 label_offset_map.insert(label.name.clone(), *offset);
115 numeric_labels.push((label.name.clone(), *offset, idx));
117 }
118 }
119
120 for node in &self.rodata_nodes {
121 if let ASTNode::ROData { rodata, offset } = node {
122 label_offset_map.insert(rodata.name.clone(), *offset + self.text_size);
123 }
124 }
125
126 let program_is_static = !self.nodes.iter().any(|node| matches!(node, ASTNode::Instruction { instruction: inst, .. } if inst.needs_relocation()));
130 let mut relocations = RelDynMap::new();
131 let mut dynamic_symbols = DynamicSymbolMap::new();
132
133 let mut errors = Vec::new();
134
135 for (idx, node) in self.nodes.iter_mut().enumerate() {
136 if let ASTNode::Instruction {
137 instruction: inst,
138 offset,
139 ..
140 } = node
141 {
142 if inst.is_jump()
144 && let Some(Either::Left(label)) = &inst.off
145 {
146 let target_offset = if let Some(offset) = label_offset_map.get(label) {
147 Some(*offset)
148 } else {
149 Self::resolve_numeric_label(label, idx, &numeric_labels)
151 };
152
153 if let Some(target_offset) = target_offset {
154 let rel_offset = (target_offset as i64 - *offset as i64) / 8 - 1;
155 inst.off = Some(Either::Right(rel_offset as i16));
156 } else {
157 errors.push(CompileError::UndefinedLabel {
158 label: label.clone(),
159 span: inst.span.clone(),
160 custom_label: None,
161 });
162 }
163 } else if inst.opcode == Opcode::Call
164 && let Some(Either::Left(label)) = &inst.imm
165 && let Some(target_offset) = label_offset_map.get(label)
166 {
167 let rel_offset = (*target_offset as i64 - *offset as i64) / 8 - 1;
168 inst.imm = Some(Either::Right(Number::Int(rel_offset)));
169 }
170
171 if inst.needs_relocation() {
172 let (reloc_type, label) = get_relocation_info(inst);
173 relocations.add_rel_dyn(*offset, reloc_type, label.clone());
174 if reloc_type == RelocationType::RSbfSyscall {
175 dynamic_symbols.add_call_target(label.clone(), *offset);
176 }
177 }
178 if inst.opcode == Opcode::Lddw
179 && let Some(Either::Left(name)) = &inst.imm
180 {
181 let label = name.clone();
182 if let Some(target_offset) = label_offset_map.get(&label) {
183 let ph_count = if program_is_static { 1 } else { 3 };
184 let ph_offset = 64 + (ph_count as u64 * 56) as i64;
185 let abs_offset = *target_offset as i64 + ph_offset;
186 inst.imm = Some(Either::Right(Number::Addr(abs_offset)));
188 } else {
189 errors.push(CompileError::UndefinedLabel {
190 label: name.clone(),
191 span: inst.span.clone(),
192 custom_label: None,
193 });
194 }
195 }
196 }
197 }
198
199 if let Some(entry_label) = &self.entry_label
201 && let Some(offset) = label_offset_map.get(entry_label)
202 {
203 dynamic_symbols.add_entry_point(entry_label.clone(), *offset);
204 }
205
206 if !errors.is_empty() {
207 Err(errors)
208 } else {
209 Ok(ParseResult {
210 code_section: CodeSection::new(std::mem::take(&mut self.nodes), self.text_size),
211 data_section: DataSection::new(
212 std::mem::take(&mut self.rodata_nodes),
213 self.rodata_size,
214 ),
215 dynamic_symbols,
216 relocation_data: relocations,
217 prog_is_static: program_is_static,
218 })
219 }
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use {super::*, crate::parser::Token};
226
227 #[test]
228 fn test_ast_new() {
229 let ast = AST::new();
230 assert!(ast.nodes.is_empty());
231 assert!(ast.rodata_nodes.is_empty());
232 assert!(ast.entry_label.is_none());
233 assert_eq!(ast.text_size, 0);
234 assert_eq!(ast.rodata_size, 0);
235 }
236
237 #[test]
238 fn test_ast_set_sizes() {
239 let mut ast = AST::new();
240 ast.set_text_size(100);
241 ast.set_rodata_size(50);
242 assert_eq!(ast.text_size, 100);
243 assert_eq!(ast.rodata_size, 50);
244 }
245
246 #[test]
247 fn test_get_instruction_at_offset() {
248 let mut ast = AST::new();
249 let inst = Instruction {
250 opcode: Opcode::Exit,
251 dst: None,
252 src: None,
253 off: None,
254 imm: None,
255 span: 0..4,
256 };
257 ast.nodes.push(ASTNode::Instruction {
258 instruction: inst,
259 offset: 0,
260 });
261
262 let found = ast.get_instruction_at_offset(0);
263 assert!(found.is_some());
264 assert_eq!(found.unwrap().opcode, Opcode::Exit);
265
266 let not_found = ast.get_instruction_at_offset(8);
267 assert!(not_found.is_none());
268 }
269
270 #[test]
271 fn test_get_rodata_at_offset() {
272 let mut ast = AST::new();
273 let rodata = ROData {
274 name: "data".to_string(),
275 args: vec![
276 Token::Directive("ascii".to_string(), 0..5),
277 Token::StringLiteral("test".to_string(), 6..12),
278 ],
279 span: 0..12,
280 };
281 ast.rodata_nodes.push(ASTNode::ROData {
282 rodata: rodata.clone(),
283 offset: 0,
284 });
285
286 let found = ast.get_rodata_at_offset(0);
287 assert!(found.is_some());
288 assert_eq!(found.unwrap().name, "data");
289 }
290
291 #[test]
292 fn test_resolve_numeric_label_forward() {
293 let numeric_labels = vec![("1".to_string(), 16, 2), ("2".to_string(), 32, 4)];
294
295 let result = AST::resolve_numeric_label("1f", 0, &numeric_labels);
296 assert_eq!(result, Some(16));
297
298 let result = AST::resolve_numeric_label("2f", 3, &numeric_labels);
299 assert_eq!(result, Some(32));
300 }
301
302 #[test]
303 fn test_resolve_numeric_label_backward() {
304 let numeric_labels = vec![("1".to_string(), 16, 2), ("2".to_string(), 32, 4)];
305
306 let result = AST::resolve_numeric_label("1b", 3, &numeric_labels);
307 assert_eq!(result, Some(16));
308
309 let result = AST::resolve_numeric_label("2b", 5, &numeric_labels);
310 assert_eq!(result, Some(32));
311 }
312
313 #[test]
314 fn test_build_program_simple() {
315 let mut ast = AST::new();
316 let inst = Instruction {
317 opcode: Opcode::Exit,
318 dst: None,
319 src: None,
320 off: None,
321 imm: None,
322 span: 0..4,
323 };
324 ast.nodes.push(ASTNode::Instruction {
325 instruction: inst,
326 offset: 0,
327 });
328 ast.set_text_size(8);
329 ast.set_rodata_size(0);
330
331 let result = ast.build_program();
332 assert!(result.is_ok());
333 let parse_result = result.unwrap();
334 assert!(parse_result.prog_is_static);
335 }
336
337 #[test]
338 fn test_build_program_undefined_label_error() {
339 let mut ast = AST::new();
340
341 let inst = Instruction {
343 opcode: Opcode::Ja,
344 dst: None,
345 src: None,
346 off: Some(Either::Left("undefined_label".to_string())),
347 imm: None,
348 span: 0..10,
349 };
350 ast.nodes.push(ASTNode::Instruction {
351 instruction: inst,
352 offset: 0,
353 });
354 ast.set_text_size(8);
355
356 let result = ast.build_program();
357 assert!(result.is_err());
358 }
359}