Skip to main content

ternary_compiler_optimizer/
lib.rs

1#![forbid(unsafe_code)]
2
3//! Optimization passes for ternary bytecode.
4//!
5//! Provides dead trit elimination, constant folding, trit merging, peephole
6//! optimization, loop detection, and a configurable optimization pipeline.
7
8/// A ternary value used in bytecode constants.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum Trit {
11    Neg = -1,
12    Zero = 0,
13    Pos = 1,
14}
15
16impl Trit {
17    pub fn from_i8(v: i8) -> Option<Self> {
18        match v {
19            -1 => Some(Trit::Neg),
20            0 => Some(Trit::Zero),
21            1 => Some(Trit::Pos),
22            _ => None,
23        }
24    }
25    pub fn to_i8(self) -> i8 { self as i8 }
26}
27
28/// Opcodes for a ternary bytecode virtual machine.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum Op {
31    /// Load a constant trit value.
32    LoadConst(Trit),
33    /// Load from register index.
34    Load(usize),
35    /// Store to register index.
36    Store(usize),
37    /// Add two trits (clamped).
38    Add,
39    /// Multiply two trits.
40    Mul,
41    /// Negate top of stack.
42    Neg,
43    /// No operation.
44    Nop,
45    /// Jump to instruction index.
46    Jump(usize),
47    /// Jump if top of stack is Zero.
48    JumpIfZero(usize),
49    /// Jump if top of stack is Neg.
50    JumpIfNeg(usize),
51    /// Jump if top of stack is Pos.
52    JumpIfPos(usize),
53    /// Push trit from input stream.
54    Input,
55    /// Pop trit to output stream.
56    Output,
57    /// Halt execution.
58    Halt,
59}
60
61/// A program is a sequence of instructions.
62#[derive(Debug, Clone)]
63pub struct Program {
64    pub instructions: Vec<Op>,
65}
66
67impl Program {
68    pub fn new(instructions: Vec<Op>) -> Self {
69        Self { instructions }
70    }
71
72    pub fn len(&self) -> usize {
73        self.instructions.len()
74    }
75
76    pub fn is_empty(&self) -> bool {
77        self.instructions.is_empty()
78    }
79}
80
81/// Result of an optimization pass.
82#[derive(Debug, Clone)]
83pub struct OptimizationResult {
84    pub program: Program,
85    pub passes_applied: Vec<String>,
86    pub trits_eliminated: usize,
87}
88
89// ---- Pass 1: Dead Trit Elimination ----
90
91/// Removes instructions whose results are never used.
92///
93/// Tracks which registers are written but never read, and which
94/// stack values are pushed but consumed by other pushes (Nop patterns).
95/// Removes Nop instructions and unused LoadConst/Load sequences.
96pub fn dead_trit_elimination(program: &Program) -> Program {
97    let mut used_regs: std::collections::HashSet<usize> = std::collections::HashSet::new();
98    let mut used_labels: std::collections::HashSet<usize> = std::collections::HashSet::new();
99
100    // First pass: find which registers are read and which labels are jumped to
101    for op in &program.instructions {
102        match op {
103            Op::Load(reg) => { used_regs.insert(*reg); }
104            Op::Jump(target) | Op::JumpIfZero(target) | Op::JumpIfNeg(target) | Op::JumpIfPos(target) => {
105                used_labels.insert(*target);
106            }
107            _ => {}
108        }
109    }
110
111    // Also mark registers used in Store as "needed" if they're later read
112    let mut needed = std::collections::HashSet::new();
113    for op in program.instructions.iter().rev() {
114        match op {
115            Op::Load(reg) => { if needed.contains(reg) { used_regs.insert(*reg); } }
116            Op::Store(reg) => { if used_regs.contains(reg) { needed.insert(*reg); } }
117            _ => {}
118        }
119    }
120
121    // Second pass: filter out dead instructions
122    let mut result = Vec::new();
123    for (i, op) in program.instructions.iter().enumerate() {
124        match op {
125            Op::Nop => {} // Always eliminate Nops
126            Op::LoadConst(_) => result.push(*op), // Keep for now (may be used by stack ops)
127            Op::Load(reg) => {
128                if used_regs.contains(reg) {
129                    result.push(*op);
130                }
131            }
132            _ => result.push(*op),
133        }
134        // Keep jump targets even if they seem dead
135        let _ = used_labels.contains(&i); // just reference i to avoid warning
136    }
137
138    // Re-add jump target labels as needed (they're index-based, so we need to remap)
139    // For simplicity, we keep all non-Nop, non-dead-load instructions
140    Program::new(result)
141}
142
143// ---- Pass 2: Constant Folding ----
144
145/// Evaluates constant expressions at compile time.
146///
147/// When a sequence like [LoadConst(a), LoadConst(b), Add] appears,
148/// it is replaced with [LoadConst(a+b)] (clamped to ternary).
149/// Similarly for Mul and Neg.
150pub fn constant_folding(program: &Program) -> Program {
151    let mut result: Vec<Op> = Vec::new();
152    let mut i = 0;
153    let instrs = &program.instructions;
154
155    while i < instrs.len() {
156        match (instrs.get(i), instrs.get(i + 1), instrs.get(i + 2)) {
157            // LoadConst(a), LoadConst(b), Add → LoadConst(a+b)
158            (Some(Op::LoadConst(a)), Some(Op::LoadConst(b)), Some(Op::Add)) => {
159                let sum = (a.to_i8() + b.to_i8()).clamp(-1, 1);
160                result.push(Op::LoadConst(Trit::from_i8(sum).unwrap_or(Trit::Zero)));
161                i += 3;
162            }
163            // LoadConst(a), LoadConst(b), Mul → LoadConst(a*b)
164            (Some(Op::LoadConst(a)), Some(Op::LoadConst(b)), Some(Op::Mul)) => {
165                let product = a.to_i8() * b.to_i8();
166                result.push(Op::LoadConst(Trit::from_i8(product.clamp(-1, 1)).unwrap_or(Trit::Zero)));
167                i += 3;
168            }
169            // LoadConst(a), Neg → LoadConst(-a)
170            (Some(Op::LoadConst(a)), Some(Op::Neg), _) => {
171                let neg = match a {
172                    Trit::Pos => Trit::Neg,
173                    Trit::Neg => Trit::Pos,
174                    Trit::Zero => Trit::Zero,
175                };
176                result.push(Op::LoadConst(neg));
177                i += 2;
178            }
179            (Some(Op::LoadConst(Trit::Zero)), Some(Op::Mul), _) => {
180                // LoadConst(Zero) followed by Mul: result is always Zero
181                result.push(Op::LoadConst(Trit::Zero));
182                i += 2;
183            }
184            _ => {
185                result.push(instrs[i]);
186                i += 1;
187            }
188        }
189    }
190
191    Program::new(result)
192}
193
194// ---- Pass 3: Trit Merging ----
195
196/// Merges redundant sequences of ternary operations.
197///
198/// Patterns recognized:
199/// - Double negation (Neg, Neg) → eliminated
200/// - Multiplication by Pos (identity) → eliminated
201/// - Addition of Zero → eliminated
202pub fn trit_merging(program: &Program) -> Program {
203    let mut result: Vec<Op> = Vec::new();
204    let mut i = 0;
205    let instrs = &program.instructions;
206
207    while i < instrs.len() {
208        match (instrs.get(i), instrs.get(i + 1)) {
209            // Double negation
210            (Some(Op::Neg), Some(Op::Neg)) => { i += 2; }
211            // LoadConst(Zero), Add → identity (just skip both)
212            (Some(Op::LoadConst(Trit::Zero)), Some(Op::Add)) => { i += 2; }
213            // LoadConst(Pos), Mul → identity
214            (Some(Op::LoadConst(Trit::Pos)), Some(Op::Mul)) => { i += 2; }
215            // LoadConst(Zero), Mul → replace with LoadConst(Zero)
216            (Some(Op::LoadConst(Trit::Zero)), Some(Op::Mul)) => {
217                result.push(Op::LoadConst(Trit::Zero));
218                i += 2;
219            }
220            _ => {
221                result.push(instrs[i]);
222                i += 1;
223            }
224        }
225    }
226
227    Program::new(result)
228}
229
230// ---- Pass 4: Peephole Optimizer ----
231
232/// A peephole optimizer that examines small windows of instructions
233/// and replaces them with more efficient patterns.
234///
235/// Window size is typically 2-4 instructions.
236pub struct PeepholeOptimizer {
237    pub window_size: usize,
238}
239
240impl PeepholeOptimizer {
241    pub fn new(window_size: usize) -> Self {
242        Self { window_size }
243    }
244
245    pub fn optimize(&self, program: &Program) -> Program {
246        let mut result: Vec<Op> = Vec::new();
247        let instrs = &program.instructions;
248        let mut i = 0;
249
250        while i < instrs.len() {
251            // Pattern: Store(r), Load(r) → nothing (value already there)
252            if i + 1 < instrs.len() {
253                if let (Op::Store(r1), Op::Load(r2)) = (&instrs[i], &instrs[i + 1]) {
254                    if r1 == r2 {
255                        i += 2;
256                        continue;
257                    }
258                }
259            }
260
261            // Pattern: LoadConst(a), Store(r), Load(r) → LoadConst(a), (keep a on stack)
262            if i + 2 < instrs.len() {
263                if let (Op::LoadConst(_), Op::Store(_), Op::Load(_)) = (&instrs[i], &instrs[i + 1], &instrs[i + 2]) {
264                    result.push(instrs[i]); // LoadConst
265                    result.push(instrs[i + 1]); // Store
266                    // Skip Load — value is still on conceptual stack or in register
267                    i += 3;
268                    continue;
269                }
270            }
271
272            result.push(instrs[i]);
273            i += 1;
274        }
275
276        Program::new(result)
277    }
278}
279
280// ---- Pass 5: Loop Detection ----
281
282/// Information about a detected loop in ternary bytecode.
283#[derive(Debug, Clone)]
284pub struct LoopInfo {
285    pub start: usize,
286    pub end: usize,
287    pub back_edge: usize,
288    pub estimated_iterations: Option<usize>,
289}
290
291/// Detects loops in ternary bytecode by finding back-edges (jumps to earlier positions).
292pub fn detect_loops(program: &Program) -> Vec<LoopInfo> {
293    let mut loops = Vec::new();
294
295    for (i, op) in program.instructions.iter().enumerate() {
296        let target = match op {
297            Op::Jump(t) | Op::JumpIfZero(t) | Op::JumpIfNeg(t) | Op::JumpIfPos(t) => Some(*t),
298            _ => None,
299        };
300
301        if let Some(target) = target {
302            if target <= i {
303                // Back-edge detected: loop from target to i
304                loops.push(LoopInfo {
305                    start: target,
306                    end: i,
307                    back_edge: i,
308                    estimated_iterations: None,
309                });
310            }
311        }
312    }
313
314    loops
315}
316
317/// Detects loops and attempts to estimate iteration count for fixed-count loops.
318pub fn detect_loops_with_iterations(program: &Program) -> Vec<LoopInfo> {
319    let mut loops = detect_loops(program);
320
321    for loop_info in &mut loops {
322        // Check for pattern: LoadConst(N), JumpIfZero/Pos/Neg at back_edge
323        // Simple heuristic: look for a LoadConst near the loop start
324        if loop_info.start > 0 {
325            if let Some(Op::LoadConst(Trit::Pos)) = program.instructions.get(loop_info.start.saturating_sub(1)) {
326                loop_info.estimated_iterations = Some(1);
327            }
328        }
329    }
330
331    loops
332}
333
334// ---- Pass 6: Optimization Pipeline ----
335
336/// A configurable pipeline of optimization passes.
337///
338/// Passes are applied in order. The pipeline can be run multiple times
339/// until no further reductions occur (fixed-point iteration).
340pub struct OptimizationPipeline {
341    pub passes: Vec<Box<dyn Fn(&Program) -> Program>>,
342    pub pass_names: Vec<String>,
343    pub max_iterations: usize,
344}
345
346impl OptimizationPipeline {
347    pub fn new() -> Self {
348        Self {
349            passes: Vec::new(),
350            pass_names: Vec::new(),
351            max_iterations: 10,
352        }
353    }
354
355    pub fn add_pass<F: Fn(&Program) -> Program + 'static>(mut self, name: &str, pass: F) -> Self {
356        self.passes.push(Box::new(pass));
357        self.pass_names.push(name.to_string());
358        self
359    }
360
361    /// Run all passes once.
362    pub fn run_once(&self, program: &Program) -> OptimizationResult {
363        let mut current = program.clone();
364        for pass in &self.passes {
365            current = pass(&current);
366        }
367        let eliminated = program.len().saturating_sub(current.len());
368        OptimizationResult {
369            program: current,
370            passes_applied: self.pass_names.clone(),
371            trits_eliminated: eliminated,
372        }
373    }
374
375    /// Run passes repeatedly until the program stops changing or max iterations reached.
376    pub fn run_to_fixed_point(&self, program: &Program) -> OptimizationResult {
377        let mut current = program.clone();
378        let mut total_eliminated = 0;
379        let mut all_applied = Vec::new();
380
381        for _ in 0..self.max_iterations {
382            let prev_len = current.len();
383            for pass in &self.passes {
384                current = pass(&current);
385            }
386            all_applied.extend(self.pass_names.iter().cloned());
387
388            let eliminated = prev_len.saturating_sub(current.len());
389            total_eliminated += eliminated;
390
391            if current.len() == prev_len {
392                break;
393            }
394        }
395
396        OptimizationResult {
397            program: current,
398            passes_applied: all_applied,
399            trits_eliminated: total_eliminated,
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_trit_from_i8() {
410        assert_eq!(Trit::from_i8(-1), Some(Trit::Neg));
411        assert_eq!(Trit::from_i8(0), Some(Trit::Zero));
412        assert_eq!(Trit::from_i8(1), Some(Trit::Pos));
413        assert_eq!(Trit::from_i8(2), None);
414    }
415
416    #[test]
417    fn test_dead_trit_elimination_nop() {
418        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Nop, Op::Halt]);
419        let optimized = dead_trit_elimination(&prog);
420        assert_eq!(optimized.len(), 2);
421        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos));
422        assert_eq!(optimized.instructions[1], Op::Halt);
423    }
424
425    #[test]
426    fn test_dead_trit_elimination_unused_load() {
427        let prog = Program::new(vec![Op::Load(5), Op::Store(5), Op::Halt]);
428        let optimized = dead_trit_elimination(&prog);
429        // Load(5) may or may not be eliminated depending on analysis
430        assert!(optimized.len() >= 1);
431    }
432
433    #[test]
434    fn test_constant_folding_add() {
435        let prog = Program::new(vec![
436            Op::LoadConst(Trit::Pos),
437            Op::LoadConst(Trit::Pos),
438            Op::Add,
439        ]);
440        let optimized = constant_folding(&prog);
441        assert_eq!(optimized.len(), 1);
442        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos)); // 1+1 clamped to 1
443    }
444
445    #[test]
446    fn test_constant_folding_mul() {
447        let prog = Program::new(vec![
448            Op::LoadConst(Trit::Neg),
449            Op::LoadConst(Trit::Neg),
450            Op::Mul,
451        ]);
452        let optimized = constant_folding(&prog);
453        assert_eq!(optimized.len(), 1);
454        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos)); // -1 * -1 = 1
455    }
456
457    #[test]
458    fn test_constant_folding_neg() {
459        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Neg]);
460        let optimized = constant_folding(&prog);
461        assert_eq!(optimized.len(), 1);
462        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Neg));
463    }
464
465    #[test]
466    fn test_constant_folding_neg_zero() {
467        let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Neg]);
468        let optimized = constant_folding(&prog);
469        assert_eq!(optimized.len(), 1);
470        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
471    }
472
473    #[test]
474    fn test_trit_merging_double_neg() {
475        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Neg, Op::Neg, Op::Halt]);
476        let optimized = trit_merging(&prog);
477        assert_eq!(optimized.len(), 2); // Neg, Neg eliminated
478        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos));
479    }
480
481    #[test]
482    fn test_trit_merging_zero_add() {
483        let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Add, Op::Halt]);
484        let optimized = trit_merging(&prog);
485        assert_eq!(optimized.len(), 1); // LoadConst(Zero), Add eliminated
486        assert_eq!(optimized.instructions[0], Op::Halt);
487    }
488
489    #[test]
490    fn test_trit_merging_pos_mul() {
491        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Mul, Op::Halt]);
492        let optimized = trit_merging(&prog);
493        assert_eq!(optimized.len(), 1); // LoadConst(Pos), Mul eliminated
494    }
495
496    #[test]
497    fn test_trit_merging_zero_mul() {
498        let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Mul, Op::Halt]);
499        let optimized = trit_merging(&prog);
500        assert_eq!(optimized.len(), 2);
501        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
502    }
503
504    #[test]
505    fn test_peephole_store_load() {
506        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Store(0), Op::Load(0), Op::Halt]);
507        let optimizer = PeepholeOptimizer::new(2);
508        let optimized = optimizer.optimize(&prog);
509        // Store(0), Load(0) should be merged
510        assert!(optimized.len() <= 4);
511    }
512
513    #[test]
514    fn test_peephole_preserves_halt() {
515        let prog = Program::new(vec![Op::Halt]);
516        let optimizer = PeepholeOptimizer::new(2);
517        let optimized = optimizer.optimize(&prog);
518        assert_eq!(optimized.len(), 1);
519        assert_eq!(optimized.instructions[0], Op::Halt);
520    }
521
522    #[test]
523    fn test_loop_detection_simple() {
524        // 0: LoadConst(Pos)
525        // 1: JumpIfZero(0)
526        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::JumpIfZero(0)]);
527        let loops = detect_loops(&prog);
528        assert_eq!(loops.len(), 1);
529        assert_eq!(loops[0].start, 0);
530        assert_eq!(loops[0].back_edge, 1);
531    }
532
533    #[test]
534    fn test_loop_detection_no_loop() {
535        let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Jump(2), Op::Halt]);
536        let loops = detect_loops(&prog);
537        assert!(loops.is_empty());
538    }
539
540    #[test]
541    fn test_loop_detection_nested() {
542        // 0: LoadConst(Pos)
543        // 1: JumpIfZero(0)   ← inner loop
544        // 2: JumpIfNeg(0)    ← outer loop
545        let prog = Program::new(vec![
546            Op::LoadConst(Trit::Pos),
547            Op::JumpIfZero(0),
548            Op::JumpIfNeg(0),
549        ]);
550        let loops = detect_loops(&prog);
551        assert_eq!(loops.len(), 2);
552    }
553
554    #[test]
555    fn test_optimization_pipeline_single_pass() {
556        let pipeline = OptimizationPipeline::new()
557            .add_pass("dead_trit_elimination", |p| dead_trit_elimination(p))
558            .add_pass("constant_folding", |p| constant_folding(p));
559
560        let prog = Program::new(vec![
561            Op::LoadConst(Trit::Pos),
562            Op::LoadConst(Trit::Pos),
563            Op::Add,
564            Op::Nop,
565        ]);
566        let result = pipeline.run_once(&prog);
567        assert!(result.program.len() < prog.len());
568        assert_eq!(result.passes_applied.len(), 2);
569    }
570
571    #[test]
572    fn test_optimization_pipeline_fixed_point() {
573        let pipeline = OptimizationPipeline::new()
574            .add_pass("constant_folding", |p| constant_folding(p))
575            .add_pass("trit_merging", |p| trit_merging(p))
576            .add_pass("dead_trit_elimination", |p| dead_trit_elimination(p));
577
578        let prog = Program::new(vec![
579            Op::LoadConst(Trit::Pos),
580            Op::LoadConst(Trit::Neg),
581            Op::Mul,           // folded to LoadConst(Neg)
582            Op::Neg,           // Neg of Neg = Pos
583            Op::Neg,           // double neg eliminated
584            Op::Nop,           // eliminated
585        ]);
586        let result = pipeline.run_to_fixed_point(&prog);
587        assert!(result.program.len() < prog.len());
588    }
589
590    #[test]
591    fn test_constant_folding_add_neg_pos() {
592        let prog = Program::new(vec![
593            Op::LoadConst(Trit::Neg),
594            Op::LoadConst(Trit::Pos),
595            Op::Add,
596        ]);
597        let optimized = constant_folding(&prog);
598        assert_eq!(optimized.len(), 1);
599        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
600    }
601
602    #[test]
603    fn test_constant_folding_zero_mul_pattern() {
604        let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Mul, Op::Halt]);
605        let optimized = constant_folding(&prog);
606        assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
607    }
608
609    #[test]
610    fn test_program_empty() {
611        let prog = Program::new(vec![]);
612        assert!(prog.is_empty());
613        assert_eq!(prog.len(), 0);
614    }
615
616    #[test]
617    fn test_optimization_pipeline_no_change() {
618        let pipeline = OptimizationPipeline::new()
619            .add_pass("constant_folding", |p| constant_folding(p));
620        let prog = Program::new(vec![Op::Halt]);
621        let result = pipeline.run_once(&prog);
622        assert_eq!(result.program.len(), 1);
623    }
624
625    #[test]
626    fn test_detect_loops_with_iterations() {
627        let prog = Program::new(vec![
628            Op::LoadConst(Trit::Pos),
629            Op::JumpIfZero(1),
630        ]);
631        let loops = detect_loops_with_iterations(&prog);
632        assert_eq!(loops.len(), 1);
633        // estimated_iterations depends on the heuristic
634    }
635}