torsh_cli/commands/model/
enhanced_profiling.rs1#![allow(dead_code)]
8
9use anyhow::Result;
10use std::collections::HashMap;
11use std::time::Instant;
12use tracing::{debug, info};
13
14use scirs2_core::random::{thread_rng, Distribution, Normal};
16
17use torsh::core::device::DeviceType;
19use torsh::tensor::Tensor;
20
21use super::tensor_integration::estimate_tensor_flops;
22use super::types::{LayerInfo, TorshModel};
23
24#[derive(Debug, Clone)]
26pub struct LayerProfile {
27 pub name: String,
29 pub layer_type: String,
31 pub input_shape: Vec<usize>,
33 pub output_shape: Vec<usize>,
35 pub parameters: u64,
37 pub memory_bytes: u64,
39 pub flops: u64,
41 pub execution_time_ms: Option<f64>,
43 pub runtime_memory_mb: Option<f64>,
45 pub param_percentage: f64,
47 pub flops_percentage: f64,
49}
50
51#[derive(Debug, Clone)]
53pub struct ModelProfile {
54 pub layers: Vec<LayerProfile>,
56 pub total_parameters: u64,
58 pub total_flops: u64,
60 pub total_memory: u64,
62 pub execution_breakdown: HashMap<String, f64>,
64 pub memory_hotspots: Vec<(String, u64)>,
66 pub computation_hotspots: Vec<(String, u64)>,
68}
69
70pub fn profile_model(model: &TorshModel) -> Result<ModelProfile> {
72 info!("Profiling model with {} layers", model.layers.len());
73
74 let total_parameters: u64 = model.layers.iter().map(|l| l.parameters).sum();
76
77 let total_flops: u64 = model.layers.iter().map(|l| estimate_layer_flops(l)).sum();
78
79 let total_memory: u64 = model
80 .weights
81 .values()
82 .map(|t| {
83 let elements: usize = t.shape.iter().product();
84 (elements * t.dtype.size_bytes()) as u64
85 })
86 .sum();
87
88 debug!(
89 "Model totals: {} params, {} FLOPs, {:.2} MB",
90 total_parameters,
91 total_flops,
92 total_memory as f64 / (1024.0 * 1024.0)
93 );
94
95 let mut layer_profiles = Vec::new();
97
98 for layer in &model.layers {
99 let flops = estimate_layer_flops(layer);
100 let memory = calculate_layer_memory(layer, model);
101
102 let param_percentage = if total_parameters > 0 {
103 (layer.parameters as f64 / total_parameters as f64) * 100.0
104 } else {
105 0.0
106 };
107
108 let flops_percentage = if total_flops > 0 {
109 (flops as f64 / total_flops as f64) * 100.0
110 } else {
111 0.0
112 };
113
114 let profile = LayerProfile {
115 name: layer.name.clone(),
116 layer_type: layer.layer_type.clone(),
117 input_shape: layer.input_shape.clone(),
118 output_shape: layer.output_shape.clone(),
119 parameters: layer.parameters,
120 memory_bytes: memory,
121 flops,
122 execution_time_ms: None, runtime_memory_mb: None,
124 param_percentage,
125 flops_percentage,
126 };
127
128 layer_profiles.push(profile);
129 }
130
131 let mut memory_hotspots: Vec<(String, u64)> = layer_profiles
133 .iter()
134 .map(|p| (p.name.clone(), p.memory_bytes))
135 .collect();
136 memory_hotspots.sort_by(|a, b| b.1.cmp(&a.1));
137 memory_hotspots.truncate(5);
138
139 let mut computation_hotspots: Vec<(String, u64)> = layer_profiles
140 .iter()
141 .map(|p| (p.name.clone(), p.flops))
142 .collect();
143 computation_hotspots.sort_by(|a, b| b.1.cmp(&a.1));
144 computation_hotspots.truncate(5);
145
146 Ok(ModelProfile {
147 layers: layer_profiles,
148 total_parameters,
149 total_flops,
150 total_memory,
151 execution_breakdown: HashMap::new(),
152 memory_hotspots,
153 computation_hotspots,
154 })
155}
156
157fn estimate_layer_flops(layer: &LayerInfo) -> u64 {
159 estimate_tensor_flops(
160 &layer.layer_type.to_lowercase(),
161 &layer.input_shape,
162 &layer.output_shape,
163 )
164}
165
166fn calculate_layer_memory(layer: &LayerInfo, model: &TorshModel) -> u64 {
168 let weight_name = format!("{}.weight", layer.name);
169 let bias_name = format!("{}.bias", layer.name);
170
171 let mut memory = 0u64;
172
173 if let Some(weight) = model.weights.get(&weight_name) {
174 let elements: usize = weight.shape.iter().product();
175 memory += (elements * weight.dtype.size_bytes()) as u64;
176 }
177
178 if let Some(bias) = model.weights.get(&bias_name) {
179 let elements: usize = bias.shape.iter().product();
180 memory += (elements * bias.dtype.size_bytes()) as u64;
181 }
182
183 memory
184}
185
186pub fn format_model_profile(profile: &ModelProfile) -> String {
188 let mut output = String::new();
189
190 output.push_str("╔═══════════════════════════════════════════════════════════════════════╗\n");
191 output.push_str("║ MODEL PROFILE REPORT ║\n");
192 output
193 .push_str("╚═══════════════════════════════════════════════════════════════════════╝\n\n");
194
195 output.push_str("📊 Overall Statistics\n");
197 output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
198 output.push_str(&format!(
199 " Total Parameters: {} ({:.2} M)\n",
200 profile.total_parameters,
201 profile.total_parameters as f64 / 1_000_000.0
202 ));
203 output.push_str(&format!(
204 " Total FLOPs: {} ({:.2} GFLOPs)\n",
205 profile.total_flops,
206 profile.total_flops as f64 / 1_000_000_000.0
207 ));
208 output.push_str(&format!(
209 " Total Memory: {:.2} MB\n",
210 profile.total_memory as f64 / (1024.0 * 1024.0)
211 ));
212 output.push_str(&format!(" Number of Layers: {}\n", profile.layers.len()));
213 output.push_str("\n");
214
215 output.push_str("📋 Layer-by-Layer Breakdown\n");
217 output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
218
219 for (i, layer) in profile.layers.iter().enumerate() {
220 output.push_str(&format!(
221 "\n[{}] {} ({})\n",
222 i, layer.name, layer.layer_type
223 ));
224 output.push_str(&format!(
225 " Shape: {:?} → {:?}\n",
226 layer.input_shape, layer.output_shape
227 ));
228 output.push_str(&format!(
229 " Parameters: {} ({:.1}% of total)\n",
230 layer.parameters, layer.param_percentage
231 ));
232 output.push_str(&format!(
233 " Memory: {:.2} KB\n",
234 layer.memory_bytes as f64 / 1024.0
235 ));
236 output.push_str(&format!(
237 " FLOPs: {:.2} MFLOPs ({:.1}% of total)\n",
238 layer.flops as f64 / 1_000_000.0,
239 layer.flops_percentage
240 ));
241
242 if let Some(time) = layer.execution_time_ms {
243 output.push_str(&format!(" Execution Time: {:.2} ms\n", time));
244 }
245 }
246
247 output.push_str("\n\n🔥 Memory Hotspots (Top 5)\n");
249 output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
250 for (i, (name, memory)) in profile.memory_hotspots.iter().enumerate() {
251 output.push_str(&format!(
252 " {}. {} - {:.2} MB\n",
253 i + 1,
254 name,
255 *memory as f64 / (1024.0 * 1024.0)
256 ));
257 }
258
259 output.push_str("\n🚀 Computation Hotspots (Top 5)\n");
260 output.push_str("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n");
261 for (i, (name, flops)) in profile.computation_hotspots.iter().enumerate() {
262 output.push_str(&format!(
263 " {}. {} - {:.2} GFLOPs\n",
264 i + 1,
265 name,
266 *flops as f64 / 1_000_000_000.0
267 ));
268 }
269
270 output.push_str("\n");
271 output
272}
273
274pub async fn profile_model_runtime(
276 model: &TorshModel,
277 batch_size: usize,
278 iterations: usize,
279) -> Result<ModelProfile> {
280 info!(
281 "Runtime profiling model with batch size {} for {} iterations",
282 batch_size, iterations
283 );
284
285 let mut profile = profile_model(model)?;
287
288 let input_shape = model
290 .layers
291 .first()
292 .map(|l| l.input_shape.clone())
293 .unwrap_or_else(|| vec![784]);
294
295 let _input = create_test_input(&input_shape, batch_size)?;
296
297 for layer_profile in &mut profile.layers {
300 debug!("Profiling layer: {}", layer_profile.name);
301
302 let mut timings = Vec::new();
303
304 for _ in 0..iterations {
305 let start = Instant::now();
306
307 let compute_time = (layer_profile.flops as f64 / 1_000_000_000.0) * 10.0; tokio::time::sleep(std::time::Duration::from_micros(
311 (compute_time * 1000.0) as u64,
312 ))
313 .await;
314
315 timings.push(start.elapsed().as_secs_f64() * 1000.0);
316 }
317
318 let avg_time = timings.iter().sum::<f64>() / timings.len() as f64;
319 layer_profile.execution_time_ms = Some(avg_time);
320
321 let activation_memory = layer_profile.output_shape.iter().product::<usize>() * 4; layer_profile.runtime_memory_mb = Some(
324 (layer_profile.memory_bytes + activation_memory as u64) as f64 / (1024.0 * 1024.0),
325 );
326 }
327
328 let mut execution_breakdown = HashMap::new();
330 for layer_profile in &profile.layers {
331 if let Some(time) = layer_profile.execution_time_ms {
332 execution_breakdown.insert(layer_profile.name.clone(), time);
333 }
334 }
335 profile.execution_breakdown = execution_breakdown;
336
337 Ok(profile)
338}
339
340fn create_test_input(shape: &[usize], batch_size: usize) -> Result<Tensor<f32>> {
342 let mut full_shape = vec![batch_size];
343 full_shape.extend_from_slice(shape);
344
345 let mut rng = thread_rng();
346 let normal = Normal::new(0.0, 1.0)?;
347
348 let num_elements: usize = full_shape.iter().product();
349 let data: Vec<f32> = (0..num_elements)
350 .map(|_| normal.sample(&mut rng) as f32)
351 .collect();
352
353 Ok(Tensor::from_data(data, full_shape, DeviceType::Cpu)?)
354}
355
356pub fn export_profile_json(profile: &ModelProfile) -> Result<String> {
358 let json = serde_json::json!({
359 "summary": {
360 "total_parameters": profile.total_parameters,
361 "total_flops": profile.total_flops,
362 "total_memory_bytes": profile.total_memory,
363 "num_layers": profile.layers.len(),
364 },
365 "layers": profile.layers.iter().map(|l| {
366 serde_json::json!({
367 "name": l.name,
368 "type": l.layer_type,
369 "input_shape": l.input_shape,
370 "output_shape": l.output_shape,
371 "parameters": l.parameters,
372 "memory_bytes": l.memory_bytes,
373 "flops": l.flops,
374 "param_percentage": l.param_percentage,
375 "flops_percentage": l.flops_percentage,
376 "execution_time_ms": l.execution_time_ms,
377 "runtime_memory_mb": l.runtime_memory_mb,
378 })
379 }).collect::<Vec<_>>(),
380 "hotspots": {
381 "memory": profile.memory_hotspots.iter().map(|(name, mem)| {
382 serde_json::json!({"layer": name, "memory_bytes": mem})
383 }).collect::<Vec<_>>(),
384 "computation": profile.computation_hotspots.iter().map(|(name, flops)| {
385 serde_json::json!({"layer": name, "flops": flops})
386 }).collect::<Vec<_>>(),
387 }
388 });
389
390 Ok(serde_json::to_string_pretty(&json)?)
391}
392
393#[cfg(test)]
394mod tests {
395 use super::super::tensor_integration::create_real_model;
396 use super::*;
397
398 #[test]
399 fn test_model_profiling() {
400 let model = create_real_model("test", 3, DeviceType::Cpu)
401 .expect("create real model should succeed");
402 let profile = profile_model(&model).expect("profile model should succeed");
403
404 assert_eq!(profile.layers.len(), 3);
405 assert!(profile.total_parameters > 0);
406 assert!(profile.total_flops > 0);
407 assert!(!profile.memory_hotspots.is_empty());
408 assert!(!profile.computation_hotspots.is_empty());
409 }
410
411 #[test]
412 fn test_profile_formatting() {
413 let model = create_real_model("test", 2, DeviceType::Cpu)
414 .expect("create real model should succeed");
415 let profile = profile_model(&model).expect("profile model should succeed");
416 let formatted = format_model_profile(&profile);
417
418 assert!(formatted.contains("MODEL PROFILE REPORT"));
419 assert!(formatted.contains("Overall Statistics"));
420 assert!(formatted.contains("Layer-by-Layer Breakdown"));
421 }
422
423 #[test]
424 fn test_profile_export_json() {
425 let model = create_real_model("test", 2, DeviceType::Cpu)
426 .expect("create real model should succeed");
427 let profile = profile_model(&model).expect("profile model should succeed");
428 let json = export_profile_json(&profile).expect("export profile json should succeed");
429
430 assert!(json.contains("total_parameters"));
431 assert!(json.contains("layers"));
432 assert!(json.contains("hotspots"));
433 }
434
435 #[tokio::test]
436 async fn test_runtime_profiling() {
437 let model = create_real_model("test", 2, DeviceType::Cpu)
438 .expect("create real model should succeed");
439 let profile = profile_model_runtime(&model, 1, 5)
440 .await
441 .expect("operation should succeed");
442
443 assert!(profile.layers.iter().any(|l| l.execution_time_ms.is_some()));
445 }
446}