1use crate::error::OptimizeError;
7use std::collections::HashMap;
8
9type BatchProcessor = Box<dyn Fn(&[TapeNode]) -> Result<(), OptimizeError>>;
11
12#[derive(Debug, Clone)]
14pub struct Variable {
15 pub id: usize,
17 pub value: f64,
19}
20
21impl Variable {
22 pub fn new(id: usize, value: f64) -> Self {
24 Self { id, value }
25 }
26}
27
28#[derive(Debug, Clone, Copy)]
30pub enum UnaryOpType {
31 Neg,
33 Ln,
35 Exp,
37 Sin,
39 Cos,
41 Tan,
43 Sqrt,
45 Square,
47 Reciprocal,
49}
50
51#[derive(Debug, Clone, Copy)]
53pub enum BinaryOpType {
54 Add,
56 Sub,
58 Mul,
60 Div,
62 Pow,
64}
65
66#[derive(Debug, Clone)]
68pub enum TapeNode {
69 Input { var_id: usize },
71 Constant { value: f64, result: usize },
73 UnaryOp {
75 op_type: UnaryOpType,
76 input: usize,
77 result: usize,
78 partial: f64, },
80 BinaryOp {
82 op_type: BinaryOpType,
83 left: usize,
84 right: usize,
85 result: usize,
86 left_partial: f64, right_partial: f64, },
89 NAryOp {
91 inputs: Vec<usize>,
92 result: usize,
93 partials: Vec<f64>, },
95}
96
97#[derive(Debug)]
99pub struct ComputationTape {
100 nodes: Vec<TapeNode>,
102 inputs: Vec<Variable>,
104 var_positions: HashMap<usize, usize>,
106 max_var_id: usize,
108}
109
110impl ComputationTape {
111 pub fn new() -> Self {
113 Self {
114 nodes: Vec::new(),
115 inputs: Vec::new(),
116 var_positions: HashMap::new(),
117 max_var_id: 0,
118 }
119 }
120
121 pub fn add_input(&mut self, var: Variable) {
123 self.var_positions.insert(var.id, self.nodes.len());
124 self.max_var_id = self.max_var_id.max(var.id);
125
126 self.nodes.push(TapeNode::Input { var_id: var.id });
127 self.inputs.push(var);
128 }
129
130 pub fn add_node(&mut self, node: TapeNode) {
132 match &node {
134 TapeNode::Constant { result, .. } => {
135 self.var_positions.insert(*result, self.nodes.len());
136 self.max_var_id = self.max_var_id.max(*result);
137 }
138 TapeNode::UnaryOp { result, .. } => {
139 self.var_positions.insert(*result, self.nodes.len());
140 self.max_var_id = self.max_var_id.max(*result);
141 }
142 TapeNode::BinaryOp { result, .. } => {
143 self.var_positions.insert(*result, self.nodes.len());
144 self.max_var_id = self.max_var_id.max(*result);
145 }
146 TapeNode::NAryOp { result, .. } => {
147 self.var_positions.insert(*result, self.nodes.len());
148 self.max_var_id = self.max_var_id.max(*result);
149 }
150 _ => {}
151 }
152
153 self.nodes.push(node);
154 }
155
156 pub fn backward(&self, gradients: &mut Vec<f64>) -> Result<(), OptimizeError> {
158 if gradients.len() <= self.max_var_id {
160 gradients.resize(self.max_var_id + 1, 0.0);
161 }
162
163 for node in self.nodes.iter().rev() {
165 match node {
166 TapeNode::Input { .. } => {
167 }
169 TapeNode::Constant { .. } => {
170 }
172 TapeNode::UnaryOp {
173 op_type: _,
174 input,
175 result,
176 partial,
177 } => {
178 if *input != usize::MAX && *input < gradients.len() {
181 gradients[*input] += gradients[*result] * partial;
182 }
183 }
184 TapeNode::BinaryOp {
185 op_type: _,
186 left,
187 right,
188 result,
189 left_partial,
190 right_partial,
191 } => {
192 if *left != usize::MAX && *left < gradients.len() {
195 gradients[*left] += gradients[*result] * left_partial;
196 }
197 if *right != usize::MAX && *right < gradients.len() {
198 gradients[*right] += gradients[*result] * right_partial;
199 }
200 }
201 TapeNode::NAryOp {
202 inputs,
203 result,
204 partials,
205 } => {
206 for (input_id, partial) in inputs.iter().zip(partials.iter()) {
209 if *input_id != usize::MAX && *input_id < gradients.len() {
210 gradients[*input_id] += gradients[*result] * partial;
211 }
212 }
213 }
214 }
215 }
216
217 Ok(())
218 }
219
220 pub fn forward(&self, input_values: &[f64]) -> Result<Vec<f64>, OptimizeError> {
222 let mut values = vec![0.0; self.max_var_id + 1];
223
224 for (i, var) in self.inputs.iter().enumerate() {
226 if i < input_values.len() {
227 values[var.id] = input_values[i];
228 } else {
229 values[var.id] = var.value; }
231 }
232
233 for node in &self.nodes {
235 match node {
236 TapeNode::Input { .. } => {
237 }
239 TapeNode::Constant { value, result } => {
240 values[*result] = *value;
242 }
243 TapeNode::UnaryOp {
244 op_type,
245 input,
246 result,
247 ..
248 } => {
249 let input_val = values[*input];
251 values[*result] = match op_type {
252 UnaryOpType::Neg => -input_val,
253 UnaryOpType::Ln => input_val.ln(),
254 UnaryOpType::Exp => input_val.exp(),
255 UnaryOpType::Sin => input_val.sin(),
256 UnaryOpType::Cos => input_val.cos(),
257 UnaryOpType::Tan => input_val.tan(),
258 UnaryOpType::Sqrt => input_val.sqrt(),
259 UnaryOpType::Square => input_val * input_val,
260 UnaryOpType::Reciprocal => 1.0 / input_val,
261 };
262 }
263 TapeNode::BinaryOp {
264 op_type,
265 left,
266 right,
267 result,
268 ..
269 } => {
270 let left_val = values[*left];
272 let right_val = values[*right];
273 values[*result] = match op_type {
274 BinaryOpType::Add => left_val + right_val,
275 BinaryOpType::Sub => left_val - right_val,
276 BinaryOpType::Mul => left_val * right_val,
277 BinaryOpType::Div => left_val / right_val,
278 BinaryOpType::Pow => left_val.powf(right_val),
279 };
280 }
281 TapeNode::NAryOp { inputs, result, .. } => {
282 values[*result] = inputs.iter().map(|&id| values[id]).sum();
285 }
286 }
287 }
288
289 Ok(values)
290 }
291
292 pub fn add_constant(&mut self, value: f64) -> usize {
294 let result_id = self.max_var_id + 1;
295 self.add_node(TapeNode::Constant {
296 value,
297 result: result_id,
298 });
299 result_id
300 }
301
302 pub fn add_unary_op(
304 &mut self,
305 op_type: UnaryOpType,
306 input: usize,
307 input_values: &[f64],
308 ) -> usize {
309 let result_id = self.max_var_id + 1;
310
311 let input_val = input_values[input];
313 let partial = match op_type {
314 UnaryOpType::Neg => -1.0,
315 UnaryOpType::Ln => 1.0 / input_val,
316 UnaryOpType::Exp => input_val.exp(),
317 UnaryOpType::Sin => input_val.cos(),
318 UnaryOpType::Cos => -input_val.sin(),
319 UnaryOpType::Tan => 1.0 + input_val.tan().powi(2), UnaryOpType::Sqrt => 1.0 / (2.0 * input_val.sqrt()),
321 UnaryOpType::Square => 2.0 * input_val,
322 UnaryOpType::Reciprocal => -1.0 / (input_val * input_val),
323 };
324
325 self.add_node(TapeNode::UnaryOp {
326 op_type,
327 input,
328 result: result_id,
329 partial,
330 });
331
332 result_id
333 }
334
335 pub fn add_binary_op(
337 &mut self,
338 op_type: BinaryOpType,
339 left: usize,
340 right: usize,
341 input_values: &[f64],
342 ) -> usize {
343 let result_id = self.max_var_id + 1;
344
345 let left_val = input_values[left];
347 let right_val = input_values[right];
348
349 let (left_partial, right_partial) = match op_type {
350 BinaryOpType::Add => (1.0, 1.0),
351 BinaryOpType::Sub => (1.0, -1.0),
352 BinaryOpType::Mul => (right_val, left_val),
353 BinaryOpType::Div => (1.0 / right_val, -left_val / (right_val * right_val)),
354 BinaryOpType::Pow => {
355 (
358 right_val * left_val.powf(right_val - 1.0),
359 left_val.powf(right_val) * left_val.ln(),
360 )
361 }
362 };
363
364 self.add_node(TapeNode::BinaryOp {
365 op_type,
366 left,
367 right,
368 result: result_id,
369 left_partial,
370 right_partial,
371 });
372
373 result_id
374 }
375
376 pub fn forward_ad(
378 &self,
379 input_values: &[f64],
380 seed_derivatives: &[f64],
381 ) -> Result<(Vec<f64>, Vec<f64>), OptimizeError> {
382 let mut values = vec![0.0; self.max_var_id + 1];
383 let mut derivatives = vec![0.0; self.max_var_id + 1];
384
385 for (i, var) in self.inputs.iter().enumerate() {
387 if i < input_values.len() {
388 values[var.id] = input_values[i];
389 if i < seed_derivatives.len() {
390 derivatives[var.id] = seed_derivatives[i];
391 }
392 } else {
393 values[var.id] = var.value;
394 }
395 }
396
397 for node in &self.nodes {
399 match node {
400 TapeNode::Input { .. } => {
401 }
403 TapeNode::Constant { value, result } => {
404 values[*result] = *value;
406 derivatives[*result] = 0.0;
407 }
408 TapeNode::UnaryOp {
409 op_type,
410 input,
411 result,
412 ..
413 } => {
414 let input_val = values[*input];
416 let input_deriv = derivatives[*input];
417
418 values[*result] = match op_type {
420 UnaryOpType::Neg => -input_val,
421 UnaryOpType::Ln => input_val.ln(),
422 UnaryOpType::Exp => input_val.exp(),
423 UnaryOpType::Sin => input_val.sin(),
424 UnaryOpType::Cos => input_val.cos(),
425 UnaryOpType::Tan => input_val.tan(),
426 UnaryOpType::Sqrt => input_val.sqrt(),
427 UnaryOpType::Square => input_val * input_val,
428 UnaryOpType::Reciprocal => 1.0 / input_val,
429 };
430
431 let f_prime = match op_type {
433 UnaryOpType::Neg => -1.0,
434 UnaryOpType::Ln => 1.0 / input_val,
435 UnaryOpType::Exp => input_val.exp(),
436 UnaryOpType::Sin => input_val.cos(),
437 UnaryOpType::Cos => -input_val.sin(),
438 UnaryOpType::Tan => 1.0 + input_val.tan().powi(2),
439 UnaryOpType::Sqrt => 1.0 / (2.0 * input_val.sqrt()),
440 UnaryOpType::Square => 2.0 * input_val,
441 UnaryOpType::Reciprocal => -1.0 / (input_val * input_val),
442 };
443 derivatives[*result] = f_prime * input_deriv;
444 }
445 TapeNode::BinaryOp {
446 op_type,
447 left,
448 right,
449 result,
450 ..
451 } => {
452 let left_val = values[*left];
454 let right_val = values[*right];
455 let left_deriv = derivatives[*left];
456 let right_deriv = derivatives[*right];
457
458 values[*result] = match op_type {
460 BinaryOpType::Add => left_val + right_val,
461 BinaryOpType::Sub => left_val - right_val,
462 BinaryOpType::Mul => left_val * right_val,
463 BinaryOpType::Div => left_val / right_val,
464 BinaryOpType::Pow => left_val.powf(right_val),
465 };
466
467 derivatives[*result] = match op_type {
469 BinaryOpType::Add => left_deriv + right_deriv,
470 BinaryOpType::Sub => left_deriv - right_deriv,
471 BinaryOpType::Mul => left_deriv * right_val + left_val * right_deriv,
472 BinaryOpType::Div => {
473 (left_deriv * right_val - left_val * right_deriv)
474 / (right_val * right_val)
475 }
476 BinaryOpType::Pow => {
477 let result_val = left_val.powf(right_val);
479 result_val
480 * (right_deriv * left_val.ln() + right_val * left_deriv / left_val)
481 }
482 };
483 }
484 TapeNode::NAryOp {
485 inputs,
486 result,
487 partials,
488 } => {
489 values[*result] = inputs.iter().map(|&id| values[id]).sum();
491 derivatives[*result] = inputs
492 .iter()
493 .enumerate()
494 .map(|(i, &id)| partials.get(i).unwrap_or(&1.0) * derivatives[id])
495 .sum();
496 }
497 }
498 }
499
500 Ok((values, derivatives))
501 }
502
503 pub fn optimize(&mut self) {
505 let mut used_vars = std::collections::HashSet::new();
510
511 for node in &self.nodes {
513 match node {
514 TapeNode::UnaryOp { input, result, .. } => {
515 used_vars.insert(*input);
516 used_vars.insert(*result);
517 }
518 TapeNode::BinaryOp {
519 left,
520 right,
521 result,
522 ..
523 } => {
524 used_vars.insert(*left);
525 used_vars.insert(*right);
526 used_vars.insert(*result);
527 }
528 TapeNode::NAryOp { inputs, result, .. } => {
529 for &input_id in inputs {
530 used_vars.insert(input_id);
531 }
532 used_vars.insert(*result);
533 }
534 TapeNode::Input { var_id } => {
535 used_vars.insert(*var_id);
536 }
537 _ => {}
538 }
539 }
540
541 }
543
544 pub fn size(&self) -> usize {
546 self.nodes.len()
547 }
548
549 pub fn is_empty(&self) -> bool {
551 self.nodes.is_empty()
552 }
553
554 pub fn clear(&mut self) {
556 self.nodes.clear();
557 self.inputs.clear();
558 self.var_positions.clear();
559 self.max_var_id = 0;
560 }
561
562 pub fn get_stats(&self) -> TapeStats {
564 let mut unary_ops = 0;
565 let mut binary_ops = 0;
566 let mut nary_ops = 0;
567 let mut constants = 0;
568
569 for node in &self.nodes {
570 match node {
571 TapeNode::Input { .. } => {}
572 TapeNode::Constant { .. } => constants += 1,
573 TapeNode::UnaryOp { .. } => unary_ops += 1,
574 TapeNode::BinaryOp { .. } => binary_ops += 1,
575 TapeNode::NAryOp { .. } => nary_ops += 1,
576 }
577 }
578
579 TapeStats {
580 total_nodes: self.nodes.len(),
581 input_vars: self.inputs.len(),
582 unary_ops,
583 binary_ops,
584 nary_ops,
585 constants,
586 max_var_id: self.max_var_id,
587 }
588 }
589}
590
591impl Default for ComputationTape {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[derive(Debug, Clone)]
599pub struct TapeStats {
600 pub total_nodes: usize,
602 pub input_vars: usize,
604 pub unary_ops: usize,
606 pub binary_ops: usize,
608 pub nary_ops: usize,
610 pub constants: usize,
612 pub max_var_id: usize,
614}
615
616pub struct TapeBuilder {
618 tape: ComputationTape,
619 next_var_id: usize,
620}
621
622impl TapeBuilder {
623 pub fn new() -> Self {
625 Self {
626 tape: ComputationTape::new(),
627 next_var_id: 0,
628 }
629 }
630
631 pub fn input(&mut self, value: f64) -> usize {
633 let var_id = self.next_var_id;
634 self.next_var_id += 1;
635
636 let var = Variable::new(var_id, value);
637 self.tape.add_input(var);
638
639 var_id
640 }
641
642 pub fn unary_op(&mut self, op_type: UnaryOpType, input: usize, partial: f64) -> usize {
644 let result_id = self.next_var_id;
645 self.next_var_id += 1;
646
647 let node = TapeNode::UnaryOp {
648 op_type,
649 input,
650 result: result_id,
651 partial,
652 };
653 self.tape.add_node(node);
654
655 result_id
656 }
657
658 pub fn binary_op(
660 &mut self,
661 op_type: BinaryOpType,
662 left: usize,
663 right: usize,
664 left_partial: f64,
665 right_partial: f64,
666 ) -> usize {
667 let result_id = self.next_var_id;
668 self.next_var_id += 1;
669
670 let node = TapeNode::BinaryOp {
671 op_type,
672 left,
673 right,
674 result: result_id,
675 left_partial,
676 right_partial,
677 };
678 self.tape.add_node(node);
679
680 result_id
681 }
682
683 pub fn build(self) -> ComputationTape {
685 self.tape
686 }
687}
688
689impl Default for TapeBuilder {
690 fn default() -> Self {
691 Self::new()
692 }
693}
694
695pub struct StreamingTape {
697 current_batch: Vec<TapeNode>,
699 batch_size: usize,
701 batch_processor: Option<BatchProcessor>,
703}
704
705impl StreamingTape {
706 pub fn new(batch_size: usize) -> Self {
708 Self {
709 current_batch: Vec::with_capacity(batch_size),
710 batch_size,
711 batch_processor: None,
712 }
713 }
714
715 pub fn set_batch_processor<F>(&mut self, processor: F)
717 where
718 F: Fn(&[TapeNode]) -> Result<(), OptimizeError> + 'static,
719 {
720 self.batch_processor = Some(Box::new(processor));
721 }
722
723 pub fn add_node(&mut self, node: TapeNode) -> Result<(), OptimizeError> {
725 self.current_batch.push(node);
726
727 if self.current_batch.len() >= self.batch_size {
728 self.flush_batch()?;
729 }
730
731 Ok(())
732 }
733
734 pub fn flush_batch(&mut self) -> Result<(), OptimizeError> {
736 if let Some(ref processor) = self.batch_processor {
737 processor(&self.current_batch)?;
738 }
739 self.current_batch.clear();
740 Ok(())
741 }
742
743 pub fn finalize(&mut self) -> Result<(), OptimizeError> {
745 if !self.current_batch.is_empty() {
746 self.flush_batch()?;
747 }
748 Ok(())
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755
756 #[test]
757 fn test_tape_construction() {
758 let mut builder = TapeBuilder::new();
759
760 let x = builder.input(2.0);
762 let y = builder.input(3.0);
763 let sum = builder.binary_op(BinaryOpType::Add, x, y, 1.0, 1.0); let _result = builder.binary_op(BinaryOpType::Mul, sum, x, 2.0, 5.0); let tape = builder.build();
767
768 assert_eq!(tape.size(), 4); let stats = tape.get_stats();
771 assert_eq!(stats.input_vars, 2);
772 assert_eq!(stats.binary_ops, 2);
773 }
774
775 #[test]
776 fn test_backward_pass() {
777 let mut tape = ComputationTape::new();
778
779 tape.add_input(Variable::new(0, 2.0));
781 tape.add_input(Variable::new(1, 3.0));
782
783 tape.add_node(TapeNode::BinaryOp {
785 op_type: BinaryOpType::Add,
786 left: 0,
787 right: 1,
788 result: 2,
789 left_partial: 1.0, right_partial: 1.0, });
792
793 let mut gradients = vec![0.0, 0.0, 1.0];
795
796 tape.backward(&mut gradients).unwrap();
797
798 assert_eq!(gradients[0], 1.0); assert_eq!(gradients[1], 1.0); }
802
803 #[test]
804 fn test_tape_optimization() {
805 let mut tape = ComputationTape::new();
806
807 tape.add_input(Variable::new(0, 1.0));
808 tape.add_node(TapeNode::UnaryOp {
809 op_type: UnaryOpType::Neg,
810 input: 0,
811 result: 1,
812 partial: 1.0,
813 });
814
815 let original_size = tape.size();
816 tape.optimize();
817
818 assert!(tape.size() <= original_size);
820 }
821
822 #[test]
823 fn test_streaming_tape() {
824 let mut streaming_tape = StreamingTape::new(2);
825
826 streaming_tape.set_batch_processor(move |_batch| {
827 Ok(())
829 });
830
831 streaming_tape
833 .add_node(TapeNode::Input { var_id: 0 })
834 .unwrap();
835 streaming_tape
836 .add_node(TapeNode::Input { var_id: 1 })
837 .unwrap();
838
839 streaming_tape
841 .add_node(TapeNode::UnaryOp {
842 op_type: UnaryOpType::Neg,
843 input: 0,
844 result: 2,
845 partial: 1.0,
846 })
847 .unwrap();
848
849 streaming_tape.finalize().unwrap();
850 }
851
852 #[test]
853 fn test_tape_stats() {
854 let mut builder = TapeBuilder::new();
855
856 let x = builder.input(1.0);
857 let y = builder.input(2.0);
858 builder.binary_op(BinaryOpType::Add, x, y, 1.0, 1.0);
859 builder.unary_op(UnaryOpType::Neg, x, 2.0);
860
861 let tape = builder.build();
862 let stats = tape.get_stats();
863
864 assert_eq!(stats.input_vars, 2);
865 assert_eq!(stats.binary_ops, 1);
866 assert_eq!(stats.unary_ops, 1);
867 assert_eq!(stats.total_nodes, 4); }
869}