1use crate::error_functions;
8use crate::fast_approximations;
9use crate::gamma;
10use std::collections::HashMap;
11use torsh_core::{device::DeviceType, error::Result as TorshResult};
12use torsh_tensor::Tensor;
13
14#[derive(Debug, Clone)]
16pub struct FunctionAnalysis {
17 pub name: String,
19 pub input_range: (f32, f32),
21 pub num_points: usize,
23 pub max_value: f32,
25 pub min_value: f32,
27 pub singularities: Vec<f32>,
29 pub numerical_accuracy: f32,
31 pub monotonicity: Monotonicity,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum Monotonicity {
38 Increasing,
39 Decreasing,
40 NonMonotonic,
41 Constant,
42}
43
44#[derive(Debug, Clone)]
46pub struct AccuracyComparison {
47 pub reference_name: String,
49 pub test_name: String,
51 pub max_relative_error: f32,
53 pub avg_relative_error: f32,
55 pub rms_error: f32,
57 pub worst_points: Vec<(f32, f32, f32)>, }
60
61#[derive(Debug, Clone)]
63pub struct PlotData {
64 pub x_values: Vec<f32>,
66 pub y_values: Vec<f32>,
68 pub width: usize,
70 pub height: usize,
72 pub ascii_plot: String,
74}
75
76pub fn analyze_function_behavior<F>(
78 name: &str,
79 func: F,
80 range: (f32, f32),
81 num_points: usize,
82) -> TorshResult<FunctionAnalysis>
83where
84 F: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
85{
86 let device = DeviceType::Cpu;
87 let (start, end) = range;
88
89 let step = (end - start) / (num_points - 1) as f32;
91 let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
92 let x_tensor = Tensor::from_data(x_values.clone(), vec![num_points], device)?;
93
94 let result = func(&x_tensor)?;
96 let y_values = result.data()?;
97
98 let max_value = y_values
100 .iter()
101 .fold(f32::NEG_INFINITY, |a, &b| a.max(b.abs()));
102 let min_value = y_values.iter().fold(f32::INFINITY, |a, &b| a.min(b.abs()));
103
104 let mut singularities = Vec::new();
106 for i in 1..y_values.len() {
107 let jump = (y_values[i] - y_values[i - 1]).abs();
108 if jump > 100.0 || !y_values[i].is_finite() {
109 singularities.push(x_values[i]);
110 }
111 }
112
113 let monotonicity = assess_monotonicity(&y_values);
115
116 let numerical_accuracy = estimate_numerical_accuracy(&x_values, &y_values);
118
119 Ok(FunctionAnalysis {
120 name: name.to_string(),
121 input_range: range,
122 num_points,
123 max_value,
124 min_value,
125 singularities,
126 numerical_accuracy,
127 monotonicity,
128 })
129}
130
131pub fn compare_function_accuracy<F1, F2>(
133 reference_name: &str,
134 reference_func: F1,
135 test_name: &str,
136 test_func: F2,
137 range: (f32, f32),
138 num_points: usize,
139) -> TorshResult<AccuracyComparison>
140where
141 F1: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
142 F2: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
143{
144 let device = DeviceType::Cpu;
145 let (start, end) = range;
146
147 let step = (end - start) / (num_points - 1) as f32;
149 let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
150 let x_tensor = Tensor::from_data(x_values.clone(), vec![num_points], device)?;
151
152 let ref_result = reference_func(&x_tensor)?;
154 let test_result = test_func(&x_tensor)?;
155
156 let ref_values = ref_result.data()?;
157 let test_values = test_result.data()?;
158
159 let mut errors = Vec::new();
161 let mut relative_errors = Vec::new();
162 let mut worst_points: Vec<(f32, f32, f32)> = Vec::new();
163
164 for i in 0..num_points {
165 if ref_values[i].is_finite() && test_values[i].is_finite() && ref_values[i] != 0.0 {
166 let error = (test_values[i] - ref_values[i]).abs();
167 let rel_error = error / ref_values[i].abs();
168
169 errors.push(error);
170 relative_errors.push(rel_error);
171
172 if worst_points.len() < 5 || rel_error > worst_points[4].2 {
174 worst_points.push((x_values[i], error, rel_error));
175 worst_points.sort_by(|a, b| {
176 b.2.partial_cmp(&a.2)
177 .expect("relative error comparison should succeed for finite floats")
178 });
179 worst_points.truncate(5);
180 }
181 }
182 }
183
184 let max_relative_error = relative_errors.iter().fold(0.0f32, |a, &b| a.max(b));
185 let avg_relative_error = relative_errors.iter().sum::<f32>() / relative_errors.len() as f32;
186 let rms_error = (errors.iter().map(|&x| x * x).sum::<f32>() / errors.len() as f32).sqrt();
187
188 Ok(AccuracyComparison {
189 reference_name: reference_name.to_string(),
190 test_name: test_name.to_string(),
191 max_relative_error,
192 avg_relative_error,
193 rms_error,
194 worst_points,
195 })
196}
197
198pub fn generate_ascii_plot<F>(
200 func: F,
201 range: (f32, f32),
202 num_points: usize,
203 width: usize,
204 height: usize,
205) -> TorshResult<PlotData>
206where
207 F: Fn(&Tensor<f32>) -> TorshResult<Tensor<f32>>,
208{
209 let device = DeviceType::Cpu;
210 let (start, end) = range;
211
212 let step = (end - start) / (num_points - 1) as f32;
214 let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
215 let x_tensor = Tensor::from_data(x_values.clone(), vec![num_points], device)?;
216
217 let result = func(&x_tensor)?;
219 let y_values = result.data()?;
220
221 let y_min = y_values.iter().fold(
223 f32::INFINITY,
224 |a, &b| {
225 if b.is_finite() {
226 a.min(b)
227 } else {
228 a
229 }
230 },
231 );
232 let y_max = y_values.iter().fold(
233 f32::NEG_INFINITY,
234 |a, &b| {
235 if b.is_finite() {
236 a.max(b)
237 } else {
238 a
239 }
240 },
241 );
242
243 let mut plot = vec![vec![' '; width]; height];
245
246 for row in plot.iter_mut().take(height) {
248 row[0] = '|'; }
250 for j in 0..width {
251 plot[height - 1][j] = '-'; }
253 plot[height - 1][0] = '+'; for i in 0..num_points {
257 if y_values[i].is_finite() {
258 let x_pos = ((x_values[i] - start) / (end - start) * (width - 1) as f32) as usize;
259 let y_pos = ((y_max - y_values[i]) / (y_max - y_min) * (height - 1) as f32) as usize;
260
261 if x_pos < width && y_pos < height {
262 plot[y_pos][x_pos] = '*';
263 }
264 }
265 }
266
267 let ascii_plot = plot
269 .iter()
270 .map(|row| row.iter().collect::<String>())
271 .collect::<Vec<_>>()
272 .join("\n");
273
274 Ok(PlotData {
275 x_values,
276 y_values: y_values.to_vec(),
277 width,
278 height,
279 ascii_plot,
280 })
281}
282
283pub fn benchmark_optimization_levels(
285 range: (f32, f32),
286 num_points: usize,
287 iterations: usize,
288) -> TorshResult<HashMap<String, f64>> {
289 use std::time::Instant;
290
291 let device = DeviceType::Cpu;
292 let (start, end) = range;
293
294 let step = (end - start) / (num_points - 1) as f32;
296 let x_values: Vec<f32> = (0..num_points).map(|i| start + i as f32 * step).collect();
297 let x_tensor = Tensor::from_data(x_values, vec![num_points], device)?;
298
299 let mut results = HashMap::new();
300
301 let start_time = Instant::now();
303 for _ in 0..iterations {
304 let _ = gamma::gamma(&x_tensor)?;
305 }
306 let gamma_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
307 results.insert("gamma_standard".to_string(), gamma_time);
308
309 let start_time = Instant::now();
311 for _ in 0..iterations {
312 let _ = fast_approximations::gamma_fast(&x_tensor)?;
313 }
314 let gamma_fast_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
315 results.insert("gamma_fast".to_string(), gamma_fast_time);
316
317 let start_time = Instant::now();
319 for _ in 0..iterations {
320 let _ = error_functions::erf(&x_tensor)?;
321 }
322 let erf_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
323 results.insert("erf_standard".to_string(), erf_time);
324
325 let start_time = Instant::now();
327 for _ in 0..iterations {
328 let _ = fast_approximations::erf_fast(&x_tensor)?;
329 }
330 let erf_fast_time = start_time.elapsed().as_secs_f64() * 1e9 / iterations as f64;
331 results.insert("erf_fast".to_string(), erf_fast_time);
332
333 Ok(results)
334}
335
336fn assess_monotonicity(values: &[f32]) -> Monotonicity {
338 if values.len() < 2 {
339 return Monotonicity::Constant;
340 }
341
342 let mut increasing = 0;
343 let mut decreasing = 0;
344 let mut constant = 0;
345
346 for i in 1..values.len() {
347 if values[i].is_finite() && values[i - 1].is_finite() {
348 if values[i] > values[i - 1] {
349 increasing += 1;
350 } else if values[i] < values[i - 1] {
351 decreasing += 1;
352 } else {
353 constant += 1;
354 }
355 }
356 }
357
358 let total = increasing + decreasing + constant;
359 if total == 0 {
360 return Monotonicity::Constant;
361 }
362
363 let inc_ratio = increasing as f32 / total as f32;
364 let dec_ratio = decreasing as f32 / total as f32;
365
366 if inc_ratio > 0.9 {
367 Monotonicity::Increasing
368 } else if dec_ratio > 0.9 {
369 Monotonicity::Decreasing
370 } else if inc_ratio < 0.1 && dec_ratio < 0.1 {
371 Monotonicity::Constant
372 } else {
373 Monotonicity::NonMonotonic
374 }
375}
376
377fn estimate_numerical_accuracy(x_values: &[f32], y_values: &[f32]) -> f32 {
379 if x_values.len() < 3 {
380 return 1e-6; }
382
383 let mut max_curvature = 0.0f32;
384
385 for i in 1..x_values.len() - 1 {
386 if y_values[i - 1].is_finite() && y_values[i].is_finite() && y_values[i + 1].is_finite() {
387 let h1 = x_values[i] - x_values[i - 1];
388 let h2 = x_values[i + 1] - x_values[i];
389
390 if h1 > 0.0 && h2 > 0.0 {
391 let d2y =
393 (y_values[i + 1] - y_values[i]) / h2 - (y_values[i] - y_values[i - 1]) / h1;
394 let curvature = d2y.abs() / (h1 + h2);
395 max_curvature = max_curvature.max(curvature);
396 }
397 }
398 }
399
400 let machine_eps = f32::EPSILON;
402 let estimated_error = machine_eps * (1.0 + max_curvature);
403
404 estimated_error.min(1e-3).max(machine_eps)
405}
406
407pub fn print_analysis_report(analysis: &FunctionAnalysis) {
409 println!("═══ Function Analysis Report ═══");
410 println!("Function: {}", analysis.name);
411 println!(
412 "Range: [{:.3}, {:.3}]",
413 analysis.input_range.0, analysis.input_range.1
414 );
415 println!("Sample points: {}", analysis.num_points);
416 println!(
417 "Value range: [{:.6}, {:.6}]",
418 analysis.min_value, analysis.max_value
419 );
420 println!("Monotonicity: {:?}", analysis.monotonicity);
421 println!("Numerical accuracy: {:.2e}", analysis.numerical_accuracy);
422
423 if !analysis.singularities.is_empty() {
424 println!("Singularities detected at: {:?}", analysis.singularities);
425 } else {
426 println!("No singularities detected");
427 }
428
429 println!("═══════════════════════════════");
430}
431
432pub fn print_accuracy_report(comparison: &AccuracyComparison) {
434 println!("═══ Accuracy Comparison Report ═══");
435 println!("Reference: {}", comparison.reference_name);
436 println!("Test function: {}", comparison.test_name);
437 println!("Max relative error: {:.2e}", comparison.max_relative_error);
438 println!(
439 "Average relative error: {:.2e}",
440 comparison.avg_relative_error
441 );
442 println!("RMS error: {:.2e}", comparison.rms_error);
443
444 println!("\nWorst accuracy points:");
445 for (i, &(x, err, rel_err)) in comparison.worst_points.iter().enumerate() {
446 println!(
447 " {}: x={:.4}, error={:.2e}, rel_error={:.2e}",
448 i + 1,
449 x,
450 err,
451 rel_err
452 );
453 }
454
455 println!("═══════════════════════════════════");
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_function_analysis() -> TorshResult<()> {
464 let analysis = analyze_function_behavior("gamma", gamma::gamma, (0.1, 3.0), 50)?;
465
466 assert_eq!(analysis.name, "gamma");
467 assert!(analysis.max_value > 0.0);
468 assert!(analysis.numerical_accuracy > 0.0);
469
470 Ok(())
471 }
472
473 #[test]
474 fn test_accuracy_comparison() -> TorshResult<()> {
475 let comparison = compare_function_accuracy(
476 "gamma_standard",
477 gamma::gamma,
478 "gamma_fast",
479 fast_approximations::gamma_fast,
480 (0.5, 2.0),
481 20,
482 )?;
483
484 assert!(comparison.max_relative_error >= 0.0);
485 assert!(comparison.avg_relative_error >= 0.0);
486 assert!(comparison.rms_error >= 0.0);
487
488 Ok(())
489 }
490
491 #[test]
492 fn test_ascii_plot() -> TorshResult<()> {
493 let plot = generate_ascii_plot(gamma::gamma, (0.5, 2.0), 20, 40, 20)?;
494
495 assert_eq!(plot.width, 40);
496 assert_eq!(plot.height, 20);
497 assert!(!plot.ascii_plot.is_empty());
498 assert!(plot.ascii_plot.contains('*')); assert!(plot.ascii_plot.contains('|')); Ok(())
502 }
503
504 #[test]
505 fn test_monotonicity_assessment() {
506 assert_eq!(
507 assess_monotonicity(&[1.0, 2.0, 3.0, 4.0]),
508 Monotonicity::Increasing
509 );
510 assert_eq!(
511 assess_monotonicity(&[4.0, 3.0, 2.0, 1.0]),
512 Monotonicity::Decreasing
513 );
514 assert_eq!(
515 assess_monotonicity(&[2.0, 2.0, 2.0, 2.0]),
516 Monotonicity::Constant
517 );
518 assert_eq!(
519 assess_monotonicity(&[1.0, 3.0, 2.0, 4.0]),
520 Monotonicity::NonMonotonic
521 );
522 }
523
524 #[test]
525 fn test_benchmark() -> TorshResult<()> {
526 let results = benchmark_optimization_levels((0.5, 2.0), 100, 5)?;
527
528 assert!(results.contains_key("gamma_standard"));
529 assert!(results.contains_key("gamma_fast"));
530 assert!(results.contains_key("erf_standard"));
531 assert!(results.contains_key("erf_fast"));
532
533 assert!(results["gamma_standard"] > 0.0);
535 assert!(results["gamma_fast"] > 0.0);
536 assert!(results["erf_standard"] > 0.0);
537 assert!(results["erf_fast"] > 0.0);
538
539 assert!(results["gamma_fast"] <= results["gamma_standard"] * 10.0);
542 assert!(results["erf_fast"] <= results["erf_standard"] * 10.0);
543
544 Ok(())
545 }
546}