1use super::performance_benchmark::BenchmarkConfig;
7use crate::{Result, Tensor, TensorError};
8use std::collections::HashMap;
9use std::process::Command;
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
14pub struct FrameworkComparisonResult {
15 pub operation: String,
16 pub size: usize,
17 pub tenflowers_time: Duration,
18 pub framework_times: HashMap<String, Duration>,
19 pub tenflowers_throughput: f64,
20 pub framework_throughputs: HashMap<String, f64>,
21 pub relative_performance: HashMap<String, f64>, }
23
24impl FrameworkComparisonResult {
25 pub fn new(
26 operation: String,
27 size: usize,
28 tenflowers_time: Duration,
29 framework_times: HashMap<String, Duration>,
30 ) -> Self {
31 let tenflowers_throughput = size as f64 / tenflowers_time.as_secs_f64();
32
33 let mut framework_throughputs = HashMap::new();
34 let mut relative_performance = HashMap::new();
35
36 for (framework, time) in &framework_times {
37 let throughput = size as f64 / time.as_secs_f64();
38 framework_throughputs.insert(framework.clone(), throughput);
39
40 let relative = tenflowers_time.as_nanos() as f64 / time.as_nanos() as f64;
43 relative_performance.insert(framework.clone(), relative);
44 }
45
46 Self {
47 operation,
48 size,
49 tenflowers_time,
50 framework_times,
51 tenflowers_throughput,
52 framework_throughputs,
53 relative_performance,
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct FrameworkBenchmarkConfig {
61 pub base_config: BenchmarkConfig,
62 pub frameworks_to_test: Vec<String>,
63 pub python_executable: String,
64 pub skip_missing_frameworks: bool,
65}
66
67impl Default for FrameworkBenchmarkConfig {
68 fn default() -> Self {
69 Self {
70 base_config: BenchmarkConfig::default(),
71 frameworks_to_test: vec![
72 "numpy".to_string(),
73 "pytorch".to_string(),
74 "tensorflow".to_string(),
75 ],
76 python_executable: "python3".to_string(),
77 skip_missing_frameworks: true,
78 }
79 }
80}
81
82fn check_framework_availability(framework: &str, python_executable: &str) -> bool {
84 let import_name = match framework {
85 "numpy" => "numpy",
86 "pytorch" => "torch",
87 "tensorflow" => "tensorflow",
88 _ => framework,
89 };
90
91 let output = Command::new(python_executable)
92 .arg("-c")
93 .arg(format!("import {import_name}"))
94 .output();
95
96 output.as_ref().map(|o| o.status.success()).unwrap_or(false)
97}
98
99fn generate_python_benchmark_script(
101 framework: &str,
102 operation: &str,
103 size: usize,
104 iterations: usize,
105) -> String {
106 let setup_code = match framework {
107 "numpy" => format!(
108 r#"
109import numpy as np
110import time
111a = np.random.randn({size}).astype(np.float32)
112b = np.random.randn({size}).astype(np.float32)
113"#
114 ),
115 "pytorch" => format!(
116 r#"
117import torch
118import time
119a = torch.randn({size}, dtype=torch.float32)
120b = torch.randn({size}, dtype=torch.float32)
121"#
122 ),
123 "tensorflow" => format!(
124 r#"
125import tensorflow as tf
126import time
127a = tf.random.normal([{size}], dtype=tf.float32)
128b = tf.random.normal([{size}], dtype=tf.float32)
129"#
130 ),
131 _ => return String::new(),
132 };
133
134 let operation_code = match (framework, operation) {
135 ("numpy", "add") => "result = np.add(a, b)",
136 ("numpy", "mul") => "result = np.multiply(a, b)",
137 ("numpy", "sub") => "result = np.subtract(a, b)",
138 ("numpy", "div") => "result = np.divide(a, b)",
139 ("pytorch", "add") => "result = torch.add(a, b)",
140 ("pytorch", "mul") => "result = torch.mul(a, b)",
141 ("pytorch", "sub") => "result = torch.sub(a, b)",
142 ("pytorch", "div") => "result = torch.div(a, b)",
143 ("tensorflow", "add") => "result = tf.add(a, b)",
144 ("tensorflow", "mul") => "result = tf.multiply(a, b)",
145 ("tensorflow", "sub") => "result = tf.subtract(a, b)",
146 ("tensorflow", "div") => "result = tf.divide(a, b)",
147 _ => return String::new(),
148 };
149
150 format!(
151 r#"
152{setup_code}
153
154# Warmup
155for _ in range(5):
156 {operation_code}
157
158# Benchmark
159start_time = time.perf_counter()
160for _ in range({iterations}):
161 {operation_code}
162end_time = time.perf_counter()
163
164elapsed_ns = (end_time - start_time) * 1e9
165print(f"{{elapsed_ns:.0f}}")
166"#
167 )
168}
169
170fn benchmark_operation_against_frameworks(
172 operation: &str,
173 size: usize,
174 config: &FrameworkBenchmarkConfig,
175) -> Result<FrameworkComparisonResult> {
176 let tenflowers_time = benchmark_tenflowers_operation(operation, size, &config.base_config)?;
178
179 let mut framework_times = HashMap::new();
181
182 for framework in &config.frameworks_to_test {
183 if !check_framework_availability(framework, &config.python_executable) {
184 if config.skip_missing_frameworks {
185 println!("Warning: {framework} not available, skipping");
186 continue;
187 } else {
188 return Err(TensorError::other(format!(
189 "Framework {framework} not available"
190 )));
191 }
192 }
193
194 if let Ok(time) = benchmark_external_framework(
195 framework,
196 operation,
197 size,
198 &config.base_config,
199 &config.python_executable,
200 ) {
201 framework_times.insert(framework.clone(), time);
202 } else {
203 println!("Warning: Failed to benchmark {framework} for {operation}");
204 }
205 }
206
207 Ok(FrameworkComparisonResult::new(
208 operation.to_string(),
209 size,
210 tenflowers_time,
211 framework_times,
212 ))
213}
214
215fn benchmark_tenflowers_operation(
217 operation: &str,
218 size: usize,
219 config: &BenchmarkConfig,
220) -> Result<Duration> {
221 let a_data: Vec<f32> = (0..size).map(|i| i as f32).collect();
223 let b_data: Vec<f32> = (0..size).map(|i| (i as f32) + 1.0).collect();
224
225 let a = Tensor::from_vec(a_data, &[size])?;
226 let b = Tensor::from_vec(b_data, &[size])?;
227
228 for _ in 0..config.warmup_iterations {
230 match operation {
231 "add" => {
232 let _ = super::binary::add(&a, &b)?;
233 }
234 "mul" => {
235 let _ = super::binary::mul(&a, &b)?;
236 }
237 "sub" => {
238 let _ = super::binary::sub(&a, &b)?;
239 }
240 "div" => {
241 let _ = super::binary::div(&a, &b)?;
242 }
243 _ => {
244 return Err(TensorError::other(format!(
245 "Unknown operation: {operation}"
246 )))
247 }
248 }
249 }
250
251 let start = Instant::now();
253 for _ in 0..config.measurement_iterations {
254 match operation {
255 "add" => {
256 let _ = super::binary::add(&a, &b)?;
257 }
258 "mul" => {
259 let _ = super::binary::mul(&a, &b)?;
260 }
261 "sub" => {
262 let _ = super::binary::sub(&a, &b)?;
263 }
264 "div" => {
265 let _ = super::binary::div(&a, &b)?;
266 }
267 _ => {
268 return Err(TensorError::other(format!(
269 "Unknown operation: {operation}"
270 )))
271 }
272 }
273 }
274 let elapsed = start.elapsed() / config.measurement_iterations as u32;
275
276 Ok(elapsed)
277}
278
279fn benchmark_external_framework(
281 framework: &str,
282 operation: &str,
283 size: usize,
284 config: &BenchmarkConfig,
285 python_executable: &str,
286) -> Result<Duration> {
287 let script =
288 generate_python_benchmark_script(framework, operation, size, config.measurement_iterations);
289
290 if script.is_empty() {
291 return Err(TensorError::other(format!(
292 "Unsupported framework/operation: {framework}/{operation}"
293 )));
294 }
295
296 let output = Command::new(python_executable)
297 .arg("-c")
298 .arg(&script)
299 .output()
300 .map_err(|e| TensorError::other(format!("Failed to execute Python script: {e}")))?;
301
302 if !output.status.success() {
303 return Err(TensorError::other(format!(
304 "Python script failed: {}",
305 String::from_utf8_lossy(&output.stderr)
306 )));
307 }
308
309 let elapsed_ns_str = String::from_utf8_lossy(&output.stdout);
310 let elapsed_ns: f64 = elapsed_ns_str
311 .trim()
312 .parse()
313 .map_err(|e| TensorError::other(format!("Failed to parse timing result: {e}")))?;
314
315 Ok(Duration::from_nanos(elapsed_ns as u64))
316}
317
318pub fn run_framework_comparison_benchmark(
320 config: FrameworkBenchmarkConfig,
321) -> Result<Vec<FrameworkComparisonResult>> {
322 println!("Running TenfloweRS Framework Comparison Benchmark");
323 println!("Testing against external frameworks...\n");
324
325 let operations = vec!["add", "mul", "sub", "div"];
326 let mut results = Vec::new();
327
328 for &size in &config.base_config.sizes {
329 println!("Benchmarking size: {size}");
330
331 for operation in &operations {
332 match benchmark_operation_against_frameworks(operation, size, &config) {
333 Ok(result) => {
334 results.push(result);
335 }
336 Err(e) => {
337 println!("Warning: Failed to benchmark {operation} at size {size}: {e}");
338 }
339 }
340 }
341 }
342
343 print_framework_comparison_results(&results);
344
345 Ok(results)
346}
347
348pub fn print_framework_comparison_results(results: &[FrameworkComparisonResult]) {
350 if results.is_empty() {
351 println!("No benchmark results to display");
352 return;
353 }
354
355 println!("\n{:-<120}", "");
356 println!(
357 "| {:^12} | {:^8} | {:^15} | {:^15} | {:^15} | {:^15} | {:^15} |",
358 "Operation",
359 "Size",
360 "TenfloweRS (μs)",
361 "NumPy (μs)",
362 "PyTorch (μs)",
363 "TensorFlow (μs)",
364 "Relative Perf."
365 );
366 println!("{:-<120}", "");
367
368 for result in results {
369 let tf_us = result.tenflowers_time.as_micros();
370 let numpy_us = result
371 .framework_times
372 .get("numpy")
373 .map(|t| t.as_micros())
374 .unwrap_or(0);
375 let pytorch_us = result
376 .framework_times
377 .get("pytorch")
378 .map(|t| t.as_micros())
379 .unwrap_or(0);
380 let tensorflow_us = result
381 .framework_times
382 .get("tensorflow")
383 .map(|t| t.as_micros())
384 .unwrap_or(0);
385
386 let avg_relative = if !result.relative_performance.is_empty() {
388 result.relative_performance.values().sum::<f64>()
389 / result.relative_performance.len() as f64
390 } else {
391 0.0
392 };
393
394 println!(
395 "| {:^12} | {:^8} | {:^15} | {:^15} | {:^15} | {:^15} | {:^15.2} |",
396 result.operation,
397 result.size,
398 if tf_us > 0 {
399 tf_us.to_string()
400 } else {
401 "-".to_string()
402 },
403 if numpy_us > 0 {
404 numpy_us.to_string()
405 } else {
406 "-".to_string()
407 },
408 if pytorch_us > 0 {
409 pytorch_us.to_string()
410 } else {
411 "-".to_string()
412 },
413 if tensorflow_us > 0 {
414 tensorflow_us.to_string()
415 } else {
416 "-".to_string()
417 },
418 avg_relative
419 );
420 }
421 println!("{:-<120}", "");
422
423 let all_relative_perfs: Vec<f64> = results
425 .iter()
426 .flat_map(|r| r.relative_performance.values())
427 .cloned()
428 .collect();
429
430 if !all_relative_perfs.is_empty() {
431 let avg_relative = all_relative_perfs.iter().sum::<f64>() / all_relative_perfs.len() as f64;
432 let best_relative = all_relative_perfs
433 .iter()
434 .fold(f64::INFINITY, |a, &b| a.min(b));
435 let worst_relative = all_relative_perfs.iter().fold(0.0f64, |a, &b| a.max(b));
436
437 println!("Summary:");
438 println!(" Average relative performance: {avg_relative:.2}x");
439 println!(" Best relative performance: {best_relative:.2}x");
440 println!(" Worst relative performance: {worst_relative:.2}x");
441
442 if avg_relative < 1.0 {
443 println!(
444 " 🚀 TenfloweRS is on average {:.2}x faster than other frameworks",
445 1.0 / avg_relative
446 );
447 } else {
448 println!(
449 " ⚠️ TenfloweRS is on average {avg_relative:.2}x slower than other frameworks"
450 );
451 }
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_framework_availability_check() {
461 let has_python = check_framework_availability("sys", "python3")
463 || check_framework_availability("sys", "python");
464
465 println!("Python available: {}", has_python);
467 }
468
469 #[test]
470 fn test_benchmark_script_generation() {
471 let script = generate_python_benchmark_script("numpy", "add", 1000, 10);
472 assert!(script.contains("import numpy"));
473 assert!(script.contains("np.add"));
474
475 let script = generate_python_benchmark_script("pytorch", "mul", 1000, 10);
476 assert!(script.contains("import torch"));
477 assert!(script.contains("torch.mul"));
478 }
479
480 #[test]
481 fn test_framework_comparison_result() {
482 let mut framework_times = HashMap::new();
483 framework_times.insert("numpy".to_string(), Duration::from_millis(2));
484 framework_times.insert("pytorch".to_string(), Duration::from_millis(3));
485
486 let result = FrameworkComparisonResult::new(
487 "add".to_string(),
488 1000,
489 Duration::from_millis(1),
490 framework_times,
491 );
492
493 assert_eq!(result.operation, "add");
494 assert_eq!(result.size, 1000);
495 assert!(result.relative_performance.contains_key("numpy"));
496 assert!(result.relative_performance.contains_key("pytorch"));
497
498 assert!(result.relative_performance["numpy"] < 1.0);
500 assert!(result.relative_performance["pytorch"] < 1.0);
501 }
502}