Skip to main content

sbpf_assembler/
program.rs

1use {
2    crate::{
3        debug::{self, DebugData, reuse_debug_sections},
4        dynsym::{DynamicSymbol, RelDyn, RelocationType},
5        header::{ElfHeader, ProgramHeader},
6        parser::ParseResult,
7        section::{
8            DebugSection, DynStrSection, DynSymSection, DynamicSection, NullSection, RelDynSection,
9            Section, SectionType, ShStrTabSection,
10        },
11    },
12    std::{fs::File, io::Write, path::Path},
13};
14
15#[derive(Debug)]
16pub struct Program {
17    pub elf_header: ElfHeader,
18    pub program_headers: Option<Vec<ProgramHeader>>,
19    pub sections: Vec<SectionType>,
20}
21
22impl Program {
23    pub fn from_parse_result(
24        ParseResult {
25            code_section,
26            data_section,
27            dynamic_symbols,
28            relocation_data,
29            prog_is_static,
30            arch,
31            debug_sections,
32        }: ParseResult,
33        debug_data: Option<DebugData>,
34    ) -> Self {
35        let mut elf_header = ElfHeader::new();
36        let mut program_headers = None;
37
38        let bytecode_size = code_section.size();
39        let rodata_size = data_section.size();
40
41        let has_rodata = rodata_size > 0;
42        let ph_count = if arch.is_v3() {
43            if has_rodata { 2 } else { 1 }
44        } else if prog_is_static {
45            0
46        } else {
47            3
48        };
49
50        elf_header.e_flags = arch.e_flags();
51        elf_header.e_phnum = ph_count;
52
53        // save read + execute size for program header before
54        // ownership of code/data sections is transferred
55        let text_size = bytecode_size + rodata_size;
56
57        // Calculate base offset after ELF header and program headers
58        let base_offset = 64 + (ph_count as u64 * 56); // 64 bytes ELF header, 56 bytes per program header
59        let mut current_offset = base_offset;
60
61        let text_offset = if arch.is_v3() && has_rodata {
62            rodata_size + base_offset
63        } else {
64            base_offset
65        };
66
67        // Get the entry point offset from dynamic_symbols if available
68        let entry_point_offset = dynamic_symbols
69            .get_entry_points()
70            .first()
71            .map(|(_, offset)| *offset)
72            .unwrap_or(0);
73
74        elf_header.e_entry = if arch.is_v3() {
75            ProgramHeader::V3_BYTECODE_VADDR + entry_point_offset
76        } else {
77            text_offset + entry_point_offset
78        };
79
80        // Create a vector of sections
81        let mut sections = Vec::new();
82        sections.push(SectionType::Default(NullSection::new()));
83
84        let mut section_names = Vec::new();
85
86        // Add section_names in fixed order for shstrtab
87        section_names.push(".text".to_string());
88        if has_rodata {
89            section_names.push(".rodata".to_string());
90        }
91
92        if arch.is_v3() && has_rodata {
93            // Data section
94            let mut rodata_section = SectionType::Data(data_section);
95            rodata_section.set_offset(current_offset);
96            current_offset += rodata_section.size();
97            sections.push(rodata_section);
98
99            // Code section
100            let mut text_section = SectionType::Code(code_section);
101            text_section.set_offset(current_offset);
102            current_offset += text_section.size();
103            sections.push(text_section);
104        } else {
105            // Code section
106            let mut text_section = SectionType::Code(code_section);
107            text_section.set_offset(current_offset);
108            current_offset += text_section.size();
109            sections.push(text_section);
110
111            // Data section (if any)
112            if has_rodata {
113                let mut rodata_section = SectionType::Data(data_section);
114                rodata_section.set_offset(current_offset);
115                current_offset += rodata_section.size();
116                sections.push(rodata_section);
117            }
118        }
119
120        let padding = (8 - (current_offset % 8)) % 8;
121        current_offset += padding;
122
123        if arch.is_v3() {
124            // Generate debug sections
125            let debug_sections = Self::generate_debug_sections(
126                debug_sections,
127                &debug_data,
128                text_offset,
129                &mut section_names,
130                &mut current_offset,
131            );
132
133            for debug_section in debug_sections {
134                sections.push(debug_section);
135            }
136
137            let mut shstrtab_section = ShStrTabSection::new(
138                section_names
139                    .iter()
140                    .map(|name| name.len() + 1)
141                    .sum::<usize>() as u32,
142                section_names,
143            );
144            shstrtab_section.set_offset(current_offset);
145            current_offset += shstrtab_section.size();
146            sections.push(SectionType::ShStrTab(shstrtab_section));
147
148            if has_rodata {
149                // 2 headers: rodata (PF_R) then bytecode (PF_X)
150                let rodata_offset = base_offset;
151                let bytecode_offset = base_offset + rodata_size;
152                program_headers = Some(vec![
153                    ProgramHeader::new_load(rodata_offset, rodata_size, false, arch),
154                    ProgramHeader::new_load(bytecode_offset, bytecode_size, true, arch),
155                ]);
156            } else {
157                // 1 header: bytecode only (PF_X)
158                program_headers = Some(vec![ProgramHeader::new_load(
159                    base_offset,
160                    bytecode_size,
161                    true,
162                    arch,
163                )]);
164            }
165        } else if !prog_is_static {
166            let mut symbol_names = Vec::new();
167            let mut dyn_syms = Vec::new();
168            let mut dyn_str_offset = 1;
169
170            dyn_syms.push(DynamicSymbol::new(0, 0, 0, 0, 0, 0));
171
172            // all symbols handled right now are all global symbols
173            for (name, _) in dynamic_symbols.get_entry_points() {
174                symbol_names.push(name.clone());
175                dyn_syms.push(DynamicSymbol::new(
176                    dyn_str_offset as u32,
177                    0x10,
178                    0,
179                    1,
180                    elf_header.e_entry,
181                    0,
182                ));
183                dyn_str_offset += name.len() + 1;
184            }
185
186            for (name, _) in dynamic_symbols.get_call_targets() {
187                symbol_names.push(name.clone());
188                dyn_syms.push(DynamicSymbol::new(dyn_str_offset as u32, 0x10, 0, 0, 0, 0));
189                dyn_str_offset += name.len() + 1;
190            }
191
192            let mut rel_count = 0;
193            let mut rel_dyns = Vec::new();
194            for (offset, rel_type, name) in relocation_data.get_rel_dyns() {
195                if rel_type == RelocationType::RSbfSyscall {
196                    if let Some(index) = symbol_names.iter().position(|n| *n == name) {
197                        rel_dyns.push(RelDyn::new(
198                            offset + text_offset,
199                            rel_type as u64,
200                            index as u64 + 1,
201                        ));
202                    } else {
203                        panic!("Symbol {} not found in symbol_names", name);
204                    }
205                } else if rel_type == RelocationType::RSbf64Relative {
206                    rel_count += 1;
207                    rel_dyns.push(RelDyn::new(offset + text_offset, rel_type as u64, 0));
208                }
209            }
210            // create four dynamic related sections
211            let mut dynamic_section = SectionType::Dynamic(DynamicSection::new(
212                (section_names
213                    .iter()
214                    .map(|name| name.len() + 1)
215                    .sum::<usize>()
216                    + 1) as u32,
217            ));
218            section_names.push(dynamic_section.name().to_string());
219
220            let mut dynsym_section = SectionType::DynSym(DynSymSection::new(
221                (section_names
222                    .iter()
223                    .map(|name| name.len() + 1)
224                    .sum::<usize>()
225                    + 1) as u32,
226                dyn_syms,
227            ));
228            section_names.push(dynsym_section.name().to_string());
229
230            let mut dynstr_section = SectionType::DynStr(DynStrSection::new(
231                (section_names
232                    .iter()
233                    .map(|name| name.len() + 1)
234                    .sum::<usize>()
235                    + 1) as u32,
236                symbol_names,
237            ));
238            section_names.push(dynstr_section.name().to_string());
239
240            let mut rel_dyn_section = SectionType::RelDyn(RelDynSection::new(
241                (section_names
242                    .iter()
243                    .map(|name| name.len() + 1)
244                    .sum::<usize>()
245                    + 1) as u32,
246                rel_dyns,
247            ));
248            section_names.push(rel_dyn_section.name().to_string());
249
250            dynamic_section.set_offset(current_offset);
251            if let SectionType::Dynamic(ref mut dynamic_section) = dynamic_section {
252                // link to .dynstr
253                dynamic_section.set_link(
254                    section_names
255                        .iter()
256                        .position(|name| name == ".dynstr")
257                        .expect("missing .dynstr section") as u32
258                        + 1,
259                );
260                dynamic_section.set_rel_count(rel_count);
261            }
262            current_offset += dynamic_section.size();
263
264            dynsym_section.set_offset(current_offset);
265            if let SectionType::DynSym(ref mut dynsym_section) = dynsym_section {
266                // link to .dynstr
267                dynsym_section.set_link(
268                    section_names
269                        .iter()
270                        .position(|name| name == ".dynstr")
271                        .expect("missing .dynstr section") as u32
272                        + 1,
273                );
274            }
275            current_offset += dynsym_section.size();
276
277            dynstr_section.set_offset(current_offset);
278            current_offset += dynstr_section.size();
279
280            rel_dyn_section.set_offset(current_offset);
281            if let SectionType::RelDyn(ref mut rel_dyn_section) = rel_dyn_section {
282                // link to .dynsym
283                rel_dyn_section.set_link(
284                    section_names
285                        .iter()
286                        .position(|name| name == ".dynsym")
287                        .expect("missing .dynsym section") as u32
288                        + 1,
289                );
290            }
291            current_offset += rel_dyn_section.size();
292
293            if let SectionType::Dynamic(ref mut dynamic_section) = dynamic_section {
294                dynamic_section.set_rel_offset(rel_dyn_section.offset());
295                dynamic_section.set_rel_size(rel_dyn_section.size());
296                dynamic_section.set_dynsym_offset(dynsym_section.offset());
297                dynamic_section.set_dynstr_offset(dynstr_section.offset());
298                dynamic_section.set_dynstr_size(dynstr_section.size());
299            }
300
301            // Generate debug sections
302            let debug_sections = Self::generate_debug_sections(
303                debug_sections,
304                &debug_data,
305                text_offset,
306                &mut section_names,
307                &mut current_offset,
308            );
309
310            let mut shstrtab_section = SectionType::ShStrTab(ShStrTabSection::new(
311                (section_names
312                    .iter()
313                    .map(|name| name.len() + 1)
314                    .sum::<usize>()
315                    + 1) as u32,
316                section_names,
317            ));
318            shstrtab_section.set_offset(current_offset);
319            current_offset += shstrtab_section.size();
320
321            program_headers = Some(vec![
322                ProgramHeader::new_load(
323                    text_offset,
324                    text_size,
325                    true, // executable
326                    arch,
327                ),
328                ProgramHeader::new_load(
329                    dynsym_section.offset(),
330                    dynsym_section.size() + dynstr_section.size() + rel_dyn_section.size(),
331                    false,
332                    arch,
333                ),
334                ProgramHeader::new_dynamic(dynamic_section.offset(), dynamic_section.size()),
335            ]);
336
337            sections.push(dynamic_section);
338            sections.push(dynsym_section);
339            sections.push(dynstr_section);
340            sections.push(rel_dyn_section);
341
342            for debug_section in debug_sections {
343                sections.push(debug_section);
344            }
345
346            sections.push(shstrtab_section);
347        } else {
348            // Create a vector of section names
349            let mut section_names = Vec::new();
350            for section in &sections {
351                section_names.push(section.name().to_string());
352            }
353
354            // Generate debug sections
355            let debug_sections = Self::generate_debug_sections(
356                debug_sections,
357                &debug_data,
358                text_offset,
359                &mut section_names,
360                &mut current_offset,
361            );
362
363            for debug_section in debug_sections {
364                sections.push(debug_section);
365            }
366
367            let mut shstrtab_section = ShStrTabSection::new(
368                section_names
369                    .iter()
370                    .map(|name| name.len() + 1)
371                    .sum::<usize>() as u32,
372                section_names,
373            );
374            shstrtab_section.set_offset(current_offset);
375            current_offset += shstrtab_section.size();
376            sections.push(SectionType::ShStrTab(shstrtab_section));
377        }
378
379        // Update section header offset in ELF header
380        let padding = (8 - (current_offset % 8)) % 8;
381        elf_header.e_shoff = current_offset + padding;
382        elf_header.e_shnum = sections.len() as u16;
383        elf_header.e_shstrndx = sections.len() as u16 - 1;
384
385        Self {
386            elf_header,
387            program_headers,
388            sections,
389        }
390    }
391
392    pub fn emit_bytecode(&self) -> Vec<u8> {
393        let mut bytes = Vec::new();
394
395        // Emit ELF Header bytes
396        bytes.extend(self.elf_header.bytecode());
397
398        // Emit program headers
399        if let Some(program_headers) = &self.program_headers {
400            for ph in program_headers {
401                bytes.extend(ph.bytecode());
402            }
403        }
404
405        // Emit sections
406        for section in &self.sections {
407            bytes.extend(section.bytecode());
408        }
409
410        // Emit section headers
411        for section in &self.sections {
412            bytes.extend(section.section_header_bytecode());
413        }
414
415        bytes
416    }
417
418    fn generate_debug_sections(
419        parsed_debug_sections: Vec<DebugSection>,
420        debug_data: &Option<DebugData>,
421        text_offset: u64,
422        section_names: &mut Vec<String>,
423        current_offset: &mut u64,
424    ) -> Vec<SectionType> {
425        if let Some(data) = debug_data {
426            debug::generate_debug_sections(data, text_offset, section_names, current_offset)
427                .into_iter()
428                .enumerate()
429                .map(|(i, s)| match i {
430                    0 => SectionType::DebugAbbrev(s),
431                    1 => SectionType::DebugInfo(s),
432                    2 => SectionType::DebugLine(s),
433                    3 => SectionType::DebugLineStr(s),
434                    _ => unreachable!(),
435                })
436                .collect()
437        } else {
438            reuse_debug_sections(parsed_debug_sections, section_names, current_offset)
439        }
440    }
441
442    pub fn has_rodata(&self) -> bool {
443        self.sections.iter().any(|s| s.name() == ".rodata")
444    }
445
446    pub fn parse_rodata(&self) -> Vec<(String, usize, String)> {
447        let rodata = self
448            .sections
449            .iter()
450            .find(|s| s.name() == ".rodata")
451            .unwrap();
452        if let SectionType::Data(data_section) = rodata {
453            data_section.rodata()
454        } else {
455            panic!("ROData section not found");
456        }
457    }
458
459    pub fn save_to_file(&self, input_path: &str) -> std::io::Result<()> {
460        // Get the file stem (name without extension) from input path
461        let path = Path::new(input_path);
462        let file_stem = path
463            .file_stem()
464            .and_then(|s| s.to_str())
465            .unwrap_or("output");
466
467        // Create the output file name with .so extension
468        let output_path = format!("{}.so", file_stem);
469
470        // Get the bytecode
471        let bytes = self.emit_bytecode();
472
473        // Write bytes to file
474        let mut file = File::create(output_path)?;
475        file.write_all(&bytes)?;
476
477        Ok(())
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use {
484        super::*,
485        crate::{SbpfArch, parser::parse},
486    };
487
488    #[test]
489    fn test_program_from_simple_source() {
490        let source = "exit";
491        let parse_result = parse(source, SbpfArch::V0).unwrap();
492        let program = Program::from_parse_result(parse_result, None);
493
494        // Verify basic structure
495        assert!(!program.sections.is_empty());
496        assert!(program.sections.len() >= 2);
497    }
498
499    #[test]
500    fn test_program_without_rodata() {
501        let source = "exit";
502        let parse_result = parse(source, SbpfArch::V0).unwrap();
503        let program = Program::from_parse_result(parse_result, None);
504
505        assert!(!program.has_rodata());
506    }
507
508    #[test]
509    fn test_program_emit_bytecode() {
510        let source = "exit";
511        let parse_result = parse(source, SbpfArch::V0).unwrap();
512        let program = Program::from_parse_result(parse_result, None);
513
514        let bytecode = program.emit_bytecode();
515        assert!(!bytecode.is_empty());
516        // Should start with ELF magic
517        assert_eq!(&bytecode[0..4], b"\x7fELF");
518    }
519
520    #[test]
521    fn test_program_static_no_program_headers() {
522        // Create a static program (no dynamic symbols)
523        let source = "exit";
524        let mut parse_result = parse(source, SbpfArch::V0).unwrap();
525        parse_result.prog_is_static = true;
526
527        let program = Program::from_parse_result(parse_result, None);
528        assert!(program.program_headers.is_none());
529        assert_eq!(program.elf_header.e_phnum, 0);
530    }
531
532    #[test]
533    fn test_program_sections_ordering() {
534        let source = "exit";
535        let parse_result = parse(source, SbpfArch::V0).unwrap();
536        let program = Program::from_parse_result(parse_result, None);
537
538        // First section should be null
539        assert_eq!(program.sections[0].name(), "");
540        // Second should be .text
541        assert_eq!(program.sections[1].name(), ".text");
542    }
543
544    #[test]
545    fn test_program_sections_debug() {
546        let source = "exit";
547        let parse_result = parse(source, SbpfArch::V0).unwrap();
548        let debug_data = Some(DebugData {
549            filename: "test.s".to_string(),
550            directory: "/test".to_string(),
551            lines: vec![],
552            labels: vec![],
553            code_start: 0,
554            code_end: 8,
555        });
556        let program = Program::from_parse_result(parse_result, debug_data);
557
558        let debug_section_names: Vec<&str> = program
559            .sections
560            .iter()
561            .map(|s| s.name())
562            .filter(|name| name.starts_with(".debug_"))
563            .collect();
564
565        assert!(debug_section_names.contains(&".debug_abbrev"));
566        assert!(debug_section_names.contains(&".debug_info"));
567        assert!(debug_section_names.contains(&".debug_line"));
568        assert!(debug_section_names.contains(&".debug_line_str"));
569    }
570
571    #[test]
572    fn test_v3_e_flags() {
573        let source = "exit";
574        let parse_result = parse(source, SbpfArch::V3).unwrap();
575        let program = Program::from_parse_result(parse_result, None);
576        assert_eq!(program.elf_header.e_flags, 3);
577    }
578
579    #[test]
580    fn test_v3_no_rodata_one_header() {
581        let source = "exit";
582        let parse_result = parse(source, SbpfArch::V3).unwrap();
583        let program = Program::from_parse_result(parse_result, None);
584
585        let headers = program.program_headers.as_ref().unwrap();
586        assert_eq!(headers.len(), 1);
587        assert_eq!(headers[0].p_flags, ProgramHeader::PF_X);
588        assert_eq!(headers[0].p_vaddr, ProgramHeader::V3_BYTECODE_VADDR);
589    }
590
591    #[test]
592    fn test_v3_with_rodata_two_headers() {
593        let source = r#"
594.rodata
595msg: .ascii "test"
596.text
597.globl entrypoint
598entrypoint:
599    exit
600        "#;
601        let parse_result = parse(source, SbpfArch::V3).unwrap();
602        let program = Program::from_parse_result(parse_result, None);
603
604        let headers = program.program_headers.as_ref().unwrap();
605        assert_eq!(headers.len(), 2);
606        // first header: rodata (PF_R, vaddr=0)
607        assert_eq!(headers[0].p_flags, ProgramHeader::PF_R);
608        assert_eq!(headers[0].p_vaddr, ProgramHeader::V3_RODATA_VADDR);
609        // second header: bytecode (PF_X, vaddr=1<<32)
610        assert_eq!(headers[1].p_flags, ProgramHeader::PF_X);
611        assert_eq!(headers[1].p_vaddr, ProgramHeader::V3_BYTECODE_VADDR);
612    }
613
614    #[test]
615    fn test_v3_e_entry() {
616        let source = r#"
617.globl entrypoint
618entrypoint:
619    exit
620        "#;
621        let parse_result = parse(source, SbpfArch::V3).unwrap();
622        let program = Program::from_parse_result(parse_result, None);
623
624        // v3: e_entry must be >= V3_BYTECODE_VADDR (1 << 32)
625        assert!(program.elf_header.e_entry >= ProgramHeader::V3_BYTECODE_VADDR,);
626    }
627
628    #[test]
629    fn test_v3_p_offset() {
630        let source = r#"
631.rodata
632msg: .ascii "test"
633.text
634.globl entrypoint
635entrypoint:
636    exit
637        "#;
638        let parse_result = parse(source, SbpfArch::V3).unwrap();
639        let program = Program::from_parse_result(parse_result, None);
640
641        let headers = program.program_headers.as_ref().unwrap();
642        let expected_first_offset = 64 + (program.elf_header.e_phnum as u64) * 56;
643        assert_eq!(headers[0].p_offset, expected_first_offset);
644        assert_eq!(
645            headers[1].p_offset,
646            headers[0].p_offset + headers[0].p_filesz
647        );
648    }
649
650    #[test]
651    fn test_v3_no_dynamic_sections() {
652        let source = r#"
653.globl entrypoint
654entrypoint:
655    call sol_log_64_
656    exit
657        "#;
658        let parse_result = parse(source, SbpfArch::V3).unwrap();
659        let program = Program::from_parse_result(parse_result, None);
660
661        // v3 should not have any dynamic sections
662        let section_names: Vec<&str> = program.sections.iter().map(|s| s.name()).collect();
663        assert!(!section_names.contains(&".dynamic"));
664        assert!(!section_names.contains(&".dynsym"));
665        assert!(!section_names.contains(&".dynstr"));
666        assert!(!section_names.contains(&".rel.dyn"));
667    }
668}