1mod internal {
21 use crate::FxGraph;
22
23 use torsh_core::Result;
24
25 pub trait CodeGenBackend: std::fmt::Debug {
27 fn generate(&self, graph: &FxGraph) -> Result<String>;
28 fn file_extension(&self) -> &'static str;
29 fn language_name(&self) -> &'static str;
30 }
31
32 #[derive(Debug, Clone, PartialEq, Eq)]
34 pub enum BackendType {
35 CPU,
36 CUDA,
37 TensorRT,
38 }
39
40 #[derive(Debug, Clone, PartialEq, Eq)]
42 pub enum OptimizationLevel {
43 Debug,
44 Release,
45 Aggressive,
46 }
47
48 #[derive(Debug, Clone, PartialEq, Eq)]
50 pub enum Precision {
51 Float16,
52 Float32,
53 Mixed,
54 }
55
56 #[derive(Debug, Clone, PartialEq, Eq)]
58 pub enum TargetDevice {
59 CPU,
60 CUDA,
61 }
62
63 #[derive(Debug, Clone, PartialEq, Eq)]
65 pub enum SimdSupport {
66 None,
67 AVX2,
68 }
69
70 #[derive(Debug, Clone, PartialEq, Eq)]
72 pub enum MemoryLayout {
73 RowMajor,
74 ColumnMajor,
75 }
76
77 #[derive(Debug, Clone)]
79 pub struct TargetSpecification {
80 pub device: TargetDevice,
81 pub simd_support: SimdSupport,
82 pub optimization_level: OptimizationLevel,
83 pub precision: Precision,
84 pub memory_layout: MemoryLayout,
85 }
86}
87
88pub use internal::{
90 BackendType, CodeGenBackend, MemoryLayout, OptimizationLevel, Precision, SimdSupport,
91 TargetDevice, TargetSpecification,
92};
93
94mod enhanced_backends {
96 use super::internal::CodeGenBackend;
97 use crate::FxGraph;
98 use torsh_core::Result;
99
100 #[derive(Debug, Clone)]
102 pub struct PythonCodeGen {
103 pub use_torch: bool,
104 pub indent_size: usize,
105 }
106
107 impl Default for PythonCodeGen {
108 fn default() -> Self {
109 Self {
110 use_torch: true,
111 indent_size: 4,
112 }
113 }
114 }
115
116 impl PythonCodeGen {
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn with_torch(mut self, use_torch: bool) -> Self {
122 self.use_torch = use_torch;
123 self
124 }
125 }
126
127 impl CodeGenBackend for PythonCodeGen {
128 fn generate(&self, graph: &FxGraph) -> Result<String> {
129 let mut code = String::new();
131 code.push_str("# Generated Python code from FX graph\n");
132
133 if self.use_torch {
134 code.push_str("import torch\n");
135 code.push_str("import torch.nn.functional as F\n");
136 } else {
137 code.push_str("import numpy as np\n");
138 }
139
140 code.push_str("\ndef generated_function(");
141 for (i, _input) in graph.inputs().iter().enumerate() {
142 if i > 0 {
143 code.push_str(", ");
144 }
145 code.push_str(&format!("input_{}", i));
146 }
147 code.push_str("):\n");
148
149 for node_index in graph.graph.node_indices() {
151 if let Some(node) = graph.graph.node_weight(node_index) {
152 let indent = " ".repeat(self.indent_size);
153 match node {
154 crate::Node::Call(op_name, _) => {
155 code.push_str(&format!("{}# Operation: {}\n", indent, op_name));
156 }
157 _ => {}
158 }
159 }
160 }
161
162 code.push_str(&format!("{}return result\n", " ".repeat(self.indent_size)));
163 Ok(code)
164 }
165
166 fn file_extension(&self) -> &'static str {
167 "py"
168 }
169
170 fn language_name(&self) -> &'static str {
171 if self.use_torch {
172 "PyTorch"
173 } else {
174 "NumPy"
175 }
176 }
177 }
178
179 #[derive(Debug, Clone)]
181 pub struct CppCodeGen {
182 pub use_libtorch: bool,
183 pub indent_size: usize,
184 }
185
186 impl Default for CppCodeGen {
187 fn default() -> Self {
188 Self {
189 use_libtorch: true,
190 indent_size: 2,
191 }
192 }
193 }
194
195 impl CppCodeGen {
196 pub fn new() -> Self {
197 Self::default()
198 }
199
200 pub fn with_libtorch(mut self, use_libtorch: bool) -> Self {
201 self.use_libtorch = use_libtorch;
202 self
203 }
204 }
205
206 impl CodeGenBackend for CppCodeGen {
207 fn generate(&self, graph: &FxGraph) -> Result<String> {
208 let mut code = String::new();
209 code.push_str("// Generated C++ code from FX graph\n");
210
211 if self.use_libtorch {
212 code.push_str("#include <torch/torch.h>\n");
213 code.push_str("#include <torch/script.h>\n");
214 } else {
215 code.push_str("#include <vector>\n");
216 code.push_str("#include <cmath>\n");
217 }
218
219 code.push_str("\n");
220 if self.use_libtorch {
221 code.push_str("torch::Tensor generated_function(");
222 } else {
223 code.push_str("std::vector<float> generated_function(");
224 }
225
226 for (i, _) in graph.inputs().iter().enumerate() {
227 if i > 0 {
228 code.push_str(", ");
229 }
230 if self.use_libtorch {
231 code.push_str(&format!("const torch::Tensor& input_{}", i));
232 } else {
233 code.push_str(&format!("const std::vector<float>& input_{}", i));
234 }
235 }
236 code.push_str(") {\n");
237
238 let indent = " ".repeat(self.indent_size);
239 code.push_str(&format!("{}// Function implementation\n", indent));
240
241 if self.use_libtorch {
242 code.push_str(&format!("{}torch::Tensor result;\n", indent));
243 code.push_str(&format!("{}return result;\n", indent));
244 } else {
245 code.push_str(&format!("{}std::vector<float> result;\n", indent));
246 code.push_str(&format!("{}return result;\n", indent));
247 }
248
249 code.push_str("}\n");
250 Ok(code)
251 }
252
253 fn file_extension(&self) -> &'static str {
254 "cpp"
255 }
256
257 fn language_name(&self) -> &'static str {
258 if self.use_libtorch {
259 "LibTorch C++"
260 } else {
261 "Plain C++"
262 }
263 }
264 }
265}
266
267pub use enhanced_backends::{CppCodeGen, PythonCodeGen};
268
269use crate::FxGraph;
270use std::collections::HashMap;
271use torsh_core::error::Result;
272
273#[derive(Debug)]
275pub struct CodeGenerator {
276 backends: HashMap<String, Box<dyn CodeGenBackend>>,
277}
278
279impl Default for CodeGenerator {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285impl CodeGenerator {
286 pub fn new() -> Self {
288 let mut generator = Self {
289 backends: HashMap::new(),
290 };
291
292 generator.add_backend("python".to_string(), PythonCodeGen::new());
294 generator.add_backend("cpp".to_string(), CppCodeGen::new());
295 generator.add_backend("pytorch".to_string(), PythonCodeGen::new().with_torch(true));
296 generator.add_backend("numpy".to_string(), PythonCodeGen::new().with_torch(false));
297 generator.add_backend(
298 "libtorch".to_string(),
299 CppCodeGen::new().with_libtorch(true),
300 );
301 generator.add_backend(
302 "plain_cpp".to_string(),
303 CppCodeGen::new().with_libtorch(false),
304 );
305
306 generator
307 }
308
309 pub fn add_backend<T: CodeGenBackend + 'static>(&mut self, name: String, backend: T) {
311 self.backends.insert(name, Box::new(backend));
312 }
313
314 pub fn available_targets(&self) -> Vec<String> {
316 self.backends.keys().cloned().collect()
317 }
318
319 pub fn generate_code(&self, graph: &FxGraph, target: &str) -> Result<String> {
321 if let Some(backend) = self.backends.get(target) {
322 backend.generate(graph)
323 } else {
324 Ok(format!(
325 "// Code generation not implemented for target: {}",
326 target
327 ))
328 }
329 }
330}
331
332#[derive(Debug, Clone, Default)]
334pub struct CacheStats {
335 pub hits: usize,
336 pub misses: usize,
337 pub evictions: usize,
338}
339
340#[derive(Debug, Clone)]
342pub struct CompiledCode {
343 pub source: String,
344 pub target: String,
345 pub language: String,
346 pub file_extension: String,
347}
348
349impl CompiledCode {
350 pub fn new(source: String, target: String, language: String, file_extension: String) -> Self {
351 Self {
352 source,
353 target,
354 language,
355 file_extension,
356 }
357 }
358}
359
360#[derive(Debug)]
362pub struct LazyCompiler {
363 generator: CodeGenerator,
364 cache: HashMap<String, CompiledCode>,
365 stats: CacheStats,
366}
367
368impl Default for LazyCompiler {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374impl LazyCompiler {
375 pub fn new() -> Self {
377 Self {
378 generator: CodeGenerator::new(),
379 cache: HashMap::new(),
380 stats: CacheStats::default(),
381 }
382 }
383
384 pub fn compile(&mut self, graph: &FxGraph, target: &str) -> Result<CompiledCode> {
386 let cache_key = format!("{}-{}", graph.node_count(), target);
387
388 if let Some(cached) = self.cache.get(&cache_key).cloned() {
389 self.stats.hits += 1;
390 return Ok(cached);
391 }
392
393 self.stats.misses += 1;
394 let source = self.generator.generate_code(graph, target)?;
395
396 let language = if let Some(backend) = self.generator.backends.get(target) {
397 backend.language_name().to_string()
398 } else {
399 "Unknown".to_string()
400 };
401
402 let file_extension = if let Some(backend) = self.generator.backends.get(target) {
403 backend.file_extension().to_string()
404 } else {
405 "txt".to_string()
406 };
407
408 let compiled = CompiledCode::new(source, target.to_string(), language, file_extension);
409 self.cache.insert(cache_key, compiled.clone());
410
411 Ok(compiled)
412 }
413
414 pub fn cache_stats(&self) -> &CacheStats {
416 &self.stats
417 }
418
419 pub fn clear_cache(&mut self) {
421 self.cache.clear();
422 self.stats.evictions += self.cache.len();
423 }
424}
425
426pub fn create_code_generator() -> CodeGenerator {
428 CodeGenerator::new()
429}
430
431pub fn generate_python_code(graph: &FxGraph) -> Result<String> {
433 use internal::CodeGenBackend;
434 let backend = PythonCodeGen::new();
435 backend.generate(graph)
436}
437
438pub fn generate_cpp_code(graph: &FxGraph) -> Result<String> {
440 use internal::CodeGenBackend;
441 let backend = CppCodeGen::new();
442 backend.generate(graph)
443}
444
445pub fn generate_numpy_code(graph: &FxGraph) -> Result<String> {
447 use internal::CodeGenBackend;
448 let backend = PythonCodeGen::new().with_torch(false);
449 backend.generate(graph)
450}
451
452pub fn generate_plain_cpp_code(graph: &FxGraph) -> Result<String> {
454 use internal::CodeGenBackend;
455 let backend = CppCodeGen::new().with_libtorch(false);
456 backend.generate(graph)
457}
458
459pub fn cpu_target_spec() -> TargetSpecification {
461 TargetSpecification {
462 device: TargetDevice::CPU,
463 simd_support: SimdSupport::AVX2,
464 optimization_level: OptimizationLevel::Release,
465 precision: Precision::Float32,
466 memory_layout: MemoryLayout::RowMajor,
467 }
468}
469
470pub fn cuda_target_spec() -> TargetSpecification {
472 TargetSpecification {
473 device: TargetDevice::CUDA,
474 simd_support: SimdSupport::None, optimization_level: OptimizationLevel::Aggressive,
476 precision: Precision::Float32,
477 memory_layout: MemoryLayout::RowMajor,
478 }
479}
480
481pub fn mixed_precision_target_spec() -> TargetSpecification {
483 TargetSpecification {
484 device: TargetDevice::CUDA,
485 simd_support: SimdSupport::None,
486 optimization_level: OptimizationLevel::Aggressive,
487 precision: Precision::Mixed,
488 memory_layout: MemoryLayout::RowMajor,
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_create_code_generator() {
498 let generator = create_code_generator();
499 let targets = generator.available_targets();
500
501 assert!(targets.len() >= 2);
502 assert!(targets.contains(&"python".to_string()));
503 assert!(targets.contains(&"cpp".to_string()));
504 }
505
506 #[test]
507 fn test_convenience_functions() {
508 assert!(true);
511 }
512
513 #[test]
514 fn test_target_specifications() {
515 let cpu_spec = cpu_target_spec();
516 assert_eq!(cpu_spec.device, TargetDevice::CPU);
517 assert_eq!(cpu_spec.precision, Precision::Float32);
518
519 let cuda_spec = cuda_target_spec();
520 assert_eq!(cuda_spec.device, TargetDevice::CUDA);
521 assert_eq!(cuda_spec.optimization_level, OptimizationLevel::Aggressive);
522
523 let mixed_spec = mixed_precision_target_spec();
524 assert_eq!(mixed_spec.precision, Precision::Mixed);
525 }
526
527 #[test]
528 fn test_backend_types() {
529 let cpu_backend = BackendType::CPU;
531 let cuda_backend = BackendType::CUDA;
532 let tensorrt_backend = BackendType::TensorRT;
533
534 assert_ne!(cpu_backend, cuda_backend);
535 assert_ne!(cuda_backend, tensorrt_backend);
536 }
537
538 #[test]
539 fn test_optimization_levels() {
540 let debug = OptimizationLevel::Debug;
541 let release = OptimizationLevel::Release;
542 let aggressive = OptimizationLevel::Aggressive;
543
544 assert_ne!(debug, release);
545 assert_ne!(release, aggressive);
546 }
547
548 #[test]
549 fn test_precision_types() {
550 let fp32 = Precision::Float32;
551 let fp16 = Precision::Float16;
552 let mixed = Precision::Mixed;
553
554 assert_ne!(fp32, fp16);
555 assert_ne!(fp16, mixed);
556 }
557}