Skip to main content

tenflowers_core/ops/
framework_comparison.rs

1//! Framework comparison benchmarking
2//!
3//! This module provides utilities to benchmark TenfloweRS against other
4//! machine learning frameworks like PyTorch, TensorFlow, and NumPy.
5
6use super::performance_benchmark::BenchmarkConfig;
7use crate::{Result, Tensor, TensorError};
8use std::collections::HashMap;
9use std::process::Command;
10use std::time::{Duration, Instant};
11
12/// Framework comparison result
13#[derive(Debug, Clone)]
14pub struct FrameworkComparisonResult {
15    pub operation: String,
16    pub size: usize,
17    pub tenflowers_time: Duration,
18    pub framework_times: HashMap<String, Duration>,
19    pub tenflowers_throughput: f64,
20    pub framework_throughputs: HashMap<String, f64>,
21    pub relative_performance: HashMap<String, f64>, // TenfloweRS vs other frameworks
22}
23
24impl FrameworkComparisonResult {
25    pub fn new(
26        operation: String,
27        size: usize,
28        tenflowers_time: Duration,
29        framework_times: HashMap<String, Duration>,
30    ) -> Self {
31        let tenflowers_throughput = size as f64 / tenflowers_time.as_secs_f64();
32
33        let mut framework_throughputs = HashMap::new();
34        let mut relative_performance = HashMap::new();
35
36        for (framework, time) in &framework_times {
37            let throughput = size as f64 / time.as_secs_f64();
38            framework_throughputs.insert(framework.clone(), throughput);
39
40            // Relative performance: TenfloweRS time / framework time
41            // > 1.0 means TenfloweRS is slower, < 1.0 means TenfloweRS is faster
42            let relative = tenflowers_time.as_nanos() as f64 / time.as_nanos() as f64;
43            relative_performance.insert(framework.clone(), relative);
44        }
45
46        Self {
47            operation,
48            size,
49            tenflowers_time,
50            framework_times,
51            tenflowers_throughput,
52            framework_throughputs,
53            relative_performance,
54        }
55    }
56}
57
58/// Framework benchmarking configuration
59#[derive(Debug, Clone)]
60pub struct FrameworkBenchmarkConfig {
61    pub base_config: BenchmarkConfig,
62    pub frameworks_to_test: Vec<String>,
63    pub python_executable: String,
64    pub skip_missing_frameworks: bool,
65}
66
67impl Default for FrameworkBenchmarkConfig {
68    fn default() -> Self {
69        Self {
70            base_config: BenchmarkConfig::default(),
71            frameworks_to_test: vec![
72                "numpy".to_string(),
73                "pytorch".to_string(),
74                "tensorflow".to_string(),
75            ],
76            python_executable: "python3".to_string(),
77            skip_missing_frameworks: true,
78        }
79    }
80}
81
82/// Check if a framework is available via Python
83fn check_framework_availability(framework: &str, python_executable: &str) -> bool {
84    let import_name = match framework {
85        "numpy" => "numpy",
86        "pytorch" => "torch",
87        "tensorflow" => "tensorflow",
88        _ => framework,
89    };
90
91    let output = Command::new(python_executable)
92        .arg("-c")
93        .arg(format!("import {import_name}"))
94        .output();
95
96    output.as_ref().map(|o| o.status.success()).unwrap_or(false)
97}
98
99/// Generate Python benchmark script for a specific operation
100fn generate_python_benchmark_script(
101    framework: &str,
102    operation: &str,
103    size: usize,
104    iterations: usize,
105) -> String {
106    let setup_code = match framework {
107        "numpy" => format!(
108            r#"
109import numpy as np
110import time
111a = np.random.randn({size}).astype(np.float32)
112b = np.random.randn({size}).astype(np.float32)
113"#
114        ),
115        "pytorch" => format!(
116            r#"
117import torch
118import time
119a = torch.randn({size}, dtype=torch.float32)
120b = torch.randn({size}, dtype=torch.float32)
121"#
122        ),
123        "tensorflow" => format!(
124            r#"
125import tensorflow as tf
126import time
127a = tf.random.normal([{size}], dtype=tf.float32)
128b = tf.random.normal([{size}], dtype=tf.float32)
129"#
130        ),
131        _ => return String::new(),
132    };
133
134    let operation_code = match (framework, operation) {
135        ("numpy", "add") => "result = np.add(a, b)",
136        ("numpy", "mul") => "result = np.multiply(a, b)",
137        ("numpy", "sub") => "result = np.subtract(a, b)",
138        ("numpy", "div") => "result = np.divide(a, b)",
139        ("pytorch", "add") => "result = torch.add(a, b)",
140        ("pytorch", "mul") => "result = torch.mul(a, b)",
141        ("pytorch", "sub") => "result = torch.sub(a, b)",
142        ("pytorch", "div") => "result = torch.div(a, b)",
143        ("tensorflow", "add") => "result = tf.add(a, b)",
144        ("tensorflow", "mul") => "result = tf.multiply(a, b)",
145        ("tensorflow", "sub") => "result = tf.subtract(a, b)",
146        ("tensorflow", "div") => "result = tf.divide(a, b)",
147        _ => return String::new(),
148    };
149
150    format!(
151        r#"
152{setup_code}
153
154# Warmup
155for _ in range(5):
156    {operation_code}
157
158# Benchmark
159start_time = time.perf_counter()
160for _ in range({iterations}):
161    {operation_code}
162end_time = time.perf_counter()
163
164elapsed_ns = (end_time - start_time) * 1e9
165print(f"{{elapsed_ns:.0f}}")
166"#
167    )
168}
169
170/// Benchmark a specific operation against external frameworks
171fn benchmark_operation_against_frameworks(
172    operation: &str,
173    size: usize,
174    config: &FrameworkBenchmarkConfig,
175) -> Result<FrameworkComparisonResult> {
176    // Benchmark TenfloweRS
177    let tenflowers_time = benchmark_tenflowers_operation(operation, size, &config.base_config)?;
178
179    // Benchmark other frameworks
180    let mut framework_times = HashMap::new();
181
182    for framework in &config.frameworks_to_test {
183        if !check_framework_availability(framework, &config.python_executable) {
184            if config.skip_missing_frameworks {
185                println!("Warning: {framework} not available, skipping");
186                continue;
187            } else {
188                return Err(TensorError::other(format!(
189                    "Framework {framework} not available"
190                )));
191            }
192        }
193
194        if let Ok(time) = benchmark_external_framework(
195            framework,
196            operation,
197            size,
198            &config.base_config,
199            &config.python_executable,
200        ) {
201            framework_times.insert(framework.clone(), time);
202        } else {
203            println!("Warning: Failed to benchmark {framework} for {operation}");
204        }
205    }
206
207    Ok(FrameworkComparisonResult::new(
208        operation.to_string(),
209        size,
210        tenflowers_time,
211        framework_times,
212    ))
213}
214
215/// Benchmark TenfloweRS operation
216fn benchmark_tenflowers_operation(
217    operation: &str,
218    size: usize,
219    config: &BenchmarkConfig,
220) -> Result<Duration> {
221    // Create test data
222    let a_data: Vec<f32> = (0..size).map(|i| i as f32).collect();
223    let b_data: Vec<f32> = (0..size).map(|i| (i as f32) + 1.0).collect();
224
225    let a = Tensor::from_vec(a_data, &[size])?;
226    let b = Tensor::from_vec(b_data, &[size])?;
227
228    // Warmup
229    for _ in 0..config.warmup_iterations {
230        match operation {
231            "add" => {
232                let _ = super::binary::add(&a, &b)?;
233            }
234            "mul" => {
235                let _ = super::binary::mul(&a, &b)?;
236            }
237            "sub" => {
238                let _ = super::binary::sub(&a, &b)?;
239            }
240            "div" => {
241                let _ = super::binary::div(&a, &b)?;
242            }
243            _ => {
244                return Err(TensorError::other(format!(
245                    "Unknown operation: {operation}"
246                )))
247            }
248        }
249    }
250
251    // Benchmark
252    let start = Instant::now();
253    for _ in 0..config.measurement_iterations {
254        match operation {
255            "add" => {
256                let _ = super::binary::add(&a, &b)?;
257            }
258            "mul" => {
259                let _ = super::binary::mul(&a, &b)?;
260            }
261            "sub" => {
262                let _ = super::binary::sub(&a, &b)?;
263            }
264            "div" => {
265                let _ = super::binary::div(&a, &b)?;
266            }
267            _ => {
268                return Err(TensorError::other(format!(
269                    "Unknown operation: {operation}"
270                )))
271            }
272        }
273    }
274    let elapsed = start.elapsed() / config.measurement_iterations as u32;
275
276    Ok(elapsed)
277}
278
279/// Benchmark external framework via Python
280fn benchmark_external_framework(
281    framework: &str,
282    operation: &str,
283    size: usize,
284    config: &BenchmarkConfig,
285    python_executable: &str,
286) -> Result<Duration> {
287    let script =
288        generate_python_benchmark_script(framework, operation, size, config.measurement_iterations);
289
290    if script.is_empty() {
291        return Err(TensorError::other(format!(
292            "Unsupported framework/operation: {framework}/{operation}"
293        )));
294    }
295
296    let output = Command::new(python_executable)
297        .arg("-c")
298        .arg(&script)
299        .output()
300        .map_err(|e| TensorError::other(format!("Failed to execute Python script: {e}")))?;
301
302    if !output.status.success() {
303        return Err(TensorError::other(format!(
304            "Python script failed: {}",
305            String::from_utf8_lossy(&output.stderr)
306        )));
307    }
308
309    let elapsed_ns_str = String::from_utf8_lossy(&output.stdout);
310    let elapsed_ns: f64 = elapsed_ns_str
311        .trim()
312        .parse()
313        .map_err(|e| TensorError::other(format!("Failed to parse timing result: {e}")))?;
314
315    Ok(Duration::from_nanos(elapsed_ns as u64))
316}
317
318/// Run framework comparison benchmark suite
319pub fn run_framework_comparison_benchmark(
320    config: FrameworkBenchmarkConfig,
321) -> Result<Vec<FrameworkComparisonResult>> {
322    println!("Running TenfloweRS Framework Comparison Benchmark");
323    println!("Testing against external frameworks...\n");
324
325    let operations = vec!["add", "mul", "sub", "div"];
326    let mut results = Vec::new();
327
328    for &size in &config.base_config.sizes {
329        println!("Benchmarking size: {size}");
330
331        for operation in &operations {
332            match benchmark_operation_against_frameworks(operation, size, &config) {
333                Ok(result) => {
334                    results.push(result);
335                }
336                Err(e) => {
337                    println!("Warning: Failed to benchmark {operation} at size {size}: {e}");
338                }
339            }
340        }
341    }
342
343    print_framework_comparison_results(&results);
344
345    Ok(results)
346}
347
348/// Print framework comparison results in formatted table
349pub fn print_framework_comparison_results(results: &[FrameworkComparisonResult]) {
350    if results.is_empty() {
351        println!("No benchmark results to display");
352        return;
353    }
354
355    println!("\n{:-<120}", "");
356    println!(
357        "| {:^12} | {:^8} | {:^15} | {:^15} | {:^15} | {:^15} | {:^15} |",
358        "Operation",
359        "Size",
360        "TenfloweRS (μs)",
361        "NumPy (μs)",
362        "PyTorch (μs)",
363        "TensorFlow (μs)",
364        "Relative Perf."
365    );
366    println!("{:-<120}", "");
367
368    for result in results {
369        let tf_us = result.tenflowers_time.as_micros();
370        let numpy_us = result
371            .framework_times
372            .get("numpy")
373            .map(|t| t.as_micros())
374            .unwrap_or(0);
375        let pytorch_us = result
376            .framework_times
377            .get("pytorch")
378            .map(|t| t.as_micros())
379            .unwrap_or(0);
380        let tensorflow_us = result
381            .framework_times
382            .get("tensorflow")
383            .map(|t| t.as_micros())
384            .unwrap_or(0);
385
386        // Calculate average relative performance (lower is better for TenfloweRS)
387        let avg_relative = if !result.relative_performance.is_empty() {
388            result.relative_performance.values().sum::<f64>()
389                / result.relative_performance.len() as f64
390        } else {
391            0.0
392        };
393
394        println!(
395            "| {:^12} | {:^8} | {:^15} | {:^15} | {:^15} | {:^15} | {:^15.2} |",
396            result.operation,
397            result.size,
398            if tf_us > 0 {
399                tf_us.to_string()
400            } else {
401                "-".to_string()
402            },
403            if numpy_us > 0 {
404                numpy_us.to_string()
405            } else {
406                "-".to_string()
407            },
408            if pytorch_us > 0 {
409                pytorch_us.to_string()
410            } else {
411                "-".to_string()
412            },
413            if tensorflow_us > 0 {
414                tensorflow_us.to_string()
415            } else {
416                "-".to_string()
417            },
418            avg_relative
419        );
420    }
421    println!("{:-<120}", "");
422
423    // Summary statistics
424    let all_relative_perfs: Vec<f64> = results
425        .iter()
426        .flat_map(|r| r.relative_performance.values())
427        .cloned()
428        .collect();
429
430    if !all_relative_perfs.is_empty() {
431        let avg_relative = all_relative_perfs.iter().sum::<f64>() / all_relative_perfs.len() as f64;
432        let best_relative = all_relative_perfs
433            .iter()
434            .fold(f64::INFINITY, |a, &b| a.min(b));
435        let worst_relative = all_relative_perfs.iter().fold(0.0f64, |a, &b| a.max(b));
436
437        println!("Summary:");
438        println!("  Average relative performance: {avg_relative:.2}x");
439        println!("  Best relative performance: {best_relative:.2}x");
440        println!("  Worst relative performance: {worst_relative:.2}x");
441
442        if avg_relative < 1.0 {
443            println!(
444                "  🚀 TenfloweRS is on average {:.2}x faster than other frameworks",
445                1.0 / avg_relative
446            );
447        } else {
448            println!(
449                "  ⚠️  TenfloweRS is on average {avg_relative:.2}x slower than other frameworks"
450            );
451        }
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_framework_availability_check() {
461        // Test with Python (should be available in most environments)
462        let has_python = check_framework_availability("sys", "python3")
463            || check_framework_availability("sys", "python");
464
465        // This test is environment-dependent, so we just ensure it doesn't crash
466        println!("Python available: {}", has_python);
467    }
468
469    #[test]
470    fn test_benchmark_script_generation() {
471        let script = generate_python_benchmark_script("numpy", "add", 1000, 10);
472        assert!(script.contains("import numpy"));
473        assert!(script.contains("np.add"));
474
475        let script = generate_python_benchmark_script("pytorch", "mul", 1000, 10);
476        assert!(script.contains("import torch"));
477        assert!(script.contains("torch.mul"));
478    }
479
480    #[test]
481    fn test_framework_comparison_result() {
482        let mut framework_times = HashMap::new();
483        framework_times.insert("numpy".to_string(), Duration::from_millis(2));
484        framework_times.insert("pytorch".to_string(), Duration::from_millis(3));
485
486        let result = FrameworkComparisonResult::new(
487            "add".to_string(),
488            1000,
489            Duration::from_millis(1),
490            framework_times,
491        );
492
493        assert_eq!(result.operation, "add");
494        assert_eq!(result.size, 1000);
495        assert!(result.relative_performance.contains_key("numpy"));
496        assert!(result.relative_performance.contains_key("pytorch"));
497
498        // TenfloweRS is faster (1ms vs 2ms, 3ms), so relative should be < 1.0
499        assert!(result.relative_performance["numpy"] < 1.0);
500        assert!(result.relative_performance["pytorch"] < 1.0);
501    }
502}