1#![allow(unused_variables)] use crate::errors::{Result, TrustformersError};
9use crate::tensor::Tensor;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, VecDeque};
12
13pub type NodeId = usize;
15
16#[derive(Debug)]
18pub struct ComputationGraph {
19 nodes: HashMap<NodeId, GraphNode>,
21 next_id: NodeId,
23 topological_order: Vec<NodeId>,
25 dirty: bool,
27 root_nodes: Vec<NodeId>,
29 leaf_nodes: Vec<NodeId>,
31}
32
33#[derive(Debug, Clone)]
35pub struct GraphNode {
36 pub id: NodeId,
38 pub value: Tensor,
40 pub gradient: Option<Tensor>,
42 pub operation: Option<OperationType>,
44 pub parents: Vec<NodeId>,
46 pub children: Vec<NodeId>,
48 pub requires_grad: bool,
50 pub is_leaf: bool,
52 pub name: Option<String>,
54 pub shape: Vec<usize>,
56}
57
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60pub enum OperationType {
61 Add,
63 Subtract,
64 Multiply,
65 Divide,
66 MatrixMultiply,
67
68 Negate,
70 Reciprocal,
71 Square,
72 Sqrt,
73 Log,
74 Exp,
75
76 Sigmoid,
78 Tanh,
79 ReLU,
80 LeakyReLU(f32),
81 Softmax,
82 LogSoftmax,
83
84 Reshape(Vec<usize>),
86 Transpose(Vec<usize>),
87 Slice(Vec<std::ops::Range<usize>>),
88 Concat(usize), Split(Vec<usize>), Sum(Option<Vec<usize>>), Mean(Option<Vec<usize>>), Max(Option<Vec<usize>>), Min(Option<Vec<usize>>), LayerNorm(f32), Dropout(f32), BatchNorm(f32), Custom(String),
104}
105
106pub trait GradientFunction: Send + Sync {
108 fn backward(&self, grad_output: &Tensor, inputs: &[&Tensor]) -> Result<Vec<Tensor>>;
110
111 fn operation_type(&self) -> OperationType;
113}
114
115impl ComputationGraph {
116 pub fn new() -> Self {
118 Self {
119 nodes: HashMap::new(),
120 next_id: 0,
121 topological_order: Vec::new(),
122 dirty: false,
123 root_nodes: Vec::new(),
124 leaf_nodes: Vec::new(),
125 }
126 }
127
128 pub fn add_node(&mut self, value: Tensor, requires_grad: bool, name: Option<String>) -> NodeId {
130 let id = self.next_id;
131 self.next_id += 1;
132
133 let shape = value.shape();
134 let node = GraphNode {
135 id,
136 value,
137 gradient: None,
138 operation: None,
139 parents: Vec::new(),
140 children: Vec::new(),
141 requires_grad,
142 is_leaf: true,
143 name,
144 shape,
145 };
146
147 self.nodes.insert(id, node);
148 if requires_grad {
149 self.root_nodes.push(id);
150 }
151 self.dirty = true;
152
153 id
154 }
155
156 pub fn add_operation_node(
158 &mut self,
159 value: Tensor,
160 operation: OperationType,
161 parents: Vec<NodeId>,
162 requires_grad: bool,
163 name: Option<String>,
164 ) -> Result<NodeId> {
165 let id = self.next_id;
166 self.next_id += 1;
167
168 for parent_id in &parents {
170 if let Some(parent) = self.nodes.get_mut(parent_id) {
171 parent.children.push(id);
172 } else {
173 return Err(TrustformersError::tensor_op_error(
174 &format!("Parent node {} not found", parent_id),
175 "ComputationGraph::add_operation_node",
176 ));
177 }
178 }
179
180 let shape = value.shape();
181 let node = GraphNode {
182 id,
183 value,
184 gradient: None,
185 operation: Some(operation),
186 parents,
187 children: Vec::new(),
188 requires_grad,
189 is_leaf: false,
190 name,
191 shape,
192 };
193
194 self.nodes.insert(id, node);
195 self.dirty = true;
196
197 Ok(id)
198 }
199
200 pub fn get_node(&self, id: NodeId) -> Option<&GraphNode> {
202 self.nodes.get(&id)
203 }
204
205 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut GraphNode> {
207 self.nodes.get_mut(&id)
208 }
209
210 pub fn compute_topological_order(&mut self) -> Result<()> {
212 if !self.dirty {
213 return Ok(());
214 }
215
216 let mut in_degree = HashMap::new();
217 let mut queue = VecDeque::new();
218 let mut result = Vec::new();
219
220 for (id, node) in &self.nodes {
222 in_degree.insert(*id, node.parents.len());
223 if node.parents.is_empty() {
224 queue.push_back(*id);
225 }
226 }
227
228 while let Some(node_id) = queue.pop_front() {
230 result.push(node_id);
231
232 let Some(node) = self.nodes.get(&node_id) else {
233 continue;
234 };
235
236 for child_id in &node.children {
237 let Some(degree) = in_degree.get_mut(child_id) else {
238 continue;
239 };
240
241 *degree -= 1;
242 if *degree == 0 {
243 queue.push_back(*child_id);
244 }
245 }
246 }
247
248 if result.len() != self.nodes.len() {
249 return Err(TrustformersError::tensor_op_error(
250 "Cycle detected in computation graph",
251 "ComputationGraph::compute_topological_order",
252 ));
253 }
254
255 self.topological_order = result;
256 self.dirty = false;
257
258 Ok(())
259 }
260
261 pub fn backward(&mut self, output_id: NodeId, grad_output: Option<Tensor>) -> Result<()> {
263 self.compute_topological_order()?;
265
266 if let Some(output_node) = self.nodes.get_mut(&output_id) {
268 output_node.gradient = Some(grad_output.unwrap_or_else(|| {
269 Tensor::ones(&output_node.shape).expect("Failed to create ones tensor")
270 }));
271 } else {
272 return Err(TrustformersError::tensor_op_error(
273 &format!("Output node {} not found", output_id),
274 "ComputationGraph::backward",
275 ));
276 }
277
278 for &node_id in self.topological_order.iter().rev() {
280 let Some(node) = self.nodes.get(&node_id).cloned() else {
281 continue;
282 };
283
284 let Some(ref grad) = node.gradient else {
285 continue;
286 };
287
288 let Some(ref operation) = node.operation else {
289 continue;
290 };
291
292 let parent_gradients =
294 self.compute_operation_gradients(operation, grad, &node.parents)?;
295
296 for (parent_id, parent_grad) in node.parents.iter().zip(parent_gradients.iter()) {
298 let Some(parent_node) = self.nodes.get_mut(parent_id) else {
299 continue;
300 };
301
302 if !parent_node.requires_grad {
303 continue;
304 }
305
306 if let Some(ref mut existing_grad) = parent_node.gradient {
307 *existing_grad = existing_grad.add(parent_grad)?;
308 } else {
309 parent_node.gradient = Some(parent_grad.clone());
310 }
311 }
312 }
313
314 Ok(())
315 }
316
317 fn compute_operation_gradients(
319 &self,
320 operation: &OperationType,
321 grad_output: &Tensor,
322 parent_ids: &[NodeId],
323 ) -> Result<Vec<Tensor>> {
324 let parent_values: Vec<&Tensor> =
325 parent_ids.iter().map(|id| &self.nodes[id].value).collect();
326
327 match operation {
328 OperationType::Add => {
329 Ok(vec![grad_output.clone(), grad_output.clone()])
331 },
332 OperationType::Subtract => {
333 Ok(vec![grad_output.clone(), grad_output.neg()?])
335 },
336 OperationType::Multiply => {
337 if parent_values.len() != 2 {
339 return Err(TrustformersError::tensor_op_error(
340 "Multiply operation requires exactly 2 inputs",
341 "ComputationGraph::compute_operation_gradients",
342 ));
343 }
344 Ok(vec![
345 grad_output.mul(parent_values[1])?,
346 grad_output.mul(parent_values[0])?,
347 ])
348 },
349 OperationType::Divide => {
350 if parent_values.len() != 2 {
352 return Err(TrustformersError::tensor_op_error(
353 "Divide operation requires exactly 2 inputs",
354 "ComputationGraph::compute_operation_gradients",
355 ));
356 }
357 let a = parent_values[0];
358 let b = parent_values[1];
359 Ok(vec![
360 grad_output.div(b)?,
361 grad_output.mul(a)?.neg()?.div(&b.mul(b)?)?,
362 ])
363 },
364 OperationType::MatrixMultiply => {
365 if parent_values.len() != 2 {
367 return Err(TrustformersError::tensor_op_error(
368 "MatrixMultiply operation requires exactly 2 inputs",
369 "ComputationGraph::compute_operation_gradients",
370 ));
371 }
372 let a = parent_values[0];
373 let b = parent_values[1];
374
375 let a_shape = a.shape();
377 let b_shape = b.shape();
378
379 let grad_a = if a_shape.len() == 2 && b_shape.len() == 2 {
380 grad_output.matmul(&b.transpose(1, 0)?)?
382 } else {
383 let b_transposed = b.transpose(2, 1)?;
385 grad_output.matmul(&b_transposed)?
386 };
387
388 let grad_b = if a_shape.len() == 2 && b_shape.len() == 2 {
389 a.transpose(1, 0)?.matmul(grad_output)?
391 } else {
392 let a_transposed = a.permute(&[0, 2, 1])?;
394 a_transposed.matmul(grad_output)?
395 };
396
397 Ok(vec![grad_a, grad_b])
398 },
399 OperationType::Sigmoid => {
400 if parent_values.len() != 1 {
402 return Err(TrustformersError::tensor_op_error(
403 "Sigmoid operation requires exactly 1 input",
404 "ComputationGraph::compute_operation_gradients",
405 ));
406 }
407 let sigmoid_out = parent_values[0].sigmoid()?;
408 let one = Tensor::ones(&sigmoid_out.shape())?;
409 let grad_input = grad_output.mul(&sigmoid_out)?.mul(&one.sub(&sigmoid_out)?)?;
410 Ok(vec![grad_input])
411 },
412 OperationType::Tanh => {
413 if parent_values.len() != 1 {
415 return Err(TrustformersError::tensor_op_error(
416 "Tanh operation requires exactly 1 input",
417 "ComputationGraph::compute_operation_gradients",
418 ));
419 }
420 let tanh_out = parent_values[0].tanh()?;
421 let one = Tensor::ones(&tanh_out.shape())?;
422 let grad_input = grad_output.mul(&one.sub(&tanh_out.mul(&tanh_out)?)?)?;
423 Ok(vec![grad_input])
424 },
425 OperationType::ReLU => {
426 if parent_values.len() != 1 {
428 return Err(TrustformersError::tensor_op_error(
429 "ReLU operation requires exactly 1 input",
430 "ComputationGraph::compute_operation_gradients",
431 ));
432 }
433 let input = parent_values[0];
434 let zero = Tensor::zeros(&input.shape())?;
435 let mask = input.greater(&zero)?;
436 let grad_input = grad_output.mul(&mask)?;
437 Ok(vec![grad_input])
438 },
439 OperationType::LeakyReLU(alpha) => {
440 if parent_values.len() != 1 {
442 return Err(TrustformersError::tensor_op_error(
443 "LeakyReLU operation requires exactly 1 input",
444 "ComputationGraph::compute_operation_gradients",
445 ));
446 }
447 let input = parent_values[0];
448 let zero = Tensor::zeros(&input.shape())?;
449 let alpha_tensor = Tensor::scalar(*alpha)?;
450 let one = Tensor::ones(&input.shape())?;
451
452 let positive_mask = input.greater(&zero)?;
453 let negative_mask = one.sub(&positive_mask)?;
454
455 let grad_input =
456 grad_output.mul(&positive_mask.add(&negative_mask.mul(&alpha_tensor)?)?)?;
457 Ok(vec![grad_input])
458 },
459 OperationType::Sum(axes) => {
460 if parent_values.len() != 1 {
462 return Err(TrustformersError::tensor_op_error(
463 "Sum operation requires exactly 1 input",
464 "ComputationGraph::compute_operation_gradients",
465 ));
466 }
467 let input_shape = parent_values[0].shape();
468 let grad_input =
469 self.broadcast_gradient(grad_output, &input_shape, axes.as_ref())?;
470 Ok(vec![grad_input])
471 },
472 OperationType::Mean(axes) => {
473 if parent_values.len() != 1 {
475 return Err(TrustformersError::tensor_op_error(
476 "Mean operation requires exactly 1 input",
477 "ComputationGraph::compute_operation_gradients",
478 ));
479 }
480 let input_shape = parent_values[0].shape();
481 let grad_broadcasted =
482 self.broadcast_gradient(grad_output, &input_shape, axes.as_ref())?;
483
484 let num_elements = if let Some(axes) = axes {
486 axes.iter().map(|&axis| input_shape[axis]).product::<usize>()
487 } else {
488 input_shape.iter().product::<usize>()
489 };
490
491 let grad_input = grad_broadcasted.scalar_div(num_elements as f32)?;
492 Ok(vec![grad_input])
493 },
494 OperationType::Reshape(target_shape) => {
495 if parent_values.len() != 1 {
497 return Err(TrustformersError::tensor_op_error(
498 "Reshape operation requires exactly 1 input",
499 "ComputationGraph::compute_operation_gradients",
500 ));
501 }
502 let original_shape = parent_values[0].shape();
503 let grad_input = grad_output.reshape(&original_shape)?;
504 Ok(vec![grad_input])
505 },
506 OperationType::Transpose(permutation) => {
507 if parent_values.len() != 1 {
509 return Err(TrustformersError::tensor_op_error(
510 "Transpose operation requires exactly 1 input",
511 "ComputationGraph::compute_operation_gradients",
512 ));
513 }
514 let inverse_permutation = self.compute_inverse_permutation(permutation)?;
515 let grad_input = grad_output.permute(&inverse_permutation)?;
516 Ok(vec![grad_input])
517 },
518 _ => {
519 let zero_grads = parent_values
521 .iter()
522 .map(|input| {
523 Tensor::zeros(&input.shape()).expect("Failed to create zeros tensor")
524 })
525 .collect();
526 Ok(zero_grads)
527 },
528 }
529 }
530
531 fn broadcast_gradient(
533 &self,
534 grad_output: &Tensor,
535 original_shape: &[usize],
536 axes: Option<&Vec<usize>>,
537 ) -> Result<Tensor> {
538 if let Some(axes) = axes {
539 let mut result = grad_output.clone();
541 for &axis in axes {
542 result = result.unsqueeze(axis)?;
543 }
544 result.broadcast_to(original_shape)
545 } else {
546 let grad_scalar = grad_output.clone();
548 grad_scalar.broadcast_to(original_shape)
549 }
550 }
551
552 fn compute_inverse_permutation(&self, permutation: &[usize]) -> Result<Vec<usize>> {
554 let mut inverse = vec![0; permutation.len()];
555 for (i, &p) in permutation.iter().enumerate() {
556 if p >= permutation.len() {
557 return Err(TrustformersError::tensor_op_error(
558 &format!("Invalid permutation index: {}", p),
559 "ComputationGraph::compute_inverse_permutation",
560 ));
561 }
562 inverse[p] = i;
563 }
564 Ok(inverse)
565 }
566
567 pub fn zero_grad(&mut self) {
569 for node in self.nodes.values_mut() {
570 node.gradient = None;
571 }
572 }
573
574 pub fn get_gradient(&self, node_id: NodeId) -> Option<&Tensor> {
576 self.nodes.get(&node_id)?.gradient.as_ref()
577 }
578
579 pub fn get_value(&self, node_id: NodeId) -> Option<&Tensor> {
581 self.nodes.get(&node_id).map(|node| &node.value)
582 }
583
584 pub fn update_value(&mut self, node_id: NodeId, value: Tensor) -> Result<()> {
586 if let Some(node) = self.nodes.get_mut(&node_id) {
587 node.value = value;
588 node.shape = node.value.shape();
589 Ok(())
590 } else {
591 Err(TrustformersError::tensor_op_error(
592 &format!("Node {} not found", node_id),
593 "ComputationGraph::update_value",
594 ))
595 }
596 }
597
598 pub fn get_root_nodes(&self) -> &[NodeId] {
600 &self.root_nodes
601 }
602
603 pub fn get_leaf_nodes(&self) -> &[NodeId] {
605 &self.leaf_nodes
606 }
607
608 pub fn set_leaf_node(&mut self, node_id: NodeId) {
610 if !self.leaf_nodes.contains(&node_id) {
611 self.leaf_nodes.push(node_id);
612 }
613 }
614
615 pub fn num_nodes(&self) -> usize {
617 self.nodes.len()
618 }
619
620 pub fn get_topological_order(&self) -> &[NodeId] {
622 &self.topological_order
623 }
624
625 pub fn export_graph(&self) -> GraphExport {
627 let nodes: Vec<_> = self.nodes.values().cloned().collect();
628 GraphExport {
629 nodes,
630 topological_order: self.topological_order.clone(),
631 }
632 }
633}
634
635#[derive(Debug, Clone)]
637pub struct GraphExport {
638 pub nodes: Vec<GraphNode>,
639 pub topological_order: Vec<NodeId>,
640}
641
642impl Default for ComputationGraph {
643 fn default() -> Self {
644 Self::new()
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::tensor::Tensor;
652
653 #[test]
654 fn test_graph_creation() {
655 let mut graph = ComputationGraph::new();
656 assert_eq!(graph.num_nodes(), 0);
657
658 let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
659 let node_id = graph.add_node(tensor, true, Some("test".to_string()));
660 assert_eq!(graph.num_nodes(), 1);
661 assert_eq!(node_id, 0);
662 }
663
664 #[test]
665 fn test_topological_order() {
666 let mut graph = ComputationGraph::new();
667
668 let a = Tensor::ones(&[2, 2]).expect("Failed to create ones tensor");
670 let b = Tensor::ones(&[2, 2]).expect("Failed to create ones tensor");
671 let c = a.add(&b).expect("Addition failed");
672
673 let node_a = graph.add_node(a, true, Some("a".to_string()));
674 let node_b = graph.add_node(b, true, Some("b".to_string()));
675 let node_c = graph
676 .add_operation_node(
677 c,
678 OperationType::Add,
679 vec![node_a, node_b],
680 true,
681 Some("c".to_string()),
682 )
683 .expect("operation failed in test");
684
685 graph.compute_topological_order().expect("operation failed in test");
686 let order = graph.get_topological_order();
687 assert_eq!(order.len(), 3);
688
689 let a_pos = order.iter().position(|&id| id == node_a).expect("operation failed in test");
691 let b_pos = order.iter().position(|&id| id == node_b).expect("operation failed in test");
692 let c_pos = order.iter().position(|&id| id == node_c).expect("operation failed in test");
693
694 assert!(a_pos < c_pos);
695 assert!(b_pos < c_pos);
696 }
697
698 #[test]
699 fn test_backward_pass() {
700 let mut graph = ComputationGraph::new();
701
702 let a = Tensor::scalar(2.0).expect("tensor operation failed");
704 let b = Tensor::scalar(3.0).expect("tensor operation failed");
705 let c = a.mul(&b).expect("Multiplication failed");
706
707 let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
708 let node_b = graph.add_node(b.clone(), true, Some("b".to_string()));
709 let node_c = graph
710 .add_operation_node(
711 c,
712 OperationType::Multiply,
713 vec![node_a, node_b],
714 true,
715 Some("c".to_string()),
716 )
717 .expect("operation failed in test");
718
719 graph.backward(node_c, None).expect("operation failed in test");
721
722 let grad_a = graph.get_gradient(node_a).expect("operation failed in test");
724 let grad_b = graph.get_gradient(node_b).expect("operation failed in test");
725
726 assert_eq!(
729 grad_a.to_vec_f32().expect("operation failed in test")[0],
730 3.0
731 );
732 assert_eq!(
733 grad_b.to_vec_f32().expect("operation failed in test")[0],
734 2.0
735 );
736 }
737
738 #[test]
739 fn test_gradient_accumulation() {
740 let mut graph = ComputationGraph::new();
741
742 let a = Tensor::scalar(2.0).expect("tensor operation failed");
744 let d = a.add(&a).expect("Addition failed");
745
746 let node_a = graph.add_node(a.clone(), true, Some("a".to_string()));
747 let node_d = graph
748 .add_operation_node(
749 d,
750 OperationType::Add,
751 vec![node_a, node_a],
752 true,
753 Some("d".to_string()),
754 )
755 .expect("operation failed in test");
756
757 graph.backward(node_d, None).expect("operation failed in test");
759
760 let grad_a = graph.get_gradient(node_a).expect("operation failed in test");
762
763 assert_eq!(
765 grad_a.to_vec_f32().expect("operation failed in test")[0],
766 2.0
767 );
768 }
769}