scirs2_optimize/automatic_differentiation/
tape.rs

1//! Computational tape for reverse-mode automatic differentiation
2//!
3//! This module implements the tape structure used to record computational
4//! operations for later backpropagation in reverse-mode AD.
5
6use crate::error::OptimizeError;
7use std::collections::HashMap;
8
9/// Type alias for batch processor function
10type BatchProcessor = Box<dyn Fn(&[TapeNode]) -> Result<(), OptimizeError>>;
11
12/// A variable in the computational tape
13#[derive(Debug, Clone)]
14pub struct Variable {
15    /// Unique identifier for this variable
16    pub id: usize,
17    /// Current value
18    pub value: f64,
19}
20
21impl Variable {
22    /// Create a new variable
23    pub fn new(id: usize, value: f64) -> Self {
24        Self { id, value }
25    }
26}
27
28/// Type of unary operation
29#[derive(Debug, Clone, Copy)]
30pub enum UnaryOpType {
31    /// Negation: -x
32    Neg,
33    /// Natural logarithm: ln(x)
34    Ln,
35    /// Exponential: exp(x)
36    Exp,
37    /// Sine: sin(x)
38    Sin,
39    /// Cosine: cos(x)
40    Cos,
41    /// Tangent: tan(x)
42    Tan,
43    /// Square root: sqrt(x)
44    Sqrt,
45    /// Square: x^2
46    Square,
47    /// Reciprocal: 1/x
48    Reciprocal,
49}
50
51/// Type of binary operation
52#[derive(Debug, Clone, Copy)]
53pub enum BinaryOpType {
54    /// Addition: x + y
55    Add,
56    /// Subtraction: x - y
57    Sub,
58    /// Multiplication: x * y
59    Mul,
60    /// Division: x / y
61    Div,
62    /// Power: x^y
63    Pow,
64}
65
66/// A node in the computational tape representing an operation
67#[derive(Debug, Clone)]
68pub enum TapeNode {
69    /// Input variable (leaf node)
70    Input { var_id: usize },
71    /// Constant value
72    Constant { value: f64, result: usize },
73    /// Unary operation
74    UnaryOp {
75        op_type: UnaryOpType,
76        input: usize,
77        result: usize,
78        partial: f64, // ∂result/∂input
79    },
80    /// Binary operation
81    BinaryOp {
82        op_type: BinaryOpType,
83        left: usize,
84        right: usize,
85        result: usize,
86        left_partial: f64,  // ∂result/∂left
87        right_partial: f64, // ∂result/∂right
88    },
89    /// N-ary operation (for efficiency with many inputs)
90    NAryOp {
91        inputs: Vec<usize>,
92        result: usize,
93        partials: Vec<f64>, // ∂result/∂inputs[i]
94    },
95}
96
97/// Computational tape for recording operations
98#[derive(Debug)]
99pub struct ComputationTape {
100    /// Sequence of operations in forward order
101    nodes: Vec<TapeNode>,
102    /// Input variables
103    inputs: Vec<Variable>,
104    /// Mapping from variable ID to its position in the tape
105    var_positions: HashMap<usize, usize>,
106    /// Maximum variable ID used
107    max_var_id: usize,
108}
109
110impl ComputationTape {
111    /// Create a new empty tape
112    pub fn new() -> Self {
113        Self {
114            nodes: Vec::new(),
115            inputs: Vec::new(),
116            var_positions: HashMap::new(),
117            max_var_id: 0,
118        }
119    }
120
121    /// Add an input variable to the tape
122    pub fn add_input(&mut self, var: Variable) {
123        self.var_positions.insert(var.id, self.nodes.len());
124        self.max_var_id = self.max_var_id.max(var.id);
125
126        self.nodes.push(TapeNode::Input { var_id: var.id });
127        self.inputs.push(var);
128    }
129
130    /// Add a computation node to the tape
131    pub fn add_node(&mut self, node: TapeNode) {
132        // Update variable positions for result variables
133        match &node {
134            TapeNode::Constant { result, .. } => {
135                self.var_positions.insert(*result, self.nodes.len());
136                self.max_var_id = self.max_var_id.max(*result);
137            }
138            TapeNode::UnaryOp { result, .. } => {
139                self.var_positions.insert(*result, self.nodes.len());
140                self.max_var_id = self.max_var_id.max(*result);
141            }
142            TapeNode::BinaryOp { result, .. } => {
143                self.var_positions.insert(*result, self.nodes.len());
144                self.max_var_id = self.max_var_id.max(*result);
145            }
146            TapeNode::NAryOp { result, .. } => {
147                self.var_positions.insert(*result, self.nodes.len());
148                self.max_var_id = self.max_var_id.max(*result);
149            }
150            _ => {}
151        }
152
153        self.nodes.push(node);
154    }
155
156    /// Perform backpropagation to compute gradients
157    pub fn backward(&self, gradients: &mut Vec<f64>) -> Result<(), OptimizeError> {
158        // Ensure gradients vector is large enough
159        if gradients.len() <= self.max_var_id {
160            gradients.resize(self.max_var_id + 1, 0.0);
161        }
162
163        // Reverse pass through the tape
164        for node in self.nodes.iter().rev() {
165            match node {
166                TapeNode::Input { .. } => {
167                    // Input nodes don't propagate gradients backward
168                }
169                TapeNode::Constant { .. } => {
170                    // Constants have zero gradient
171                }
172                TapeNode::UnaryOp {
173                    op_type: _,
174                    input,
175                    result,
176                    partial,
177                } => {
178                    // Propagate gradient: ∂L/∂input += ∂L/∂result * ∂result/∂input
179                    // Skip constants (which have index usize::MAX)
180                    if *input != usize::MAX && *input < gradients.len() {
181                        gradients[*input] += gradients[*result] * partial;
182                    }
183                }
184                TapeNode::BinaryOp {
185                    op_type: _,
186                    left,
187                    right,
188                    result,
189                    left_partial,
190                    right_partial,
191                } => {
192                    // Propagate gradients to both inputs
193                    // Skip constants (which have index usize::MAX)
194                    if *left != usize::MAX && *left < gradients.len() {
195                        gradients[*left] += gradients[*result] * left_partial;
196                    }
197                    if *right != usize::MAX && *right < gradients.len() {
198                        gradients[*right] += gradients[*result] * right_partial;
199                    }
200                }
201                TapeNode::NAryOp {
202                    inputs,
203                    result,
204                    partials,
205                } => {
206                    // Propagate gradient to all inputs
207                    // Skip constants (which have index usize::MAX)
208                    for (input_id, partial) in inputs.iter().zip(partials.iter()) {
209                        if *input_id != usize::MAX && *input_id < gradients.len() {
210                            gradients[*input_id] += gradients[*result] * partial;
211                        }
212                    }
213                }
214            }
215        }
216
217        Ok(())
218    }
219
220    /// Forward pass to compute all variable values
221    pub fn forward(&self, input_values: &[f64]) -> Result<Vec<f64>, OptimizeError> {
222        let mut values = vec![0.0; self.max_var_id + 1];
223
224        // Set input values
225        for (i, var) in self.inputs.iter().enumerate() {
226            if i < input_values.len() {
227                values[var.id] = input_values[i];
228            } else {
229                values[var.id] = var.value; // Use default value
230            }
231        }
232
233        // Forward pass through the tape
234        for node in &self.nodes {
235            match node {
236                TapeNode::Input { .. } => {
237                    // Already handled above
238                }
239                TapeNode::Constant { value, result } => {
240                    // Set constant value
241                    values[*result] = *value;
242                }
243                TapeNode::UnaryOp {
244                    op_type,
245                    input,
246                    result,
247                    ..
248                } => {
249                    // Perform actual unary operation
250                    let input_val = values[*input];
251                    values[*result] = match op_type {
252                        UnaryOpType::Neg => -input_val,
253                        UnaryOpType::Ln => input_val.ln(),
254                        UnaryOpType::Exp => input_val.exp(),
255                        UnaryOpType::Sin => input_val.sin(),
256                        UnaryOpType::Cos => input_val.cos(),
257                        UnaryOpType::Tan => input_val.tan(),
258                        UnaryOpType::Sqrt => input_val.sqrt(),
259                        UnaryOpType::Square => input_val * input_val,
260                        UnaryOpType::Reciprocal => 1.0 / input_val,
261                    };
262                }
263                TapeNode::BinaryOp {
264                    op_type,
265                    left,
266                    right,
267                    result,
268                    ..
269                } => {
270                    // Perform actual binary operation
271                    let left_val = values[*left];
272                    let right_val = values[*right];
273                    values[*result] = match op_type {
274                        BinaryOpType::Add => left_val + right_val,
275                        BinaryOpType::Sub => left_val - right_val,
276                        BinaryOpType::Mul => left_val * right_val,
277                        BinaryOpType::Div => left_val / right_val,
278                        BinaryOpType::Pow => left_val.powf(right_val),
279                    };
280                }
281                TapeNode::NAryOp { inputs, result, .. } => {
282                    // N-ary operations are application-specific
283                    // For now, implement as sum (could be extended for other operations)
284                    values[*result] = inputs.iter().map(|&id| values[id]).sum();
285                }
286            }
287        }
288
289        Ok(values)
290    }
291
292    /// Add a constant to the tape
293    pub fn add_constant(&mut self, value: f64) -> usize {
294        let result_id = self.max_var_id + 1;
295        self.add_node(TapeNode::Constant {
296            value,
297            result: result_id,
298        });
299        result_id
300    }
301
302    /// Add a unary operation with automatic partial derivative computation
303    pub fn add_unary_op(
304        &mut self,
305        op_type: UnaryOpType,
306        input: usize,
307        input_values: &[f64],
308    ) -> usize {
309        let result_id = self.max_var_id + 1;
310
311        // Compute partial derivative based on operation type and current input value
312        let input_val = input_values[input];
313        let partial = match op_type {
314            UnaryOpType::Neg => -1.0,
315            UnaryOpType::Ln => 1.0 / input_val,
316            UnaryOpType::Exp => input_val.exp(),
317            UnaryOpType::Sin => input_val.cos(),
318            UnaryOpType::Cos => -input_val.sin(),
319            UnaryOpType::Tan => 1.0 + input_val.tan().powi(2), // sec^2(x)
320            UnaryOpType::Sqrt => 1.0 / (2.0 * input_val.sqrt()),
321            UnaryOpType::Square => 2.0 * input_val,
322            UnaryOpType::Reciprocal => -1.0 / (input_val * input_val),
323        };
324
325        self.add_node(TapeNode::UnaryOp {
326            op_type,
327            input,
328            result: result_id,
329            partial,
330        });
331
332        result_id
333    }
334
335    /// Add a binary operation with automatic partial derivative computation  
336    pub fn add_binary_op(
337        &mut self,
338        op_type: BinaryOpType,
339        left: usize,
340        right: usize,
341        input_values: &[f64],
342    ) -> usize {
343        let result_id = self.max_var_id + 1;
344
345        // Compute partial derivatives based on operation type and current input values
346        let left_val = input_values[left];
347        let right_val = input_values[right];
348
349        let (left_partial, right_partial) = match op_type {
350            BinaryOpType::Add => (1.0, 1.0),
351            BinaryOpType::Sub => (1.0, -1.0),
352            BinaryOpType::Mul => (right_val, left_val),
353            BinaryOpType::Div => (1.0 / right_val, -left_val / (right_val * right_val)),
354            BinaryOpType::Pow => {
355                // d/dx[f^g] = f^(g-1) * (g * f' + f * ln(f) * g')
356                // d/dx[x^y] = y * x^(y-1), d/dy[x^y] = x^y * ln(x)
357                (
358                    right_val * left_val.powf(right_val - 1.0),
359                    left_val.powf(right_val) * left_val.ln(),
360                )
361            }
362        };
363
364        self.add_node(TapeNode::BinaryOp {
365            op_type,
366            left,
367            right,
368            result: result_id,
369            left_partial,
370            right_partial,
371        });
372
373        result_id
374    }
375
376    /// Forward-mode AD: compute function value and derivatives simultaneously
377    pub fn forward_ad(
378        &self,
379        input_values: &[f64],
380        seed_derivatives: &[f64],
381    ) -> Result<(Vec<f64>, Vec<f64>), OptimizeError> {
382        let mut values = vec![0.0; self.max_var_id + 1];
383        let mut derivatives = vec![0.0; self.max_var_id + 1];
384
385        // Set input values and seed derivatives
386        for (i, var) in self.inputs.iter().enumerate() {
387            if i < input_values.len() {
388                values[var.id] = input_values[i];
389                if i < seed_derivatives.len() {
390                    derivatives[var.id] = seed_derivatives[i];
391                }
392            } else {
393                values[var.id] = var.value;
394            }
395        }
396
397        // Forward pass through the tape
398        for node in &self.nodes {
399            match node {
400                TapeNode::Input { .. } => {
401                    // Already handled above
402                }
403                TapeNode::Constant { value, result } => {
404                    // Constants have zero derivative
405                    values[*result] = *value;
406                    derivatives[*result] = 0.0;
407                }
408                TapeNode::UnaryOp {
409                    op_type,
410                    input,
411                    result,
412                    ..
413                } => {
414                    // Forward-mode AD for unary operations
415                    let input_val = values[*input];
416                    let input_deriv = derivatives[*input];
417
418                    // Compute function value
419                    values[*result] = match op_type {
420                        UnaryOpType::Neg => -input_val,
421                        UnaryOpType::Ln => input_val.ln(),
422                        UnaryOpType::Exp => input_val.exp(),
423                        UnaryOpType::Sin => input_val.sin(),
424                        UnaryOpType::Cos => input_val.cos(),
425                        UnaryOpType::Tan => input_val.tan(),
426                        UnaryOpType::Sqrt => input_val.sqrt(),
427                        UnaryOpType::Square => input_val * input_val,
428                        UnaryOpType::Reciprocal => 1.0 / input_val,
429                    };
430
431                    // Compute derivative using chain rule: d/dx[f(g(x))] = f'(g(x)) * g'(x)
432                    let f_prime = match op_type {
433                        UnaryOpType::Neg => -1.0,
434                        UnaryOpType::Ln => 1.0 / input_val,
435                        UnaryOpType::Exp => input_val.exp(),
436                        UnaryOpType::Sin => input_val.cos(),
437                        UnaryOpType::Cos => -input_val.sin(),
438                        UnaryOpType::Tan => 1.0 + input_val.tan().powi(2),
439                        UnaryOpType::Sqrt => 1.0 / (2.0 * input_val.sqrt()),
440                        UnaryOpType::Square => 2.0 * input_val,
441                        UnaryOpType::Reciprocal => -1.0 / (input_val * input_val),
442                    };
443                    derivatives[*result] = f_prime * input_deriv;
444                }
445                TapeNode::BinaryOp {
446                    op_type,
447                    left,
448                    right,
449                    result,
450                    ..
451                } => {
452                    // Forward-mode AD for binary operations
453                    let left_val = values[*left];
454                    let right_val = values[*right];
455                    let left_deriv = derivatives[*left];
456                    let right_deriv = derivatives[*right];
457
458                    // Compute function value
459                    values[*result] = match op_type {
460                        BinaryOpType::Add => left_val + right_val,
461                        BinaryOpType::Sub => left_val - right_val,
462                        BinaryOpType::Mul => left_val * right_val,
463                        BinaryOpType::Div => left_val / right_val,
464                        BinaryOpType::Pow => left_val.powf(right_val),
465                    };
466
467                    // Compute derivative using product rule and chain rule
468                    derivatives[*result] = match op_type {
469                        BinaryOpType::Add => left_deriv + right_deriv,
470                        BinaryOpType::Sub => left_deriv - right_deriv,
471                        BinaryOpType::Mul => left_deriv * right_val + left_val * right_deriv,
472                        BinaryOpType::Div => {
473                            (left_deriv * right_val - left_val * right_deriv)
474                                / (right_val * right_val)
475                        }
476                        BinaryOpType::Pow => {
477                            // d/dx[f^g] = f^g * (g' * ln(f) + g * f'/f)
478                            let result_val = left_val.powf(right_val);
479                            result_val
480                                * (right_deriv * left_val.ln() + right_val * left_deriv / left_val)
481                        }
482                    };
483                }
484                TapeNode::NAryOp {
485                    inputs,
486                    result,
487                    partials,
488                } => {
489                    // N-ary operations: sum for now
490                    values[*result] = inputs.iter().map(|&id| values[id]).sum();
491                    derivatives[*result] = inputs
492                        .iter()
493                        .enumerate()
494                        .map(|(i, &id)| partials.get(i).unwrap_or(&1.0) * derivatives[id])
495                        .sum();
496                }
497            }
498        }
499
500        Ok((values, derivatives))
501    }
502
503    /// Optimize the tape by removing unnecessary operations
504    pub fn optimize(&mut self) {
505        // Remove redundant operations, constant folding, etc.
506        // This is a placeholder for more sophisticated optimizations
507
508        // Remove nodes that are never used
509        let mut used_vars = std::collections::HashSet::new();
510
511        // Mark all variables that are actually used
512        for node in &self.nodes {
513            match node {
514                TapeNode::UnaryOp { input, result, .. } => {
515                    used_vars.insert(*input);
516                    used_vars.insert(*result);
517                }
518                TapeNode::BinaryOp {
519                    left,
520                    right,
521                    result,
522                    ..
523                } => {
524                    used_vars.insert(*left);
525                    used_vars.insert(*right);
526                    used_vars.insert(*result);
527                }
528                TapeNode::NAryOp { inputs, result, .. } => {
529                    for &input_id in inputs {
530                        used_vars.insert(input_id);
531                    }
532                    used_vars.insert(*result);
533                }
534                TapeNode::Input { var_id } => {
535                    used_vars.insert(*var_id);
536                }
537                _ => {}
538            }
539        }
540
541        // Could implement more optimizations here
542    }
543
544    /// Get the size of the tape
545    pub fn size(&self) -> usize {
546        self.nodes.len()
547    }
548
549    /// Check if the tape is empty
550    pub fn is_empty(&self) -> bool {
551        self.nodes.is_empty()
552    }
553
554    /// Clear the tape
555    pub fn clear(&mut self) {
556        self.nodes.clear();
557        self.inputs.clear();
558        self.var_positions.clear();
559        self.max_var_id = 0;
560    }
561
562    /// Get statistics about the tape
563    pub fn get_stats(&self) -> TapeStats {
564        let mut unary_ops = 0;
565        let mut binary_ops = 0;
566        let mut nary_ops = 0;
567        let mut constants = 0;
568
569        for node in &self.nodes {
570            match node {
571                TapeNode::Input { .. } => {}
572                TapeNode::Constant { .. } => constants += 1,
573                TapeNode::UnaryOp { .. } => unary_ops += 1,
574                TapeNode::BinaryOp { .. } => binary_ops += 1,
575                TapeNode::NAryOp { .. } => nary_ops += 1,
576            }
577        }
578
579        TapeStats {
580            total_nodes: self.nodes.len(),
581            input_vars: self.inputs.len(),
582            unary_ops,
583            binary_ops,
584            nary_ops,
585            constants,
586            max_var_id: self.max_var_id,
587        }
588    }
589}
590
591impl Default for ComputationTape {
592    fn default() -> Self {
593        Self::new()
594    }
595}
596
597/// Statistics about a computation tape
598#[derive(Debug, Clone)]
599pub struct TapeStats {
600    /// Total number of nodes
601    pub total_nodes: usize,
602    /// Number of input variables
603    pub input_vars: usize,
604    /// Number of unary operations
605    pub unary_ops: usize,
606    /// Number of binary operations
607    pub binary_ops: usize,
608    /// Number of n-ary operations
609    pub nary_ops: usize,
610    /// Number of constants
611    pub constants: usize,
612    /// Maximum variable ID
613    pub max_var_id: usize,
614}
615
616/// Tape builder for more convenient tape construction
617pub struct TapeBuilder {
618    tape: ComputationTape,
619    next_var_id: usize,
620}
621
622impl TapeBuilder {
623    /// Create a new tape builder
624    pub fn new() -> Self {
625        Self {
626            tape: ComputationTape::new(),
627            next_var_id: 0,
628        }
629    }
630
631    /// Add an input variable
632    pub fn input(&mut self, value: f64) -> usize {
633        let var_id = self.next_var_id;
634        self.next_var_id += 1;
635
636        let var = Variable::new(var_id, value);
637        self.tape.add_input(var);
638
639        var_id
640    }
641
642    /// Add a unary operation
643    pub fn unary_op(&mut self, op_type: UnaryOpType, input: usize, partial: f64) -> usize {
644        let result_id = self.next_var_id;
645        self.next_var_id += 1;
646
647        let node = TapeNode::UnaryOp {
648            op_type,
649            input,
650            result: result_id,
651            partial,
652        };
653        self.tape.add_node(node);
654
655        result_id
656    }
657
658    /// Add a binary operation
659    pub fn binary_op(
660        &mut self,
661        op_type: BinaryOpType,
662        left: usize,
663        right: usize,
664        left_partial: f64,
665        right_partial: f64,
666    ) -> usize {
667        let result_id = self.next_var_id;
668        self.next_var_id += 1;
669
670        let node = TapeNode::BinaryOp {
671            op_type,
672            left,
673            right,
674            result: result_id,
675            left_partial,
676            right_partial,
677        };
678        self.tape.add_node(node);
679
680        result_id
681    }
682
683    /// Finish building and return the tape
684    pub fn build(self) -> ComputationTape {
685        self.tape
686    }
687}
688
689impl Default for TapeBuilder {
690    fn default() -> Self {
691        Self::new()
692    }
693}
694
695/// Memory-efficient tape that can handle very large computations
696pub struct StreamingTape {
697    /// Current batch of operations
698    current_batch: Vec<TapeNode>,
699    /// Batch size for processing
700    batch_size: usize,
701    /// Function to process completed batches
702    batch_processor: Option<BatchProcessor>,
703}
704
705impl StreamingTape {
706    /// Create a new streaming tape
707    pub fn new(batch_size: usize) -> Self {
708        Self {
709            current_batch: Vec::with_capacity(batch_size),
710            batch_size,
711            batch_processor: None,
712        }
713    }
714
715    /// Set the batch processor
716    pub fn set_batch_processor<F>(&mut self, processor: F)
717    where
718        F: Fn(&[TapeNode]) -> Result<(), OptimizeError> + 'static,
719    {
720        self.batch_processor = Some(Box::new(processor));
721    }
722
723    /// Add a node to the streaming tape
724    pub fn add_node(&mut self, node: TapeNode) -> Result<(), OptimizeError> {
725        self.current_batch.push(node);
726
727        if self.current_batch.len() >= self.batch_size {
728            self.flush_batch()?;
729        }
730
731        Ok(())
732    }
733
734    /// Flush the current batch
735    pub fn flush_batch(&mut self) -> Result<(), OptimizeError> {
736        if let Some(ref processor) = self.batch_processor {
737            processor(&self.current_batch)?;
738        }
739        self.current_batch.clear();
740        Ok(())
741    }
742
743    /// Finalize the streaming tape
744    pub fn finalize(&mut self) -> Result<(), OptimizeError> {
745        if !self.current_batch.is_empty() {
746            self.flush_batch()?;
747        }
748        Ok(())
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755
756    #[test]
757    fn test_tape_construction() {
758        let mut builder = TapeBuilder::new();
759
760        // Build tape for: z = (x + y) * x
761        let x = builder.input(2.0);
762        let y = builder.input(3.0);
763        let sum = builder.binary_op(BinaryOpType::Add, x, y, 1.0, 1.0); // x + y, partials: ∂/∂x=1, ∂/∂y=1
764        let _result = builder.binary_op(BinaryOpType::Mul, sum, x, 2.0, 5.0); // sum * x, partials: ∂/∂sum=x=2, ∂/∂x=sum=5
765
766        let tape = builder.build();
767
768        assert_eq!(tape.size(), 4); // 2 inputs + 2 operations
769
770        let stats = tape.get_stats();
771        assert_eq!(stats.input_vars, 2);
772        assert_eq!(stats.binary_ops, 2);
773    }
774
775    #[test]
776    fn test_backward_pass() {
777        let mut tape = ComputationTape::new();
778
779        // Add inputs: x=2, y=3
780        tape.add_input(Variable::new(0, 2.0));
781        tape.add_input(Variable::new(1, 3.0));
782
783        // Add operation: z = x + y (result=2)
784        tape.add_node(TapeNode::BinaryOp {
785            op_type: BinaryOpType::Add,
786            left: 0,
787            right: 1,
788            result: 2,
789            left_partial: 1.0,  // ∂z/∂x = 1
790            right_partial: 1.0, // ∂z/∂y = 1
791        });
792
793        // Initialize gradients: ∂L/∂z = 1 (z is the output)
794        let mut gradients = vec![0.0, 0.0, 1.0];
795
796        tape.backward(&mut gradients).unwrap();
797
798        // Check gradients
799        assert_eq!(gradients[0], 1.0); // ∂L/∂x = ∂L/∂z * ∂z/∂x = 1 * 1 = 1
800        assert_eq!(gradients[1], 1.0); // ∂L/∂y = ∂L/∂z * ∂z/∂y = 1 * 1 = 1
801    }
802
803    #[test]
804    fn test_tape_optimization() {
805        let mut tape = ComputationTape::new();
806
807        tape.add_input(Variable::new(0, 1.0));
808        tape.add_node(TapeNode::UnaryOp {
809            op_type: UnaryOpType::Neg,
810            input: 0,
811            result: 1,
812            partial: 1.0,
813        });
814
815        let original_size = tape.size();
816        tape.optimize();
817
818        // Optimization might not change this simple tape, but it shouldn't break it
819        assert!(tape.size() <= original_size);
820    }
821
822    #[test]
823    fn test_streaming_tape() {
824        let mut streaming_tape = StreamingTape::new(2);
825
826        streaming_tape.set_batch_processor(move |_batch| {
827            // Just a placeholder processor for testing
828            Ok(())
829        });
830
831        // Add nodes - should trigger batch processing
832        streaming_tape
833            .add_node(TapeNode::Input { var_id: 0 })
834            .unwrap();
835        streaming_tape
836            .add_node(TapeNode::Input { var_id: 1 })
837            .unwrap();
838
839        // This should have triggered one batch
840        streaming_tape
841            .add_node(TapeNode::UnaryOp {
842                op_type: UnaryOpType::Neg,
843                input: 0,
844                result: 2,
845                partial: 1.0,
846            })
847            .unwrap();
848
849        streaming_tape.finalize().unwrap();
850    }
851
852    #[test]
853    fn test_tape_stats() {
854        let mut builder = TapeBuilder::new();
855
856        let x = builder.input(1.0);
857        let y = builder.input(2.0);
858        builder.binary_op(BinaryOpType::Add, x, y, 1.0, 1.0);
859        builder.unary_op(UnaryOpType::Neg, x, 2.0);
860
861        let tape = builder.build();
862        let stats = tape.get_stats();
863
864        assert_eq!(stats.input_vars, 2);
865        assert_eq!(stats.binary_ops, 1);
866        assert_eq!(stats.unary_ops, 1);
867        assert_eq!(stats.total_nodes, 4); // 2 inputs + 1 binary + 1 unary
868    }
869}