Skip to main content

sbpf_assembler/
parser.rs

1use {
2    crate::{
3        SbpfArch,
4        ast::AST,
5        astnode::{ASTNode, ExternDecl, GlobalDecl, Label, ROData, RodataDecl},
6        dynsym::{DynamicSymbolMap, RelDynMap},
7        errors::CompileError,
8        section::{CodeSection, DataSection, DebugSection},
9    },
10    either::Either,
11    pest::{Parser, iterators::Pair},
12    pest_derive::Parser,
13    sbpf_common::{
14        inst_param::{Number, Register},
15        instruction::Instruction,
16        opcode::Opcode,
17    },
18    std::{collections::HashMap, str::FromStr},
19};
20#[derive(Parser)]
21#[grammar = "sbpf.pest"]
22pub struct SbpfParser;
23
24/// Context containing all mutable state during parsing
25struct ParseContext<'a> {
26    ast: &'a mut AST,
27    const_map: &'a mut HashMap<String, Number>,
28    label_spans: &'a mut HashMap<String, std::ops::Range<usize>>,
29    errors: Vec<CompileError>,
30    rodata_phase: bool,
31    text_offset: u64,
32    rodata_offset: u64,
33    missing_text_directive: bool,
34}
35
36/// BPF_X flag: Converts immediate variant opcodes to register variant opcodes
37const BPF_X: u8 = 0x08;
38
39/// Token types used in the AST
40#[derive(Debug, Clone)]
41pub enum Token {
42    Directive(String, std::ops::Range<usize>),
43    Identifier(String, std::ops::Range<usize>),
44    ImmediateValue(Number, std::ops::Range<usize>),
45    StringLiteral(String, std::ops::Range<usize>),
46    VectorLiteral(Vec<Number>, std::ops::Range<usize>),
47}
48
49pub struct ParseResult {
50    // TODO: parse result is basically 1. static part 2. dynamic part of the program
51    pub code_section: CodeSection,
52
53    pub data_section: DataSection,
54
55    pub dynamic_symbols: DynamicSymbolMap,
56
57    pub relocation_data: RelDynMap,
58
59    // TODO: this can be removed and dynamic-ness should just be
60    // determined by if there's any dynamic symbol
61    pub prog_is_static: bool,
62
63    pub arch: SbpfArch,
64
65    // Debug sections we came across while byteparsing
66    pub debug_sections: Vec<DebugSection>,
67}
68
69pub fn parse(source: &str, arch: SbpfArch) -> Result<ParseResult, Vec<CompileError>> {
70    let pairs = SbpfParser::parse(Rule::program, source).map_err(|e| {
71        vec![CompileError::ParseError {
72            error: e.to_string(),
73            span: 0..source.len(),
74            custom_label: None,
75        }]
76    })?;
77
78    let mut ast = AST::new();
79    let mut const_map = HashMap::<String, Number>::new();
80    let mut label_spans = HashMap::<String, std::ops::Range<usize>>::new();
81
82    let (text_offset, rodata_offset, errors) = {
83        let mut ctx = ParseContext {
84            ast: &mut ast,
85            const_map: &mut const_map,
86            label_spans: &mut label_spans,
87            errors: Vec::new(),
88            rodata_phase: false,
89            text_offset: 0,
90            rodata_offset: 0,
91            missing_text_directive: false,
92        };
93
94        for pair in pairs {
95            if pair.as_rule() == Rule::program {
96                for statement in pair.into_inner() {
97                    if statement.as_rule() == Rule::EOI {
98                        continue;
99                    }
100
101                    process_statement(statement, &mut ctx);
102                }
103            }
104        }
105
106        (ctx.text_offset, ctx.rodata_offset, ctx.errors)
107    };
108
109    if !errors.is_empty() {
110        return Err(errors);
111    }
112
113    ast.set_text_size(text_offset);
114    ast.set_rodata_size(rodata_offset);
115
116    ast.build_program(arch)
117}
118
119fn process_statement(pair: Pair<Rule>, ctx: &mut ParseContext) {
120    for inner in pair.into_inner() {
121        match inner.as_rule() {
122            Rule::label => {
123                let mut label_opt = None;
124                let mut directive_opt = None;
125                let mut instruction_opt = None;
126
127                for item in inner.into_inner() {
128                    match item.as_rule() {
129                        Rule::identifier | Rule::numeric_label => {
130                            match extract_label_from_pair(item) {
131                                Ok(label) => label_opt = Some(label),
132                                Err(e) => ctx.errors.push(e),
133                            }
134                        }
135                        Rule::directive_inner => {
136                            directive_opt = Some(item);
137                        }
138                        Rule::instruction => {
139                            instruction_opt = Some(item);
140                        }
141                        _ => {}
142                    }
143                }
144
145                if let Some((label_name, label_span)) = label_opt {
146                    // Check for duplicate labels
147                    if let Some(original_span) = ctx.label_spans.get(&label_name) {
148                        ctx.errors.push(CompileError::DuplicateLabel {
149                            label: label_name,
150                            span: label_span,
151                            original_span: original_span.clone(),
152                            custom_label: Some("Label already defined".to_string()),
153                        });
154                        continue;
155                    }
156                    ctx.label_spans
157                        .insert(label_name.clone(), label_span.clone());
158
159                    if ctx.rodata_phase {
160                        // Handle rodata label with directive
161                        if let Some(dir_pair) = directive_opt {
162                            match process_rodata_directive(
163                                label_name.clone(),
164                                label_span.clone(),
165                                dir_pair,
166                            ) {
167                                Ok(rodata) => {
168                                    let size = rodata.get_size();
169                                    ctx.ast.rodata_nodes.push(ASTNode::ROData {
170                                        rodata,
171                                        offset: ctx.rodata_offset,
172                                    });
173                                    ctx.rodata_offset += size;
174                                }
175                                Err(e) => ctx.errors.push(e),
176                            }
177                        } else if let Some(inst_pair) = instruction_opt {
178                            if let Err(e) = process_instruction(inst_pair, ctx.const_map) {
179                                ctx.errors.push(e);
180                            }
181                            if !ctx.missing_text_directive {
182                                ctx.missing_text_directive = true;
183                                ctx.errors.push(CompileError::MissingTextDirective {
184                                    span: label_span,
185                                    custom_label: None,
186                                });
187                            }
188                        }
189                    } else {
190                        ctx.ast.nodes.push(ASTNode::Label {
191                            label: Label {
192                                name: label_name,
193                                span: label_span,
194                            },
195                            offset: ctx.text_offset,
196                        });
197
198                        if let Some(inst_pair) = instruction_opt {
199                            match process_instruction(inst_pair, ctx.const_map) {
200                                Ok(instruction) => {
201                                    let size = instruction.get_size();
202                                    ctx.ast.nodes.push(ASTNode::Instruction {
203                                        instruction,
204                                        offset: ctx.text_offset,
205                                    });
206                                    ctx.text_offset += size;
207                                }
208                                Err(e) => ctx.errors.push(e),
209                            }
210                        }
211                    }
212                }
213            }
214            Rule::directive => {
215                process_directive_statement(inner, ctx);
216            }
217            Rule::instruction => {
218                let span = inner.as_span();
219                let span_range = span.start()..span.end();
220
221                match process_instruction(inner, ctx.const_map) {
222                    Ok(instruction) => {
223                        if !ctx.rodata_phase {
224                            let size = instruction.get_size();
225                            ctx.ast.nodes.push(ASTNode::Instruction {
226                                instruction,
227                                offset: ctx.text_offset,
228                            });
229                            ctx.text_offset += size;
230                        }
231                    }
232                    Err(e) => ctx.errors.push(e),
233                }
234
235                if ctx.rodata_phase && !ctx.missing_text_directive {
236                    ctx.missing_text_directive = true;
237                    ctx.errors.push(CompileError::MissingTextDirective {
238                        span: span_range,
239                        custom_label: None,
240                    });
241                }
242            }
243            _ => {}
244        }
245    }
246}
247
248fn extract_label_from_pair(
249    pair: Pair<Rule>,
250) -> Result<(String, std::ops::Range<usize>), CompileError> {
251    let span = pair.as_span();
252    Ok((pair.as_str().to_string(), span.start()..span.end()))
253}
254
255fn process_directive_statement(pair: Pair<Rule>, ctx: &mut ParseContext) {
256    for directive_inner_pair in pair.into_inner() {
257        process_directive_inner(directive_inner_pair, ctx);
258    }
259}
260
261fn process_directive_inner(pair: Pair<Rule>, ctx: &mut ParseContext) {
262    for inner in pair.into_inner() {
263        match inner.as_rule() {
264            Rule::directive_globl => {
265                let span = inner.as_span();
266                for globl_inner in inner.into_inner() {
267                    if globl_inner.as_rule() == Rule::globl_symbol {
268                        let entry_label = globl_inner.as_str().to_string();
269                        ctx.ast.nodes.push(ASTNode::GlobalDecl {
270                            global_decl: GlobalDecl {
271                                entry_label,
272                                span: span.start()..span.end(),
273                            },
274                        });
275                    }
276                }
277            }
278            Rule::directive_extern => {
279                let span = inner.as_span();
280                let mut symbols = Vec::new();
281                for extern_inner in inner.into_inner() {
282                    if extern_inner.as_rule() == Rule::symbol {
283                        let symbol_span = extern_inner.as_span();
284                        symbols.push(Token::Identifier(
285                            extern_inner.as_str().to_string(),
286                            symbol_span.start()..symbol_span.end(),
287                        ));
288                    }
289                }
290                ctx.ast.nodes.push(ASTNode::ExternDecl {
291                    extern_decl: ExternDecl {
292                        args: symbols,
293                        span: span.start()..span.end(),
294                    },
295                });
296            }
297            Rule::directive_equ => {
298                let mut ident = None;
299                let mut value = None;
300
301                for equ_inner in inner.into_inner() {
302                    match equ_inner.as_rule() {
303                        Rule::identifier => {
304                            ident = Some(equ_inner.as_str().to_string());
305                        }
306                        Rule::expression => match eval_expression(equ_inner, ctx.const_map) {
307                            Ok(v) => value = Some(v),
308                            Err(e) => ctx.errors.push(e),
309                        },
310                        _ => {}
311                    }
312                }
313
314                if let (Some(name), Some(val)) = (ident, value) {
315                    ctx.const_map.insert(name, val);
316                }
317            }
318            Rule::directive_section => {
319                let section_name = inner.as_str().trim_start_matches('.');
320                match section_name {
321                    "text" => ctx.rodata_phase = false,
322                    "rodata" => {
323                        ctx.rodata_phase = true;
324                        let span = inner.as_span();
325                        ctx.ast.nodes.push(ASTNode::RodataDecl {
326                            rodata_decl: RodataDecl {
327                                span: span.start()..span.end(),
328                            },
329                        });
330                    }
331                    _ => {}
332                }
333            }
334            _ => {}
335        }
336    }
337}
338
339fn process_rodata_directive(
340    label_name: String,
341    label_span: std::ops::Range<usize>,
342    pair: Pair<Rule>,
343) -> Result<ROData, CompileError> {
344    let inner_pair = if pair.as_rule() == Rule::directive_inner {
345        pair
346    } else {
347        pair.into_inner()
348            .next()
349            .ok_or_else(|| CompileError::ParseError {
350                error: "No directive content found".to_string(),
351                span: label_span.clone(),
352                custom_label: None,
353            })?
354    };
355
356    for inner in inner_pair.into_inner() {
357        let directive_span = inner.as_span();
358
359        match inner.as_rule() {
360            Rule::directive_ascii => {
361                for ascii_inner in inner.into_inner() {
362                    if ascii_inner.as_rule() == Rule::string_literal {
363                        for content_inner in ascii_inner.into_inner() {
364                            if content_inner.as_rule() == Rule::string_content {
365                                let content = content_inner.as_str().to_string();
366                                let content_span = content_inner.as_span();
367                                return Ok(ROData {
368                                    name: label_name,
369                                    args: vec![
370                                        Token::Directive(
371                                            "ascii".to_string(),
372                                            directive_span.start()..directive_span.end(),
373                                        ),
374                                        Token::StringLiteral(
375                                            content,
376                                            content_span.start()..content_span.end(),
377                                        ),
378                                    ],
379                                    span: label_span,
380                                });
381                            }
382                        }
383                    }
384                }
385            }
386            Rule::directive_byte
387            | Rule::directive_short
388            | Rule::directive_word
389            | Rule::directive_int
390            | Rule::directive_long
391            | Rule::directive_quad => {
392                let directive_name = match inner.as_rule() {
393                    Rule::directive_byte => "byte",
394                    Rule::directive_short => "short",
395                    Rule::directive_word => "word",
396                    Rule::directive_int => "int",
397                    Rule::directive_long => "long",
398                    Rule::directive_quad => "quad",
399                    _ => "byte",
400                };
401
402                let mut values = Vec::new();
403                for byte_inner in inner.into_inner() {
404                    if byte_inner.as_rule() == Rule::number {
405                        values.push(parse_number(byte_inner)?);
406                    }
407                }
408
409                let values_span = directive_span.start()..directive_span.end();
410                return Ok(ROData {
411                    name: label_name,
412                    args: vec![
413                        Token::Directive(
414                            directive_name.to_string(),
415                            directive_span.start()..directive_span.end(),
416                        ),
417                        Token::VectorLiteral(values, values_span),
418                    ],
419                    span: label_span,
420                });
421            }
422            _ => {}
423        }
424    }
425
426    Err(CompileError::InvalidRodataDecl {
427        span: label_span,
428        custom_label: None,
429    })
430}
431
432fn process_instruction(
433    pair: Pair<Rule>,
434    const_map: &HashMap<String, Number>,
435) -> Result<Instruction, CompileError> {
436    let outer_span = pair.as_span();
437    let outer_span_range = outer_span.start()..outer_span.end();
438
439    for inner in pair.into_inner() {
440        let span = inner.as_span();
441        let span_range = span.start()..span.end();
442
443        match inner.as_rule() {
444            Rule::instr_exit => {
445                return Ok(Instruction {
446                    opcode: Opcode::Exit,
447                    dst: None,
448                    src: None,
449                    off: None,
450                    imm: None,
451                    span: span_range,
452                });
453            }
454            Rule::instr_lddw => return process_lddw(inner, const_map, span_range),
455            Rule::instr_call => return process_call(inner, const_map, span_range),
456            Rule::instr_callx => return process_callx(inner, span_range),
457            Rule::instr_neg32 => return process_neg32(inner, span_range),
458            Rule::instr_neg64 => return process_neg64(inner, span_range),
459            Rule::instr_alu64_imm | Rule::instr_alu32_imm => {
460                return process_alu_imm(inner, const_map, span_range);
461            }
462            Rule::instr_alu64_reg | Rule::instr_alu32_reg => {
463                return process_alu_reg(inner, span_range);
464            }
465            Rule::instr_load => return process_load(inner, const_map, span_range),
466            Rule::instr_store_imm => return process_store_imm(inner, const_map, span_range),
467            Rule::instr_store_reg => return process_store_reg(inner, const_map, span_range),
468            Rule::instr_jump_imm => return process_jump_imm(inner, const_map, span_range),
469            Rule::instr_jump_reg => return process_jump_reg(inner, span_range),
470            Rule::instr_jump_uncond => return process_jump_uncond(inner, const_map, span_range),
471            Rule::instr_endian => return process_endian(inner, span_range),
472            _ => {}
473        }
474    }
475
476    Err(CompileError::ParseError {
477        error: "Invalid instruction".to_string(),
478        span: outer_span_range,
479        custom_label: None,
480    })
481}
482
483fn process_lddw(
484    pair: Pair<Rule>,
485    const_map: &HashMap<String, Number>,
486    span: std::ops::Range<usize>,
487) -> Result<Instruction, CompileError> {
488    let mut dst = None;
489    let mut imm = None;
490
491    for inner in pair.into_inner() {
492        match inner.as_rule() {
493            Rule::register => dst = Some(parse_register(inner)?),
494            Rule::operand => imm = Some(parse_operand(inner, const_map)?),
495            _ => {}
496        }
497    }
498
499    Ok(Instruction {
500        opcode: Opcode::Lddw,
501        dst,
502        src: None,
503        off: None,
504        imm,
505        span,
506    })
507}
508
509fn process_load(
510    pair: Pair<Rule>,
511    const_map: &HashMap<String, Number>,
512    span: std::ops::Range<usize>,
513) -> Result<Instruction, CompileError> {
514    let mut opcode = None;
515    let mut dst = None;
516    let mut src = None;
517    let mut off = None;
518
519    for inner in pair.into_inner() {
520        match inner.as_rule() {
521            Rule::load_op => opcode = Opcode::from_str(inner.as_str()).ok(),
522            Rule::register => dst = Some(parse_register(inner)?),
523            Rule::memory_ref => {
524                let (s, o) = parse_memory_ref(inner, const_map)?;
525                src = Some(s);
526                off = Some(o);
527            }
528            _ => {}
529        }
530    }
531
532    Ok(Instruction {
533        opcode: opcode.unwrap_or(Opcode::Exit),
534        dst,
535        src,
536        off,
537        imm: None,
538        span,
539    })
540}
541
542fn process_store_imm(
543    pair: Pair<Rule>,
544    const_map: &HashMap<String, Number>,
545    span: std::ops::Range<usize>,
546) -> Result<Instruction, CompileError> {
547    let mut opcode = None;
548    let mut dst = None;
549    let mut off = None;
550    let mut imm = None;
551
552    for inner in pair.into_inner() {
553        match inner.as_rule() {
554            Rule::store_op_imm => opcode = Opcode::from_str(inner.as_str()).ok(),
555            Rule::memory_ref => {
556                let (d, o) = parse_memory_ref(inner, const_map)?;
557                dst = Some(d);
558                off = Some(o);
559            }
560            Rule::operand => imm = Some(parse_operand(inner, const_map)?),
561            _ => {}
562        }
563    }
564
565    Ok(Instruction {
566        opcode: opcode.unwrap_or(Opcode::Exit),
567        dst,
568        src: None,
569        off,
570        imm,
571        span,
572    })
573}
574
575fn process_store_reg(
576    pair: Pair<Rule>,
577    const_map: &HashMap<String, Number>,
578    span: std::ops::Range<usize>,
579) -> Result<Instruction, CompileError> {
580    let mut opcode = None;
581    let mut dst = None;
582    let mut src = None;
583    let mut off = None;
584
585    for inner in pair.into_inner() {
586        match inner.as_rule() {
587            Rule::store_op_reg => opcode = Opcode::from_str(inner.as_str()).ok(),
588            Rule::memory_ref => {
589                let (d, o) = parse_memory_ref(inner, const_map)?;
590                dst = Some(d);
591                off = Some(o);
592            }
593            Rule::register => src = Some(parse_register(inner)?),
594            _ => {}
595        }
596    }
597
598    Ok(Instruction {
599        opcode: opcode.unwrap_or(Opcode::Exit),
600        dst,
601        src,
602        off,
603        imm: None,
604        span,
605    })
606}
607
608fn process_alu_imm(
609    pair: Pair<Rule>,
610    const_map: &HashMap<String, Number>,
611    span: std::ops::Range<usize>,
612) -> Result<Instruction, CompileError> {
613    let mut opcode = None;
614    let mut dst = None;
615    let mut imm = None;
616
617    for inner in pair.into_inner() {
618        match inner.as_rule() {
619            Rule::alu_64_op | Rule::alu_32_op => opcode = Opcode::from_str(inner.as_str()).ok(),
620            Rule::register => dst = Some(parse_register(inner)?),
621            Rule::operand => imm = Some(parse_operand(inner, const_map)?),
622            _ => {}
623        }
624    }
625
626    Ok(Instruction {
627        opcode: opcode.unwrap_or(Opcode::Exit),
628        dst,
629        src: None,
630        off: None,
631        imm,
632        span,
633    })
634}
635
636fn process_alu_reg(
637    pair: Pair<Rule>,
638    span: std::ops::Range<usize>,
639) -> Result<Instruction, CompileError> {
640    let mut opcode = None;
641    let mut dst = None;
642    let mut src = None;
643
644    for inner in pair.into_inner() {
645        match inner.as_rule() {
646            Rule::alu_64_op | Rule::alu_32_op => {
647                let op_str = inner.as_str();
648                let inner_span = inner.as_span();
649                if let Ok(opc) = Opcode::from_str(op_str) {
650                    // Convert to register variant using BPF_X flag
651                    let reg_opcode = Into::<u8>::into(opc) | BPF_X;
652                    opcode =
653                        Some(
654                            reg_opcode
655                                .try_into()
656                                .map_err(|e| CompileError::BytecodeError {
657                                    error: format!("Invalid opcode 0x{:02x}: {}", reg_opcode, e),
658                                    span: inner_span.start()..inner_span.end(),
659                                    custom_label: None,
660                                })?,
661                        );
662                }
663            }
664            Rule::register => {
665                if dst.is_none() {
666                    dst = Some(parse_register(inner)?);
667                } else {
668                    src = Some(parse_register(inner)?);
669                }
670            }
671            _ => {}
672        }
673    }
674
675    Ok(Instruction {
676        opcode: opcode.unwrap_or(Opcode::Exit),
677        dst,
678        src,
679        off: None,
680        imm: None,
681        span,
682    })
683}
684
685fn process_jump_imm(
686    pair: Pair<Rule>,
687    const_map: &HashMap<String, Number>,
688    span: std::ops::Range<usize>,
689) -> Result<Instruction, CompileError> {
690    let mut opcode = None;
691    let mut dst = None;
692    let mut imm = None;
693    let mut off = None;
694
695    for inner in pair.into_inner() {
696        match inner.as_rule() {
697            Rule::jump_op => opcode = Opcode::from_str(inner.as_str()).ok(),
698            Rule::register => dst = Some(parse_register(inner)?),
699            Rule::operand => imm = Some(parse_operand(inner, const_map)?),
700            Rule::jump_target => off = Some(parse_jump_target(inner, const_map)?),
701            _ => {}
702        }
703    }
704
705    Ok(Instruction {
706        opcode: opcode.unwrap_or(Opcode::Exit),
707        dst,
708        src: None,
709        off,
710        imm,
711        span,
712    })
713}
714
715fn process_jump_reg(
716    pair: Pair<Rule>,
717    span: std::ops::Range<usize>,
718) -> Result<Instruction, CompileError> {
719    let mut opcode = None;
720    let mut dst = None;
721    let mut src = None;
722    let mut off = None;
723
724    for inner in pair.into_inner() {
725        match inner.as_rule() {
726            Rule::jump_op => {
727                let op_str = inner.as_str();
728                let inner_span = inner.as_span();
729                if let Ok(opc) = Opcode::from_str(op_str) {
730                    // Convert Imm variant to Reg variant using BPF_X flag
731                    let reg_opcode = Into::<u8>::into(opc) | BPF_X;
732                    opcode =
733                        Some(
734                            reg_opcode
735                                .try_into()
736                                .map_err(|e| CompileError::BytecodeError {
737                                    error: format!("Invalid opcode 0x{:02x}: {}", reg_opcode, e),
738                                    span: inner_span.start()..inner_span.end(),
739                                    custom_label: None,
740                                })?,
741                        );
742                }
743            }
744            Rule::register => {
745                if dst.is_none() {
746                    dst = Some(parse_register(inner)?);
747                } else {
748                    src = Some(parse_register(inner)?);
749                }
750            }
751            Rule::jump_target => off = Some(parse_jump_target(inner, &HashMap::new())?),
752            _ => {}
753        }
754    }
755
756    Ok(Instruction {
757        opcode: opcode.unwrap_or(Opcode::Exit),
758        dst,
759        src,
760        off,
761        imm: None,
762        span,
763    })
764}
765
766fn process_jump_uncond(
767    pair: Pair<Rule>,
768    const_map: &HashMap<String, Number>,
769    span: std::ops::Range<usize>,
770) -> Result<Instruction, CompileError> {
771    let mut off = None;
772
773    for inner in pair.into_inner() {
774        if inner.as_rule() == Rule::jump_target {
775            off = Some(parse_jump_target(inner, const_map)?);
776        }
777    }
778
779    Ok(Instruction {
780        opcode: Opcode::Ja,
781        dst: None,
782        src: None,
783        off,
784        imm: None,
785        span,
786    })
787}
788
789fn process_call(
790    pair: Pair<Rule>,
791    const_map: &HashMap<String, Number>,
792    span: std::ops::Range<usize>,
793) -> Result<Instruction, CompileError> {
794    let mut imm = None;
795
796    for inner in pair.into_inner() {
797        if inner.as_rule() == Rule::symbol {
798            if let Some(symbol) = const_map.get(inner.as_str()) {
799                imm = Some(Either::Right(symbol.to_owned()));
800            } else {
801                imm = Some(Either::Left(inner.as_str().to_string()));
802            }
803        }
804    }
805
806    Ok(Instruction {
807        opcode: Opcode::Call,
808        dst: None,
809        src: None,
810        off: None,
811        imm,
812        span,
813    })
814}
815
816fn process_callx(
817    pair: Pair<Rule>,
818    span: std::ops::Range<usize>,
819) -> Result<Instruction, CompileError> {
820    let mut dst = None;
821
822    for inner in pair.into_inner() {
823        if inner.as_rule() == Rule::register {
824            dst = Some(parse_register(inner)?);
825        }
826    }
827
828    Ok(Instruction {
829        opcode: Opcode::Callx,
830        dst,
831        src: None,
832        off: None,
833        imm: None,
834        span,
835    })
836}
837
838fn process_neg32(
839    pair: Pair<Rule>,
840    span: std::ops::Range<usize>,
841) -> Result<Instruction, CompileError> {
842    let mut dst = None;
843
844    for inner in pair.into_inner() {
845        if inner.as_rule() == Rule::register {
846            dst = Some(parse_register(inner)?);
847        }
848    }
849
850    Ok(Instruction {
851        opcode: Opcode::Neg32,
852        dst,
853        src: None,
854        off: None,
855        imm: None,
856        span,
857    })
858}
859
860fn process_neg64(
861    pair: Pair<Rule>,
862    span: std::ops::Range<usize>,
863) -> Result<Instruction, CompileError> {
864    let mut dst = None;
865
866    for inner in pair.into_inner() {
867        if inner.as_rule() == Rule::register {
868            dst = Some(parse_register(inner)?);
869        }
870    }
871
872    Ok(Instruction {
873        opcode: Opcode::Neg64,
874        dst,
875        src: None,
876        off: None,
877        imm: None,
878        span,
879    })
880}
881
882fn process_endian(
883    pair: Pair<Rule>,
884    span: std::ops::Range<usize>,
885) -> Result<Instruction, CompileError> {
886    let mut opcode = None;
887    let mut dst = None;
888    let mut imm = None;
889
890    for inner in pair.into_inner() {
891        match inner.as_rule() {
892            Rule::endian_op => {
893                let op_str = inner.as_str();
894                let inner_span = inner.as_span();
895                // Extract opcode and size from instruction (example: "be16" = be opcode, 16 bits)
896                let (opc, size) = if let Some(size_str) = op_str.strip_prefix("be") {
897                    let size = size_str
898                        .parse::<i64>()
899                        .map_err(|_| CompileError::ParseError {
900                            error: format!("Invalid endian size in '{}'", op_str),
901                            span: inner_span.start()..inner_span.end(),
902                            custom_label: None,
903                        })?;
904                    (Opcode::Be, size)
905                } else if let Some(size_str) = op_str.strip_prefix("le") {
906                    let size = size_str
907                        .parse::<i64>()
908                        .map_err(|_| CompileError::ParseError {
909                            error: format!("Invalid endian size in '{}'", op_str),
910                            span: inner_span.start()..inner_span.end(),
911                            custom_label: None,
912                        })?;
913                    (Opcode::Le, size)
914                } else {
915                    return Err(CompileError::ParseError {
916                        error: format!("Invalid endian operation '{}'", op_str),
917                        span: inner_span.start()..inner_span.end(),
918                        custom_label: None,
919                    });
920                };
921                opcode = Some(opc);
922                imm = Some(Either::Right(Number::Int(size)));
923            }
924            Rule::register => dst = Some(parse_register(inner)?),
925            _ => {}
926        }
927    }
928
929    Ok(Instruction {
930        opcode: opcode.unwrap_or(Opcode::Exit),
931        dst,
932        src: None,
933        off: None,
934        imm,
935        span,
936    })
937}
938
939fn parse_register(pair: Pair<Rule>) -> Result<Register, CompileError> {
940    let reg_str = pair.as_str();
941    let span = pair.as_span();
942
943    if let Ok(n) = reg_str[1..].parse::<u8>() {
944        Ok(Register { n })
945    } else {
946        Err(CompileError::InvalidRegister {
947            register: reg_str.to_string(),
948            span: span.start()..span.end(),
949            custom_label: None,
950        })
951    }
952}
953
954fn parse_operand(
955    pair: Pair<Rule>,
956    const_map: &HashMap<String, Number>,
957) -> Result<Either<String, Number>, CompileError> {
958    let span = pair.as_span();
959    let span_range = span.start()..span.end();
960
961    for inner in pair.into_inner() {
962        match inner.as_rule() {
963            Rule::number => return Ok(Either::Right(parse_number(inner)?)),
964            Rule::symbol => {
965                let name = inner.as_str().to_string();
966                if let Some(value) = const_map.get(&name) {
967                    return Ok(Either::Right(value.clone()));
968                }
969                return Ok(Either::Left(name));
970            }
971            Rule::operand_expr => {
972                let mut sym_name = None;
973                let mut num_value = None;
974
975                for expr_inner in inner.into_inner() {
976                    match expr_inner.as_rule() {
977                        Rule::symbol => sym_name = Some(expr_inner.as_str().to_string()),
978                        Rule::number => num_value = Some(parse_number(expr_inner)?),
979                        _ => {}
980                    }
981                }
982
983                if let (Some(sym), Some(num)) = (sym_name, num_value) {
984                    if let Some(base_value) = const_map.get(&sym) {
985                        let result = base_value.clone() + num;
986                        return Ok(Either::Right(result));
987                    } else {
988                        return Ok(Either::Left(sym));
989                    }
990                }
991            }
992            _ => {}
993        }
994    }
995
996    Err(CompileError::ParseError {
997        error: "Invalid operand".to_string(),
998        span: span_range,
999        custom_label: None,
1000    })
1001}
1002
1003fn parse_jump_target(
1004    pair: Pair<Rule>,
1005    _const_map: &HashMap<String, Number>,
1006) -> Result<Either<String, i16>, CompileError> {
1007    let span = pair.as_span();
1008    let span_range = span.start()..span.end();
1009
1010    for inner in pair.into_inner() {
1011        match inner.as_rule() {
1012            Rule::symbol | Rule::numeric_label_ref => {
1013                return Ok(Either::Left(inner.as_str().to_string()));
1014            }
1015            Rule::number => {
1016                let num = parse_number(inner)?;
1017                return Ok(Either::Right(num.to_i16()));
1018            }
1019            _ => {}
1020        }
1021    }
1022
1023    Err(CompileError::ParseError {
1024        error: "Invalid jump target".to_string(),
1025        span: span_range,
1026        custom_label: None,
1027    })
1028}
1029
1030fn parse_memory_ref(
1031    pair: Pair<Rule>,
1032    const_map: &HashMap<String, Number>,
1033) -> Result<(Register, Either<String, i16>), CompileError> {
1034    let mut reg = None;
1035    let mut accumulated_offset: i16 = 0;
1036    let mut unresolved_symbol: Option<String> = None;
1037    let mut sign: i16 = 1;
1038
1039    for inner in pair.into_inner() {
1040        match inner.as_rule() {
1041            Rule::register => {
1042                reg = Some(parse_register(inner)?);
1043            }
1044            Rule::memory_op => {
1045                sign = if inner.as_str() == "+" { 1 } else { -1 };
1046            }
1047            Rule::memory_offset => {
1048                for offset_inner in inner.into_inner() {
1049                    match offset_inner.as_rule() {
1050                        Rule::number => {
1051                            let num = parse_number(offset_inner)?;
1052                            accumulated_offset =
1053                                accumulated_offset.wrapping_add(sign * num.to_i16());
1054                        }
1055                        Rule::symbol => {
1056                            let name = offset_inner.as_str().to_string();
1057                            if let Some(value) = const_map.get(&name) {
1058                                accumulated_offset =
1059                                    accumulated_offset.wrapping_add(sign * value.to_i16());
1060                            } else if unresolved_symbol.is_none() {
1061                                unresolved_symbol = Some(name);
1062                            }
1063                        }
1064                        _ => {}
1065                    }
1066                }
1067            }
1068            _ => {}
1069        }
1070    }
1071
1072    let offset = if let Some(sym) = unresolved_symbol {
1073        Either::Left(sym)
1074    } else {
1075        Either::Right(accumulated_offset)
1076    };
1077
1078    Ok((reg.unwrap_or(Register { n: 0 }), offset))
1079}
1080
1081fn parse_number(pair: Pair<Rule>) -> Result<Number, CompileError> {
1082    let span = pair.as_span();
1083    let span_range = span.start()..span.end();
1084    let number_str = pair.as_str().replace('_', "");
1085
1086    // Try parsing as i64 first
1087    if let Ok(value) = number_str.parse::<i64>() {
1088        return Ok(Number::Int(value));
1089    }
1090
1091    let mut sign: i64 = 1;
1092    let value = if number_str.starts_with('-') {
1093        sign = -1;
1094        number_str.strip_prefix('-').unwrap()
1095    } else {
1096        number_str.as_str()
1097    };
1098
1099    if value.starts_with("0x") {
1100        let hex_str = value.trim_start_matches("0x");
1101        if let Ok(value) = u64::from_str_radix(hex_str, 16) {
1102            return Ok(Number::Addr(sign * (value as i64)));
1103        }
1104    }
1105
1106    Err(CompileError::InvalidNumber {
1107        number: number_str,
1108        span: span_range,
1109        custom_label: None,
1110    })
1111}
1112
1113fn eval_expression(
1114    pair: Pair<Rule>,
1115    const_map: &HashMap<String, Number>,
1116) -> Result<Number, CompileError> {
1117    let span = pair.as_span();
1118    let span_range = span.start()..span.end();
1119
1120    let mut stack = Vec::new();
1121    let mut op_stack = Vec::new();
1122
1123    for inner in pair.into_inner() {
1124        match inner.as_rule() {
1125            Rule::term => {
1126                let val = eval_term(inner, const_map)?;
1127                stack.push(val);
1128            }
1129            Rule::bin_op => {
1130                op_stack.push(inner.as_str());
1131            }
1132            _ => {}
1133        }
1134    }
1135
1136    // Apply operators
1137    while let Some(op) = op_stack.pop() {
1138        if stack.len() >= 2 {
1139            let b = stack.pop().unwrap();
1140            let a = stack.pop().unwrap();
1141            let result = match op {
1142                "+" => a + b,
1143                "-" => a - b,
1144                "*" => a * b,
1145                "/" => a / b,
1146                _ => a,
1147            };
1148            stack.push(result);
1149        }
1150    }
1151
1152    stack.pop().ok_or_else(|| CompileError::ParseError {
1153        error: "Invalid expression".to_string(),
1154        span: span_range,
1155        custom_label: None,
1156    })
1157}
1158
1159fn eval_term(
1160    pair: Pair<Rule>,
1161    const_map: &HashMap<String, Number>,
1162) -> Result<Number, CompileError> {
1163    let span = pair.as_span();
1164    let span_range = span.start()..span.end();
1165
1166    for inner in pair.into_inner() {
1167        match inner.as_rule() {
1168            Rule::expression => {
1169                return eval_expression(inner, const_map);
1170            }
1171            Rule::number => {
1172                return parse_number(inner);
1173            }
1174            Rule::symbol => {
1175                let name = inner.as_str().to_string();
1176                if let Some(value) = const_map.get(&name) {
1177                    return Ok(value.clone());
1178                }
1179                return Err(CompileError::ParseError {
1180                    error: format!("Undefined constant: {}", name),
1181                    span: inner.as_span().start()..inner.as_span().end(),
1182                    custom_label: None,
1183                });
1184            }
1185            _ => {}
1186        }
1187    }
1188
1189    Err(CompileError::ParseError {
1190        error: "Invalid term".to_string(),
1191        span: span_range,
1192        custom_label: None,
1193    })
1194}