Skip to main content

torsh_functional/profiling/
core.rs

1//! Core profiling types and functionality
2//!
3//! This module provides the basic profiling infrastructure including
4//! OperationMetrics and the main Profiler struct.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11/// Performance metrics for an operation
12#[derive(Debug, Clone)]
13pub struct OperationMetrics {
14    /// Name of the operation
15    pub name: String,
16    /// Execution time
17    pub duration: Duration,
18    /// Peak memory usage during operation (in bytes)
19    pub peak_memory: Option<usize>,
20    /// Input tensor shapes
21    pub input_shapes: Vec<Vec<usize>>,
22    /// Output tensor shapes
23    pub output_shapes: Vec<Vec<usize>>,
24    /// Number of floating-point operations (estimated)
25    pub flops: Option<u64>,
26    /// Memory bandwidth utilization (bytes/second)
27    pub memory_bandwidth: Option<f64>,
28    /// CPU utilization percentage
29    pub cpu_utilization: Option<f32>,
30    /// Additional custom metrics
31    pub custom_metrics: HashMap<String, f64>,
32}
33
34impl OperationMetrics {
35    /// Create new operation metrics
36    pub fn new(name: String) -> Self {
37        Self {
38            name,
39            duration: Duration::default(),
40            peak_memory: None,
41            input_shapes: Vec::new(),
42            output_shapes: Vec::new(),
43            flops: None,
44            memory_bandwidth: None,
45            cpu_utilization: None,
46            custom_metrics: HashMap::new(),
47        }
48    }
49
50    /// Add a custom metric
51    pub fn add_metric(&mut self, key: String, value: f64) {
52        self.custom_metrics.insert(key, value);
53    }
54
55    /// Get throughput in operations per second
56    pub fn throughput(&self) -> f64 {
57        if self.duration.as_secs_f64() > 0.0 {
58            1.0 / self.duration.as_secs_f64()
59        } else {
60            0.0
61        }
62    }
63
64    /// Get FLOPS (floating-point operations per second)
65    pub fn flops_per_second(&self) -> Option<f64> {
66        self.flops
67            .map(|flops| flops as f64 / self.duration.as_secs_f64())
68    }
69
70    /// Get memory efficiency (fraction of peak bandwidth utilized)
71    pub fn memory_efficiency(&self, peak_bandwidth_gbps: f64) -> Option<f64> {
72        self.memory_bandwidth
73            .map(|bw| bw / (peak_bandwidth_gbps * 1e9))
74    }
75}
76
77/// Performance profiler for tracking operation metrics
78pub struct Profiler {
79    /// Collected metrics
80    pub metrics: Vec<OperationMetrics>,
81    /// Currently active profiling session
82    current_session: Option<ProfilingSession>,
83    /// Enable detailed memory tracking
84    track_memory: bool,
85    /// Enable FLOPS counting
86    count_flops: bool,
87    /// Custom profiling hooks
88    hooks: Vec<Box<dyn Fn(&OperationMetrics) + Send + Sync>>,
89}
90
91#[derive(Debug)]
92struct ProfilingSession {
93    name: String,
94    start_time: Instant,
95    input_shapes: Vec<Vec<usize>>,
96    initial_memory: Option<usize>,
97}
98
99impl Default for Profiler {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl Profiler {
106    /// Create a new profiler
107    pub fn new() -> Self {
108        Self {
109            metrics: Vec::new(),
110            current_session: None,
111            track_memory: false,
112            count_flops: false,
113            hooks: Vec::new(),
114        }
115    }
116
117    /// Enable memory tracking
118    pub fn enable_memory_tracking(&mut self) {
119        self.track_memory = true;
120    }
121
122    /// Enable FLOPS counting
123    pub fn enable_flops_counting(&mut self) {
124        self.count_flops = true;
125    }
126
127    /// Add a profiling hook
128    pub fn add_hook<F>(&mut self, hook: F)
129    where
130        F: Fn(&OperationMetrics) + Send + Sync + 'static,
131    {
132        self.hooks.push(Box::new(hook));
133    }
134
135    /// Start profiling an operation
136    pub fn start_operation(&mut self, name: &str, inputs: &[&Tensor]) -> TorshResult<()> {
137        if self.current_session.is_some() {
138            return Err(TorshError::invalid_argument_with_context(
139                "Cannot start operation while another is in progress",
140                "Profiler::start_operation",
141            ));
142        }
143
144        let input_shapes: Vec<Vec<usize>> =
145            inputs.iter().map(|t| t.shape().dims().to_vec()).collect();
146
147        let initial_memory = if self.track_memory {
148            Some(get_current_memory_usage())
149        } else {
150            None
151        };
152
153        self.current_session = Some(ProfilingSession {
154            name: name.to_string(),
155            start_time: Instant::now(),
156            input_shapes,
157            initial_memory,
158        });
159
160        Ok(())
161    }
162
163    /// Finish profiling an operation
164    pub fn finish_operation(&mut self, outputs: &[&Tensor]) -> TorshResult<()> {
165        let session = self.current_session.take().ok_or_else(|| {
166            TorshError::invalid_argument_with_context(
167                "No operation in progress",
168                "Profiler::finish_operation",
169            )
170        })?;
171
172        let duration = session.start_time.elapsed();
173        let output_shapes: Vec<Vec<usize>> =
174            outputs.iter().map(|t| t.shape().dims().to_vec()).collect();
175
176        let peak_memory = if self.track_memory {
177            Some(get_current_memory_usage().saturating_sub(session.initial_memory.unwrap_or(0)))
178        } else {
179            None
180        };
181
182        let flops = if self.count_flops {
183            Some(estimate_flops(
184                &session.name,
185                &session.input_shapes,
186                &output_shapes,
187            ))
188        } else {
189            None
190        };
191
192        let memory_bandwidth =
193            calculate_memory_bandwidth(&session.input_shapes, &output_shapes, duration);
194
195        let metrics = OperationMetrics {
196            name: session.name,
197            duration,
198            peak_memory,
199            input_shapes: session.input_shapes,
200            output_shapes,
201            flops,
202            memory_bandwidth: Some(memory_bandwidth),
203            cpu_utilization: None, // TODO: Implement CPU utilization tracking
204            custom_metrics: HashMap::new(),
205        };
206
207        // Call hooks
208        for hook in &self.hooks {
209            hook(&metrics);
210        }
211
212        self.metrics.push(metrics);
213        Ok(())
214    }
215
216    /// Get metrics for a specific operation
217    pub fn get_metrics(&self, operation_name: &str) -> Vec<&OperationMetrics> {
218        self.metrics
219            .iter()
220            .filter(|m| m.name == operation_name)
221            .collect()
222    }
223
224    /// Get summary statistics for an operation
225    pub fn get_summary(&self, operation_name: &str) -> Option<OperationSummary> {
226        let metrics: Vec<_> = self.get_metrics(operation_name);
227        if metrics.is_empty() {
228            return None;
229        }
230
231        let count = metrics.len();
232        let durations: Vec<f64> = metrics.iter().map(|m| m.duration.as_secs_f64()).collect();
233
234        let mean_duration = durations.iter().sum::<f64>() / count as f64;
235        let min_duration = durations.iter().fold(f64::INFINITY, |a, &b| a.min(b));
236        let max_duration = durations.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
237
238        // Calculate standard deviation
239        let variance = durations
240            .iter()
241            .map(|d| (d - mean_duration).powi(2))
242            .sum::<f64>()
243            / count as f64;
244        let std_duration = variance.sqrt();
245
246        let total_flops: Option<u64> = metrics
247            .iter()
248            .try_fold(0u64, |acc, m| m.flops.map(|f| acc + f));
249
250        let mean_throughput = metrics.iter().map(|m| m.throughput()).sum::<f64>() / count as f64;
251
252        Some(OperationSummary {
253            operation_name: operation_name.to_string(),
254            count,
255            mean_duration,
256            std_duration,
257            min_duration,
258            max_duration,
259            total_flops,
260            mean_throughput,
261        })
262    }
263
264    /// Clear all collected metrics
265    pub fn clear(&mut self) {
266        self.metrics.clear();
267    }
268
269    /// Export metrics to CSV format
270    pub fn export_csv(&self) -> String {
271        let mut csv = String::from(
272            "operation,duration_ms,peak_memory_mb,input_shapes,output_shapes,flops,throughput\n",
273        );
274
275        for metric in &self.metrics {
276            let input_shapes_str = format!("{:?}", metric.input_shapes);
277            let output_shapes_str = format!("{:?}", metric.output_shapes);
278            let peak_memory_mb = metric
279                .peak_memory
280                .map(|m| m as f64 / 1024.0 / 1024.0)
281                .unwrap_or(0.0);
282
283            csv.push_str(&format!(
284                "{},{},{},{},{},{},{}\n",
285                metric.name,
286                metric.duration.as_millis(),
287                peak_memory_mb,
288                input_shapes_str,
289                output_shapes_str,
290                metric.flops.unwrap_or(0),
291                metric.throughput()
292            ));
293        }
294
295        csv
296    }
297}
298
299/// Summary statistics for an operation
300#[derive(Debug, Clone)]
301pub struct OperationSummary {
302    pub operation_name: String,
303    pub count: usize,
304    pub mean_duration: f64,
305    pub std_duration: f64,
306    pub min_duration: f64,
307    pub max_duration: f64,
308    pub total_flops: Option<u64>,
309    pub mean_throughput: f64,
310}
311
312// Helper functions
313pub fn get_current_memory_usage() -> usize {
314    // Simplified memory usage tracking
315    // In a real implementation, this would use platform-specific APIs
316    0
317}
318
319pub fn estimate_flops(
320    operation: &str,
321    input_shapes: &[Vec<usize>],
322    output_shapes: &[Vec<usize>],
323) -> u64 {
324    match operation {
325        "matmul" | "bmm" => {
326            if input_shapes.len() >= 2 {
327                let a_shape = &input_shapes[0];
328                let b_shape = &input_shapes[1];
329                if a_shape.len() >= 2 && b_shape.len() >= 2 {
330                    let m = a_shape[a_shape.len() - 2];
331                    let k = a_shape[a_shape.len() - 1];
332                    let n = b_shape[b_shape.len() - 1];
333                    let batch_size = a_shape.iter().take(a_shape.len() - 2).product::<usize>();
334                    return (2 * m * k * n * batch_size) as u64;
335                }
336            }
337        }
338        "conv2d" => {
339            if !input_shapes.is_empty() && !output_shapes.is_empty() {
340                let output_elements: usize = output_shapes[0].iter().product();
341                // Rough estimate: 2 operations per output element per filter weight
342                return (output_elements * 9 * 2) as u64; // Assuming 3x3 kernel
343            }
344        }
345        "add" | "sub" | "mul" | "div" => {
346            if !output_shapes.is_empty() {
347                let elements: usize = output_shapes[0].iter().product();
348                return elements as u64;
349            }
350        }
351        _ => {}
352    }
353    0
354}
355
356fn calculate_memory_bandwidth(
357    input_shapes: &[Vec<usize>],
358    output_shapes: &[Vec<usize>],
359    duration: Duration,
360) -> f64 {
361    let input_elements: usize = input_shapes
362        .iter()
363        .map(|shape| shape.iter().product::<usize>())
364        .sum();
365    let output_elements: usize = output_shapes
366        .iter()
367        .map(|shape| shape.iter().product::<usize>())
368        .sum();
369
370    let total_bytes = (input_elements + output_elements) * 4; // Assume f32
371    total_bytes as f64 / duration.as_secs_f64()
372}