sklears_utils/
debug.rs

1//! Debug utilities for development and troubleshooting
2//!
3//! This module provides debugging helper functions, assertion macros,
4//! test data generation, and diagnostic utilities for ML development.
5
6use crate::{UtilsError, UtilsResult};
7use std::collections::HashMap;
8use std::fmt;
9use std::hash::Hash;
10use std::time::{Duration, Instant};
11
12/// Debug context for collecting debugging information
13#[derive(Debug, Clone)]
14pub struct DebugContext {
15    pub module: String,
16    pub function: String,
17    pub line: u32,
18    pub timestamp: Instant,
19    pub metadata: HashMap<String, String>,
20}
21
22impl DebugContext {
23    /// Create a new debug context
24    pub fn new(module: &str, function: &str, line: u32) -> Self {
25        Self {
26            module: module.to_string(),
27            function: function.to_string(),
28            line,
29            timestamp: Instant::now(),
30            metadata: HashMap::new(),
31        }
32    }
33
34    /// Add metadata to the debug context
35    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
36        self.metadata.insert(key.to_string(), value.to_string());
37        self
38    }
39
40    /// Format the debug context as a string
41    pub fn format(&self) -> String {
42        let mut result = format!("{}::{}:{}", self.module, self.function, self.line);
43        if !self.metadata.is_empty() {
44            result.push_str(" {");
45            let metadata_strs: Vec<String> = self
46                .metadata
47                .iter()
48                .map(|(k, v)| format!("{k}={v}"))
49                .collect();
50            result.push_str(&metadata_strs.join(", "));
51            result.push('}');
52        }
53        result
54    }
55}
56
57/// Macro for creating debug contexts easily
58#[macro_export]
59macro_rules! debug_context {
60    () => {
61        DebugContext::new(module_path!(), function_name!(), line!())
62    };
63    ($($key:expr => $value:expr),*) => {
64        DebugContext::new(module_path!(), function_name!(), line!())
65            $(.with_metadata($key, &$value.to_string()))*
66    };
67}
68
69/// Enhanced assertion macro with debugging information
70#[macro_export]
71macro_rules! debug_assert_msg {
72    ($condition:expr, $msg:expr) => {
73        if !$condition {
74            panic!(
75                "Assertion failed at {}:{}:{}: {}\nCondition: {}",
76                module_path!(),
77                file!(),
78                line!(),
79                $msg,
80                stringify!($condition)
81            );
82        }
83    };
84    ($condition:expr, $msg:expr, $($arg:tt)*) => {
85        if !$condition {
86            panic!(
87                "Assertion failed at {}:{}:{}: {}\nCondition: {}",
88                module_path!(),
89                file!(),
90                line!(),
91                format!($msg, $($arg)*),
92                stringify!($condition)
93            );
94        }
95    };
96}
97
98/// Array debugging utilities
99pub struct ArrayDebugger;
100
101impl ArrayDebugger {
102    /// Debug array statistics
103    pub fn array_stats<T: fmt::Debug + Clone>(array: &[T]) -> String {
104        format!(
105            "Array[len={}, type={}]",
106            array.len(),
107            std::any::type_name::<T>()
108        )
109    }
110
111    /// Debug array shape for multidimensional arrays
112    pub fn array_shape_info(shape: &[usize]) -> String {
113        format!(
114            "Shape: {:?}, Total elements: {}",
115            shape,
116            shape.iter().product::<usize>()
117        )
118    }
119
120    /// Find array differences for debugging
121    pub fn compare_arrays<T: PartialEq + fmt::Debug>(a: &[T], b: &[T]) -> Vec<String> {
122        let mut differences = Vec::new();
123
124        if a.len() != b.len() {
125            differences.push(format!("Length mismatch: {} vs {}", b.len(), a.len()));
126        }
127
128        let min_len = a.len().min(b.len());
129        for i in 0..min_len {
130            if a[i] != b[i] {
131                differences.push(format!("Element {} differs: {:?} vs {:?}", i, a[i], b[i]));
132            }
133        }
134
135        differences
136    }
137
138    /// Check for NaN and infinite values in float arrays
139    pub fn check_float_array(array: &[f64]) -> Vec<String> {
140        let mut issues = Vec::new();
141
142        for (i, &value) in array.iter().enumerate() {
143            if value.is_nan() {
144                issues.push(format!("NaN at index {i}"));
145            } else if value.is_infinite() {
146                issues.push(format!("Infinite value at index {i}: {value}"));
147            }
148        }
149
150        issues
151    }
152}
153
154/// Memory debugging utilities
155pub struct MemoryDebugger;
156
157impl MemoryDebugger {
158    /// Get memory usage information (simplified version)
159    pub fn memory_info() -> String {
160        // Note: In a real implementation, you'd use system calls or external crates
161        // for actual memory information. This is a placeholder.
162        "Memory debugging not fully implemented in this version".to_string()
163    }
164
165    /// Debug heap allocations (placeholder)
166    pub fn heap_info() -> String {
167        "Heap debugging requires external dependencies".to_string()
168    }
169
170    /// Check for potential memory leaks (placeholder)
171    pub fn leak_check() -> Vec<String> {
172        vec!["Memory leak detection requires runtime support".to_string()]
173    }
174}
175
176/// Performance debugging utilities
177#[derive(Debug, Clone)]
178pub struct PerformanceDebugger {
179    timers: HashMap<String, Instant>,
180    durations: HashMap<String, Vec<Duration>>,
181}
182
183impl Default for PerformanceDebugger {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl PerformanceDebugger {
190    /// Create a new performance debugger
191    pub fn new() -> Self {
192        Self {
193            timers: HashMap::new(),
194            durations: HashMap::new(),
195        }
196    }
197
198    /// Start timing an operation
199    pub fn start_timer(&mut self, name: &str) {
200        self.timers.insert(name.to_string(), Instant::now());
201    }
202
203    /// Stop timing and record duration
204    pub fn stop_timer(&mut self, name: &str) -> Option<Duration> {
205        if let Some(start) = self.timers.remove(name) {
206            let duration = start.elapsed();
207            self.durations
208                .entry(name.to_string())
209                .or_default()
210                .push(duration);
211            Some(duration)
212        } else {
213            None
214        }
215    }
216
217    /// Get timing statistics
218    pub fn timing_stats(&self, name: &str) -> Option<TimingStats> {
219        self.durations.get(name).map(|durations| {
220            let count = durations.len();
221            let total: Duration = durations.iter().sum();
222            let avg = total / count as u32;
223            let min = *durations.iter().min().unwrap();
224            let max = *durations.iter().max().unwrap();
225
226            TimingStats {
227                name: name.to_string(),
228                count,
229                total,
230                average: avg,
231                min,
232                max,
233            }
234        })
235    }
236
237    /// Get all timing statistics
238    pub fn all_stats(&self) -> Vec<TimingStats> {
239        self.durations
240            .keys()
241            .filter_map(|name| self.timing_stats(name))
242            .collect()
243    }
244
245    /// Format timing report
246    pub fn timing_report(&self) -> String {
247        let mut report = String::from("Performance Report:\n");
248
249        for stats in self.all_stats() {
250            report.push_str(&format!(
251                "  {}: {} calls, avg: {:?}, min: {:?}, max: {:?}, total: {:?}\n",
252                stats.name, stats.count, stats.average, stats.min, stats.max, stats.total
253            ));
254        }
255
256        report
257    }
258}
259
260/// Timing statistics for performance analysis
261#[derive(Debug, Clone)]
262pub struct TimingStats {
263    pub name: String,
264    pub count: usize,
265    pub total: Duration,
266    pub average: Duration,
267    pub min: Duration,
268    pub max: Duration,
269}
270
271/// Test data generation utilities
272pub struct TestDataGenerator;
273
274impl TestDataGenerator {
275    /// Generate random integers in a range
276    pub fn random_integers(count: usize, min: i32, max: i32) -> Vec<i32> {
277        use std::collections::hash_map::DefaultHasher;
278        use std::hash::Hasher;
279
280        let mut result = Vec::with_capacity(count);
281        let range = (max - min + 1) as u64;
282
283        for i in 0..count {
284            let mut hasher = DefaultHasher::new();
285            (i as u64 + 12345).hash(&mut hasher); // Add seed for better distribution
286            let hash_value = hasher.finish();
287            let value = min + ((hash_value % range) as i32);
288            result.push(value);
289        }
290
291        result
292    }
293
294    /// Generate random floats in a range
295    pub fn random_floats(count: usize, min: f64, max: f64) -> Vec<f64> {
296        use std::collections::hash_map::DefaultHasher;
297        use std::hash::Hasher;
298
299        let mut result = Vec::with_capacity(count);
300
301        for i in 0..count {
302            let mut hasher = DefaultHasher::new();
303            (i as u64 + 54321).hash(&mut hasher); // Add different seed
304            let hash_value = hasher.finish();
305            let normalized = (hash_value as f64) / (u64::MAX as f64);
306            let value = min + normalized * (max - min);
307            result.push(value);
308        }
309
310        result
311    }
312
313    /// Generate test matrix data
314    pub fn test_matrix(rows: usize, cols: usize) -> Vec<Vec<f64>> {
315        let mut matrix = Vec::with_capacity(rows);
316
317        for i in 0..rows {
318            let mut row = Vec::with_capacity(cols);
319            for j in 0..cols {
320                let value = (i * cols + j) as f64 / (rows * cols) as f64;
321                row.push(value);
322            }
323            matrix.push(row);
324        }
325
326        matrix
327    }
328
329    /// Generate test strings
330    pub fn test_strings(count: usize) -> Vec<String> {
331        let prefixes = ["test", "data", "sample", "debug", "item"];
332        let mut result = Vec::with_capacity(count);
333
334        for i in 0..count {
335            let prefix = prefixes[i % prefixes.len()];
336            result.push(format!("{prefix}_{i}"));
337        }
338
339        result
340    }
341
342    /// Generate pathological test cases
343    pub fn pathological_cases() -> HashMap<String, Vec<f64>> {
344        let mut cases = HashMap::new();
345
346        cases.insert("empty".to_string(), vec![]);
347        cases.insert("single".to_string(), vec![1.0]);
348        cases.insert("zeros".to_string(), vec![0.0; 10]);
349        cases.insert("ones".to_string(), vec![1.0; 10]);
350        cases.insert("alternating".to_string(), vec![1.0, -1.0, 1.0, -1.0, 1.0]);
351        cases.insert("large_values".to_string(), vec![1e10, 1e11, 1e12]);
352        cases.insert("small_values".to_string(), vec![1e-10, 1e-11, 1e-12]);
353        cases.insert(
354            "mixed_signs".to_string(),
355            vec![-100.0, -1.0, 0.0, 1.0, 100.0],
356        );
357
358        // Add special float values
359        cases.insert(
360            "special_floats".to_string(),
361            vec![
362                f64::INFINITY,
363                f64::NEG_INFINITY,
364                f64::NAN,
365                f64::MIN,
366                f64::MAX,
367                f64::EPSILON,
368            ],
369        );
370
371        cases
372    }
373}
374
375/// Diagnostic utilities for troubleshooting
376pub struct DiagnosticTools;
377
378impl DiagnosticTools {
379    /// Run basic system diagnostics
380    pub fn system_check() -> Vec<String> {
381        let mut checks = Vec::new();
382
383        // Check basic Rust environment
384        checks.push(format!("Target architecture: {}", std::env::consts::ARCH));
385        checks.push(format!("Operating system: {}", std::env::consts::OS));
386        checks.push(format!(
387            "Profile: {}",
388            if cfg!(debug_assertions) {
389                "debug"
390            } else {
391                "release"
392            }
393        ));
394
395        // Check available features
396        checks.push(format!(
397            "Float support: {}",
398            if f64::INFINITY.is_infinite() {
399                "OK"
400            } else {
401                "FAIL"
402            }
403        ));
404        checks.push(format!(
405            "Threading: {}",
406            if std::thread::available_parallelism().is_ok() {
407                "OK"
408            } else {
409                "LIMITED"
410            }
411        ));
412
413        checks
414    }
415
416    /// Validate algorithm inputs
417    pub fn validate_ml_inputs(
418        features: &[Vec<f64>],
419        targets: Option<&[f64]>,
420    ) -> UtilsResult<Vec<String>> {
421        let mut warnings = Vec::new();
422
423        if features.is_empty() {
424            return Err(UtilsError::InvalidParameter(
425                "Empty feature matrix".to_string(),
426            ));
427        }
428
429        let feature_count = features[0].len();
430
431        // Check feature consistency
432        for (i, row) in features.iter().enumerate() {
433            if row.len() != feature_count {
434                warnings.push(format!(
435                    "Row {} has {} features, expected {}",
436                    i,
437                    row.len(),
438                    feature_count
439                ));
440            }
441
442            // Check for NaN/infinite values
443            for (j, &value) in row.iter().enumerate() {
444                if value.is_nan() {
445                    warnings.push(format!("NaN found at row {i}, col {j}"));
446                } else if value.is_infinite() {
447                    warnings.push(format!("Infinite value found at row {i}, col {j}: {value}"));
448                }
449            }
450        }
451
452        // Check targets if provided
453        if let Some(targets) = targets {
454            if targets.len() != features.len() {
455                warnings.push(format!(
456                    "Target count ({}) doesn't match sample count ({})",
457                    targets.len(),
458                    features.len()
459                ));
460            }
461
462            for (i, &target) in targets.iter().enumerate() {
463                if target.is_nan() {
464                    warnings.push(format!("NaN target at index {i}"));
465                } else if target.is_infinite() {
466                    warnings.push(format!("Infinite target at index {i}: {target}"));
467                }
468            }
469        }
470
471        Ok(warnings)
472    }
473
474    /// Visualize data distribution (simplified text histogram)
475    pub fn text_histogram(data: &[f64], bins: usize) -> String {
476        if data.is_empty() {
477            return "No data to visualize".to_string();
478        }
479
480        let min = data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
481        let max = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
482
483        if min == max {
484            return format!("All values equal: {min}");
485        }
486
487        let bin_width = (max - min) / bins as f64;
488        let mut counts = vec![0; bins];
489
490        for &value in data {
491            let bin = ((value - min) / bin_width).floor() as usize;
492            let bin = bin.min(bins - 1);
493            counts[bin] += 1;
494        }
495
496        let max_count = *counts.iter().max().unwrap();
497        let scale = 50.0 / max_count as f64;
498
499        let mut result = String::from("Data Distribution:\n");
500        for (i, &count) in counts.iter().enumerate() {
501            let bin_start = min + i as f64 * bin_width;
502            let bin_end = bin_start + bin_width;
503            let bar_length = (count as f64 * scale) as usize;
504            let bar = "█".repeat(bar_length);
505
506            result.push_str(&format!(
507                "[{bin_start:8.2} - {bin_end:8.2}]: {count:4} {bar}\n"
508            ));
509        }
510
511        result
512    }
513}
514
515/// Macro for quick performance timing
516#[macro_export]
517macro_rules! time_it {
518    ($name:expr, $code:block) => {{
519        let start = std::time::Instant::now();
520        let result = $code;
521        let duration = start.elapsed();
522        println!("{}: {:?}", $name, duration);
523        result
524    }};
525}
526
527/// Macro for conditional debugging output
528#[macro_export]
529macro_rules! debug_println {
530    ($($arg:tt)*) => {
531        if cfg!(debug_assertions) {
532            println!("[DEBUG] {}", format!($($arg)*));
533        }
534    };
535}
536
537#[allow(non_snake_case)]
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_debug_context() {
544        let ctx = DebugContext::new("test_module", "test_function", 42)
545            .with_metadata("key1", "value1")
546            .with_metadata("key2", "value2");
547
548        let formatted = ctx.format();
549        assert!(formatted.contains("test_module::test_function:42"));
550        assert!(formatted.contains("key1=value1"));
551        assert!(formatted.contains("key2=value2"));
552    }
553
554    #[test]
555    fn test_array_debugger() {
556        let array = vec![1, 2, 3, 4, 5];
557        let stats = ArrayDebugger::array_stats(&array);
558        assert!(stats.contains("len=5"));
559
560        let shape_info = ArrayDebugger::array_shape_info(&[2, 3, 4]);
561        assert!(shape_info.contains("Total elements: 24"));
562
563        let diffs = ArrayDebugger::compare_arrays(&[1, 2, 3], &[1, 2, 4]);
564        assert_eq!(diffs.len(), 1);
565        assert!(diffs[0].contains("Element 2 differs"));
566    }
567
568    #[test]
569    fn test_float_array_check() {
570        let array = vec![1.0, 2.0, f64::NAN, f64::INFINITY, 5.0];
571        let issues = ArrayDebugger::check_float_array(&array);
572        assert_eq!(issues.len(), 2);
573        assert!(issues.iter().any(|s| s.contains("NaN")));
574        assert!(issues.iter().any(|s| s.contains("Infinite")));
575    }
576
577    #[test]
578    fn test_performance_debugger() {
579        let mut debugger = PerformanceDebugger::new();
580
581        debugger.start_timer("test_operation");
582        std::thread::sleep(std::time::Duration::from_millis(1));
583        let duration = debugger.stop_timer("test_operation");
584
585        assert!(duration.is_some());
586        assert!(duration.unwrap().as_millis() >= 1);
587
588        let stats = debugger.timing_stats("test_operation");
589        assert!(stats.is_some());
590        assert_eq!(stats.unwrap().count, 1);
591    }
592
593    #[test]
594    fn test_test_data_generator() {
595        let integers = TestDataGenerator::random_integers(10, 1, 100);
596        assert_eq!(integers.len(), 10);
597        assert!(integers.iter().all(|&x| x >= 1 && x <= 100));
598
599        let floats = TestDataGenerator::random_floats(10, 0.0, 1.0);
600        assert_eq!(floats.len(), 10);
601        assert!(floats.iter().all(|&x| x >= 0.0 && x <= 1.0));
602
603        let matrix = TestDataGenerator::test_matrix(3, 4);
604        assert_eq!(matrix.len(), 3);
605        assert!(matrix.iter().all(|row| row.len() == 4));
606
607        let strings = TestDataGenerator::test_strings(5);
608        assert_eq!(strings.len(), 5);
609        assert!(strings.iter().all(|s| !s.is_empty()));
610    }
611
612    #[test]
613    fn test_pathological_cases() {
614        let cases = TestDataGenerator::pathological_cases();
615        assert!(cases.contains_key("empty"));
616        assert!(cases.contains_key("special_floats"));
617        assert_eq!(cases["empty"].len(), 0);
618        assert_eq!(cases["single"].len(), 1);
619    }
620
621    #[test]
622    fn test_diagnostic_tools() {
623        let checks = DiagnosticTools::system_check();
624        assert!(!checks.is_empty());
625        assert!(checks.iter().any(|s| s.contains("Target architecture")));
626
627        let features = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
628        let targets = vec![1.0, 2.0];
629
630        let warnings = DiagnosticTools::validate_ml_inputs(&features, Some(&targets)).unwrap();
631        assert!(warnings.is_empty());
632
633        // Test with inconsistent data
634        let bad_features = vec![vec![1.0, 2.0], vec![3.0]];
635        let warnings = DiagnosticTools::validate_ml_inputs(&bad_features, Some(&targets)).unwrap();
636        assert!(!warnings.is_empty());
637    }
638
639    #[test]
640    fn test_text_histogram() {
641        let data = vec![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 5.0];
642        let histogram = DiagnosticTools::text_histogram(&data, 5);
643        assert!(histogram.contains("Data Distribution:"));
644        assert!(histogram.contains("█"));
645
646        let empty_data = vec![];
647        let empty_histogram = DiagnosticTools::text_histogram(&empty_data, 5);
648        assert!(empty_histogram.contains("No data"));
649
650        let uniform_data = vec![1.0, 1.0, 1.0];
651        let uniform_histogram = DiagnosticTools::text_histogram(&uniform_data, 5);
652        assert!(uniform_histogram.contains("All values equal"));
653    }
654}