1#![forbid(unsafe_code)]
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum Trit {
11 Neg = -1,
12 Zero = 0,
13 Pos = 1,
14}
15
16impl Trit {
17 pub fn from_i8(v: i8) -> Option<Self> {
18 match v {
19 -1 => Some(Trit::Neg),
20 0 => Some(Trit::Zero),
21 1 => Some(Trit::Pos),
22 _ => None,
23 }
24 }
25 pub fn to_i8(self) -> i8 { self as i8 }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub enum Op {
31 LoadConst(Trit),
33 Load(usize),
35 Store(usize),
37 Add,
39 Mul,
41 Neg,
43 Nop,
45 Jump(usize),
47 JumpIfZero(usize),
49 JumpIfNeg(usize),
51 JumpIfPos(usize),
53 Input,
55 Output,
57 Halt,
59}
60
61#[derive(Debug, Clone)]
63pub struct Program {
64 pub instructions: Vec<Op>,
65}
66
67impl Program {
68 pub fn new(instructions: Vec<Op>) -> Self {
69 Self { instructions }
70 }
71
72 pub fn len(&self) -> usize {
73 self.instructions.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.instructions.is_empty()
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct OptimizationResult {
84 pub program: Program,
85 pub passes_applied: Vec<String>,
86 pub trits_eliminated: usize,
87}
88
89pub fn dead_trit_elimination(program: &Program) -> Program {
97 let mut used_regs: std::collections::HashSet<usize> = std::collections::HashSet::new();
98 let mut used_labels: std::collections::HashSet<usize> = std::collections::HashSet::new();
99
100 for op in &program.instructions {
102 match op {
103 Op::Load(reg) => { used_regs.insert(*reg); }
104 Op::Jump(target) | Op::JumpIfZero(target) | Op::JumpIfNeg(target) | Op::JumpIfPos(target) => {
105 used_labels.insert(*target);
106 }
107 _ => {}
108 }
109 }
110
111 let mut needed = std::collections::HashSet::new();
113 for op in program.instructions.iter().rev() {
114 match op {
115 Op::Load(reg) => { if needed.contains(reg) { used_regs.insert(*reg); } }
116 Op::Store(reg) => { if used_regs.contains(reg) { needed.insert(*reg); } }
117 _ => {}
118 }
119 }
120
121 let mut result = Vec::new();
123 for (i, op) in program.instructions.iter().enumerate() {
124 match op {
125 Op::Nop => {} Op::LoadConst(_) => result.push(*op), Op::Load(reg) => {
128 if used_regs.contains(reg) {
129 result.push(*op);
130 }
131 }
132 _ => result.push(*op),
133 }
134 let _ = used_labels.contains(&i); }
137
138 Program::new(result)
141}
142
143pub fn constant_folding(program: &Program) -> Program {
151 let mut result: Vec<Op> = Vec::new();
152 let mut i = 0;
153 let instrs = &program.instructions;
154
155 while i < instrs.len() {
156 match (instrs.get(i), instrs.get(i + 1), instrs.get(i + 2)) {
157 (Some(Op::LoadConst(a)), Some(Op::LoadConst(b)), Some(Op::Add)) => {
159 let sum = (a.to_i8() + b.to_i8()).clamp(-1, 1);
160 result.push(Op::LoadConst(Trit::from_i8(sum).unwrap_or(Trit::Zero)));
161 i += 3;
162 }
163 (Some(Op::LoadConst(a)), Some(Op::LoadConst(b)), Some(Op::Mul)) => {
165 let product = a.to_i8() * b.to_i8();
166 result.push(Op::LoadConst(Trit::from_i8(product.clamp(-1, 1)).unwrap_or(Trit::Zero)));
167 i += 3;
168 }
169 (Some(Op::LoadConst(a)), Some(Op::Neg), _) => {
171 let neg = match a {
172 Trit::Pos => Trit::Neg,
173 Trit::Neg => Trit::Pos,
174 Trit::Zero => Trit::Zero,
175 };
176 result.push(Op::LoadConst(neg));
177 i += 2;
178 }
179 (Some(Op::LoadConst(Trit::Zero)), Some(Op::Mul), _) => {
180 result.push(Op::LoadConst(Trit::Zero));
182 i += 2;
183 }
184 _ => {
185 result.push(instrs[i]);
186 i += 1;
187 }
188 }
189 }
190
191 Program::new(result)
192}
193
194pub fn trit_merging(program: &Program) -> Program {
203 let mut result: Vec<Op> = Vec::new();
204 let mut i = 0;
205 let instrs = &program.instructions;
206
207 while i < instrs.len() {
208 match (instrs.get(i), instrs.get(i + 1)) {
209 (Some(Op::Neg), Some(Op::Neg)) => { i += 2; }
211 (Some(Op::LoadConst(Trit::Zero)), Some(Op::Add)) => { i += 2; }
213 (Some(Op::LoadConst(Trit::Pos)), Some(Op::Mul)) => { i += 2; }
215 (Some(Op::LoadConst(Trit::Zero)), Some(Op::Mul)) => {
217 result.push(Op::LoadConst(Trit::Zero));
218 i += 2;
219 }
220 _ => {
221 result.push(instrs[i]);
222 i += 1;
223 }
224 }
225 }
226
227 Program::new(result)
228}
229
230pub struct PeepholeOptimizer {
237 pub window_size: usize,
238}
239
240impl PeepholeOptimizer {
241 pub fn new(window_size: usize) -> Self {
242 Self { window_size }
243 }
244
245 pub fn optimize(&self, program: &Program) -> Program {
246 let mut result: Vec<Op> = Vec::new();
247 let instrs = &program.instructions;
248 let mut i = 0;
249
250 while i < instrs.len() {
251 if i + 1 < instrs.len() {
253 if let (Op::Store(r1), Op::Load(r2)) = (&instrs[i], &instrs[i + 1]) {
254 if r1 == r2 {
255 i += 2;
256 continue;
257 }
258 }
259 }
260
261 if i + 2 < instrs.len() {
263 if let (Op::LoadConst(_), Op::Store(_), Op::Load(_)) = (&instrs[i], &instrs[i + 1], &instrs[i + 2]) {
264 result.push(instrs[i]); result.push(instrs[i + 1]); i += 3;
268 continue;
269 }
270 }
271
272 result.push(instrs[i]);
273 i += 1;
274 }
275
276 Program::new(result)
277 }
278}
279
280#[derive(Debug, Clone)]
284pub struct LoopInfo {
285 pub start: usize,
286 pub end: usize,
287 pub back_edge: usize,
288 pub estimated_iterations: Option<usize>,
289}
290
291pub fn detect_loops(program: &Program) -> Vec<LoopInfo> {
293 let mut loops = Vec::new();
294
295 for (i, op) in program.instructions.iter().enumerate() {
296 let target = match op {
297 Op::Jump(t) | Op::JumpIfZero(t) | Op::JumpIfNeg(t) | Op::JumpIfPos(t) => Some(*t),
298 _ => None,
299 };
300
301 if let Some(target) = target {
302 if target <= i {
303 loops.push(LoopInfo {
305 start: target,
306 end: i,
307 back_edge: i,
308 estimated_iterations: None,
309 });
310 }
311 }
312 }
313
314 loops
315}
316
317pub fn detect_loops_with_iterations(program: &Program) -> Vec<LoopInfo> {
319 let mut loops = detect_loops(program);
320
321 for loop_info in &mut loops {
322 if loop_info.start > 0 {
325 if let Some(Op::LoadConst(Trit::Pos)) = program.instructions.get(loop_info.start.saturating_sub(1)) {
326 loop_info.estimated_iterations = Some(1);
327 }
328 }
329 }
330
331 loops
332}
333
334pub struct OptimizationPipeline {
341 pub passes: Vec<Box<dyn Fn(&Program) -> Program>>,
342 pub pass_names: Vec<String>,
343 pub max_iterations: usize,
344}
345
346impl OptimizationPipeline {
347 pub fn new() -> Self {
348 Self {
349 passes: Vec::new(),
350 pass_names: Vec::new(),
351 max_iterations: 10,
352 }
353 }
354
355 pub fn add_pass<F: Fn(&Program) -> Program + 'static>(mut self, name: &str, pass: F) -> Self {
356 self.passes.push(Box::new(pass));
357 self.pass_names.push(name.to_string());
358 self
359 }
360
361 pub fn run_once(&self, program: &Program) -> OptimizationResult {
363 let mut current = program.clone();
364 for pass in &self.passes {
365 current = pass(¤t);
366 }
367 let eliminated = program.len().saturating_sub(current.len());
368 OptimizationResult {
369 program: current,
370 passes_applied: self.pass_names.clone(),
371 trits_eliminated: eliminated,
372 }
373 }
374
375 pub fn run_to_fixed_point(&self, program: &Program) -> OptimizationResult {
377 let mut current = program.clone();
378 let mut total_eliminated = 0;
379 let mut all_applied = Vec::new();
380
381 for _ in 0..self.max_iterations {
382 let prev_len = current.len();
383 for pass in &self.passes {
384 current = pass(¤t);
385 }
386 all_applied.extend(self.pass_names.iter().cloned());
387
388 let eliminated = prev_len.saturating_sub(current.len());
389 total_eliminated += eliminated;
390
391 if current.len() == prev_len {
392 break;
393 }
394 }
395
396 OptimizationResult {
397 program: current,
398 passes_applied: all_applied,
399 trits_eliminated: total_eliminated,
400 }
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_trit_from_i8() {
410 assert_eq!(Trit::from_i8(-1), Some(Trit::Neg));
411 assert_eq!(Trit::from_i8(0), Some(Trit::Zero));
412 assert_eq!(Trit::from_i8(1), Some(Trit::Pos));
413 assert_eq!(Trit::from_i8(2), None);
414 }
415
416 #[test]
417 fn test_dead_trit_elimination_nop() {
418 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Nop, Op::Halt]);
419 let optimized = dead_trit_elimination(&prog);
420 assert_eq!(optimized.len(), 2);
421 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos));
422 assert_eq!(optimized.instructions[1], Op::Halt);
423 }
424
425 #[test]
426 fn test_dead_trit_elimination_unused_load() {
427 let prog = Program::new(vec![Op::Load(5), Op::Store(5), Op::Halt]);
428 let optimized = dead_trit_elimination(&prog);
429 assert!(optimized.len() >= 1);
431 }
432
433 #[test]
434 fn test_constant_folding_add() {
435 let prog = Program::new(vec![
436 Op::LoadConst(Trit::Pos),
437 Op::LoadConst(Trit::Pos),
438 Op::Add,
439 ]);
440 let optimized = constant_folding(&prog);
441 assert_eq!(optimized.len(), 1);
442 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos)); }
444
445 #[test]
446 fn test_constant_folding_mul() {
447 let prog = Program::new(vec![
448 Op::LoadConst(Trit::Neg),
449 Op::LoadConst(Trit::Neg),
450 Op::Mul,
451 ]);
452 let optimized = constant_folding(&prog);
453 assert_eq!(optimized.len(), 1);
454 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos)); }
456
457 #[test]
458 fn test_constant_folding_neg() {
459 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Neg]);
460 let optimized = constant_folding(&prog);
461 assert_eq!(optimized.len(), 1);
462 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Neg));
463 }
464
465 #[test]
466 fn test_constant_folding_neg_zero() {
467 let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Neg]);
468 let optimized = constant_folding(&prog);
469 assert_eq!(optimized.len(), 1);
470 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
471 }
472
473 #[test]
474 fn test_trit_merging_double_neg() {
475 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Neg, Op::Neg, Op::Halt]);
476 let optimized = trit_merging(&prog);
477 assert_eq!(optimized.len(), 2); assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Pos));
479 }
480
481 #[test]
482 fn test_trit_merging_zero_add() {
483 let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Add, Op::Halt]);
484 let optimized = trit_merging(&prog);
485 assert_eq!(optimized.len(), 1); assert_eq!(optimized.instructions[0], Op::Halt);
487 }
488
489 #[test]
490 fn test_trit_merging_pos_mul() {
491 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Mul, Op::Halt]);
492 let optimized = trit_merging(&prog);
493 assert_eq!(optimized.len(), 1); }
495
496 #[test]
497 fn test_trit_merging_zero_mul() {
498 let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Mul, Op::Halt]);
499 let optimized = trit_merging(&prog);
500 assert_eq!(optimized.len(), 2);
501 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
502 }
503
504 #[test]
505 fn test_peephole_store_load() {
506 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Store(0), Op::Load(0), Op::Halt]);
507 let optimizer = PeepholeOptimizer::new(2);
508 let optimized = optimizer.optimize(&prog);
509 assert!(optimized.len() <= 4);
511 }
512
513 #[test]
514 fn test_peephole_preserves_halt() {
515 let prog = Program::new(vec![Op::Halt]);
516 let optimizer = PeepholeOptimizer::new(2);
517 let optimized = optimizer.optimize(&prog);
518 assert_eq!(optimized.len(), 1);
519 assert_eq!(optimized.instructions[0], Op::Halt);
520 }
521
522 #[test]
523 fn test_loop_detection_simple() {
524 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::JumpIfZero(0)]);
527 let loops = detect_loops(&prog);
528 assert_eq!(loops.len(), 1);
529 assert_eq!(loops[0].start, 0);
530 assert_eq!(loops[0].back_edge, 1);
531 }
532
533 #[test]
534 fn test_loop_detection_no_loop() {
535 let prog = Program::new(vec![Op::LoadConst(Trit::Pos), Op::Jump(2), Op::Halt]);
536 let loops = detect_loops(&prog);
537 assert!(loops.is_empty());
538 }
539
540 #[test]
541 fn test_loop_detection_nested() {
542 let prog = Program::new(vec![
546 Op::LoadConst(Trit::Pos),
547 Op::JumpIfZero(0),
548 Op::JumpIfNeg(0),
549 ]);
550 let loops = detect_loops(&prog);
551 assert_eq!(loops.len(), 2);
552 }
553
554 #[test]
555 fn test_optimization_pipeline_single_pass() {
556 let pipeline = OptimizationPipeline::new()
557 .add_pass("dead_trit_elimination", |p| dead_trit_elimination(p))
558 .add_pass("constant_folding", |p| constant_folding(p));
559
560 let prog = Program::new(vec![
561 Op::LoadConst(Trit::Pos),
562 Op::LoadConst(Trit::Pos),
563 Op::Add,
564 Op::Nop,
565 ]);
566 let result = pipeline.run_once(&prog);
567 assert!(result.program.len() < prog.len());
568 assert_eq!(result.passes_applied.len(), 2);
569 }
570
571 #[test]
572 fn test_optimization_pipeline_fixed_point() {
573 let pipeline = OptimizationPipeline::new()
574 .add_pass("constant_folding", |p| constant_folding(p))
575 .add_pass("trit_merging", |p| trit_merging(p))
576 .add_pass("dead_trit_elimination", |p| dead_trit_elimination(p));
577
578 let prog = Program::new(vec![
579 Op::LoadConst(Trit::Pos),
580 Op::LoadConst(Trit::Neg),
581 Op::Mul, Op::Neg, Op::Neg, Op::Nop, ]);
586 let result = pipeline.run_to_fixed_point(&prog);
587 assert!(result.program.len() < prog.len());
588 }
589
590 #[test]
591 fn test_constant_folding_add_neg_pos() {
592 let prog = Program::new(vec![
593 Op::LoadConst(Trit::Neg),
594 Op::LoadConst(Trit::Pos),
595 Op::Add,
596 ]);
597 let optimized = constant_folding(&prog);
598 assert_eq!(optimized.len(), 1);
599 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
600 }
601
602 #[test]
603 fn test_constant_folding_zero_mul_pattern() {
604 let prog = Program::new(vec![Op::LoadConst(Trit::Zero), Op::Mul, Op::Halt]);
605 let optimized = constant_folding(&prog);
606 assert_eq!(optimized.instructions[0], Op::LoadConst(Trit::Zero));
607 }
608
609 #[test]
610 fn test_program_empty() {
611 let prog = Program::new(vec![]);
612 assert!(prog.is_empty());
613 assert_eq!(prog.len(), 0);
614 }
615
616 #[test]
617 fn test_optimization_pipeline_no_change() {
618 let pipeline = OptimizationPipeline::new()
619 .add_pass("constant_folding", |p| constant_folding(p));
620 let prog = Program::new(vec![Op::Halt]);
621 let result = pipeline.run_once(&prog);
622 assert_eq!(result.program.len(), 1);
623 }
624
625 #[test]
626 fn test_detect_loops_with_iterations() {
627 let prog = Program::new(vec![
628 Op::LoadConst(Trit::Pos),
629 Op::JumpIfZero(1),
630 ]);
631 let loops = detect_loops_with_iterations(&prog);
632 assert_eq!(loops.len(), 1);
633 }
635}