1use anyhow::Result;
10use scirs2_core::ndarray::{Array, IxDyn};
11use scirs2_core::random::thread_rng;
12use std::collections::HashMap;
13use tensorlogic_infer::TlAutodiff;
14use tensorlogic_ir::EinsumGraph;
15use tensorlogic_scirs_backend::{
16 DeviceType, ParallelScirs2Exec, ProfiledScirs2Exec, Scirs2Exec, Scirs2Tensor,
17};
18
19use crate::output::{print_info, print_success};
20
21type TensorId = usize;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub enum Backend {
27 #[default]
29 SciRS2CPU,
30 #[cfg(feature = "simd")]
32 SciRS2SIMD,
33 #[cfg(feature = "gpu")]
35 SciRS2GPU,
36 Parallel,
38 Profiled,
40}
41
42impl Backend {
43 pub fn name(&self) -> &'static str {
45 match self {
46 Backend::SciRS2CPU => "SciRS2 CPU",
47 #[cfg(feature = "simd")]
48 Backend::SciRS2SIMD => "SciRS2 SIMD",
49 #[cfg(feature = "gpu")]
50 Backend::SciRS2GPU => "SciRS2 GPU",
51 Backend::Parallel => "SciRS2 Parallel",
52 Backend::Profiled => "SciRS2 Profiled",
53 }
54 }
55
56 pub fn is_available(&self) -> bool {
58 match self {
59 Backend::SciRS2CPU => true,
60 #[cfg(feature = "simd")]
61 Backend::SciRS2SIMD => true, #[cfg(feature = "gpu")]
63 Backend::SciRS2GPU => true, Backend::Parallel => true,
65 Backend::Profiled => true,
66 }
67 }
68
69 pub fn available_backends() -> Vec<Backend> {
71 let backends = vec![Backend::SciRS2CPU, Backend::Parallel, Backend::Profiled];
72
73 backends
85 }
86
87 #[allow(clippy::should_implement_trait)]
89 pub fn from_str(s: &str) -> Result<Self> {
90 match s.to_lowercase().as_str() {
91 "cpu" | "scirs2-cpu" => Ok(Backend::SciRS2CPU),
92 #[cfg(feature = "simd")]
93 "simd" | "scirs2-simd" => Ok(Backend::SciRS2SIMD),
94 #[cfg(feature = "gpu")]
95 "gpu" | "scirs2-gpu" => Ok(Backend::SciRS2GPU),
96 "parallel" => Ok(Backend::Parallel),
97 "profiled" => Ok(Backend::Profiled),
98 _ => anyhow::bail!("Unknown backend: {}", s),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
107pub struct ExecutionConfig {
108 pub backend: Backend,
110 #[allow(dead_code)]
112 pub device: DeviceType,
113 pub show_metrics: bool,
115 pub show_intermediates: bool,
117 #[allow(dead_code)]
119 pub validate_shapes: bool,
120 pub trace: bool,
122}
123
124impl Default for ExecutionConfig {
125 fn default() -> Self {
126 Self {
127 backend: Backend::default(),
128 device: DeviceType::Cpu,
129 show_metrics: false,
130 show_intermediates: false,
131 validate_shapes: true,
132 trace: false,
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct ExecutionResult {
140 pub output: Scirs2Tensor,
142 pub execution_time_ms: f64,
144 pub intermediates: HashMap<TensorId, Scirs2Tensor>,
146 pub backend: Backend,
148 pub memory_bytes: usize,
150}
151
152impl ExecutionResult {
153 pub fn print_summary(&self, config: &ExecutionConfig) {
155 print_success(&format!(
156 "Execution completed with {} in {:.3} ms",
157 self.backend.name(),
158 self.execution_time_ms
159 ));
160
161 println!("\nOutput shape: {:?}", self.output.shape());
162 println!("Output dtype: f64");
163
164 if config.show_metrics {
165 println!("\nPerformance Metrics:");
166 println!(" Execution time: {:.3} ms", self.execution_time_ms);
167 println!(" Memory used: {} bytes", self.memory_bytes);
168 println!(
169 " Throughput: {:.2} GFLOPS",
170 self.estimate_flops() / self.execution_time_ms / 1_000.0
171 );
172 }
173
174 if config.show_intermediates && !self.intermediates.is_empty() {
175 println!("\nIntermediate Tensors:");
176 for (id, tensor) in &self.intermediates {
177 println!(" Tensor {}: shape {:?}", id, tensor.shape());
178 }
179 }
180
181 println!("\nOutput preview (first 10 elements):");
182 print_tensor_preview(&self.output, 10);
183 }
184
185 fn estimate_flops(&self) -> f64 {
187 let elements = self.output.len() as f64;
189 elements * 2.0 }
191}
192
193fn print_tensor_preview(tensor: &Scirs2Tensor, max_elements: usize) {
195 let flat = tensor.as_slice().unwrap_or(&[]);
196 let preview: Vec<String> = flat
197 .iter()
198 .take(max_elements)
199 .map(|v| format!("{:.4}", v))
200 .collect();
201
202 println!(
203 " [{}{}]",
204 preview.join(", "),
205 if flat.len() > max_elements {
206 ", ..."
207 } else {
208 ""
209 }
210 );
211}
212
213pub struct CliExecutor {
215 config: ExecutionConfig,
216}
217
218impl CliExecutor {
219 pub fn new(config: ExecutionConfig) -> Result<Self> {
221 if !config.backend.is_available() {
222 anyhow::bail!("Backend {} is not available", config.backend.name());
223 }
224
225 Ok(Self { config })
226 }
227
228 pub fn execute(&self, graph: &EinsumGraph) -> Result<ExecutionResult> {
230 let start = std::time::Instant::now();
231
232 let inputs = self.generate_inputs(graph)?;
234
235 if self.config.trace {
236 print_info("Generated input tensors");
237 for (id, tensor) in &inputs {
238 println!(" Tensor {}: shape {:?}", id, tensor.shape());
239 }
240 }
241
242 let (output, intermediates) = match self.config.backend {
244 Backend::SciRS2CPU => self.execute_cpu(graph)?,
245 #[cfg(feature = "simd")]
246 Backend::SciRS2SIMD => self.execute_simd(graph)?,
247 #[cfg(feature = "gpu")]
248 Backend::SciRS2GPU => self.execute_gpu(graph)?,
249 Backend::Parallel => self.execute_parallel(graph)?,
250 Backend::Profiled => self.execute_profiled(graph)?,
251 };
252
253 let execution_time_ms = start.elapsed().as_secs_f64() * 1000.0;
254
255 let memory_bytes = self.estimate_memory(&output, &intermediates);
257
258 Ok(ExecutionResult {
259 output,
260 execution_time_ms,
261 intermediates,
262 backend: self.config.backend,
263 memory_bytes,
264 })
265 }
266
267 fn execute_cpu(
269 &self,
270 graph: &EinsumGraph,
271 ) -> Result<(Scirs2Tensor, HashMap<TensorId, Scirs2Tensor>)> {
272 let mut executor = Scirs2Exec::new();
273 let output = executor
274 .forward(graph)
275 .map_err(|e| anyhow::anyhow!("Execution failed: {:?}", e))?;
276
277 Ok((output, HashMap::new()))
278 }
279
280 #[cfg(feature = "simd")]
282 fn execute_simd(
283 &self,
284 graph: &EinsumGraph,
285 ) -> Result<(Scirs2Tensor, HashMap<TensorId, Scirs2Tensor>)> {
286 self.execute_cpu(graph)
288 }
289
290 #[cfg(feature = "gpu")]
292 fn execute_gpu(
293 &self,
294 graph: &EinsumGraph,
295 ) -> Result<(Scirs2Tensor, HashMap<TensorId, Scirs2Tensor>)> {
296 self.execute_cpu(graph)
299 }
300
301 fn execute_parallel(
303 &self,
304 graph: &EinsumGraph,
305 ) -> Result<(Scirs2Tensor, HashMap<TensorId, Scirs2Tensor>)> {
306 let mut executor = ParallelScirs2Exec::new();
307 let output = executor
308 .forward(graph)
309 .map_err(|e| anyhow::anyhow!("Parallel execution failed: {:?}", e))?;
310
311 Ok((output, HashMap::new()))
312 }
313
314 fn execute_profiled(
316 &self,
317 graph: &EinsumGraph,
318 ) -> Result<(Scirs2Tensor, HashMap<TensorId, Scirs2Tensor>)> {
319 let mut executor = ProfiledScirs2Exec::new();
320 let output = executor
321 .forward(graph)
322 .map_err(|e| anyhow::anyhow!("Profiled execution failed: {:?}", e))?;
323
324 if self.config.show_metrics {
326 println!("\nProfiling Results:");
328 println!(" (Profiling data collection not yet implemented)");
329 }
330
331 Ok((output, HashMap::new()))
332 }
333
334 fn generate_inputs(&self, graph: &EinsumGraph) -> Result<HashMap<TensorId, Scirs2Tensor>> {
339 let mut inputs = HashMap::new();
340 let mut rng = thread_rng();
341
342 let default_dim = 100;
344
345 for &input_idx in &graph.inputs {
347 let data: Vec<f64> = (0..default_dim)
349 .map(|_| rng.random_range(0.0..1.0))
350 .collect();
351
352 let tensor = Array::from_shape_vec(IxDyn(&[default_dim]), data)
354 .map_err(|e| anyhow::anyhow!("Failed to create tensor: {}", e))?;
355
356 inputs.insert(input_idx, tensor);
357 }
358
359 Ok(inputs)
360 }
361
362 fn estimate_memory(
364 &self,
365 output: &Scirs2Tensor,
366 intermediates: &HashMap<TensorId, Scirs2Tensor>,
367 ) -> usize {
368 let mut total = output.len() * std::mem::size_of::<f64>();
369
370 for tensor in intermediates.values() {
371 total += tensor.len() * std::mem::size_of::<f64>();
372 }
373
374 total
375 }
376}
377
378pub fn list_backends() {
380 println!("Available Backends:");
381 println!();
382
383 for backend in Backend::available_backends() {
384 let status = if backend.is_available() { "✓" } else { "✗" };
385 println!(
386 " {} {} - {}",
387 status,
388 backend.name(),
389 backend_description(backend)
390 );
391 }
392
393 println!();
394 println!("Backend Capabilities:");
395 println!(" CPU: ✓");
396 println!(" SIMD: {}", if cfg!(feature = "simd") { "✓" } else { "✗" });
397 println!(" GPU: {}", if cfg!(feature = "gpu") { "✓" } else { "✗" });
398 println!(" Parallel: ✓ (Rayon)");
399}
400
401fn backend_description(backend: Backend) -> &'static str {
402 match backend {
403 Backend::SciRS2CPU => "Standard CPU execution",
404 #[cfg(feature = "simd")]
405 Backend::SciRS2SIMD => "Vectorized SIMD acceleration",
406 #[cfg(feature = "gpu")]
407 Backend::SciRS2GPU => "GPU-accelerated execution",
408 Backend::Parallel => "Multi-threaded parallel execution",
409 Backend::Profiled => "Execution with performance profiling",
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_backend_availability() {
419 let cpu = Backend::SciRS2CPU;
420 assert!(cpu.is_available());
421 assert_eq!(cpu.name(), "SciRS2 CPU");
422 }
423
424 #[test]
425 fn test_backend_from_str() {
426 assert!(matches!(
427 Backend::from_str("cpu").unwrap(),
428 Backend::SciRS2CPU
429 ));
430 assert!(matches!(
431 Backend::from_str("parallel").unwrap(),
432 Backend::Parallel
433 ));
434 assert!(Backend::from_str("invalid").is_err());
435 }
436
437 #[test]
438 fn test_default_backend() {
439 let backend = Backend::default();
440 assert!(backend.is_available());
441 }
442
443 #[test]
444 fn test_execution_config_default() {
445 let config = ExecutionConfig::default();
446 assert!(config.backend.is_available());
447 assert_eq!(config.device, DeviceType::Cpu);
448 }
449
450 #[test]
451 fn test_available_backends() {
452 let backends = Backend::available_backends();
453 assert!(!backends.is_empty());
454 assert!(backends.contains(&Backend::SciRS2CPU));
455 }
456}