Skip to main content

trustformers_core/autodiff/
graph.rs

1//! Computational graph for automatic differentiation.
2//!
3//! This module provides the computational graph infrastructure for tracking
4//! operations and computing gradients through reverse-mode automatic differentiation.
5
6#![allow(unused_variables)] // Autodiff graph
7
8use crate::errors::{Result, TrustformersError};
9use crate::tensor::Tensor;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12
13/// Unique identifier for nodes in the computation graph
14pub type NodeId = usize;
15
16/// Computational graph for tracking operations and gradients
17#[derive(Debug)]
18pub struct ComputationGraph {
19    /// Nodes in the computation graph
20    nodes: HashMap<NodeId, GraphNode>,
21    /// Next available node ID
22    next_id: NodeId,
23    /// Topological ordering of nodes
24    topological_order: Vec<NodeId>,
25    /// Whether the graph is dirty and needs recomputation
26    dirty: bool,
27    /// Root nodes (variables with no parents)
28    root_nodes: Vec<NodeId>,
29    /// Leaf nodes (outputs that gradients flow back from)
30    leaf_nodes: Vec<NodeId>,
31}
32
33/// Node in the computation graph
34#[derive(Debug, Clone)]
35pub struct GraphNode {
36    /// Unique identifier for this node
37    pub id: NodeId,
38    /// The tensor value at this node
39    pub value: Tensor,
40    /// Gradient accumulated at this node
41    pub gradient: Option<Tensor>,
42    /// Operation that produced this node
43    pub operation: Option<OperationType>,
44    /// Parent nodes (inputs to the operation)
45    pub parents: Vec<NodeId>,
46    /// Child nodes (nodes that use this as input)
47    pub children: Vec<NodeId>,
48    /// Whether this node requires gradients
49    pub requires_grad: bool,
50    /// Whether this node is a leaf (variable)
51    pub is_leaf: bool,
52    /// Name for debugging
53    pub name: Option<String>,
54    /// Shape information for optimization
55    pub shape: Vec<usize>,
56}
57
58/// Types of operations in the computation graph
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub enum OperationType {
61    /// Binary operations
62    Add,
63    Subtract,
64    Multiply,
65    Divide,
66    MatrixMultiply,
67
68    /// Unary operations
69    Negate,
70    Reciprocal,
71    Square,
72    Sqrt,
73    Log,
74    Exp,
75
76    /// Activation functions
77    Sigmoid,
78    Tanh,
79    ReLU,
80    LeakyReLU(f32),
81    Softmax,
82    LogSoftmax,
83
84    /// Tensor operations
85    Reshape(Vec<usize>),
86    Transpose(Vec<usize>),
87    Slice(Vec<std::ops::Range<usize>>),
88    Concat(usize),     // axis
89    Split(Vec<usize>), // split sizes
90
91    /// Reduction operations
92    Sum(Option<Vec<usize>>), // axes
93    Mean(Option<Vec<usize>>), // axes
94    Max(Option<Vec<usize>>),  // axes
95    Min(Option<Vec<usize>>),  // axes
96
97    /// Specialized operations
98    LayerNorm(f32), // epsilon
99    Dropout(f32),   // probability
100    BatchNorm(f32), // epsilon
101
102    /// Custom operation
103    Custom(String),
104}
105
106/// Gradient function trait for operations
107pub trait GradientFunction: Send + Sync {
108    /// Compute gradients for the inputs given the gradient of the output
109    fn backward(&self, grad_output: &Tensor, inputs: &[&Tensor]) -> Result<Vec<Tensor>>;
110
111    /// Get the operation type
112    fn operation_type(&self) -> OperationType;
113}
114
115impl ComputationGraph {
116    /// Create a new empty computation graph
117    pub fn new() -> Self {
118        Self {
119            nodes: HashMap::new(),
120            next_id: 0,
121            topological_order: Vec::new(),
122            dirty: false,
123            root_nodes: Vec::new(),
124            leaf_nodes: Vec::new(),
125        }
126    }
127
128    /// Add a new node to the graph
129    pub fn add_node(&mut self, value: Tensor, requires_grad: bool, name: Option<String>) -> NodeId {
130        let id = self.next_id;
131        self.next_id += 1;
132
133        let shape = value.shape();
134        let node = GraphNode {
135            id,
136            value,
137            gradient: None,
138            operation: None,
139            parents: Vec::new(),
140            children: Vec::new(),
141            requires_grad,
142            is_leaf: true,
143            name,
144            shape,
145        };
146
147        self.nodes.insert(id, node);
148        if requires_grad {
149            self.root_nodes.push(id);
150        }
151        self.dirty = true;
152
153        id
154    }
155
156    /// Add an operation node to the graph
157    pub fn add_operation_node(
158        &mut self,
159        value: Tensor,
160        operation: OperationType,
161        parents: Vec<NodeId>,
162        requires_grad: bool,
163        name: Option<String>,
164    ) -> Result<NodeId> {
165        let id = self.next_id;
166        self.next_id += 1;
167
168        // Update parent nodes to include this as a child
169        for parent_id in &parents {
170            if let Some(parent) = self.nodes.get_mut(parent_id) {
171                parent.children.push(id);
172            } else {
173                return Err(TrustformersError::tensor_op_error(
174                    &format!("Parent node {} not found", parent_id),
175                    "ComputationGraph::add_operation_node",
176                ));
177            }
178        }
179
180        let shape = value.shape();
181        let node = GraphNode {
182            id,
183            value,
184            gradient: None,
185            operation: Some(operation),
186            parents,
187            children: Vec::new(),
188            requires_grad,
189            is_leaf: false,
190            name,
191            shape,
192        };
193
194        self.nodes.insert(id, node);
195        self.dirty = true;
196
197        Ok(id)
198    }
199
200    /// Get a node by ID
201    pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
202        self.nodes.get(&id)
203    }
204
205    /// Get a mutable reference to a node by ID
206    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut GraphNode> {
207        self.nodes.get_mut(&id)
208    }
209
210    /// Compute topological ordering of nodes
211    pub fn compute_topological_order(&mut self) -> Result<()> {
212        if !self.dirty {
213            return Ok(());
214        }
215
216        let mut in_degree = HashMap::new();
217        let mut queue = VecDeque::new();
218        let mut result = Vec::new();
219
220        // Initialize in-degree counts
221        for (id, node) in &self.nodes {
222            in_degree.insert(*id, node.parents.len());
223            if node.parents.is_empty() {
224                queue.push_back(*id);
225            }
226        }
227
228        // Process nodes in topological order
229        while let Some(node_id) = queue.pop_front() {
230            result.push(node_id);
231
232            let Some(node) = self.nodes.get(&node_id) else {
233                continue;
234            };
235
236            for child_id in &node.children {
237                let Some(degree) = in_degree.get_mut(child_id) else {
238                    continue;
239                };
240
241                *degree -= 1;
242                if *degree == 0 {
243                    queue.push_back(*child_id);
244                }
245            }
246        }
247
248        if result.len() != self.nodes.len() {
249            return Err(TrustformersError::tensor_op_error(
250                "Cycle detected in computation graph",
251                "ComputationGraph::compute_topological_order",
252            ));
253        }
254
255        self.topological_order = result;
256        self.dirty = false;
257
258        Ok(())
259    }
260
261    /// Perform backward pass to compute gradients
262    pub fn backward(&mut self, output_id: NodeId, grad_output: Option<Tensor>) -> Result<()> {
263        // Ensure topological order is computed
264        self.compute_topological_order()?;
265
266        // Initialize gradient for output node
267        if let Some(output_node) = self.nodes.get_mut(&output_id) {
268            output_node.gradient = Some(grad_output.unwrap_or_else(|| {
269                Tensor::ones(&output_node.shape).expect("Failed to create ones tensor")
270            }));
271        } else {
272            return Err(TrustformersError::tensor_op_error(
273                &format!("Output node {} not found", output_id),
274                "ComputationGraph::backward",
275            ));
276        }
277
278        // Process nodes in reverse topological order
279        for &node_id in self.topological_order.iter().rev() {
280            let Some(node) = self.nodes.get(&node_id).cloned() else {
281                continue;
282            };
283
284            let Some(ref grad) = node.gradient else {
285                continue;
286            };
287
288            let Some(ref operation) = node.operation else {
289                continue;
290            };
291
292            // Compute gradients for parent nodes
293            let parent_gradients =
294                self.compute_operation_gradients(operation, grad, &node.parents)?;
295
296            // Accumulate gradients in parent nodes
297            for (parent_id, parent_grad) in node.parents.iter().zip(parent_gradients.iter()) {
298                let Some(parent_node) = self.nodes.get_mut(parent_id) else {
299                    continue;
300                };
301
302                if !parent_node.requires_grad {
303                    continue;
304                }
305
306                if let Some(ref mut existing_grad) = parent_node.gradient {
307                    *existing_grad = existing_grad.add(parent_grad)?;
308                } else {
309                    parent_node.gradient = Some(parent_grad.clone());
310                }
311            }
312        }
313
314        Ok(())
315    }
316
317    /// Compute gradients for an operation
318    fn compute_operation_gradients(
319        &self,
320        operation: &OperationType,
321        grad_output: &Tensor,
322        parent_ids: &[NodeId],
323    ) -> Result<Vec<Tensor>> {
324        let parent_values: Vec<&Tensor> =
325            parent_ids.iter().map(|id| &self.nodes[id].value).collect();
326
327        match operation {
328            OperationType::Add => {
329                // Gradient of addition is just the incoming gradient for both inputs
330                Ok(vec![grad_output.clone(), grad_output.clone()])
331            },
332            OperationType::Subtract => {
333                // Gradient of subtraction: da = dout, db = -dout
334                Ok(vec![grad_output.clone(), grad_output.neg()?])
335            },
336            OperationType::Multiply => {
337                // Gradient of multiplication: da = dout * b, db = dout * a
338                if parent_values.len() != 2 {
339                    return Err(TrustformersError::tensor_op_error(
340                        "Multiply operation requires exactly 2 inputs",
341                        "ComputationGraph::compute_operation_gradients",
342                    ));
343                }
344                Ok(vec![
345                    grad_output.mul(parent_values[1])?,
346                    grad_output.mul(parent_values[0])?,
347                ])
348            },
349            OperationType::Divide => {
350                // Gradient of division: da = dout / b, db = -dout * a / (b * b)
351                if parent_values.len() != 2 {
352                    return Err(TrustformersError::tensor_op_error(
353                        "Divide operation requires exactly 2 inputs",
354                        "ComputationGraph::compute_operation_gradients",
355                    ));
356                }
357                let a = parent_values[0];
358                let b = parent_values[1];
359                Ok(vec![
360                    grad_output.div(b)?,
361                    grad_output.mul(a)?.neg()?.div(&b.mul(b)?)?,
362                ])
363            },
364            OperationType::MatrixMultiply => {
365                // Gradient of matrix multiplication: da = dout @ b^T, db = a^T @ dout
366                if parent_values.len() != 2 {
367                    return Err(TrustformersError::tensor_op_error(
368                        "MatrixMultiply operation requires exactly 2 inputs",
369                        "ComputationGraph::compute_operation_gradients",
370                    ));
371                }
372                let a = parent_values[0];
373                let b = parent_values[1];
374
375                // Compute gradients
376                let a_shape = a.shape();
377                let b_shape = b.shape();
378
379                let grad_a = if a_shape.len() == 2 && b_shape.len() == 2 {
380                    // Simple 2D matrix multiplication
381                    grad_output.matmul(&b.transpose(1, 0)?)?
382                } else {
383                    // Handle batch matrix multiplication
384                    let b_transposed = b.transpose(2, 1)?;
385                    grad_output.matmul(&b_transposed)?
386                };
387
388                let grad_b = if a_shape.len() == 2 && b_shape.len() == 2 {
389                    // Simple 2D matrix multiplication
390                    a.transpose(1, 0)?.matmul(grad_output)?
391                } else {
392                    // Handle batch matrix multiplication
393                    let a_transposed = a.permute(&[0, 2, 1])?;
394                    a_transposed.matmul(grad_output)?
395                };
396
397                Ok(vec![grad_a, grad_b])
398            },
399            OperationType::Sigmoid => {
400                // Gradient of sigmoid: dout * sigmoid(x) * (1 - sigmoid(x))
401                if parent_values.len() != 1 {
402                    return Err(TrustformersError::tensor_op_error(
403                        "Sigmoid operation requires exactly 1 input",
404                        "ComputationGraph::compute_operation_gradients",
405                    ));
406                }
407                let sigmoid_out = parent_values[0].sigmoid()?;
408                let one = Tensor::ones(&sigmoid_out.shape())?;
409                let grad_input = grad_output.mul(&sigmoid_out)?.mul(&one.sub(&sigmoid_out)?)?;
410                Ok(vec![grad_input])
411            },
412            OperationType::Tanh => {
413                // Gradient of tanh: dout * (1 - tanh(x)^2)
414                if parent_values.len() != 1 {
415                    return Err(TrustformersError::tensor_op_error(
416                        "Tanh operation requires exactly 1 input",
417                        "ComputationGraph::compute_operation_gradients",
418                    ));
419                }
420                let tanh_out = parent_values[0].tanh()?;
421                let one = Tensor::ones(&tanh_out.shape())?;
422                let grad_input = grad_output.mul(&one.sub(&tanh_out.mul(&tanh_out)?)?)?;
423                Ok(vec![grad_input])
424            },
425            OperationType::ReLU => {
426                // Gradient of ReLU: dout * (x > 0)
427                if parent_values.len() != 1 {
428                    return Err(TrustformersError::tensor_op_error(
429                        "ReLU operation requires exactly 1 input",
430                        "ComputationGraph::compute_operation_gradients",
431                    ));
432                }
433                let input = parent_values[0];
434                let zero = Tensor::zeros(&input.shape())?;
435                let mask = input.greater(&zero)?;
436                let grad_input = grad_output.mul(&mask)?;
437                Ok(vec![grad_input])
438            },
439            OperationType::LeakyReLU(alpha) => {
440                // Gradient of LeakyReLU: dout * (x > 0 ? 1 : alpha)
441                if parent_values.len() != 1 {
442                    return Err(TrustformersError::tensor_op_error(
443                        "LeakyReLU operation requires exactly 1 input",
444                        "ComputationGraph::compute_operation_gradients",
445                    ));
446                }
447                let input = parent_values[0];
448                let zero = Tensor::zeros(&input.shape())?;
449                let alpha_tensor = Tensor::scalar(*alpha)?;
450                let one = Tensor::ones(&input.shape())?;
451
452                let positive_mask = input.greater(&zero)?;
453                let negative_mask = one.sub(&positive_mask)?;
454
455                let grad_input =
456                    grad_output.mul(&positive_mask.add(&negative_mask.mul(&alpha_tensor)?)?)?;
457                Ok(vec![grad_input])
458            },
459            OperationType::Sum(axes) => {
460                // Gradient of sum: broadcast the gradient back to original shape
461                if parent_values.len() != 1 {
462                    return Err(TrustformersError::tensor_op_error(
463                        "Sum operation requires exactly 1 input",
464                        "ComputationGraph::compute_operation_gradients",
465                    ));
466                }
467                let input_shape = parent_values[0].shape();
468                let grad_input =
469                    self.broadcast_gradient(grad_output, &input_shape, axes.as_ref())?;
470                Ok(vec![grad_input])
471            },
472            OperationType::Mean(axes) => {
473                // Gradient of mean: broadcast the gradient back and divide by the number of elements
474                if parent_values.len() != 1 {
475                    return Err(TrustformersError::tensor_op_error(
476                        "Mean operation requires exactly 1 input",
477                        "ComputationGraph::compute_operation_gradients",
478                    ));
479                }
480                let input_shape = parent_values[0].shape();
481                let grad_broadcasted =
482                    self.broadcast_gradient(grad_output, &input_shape, axes.as_ref())?;
483
484                // Compute the number of elements that were averaged
485                let num_elements = if let Some(axes) = axes {
486                    axes.iter().map(|&axis| input_shape[axis]).product::<usize>()
487                } else {
488                    input_shape.iter().product::<usize>()
489                };
490
491                let grad_input = grad_broadcasted.scalar_div(num_elements as f32)?;
492                Ok(vec![grad_input])
493            },
494            OperationType::Reshape(target_shape) => {
495                // Gradient of reshape: reshape gradient back to original shape
496                if parent_values.len() != 1 {
497                    return Err(TrustformersError::tensor_op_error(
498                        "Reshape operation requires exactly 1 input",
499                        "ComputationGraph::compute_operation_gradients",
500                    ));
501                }
502                let original_shape = parent_values[0].shape();
503                let grad_input = grad_output.reshape(&original_shape)?;
504                Ok(vec![grad_input])
505            },
506            OperationType::Transpose(permutation) => {
507                // Gradient of transpose: apply inverse permutation
508                if parent_values.len() != 1 {
509                    return Err(TrustformersError::tensor_op_error(
510                        "Transpose operation requires exactly 1 input",
511                        "ComputationGraph::compute_operation_gradients",
512                    ));
513                }
514                let inverse_permutation = self.compute_inverse_permutation(permutation)?;
515                let grad_input = grad_output.permute(&inverse_permutation)?;
516                Ok(vec![grad_input])
517            },
518            _ => {
519                // For unimplemented operations, return zero gradients
520                let zero_grads = parent_values
521                    .iter()
522                    .map(|input| {
523                        Tensor::zeros(&input.shape()).expect("Failed to create zeros tensor")
524                    })
525                    .collect();
526                Ok(zero_grads)
527            },
528        }
529    }
530
531    /// Broadcast gradient back to original shape
532    fn broadcast_gradient(
533        &self,
534        grad_output: &Tensor,
535        original_shape: &[usize],
536        axes: Option<&Vec<usize>>,
537    ) -> Result<Tensor> {
538        if let Some(axes) = axes {
539            // Sum was performed along specific axes
540            let mut result = grad_output.clone();
541            for &axis in axes {
542                result = result.unsqueeze(axis)?;
543            }
544            result.broadcast_to(original_shape)
545        } else {
546            // Sum was performed along all axes
547            let grad_scalar = grad_output.clone();
548            grad_scalar.broadcast_to(original_shape)
549        }
550    }
551
552    /// Compute inverse permutation for transpose
553    fn compute_inverse_permutation(&self, permutation: &[usize]) -> Result<Vec<usize>> {
554        let mut inverse = vec![0; permutation.len()];
555        for (i, &p) in permutation.iter().enumerate() {
556            if p >= permutation.len() {
557                return Err(TrustformersError::tensor_op_error(
558                    &format!("Invalid permutation index: {}", p),
559                    "ComputationGraph::compute_inverse_permutation",
560                ));
561            }
562            inverse[p] = i;
563        }
564        Ok(inverse)
565    }
566
567    /// Clear all gradients in the graph
568    pub fn zero_grad(&mut self) {
569        for node in self.nodes.values_mut() {
570            node.gradient = None;
571        }
572    }
573
574    /// Get gradient for a specific node
575    pub fn get_gradient(&self, node_id: NodeId) -> Option<&Tensor> {
576        self.nodes.get(&node_id)?.gradient.as_ref()
577    }
578
579    /// Get value for a specific node
580    pub fn get_value(&self, node_id: NodeId) -> Option<&Tensor> {
581        self.nodes.get(&node_id).map(|node| &node.value)
582    }
583
584    /// Update the value of a node
585    pub fn update_value(&mut self, node_id: NodeId, value: Tensor) -> Result<()> {
586        if let Some(node) = self.nodes.get_mut(&node_id) {
587            node.value = value;
588            node.shape = node.value.shape();
589            Ok(())
590        } else {
591            Err(TrustformersError::tensor_op_error(
592                &format!("Node {} not found", node_id),
593                "ComputationGraph::update_value",
594            ))
595        }
596    }
597
598    /// Get all root nodes (variables)
599    pub fn get_root_nodes(&self) -> &[NodeId] {
600        &self.root_nodes
601    }
602
603    /// Get all leaf nodes
604    pub fn get_leaf_nodes(&self) -> &[NodeId] {
605        &self.leaf_nodes
606    }
607
608    /// Set a node as a leaf node
609    pub fn set_leaf_node(&mut self, node_id: NodeId) {
610        if !self.leaf_nodes.contains(&node_id) {
611            self.leaf_nodes.push(node_id);
612        }
613    }
614
615    /// Get the number of nodes in the graph
616    pub fn num_nodes(&self) -> usize {
617        self.nodes.len()
618    }
619
620    /// Get the topological order of nodes
621    pub fn get_topological_order(&self) -> &[NodeId] {
622        &self.topological_order
623    }
624
625    /// Export the graph structure for visualization
626    pub fn export_graph(&self) -> GraphExport {
627        let nodes: Vec<_> = self.nodes.values().cloned().collect();
628        GraphExport {
629            nodes,
630            topological_order: self.topological_order.clone(),
631        }
632    }
633}
634
635/// Exported graph structure for visualization
636#[derive(Debug, Clone)]
637pub struct GraphExport {
638    pub nodes: Vec<GraphNode>,
639    pub topological_order: Vec<NodeId>,
640}
641
642impl Default for ComputationGraph {
643    fn default() -> Self {
644        Self::new()
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use crate::tensor::Tensor;
652
653    #[test]
654    fn test_graph_creation() {
655        let mut graph = ComputationGraph::new();
656        assert_eq!(graph.num_nodes(), 0);
657
658        let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
659        let node_id = graph.add_node(tensor, true, Some("test".to_string()));
660        assert_eq!(graph.num_nodes(), 1);
661        assert_eq!(node_id, 0);
662    }
663
664    #[test]
665    fn test_topological_order() {
666        let mut graph = ComputationGraph::new();
667
668        // Create a simple computation: c = a + b
669        let a = Tensor::ones(&[2, 2]).expect("Failed to create ones tensor");
670        let b = Tensor::ones(&[2, 2]).expect("Failed to create ones tensor");
671        let c = a.add(&b).expect("Addition failed");
672
673        let node_a = graph.add_node(a, true, Some("a".to_string()));
674        let node_b = graph.add_node(b, true, Some("b".to_string()));
675        let node_c = graph
676            .add_operation_node(
677                c,
678                OperationType::Add,
679                vec![node_a, node_b],
680                true,
681                Some("c".to_string()),
682            )
683            .expect("operation failed in test");
684
685        graph.compute_topological_order().expect("operation failed in test");
686        let order = graph.get_topological_order();
687        assert_eq!(order.len(), 3);
688
689        // Verify that parents come before children
690        let a_pos = order.iter().position(|&id| id == node_a).expect("operation failed in test");
691        let b_pos = order.iter().position(|&id| id == node_b).expect("operation failed in test");
692        let c_pos = order.iter().position(|&id| id == node_c).expect("operation failed in test");
693
694        assert!(a_pos < c_pos);
695        assert!(b_pos < c_pos);
696    }
697
698    #[test]
699    fn test_backward_pass() {
700        let mut graph = ComputationGraph::new();
701
702        // Create computation: c = a * b
703        let a = Tensor::scalar(2.0).expect("tensor operation failed");
704        let b = Tensor::scalar(3.0).expect("tensor operation failed");
705        let c = a.mul(&b).expect("Multiplication failed");
706
707        let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
708        let node_b = graph.add_node(b.clone(), true, Some("b".to_string()));
709        let node_c = graph
710            .add_operation_node(
711                c,
712                OperationType::Multiply,
713                vec![node_a, node_b],
714                true,
715                Some("c".to_string()),
716            )
717            .expect("operation failed in test");
718
719        // Backward pass
720        graph.backward(node_c, None).expect("operation failed in test");
721
722        // Check gradients
723        let grad_a = graph.get_gradient(node_a).expect("operation failed in test");
724        let grad_b = graph.get_gradient(node_b).expect("operation failed in test");
725
726        // Gradient of a should be b (3.0)
727        // Gradient of b should be a (2.0)
728        assert_eq!(
729            grad_a.to_vec_f32().expect("operation failed in test")[0],
730            3.0
731        );
732        assert_eq!(
733            grad_b.to_vec_f32().expect("operation failed in test")[0],
734            2.0
735        );
736    }
737
738    #[test]
739    fn test_gradient_accumulation() {
740        let mut graph = ComputationGraph::new();
741
742        // Create computation: d = a + a (gradient should accumulate)
743        let a = Tensor::scalar(2.0).expect("tensor operation failed");
744        let d = a.add(&a).expect("Addition failed");
745
746        let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
747        let node_d = graph
748            .add_operation_node(
749                d,
750                OperationType::Add,
751                vec![node_a, node_a],
752                true,
753                Some("d".to_string()),
754            )
755            .expect("operation failed in test");
756
757        // Backward pass
758        graph.backward(node_d, None).expect("operation failed in test");
759
760        // Check gradient accumulation
761        let grad_a = graph.get_gradient(node_a).expect("operation failed in test");
762
763        // Gradient should be 2.0 (1.0 + 1.0 from both uses)
764        assert_eq!(
765            grad_a.to_vec_f32().expect("operation failed in test")[0],
766            2.0
767        );
768    }
769}