tensorlogic_cli/
analysis.rs

1//! Graph analysis and metrics for TensorLogic CLI
2
3use std::collections::{HashMap, HashSet};
4use tensorlogic_ir::{EinsumGraph, OpType};
5
6/// Graph complexity metrics
7#[derive(Debug, Clone)]
8pub struct GraphMetrics {
9    /// Total number of tensors
10    pub tensor_count: usize,
11    /// Total number of nodes
12    pub node_count: usize,
13    /// Number of input tensors
14    pub input_count: usize,
15    /// Number of output tensors
16    pub output_count: usize,
17    /// Operation type breakdown
18    pub op_breakdown: HashMap<String, usize>,
19    /// Graph depth (longest path)
20    pub depth: usize,
21    /// Average fanout (outputs per node)
22    pub avg_fanout: f64,
23    /// Estimated computational complexity (FLOPs)
24    pub estimated_flops: u64,
25    /// Estimated memory usage (bytes)
26    pub estimated_memory: u64,
27}
28
29impl GraphMetrics {
30    /// Analyze an einsum graph
31    pub fn analyze(graph: &EinsumGraph) -> Self {
32        let tensor_count = graph.tensors.len();
33        let node_count = graph.nodes.len();
34        let input_count = graph.inputs.len();
35        let output_count = graph.outputs.len();
36
37        // Operation breakdown
38        let mut op_breakdown = HashMap::new();
39        for node in &graph.nodes {
40            let op_name = match &node.op {
41                OpType::Einsum { .. } => "Einsum",
42                OpType::ElemUnary { .. } => "ElemUnary",
43                OpType::ElemBinary { .. } => "ElemBinary",
44                OpType::Reduce { .. } => "Reduce",
45            };
46            *op_breakdown.entry(op_name.to_string()).or_insert(0) += 1;
47        }
48
49        // Calculate depth
50        let depth = calculate_depth(graph);
51
52        // Calculate average fanout
53        let total_outputs: usize = graph.nodes.iter().map(|n| n.outputs.len()).sum();
54        let avg_fanout = if node_count > 0 {
55            total_outputs as f64 / node_count as f64
56        } else {
57            0.0
58        };
59
60        // Estimate FLOPs and memory
61        let estimated_flops = estimate_flops(graph);
62        let estimated_memory = estimate_memory(graph);
63
64        Self {
65            tensor_count,
66            node_count,
67            input_count,
68            output_count,
69            op_breakdown,
70            depth,
71            avg_fanout,
72            estimated_flops,
73            estimated_memory,
74        }
75    }
76
77    /// Print metrics in human-readable format
78    pub fn print(&self) {
79        println!("Graph Metrics:");
80        println!("  Tensors: {}", self.tensor_count);
81        println!("  Nodes: {}", self.node_count);
82        println!("  Inputs: {}", self.input_count);
83        println!("  Outputs: {}", self.output_count);
84        println!("  Depth: {}", self.depth);
85        println!("  Avg Fanout: {:.2}", self.avg_fanout);
86        println!("\nOperation Breakdown:");
87        for (op, count) in &self.op_breakdown {
88            println!("  {}: {}", op, count);
89        }
90        println!("\nEstimates:");
91        println!("  FLOPs: {}", format_number(self.estimated_flops));
92        println!("  Memory: {}", format_bytes(self.estimated_memory));
93    }
94}
95
96fn calculate_depth(graph: &EinsumGraph) -> usize {
97    let mut depths = HashMap::new();
98
99    // Initialize input tensors with depth 0
100    for input_id in &graph.inputs {
101        depths.insert(*input_id, 0);
102    }
103
104    // Topologically process nodes
105    let mut processed = HashSet::new();
106    let mut changed = true;
107
108    while changed {
109        changed = false;
110        for (node_idx, node) in graph.nodes.iter().enumerate() {
111            if processed.contains(&node_idx) {
112                continue;
113            }
114
115            // Check if all inputs are processed
116            let all_inputs_ready = node
117                .inputs
118                .iter()
119                .all(|input_id| depths.contains_key(input_id));
120
121            if all_inputs_ready {
122                // Calculate depth as max(input depths) + 1
123                let max_input_depth = node
124                    .inputs
125                    .iter()
126                    .map(|id| *depths.get(id).unwrap_or(&0))
127                    .max()
128                    .unwrap_or(0);
129
130                let node_depth = max_input_depth + 1;
131
132                // Set depth for all output tensors
133                for output_id in &node.outputs {
134                    depths.insert(*output_id, node_depth);
135                }
136
137                processed.insert(node_idx);
138                changed = true;
139            }
140        }
141    }
142
143    // Return maximum depth
144    *depths.values().max().unwrap_or(&0)
145}
146
147fn estimate_flops(graph: &EinsumGraph) -> u64 {
148    let mut total_flops = 0u64;
149
150    for node in &graph.nodes {
151        let flops = match &node.op {
152            OpType::Einsum { .. } => {
153                // Rough estimate: 2 FLOPs per element (multiply-add)
154                // Assume 1000 elements per tensor (very rough)
155                2000
156            }
157            OpType::ElemUnary { .. } => {
158                // 1 FLOP per element
159                1000
160            }
161            OpType::ElemBinary { .. } => {
162                // 1 FLOP per element
163                1000
164            }
165            OpType::Reduce { .. } => {
166                // Sum reduction: n-1 additions
167                999
168            }
169        };
170        total_flops += flops;
171    }
172
173    total_flops
174}
175
176fn estimate_memory(graph: &EinsumGraph) -> u64 {
177    // Assume f64 (8 bytes) and 1000 elements per tensor
178    let bytes_per_tensor = 8 * 1000;
179    (graph.tensors.len() as u64) * bytes_per_tensor
180}
181
182pub fn format_number(n: u64) -> String {
183    if n >= 1_000_000_000 {
184        format!("{:.2}B", n as f64 / 1_000_000_000.0)
185    } else if n >= 1_000_000 {
186        format!("{:.2}M", n as f64 / 1_000_000.0)
187    } else if n >= 1_000 {
188        format!("{:.2}K", n as f64 / 1_000.0)
189    } else {
190        n.to_string()
191    }
192}
193
194pub fn format_bytes(bytes: u64) -> String {
195    const KB: u64 = 1024;
196    const MB: u64 = KB * 1024;
197    const GB: u64 = MB * 1024;
198
199    if bytes >= GB {
200        format!("{:.2} GB", bytes as f64 / GB as f64)
201    } else if bytes >= MB {
202        format!("{:.2} MB", bytes as f64 / MB as f64)
203    } else if bytes >= KB {
204        format!("{:.2} KB", bytes as f64 / KB as f64)
205    } else {
206        format!("{} bytes", bytes)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_format_number() {
216        assert_eq!(format_number(500), "500");
217        assert_eq!(format_number(1500), "1.50K");
218        assert_eq!(format_number(1500000), "1.50M");
219    }
220
221    #[test]
222    fn test_format_bytes() {
223        assert_eq!(format_bytes(512), "512 bytes");
224        assert_eq!(format_bytes(2048), "2.00 KB");
225        assert_eq!(format_bytes(2 * 1024 * 1024), "2.00 MB");
226    }
227}