Skip to main content

torsh_cli/commands/model/
enhanced_profiling.rs

1//! Enhanced model profiling with layer-by-layer analysis
2//!
3//! This module provides detailed profiling of models including memory,
4//! computation, and performance analysis for each layer.
5
6// Infrastructure module - functions designed for CLI command integration
7#![allow(dead_code)]
8
9use anyhow::Result;
10use std::collections::HashMap;
11use std::time::Instant;
12use tracing::{debug, info};
13
14// ✅ SciRS2 POLICY COMPLIANT: Use scirs2-core unified access patterns
15use scirs2_core::random::{thread_rng, Distribution, Normal};
16
17// ToRSh integration
18use torsh::core::device::DeviceType;
19use torsh::tensor::Tensor;
20
21use super::tensor_integration::estimate_tensor_flops;
22use super::types::{LayerInfo, TorshModel};
23
24/// Layer profiling result
25#[derive(Debug, Clone)]
26pub struct LayerProfile {
27    /// Layer name
28    pub name: String,
29    /// Layer type (Linear, Conv2d, etc.)
30    pub layer_type: String,
31    /// Input shape
32    pub input_shape: Vec<usize>,
33    /// Output shape
34    pub output_shape: Vec<usize>,
35    /// Number of parameters
36    pub parameters: u64,
37    /// Memory footprint (bytes)
38    pub memory_bytes: u64,
39    /// Estimated FLOPs per forward pass
40    pub flops: u64,
41    /// Execution time (ms) - if measured
42    pub execution_time_ms: Option<f64>,
43    /// Memory usage during execution (MB) - if measured
44    pub runtime_memory_mb: Option<f64>,
45    /// Percentage of total model parameters
46    pub param_percentage: f64,
47    /// Percentage of total model FLOPs
48    pub flops_percentage: f64,
49}
50
51/// Complete model profile
52#[derive(Debug, Clone)]
53pub struct ModelProfile {
54    /// Individual layer profiles
55    pub layers: Vec<LayerProfile>,
56    /// Total parameters
57    pub total_parameters: u64,
58    /// Total FLOPs
59    pub total_flops: u64,
60    /// Total memory footprint (bytes)
61    pub total_memory: u64,
62    /// Execution time breakdown
63    pub execution_breakdown: HashMap<String, f64>,
64    /// Memory hotspots (layers using most memory)
65    pub memory_hotspots: Vec<(String, u64)>,
66    /// Computation hotspots (layers with most FLOPs)
67    pub computation_hotspots: Vec<(String, u64)>,
68}
69
70/// Profile a model's layers
71pub fn profile_model(model: &TorshModel) -> Result<ModelProfile> {
72    info!("Profiling model with {} layers", model.layers.len());
73
74    // First pass: calculate totals
75    let total_parameters: u64 = model.layers.iter().map(|l| l.parameters).sum();
76
77    let total_flops: u64 = model.layers.iter().map(|l| estimate_layer_flops(l)).sum();
78
79    let total_memory: u64 = model
80        .weights
81        .values()
82        .map(|t| {
83            let elements: usize = t.shape.iter().product();
84            (elements * t.dtype.size_bytes()) as u64
85        })
86        .sum();
87
88    debug!(
89        "Model totals: {} params, {} FLOPs, {:.2} MB",
90        total_parameters,
91        total_flops,
92        total_memory as f64 / (1024.0 * 1024.0)
93    );
94
95    // Second pass: profile each layer
96    let mut layer_profiles = Vec::new();
97
98    for layer in &model.layers {
99        let flops = estimate_layer_flops(layer);
100        let memory = calculate_layer_memory(layer, model);
101
102        let param_percentage = if total_parameters > 0 {
103            (layer.parameters as f64 / total_parameters as f64) * 100.0
104        } else {
105            0.0
106        };
107
108        let flops_percentage = if total_flops > 0 {
109            (flops as f64 / total_flops as f64) * 100.0
110        } else {
111            0.0
112        };
113
114        let profile = LayerProfile {
115            name: layer.name.clone(),
116            layer_type: layer.layer_type.clone(),
117            input_shape: layer.input_shape.clone(),
118            output_shape: layer.output_shape.clone(),
119            parameters: layer.parameters,
120            memory_bytes: memory,
121            flops,
122            execution_time_ms: None, // Will be filled by runtime profiling
123            runtime_memory_mb: None,
124            param_percentage,
125            flops_percentage,
126        };
127
128        layer_profiles.push(profile);
129    }
130
131    // Identify hotspots
132    let mut memory_hotspots: Vec<(String, u64)> = layer_profiles
133        .iter()
134        .map(|p| (p.name.clone(), p.memory_bytes))
135        .collect();
136    memory_hotspots.sort_by(|a, b| b.1.cmp(&a.1));
137    memory_hotspots.truncate(5);
138
139    let mut computation_hotspots: Vec<(String, u64)> = layer_profiles
140        .iter()
141        .map(|p| (p.name.clone(), p.flops))
142        .collect();
143    computation_hotspots.sort_by(|a, b| b.1.cmp(&a.1));
144    computation_hotspots.truncate(5);
145
146    Ok(ModelProfile {
147        layers: layer_profiles,
148        total_parameters,
149        total_flops,
150        total_memory,
151        execution_breakdown: HashMap::new(),
152        memory_hotspots,
153        computation_hotspots,
154    })
155}
156
157/// Estimate FLOPs for a layer
158fn estimate_layer_flops(layer: &LayerInfo) -> u64 {
159    estimate_tensor_flops(
160        &layer.layer_type.to_lowercase(),
161        &layer.input_shape,
162        &layer.output_shape,
163    )
164}
165
166/// Calculate memory footprint for a layer
167fn calculate_layer_memory(layer: &LayerInfo, model: &TorshModel) -> u64 {
168    let weight_name = format!("{}.weight", layer.name);
169    let bias_name = format!("{}.bias", layer.name);
170
171    let mut memory = 0u64;
172
173    if let Some(weight) = model.weights.get(&weight_name) {
174        let elements: usize = weight.shape.iter().product();
175        memory += (elements * weight.dtype.size_bytes()) as u64;
176    }
177
178    if let Some(bias) = model.weights.get(&bias_name) {
179        let elements: usize = bias.shape.iter().product();
180        memory += (elements * bias.dtype.size_bytes()) as u64;
181    }
182
183    memory
184}
185
186/// Format model profile as human-readable text
187pub fn format_model_profile(profile: &ModelProfile) -> String {
188    let mut output = String::new();
189
190    output.push_str("╔═══════════════════════════════════════════════════════════════════════╗\n");
191    output.push_str("║                        MODEL PROFILE REPORT                           ║\n");
192    output
193        .push_str("╚═══════════════════════════════════════════════════════════════════════╝\n\n");
194
195    // Summary section
196    output.push_str("📊 Overall Statistics\n");
197    output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
198    output.push_str(&format!(
199        "  Total Parameters: {} ({:.2} M)\n",
200        profile.total_parameters,
201        profile.total_parameters as f64 / 1_000_000.0
202    ));
203    output.push_str(&format!(
204        "  Total FLOPs:      {} ({:.2} GFLOPs)\n",
205        profile.total_flops,
206        profile.total_flops as f64 / 1_000_000_000.0
207    ));
208    output.push_str(&format!(
209        "  Total Memory:     {:.2} MB\n",
210        profile.total_memory as f64 / (1024.0 * 1024.0)
211    ));
212    output.push_str(&format!("  Number of Layers: {}\n", profile.layers.len()));
213    output.push_str("\n");
214
215    // Layer-by-layer breakdown
216    output.push_str("📋 Layer-by-Layer Breakdown\n");
217    output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
218
219    for (i, layer) in profile.layers.iter().enumerate() {
220        output.push_str(&format!(
221            "\n[{}] {} ({})\n",
222            i, layer.name, layer.layer_type
223        ));
224        output.push_str(&format!(
225            "    Shape: {:?} → {:?}\n",
226            layer.input_shape, layer.output_shape
227        ));
228        output.push_str(&format!(
229            "    Parameters: {} ({:.1}% of total)\n",
230            layer.parameters, layer.param_percentage
231        ));
232        output.push_str(&format!(
233            "    Memory: {:.2} KB\n",
234            layer.memory_bytes as f64 / 1024.0
235        ));
236        output.push_str(&format!(
237            "    FLOPs: {:.2} MFLOPs ({:.1}% of total)\n",
238            layer.flops as f64 / 1_000_000.0,
239            layer.flops_percentage
240        ));
241
242        if let Some(time) = layer.execution_time_ms {
243            output.push_str(&format!("    Execution Time: {:.2} ms\n", time));
244        }
245    }
246
247    // Hotspots
248    output.push_str("\n\n🔥 Memory Hotspots (Top 5)\n");
249    output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
250    for (i, (name, memory)) in profile.memory_hotspots.iter().enumerate() {
251        output.push_str(&format!(
252            "  {}. {} - {:.2} MB\n",
253            i + 1,
254            name,
255            *memory as f64 / (1024.0 * 1024.0)
256        ));
257    }
258
259    output.push_str("\n🚀 Computation Hotspots (Top 5)\n");
260    output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
261    for (i, (name, flops)) in profile.computation_hotspots.iter().enumerate() {
262        output.push_str(&format!(
263            "  {}. {} - {:.2} GFLOPs\n",
264            i + 1,
265            name,
266            *flops as f64 / 1_000_000_000.0
267        ));
268    }
269
270    output.push_str("\n");
271    output
272}
273
274/// Profile model with runtime execution measurements
275pub async fn profile_model_runtime(
276    model: &TorshModel,
277    batch_size: usize,
278    iterations: usize,
279) -> Result<ModelProfile> {
280    info!(
281        "Runtime profiling model with batch size {} for {} iterations",
282        batch_size, iterations
283    );
284
285    // Get static profile first
286    let mut profile = profile_model(model)?;
287
288    // Create input tensor
289    let input_shape = model
290        .layers
291        .first()
292        .map(|l| l.input_shape.clone())
293        .unwrap_or_else(|| vec![784]);
294
295    let _input = create_test_input(&input_shape, batch_size)?;
296
297    // Measure each layer's execution time (simulated for now)
298    // In real implementation, this would do actual forward passes per layer
299    for layer_profile in &mut profile.layers {
300        debug!("Profiling layer: {}", layer_profile.name);
301
302        let mut timings = Vec::new();
303
304        for _ in 0..iterations {
305            let start = Instant::now();
306
307            // Simulate layer execution based on FLOPs
308            // In real implementation, would do actual layer forward pass
309            let compute_time = (layer_profile.flops as f64 / 1_000_000_000.0) * 10.0; // Rough estimate
310            tokio::time::sleep(std::time::Duration::from_micros(
311                (compute_time * 1000.0) as u64,
312            ))
313            .await;
314
315            timings.push(start.elapsed().as_secs_f64() * 1000.0);
316        }
317
318        let avg_time = timings.iter().sum::<f64>() / timings.len() as f64;
319        layer_profile.execution_time_ms = Some(avg_time);
320
321        // Estimate runtime memory
322        let activation_memory = layer_profile.output_shape.iter().product::<usize>() * 4; // f32
323        layer_profile.runtime_memory_mb = Some(
324            (layer_profile.memory_bytes + activation_memory as u64) as f64 / (1024.0 * 1024.0),
325        );
326    }
327
328    // Build execution breakdown
329    let mut execution_breakdown = HashMap::new();
330    for layer_profile in &profile.layers {
331        if let Some(time) = layer_profile.execution_time_ms {
332            execution_breakdown.insert(layer_profile.name.clone(), time);
333        }
334    }
335    profile.execution_breakdown = execution_breakdown;
336
337    Ok(profile)
338}
339
340/// Create test input tensor
341fn create_test_input(shape: &[usize], batch_size: usize) -> Result<Tensor<f32>> {
342    let mut full_shape = vec![batch_size];
343    full_shape.extend_from_slice(shape);
344
345    let mut rng = thread_rng();
346    let normal = Normal::new(0.0, 1.0)?;
347
348    let num_elements: usize = full_shape.iter().product();
349    let data: Vec<f32> = (0..num_elements)
350        .map(|_| normal.sample(&mut rng) as f32)
351        .collect();
352
353    Ok(Tensor::from_data(data, full_shape, DeviceType::Cpu)?)
354}
355
356/// Export profile to JSON
357pub fn export_profile_json(profile: &ModelProfile) -> Result<String> {
358    let json = serde_json::json!({
359        "summary": {
360            "total_parameters": profile.total_parameters,
361            "total_flops": profile.total_flops,
362            "total_memory_bytes": profile.total_memory,
363            "num_layers": profile.layers.len(),
364        },
365        "layers": profile.layers.iter().map(|l| {
366            serde_json::json!({
367                "name": l.name,
368                "type": l.layer_type,
369                "input_shape": l.input_shape,
370                "output_shape": l.output_shape,
371                "parameters": l.parameters,
372                "memory_bytes": l.memory_bytes,
373                "flops": l.flops,
374                "param_percentage": l.param_percentage,
375                "flops_percentage": l.flops_percentage,
376                "execution_time_ms": l.execution_time_ms,
377                "runtime_memory_mb": l.runtime_memory_mb,
378            })
379        }).collect::<Vec<_>>(),
380        "hotspots": {
381            "memory": profile.memory_hotspots.iter().map(|(name, mem)| {
382                serde_json::json!({"layer": name, "memory_bytes": mem})
383            }).collect::<Vec<_>>(),
384            "computation": profile.computation_hotspots.iter().map(|(name, flops)| {
385                serde_json::json!({"layer": name, "flops": flops})
386            }).collect::<Vec<_>>(),
387        }
388    });
389
390    Ok(serde_json::to_string_pretty(&json)?)
391}
392
393#[cfg(test)]
394mod tests {
395    use super::super::tensor_integration::create_real_model;
396    use super::*;
397
398    #[test]
399    fn test_model_profiling() {
400        let model = create_real_model("test", 3, DeviceType::Cpu)
401            .expect("create real model should succeed");
402        let profile = profile_model(&model).expect("profile model should succeed");
403
404        assert_eq!(profile.layers.len(), 3);
405        assert!(profile.total_parameters > 0);
406        assert!(profile.total_flops > 0);
407        assert!(!profile.memory_hotspots.is_empty());
408        assert!(!profile.computation_hotspots.is_empty());
409    }
410
411    #[test]
412    fn test_profile_formatting() {
413        let model = create_real_model("test", 2, DeviceType::Cpu)
414            .expect("create real model should succeed");
415        let profile = profile_model(&model).expect("profile model should succeed");
416        let formatted = format_model_profile(&profile);
417
418        assert!(formatted.contains("MODEL PROFILE REPORT"));
419        assert!(formatted.contains("Overall Statistics"));
420        assert!(formatted.contains("Layer-by-Layer Breakdown"));
421    }
422
423    #[test]
424    fn test_profile_export_json() {
425        let model = create_real_model("test", 2, DeviceType::Cpu)
426            .expect("create real model should succeed");
427        let profile = profile_model(&model).expect("profile model should succeed");
428        let json = export_profile_json(&profile).expect("export profile json should succeed");
429
430        assert!(json.contains("total_parameters"));
431        assert!(json.contains("layers"));
432        assert!(json.contains("hotspots"));
433    }
434
435    #[tokio::test]
436    async fn test_runtime_profiling() {
437        let model = create_real_model("test", 2, DeviceType::Cpu)
438            .expect("create real model should succeed");
439        let profile = profile_model_runtime(&model, 1, 5)
440            .await
441            .expect("operation should succeed");
442
443        // Check that execution times were measured
444        assert!(profile.layers.iter().any(|l| l.execution_time_ms.is_some()));
445    }
446}