1use crate::{JitError, JitResult};
4use petgraph::graph::{DiGraph, NodeIndex};
5use petgraph::visit::EdgeRef;
6use petgraph::Direction;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use torsh_core::{DType, DeviceType, Shape};
10
11pub use crate::graph::metadata::GraphMetadata;
12pub use crate::graph::operations::Operation;
13
14pub type NodeId = NodeIndex;
15
16#[derive(Debug, Clone, Default)]
18pub struct Edge {
19 pub src_output: usize,
21 pub dst_input: usize,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
27pub struct SerializableNodeIndex(pub u32);
28
29impl From<NodeIndex> for SerializableNodeIndex {
30 fn from(node_index: NodeIndex) -> Self {
31 SerializableNodeIndex(node_index.index() as u32)
32 }
33}
34
35impl From<SerializableNodeIndex> for NodeIndex {
36 fn from(serializable: SerializableNodeIndex) -> Self {
37 NodeIndex::new(serializable.0 as usize)
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct Node {
44 pub operation: Operation,
46
47 pub name: String,
49
50 pub input_shapes: Vec<Option<Shape>>,
52
53 pub output_shapes: Vec<Option<Shape>>,
55
56 pub dtypes: Vec<DType>,
58
59 pub device: DeviceType,
61
62 pub attributes: HashMap<String, crate::graph::operations::Attribute>,
64
65 pub op: Operation,
68
69 pub dtype: DType,
71
72 pub output_shape: Shape,
74
75 pub attrs: HashMap<String, crate::graph::operations::Attribute>,
77
78 pub inputs: Vec<NodeId>,
80
81 pub is_output: bool,
83}
84
85impl Node {
86 pub fn new(operation: Operation, name: String) -> Self {
88 let op = operation.clone();
89 let dtype = DType::F32; let output_shape = Shape::new(vec![1]); let attributes = HashMap::new();
92
93 Self {
94 operation,
95 name,
96 input_shapes: Vec::new(),
97 output_shapes: Vec::new(),
98 dtypes: Vec::new(),
99 device: DeviceType::Cpu,
100 attributes: attributes.clone(),
101
102 op,
104 dtype,
105 output_shape,
106 attrs: attributes,
107 inputs: Vec::new(),
108 is_output: false,
109 }
110 }
111
112 pub fn with_input_shapes(mut self, shapes: Vec<Option<Shape>>) -> Self {
114 self.input_shapes = shapes;
115 self.sync_compatibility_fields();
116 self
117 }
118
119 pub fn with_output_shapes(mut self, shapes: Vec<Option<Shape>>) -> Self {
121 self.output_shapes = shapes;
122 self.sync_compatibility_fields();
123 self
124 }
125
126 pub fn with_dtypes(mut self, dtypes: Vec<DType>) -> Self {
128 self.dtypes = dtypes;
129 self.sync_compatibility_fields();
130 self
131 }
132
133 pub fn with_device(mut self, device: DeviceType) -> Self {
135 self.device = device;
136 self
137 }
138
139 pub fn with_attribute(
141 mut self,
142 key: String,
143 value: crate::graph::operations::Attribute,
144 ) -> Self {
145 self.attributes.insert(key, value);
146 self.sync_compatibility_fields();
147 self
148 }
149
150 pub fn num_inputs(&self) -> usize {
152 self.input_shapes.len()
153 }
154
155 pub fn num_outputs(&self) -> usize {
157 self.output_shapes.len().max(1) }
159
160 pub fn input_shape(&self, index: usize) -> Option<&Shape> {
162 self.input_shapes.get(index).and_then(|s| s.as_ref())
163 }
164
165 pub fn output_shape(&self, index: usize) -> Option<&Shape> {
167 self.output_shapes.get(index).and_then(|s| s.as_ref())
168 }
169
170 pub fn dtype(&self, index: usize) -> Option<&DType> {
172 self.dtypes.get(index)
173 }
174
175 pub fn is_input(&self) -> bool {
177 matches!(self.operation, Operation::Input | Operation::Parameter(_))
178 }
179
180 pub fn is_constant(&self) -> bool {
182 matches!(self.operation, Operation::Constant(_))
183 }
184
185 pub fn is_control_flow(&self) -> bool {
187 matches!(
188 self.operation,
189 Operation::If(_)
190 | Operation::While(_)
191 | Operation::For(_)
192 | Operation::Break
193 | Operation::Continue
194 | Operation::Return(_)
195 | Operation::Block(_)
196 | Operation::Merge(_)
197 )
198 }
199
200 pub fn memory_estimate(&self) -> usize {
202 let mut total = 0;
203 for shape_opt in &self.output_shapes {
204 if let Some(shape) = shape_opt {
205 let elements = shape.dims().iter().product::<usize>();
206 total += elements * 4;
208 }
209 }
210 total
211 }
212
213 pub fn complexity_estimate(&self) -> usize {
215 match &self.operation {
216 Operation::MatMul | Operation::BatchMatMul => {
217 if self.input_shapes.len() >= 2 {
218 if let (Some(Some(a_shape)), Some(Some(b_shape))) =
219 (self.input_shapes.get(0), self.input_shapes.get(1))
220 {
221 if a_shape.dims().len() >= 2 && b_shape.dims().len() >= 2 {
223 let m = a_shape.dims()[a_shape.dims().len() - 2];
224 let k = a_shape.dims()[a_shape.dims().len() - 1];
225 let n = b_shape.dims()[b_shape.dims().len() - 1];
226 return 2 * m * n * k;
227 }
228 }
229 }
230 0
231 }
232 Operation::Conv2d(_) => {
233 if let Some(Some(output_shape)) = self.output_shapes.get(0) {
235 output_shape.dims().iter().product::<usize>() * 9 } else {
237 0
238 }
239 }
240 _ => {
241 if let Some(Some(output_shape)) = self.output_shapes.get(0) {
243 output_shape.dims().iter().product::<usize>()
244 } else {
245 1
246 }
247 }
248 }
249 }
250
251 pub fn sync_compatibility_fields(&mut self) {
253 self.op = self.operation.clone();
254 self.dtype = self.dtypes.first().copied().unwrap_or(DType::F32);
255 self.output_shape = self
256 .output_shapes
257 .first()
258 .and_then(|s| s.as_ref())
259 .cloned()
260 .unwrap_or_else(|| Shape::new(vec![1]));
261 self.attrs = self.attributes.clone();
262 }
263
264 pub fn set_attribute(&mut self, key: String, value: crate::graph::operations::Attribute) {
266 self.attributes.insert(key.clone(), value.clone());
267 self.attrs.insert(key, value);
268 }
269
270 pub fn set_optimization_hint(&mut self, hint: &str, value: &str) -> crate::JitResult<()> {
272 let attr_value = crate::graph::operations::Attribute::String(value.to_string());
273 self.set_attribute(hint.to_string(), attr_value);
274 Ok(())
275 }
276
277 pub fn get_attribute(&self, key: &str) -> Option<&crate::graph::operations::Attribute> {
279 self.attributes.get(key)
280 }
281
282 pub fn operation_type(&self) -> &str {
284 self.operation.as_str()
285 }
286
287 pub fn has_side_effects(&self) -> bool {
289 matches!(
290 self.operation,
291 Operation::Custom(_) | Operation::Break | Operation::Continue | Operation::Return(_)
292 )
293 }
294
295 pub fn operation_category(&self) -> OperationCategory {
297 match &self.operation {
298 Operation::Add
299 | Operation::Sub
300 | Operation::Mul
301 | Operation::Div
302 | Operation::Neg
303 | Operation::Abs
304 | Operation::Exp
305 | Operation::Log
306 | Operation::Sqrt
307 | Operation::Sin
308 | Operation::Cos
309 | Operation::Tanh
310 | Operation::Sigmoid
311 | Operation::Relu
312 | Operation::Gelu
313 | Operation::Silu => OperationCategory::ElementWise,
314 Operation::MatMul | Operation::BatchMatMul => OperationCategory::LinearAlgebra,
315 Operation::Conv2d(_) | Operation::Linear(_) => OperationCategory::NeuralNetwork,
316 Operation::Sum { .. }
317 | Operation::Mean { .. }
318 | Operation::Max { .. }
319 | Operation::Min { .. } => OperationCategory::Reduction,
320 Operation::Reshape { .. }
321 | Operation::Transpose { .. }
322 | Operation::Squeeze { .. }
323 | Operation::Unsqueeze { .. }
324 | Operation::Slice { .. }
325 | Operation::Concat { .. } => OperationCategory::ShapeManipulation,
326 Operation::If(_)
327 | Operation::While(_)
328 | Operation::For(_)
329 | Operation::Break
330 | Operation::Continue
331 | Operation::Return(_)
332 | Operation::Block(_)
333 | Operation::Merge(_) => OperationCategory::ControlFlow,
334 Operation::Input | Operation::Parameter(_) | Operation::Constant(_) => {
335 OperationCategory::Input
336 }
337 _ => OperationCategory::Other,
338 }
339 }
340
341 pub fn is_vectorizable(&self) -> bool {
343 match &self.operation {
344 Operation::Add
346 | Operation::Sub
347 | Operation::Mul
348 | Operation::Div
349 | Operation::Neg
350 | Operation::Abs
351 | Operation::Exp
352 | Operation::Log
353 | Operation::Sqrt
354 | Operation::Sin
355 | Operation::Cos
356 | Operation::Tanh
357 | Operation::Sigmoid
358 | Operation::Relu
359 | Operation::Gelu
360 | Operation::Silu => true,
361 Operation::MatMul | Operation::BatchMatMul => true,
363 Operation::Sum { .. }
365 | Operation::Mean { .. }
366 | Operation::Max { .. }
367 | Operation::Min { .. } => true,
368 Operation::Conv2d(_) => true,
370 _ => false,
372 }
373 }
374
375 pub fn has_memory_access(&self) -> bool {
377 match &self.operation {
378 Operation::Input | Operation::Parameter(_) | Operation::Constant(_) => false,
380 Operation::Break | Operation::Continue | Operation::Return(_) => false,
382 _ => true,
384 }
385 }
386
387 pub fn estimate_working_set_size(&self) -> usize {
389 let mut working_set = 0;
390
391 for shape_opt in &self.input_shapes {
393 if let Some(shape) = shape_opt {
394 let elements = shape.dims().iter().product::<usize>();
395 working_set += elements * 4;
397 }
398 }
399
400 for shape_opt in &self.output_shapes {
402 if let Some(shape) = shape_opt {
403 let elements = shape.dims().iter().product::<usize>();
404 working_set += elements * 4;
405 }
406 }
407
408 match &self.operation {
410 Operation::MatMul | Operation::BatchMatMul => {
411 working_set * 2
413 }
414 Operation::Conv2d(_) => {
415 working_set * 3
417 }
418 _ => working_set,
419 }
420 }
421}
422
423#[derive(Debug, Clone, PartialEq, Eq)]
425pub enum OperationCategory {
426 ElementWise,
427 LinearAlgebra,
428 NeuralNetwork,
429 Reduction,
430 ShapeManipulation,
431 ControlFlow,
432 Input,
433 Other,
434}
435
436#[derive(Debug, Clone)]
438pub struct ComputationGraph {
439 pub(crate) graph: DiGraph<Node, Edge>,
441
442 pub inputs: Vec<NodeId>,
444
445 pub outputs: Vec<NodeId>,
447
448 pub metadata: GraphMetadata,
450}
451
452impl ComputationGraph {
453 pub fn new() -> Self {
455 Self {
456 graph: DiGraph::new(),
457 inputs: Vec::new(),
458 outputs: Vec::new(),
459 metadata: GraphMetadata::default(),
460 }
461 }
462
463 pub fn add_node(&mut self, node: Node) -> NodeId {
465 self.graph.add_node(node)
466 }
467
468 pub fn add_edge(&mut self, from: NodeId, to: NodeId, edge: Edge) {
470 self.graph.add_edge(from, to, edge);
471 }
472
473 pub fn add_input(&mut self, node: NodeId) {
475 if !self.inputs.contains(&node) {
476 self.inputs.push(node);
477 }
478 }
479
480 pub fn add_output(&mut self, node: NodeId) {
482 if !self.outputs.contains(&node) {
483 self.outputs.push(node);
484 }
485 }
486
487 pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &Node)> {
489 self.graph
490 .node_indices()
491 .map(move |idx| (idx, &self.graph[idx]))
492 }
493
494 pub fn edges(&self) -> impl Iterator<Item = (NodeId, NodeId, &Edge)> + '_ {
496 self.graph.edge_indices().map(move |idx| {
497 let (src, dst) = self
498 .graph
499 .edge_endpoints(idx)
500 .expect("edge index should be valid");
501 (src, dst, &self.graph[idx])
502 })
503 }
504
505 pub fn get_node(&self, id: NodeId) -> Option<&Node> {
507 self.graph.node_weight(id)
508 }
509
510 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
512 self.graph.node_weight_mut(id)
513 }
514
515 pub fn get_node_inputs(&self, id: NodeId) -> Vec<NodeId> {
517 self.graph
518 .neighbors_directed(id, Direction::Incoming)
519 .collect()
520 }
521
522 pub fn get_node_outputs(&self, id: NodeId) -> Vec<NodeId> {
524 self.graph
525 .neighbors_directed(id, Direction::Outgoing)
526 .collect()
527 }
528
529 pub fn incoming_edges(&self, id: NodeId) -> Vec<(NodeId, NodeId, &Edge)> {
531 self.graph
532 .edges_directed(id, Direction::Incoming)
533 .map(|edge_ref| (edge_ref.source(), edge_ref.target(), edge_ref.weight()))
534 .collect()
535 }
536
537 pub fn outgoing_edges(&self, id: NodeId) -> Vec<(NodeId, NodeId, &Edge)> {
539 self.graph
540 .edges_directed(id, Direction::Outgoing)
541 .map(|edge_ref| (edge_ref.source(), edge_ref.target(), edge_ref.weight()))
542 .collect()
543 }
544
545 pub fn remove_node(&mut self, id: NodeId) -> Option<Node> {
547 self.inputs.retain(|&x| x != id);
549 self.outputs.retain(|&x| x != id);
550
551 self.graph.remove_node(id)
552 }
553
554 pub fn remove_edge(&mut self, from: NodeId, to: NodeId) -> bool {
556 if let Some(edge_id) = self.graph.find_edge(from, to) {
557 self.graph.remove_edge(edge_id).is_some()
558 } else {
559 false
560 }
561 }
562
563 pub fn node_count(&self) -> usize {
565 self.graph.node_count()
566 }
567
568 pub fn edge_count(&self) -> usize {
570 self.graph.edge_count()
571 }
572
573 pub fn is_empty(&self) -> bool {
575 self.graph.node_count() == 0
576 }
577
578 pub fn validate(&self) -> JitResult<()> {
580 for &input_id in &self.inputs {
582 if self.graph.node_weight(input_id).is_none() {
583 return Err(JitError::GraphError(format!(
584 "Input node {:?} does not exist in graph",
585 input_id
586 )));
587 }
588 }
589
590 for &output_id in &self.outputs {
591 if self.graph.node_weight(output_id).is_none() {
592 return Err(JitError::GraphError(format!(
593 "Output node {:?} does not exist in graph",
594 output_id
595 )));
596 }
597 }
598
599 self.validate_acyclic()?;
601
602 Ok(())
603 }
604
605 fn validate_acyclic(&self) -> JitResult<()> {
607 use petgraph::algo::is_cyclic_directed;
608
609 if is_cyclic_directed(&self.graph) {
610 return Err(JitError::GraphError("Graph contains cycles".to_string()));
611 }
612
613 Ok(())
614 }
615
616 pub fn topological_sort(&self) -> JitResult<Vec<NodeId>> {
618 use petgraph::algo::toposort;
619
620 toposort(&self.graph, None)
621 .map_err(|_| JitError::GraphError("Graph contains cycles".to_string()))
622 }
623
624 pub fn subgraph(&self, node_ids: &[NodeId]) -> JitResult<ComputationGraph> {
626 let mut new_graph = ComputationGraph::new();
627 let mut node_mapping = HashMap::new();
628
629 for &node_id in node_ids {
631 if let Some(node) = self.get_node(node_id) {
632 let new_id = new_graph.add_node(node.clone());
633 node_mapping.insert(node_id, new_id);
634 } else {
635 return Err(JitError::GraphError(format!(
636 "Node {:?} not found in original graph",
637 node_id
638 )));
639 }
640 }
641
642 for &src_id in node_ids {
644 for &dst_id in node_ids {
645 if let Some(edge_ref) = self.graph.find_edge(src_id, dst_id) {
646 let edge = self.graph.edge_weight(edge_ref).expect("edge should exist");
647 let new_src = node_mapping[&src_id];
648 let new_dst = node_mapping[&dst_id];
649 new_graph.add_edge(new_src, new_dst, edge.clone());
650 }
651 }
652 }
653
654 for &input_id in &self.inputs {
656 if let Some(&new_id) = node_mapping.get(&input_id) {
657 new_graph.add_input(new_id);
658 }
659 }
660
661 for &output_id in &self.outputs {
662 if let Some(&new_id) = node_mapping.get(&output_id) {
663 new_graph.add_output(new_id);
664 }
665 }
666
667 new_graph.metadata = self.metadata.clone();
668
669 Ok(new_graph)
670 }
671
672 pub fn strongly_connected_components(&self) -> Vec<Vec<NodeId>> {
674 use petgraph::algo::tarjan_scc;
675 tarjan_scc(&self.graph)
676 }
677
678 pub fn memory_estimate(&self) -> usize {
680 self.graph
681 .node_weights()
682 .map(|node| node.memory_estimate())
683 .sum()
684 }
685
686 pub fn complexity_estimate(&self) -> usize {
688 self.graph
689 .node_weights()
690 .map(|node| node.complexity_estimate())
691 .sum()
692 }
693
694 pub fn predecessors(&self, node_id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
696 self.graph.neighbors_directed(node_id, Direction::Incoming)
697 }
698
699 pub fn successors(&self, node_id: NodeId) -> impl Iterator<Item = NodeId> + '_ {
701 self.graph.neighbors_directed(node_id, Direction::Outgoing)
702 }
703
704 pub fn node(&self, id: NodeId) -> Option<&Node> {
706 self.get_node(id)
707 }
708
709 pub fn node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
711 self.get_node_mut(id)
712 }
713
714 pub fn edges_directed(
716 &self,
717 node_id: NodeId,
718 direction: Direction,
719 ) -> impl Iterator<Item = petgraph::graph::EdgeReference<'_, Edge>> {
720 self.graph.edges_directed(node_id, direction)
721 }
722
723 pub fn is_acyclic(&self) -> bool {
725 use petgraph::algo::is_cyclic_directed;
726 !is_cyclic_directed(&self.graph)
727 }
728
729 pub fn replace_node_with_input(
746 &mut self,
747 node_id: NodeId,
748 replacement_id: NodeId,
749 ) -> crate::JitResult<()> {
750 let is_predecessor = self
752 .predecessors(node_id)
753 .any(|pred| pred == replacement_id);
754
755 if !is_predecessor {
756 return Err(crate::JitError::CompilationError(format!(
757 "Node {:?} is not a predecessor of node {:?}",
758 replacement_id, node_id
759 )));
760 }
761
762 let successors: Vec<(NodeId, Edge)> = self
764 .graph
765 .edges_directed(node_id, Direction::Outgoing)
766 .map(|edge_ref| (edge_ref.target(), edge_ref.weight().clone()))
767 .collect();
768
769 for (successor_id, edge) in successors {
771 self.graph.add_edge(replacement_id, successor_id, edge);
772 }
773
774 if let Some(pos) = self.outputs.iter().position(|&id| id == node_id) {
776 self.outputs[pos] = replacement_id;
777 }
778
779 self.remove_node(node_id);
781
782 Ok(())
783 }
784
785 pub fn replace_node_with_sequence(
803 &mut self,
804 node_id: NodeId,
805 sequence: &[Node],
806 ) -> crate::JitResult<()> {
807 if sequence.is_empty() {
808 return Err(crate::JitError::CompilationError(
809 "Cannot replace node with empty sequence".to_string(),
810 ));
811 }
812
813 let sequence_ids: Vec<NodeId> = sequence
815 .iter()
816 .map(|node| self.graph.add_node(node.clone()))
817 .collect();
818
819 let first_id = sequence_ids[0];
820 let last_id = *sequence_ids.last().expect("sequence should not be empty");
821
822 for window in sequence_ids.windows(2) {
824 let edge = Edge {
825 src_output: 0,
826 dst_input: 0,
827 };
828 self.graph.add_edge(window[0], window[1], edge);
829 }
830
831 let predecessors: Vec<(NodeId, Edge)> = self
833 .graph
834 .edges_directed(node_id, Direction::Incoming)
835 .map(|edge_ref| (edge_ref.source(), edge_ref.weight().clone()))
836 .collect();
837
838 for (pred_id, edge) in predecessors {
840 self.graph.add_edge(pred_id, first_id, edge);
841 }
842
843 let successors: Vec<(NodeId, Edge)> = self
845 .graph
846 .edges_directed(node_id, Direction::Outgoing)
847 .map(|edge_ref| (edge_ref.target(), edge_ref.weight().clone()))
848 .collect();
849
850 for (succ_id, edge) in successors {
852 self.graph.add_edge(last_id, succ_id, edge);
853 }
854
855 if let Some(pos) = self.inputs.iter().position(|&id| id == node_id) {
857 self.inputs[pos] = first_id;
858 }
859
860 if let Some(pos) = self.outputs.iter().position(|&id| id == node_id) {
862 self.outputs[pos] = last_id;
863 }
864
865 self.remove_node(node_id);
867
868 Ok(())
869 }
870}
871
872impl Default for ComputationGraph {
873 fn default() -> Self {
874 Self::new()
875 }
876}
877
878pub fn shape_from_slice(dims: &[usize]) -> Shape {
880 Shape::new(dims.to_vec())
881}