Skip to main content

trustformers_core/kernel_fusion/
graph.rs

1//! Computation graph structures for kernel fusion analysis
2//!
3//! This module provides data structures for representing computation graphs,
4//! nodes, tensor information, and related metadata used in fusion analysis.
5
6use crate::kernel_fusion::operation_types::OperationType;
7use std::collections::HashMap;
8
9/// Computation graph representation for fusion analysis
10#[derive(Debug, Clone)]
11pub struct ComputationGraph {
12    pub nodes: HashMap<String, GraphNode>,
13    pub edges: HashMap<String, Vec<String>>, // node_id -> dependencies
14    pub execution_order: Vec<String>,
15}
16
17#[derive(Debug, Clone)]
18pub struct GraphNode {
19    pub id: String,
20    pub operation: OperationType,
21    pub inputs: Vec<TensorInfo>,
22    pub outputs: Vec<TensorInfo>,
23    pub metadata: NodeMetadata,
24}
25
26#[derive(Debug, Clone)]
27pub struct TensorInfo {
28    pub shape: Vec<usize>,
29    pub dtype: DataType,
30    pub device: Device,
31    pub memory_layout: MemoryLayout,
32}
33
34#[derive(Debug, Clone, PartialEq)]
35pub enum DataType {
36    F32,
37    F16,
38    BF16,
39    I32,
40    I8,
41    U8,
42    Bool,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum Device {
47    CPU,
48    GPU(u32), // GPU device ID
49    ASIC(String),
50}
51
52#[derive(Debug, Clone, PartialEq)]
53pub enum MemoryLayout {
54    RowMajor,
55    ColumnMajor,
56    Blocked(Vec<usize>),
57    /// Cache-optimized tiled layout for better spatial locality
58    Tiled {
59        tile_sizes: Vec<usize>,
60    },
61    /// NCHW format commonly used in computer vision
62    NCHW,
63    /// NHWC format for better memory coalescing on some devices
64    NHWC,
65    /// Packed format for quantized tensors
66    Packed {
67        elements_per_pack: usize,
68    },
69    /// Strided layout with custom strides
70    Strided {
71        strides: Vec<usize>,
72    },
73}
74
75#[derive(Debug, Clone)]
76pub struct NodeMetadata {
77    pub estimated_ops: u64,
78    pub estimated_memory: usize,
79    pub is_fusible: bool,
80    pub fusion_priority: f64,
81    pub execution_time_ns: Option<u64>,
82}
83
84impl ComputationGraph {
85    pub fn new() -> Self {
86        Self {
87            nodes: HashMap::new(),
88            edges: HashMap::new(),
89            execution_order: Vec::new(),
90        }
91    }
92
93    pub fn add_node(&mut self, node: GraphNode) {
94        let node_id = node.id.clone();
95        self.nodes.insert(node_id.clone(), node);
96        self.edges.entry(node_id).or_default();
97    }
98
99    pub fn add_edge(&mut self, from: &str, to: &str) {
100        self.edges.entry(to.to_string()).or_default().push(from.to_string());
101    }
102
103    pub fn get_node(&self, id: &str) -> Option<&GraphNode> {
104        self.nodes.get(id)
105    }
106
107    pub fn get_dependencies(&self, id: &str) -> Option<&Vec<String>> {
108        self.edges.get(id)
109    }
110}
111
112impl Default for ComputationGraph {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118impl GraphNode {
119    pub fn new(id: String, operation: OperationType) -> Self {
120        Self {
121            id,
122            operation,
123            inputs: Vec::new(),
124            outputs: Vec::new(),
125            metadata: NodeMetadata::default(),
126        }
127    }
128}
129
130impl Default for NodeMetadata {
131    fn default() -> Self {
132        Self {
133            estimated_ops: 0,
134            estimated_memory: 0,
135            is_fusible: true,
136            fusion_priority: 1.0,
137            execution_time_ns: None,
138        }
139    }
140}
141
142impl TensorInfo {
143    pub fn new(shape: Vec<usize>, dtype: DataType, device: Device) -> Self {
144        Self {
145            shape,
146            dtype,
147            device,
148            memory_layout: MemoryLayout::RowMajor,
149        }
150    }
151
152    pub fn element_count(&self) -> usize {
153        self.shape.iter().product()
154    }
155
156    pub fn memory_size(&self) -> usize {
157        self.element_count() * self.dtype.size_bytes()
158    }
159}
160
161impl DataType {
162    pub fn size_bytes(&self) -> usize {
163        match self {
164            DataType::F32 => 4,
165            DataType::F16 => 2,
166            DataType::BF16 => 2,
167            DataType::I32 => 4,
168            DataType::I8 => 1,
169            DataType::U8 => 1,
170            DataType::Bool => 1,
171        }
172    }
173}