temporal_neural_solver/core/
utils.rs

1//! Utility functions for the temporal neural solver
2
3use crate::core::types::*;
4use crate::core::errors::*;
5use std::time::{Duration, Instant};
6use std::collections::HashMap;
7
8/// Timing utilities
9pub struct Timer {
10    start: Instant,
11    name: String,
12}
13
14impl Timer {
15    pub fn new(name: String) -> Self {
16        Self {
17            start: Instant::now(),
18            name,
19        }
20    }
21
22    pub fn elapsed(&self) -> Duration {
23        self.start.elapsed()
24    }
25
26    pub fn elapsed_micros(&self) -> f64 {
27        self.elapsed().as_secs_f64() * 1_000_000.0
28    }
29
30    pub fn elapsed_nanos(&self) -> u128 {
31        self.elapsed().as_nanos()
32    }
33}
34
35impl Drop for Timer {
36    fn drop(&mut self) {
37        let elapsed = self.elapsed();
38        println!("⏱️  {} took {:.3}µs", self.name, elapsed.as_secs_f64() * 1_000_000.0);
39    }
40}
41
42/// Create a timer with RAII semantics
43pub fn time_block(name: &str) -> Timer {
44    Timer::new(name.to_string())
45}
46
47/// Calculate performance metrics from timing data
48pub fn calculate_metrics(timings: &[Duration]) -> PerformanceMetrics {
49    if timings.is_empty() {
50        return PerformanceMetrics {
51            min_latency: Duration::ZERO,
52            max_latency: Duration::ZERO,
53            mean_latency: Duration::ZERO,
54            p50_latency: Duration::ZERO,
55            p90_latency: Duration::ZERO,
56            p99_latency: Duration::ZERO,
57            p999_latency: Duration::ZERO,
58            throughput_ops_per_sec: 0.0,
59            samples: 0,
60        };
61    }
62
63    let mut sorted_timings = timings.to_vec();
64    sorted_timings.sort_unstable();
65
66    let n = sorted_timings.len();
67    let sum: Duration = sorted_timings.iter().sum();
68    let mean = sum / n as u32;
69
70    let min_latency = sorted_timings[0];
71    let max_latency = sorted_timings[n - 1];
72    let p50_latency = sorted_timings[n / 2];
73    let p90_latency = sorted_timings[(n * 9) / 10];
74    let p99_latency = sorted_timings[(n * 99) / 100];
75    let p999_latency = sorted_timings[((n * 999) / 1000).min(n - 1)];
76
77    let throughput_ops_per_sec = if p50_latency > Duration::ZERO {
78        1.0 / p50_latency.as_secs_f64()
79    } else {
80        0.0
81    };
82
83    PerformanceMetrics {
84        min_latency,
85        max_latency,
86        mean_latency: mean,
87        p50_latency,
88        p90_latency,
89        p99_latency,
90        p999_latency,
91        throughput_ops_per_sec,
92        samples: n,
93    }
94}
95
96/// Format duration in human-readable form
97pub fn format_duration(duration: Duration) -> String {
98    let nanos = duration.as_nanos();
99
100    if nanos < 1_000 {
101        format!("{}ns", nanos)
102    } else if nanos < 1_000_000 {
103        format!("{:.1}µs", nanos as f64 / 1_000.0)
104    } else if nanos < 1_000_000_000 {
105        format!("{:.1}ms", nanos as f64 / 1_000_000.0)
106    } else {
107        format!("{:.1}s", nanos as f64 / 1_000_000_000.0)
108    }
109}
110
111/// Validate input vector dimensions
112pub fn validate_input(input: &[f32], expected_size: usize) -> Result<()> {
113    if input.len() != expected_size {
114        return Err(TemporalSolverError::DimensionMismatch {
115            expected: expected_size,
116            got: input.len(),
117        });
118    }
119
120    // Check for NaN or infinite values
121    for (i, &value) in input.iter().enumerate() {
122        if !value.is_finite() {
123            return Err(TemporalSolverError::NumericalError(
124                format!("Invalid value {} at index {}", value, i)
125            ));
126        }
127    }
128
129    Ok(())
130}
131
132/// Generate deterministic test data
133pub fn generate_test_data(size: usize, seed: u64) -> Vec<f32> {
134    let mut data = Vec::with_capacity(size);
135
136    for i in 0..size {
137        // Use simple but deterministic pattern
138        let value = ((i as f64 * 0.01 + seed as f64 * 0.001).sin() + 1.0) * 0.5;
139        data.push(value as f32);
140    }
141
142    data
143}
144
145/// Generate test input vector
146pub fn generate_test_input(seed: u64) -> InputVector {
147    let data = generate_test_data(128, seed);
148    let mut input = [0.0f32; 128];
149    input.copy_from_slice(&data);
150    input
151}
152
153/// Compare two output vectors with tolerance
154pub fn compare_outputs(a: &OutputVector, b: &OutputVector, tolerance: f32) -> bool {
155    for (&x, &y) in a.iter().zip(b.iter()) {
156        let diff = (x - y).abs();
157        let rel_error = if x.abs() > tolerance {
158            diff / x.abs()
159        } else {
160            diff
161        };
162
163        if rel_error > tolerance {
164            return false;
165        }
166    }
167
168    true
169}
170
171/// Calculate speedup ratio
172pub fn calculate_speedup(baseline_duration: Duration, optimized_duration: Duration) -> f64 {
173    if optimized_duration > Duration::ZERO {
174        baseline_duration.as_secs_f64() / optimized_duration.as_secs_f64()
175    } else {
176        0.0
177    }
178}
179
180/// Detect CPU features at runtime
181pub fn detect_cpu_features() -> HardwareFeatures {
182    #[cfg(target_arch = "x86_64")]
183    {
184        HardwareFeatures {
185            has_avx: is_x86_feature_detected!("avx"),
186            has_avx2: is_x86_feature_detected!("avx2"),
187            has_avx512: is_x86_feature_detected!("avx512f"),
188            has_fma: is_x86_feature_detected!("fma"),
189            has_sse4_2: is_x86_feature_detected!("sse4.2"),
190            cpu_cores: num_cpus::get_physical(),
191            cache_line_size: 64, // Typical x86_64 cache line size
192        }
193    }
194
195    #[cfg(not(target_arch = "x86_64"))]
196    {
197        HardwareFeatures {
198            has_avx: false,
199            has_avx2: false,
200            has_avx512: false,
201            has_fma: false,
202            has_sse4_2: false,
203            cpu_cores: num_cpus::get_physical(),
204            cache_line_size: 64,
205        }
206    }
207}
208
209/// System information gathering
210pub fn get_system_info() -> HashMap<String, String> {
211    let mut info = HashMap::new();
212
213    info.insert("arch".to_string(), std::env::consts::ARCH.to_string());
214    info.insert("os".to_string(), std::env::consts::OS.to_string());
215    info.insert("family".to_string(), std::env::consts::FAMILY.to_string());
216
217    let features = detect_cpu_features();
218    info.insert("cpu_cores".to_string(), features.cpu_cores.to_string());
219    info.insert("has_avx2".to_string(), features.has_avx2.to_string());
220    info.insert("has_avx512".to_string(), features.has_avx512.to_string());
221
222    // Add environment variables
223    if let Ok(rust_version) = std::env::var("CARGO_PKG_RUST_VERSION") {
224        info.insert("rust_version".to_string(), rust_version);
225    }
226
227    if let Ok(profile) = std::env::var("PROFILE") {
228        info.insert("profile".to_string(), profile);
229    }
230
231    if let Ok(target) = std::env::var("TARGET") {
232        info.insert("target".to_string(), target);
233    }
234
235    info
236}
237
238/// Memory alignment utilities
239pub fn is_aligned(ptr: *const u8, alignment: usize) -> bool {
240    (ptr as usize) % alignment == 0
241}
242
243pub fn check_simd_alignment(data: &[f32]) -> bool {
244    let ptr = data.as_ptr() as *const u8;
245    is_aligned(ptr, 32) // AVX2 requires 32-byte alignment
246}
247
248/// Statistical helper functions
249pub fn mean(data: &[f64]) -> f64 {
250    if data.is_empty() {
251        return 0.0;
252    }
253    data.iter().sum::<f64>() / data.len() as f64
254}
255
256pub fn variance(data: &[f64]) -> f64 {
257    if data.len() < 2 {
258        return 0.0;
259    }
260
261    let m = mean(data);
262    let sum_sq_diff: f64 = data.iter().map(|x| (x - m).powi(2)).sum();
263    sum_sq_diff / (data.len() - 1) as f64
264}
265
266pub fn std_dev(data: &[f64]) -> f64 {
267    variance(data).sqrt()
268}
269
270pub fn percentile(data: &[f64], p: f64) -> f64 {
271    if data.is_empty() {
272        return 0.0;
273    }
274
275    let mut sorted = data.to_vec();
276    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
277
278    let index = (p * (sorted.len() - 1) as f64).round() as usize;
279    sorted[index.min(sorted.len() - 1)]
280}
281
282/// Progress tracking utilities
283pub struct ProgressBar {
284    total: usize,
285    current: usize,
286    start_time: Instant,
287    last_update: Instant,
288}
289
290impl ProgressBar {
291    pub fn new(total: usize) -> Self {
292        let now = Instant::now();
293        Self {
294            total,
295            current: 0,
296            start_time: now,
297            last_update: now,
298        }
299    }
300
301    pub fn update(&mut self, current: usize) {
302        self.current = current;
303        let now = Instant::now();
304
305        // Update every 100ms
306        if now.duration_since(self.last_update) > Duration::from_millis(100) {
307            self.display();
308            self.last_update = now;
309        }
310    }
311
312    pub fn finish(&mut self) {
313        self.current = self.total;
314        self.display();
315        println!(); // New line after completion
316    }
317
318    fn display(&self) {
319        let percentage = if self.total > 0 {
320            (self.current as f64 / self.total as f64) * 100.0
321        } else {
322            0.0
323        };
324
325        let elapsed = self.start_time.elapsed();
326        let rate = if elapsed.as_secs_f64() > 0.0 {
327            self.current as f64 / elapsed.as_secs_f64()
328        } else {
329            0.0
330        };
331
332        let eta = if rate > 0.0 && self.current < self.total {
333            Duration::from_secs_f64((self.total - self.current) as f64 / rate)
334        } else {
335            Duration::ZERO
336        };
337
338        print!("\r🔄 Progress: {:.1}% ({}/{}) | {:.1} it/s | ETA: {}",
339            percentage, self.current, self.total, rate, format_duration(eta));
340        std::io::Write::flush(&mut std::io::stdout()).unwrap();
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_calculate_metrics() {
350        let timings = vec![
351            Duration::from_micros(100),
352            Duration::from_micros(200),
353            Duration::from_micros(150),
354            Duration::from_micros(120),
355            Duration::from_micros(180),
356        ];
357
358        let metrics = calculate_metrics(&timings);
359        assert_eq!(metrics.samples, 5);
360        assert!(metrics.throughput_ops_per_sec > 0.0);
361    }
362
363    #[test]
364    fn test_format_duration() {
365        assert_eq!(format_duration(Duration::from_nanos(500)), "500ns");
366        assert_eq!(format_duration(Duration::from_micros(1500)), "1.5ms");
367        assert_eq!(format_duration(Duration::from_millis(1500)), "1.5s");
368    }
369
370    #[test]
371    fn test_validate_input() {
372        let input = vec![1.0, 2.0, 3.0];
373        assert!(validate_input(&input, 3).is_ok());
374        assert!(validate_input(&input, 4).is_err());
375
376        let invalid_input = vec![1.0, f32::NAN, 3.0];
377        assert!(validate_input(&invalid_input, 3).is_err());
378    }
379
380    #[test]
381    fn test_generate_test_data() {
382        let data1 = generate_test_data(10, 42);
383        let data2 = generate_test_data(10, 42);
384        assert_eq!(data1, data2); // Should be deterministic
385
386        let data3 = generate_test_data(10, 43);
387        assert_ne!(data1, data3); // Different seed should give different data
388    }
389
390    #[test]
391    fn test_compare_outputs() {
392        let a = [1.0, 2.0, 3.0, 4.0];
393        let b = [1.01, 1.99, 3.01, 3.99];
394
395        assert!(compare_outputs(&a, &b, 0.02)); // Within tolerance
396        assert!(!compare_outputs(&a, &b, 0.005)); // Outside tolerance
397    }
398
399    #[test]
400    fn test_calculate_speedup() {
401        let baseline = Duration::from_micros(100);
402        let optimized = Duration::from_micros(20);
403
404        let speedup = calculate_speedup(baseline, optimized);
405        assert!((speedup - 5.0).abs() < 0.001);
406    }
407}