1use serde::{Deserialize, Serialize};
8use sklears_core::error::{Result as SklResult, SklearsError};
9use std::collections::{BTreeMap, HashMap};
10
11use super::workflow_definitions::{ExecutionMode, StepDefinition, WorkflowDefinition};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum CodeLanguage {
16 Rust,
18 Python,
20 Json,
22 Yaml,
24 JavaScript,
26 Cpp,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub enum FileFormat {
33 Json,
35 Yaml,
37 Toml,
39 Binary,
41}
42
43#[derive(Debug, Clone)]
45pub struct CodeGenerationConfig {
46 pub language: CodeLanguage,
48 pub style: CodeStyle,
50 pub optimization_level: OptimizationLevel,
52 pub include_comments: bool,
54 pub include_type_annotations: bool,
56 pub deployment_target: DeploymentTarget,
58 pub custom_templates: HashMap<String, String>,
60}
61
62#[derive(Debug, Clone)]
64pub struct CodeStyle {
65 pub indent_size: usize,
67 pub use_tabs: bool,
69 pub max_line_length: usize,
71 pub naming_convention: NamingConvention,
73 pub include_error_handling: bool,
75}
76
77#[derive(Debug, Clone)]
79pub enum NamingConvention {
80 SnakeCase,
82 CamelCase,
83 PascalCase,
85 KebabCase,
87}
88
89#[derive(Debug, Clone)]
91pub enum OptimizationLevel {
92 None,
94 Basic,
96 Aggressive,
98 Production,
100}
101
102#[derive(Debug, Clone)]
104pub enum DeploymentTarget {
105 Local,
107 Docker,
109 Kubernetes,
111 CloudFunction,
113 WebAssembly,
115 Embedded,
117}
118
119#[derive(Debug)]
121pub struct CodeGenerator {
122 config: CodeGenerationConfig,
124 templates: TemplateEngine,
126 stats: GenerationStatistics,
128}
129
130#[derive(Debug)]
132pub struct TemplateEngine {
133 templates: HashMap<CodeLanguage, LanguageTemplate>,
135 custom_overrides: HashMap<String, String>,
137}
138
139#[derive(Debug, Clone)]
141pub struct LanguageTemplate {
142 pub header: String,
144 pub imports: String,
146 pub function_def: String,
148 pub step_execution: String,
150 pub connection: String,
152 pub footer: String,
154}
155
156#[derive(Debug, Clone)]
158pub struct GenerationStatistics {
159 pub total_lines: usize,
160 pub code_lines: usize,
161 pub comment_lines: usize,
162 pub function_count: usize,
163 pub import_count: usize,
164 pub generation_time: std::time::Duration,
165}
166
167#[derive(Debug, Clone)]
169pub struct GeneratedCode {
170 pub source_code: String,
172 pub language: CodeLanguage,
174 pub dependencies: Vec<String>,
176 pub instructions: String,
178 pub statistics: GenerationStatistics,
180}
181
182impl CodeGenerator {
183 #[must_use]
185 pub fn new(config: CodeGenerationConfig) -> Self {
186 Self {
187 config,
188 templates: TemplateEngine::new(),
189 stats: GenerationStatistics::new(),
190 }
191 }
192
193 #[must_use]
195 pub fn language(&self) -> &CodeLanguage {
196 &self.config.language
197 }
198
199 #[must_use]
201 pub fn include_comments(&self) -> bool {
202 self.config.include_comments
203 }
204
205 pub fn generate_code(&mut self, workflow: &WorkflowDefinition) -> SklResult<GeneratedCode> {
207 let generation_start = std::time::Instant::now();
208
209 let source_code = match self.config.language {
210 CodeLanguage::Rust => self.generate_rust_code(workflow)?,
211 CodeLanguage::Python => self.generate_python_code(workflow)?,
212 CodeLanguage::Json => self.generate_json_code(workflow)?,
213 CodeLanguage::Yaml => self.generate_yaml_code(workflow)?,
214 CodeLanguage::JavaScript => self.generate_javascript_code(workflow)?,
215 CodeLanguage::Cpp => self.generate_cpp_code(workflow)?,
216 };
217
218 self.stats.generation_time = generation_start.elapsed();
219 self.update_statistics(&source_code);
220
221 Ok(GeneratedCode {
222 source_code,
223 language: self.config.language.clone(),
224 dependencies: self.get_required_dependencies(workflow),
225 instructions: self.generate_instructions(),
226 statistics: self.stats.clone(),
227 })
228 }
229
230 fn generate_rust_code(&self, workflow: &WorkflowDefinition) -> SklResult<String> {
232 let mut code = String::new();
233
234 code.push_str(&self.generate_rust_header(workflow));
236 code.push_str(&self.generate_rust_imports(workflow));
237
238 code.push_str(&self.generate_rust_main_function(workflow));
240
241 for step in &workflow.steps {
243 code.push_str(&self.generate_rust_step_function(step, workflow)?);
244 }
245
246 Ok(code)
247 }
248
249 fn generate_rust_header(&self, workflow: &WorkflowDefinition) -> String {
251 let mut header = String::new();
252
253 if self.config.include_comments {
254 header.push_str(&format!(
255 "//! Generated Rust code for workflow: {}\n",
256 workflow.metadata.name
257 ));
258 header.push_str(&format!("//! Version: {}\n", workflow.metadata.version));
259 if let Some(description) = &workflow.metadata.description {
260 header.push_str(&format!("//! Description: {description}\n"));
261 }
262 header.push_str("//!\n");
263 header.push_str(
264 "//! This code was automatically generated from a workflow definition.\n",
265 );
266 header.push_str("//! Do not edit this file directly.\n\n");
267 }
268
269 header
270 }
271
272 fn generate_rust_imports(&self, _workflow: &WorkflowDefinition) -> String {
274 let mut imports = String::new();
275
276 imports.push_str("use sklears_core::{\n");
277 imports.push_str(" error::{Result as SklResult, SklearsError},\n");
278 imports.push_str(" types::Float,\n");
279 imports.push_str("};\n");
280 imports.push_str("use scirs2_core::ndarray::{Array1, Array2};\n");
281 imports.push_str("use std::collections::HashMap;\n");
282 imports.push_str("use serde::{Serialize, Deserialize};\n\n");
283
284 imports
285 }
286
287 fn generate_rust_main_function(&self, workflow: &WorkflowDefinition) -> String {
289 let function_name = self.convert_name(
290 &workflow.metadata.name,
291 &self.config.style.naming_convention,
292 );
293 let mut code = String::new();
294
295 if self.config.include_comments {
296 code.push_str(&format!(
297 "/// Execute the {} workflow\n",
298 workflow.metadata.name
299 ));
300 }
301
302 code.push_str(&format!(
303 "pub fn {function_name}() -> SklResult<HashMap<String, Array2<Float>>> {{\n"
304 ));
305
306 match workflow.execution.mode {
308 ExecutionMode::Parallel => {
309 code.push_str(" // Parallel execution mode\n");
310 }
311 ExecutionMode::Sequential => {
312 code.push_str(" // Sequential execution mode\n");
313 }
314 _ => {
315 code.push_str(" // Default execution mode\n");
316 }
317 }
318
319 code.push_str(" let mut results = HashMap::new();\n\n");
320
321 for step in &workflow.steps {
323 let step_func_name = self.convert_name(&step.id, &self.config.style.naming_convention);
324 code.push_str(&format!(
325 " let {step_func_name} = {step_func_name}()?;\n"
326 ));
327 }
328
329 code.push_str("\n Ok(results)\n");
330 code.push_str("}\n\n");
331
332 code
333 }
334
335 fn generate_rust_step_function(
337 &self,
338 step: &StepDefinition,
339 _workflow: &WorkflowDefinition,
340 ) -> SklResult<String> {
341 let function_name = self.convert_name(&step.id, &self.config.style.naming_convention);
342 let mut code = String::new();
343
344 if self.config.include_comments {
345 code.push_str(&format!(
346 "/// Execute step: {} ({})\n",
347 step.id, step.algorithm
348 ));
349 if let Some(description) = &step.description {
350 code.push_str(&format!("/// {description}\n"));
351 }
352 }
353
354 code.push_str(&format!(
355 "fn {function_name}() -> SklResult<Array2<Float>> {{\n"
356 ));
357
358 match step.algorithm.as_str() {
360 "StandardScaler" => {
361 code.push_str(" // StandardScaler implementation\n");
362 code.push_str(" // TODO: Implement actual scaling logic\n");
363 code.push_str(" let scaled_data = Array2::zeros((0, 0));\n");
364 code.push_str(" Ok(scaled_data)\n");
365 }
366 "LinearRegression" => {
367 code.push_str(" // LinearRegression implementation\n");
368 code.push_str(" // TODO: Implement actual regression logic\n");
369 code.push_str(" let predictions = Array2::zeros((0, 0));\n");
370 code.push_str(" Ok(predictions)\n");
371 }
372 _ => {
373 code.push_str(&format!(" // {} implementation\n", step.algorithm));
374 code.push_str(" // TODO: Implement component logic\n");
375 code.push_str(" let result = Array2::zeros((0, 0));\n");
376 code.push_str(" Ok(result)\n");
377 }
378 }
379
380 code.push_str("}\n\n");
381
382 Ok(code)
383 }
384
385 fn generate_python_code(&self, workflow: &WorkflowDefinition) -> SklResult<String> {
387 let mut code = String::new();
388
389 code.push_str(&self.generate_python_header(workflow));
391 code.push_str(&self.generate_python_imports());
392
393 code.push_str(&self.generate_python_main_function(workflow));
395
396 for step in &workflow.steps {
398 code.push_str(&self.generate_python_step_function(step)?);
399 }
400
401 Ok(code)
402 }
403
404 fn generate_python_header(&self, workflow: &WorkflowDefinition) -> String {
406 let mut header = String::new();
407
408 if self.config.include_comments {
409 header.push_str(&format!(
410 "\"\"\"Generated Python code for workflow: {}\n",
411 workflow.metadata.name
412 ));
413 header.push_str(&format!("Version: {}\n", workflow.metadata.version));
414 if let Some(description) = &workflow.metadata.description {
415 header.push_str(&format!("Description: {description}\n"));
416 }
417 header
418 .push_str("\nThis code was automatically generated from a workflow definition.\n");
419 header.push_str("Do not edit this file directly.\n");
420 header.push_str("\"\"\"\n\n");
421 }
422
423 header
424 }
425
426 fn generate_python_imports(&self) -> String {
428 let mut imports = String::new();
429
430 imports.push_str("import numpy as np\n");
431 imports.push_str("from sklearn.preprocessing import StandardScaler\n");
432 imports.push_str("from sklearn.linear_model import LinearRegression\n");
433 imports.push_str("from typing import Dict, Any, Optional\n\n");
434
435 imports
436 }
437
438 fn generate_python_main_function(&self, workflow: &WorkflowDefinition) -> String {
440 let function_name =
441 self.convert_name(&workflow.metadata.name, &NamingConvention::SnakeCase);
442 let mut code = String::new();
443
444 if self.config.include_comments {
445 code.push_str(&format!("def {function_name}() -> Dict[str, Any]:\n"));
446 code.push_str(&format!(
447 " \"\"\"Execute the {} workflow.\"\"\"\n",
448 workflow.metadata.name
449 ));
450 } else {
451 code.push_str(&format!("def {function_name}():\n"));
452 }
453
454 code.push_str(" results = {}\n\n");
455
456 for step in &workflow.steps {
458 let step_func_name = self.convert_name(&step.id, &NamingConvention::SnakeCase);
459 code.push_str(&format!(
460 " results['{}'] = {}()\n",
461 step.id, step_func_name
462 ));
463 }
464
465 code.push_str("\n return results\n\n");
466
467 code
468 }
469
470 fn generate_python_step_function(&self, step: &StepDefinition) -> SklResult<String> {
472 let function_name = self.convert_name(&step.id, &NamingConvention::SnakeCase);
473 let mut code = String::new();
474
475 if self.config.include_comments {
476 code.push_str(&format!("def {function_name}():\n"));
477 code.push_str(&format!(
478 " \"\"\"Execute step: {} ({}).\"\"\"\n",
479 step.id, step.algorithm
480 ));
481 } else {
482 code.push_str(&format!("def {function_name}():\n"));
483 }
484
485 match step.algorithm.as_str() {
487 "StandardScaler" => {
488 code.push_str(" scaler = StandardScaler()\n");
489 code.push_str(" # TODO: Implement actual scaling logic\n");
490 code.push_str(" return scaler\n");
491 }
492 "LinearRegression" => {
493 code.push_str(" model = LinearRegression()\n");
494 code.push_str(" # TODO: Implement actual training logic\n");
495 code.push_str(" return model\n");
496 }
497 _ => {
498 code.push_str(&format!(" # {} implementation\n", step.algorithm));
499 code.push_str(" # TODO: Implement component logic\n");
500 code.push_str(" return None\n");
501 }
502 }
503
504 code.push('\n');
505
506 Ok(code)
507 }
508
509 fn generate_json_code(&self, workflow: &WorkflowDefinition) -> SklResult<String> {
511 match serde_json::to_string_pretty(workflow) {
512 Ok(json) => Ok(json),
513 Err(e) => Err(SklearsError::InvalidInput(format!(
514 "JSON serialization failed: {e}"
515 ))),
516 }
517 }
518
519 fn generate_yaml_code(&self, workflow: &WorkflowDefinition) -> SklResult<String> {
521 match serde_yaml::to_string(workflow) {
522 Ok(yaml) => Ok(yaml),
523 Err(e) => Err(SklearsError::InvalidInput(format!(
524 "YAML serialization failed: {e}"
525 ))),
526 }
527 }
528
529 fn generate_javascript_code(&self, workflow: &WorkflowDefinition) -> SklResult<String> {
531 let mut code = String::new();
532
533 code.push_str(&format!(
534 "// Generated JavaScript code for workflow: {}\n\n",
535 workflow.metadata.name
536 ));
537
538 code.push_str(&format!(
539 "function {}() {{\n",
540 self.convert_name(&workflow.metadata.name, &NamingConvention::CamelCase)
541 ));
542
543 code.push_str(" const results = {};\n\n");
544
545 for step in &workflow.steps {
546 code.push_str(&format!(
547 " results['{}'] = {}();\n",
548 step.id,
549 self.convert_name(&step.id, &NamingConvention::CamelCase)
550 ));
551 }
552
553 code.push_str("\n return results;\n");
554 code.push_str("}\n\n");
555
556 for step in &workflow.steps {
558 code.push_str(&format!(
559 "function {}() {{\n",
560 self.convert_name(&step.id, &NamingConvention::CamelCase)
561 ));
562 code.push_str(&format!(" // {} implementation\n", step.algorithm));
563 code.push_str(" // TODO: Implement component logic\n");
564 code.push_str(" return null;\n");
565 code.push_str("}\n\n");
566 }
567
568 Ok(code)
569 }
570
571 fn generate_cpp_code(&self, workflow: &WorkflowDefinition) -> SklResult<String> {
573 let mut code = String::new();
574
575 code.push_str(&format!(
577 "// Generated C++ code for workflow: {}\n\n",
578 workflow.metadata.name
579 ));
580
581 code.push_str("#include <iostream>\n");
583 code.push_str("#include <vector>\n");
584 code.push_str("#include <map>\n");
585 code.push_str("#include <string>\n\n");
586
587 code.push_str(&format!(
589 "std::map<std::string, void*> {}() {{\n",
590 self.convert_name(&workflow.metadata.name, &NamingConvention::SnakeCase)
591 ));
592
593 code.push_str(" std::map<std::string, void*> results;\n\n");
594
595 for step in &workflow.steps {
596 code.push_str(&format!(
597 " results[\"{}\"] = {}();\n",
598 step.id,
599 self.convert_name(&step.id, &NamingConvention::SnakeCase)
600 ));
601 }
602
603 code.push_str("\n return results;\n");
604 code.push_str("}\n\n");
605
606 for step in &workflow.steps {
608 code.push_str(&format!(
609 "void* {}() {{\n",
610 self.convert_name(&step.id, &NamingConvention::SnakeCase)
611 ));
612 code.push_str(&format!(" // {} implementation\n", step.algorithm));
613 code.push_str(" // TODO: Implement component logic\n");
614 code.push_str(" return nullptr;\n");
615 code.push_str("}\n\n");
616 }
617
618 Ok(code)
619 }
620
621 #[must_use]
623 pub fn convert_name(&self, name: &str, convention: &NamingConvention) -> String {
624 match convention {
625 NamingConvention::SnakeCase => name
626 .split(|c: char| c.is_whitespace() || c == '-' || c == '_')
627 .filter(|segment| !segment.is_empty())
628 .map(|segment| segment.to_lowercase())
629 .collect::<Vec<_>>()
630 .join("_"),
631 NamingConvention::CamelCase => {
632 let mut result = String::new();
633 let mut capitalize_next = false;
634 for (i, c) in name.chars().enumerate() {
635 if c.is_whitespace() || c == '_' || c == '-' {
636 capitalize_next = true;
637 } else if i == 0 {
638 result.push(c.to_lowercase().next().unwrap());
639 } else if capitalize_next {
640 result.push(c.to_uppercase().next().unwrap());
641 capitalize_next = false;
642 } else {
643 result.push(c.to_lowercase().next().unwrap());
644 }
645 }
646 result
647 }
648 NamingConvention::PascalCase => {
649 let mut result = String::new();
650 let mut capitalize_next = true;
651 for c in name.chars() {
652 if c.is_whitespace() || c == '_' || c == '-' {
653 capitalize_next = true;
654 } else if capitalize_next {
655 result.push(c.to_uppercase().next().unwrap());
656 capitalize_next = false;
657 } else {
658 result.push(c.to_lowercase().next().unwrap());
659 }
660 }
661 result
662 }
663 NamingConvention::KebabCase => name
664 .to_lowercase()
665 .chars()
666 .map(|c| {
667 if c.is_whitespace() || c == '_' {
668 '-'
669 } else {
670 c
671 }
672 })
673 .collect::<String>()
674 .replace("--", "-"),
675 }
676 }
677
678 fn get_required_dependencies(&self, workflow: &WorkflowDefinition) -> Vec<String> {
680 let mut dependencies = Vec::new();
681
682 match self.config.language {
683 CodeLanguage::Rust => {
684 dependencies.push("sklears-core".to_string());
685 dependencies.push("scirs2-autograd".to_string());
686 dependencies.push("serde".to_string());
687 }
688 CodeLanguage::Python => {
689 dependencies.push("numpy".to_string());
690 dependencies.push("scikit-learn".to_string());
691 }
692 _ => {}
693 }
694
695 for step in &workflow.steps {
697 match step.algorithm.as_str() {
698 "StandardScaler" | "LinearRegression" => {
699 if matches!(self.config.language, CodeLanguage::Python)
700 && !dependencies.contains(&"scikit-learn".to_string())
701 {
702 dependencies.push("scikit-learn".to_string());
703 }
704 }
705 _ => {}
706 }
707 }
708
709 dependencies
710 }
711
712 fn generate_instructions(&self) -> String {
714 match self.config.language {
715 CodeLanguage::Rust => "To compile and run:\n\
716 1. Add dependencies to Cargo.toml\n\
717 2. Run 'cargo build'\n\
718 3. Run 'cargo run'"
719 .to_string(),
720 CodeLanguage::Python => "To run:\n\
721 1. Install dependencies: pip install numpy scikit-learn\n\
722 2. Run: python workflow.py"
723 .to_string(),
724 CodeLanguage::JavaScript => "To run:\n\
725 1. Install Node.js\n\
726 2. Run: node workflow.js"
727 .to_string(),
728 CodeLanguage::Cpp => "To compile and run:\n\
729 1. Compile: g++ -o workflow workflow.cpp\n\
730 2. Run: ./workflow"
731 .to_string(),
732 _ => "See language-specific documentation".to_string(),
733 }
734 }
735
736 fn update_statistics(&mut self, source_code: &str) {
738 let lines: Vec<&str> = source_code.lines().collect();
739 self.stats.total_lines = lines.len();
740
741 self.stats.code_lines = lines
742 .iter()
743 .filter(|line| {
744 !line.trim().is_empty()
745 && !line.trim().starts_with("//")
746 && !line.trim().starts_with('#')
747 })
748 .count();
749
750 self.stats.comment_lines = lines
751 .iter()
752 .filter(|line| line.trim().starts_with("//") || line.trim().starts_with('#'))
753 .count();
754
755 self.stats.function_count = source_code.matches("fn ").count()
757 + source_code.matches("def ").count()
758 + source_code.matches("function ").count();
759
760 self.stats.import_count = source_code.matches("use ").count()
762 + source_code.matches("import ").count()
763 + source_code.matches("#include").count();
764 }
765}
766
767impl TemplateEngine {
768 fn new() -> Self {
769 Self {
770 templates: HashMap::new(),
771 custom_overrides: HashMap::new(),
772 }
773 }
774}
775
776impl GenerationStatistics {
777 fn new() -> Self {
778 Self {
779 total_lines: 0,
780 code_lines: 0,
781 comment_lines: 0,
782 function_count: 0,
783 import_count: 0,
784 generation_time: std::time::Duration::from_secs(0),
785 }
786 }
787}
788
789impl Default for CodeGenerationConfig {
790 fn default() -> Self {
791 Self {
792 language: CodeLanguage::Rust,
793 style: CodeStyle::default(),
794 optimization_level: OptimizationLevel::Basic,
795 include_comments: true,
796 include_type_annotations: true,
797 deployment_target: DeploymentTarget::Local,
798 custom_templates: HashMap::new(),
799 }
800 }
801}
802
803impl Default for CodeStyle {
804 fn default() -> Self {
805 Self {
806 indent_size: 4,
807 use_tabs: false,
808 max_line_length: 100,
809 naming_convention: NamingConvention::SnakeCase,
810 include_error_handling: true,
811 }
812 }
813}
814
815#[derive(Debug, Clone, thiserror::Error)]
817pub enum CodeGenerationError {
818 #[error("Template compilation error: {0}")]
820 TemplateError(String),
821 #[error("Language backend error: {0}")]
823 LanguageError(String),
824 #[error("Configuration error: {0}")]
826 ConfigError(String),
827 #[error("Workflow validation error: {0}")]
829 WorkflowError(String),
830 #[error("IO error: {0}")]
832 IoError(String),
833 #[error("Syntax error in generated code: {0}")]
835 SyntaxError(String),
836 #[error("Dependency resolution error: {0}")]
838 DependencyError(String),
839 #[error("Sklears error: {0}")]
841 SklearsError(#[from] sklears_core::error::SklearsError),
842}
843
844#[derive(Debug, Clone, Serialize, Deserialize)]
846pub struct CodeTemplate {
847 pub name: String,
849 pub content: String,
851 pub language: CodeLanguage,
853 pub variables: Vec<String>,
855 pub metadata: TemplateMetadata,
857}
858
859#[derive(Debug, Clone, Serialize, Deserialize)]
861pub struct TemplateMetadata {
862 pub author: Option<String>,
864 pub version: String,
866 pub description: Option<String>,
868 pub tags: Vec<String>,
870}
871
872#[derive(Debug, Clone, Serialize, Deserialize)]
874pub struct LanguageBackend {
875 pub language: CodeLanguage,
877 pub generator: String,
879 pub template_engine: String,
881 pub features: Vec<String>,
883 pub config: BTreeMap<String, String>,
885}
886
887pub type TargetLanguage = CodeLanguage;
889
890#[derive(Debug, Clone, Serialize, Deserialize, Default)]
892pub struct TemplateContext {
893 pub variables: BTreeMap<String, TemplateValue>,
895 pub constants: BTreeMap<String, String>,
897 pub include_paths: Vec<String>,
899 pub flags: BTreeMap<String, bool>,
901}
902
903#[derive(Debug, Clone, Serialize, Deserialize)]
905pub enum TemplateValue {
906 String(String),
908 Number(f64),
910 Boolean(bool),
912 Array(Vec<TemplateValue>),
914 Object(BTreeMap<String, TemplateValue>),
916}
917
918#[derive(Debug, Clone, Serialize, Deserialize)]
920pub struct TemplateRegistry {
921 pub templates: BTreeMap<String, CodeTemplate>,
923 pub categories: BTreeMap<String, Vec<String>>,
925 pub metadata: RegistryMetadata,
927}
928
929#[derive(Debug, Clone, Serialize, Deserialize)]
931pub struct RegistryMetadata {
932 pub name: String,
934 pub version: String,
936 pub updated_at: String,
938 pub template_count: usize,
940}
941
942impl TemplateRegistry {
943 #[must_use]
945 pub fn new() -> Self {
946 Self {
947 templates: BTreeMap::new(),
948 categories: BTreeMap::new(),
949 metadata: RegistryMetadata {
950 name: "Default Registry".to_string(),
951 version: "1.0.0".to_string(),
952 updated_at: chrono::Utc::now().to_rfc3339(),
953 template_count: 0,
954 },
955 }
956 }
957
958 pub fn register_template(&mut self, template: CodeTemplate) {
960 self.templates.insert(template.name.clone(), template);
961 self.metadata.template_count = self.templates.len();
962 self.metadata.updated_at = chrono::Utc::now().to_rfc3339();
963 }
964
965 #[must_use]
967 pub fn get_template(&self, name: &str) -> Option<&CodeTemplate> {
968 self.templates.get(name)
969 }
970}
971
972impl Default for TemplateRegistry {
973 fn default() -> Self {
974 Self::new()
975 }
976}
977
978#[allow(non_snake_case)]
979#[cfg(test)]
980mod tests {
981 use super::*;
982 use crate::workflow_language::workflow_definitions::{StepType, WorkflowMetadata};
983
984 #[test]
985 fn test_code_generator_creation() {
986 let config = CodeGenerationConfig::default();
987 let generator = CodeGenerator::new(config);
988 assert_eq!(generator.stats.total_lines, 0);
989 }
990
991 #[test]
992 fn test_rust_code_generation() {
993 let mut generator = CodeGenerator::new(CodeGenerationConfig::default());
994 let mut workflow = WorkflowDefinition::default();
995 workflow.metadata.name = "test_workflow".to_string();
996 workflow.steps.push(StepDefinition::new(
997 "step1",
998 StepType::Transformer,
999 "StandardScaler",
1000 ));
1001
1002 let result = generator.generate_code(&workflow);
1003 assert!(result.is_ok());
1004
1005 let generated = result.unwrap();
1006 assert!(!generated.source_code.is_empty());
1007 assert!(matches!(generated.language, CodeLanguage::Rust));
1008 assert!(!generated.dependencies.is_empty());
1009 }
1010
1011 #[test]
1012 fn test_python_code_generation() {
1013 let config = CodeGenerationConfig {
1014 language: CodeLanguage::Python,
1015 ..Default::default()
1016 };
1017 let mut generator = CodeGenerator::new(config);
1018 let mut workflow = WorkflowDefinition::default();
1019 workflow.metadata.name = "test_workflow".to_string();
1020 workflow.steps.push(StepDefinition::new(
1021 "step1",
1022 StepType::Transformer,
1023 "StandardScaler",
1024 ));
1025
1026 let result = generator.generate_code(&workflow);
1027 assert!(result.is_ok());
1028
1029 let generated = result.unwrap();
1030 assert!(!generated.source_code.is_empty());
1031 assert!(matches!(generated.language, CodeLanguage::Python));
1032 assert!(generated.source_code.contains("def "));
1033 }
1034
1035 #[test]
1036 fn test_json_code_generation() {
1037 let config = CodeGenerationConfig {
1038 language: CodeLanguage::Json,
1039 ..Default::default()
1040 };
1041 let mut generator = CodeGenerator::new(config);
1042 let workflow = WorkflowDefinition::default();
1043
1044 let result = generator.generate_code(&workflow);
1045 assert!(result.is_ok());
1046
1047 let generated = result.unwrap();
1048 assert!(!generated.source_code.is_empty());
1049 assert!(generated.source_code.contains("{"));
1050 }
1051
1052 #[test]
1053 fn test_naming_convention_conversion() {
1054 let config = CodeGenerationConfig::default();
1055 let generator = CodeGenerator::new(config);
1056
1057 assert_eq!(
1058 generator.convert_name("Test Workflow", &NamingConvention::SnakeCase),
1059 "test_workflow"
1060 );
1061 assert_eq!(
1062 generator.convert_name("test_workflow", &NamingConvention::CamelCase),
1063 "testWorkflow"
1064 );
1065 assert_eq!(
1066 generator.convert_name("test_workflow", &NamingConvention::PascalCase),
1067 "TestWorkflow"
1068 );
1069 assert_eq!(
1070 generator.convert_name("test_workflow", &NamingConvention::KebabCase),
1071 "test-workflow"
1072 );
1073 }
1074}