torsh_functional/profiling/
core.rs1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11#[derive(Debug, Clone)]
13pub struct OperationMetrics {
14 pub name: String,
16 pub duration: Duration,
18 pub peak_memory: Option<usize>,
20 pub input_shapes: Vec<Vec<usize>>,
22 pub output_shapes: Vec<Vec<usize>>,
24 pub flops: Option<u64>,
26 pub memory_bandwidth: Option<f64>,
28 pub cpu_utilization: Option<f32>,
30 pub custom_metrics: HashMap<String, f64>,
32}
33
34impl OperationMetrics {
35 pub fn new(name: String) -> Self {
37 Self {
38 name,
39 duration: Duration::default(),
40 peak_memory: None,
41 input_shapes: Vec::new(),
42 output_shapes: Vec::new(),
43 flops: None,
44 memory_bandwidth: None,
45 cpu_utilization: None,
46 custom_metrics: HashMap::new(),
47 }
48 }
49
50 pub fn add_metric(&mut self, key: String, value: f64) {
52 self.custom_metrics.insert(key, value);
53 }
54
55 pub fn throughput(&self) -> f64 {
57 if self.duration.as_secs_f64() > 0.0 {
58 1.0 / self.duration.as_secs_f64()
59 } else {
60 0.0
61 }
62 }
63
64 pub fn flops_per_second(&self) -> Option<f64> {
66 self.flops
67 .map(|flops| flops as f64 / self.duration.as_secs_f64())
68 }
69
70 pub fn memory_efficiency(&self, peak_bandwidth_gbps: f64) -> Option<f64> {
72 self.memory_bandwidth
73 .map(|bw| bw / (peak_bandwidth_gbps * 1e9))
74 }
75}
76
77pub struct Profiler {
79 pub metrics: Vec<OperationMetrics>,
81 current_session: Option<ProfilingSession>,
83 track_memory: bool,
85 count_flops: bool,
87 hooks: Vec<Box<dyn Fn(&OperationMetrics) + Send + Sync>>,
89}
90
91#[derive(Debug)]
92struct ProfilingSession {
93 name: String,
94 start_time: Instant,
95 input_shapes: Vec<Vec<usize>>,
96 initial_memory: Option<usize>,
97}
98
99impl Default for Profiler {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl Profiler {
106 pub fn new() -> Self {
108 Self {
109 metrics: Vec::new(),
110 current_session: None,
111 track_memory: false,
112 count_flops: false,
113 hooks: Vec::new(),
114 }
115 }
116
117 pub fn enable_memory_tracking(&mut self) {
119 self.track_memory = true;
120 }
121
122 pub fn enable_flops_counting(&mut self) {
124 self.count_flops = true;
125 }
126
127 pub fn add_hook<F>(&mut self, hook: F)
129 where
130 F: Fn(&OperationMetrics) + Send + Sync + 'static,
131 {
132 self.hooks.push(Box::new(hook));
133 }
134
135 pub fn start_operation(&mut self, name: &str, inputs: &[&Tensor]) -> TorshResult<()> {
137 if self.current_session.is_some() {
138 return Err(TorshError::invalid_argument_with_context(
139 "Cannot start operation while another is in progress",
140 "Profiler::start_operation",
141 ));
142 }
143
144 let input_shapes: Vec<Vec<usize>> =
145 inputs.iter().map(|t| t.shape().dims().to_vec()).collect();
146
147 let initial_memory = if self.track_memory {
148 Some(get_current_memory_usage())
149 } else {
150 None
151 };
152
153 self.current_session = Some(ProfilingSession {
154 name: name.to_string(),
155 start_time: Instant::now(),
156 input_shapes,
157 initial_memory,
158 });
159
160 Ok(())
161 }
162
163 pub fn finish_operation(&mut self, outputs: &[&Tensor]) -> TorshResult<()> {
165 let session = self.current_session.take().ok_or_else(|| {
166 TorshError::invalid_argument_with_context(
167 "No operation in progress",
168 "Profiler::finish_operation",
169 )
170 })?;
171
172 let duration = session.start_time.elapsed();
173 let output_shapes: Vec<Vec<usize>> =
174 outputs.iter().map(|t| t.shape().dims().to_vec()).collect();
175
176 let peak_memory = if self.track_memory {
177 Some(get_current_memory_usage().saturating_sub(session.initial_memory.unwrap_or(0)))
178 } else {
179 None
180 };
181
182 let flops = if self.count_flops {
183 Some(estimate_flops(
184 &session.name,
185 &session.input_shapes,
186 &output_shapes,
187 ))
188 } else {
189 None
190 };
191
192 let memory_bandwidth =
193 calculate_memory_bandwidth(&session.input_shapes, &output_shapes, duration);
194
195 let metrics = OperationMetrics {
196 name: session.name,
197 duration,
198 peak_memory,
199 input_shapes: session.input_shapes,
200 output_shapes,
201 flops,
202 memory_bandwidth: Some(memory_bandwidth),
203 cpu_utilization: None, custom_metrics: HashMap::new(),
205 };
206
207 for hook in &self.hooks {
209 hook(&metrics);
210 }
211
212 self.metrics.push(metrics);
213 Ok(())
214 }
215
216 pub fn get_metrics(&self, operation_name: &str) -> Vec<&OperationMetrics> {
218 self.metrics
219 .iter()
220 .filter(|m| m.name == operation_name)
221 .collect()
222 }
223
224 pub fn get_summary(&self, operation_name: &str) -> Option<OperationSummary> {
226 let metrics: Vec<_> = self.get_metrics(operation_name);
227 if metrics.is_empty() {
228 return None;
229 }
230
231 let count = metrics.len();
232 let durations: Vec<f64> = metrics.iter().map(|m| m.duration.as_secs_f64()).collect();
233
234 let mean_duration = durations.iter().sum::<f64>() / count as f64;
235 let min_duration = durations.iter().fold(f64::INFINITY, |a, &b| a.min(b));
236 let max_duration = durations.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
237
238 let variance = durations
240 .iter()
241 .map(|d| (d - mean_duration).powi(2))
242 .sum::<f64>()
243 / count as f64;
244 let std_duration = variance.sqrt();
245
246 let total_flops: Option<u64> = metrics
247 .iter()
248 .try_fold(0u64, |acc, m| m.flops.map(|f| acc + f));
249
250 let mean_throughput = metrics.iter().map(|m| m.throughput()).sum::<f64>() / count as f64;
251
252 Some(OperationSummary {
253 operation_name: operation_name.to_string(),
254 count,
255 mean_duration,
256 std_duration,
257 min_duration,
258 max_duration,
259 total_flops,
260 mean_throughput,
261 })
262 }
263
264 pub fn clear(&mut self) {
266 self.metrics.clear();
267 }
268
269 pub fn export_csv(&self) -> String {
271 let mut csv = String::from(
272 "operation,duration_ms,peak_memory_mb,input_shapes,output_shapes,flops,throughput\n",
273 );
274
275 for metric in &self.metrics {
276 let input_shapes_str = format!("{:?}", metric.input_shapes);
277 let output_shapes_str = format!("{:?}", metric.output_shapes);
278 let peak_memory_mb = metric
279 .peak_memory
280 .map(|m| m as f64 / 1024.0 / 1024.0)
281 .unwrap_or(0.0);
282
283 csv.push_str(&format!(
284 "{},{},{},{},{},{},{}\n",
285 metric.name,
286 metric.duration.as_millis(),
287 peak_memory_mb,
288 input_shapes_str,
289 output_shapes_str,
290 metric.flops.unwrap_or(0),
291 metric.throughput()
292 ));
293 }
294
295 csv
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct OperationSummary {
302 pub operation_name: String,
303 pub count: usize,
304 pub mean_duration: f64,
305 pub std_duration: f64,
306 pub min_duration: f64,
307 pub max_duration: f64,
308 pub total_flops: Option<u64>,
309 pub mean_throughput: f64,
310}
311
312pub fn get_current_memory_usage() -> usize {
314 0
317}
318
319pub fn estimate_flops(
320 operation: &str,
321 input_shapes: &[Vec<usize>],
322 output_shapes: &[Vec<usize>],
323) -> u64 {
324 match operation {
325 "matmul" | "bmm" => {
326 if input_shapes.len() >= 2 {
327 let a_shape = &input_shapes[0];
328 let b_shape = &input_shapes[1];
329 if a_shape.len() >= 2 && b_shape.len() >= 2 {
330 let m = a_shape[a_shape.len() - 2];
331 let k = a_shape[a_shape.len() - 1];
332 let n = b_shape[b_shape.len() - 1];
333 let batch_size = a_shape.iter().take(a_shape.len() - 2).product::<usize>();
334 return (2 * m * k * n * batch_size) as u64;
335 }
336 }
337 }
338 "conv2d" => {
339 if !input_shapes.is_empty() && !output_shapes.is_empty() {
340 let output_elements: usize = output_shapes[0].iter().product();
341 return (output_elements * 9 * 2) as u64; }
344 }
345 "add" | "sub" | "mul" | "div" => {
346 if !output_shapes.is_empty() {
347 let elements: usize = output_shapes[0].iter().product();
348 return elements as u64;
349 }
350 }
351 _ => {}
352 }
353 0
354}
355
356fn calculate_memory_bandwidth(
357 input_shapes: &[Vec<usize>],
358 output_shapes: &[Vec<usize>],
359 duration: Duration,
360) -> f64 {
361 let input_elements: usize = input_shapes
362 .iter()
363 .map(|shape| shape.iter().product::<usize>())
364 .sum();
365 let output_elements: usize = output_shapes
366 .iter()
367 .map(|shape| shape.iter().product::<usize>())
368 .sum();
369
370 let total_bytes = (input_elements + output_elements) * 4; total_bytes as f64 / duration.as_secs_f64()
372}