trustformers_core/kernel_fusion/
graph.rs1use crate::kernel_fusion::operation_types::OperationType;
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct ComputationGraph {
12 pub nodes: HashMap<String, GraphNode>,
13 pub edges: HashMap<String, Vec<String>>, pub execution_order: Vec<String>,
15}
16
17#[derive(Debug, Clone)]
18pub struct GraphNode {
19 pub id: String,
20 pub operation: OperationType,
21 pub inputs: Vec<TensorInfo>,
22 pub outputs: Vec<TensorInfo>,
23 pub metadata: NodeMetadata,
24}
25
26#[derive(Debug, Clone)]
27pub struct TensorInfo {
28 pub shape: Vec<usize>,
29 pub dtype: DataType,
30 pub device: Device,
31 pub memory_layout: MemoryLayout,
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum DataType {
36 F32,
37 F16,
38 BF16,
39 I32,
40 I8,
41 U8,
42 Bool,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum Device {
47 CPU,
48 GPU(u32), ASIC(String),
50}
51
52#[derive(Debug, Clone, PartialEq)]
53pub enum MemoryLayout {
54 RowMajor,
55 ColumnMajor,
56 Blocked(Vec<usize>),
57 Tiled {
59 tile_sizes: Vec<usize>,
60 },
61 NCHW,
63 NHWC,
65 Packed {
67 elements_per_pack: usize,
68 },
69 Strided {
71 strides: Vec<usize>,
72 },
73}
74
75#[derive(Debug, Clone)]
76pub struct NodeMetadata {
77 pub estimated_ops: u64,
78 pub estimated_memory: usize,
79 pub is_fusible: bool,
80 pub fusion_priority: f64,
81 pub execution_time_ns: Option<u64>,
82}
83
84impl ComputationGraph {
85 pub fn new() -> Self {
86 Self {
87 nodes: HashMap::new(),
88 edges: HashMap::new(),
89 execution_order: Vec::new(),
90 }
91 }
92
93 pub fn add_node(&mut self, node: GraphNode) {
94 let node_id = node.id.clone();
95 self.nodes.insert(node_id.clone(), node);
96 self.edges.entry(node_id).or_default();
97 }
98
99 pub fn add_edge(&mut self, from: &str, to: &str) {
100 self.edges.entry(to.to_string()).or_default().push(from.to_string());
101 }
102
103 pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
104 self.nodes.get(id)
105 }
106
107 pub fn get_dependencies(&self, id: &str) -> Option<&Vec<String>> {
108 self.edges.get(id)
109 }
110}
111
112impl Default for ComputationGraph {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl GraphNode {
119 pub fn new(id: String, operation: OperationType) -> Self {
120 Self {
121 id,
122 operation,
123 inputs: Vec::new(),
124 outputs: Vec::new(),
125 metadata: NodeMetadata::default(),
126 }
127 }
128}
129
130impl Default for NodeMetadata {
131 fn default() -> Self {
132 Self {
133 estimated_ops: 0,
134 estimated_memory: 0,
135 is_fusible: true,
136 fusion_priority: 1.0,
137 execution_time_ns: None,
138 }
139 }
140}
141
142impl TensorInfo {
143 pub fn new(shape: Vec<usize>, dtype: DataType, device: Device) -> Self {
144 Self {
145 shape,
146 dtype,
147 device,
148 memory_layout: MemoryLayout::RowMajor,
149 }
150 }
151
152 pub fn element_count(&self) -> usize {
153 self.shape.iter().product()
154 }
155
156 pub fn memory_size(&self) -> usize {
157 self.element_count() * self.dtype.size_bytes()
158 }
159}
160
161impl DataType {
162 pub fn size_bytes(&self) -> usize {
163 match self {
164 DataType::F32 => 4,
165 DataType::F16 => 2,
166 DataType::BF16 => 2,
167 DataType::I32 => 4,
168 DataType::I8 => 1,
169 DataType::U8 => 1,
170 DataType::Bool => 1,
171 }
172 }
173}