Skip to main content

torsh_profiler/
attributes.rs

1//! Attribute-based profiling support
2//!
3//! This module provides function decorators and attribute-like functionality for automatic profiling.
4//! While Rust doesn't have decorators like Python, we provide similar functionality through wrapper functions.
5
6use crate::cpu::ProfileScope;
7use once_cell::sync::Lazy;
8use parking_lot::Mutex;
9use std::any::type_name;
10use std::time::Instant;
11
12/// Trait for automatic profiling of methods
13pub trait ProfiledMethod<Args, Return> {
14    /// Execute the method with automatic profiling
15    fn profiled(self, name: Option<&str>, category: Option<&str>) -> Return;
16}
17
18/// Wrapper for functions that enables automatic profiling
19pub struct ProfiledFunction<F> {
20    func: F,
21    name: String,
22    category: String,
23    enabled: bool,
24}
25
26impl<F> ProfiledFunction<F> {
27    /// Create a new profiled function wrapper
28    pub fn new(func: F, name: String, category: String) -> Self {
29        Self {
30            func,
31            name,
32            category,
33            enabled: true,
34        }
35    }
36
37    /// Enable/disable profiling for this function
38    pub fn set_enabled(&mut self, enabled: bool) {
39        self.enabled = enabled;
40    }
41
42    /// Check if profiling is enabled for this function
43    pub fn is_enabled(&self) -> bool {
44        self.enabled
45    }
46}
47
48impl<F, R> ProfiledFunction<F>
49where
50    F: FnOnce() -> R,
51{
52    /// Execute the function with profiling
53    pub fn call(self) -> R {
54        if self.enabled {
55            let _guard = ProfileScope::simple(self.name, self.category);
56            (self.func)()
57        } else {
58            (self.func)()
59        }
60    }
61}
62
63impl<F> ProfiledFunction<F> {
64    /// Execute the function with one argument and profiling
65    pub fn call_with_arg<A, R>(self, arg: A) -> R
66    where
67        F: FnOnce(A) -> R,
68    {
69        if self.enabled {
70            let _guard = ProfileScope::simple(self.name, self.category);
71            (self.func)(arg)
72        } else {
73            (self.func)(arg)
74        }
75    }
76
77    /// Execute the function with two arguments and profiling
78    pub fn call_with_args<A, B, R>(self, arg1: A, arg2: B) -> R
79    where
80        F: FnOnce(A, B) -> R,
81    {
82        if self.enabled {
83            let _guard = ProfileScope::simple(self.name, self.category);
84            (self.func)(arg1, arg2)
85        } else {
86            (self.func)(arg1, arg2)
87        }
88    }
89}
90
91/// Attribute configuration for profiling
92#[derive(Debug, Clone)]
93pub struct ProfileAttribute {
94    /// Name of the profiling event
95    pub name: Option<String>,
96    /// Category of the profiling event
97    pub category: Option<String>,
98    /// Whether to include stack traces
99    pub stack_trace: bool,
100    /// Whether to track memory allocations
101    pub track_memory: bool,
102    /// Whether to count FLOPS (for tensor operations)
103    pub count_flops: bool,
104    /// Custom metadata to include
105    pub metadata: std::collections::HashMap<String, String>,
106    /// Sampling rate (1 = profile every call, 10 = profile every 10th call)
107    pub sample_rate: usize,
108    /// Minimum duration threshold to record (in microseconds)
109    pub min_duration_us: u64,
110}
111
112impl Default for ProfileAttribute {
113    fn default() -> Self {
114        Self {
115            name: None,
116            category: Some("function".to_string()),
117            stack_trace: false,
118            track_memory: false,
119            count_flops: false,
120            metadata: std::collections::HashMap::new(),
121            sample_rate: 1,
122            min_duration_us: 0,
123        }
124    }
125}
126
127impl ProfileAttribute {
128    /// Create a new profile attribute with default settings
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Set the profiling name
134    pub fn with_name<S: Into<String>>(mut self, name: S) -> Self {
135        self.name = Some(name.into());
136        self
137    }
138
139    /// Set the profiling category
140    pub fn with_category<S: Into<String>>(mut self, category: S) -> Self {
141        self.category = Some(category.into());
142        self
143    }
144
145    /// Enable stack trace collection
146    pub fn with_stack_trace(mut self) -> Self {
147        self.stack_trace = true;
148        self
149    }
150
151    /// Enable memory tracking
152    pub fn with_memory_tracking(mut self) -> Self {
153        self.track_memory = true;
154        self
155    }
156
157    /// Enable FLOPS counting
158    pub fn with_flops_counting(mut self) -> Self {
159        self.count_flops = true;
160        self
161    }
162
163    /// Add custom metadata
164    pub fn with_metadata<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
165        self.metadata.insert(key.into(), value.into());
166        self
167    }
168
169    /// Set sampling rate
170    pub fn with_sample_rate(mut self, rate: usize) -> Self {
171        self.sample_rate = rate.max(1);
172        self
173    }
174
175    /// Set minimum duration threshold
176    pub fn with_min_duration_us(mut self, min_us: u64) -> Self {
177        self.min_duration_us = min_us;
178        self
179    }
180}
181
182/// Function attribute registry for managing profiling attributes
183pub struct AttributeRegistry {
184    attributes: std::collections::HashMap<String, ProfileAttribute>,
185    global_enabled: bool,
186}
187
188impl Default for AttributeRegistry {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194impl AttributeRegistry {
195    /// Create a new attribute registry
196    pub fn new() -> Self {
197        Self {
198            attributes: std::collections::HashMap::new(),
199            global_enabled: true,
200        }
201    }
202
203    /// Register a function with profiling attributes
204    pub fn register<S: Into<String>>(&mut self, function_name: S, attr: ProfileAttribute) {
205        self.attributes.insert(function_name.into(), attr);
206    }
207
208    /// Get profiling attributes for a function
209    pub fn get_attributes(&self, function_name: &str) -> Option<&ProfileAttribute> {
210        self.attributes.get(function_name)
211    }
212
213    /// Enable/disable all profiling
214    pub fn set_enabled(&mut self, enabled: bool) {
215        self.global_enabled = enabled;
216    }
217
218    /// Check if profiling is globally enabled
219    pub fn is_enabled(&self) -> bool {
220        self.global_enabled
221    }
222
223    /// Check if a specific function should be profiled
224    pub fn should_profile(&self, function_name: &str, call_count: usize) -> bool {
225        if !self.global_enabled {
226            return false;
227        }
228
229        if let Some(attr) = self.attributes.get(function_name) {
230            call_count % attr.sample_rate == 0
231        } else {
232            false
233        }
234    }
235}
236
237/// Global attribute registry
238static mut GLOBAL_REGISTRY: Option<AttributeRegistry> = None;
239static REGISTRY_INIT: std::sync::Once = std::sync::Once::new();
240
241/// Get the global attribute registry
242pub fn get_registry() -> &'static mut AttributeRegistry {
243    unsafe {
244        REGISTRY_INIT.call_once(|| {
245            GLOBAL_REGISTRY = Some(AttributeRegistry::new());
246        });
247        GLOBAL_REGISTRY
248            .as_mut()
249            .expect("GLOBAL_REGISTRY should be initialized by call_once")
250    }
251}
252
253/// Wrapper function that applies profiling attributes to any function
254pub fn with_profiling<F, R>(function_name: &str, func: F) -> R
255where
256    F: FnOnce() -> R,
257{
258    let registry = get_registry();
259
260    // Check if we should profile this call
261    static CALL_COUNTS: Lazy<Mutex<std::collections::HashMap<String, usize>>> =
262        Lazy::new(|| Mutex::new(std::collections::HashMap::new()));
263    let call_count = {
264        let mut counts = CALL_COUNTS.lock();
265        let count = counts.entry(function_name.to_string()).or_insert(0);
266        *count += 1;
267        *count
268    };
269
270    if !registry.should_profile(function_name, call_count) {
271        return func();
272    }
273
274    let attr = registry.get_attributes(function_name);
275
276    // Determine profiling name and category
277    let profile_name = attr
278        .and_then(|a| a.name.as_ref())
279        .cloned()
280        .unwrap_or_else(|| function_name.to_string());
281
282    let profile_category = attr
283        .and_then(|a| a.category.as_ref())
284        .cloned()
285        .unwrap_or_else(|| "function".to_string());
286
287    let start_time = Instant::now();
288
289    // Set up profiling scope
290    let _guard = ProfileScope::simple(profile_name.clone(), profile_category.clone());
291
292    // Execute the function
293    let result = func();
294
295    let duration = start_time.elapsed();
296    let duration_us = duration.as_micros() as u64;
297
298    // Check minimum duration threshold
299    if let Some(attr) = attr {
300        if duration_us < attr.min_duration_us {
301            return result;
302        }
303    }
304
305    result
306}
307
308/// Helper macro for creating profiled function wrappers
309#[macro_export]
310macro_rules! profiled_fn {
311    ($name:expr, $func:expr) => {
312        $crate::attributes::ProfiledFunction::new($func, $name.to_string(), "function".to_string())
313    };
314    ($name:expr, $category:expr, $func:expr) => {
315        $crate::attributes::ProfiledFunction::new($func, $name.to_string(), $category.to_string())
316    };
317}
318
319/// Attribute-like macro for profiling functions
320#[macro_export]
321macro_rules! profile_attribute {
322    // Basic profiling
323    (#[profile]) => {
324        let _attr_guard = $crate::cpu::ProfileScope::simple(
325            format!("{}::{}", module_path!(), function_name!()),
326            "function".to_string(),
327        );
328    };
329
330    // Profiling with custom name
331    (#[profile(name = $name:expr)]) => {
332        let _attr_guard =
333            $crate::cpu::ProfileScope::simple($name.to_string(), "function".to_string());
334    };
335
336    // Profiling with custom name and category
337    (#[profile(name = $name:expr, category = $category:expr)]) => {
338        let _attr_guard =
339            $crate::cpu::ProfileScope::simple($name.to_string(), $category.to_string());
340    };
341
342    // Profiling with sampling
343    (#[profile(sample_rate = $rate:expr)]) => {
344        use std::sync::atomic::{AtomicUsize, Ordering};
345        static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
346
347        let call_num = CALL_COUNT.fetch_add(1, Ordering::Relaxed);
348        let _attr_guard = if call_num % $rate == 0 {
349            Some($crate::cpu::ProfileScope::simple(
350                format!("{}::{}", module_path!(), function_name!()),
351                "sampled_function".to_string(),
352            ))
353        } else {
354            None
355        };
356    };
357}
358
359/// Method profiling wrapper for structs
360pub trait ProfiledStruct {
361    /// Execute a method with profiling
362    fn profiled_method<F, R>(&self, method_name: &str, func: F) -> R
363    where
364        F: FnOnce(&Self) -> R,
365    {
366        let type_name = type_name::<Self>();
367        let full_name = format!("{type_name}::{method_name}");
368
369        let _guard = ProfileScope::simple(full_name, "method".to_string());
370        func(self)
371    }
372
373    /// Execute a mutable method with profiling
374    fn profiled_method_mut<F, R>(&mut self, method_name: &str, func: F) -> R
375    where
376        F: FnOnce(&mut Self) -> R,
377    {
378        let type_name = type_name::<Self>();
379        let full_name = format!("{type_name}::{method_name}");
380
381        let _guard = ProfileScope::simple(full_name, "method".to_string());
382        func(self)
383    }
384}
385
386/// Blanket implementation for all types
387impl<T> ProfiledStruct for T {}
388
389/// Conditional profiling based on feature flags or runtime conditions
390pub struct ConditionalProfiler {
391    condition: Box<dyn Fn() -> bool + Send + Sync>,
392    fallback_enabled: bool,
393}
394
395impl ConditionalProfiler {
396    /// Create a new conditional profiler
397    pub fn new<F>(condition: F) -> Self
398    where
399        F: Fn() -> bool + Send + Sync + 'static,
400    {
401        Self {
402            condition: Box::new(condition),
403            fallback_enabled: true,
404        }
405    }
406
407    /// Create a conditional profiler that only profiles in debug mode
408    pub fn debug_only() -> Self {
409        Self::new(|| cfg!(debug_assertions))
410    }
411
412    /// Create a conditional profiler based on an environment variable
413    pub fn env_var(var_name: &str) -> Self {
414        let var_name = var_name.to_string();
415        Self::new(move || {
416            std::env::var(&var_name)
417                .map(|v| v == "1" || v.to_lowercase() == "true")
418                .unwrap_or(false)
419        })
420    }
421
422    /// Create a conditional profiler based on a feature flag
423    pub fn feature_flag(feature: &str) -> Self {
424        let enabled = feature == "profiling";
425        Self::new(move || enabled)
426    }
427
428    /// Execute a function with conditional profiling
429    pub fn profile<F, R>(&self, name: &str, category: &str, func: F) -> R
430    where
431        F: FnOnce() -> R,
432    {
433        if (self.condition)() {
434            let _guard = ProfileScope::simple(name.to_string(), category.to_string());
435            func()
436        } else {
437            func()
438        }
439    }
440}
441
442/// Helper for async function profiling
443pub struct AsyncProfiler;
444
445impl AsyncProfiler {
446    /// Profile an async function
447    pub async fn profile<F, Fut, R>(name: &str, category: &str, func: F) -> R
448    where
449        F: FnOnce() -> Fut,
450        Fut: std::future::Future<Output = R>,
451    {
452        let _guard = ProfileScope::simple(name.to_string(), category.to_string());
453        func().await
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use std::time::Duration;
461
462    #[test]
463    fn test_profile_attribute_creation() {
464        let attr = ProfileAttribute::new()
465            .with_name("test_function")
466            .with_category("test")
467            .with_stack_trace()
468            .with_memory_tracking()
469            .with_sample_rate(5)
470            .with_min_duration_us(1000);
471
472        assert_eq!(attr.name, Some("test_function".to_string()));
473        assert_eq!(attr.category, Some("test".to_string()));
474        assert!(attr.stack_trace);
475        assert!(attr.track_memory);
476        assert_eq!(attr.sample_rate, 5);
477        assert_eq!(attr.min_duration_us, 1000);
478    }
479
480    #[test]
481    fn test_attribute_registry() {
482        let mut registry = AttributeRegistry::new();
483
484        let attr = ProfileAttribute::new()
485            .with_name("test_func")
486            .with_category("test");
487
488        registry.register("my_function", attr);
489
490        let retrieved = registry.get_attributes("my_function");
491        assert!(retrieved.is_some());
492        assert_eq!(retrieved.unwrap().name, Some("test_func".to_string()));
493    }
494
495    #[test]
496    fn test_sampling() {
497        let mut registry = AttributeRegistry::new();
498
499        let attr = ProfileAttribute::new().with_sample_rate(3);
500        registry.register("sampled_func", attr);
501
502        // Should profile on calls 3, 6, 9, etc.
503        assert!(!registry.should_profile("sampled_func", 1));
504        assert!(!registry.should_profile("sampled_func", 2));
505        assert!(registry.should_profile("sampled_func", 3));
506        assert!(!registry.should_profile("sampled_func", 4));
507        assert!(!registry.should_profile("sampled_func", 5));
508        assert!(registry.should_profile("sampled_func", 6));
509    }
510
511    #[test]
512    fn test_profiled_function() {
513        let func = || {
514            std::thread::sleep(Duration::from_millis(1));
515            42
516        };
517
518        let profiled = ProfiledFunction::new(func, "test_func".to_string(), "test".to_string());
519        let result = profiled.call();
520        assert_eq!(result, 42);
521    }
522
523    #[test]
524    fn test_with_profiling() {
525        let result = with_profiling("test_function", || {
526            std::thread::sleep(Duration::from_millis(1));
527            "success"
528        });
529        assert_eq!(result, "success");
530    }
531
532    #[test]
533    fn test_profiled_struct() {
534        struct TestStruct {
535            value: i32,
536        }
537
538        let mut test_struct = TestStruct { value: 42 };
539
540        let result = test_struct.profiled_method("get_value", |s| s.value);
541        assert_eq!(result, 42);
542
543        test_struct.profiled_method_mut("set_value", |s| {
544            s.value = 100;
545        });
546        assert_eq!(test_struct.value, 100);
547    }
548
549    #[test]
550    fn test_conditional_profiler() {
551        let profiler = ConditionalProfiler::new(|| true);
552
553        let result = profiler.profile("test_op", "test", || {
554            std::thread::sleep(Duration::from_millis(1));
555            "conditional_result"
556        });
557        assert_eq!(result, "conditional_result");
558
559        // Test with false condition
560        let profiler = ConditionalProfiler::new(|| false);
561        let result = profiler.profile("test_op", "test", || {
562            std::thread::sleep(Duration::from_millis(1));
563            "not_profiled"
564        });
565        assert_eq!(result, "not_profiled");
566    }
567
568    #[test]
569    fn test_debug_only_profiler() {
570        let profiler = ConditionalProfiler::debug_only();
571
572        let result = profiler.profile("debug_op", "debug", || "debug_result");
573        assert_eq!(result, "debug_result");
574    }
575
576    #[tokio::test]
577    async fn test_async_profiler() {
578        let result = AsyncProfiler::profile("async_test", "async", || async {
579            tokio::time::sleep(Duration::from_millis(1)).await;
580            "async_success"
581        })
582        .await;
583
584        assert_eq!(result, "async_success");
585    }
586
587    #[test]
588    fn test_profiled_fn_macro() {
589        let func = || {
590            std::thread::sleep(Duration::from_millis(1));
591            "macro_result"
592        };
593
594        let profiled = profiled_fn!("macro_test", func);
595        let result = profiled.call();
596        assert_eq!(result, "macro_result");
597    }
598}