tensorlogic_cli/
analysis.rs1use std::collections::{HashMap, HashSet};
4use tensorlogic_ir::{EinsumGraph, OpType};
5
6#[derive(Debug, Clone)]
8pub struct GraphMetrics {
9 pub tensor_count: usize,
11 pub node_count: usize,
13 pub input_count: usize,
15 pub output_count: usize,
17 pub op_breakdown: HashMap<String, usize>,
19 pub depth: usize,
21 pub avg_fanout: f64,
23 pub estimated_flops: u64,
25 pub estimated_memory: u64,
27}
28
29impl GraphMetrics {
30 pub fn analyze(graph: &EinsumGraph) -> Self {
32 let tensor_count = graph.tensors.len();
33 let node_count = graph.nodes.len();
34 let input_count = graph.inputs.len();
35 let output_count = graph.outputs.len();
36
37 let mut op_breakdown = HashMap::new();
39 for node in &graph.nodes {
40 let op_name = match &node.op {
41 OpType::Einsum { .. } => "Einsum",
42 OpType::ElemUnary { .. } => "ElemUnary",
43 OpType::ElemBinary { .. } => "ElemBinary",
44 OpType::Reduce { .. } => "Reduce",
45 };
46 *op_breakdown.entry(op_name.to_string()).or_insert(0) += 1;
47 }
48
49 let depth = calculate_depth(graph);
51
52 let total_outputs: usize = graph.nodes.iter().map(|n| n.outputs.len()).sum();
54 let avg_fanout = if node_count > 0 {
55 total_outputs as f64 / node_count as f64
56 } else {
57 0.0
58 };
59
60 let estimated_flops = estimate_flops(graph);
62 let estimated_memory = estimate_memory(graph);
63
64 Self {
65 tensor_count,
66 node_count,
67 input_count,
68 output_count,
69 op_breakdown,
70 depth,
71 avg_fanout,
72 estimated_flops,
73 estimated_memory,
74 }
75 }
76
77 pub fn print(&self) {
79 println!("Graph Metrics:");
80 println!(" Tensors: {}", self.tensor_count);
81 println!(" Nodes: {}", self.node_count);
82 println!(" Inputs: {}", self.input_count);
83 println!(" Outputs: {}", self.output_count);
84 println!(" Depth: {}", self.depth);
85 println!(" Avg Fanout: {:.2}", self.avg_fanout);
86 println!("\nOperation Breakdown:");
87 for (op, count) in &self.op_breakdown {
88 println!(" {}: {}", op, count);
89 }
90 println!("\nEstimates:");
91 println!(" FLOPs: {}", format_number(self.estimated_flops));
92 println!(" Memory: {}", format_bytes(self.estimated_memory));
93 }
94}
95
96fn calculate_depth(graph: &EinsumGraph) -> usize {
97 let mut depths = HashMap::new();
98
99 for input_id in &graph.inputs {
101 depths.insert(*input_id, 0);
102 }
103
104 let mut processed = HashSet::new();
106 let mut changed = true;
107
108 while changed {
109 changed = false;
110 for (node_idx, node) in graph.nodes.iter().enumerate() {
111 if processed.contains(&node_idx) {
112 continue;
113 }
114
115 let all_inputs_ready = node
117 .inputs
118 .iter()
119 .all(|input_id| depths.contains_key(input_id));
120
121 if all_inputs_ready {
122 let max_input_depth = node
124 .inputs
125 .iter()
126 .map(|id| *depths.get(id).unwrap_or(&0))
127 .max()
128 .unwrap_or(0);
129
130 let node_depth = max_input_depth + 1;
131
132 for output_id in &node.outputs {
134 depths.insert(*output_id, node_depth);
135 }
136
137 processed.insert(node_idx);
138 changed = true;
139 }
140 }
141 }
142
143 *depths.values().max().unwrap_or(&0)
145}
146
147fn estimate_flops(graph: &EinsumGraph) -> u64 {
148 let mut total_flops = 0u64;
149
150 for node in &graph.nodes {
151 let flops = match &node.op {
152 OpType::Einsum { .. } => {
153 2000
156 }
157 OpType::ElemUnary { .. } => {
158 1000
160 }
161 OpType::ElemBinary { .. } => {
162 1000
164 }
165 OpType::Reduce { .. } => {
166 999
168 }
169 };
170 total_flops += flops;
171 }
172
173 total_flops
174}
175
176fn estimate_memory(graph: &EinsumGraph) -> u64 {
177 let bytes_per_tensor = 8 * 1000;
179 (graph.tensors.len() as u64) * bytes_per_tensor
180}
181
182pub fn format_number(n: u64) -> String {
183 if n >= 1_000_000_000 {
184 format!("{:.2}B", n as f64 / 1_000_000_000.0)
185 } else if n >= 1_000_000 {
186 format!("{:.2}M", n as f64 / 1_000_000.0)
187 } else if n >= 1_000 {
188 format!("{:.2}K", n as f64 / 1_000.0)
189 } else {
190 n.to_string()
191 }
192}
193
194pub fn format_bytes(bytes: u64) -> String {
195 const KB: u64 = 1024;
196 const MB: u64 = KB * 1024;
197 const GB: u64 = MB * 1024;
198
199 if bytes >= GB {
200 format!("{:.2} GB", bytes as f64 / GB as f64)
201 } else if bytes >= MB {
202 format!("{:.2} MB", bytes as f64 / MB as f64)
203 } else if bytes >= KB {
204 format!("{:.2} KB", bytes as f64 / KB as f64)
205 } else {
206 format!("{} bytes", bytes)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn test_format_number() {
216 assert_eq!(format_number(500), "500");
217 assert_eq!(format_number(1500), "1.50K");
218 assert_eq!(format_number(1500000), "1.50M");
219 }
220
221 #[test]
222 fn test_format_bytes() {
223 assert_eq!(format_bytes(512), "512 bytes");
224 assert_eq!(format_bytes(2048), "2.00 KB");
225 assert_eq!(format_bytes(2 * 1024 * 1024), "2.00 MB");
226 }
227}