Skip to main content

torsh_special/
visualization.rs

1//! Visualization tools for function behavior and accuracy analysis
2//!
3//! This module provides utilities for analyzing and visualizing the behavior
4//! of special functions, including accuracy comparisons, convergence analysis,
5//! and numerical stability assessment.
6
7use crate::error_functions;
8use crate::fast_approximations;
9use crate::gamma;
10use std::collections::HashMap;
11use torsh_core::{device::DeviceType, error::Result as TorshResult};
12use torsh_tensor::Tensor;
13
14/// Function behavior analysis results
15#[derive(Debug, Clone)]
16pub struct FunctionAnalysis {
17    /// Function name
18    pub name: String,
19    /// Input range analyzed
20    pub input_range: (f32, f32),
21    /// Number of sample points
22    pub num_points: usize,
23    /// Maximum absolute value in the range
24    pub max_value: f32,
25    /// Minimum absolute value in the range
26    pub min_value: f32,
27    /// Points where function has discontinuities or singularities
28    pub singularities: Vec<f32>,
29    /// Estimated numerical accuracy (relative error)
30    pub numerical_accuracy: f32,
31    /// Function monotonicity in the range
32    pub monotonicity: Monotonicity,
33}
34
35/// Monotonicity classification
36#[derive(Debug, Clone, PartialEq)]
37pub enum Monotonicity {
38    Increasing,
39    Decreasing,
40    NonMonotonic,
41    Constant,
42}
43
44/// Accuracy comparison between two function implementations
45#[derive(Debug, Clone)]
46pub struct AccuracyComparison {
47    /// Reference function name
48    pub reference_name: String,
49    /// Test function name  
50    pub test_name: String,
51    /// Maximum relative error
52    pub max_relative_error: f32,
53    /// Average relative error
54    pub avg_relative_error: f32,
55    /// Root mean square error
56    pub rms_error: f32,
57    /// Points with largest errors
58    pub worst_points: Vec<(f32, f32, f32)>, // (input, error, relative_error)
59}
60
61/// Generate ASCII plot data for function visualization
62#[derive(Debug, Clone)]
63pub struct PlotData {
64    /// X values
65    pub x_values: Vec<f32>,
66    /// Y values
67    pub y_values: Vec<f32>,
68    /// Plot width in characters
69    pub width: usize,
70    /// Plot height in characters
71    pub height: usize,
72    /// ASCII representation
73    pub ascii_plot: String,
74}
75
76/// Analyze the behavior of a special function over a given range
77pub fn analyze_function_behavior<F>(
78    name: &str,
79    func: F,
80    range: (f32, f32),
81    num_points: usize,
82) -> TorshResult<FunctionAnalysis>
83where
84    F: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
85{
86    let device = DeviceType::Cpu;
87    let (start, end) = range;
88
89    // Generate input points
90    let step = (end - start) / (num_points - 1) as f32;
91    let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
92    let x_tensor = Tensor::from_data(x_values.clone(), vec![num_points], device)?;
93
94    // Evaluate function
95    let result = func(&x_tensor)?;
96    let y_values = result.data()?;
97
98    // Analyze properties
99    let max_value = y_values
100        .iter()
101        .fold(f32::NEG_INFINITY, |a, &b| a.max(b.abs()));
102    let min_value = y_values.iter().fold(f32::INFINITY, |a, &b| a.min(b.abs()));
103
104    // Detect singularities (large jumps or infinite values)
105    let mut singularities = Vec::new();
106    for i in 1..y_values.len() {
107        let jump = (y_values[i] - y_values[i - 1]).abs();
108        if jump > 100.0 || !y_values[i].is_finite() {
109            singularities.push(x_values[i]);
110        }
111    }
112
113    // Assess monotonicity
114    let monotonicity = assess_monotonicity(&y_values);
115
116    // Estimate numerical accuracy (using finite differences)
117    let numerical_accuracy = estimate_numerical_accuracy(&x_values, &y_values);
118
119    Ok(FunctionAnalysis {
120        name: name.to_string(),
121        input_range: range,
122        num_points,
123        max_value,
124        min_value,
125        singularities,
126        numerical_accuracy,
127        monotonicity,
128    })
129}
130
131/// Compare accuracy between two function implementations
132pub fn compare_function_accuracy<F1, F2>(
133    reference_name: &str,
134    reference_func: F1,
135    test_name: &str,
136    test_func: F2,
137    range: (f32, f32),
138    num_points: usize,
139) -> TorshResult<AccuracyComparison>
140where
141    F1: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
142    F2: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
143{
144    let device = DeviceType::Cpu;
145    let (start, end) = range;
146
147    // Generate input points
148    let step = (end - start) / (num_points - 1) as f32;
149    let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
150    let x_tensor = Tensor::from_data(x_values.clone(), vec![num_points], device)?;
151
152    // Evaluate both functions
153    let ref_result = reference_func(&x_tensor)?;
154    let test_result = test_func(&x_tensor)?;
155
156    let ref_values = ref_result.data()?;
157    let test_values = test_result.data()?;
158
159    // Calculate errors
160    let mut errors = Vec::new();
161    let mut relative_errors = Vec::new();
162    let mut worst_points: Vec<(f32, f32, f32)> = Vec::new();
163
164    for i in 0..num_points {
165        if ref_values[i].is_finite() && test_values[i].is_finite() && ref_values[i] != 0.0 {
166            let error = (test_values[i] - ref_values[i]).abs();
167            let rel_error = error / ref_values[i].abs();
168
169            errors.push(error);
170            relative_errors.push(rel_error);
171
172            // Track worst points
173            if worst_points.len() < 5 || rel_error > worst_points[4].2 {
174                worst_points.push((x_values[i], error, rel_error));
175                worst_points.sort_by(|a, b| {
176                    b.2.partial_cmp(&a.2)
177                        .expect("relative error comparison should succeed for finite floats")
178                });
179                worst_points.truncate(5);
180            }
181        }
182    }
183
184    let max_relative_error = relative_errors.iter().fold(0.0f32, |a, &b| a.max(b));
185    let avg_relative_error = relative_errors.iter().sum::<f32>() / relative_errors.len() as f32;
186    let rms_error = (errors.iter().map(|&x| x * x).sum::<f32>() / errors.len() as f32).sqrt();
187
188    Ok(AccuracyComparison {
189        reference_name: reference_name.to_string(),
190        test_name: test_name.to_string(),
191        max_relative_error,
192        avg_relative_error,
193        rms_error,
194        worst_points,
195    })
196}
197
198/// Generate ASCII plot of function behavior
199pub fn generate_ascii_plot<F>(
200    func: F,
201    range: (f32, f32),
202    num_points: usize,
203    width: usize,
204    height: usize,
205) -> TorshResult<PlotData>
206where
207    F: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
208{
209    let device = DeviceType::Cpu;
210    let (start, end) = range;
211
212    // Generate input points
213    let step = (end - start) / (num_points - 1) as f32;
214    let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
215    let x_tensor = Tensor::from_data(x_values.clone(), vec![num_points], device)?;
216
217    // Evaluate function
218    let result = func(&x_tensor)?;
219    let y_values = result.data()?;
220
221    // Find y-range
222    let y_min = y_values.iter().fold(
223        f32::INFINITY,
224        |a, &b| {
225            if b.is_finite() {
226                a.min(b)
227            } else {
228                a
229            }
230        },
231    );
232    let y_max = y_values.iter().fold(
233        f32::NEG_INFINITY,
234        |a, &b| {
235            if b.is_finite() {
236                a.max(b)
237            } else {
238                a
239            }
240        },
241    );
242
243    // Create ASCII plot
244    let mut plot = vec![vec![' '; width]; height];
245
246    // Plot axes
247    for row in plot.iter_mut().take(height) {
248        row[0] = '|'; // Y-axis
249    }
250    for j in 0..width {
251        plot[height - 1][j] = '-'; // X-axis
252    }
253    plot[height - 1][0] = '+'; // Origin
254
255    // Plot function points
256    for i in 0..num_points {
257        if y_values[i].is_finite() {
258            let x_pos = ((x_values[i] - start) / (end - start) * (width - 1) as f32) as usize;
259            let y_pos = ((y_max - y_values[i]) / (y_max - y_min) * (height - 1) as f32) as usize;
260
261            if x_pos < width && y_pos < height {
262                plot[y_pos][x_pos] = '*';
263            }
264        }
265    }
266
267    // Convert to string
268    let ascii_plot = plot
269        .iter()
270        .map(|row| row.iter().collect::<String>())
271        .collect::<Vec<_>>()
272        .join("\n");
273
274    Ok(PlotData {
275        x_values,
276        y_values: y_values.to_vec(),
277        width,
278        height,
279        ascii_plot,
280    })
281}
282
283/// Benchmark function performance across optimization levels
284pub fn benchmark_optimization_levels(
285    range: (f32, f32),
286    num_points: usize,
287    iterations: usize,
288) -> TorshResult<HashMap<String, f64>> {
289    use std::time::Instant;
290
291    let device = DeviceType::Cpu;
292    let (start, end) = range;
293
294    // Generate input data
295    let step = (end - start) / (num_points - 1) as f32;
296    let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
297    let x_tensor = Tensor::from_data(x_values, vec![num_points], device)?;
298
299    let mut results = HashMap::new();
300
301    // Benchmark standard gamma function
302    let start_time = Instant::now();
303    for _ in 0..iterations {
304        let _ = gamma::gamma(&x_tensor)?;
305    }
306    let gamma_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
307    results.insert("gamma_standard".to_string(), gamma_time);
308
309    // Benchmark fast gamma approximation
310    let start_time = Instant::now();
311    for _ in 0..iterations {
312        let _ = fast_approximations::gamma_fast(&x_tensor)?;
313    }
314    let gamma_fast_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
315    results.insert("gamma_fast".to_string(), gamma_fast_time);
316
317    // Benchmark standard error function
318    let start_time = Instant::now();
319    for _ in 0..iterations {
320        let _ = error_functions::erf(&x_tensor)?;
321    }
322    let erf_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
323    results.insert("erf_standard".to_string(), erf_time);
324
325    // Benchmark fast error function approximation
326    let start_time = Instant::now();
327    for _ in 0..iterations {
328        let _ = fast_approximations::erf_fast(&x_tensor)?;
329    }
330    let erf_fast_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
331    results.insert("erf_fast".to_string(), erf_fast_time);
332
333    Ok(results)
334}
335
336/// Assess monotonicity of a function from sample values
337fn assess_monotonicity(values: &[f32]) -> Monotonicity {
338    if values.len() < 2 {
339        return Monotonicity::Constant;
340    }
341
342    let mut increasing = 0;
343    let mut decreasing = 0;
344    let mut constant = 0;
345
346    for i in 1..values.len() {
347        if values[i].is_finite() && values[i - 1].is_finite() {
348            if values[i] > values[i - 1] {
349                increasing += 1;
350            } else if values[i] < values[i - 1] {
351                decreasing += 1;
352            } else {
353                constant += 1;
354            }
355        }
356    }
357
358    let total = increasing + decreasing + constant;
359    if total == 0 {
360        return Monotonicity::Constant;
361    }
362
363    let inc_ratio = increasing as f32 / total as f32;
364    let dec_ratio = decreasing as f32 / total as f32;
365
366    if inc_ratio > 0.9 {
367        Monotonicity::Increasing
368    } else if dec_ratio > 0.9 {
369        Monotonicity::Decreasing
370    } else if inc_ratio < 0.1 && dec_ratio < 0.1 {
371        Monotonicity::Constant
372    } else {
373        Monotonicity::NonMonotonic
374    }
375}
376
377/// Estimate numerical accuracy using finite differences
378fn estimate_numerical_accuracy(x_values: &[f32], y_values: &[f32]) -> f32 {
379    if x_values.len() < 3 {
380        return 1e-6; // Default assumption
381    }
382
383    let mut max_curvature = 0.0f32;
384
385    for i in 1..x_values.len() - 1 {
386        if y_values[i - 1].is_finite() && y_values[i].is_finite() && y_values[i + 1].is_finite() {
387            let h1 = x_values[i] - x_values[i - 1];
388            let h2 = x_values[i + 1] - x_values[i];
389
390            if h1 > 0.0 && h2 > 0.0 {
391                // Second derivative approximation
392                let d2y =
393                    (y_values[i + 1] - y_values[i]) / h2 - (y_values[i] - y_values[i - 1]) / h1;
394                let curvature = d2y.abs() / (h1 + h2);
395                max_curvature = max_curvature.max(curvature);
396            }
397        }
398    }
399
400    // Estimate accuracy based on curvature and floating-point precision
401    let machine_eps = f32::EPSILON;
402    let estimated_error = machine_eps * (1.0 + max_curvature);
403
404    estimated_error.min(1e-3).max(machine_eps)
405}
406
407/// Print comprehensive function analysis report
408pub fn print_analysis_report(analysis: &FunctionAnalysis) {
409    println!("═══ Function Analysis Report ═══");
410    println!("Function: {}", analysis.name);
411    println!(
412        "Range: [{:.3}, {:.3}]",
413        analysis.input_range.0, analysis.input_range.1
414    );
415    println!("Sample points: {}", analysis.num_points);
416    println!(
417        "Value range: [{:.6}, {:.6}]",
418        analysis.min_value, analysis.max_value
419    );
420    println!("Monotonicity: {:?}", analysis.monotonicity);
421    println!("Numerical accuracy: {:.2e}", analysis.numerical_accuracy);
422
423    if !analysis.singularities.is_empty() {
424        println!("Singularities detected at: {:?}", analysis.singularities);
425    } else {
426        println!("No singularities detected");
427    }
428
429    println!("═══════════════════════════════");
430}
431
432/// Print accuracy comparison report
433pub fn print_accuracy_report(comparison: &AccuracyComparison) {
434    println!("═══ Accuracy Comparison Report ═══");
435    println!("Reference: {}", comparison.reference_name);
436    println!("Test function: {}", comparison.test_name);
437    println!("Max relative error: {:.2e}", comparison.max_relative_error);
438    println!(
439        "Average relative error: {:.2e}",
440        comparison.avg_relative_error
441    );
442    println!("RMS error: {:.2e}", comparison.rms_error);
443
444    println!("\nWorst accuracy points:");
445    for (i, &(x, err, rel_err)) in comparison.worst_points.iter().enumerate() {
446        println!(
447            "  {}: x={:.4}, error={:.2e}, rel_error={:.2e}",
448            i + 1,
449            x,
450            err,
451            rel_err
452        );
453    }
454
455    println!("═══════════════════════════════════");
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_function_analysis() -> TorshResult<()> {
464        let analysis = analyze_function_behavior("gamma", gamma::gamma, (0.1, 3.0), 50)?;
465
466        assert_eq!(analysis.name, "gamma");
467        assert!(analysis.max_value > 0.0);
468        assert!(analysis.numerical_accuracy > 0.0);
469
470        Ok(())
471    }
472
473    #[test]
474    fn test_accuracy_comparison() -> TorshResult<()> {
475        let comparison = compare_function_accuracy(
476            "gamma_standard",
477            gamma::gamma,
478            "gamma_fast",
479            fast_approximations::gamma_fast,
480            (0.5, 2.0),
481            20,
482        )?;
483
484        assert!(comparison.max_relative_error >= 0.0);
485        assert!(comparison.avg_relative_error >= 0.0);
486        assert!(comparison.rms_error >= 0.0);
487
488        Ok(())
489    }
490
491    #[test]
492    fn test_ascii_plot() -> TorshResult<()> {
493        let plot = generate_ascii_plot(gamma::gamma, (0.5, 2.0), 20, 40, 20)?;
494
495        assert_eq!(plot.width, 40);
496        assert_eq!(plot.height, 20);
497        assert!(!plot.ascii_plot.is_empty());
498        assert!(plot.ascii_plot.contains('*')); // Should have plot points
499        assert!(plot.ascii_plot.contains('|')); // Should have axes
500
501        Ok(())
502    }
503
504    #[test]
505    fn test_monotonicity_assessment() {
506        assert_eq!(
507            assess_monotonicity(&[1.0, 2.0, 3.0, 4.0]),
508            Monotonicity::Increasing
509        );
510        assert_eq!(
511            assess_monotonicity(&[4.0, 3.0, 2.0, 1.0]),
512            Monotonicity::Decreasing
513        );
514        assert_eq!(
515            assess_monotonicity(&[2.0, 2.0, 2.0, 2.0]),
516            Monotonicity::Constant
517        );
518        assert_eq!(
519            assess_monotonicity(&[1.0, 3.0, 2.0, 4.0]),
520            Monotonicity::NonMonotonic
521        );
522    }
523
524    #[test]
525    fn test_benchmark() -> TorshResult<()> {
526        let results = benchmark_optimization_levels((0.5, 2.0), 100, 5)?;
527
528        assert!(results.contains_key("gamma_standard"));
529        assert!(results.contains_key("gamma_fast"));
530        assert!(results.contains_key("erf_standard"));
531        assert!(results.contains_key("erf_fast"));
532
533        // Verify all benchmarks returned positive timing values
534        assert!(results["gamma_standard"] > 0.0);
535        assert!(results["gamma_fast"] > 0.0);
536        assert!(results["erf_standard"] > 0.0);
537        assert!(results["erf_fast"] > 0.0);
538
539        // Fast functions should generally be faster, but allow large tolerance for system variability
540        // Allow up to 10x slowdown to account for system load, cold cache, etc.
541        assert!(results["gamma_fast"] <= results["gamma_standard"] * 10.0);
542        assert!(results["erf_fast"] <= results["erf_standard"] * 10.0);
543
544        Ok(())
545    }
546}