Skip to main content

torsh_jit/
metaprogramming.rs

1//! Metaprogramming support for dynamic code generation and reflection
2//!
3//! This module provides comprehensive metaprogramming capabilities including:
4//! - Dynamic code generation from templates
5//! - Runtime reflection and introspection
6//! - Compile-time code transformation
7//! - Template-based code specialization
8
9use crate::{ComputationGraph, JitError, JitResult, NodeId};
10use std::collections::HashMap;
11// use std::fmt::Write; // Reserved for future template string manipulation
12use torsh_core::{DType, Shape};
13
14/// Metaprogramming engine for dynamic code generation
15pub struct MetaprogrammingEngine {
16    templates: HashMap<String, CodeTemplate>,
17    macros: HashMap<String, MacroDefinition>,
18    reflector: RuntimeReflector,
19    code_generator: DynamicCodeGenerator,
20}
21
22impl MetaprogrammingEngine {
23    /// Create a new metaprogramming engine
24    pub fn new() -> Self {
25        Self {
26            templates: HashMap::new(),
27            macros: HashMap::new(),
28            reflector: RuntimeReflector::new(),
29            code_generator: DynamicCodeGenerator::new(),
30        }
31    }
32
33    /// Register a code template
34    pub fn register_template(&mut self, name: String, template: CodeTemplate) {
35        self.templates.insert(name, template);
36    }
37
38    /// Register a macro definition
39    pub fn register_macro(&mut self, name: String, macro_def: MacroDefinition) {
40        self.macros.insert(name, macro_def);
41    }
42
43    /// Generate code from a template
44    pub fn generate_from_template(
45        &self,
46        template_name: &str,
47        parameters: &TemplateParameters,
48    ) -> JitResult<GeneratedCode> {
49        let template = self.templates.get(template_name).ok_or_else(|| {
50            JitError::CompilationError(format!("Template '{}' not found", template_name))
51        })?;
52
53        template.instantiate(parameters, &self.code_generator)
54    }
55
56    /// Expand a macro with given arguments
57    pub fn expand_macro(
58        &self,
59        macro_name: &str,
60        args: &[MacroArgument],
61    ) -> JitResult<GeneratedCode> {
62        let macro_def = self.macros.get(macro_name).ok_or_else(|| {
63            JitError::CompilationError(format!("Macro '{}' not found", macro_name))
64        })?;
65
66        macro_def.expand(args, &self.code_generator)
67    }
68
69    /// Reflect on a computation graph
70    pub fn reflect_graph(&self, graph: &ComputationGraph) -> GraphReflection {
71        self.reflector.reflect_graph(graph)
72    }
73
74    /// Generate specialized code based on runtime information
75    pub fn generate_specialized_code(
76        &self,
77        base_template: &str,
78        specialization_info: &SpecializationInfo,
79    ) -> JitResult<GeneratedCode> {
80        let mut params = TemplateParameters::new();
81
82        // Add specialization parameters
83        for (key, value) in &specialization_info.type_info {
84            params.add_type(key.clone(), value.clone());
85        }
86
87        for (key, value) in &specialization_info.shape_info {
88            params.add_shape(key.clone(), value.clone());
89        }
90
91        for (key, value) in &specialization_info.constants {
92            params.add_constant(key.clone(), value.clone());
93        }
94
95        self.generate_from_template(base_template, &params)
96    }
97}
98
99/// Code template for dynamic code generation
100#[derive(Debug, Clone)]
101pub struct CodeTemplate {
102    pub name: String,
103    pub template_string: String,
104    pub parameters: Vec<TemplateParameter>,
105    pub constraints: Vec<TemplateConstraint>,
106}
107
108impl CodeTemplate {
109    /// Create a new code template
110    pub fn new(name: String, template_string: String) -> Self {
111        Self {
112            name,
113            template_string,
114            parameters: Vec::new(),
115            constraints: Vec::new(),
116        }
117    }
118
119    /// Add a template parameter
120    pub fn add_parameter(&mut self, param: TemplateParameter) {
121        self.parameters.push(param);
122    }
123
124    /// Add a template constraint
125    pub fn add_constraint(&mut self, constraint: TemplateConstraint) {
126        self.constraints.push(constraint);
127    }
128
129    /// Instantiate the template with given parameters
130    pub fn instantiate(
131        &self,
132        parameters: &TemplateParameters,
133        generator: &DynamicCodeGenerator,
134    ) -> JitResult<GeneratedCode> {
135        // Validate constraints
136        self.validate_constraints(parameters)?;
137
138        // Perform template substitution
139        let mut code = self.template_string.clone();
140
141        // Replace type parameters
142        for (name, dtype) in &parameters.types {
143            let replacement = generator.format_type(dtype);
144            code = code.replace(&format!("${{{}}}", name), &replacement);
145        }
146
147        // Replace shape parameters
148        for (name, shape) in &parameters.shapes {
149            let replacement = generator.format_shape(shape);
150            code = code.replace(&format!("$shape{{{}}}", name), &replacement);
151        }
152
153        // Replace constant parameters
154        for (name, value) in &parameters.constants {
155            let replacement = generator.format_constant(value);
156            code = code.replace(&format!("$const{{{}}}", name), &replacement);
157        }
158
159        // Replace code block parameters
160        for (name, block) in &parameters.code_blocks {
161            code = code.replace(&format!("$code{{{}}}", name), block);
162        }
163
164        Ok(GeneratedCode {
165            source: code,
166            metadata: CodeMetadata {
167                template_name: self.name.clone(),
168                parameters: parameters.clone(),
169                generated_at: std::time::SystemTime::now(),
170            },
171        })
172    }
173
174    /// Validate template constraints
175    fn validate_constraints(&self, parameters: &TemplateParameters) -> JitResult<()> {
176        for constraint in &self.constraints {
177            match constraint {
178                TemplateConstraint::TypeConstraint {
179                    param_name,
180                    allowed_types,
181                } => {
182                    if let Some(dtype) = parameters.types.get(param_name) {
183                        if !allowed_types.contains(dtype) {
184                            return Err(JitError::CompilationError(format!(
185                                "Type constraint violated for parameter '{}'",
186                                param_name
187                            )));
188                        }
189                    }
190                }
191                TemplateConstraint::ShapeConstraint {
192                    param_name,
193                    dimension_count,
194                } => {
195                    if let Some(shape) = parameters.shapes.get(param_name) {
196                        if shape.ndim() != *dimension_count {
197                            return Err(JitError::CompilationError(format!(
198                                "Shape constraint violated for parameter '{}'",
199                                param_name
200                            )));
201                        }
202                    }
203                }
204                TemplateConstraint::ValueConstraint {
205                    param_name,
206                    min_value,
207                    max_value,
208                } => {
209                    if let Some(value) = parameters.constants.get(param_name) {
210                        if let ConstantValue::Integer(val) = value {
211                            if *val < *min_value || *val > *max_value {
212                                return Err(JitError::CompilationError(format!(
213                                    "Value constraint violated for parameter '{}'",
214                                    param_name
215                                )));
216                            }
217                        }
218                    }
219                }
220            }
221        }
222        Ok(())
223    }
224}
225
226/// Template parameter definition
227#[derive(Debug, Clone)]
228pub struct TemplateParameter {
229    pub name: String,
230    pub param_type: ParameterType,
231    pub default_value: Option<String>,
232    pub description: String,
233}
234
235/// Types of template parameters
236#[derive(Debug, Clone)]
237pub enum ParameterType {
238    Type,
239    Shape,
240    Constant,
241    CodeBlock,
242}
243
244/// Template constraints for validation
245#[derive(Debug, Clone)]
246pub enum TemplateConstraint {
247    TypeConstraint {
248        param_name: String,
249        allowed_types: Vec<DType>,
250    },
251    ShapeConstraint {
252        param_name: String,
253        dimension_count: usize,
254    },
255    ValueConstraint {
256        param_name: String,
257        min_value: i64,
258        max_value: i64,
259    },
260}
261
262/// Parameters for template instantiation
263#[derive(Debug, Clone)]
264pub struct TemplateParameters {
265    pub types: HashMap<String, DType>,
266    pub shapes: HashMap<String, Shape>,
267    pub constants: HashMap<String, ConstantValue>,
268    pub code_blocks: HashMap<String, String>,
269}
270
271impl TemplateParameters {
272    /// Create new empty parameters
273    pub fn new() -> Self {
274        Self {
275            types: HashMap::new(),
276            shapes: HashMap::new(),
277            constants: HashMap::new(),
278            code_blocks: HashMap::new(),
279        }
280    }
281
282    /// Add a type parameter
283    pub fn add_type(&mut self, name: String, dtype: DType) {
284        self.types.insert(name, dtype);
285    }
286
287    /// Add a shape parameter
288    pub fn add_shape(&mut self, name: String, shape: Shape) {
289        self.shapes.insert(name, shape);
290    }
291
292    /// Add a constant parameter
293    pub fn add_constant(&mut self, name: String, value: ConstantValue) {
294        self.constants.insert(name, value);
295    }
296
297    /// Add a code block parameter
298    pub fn add_code_block(&mut self, name: String, code: String) {
299        self.code_blocks.insert(name, code);
300    }
301}
302
303/// Constant values for templates
304#[derive(Debug, Clone)]
305pub enum ConstantValue {
306    Integer(i64),
307    Float(f64),
308    Boolean(bool),
309    String(String),
310}
311
312/// Macro definition for code expansion
313#[derive(Debug, Clone)]
314pub struct MacroDefinition {
315    pub name: String,
316    pub parameters: Vec<String>,
317    pub body: String,
318    pub expansion_rules: Vec<ExpansionRule>,
319}
320
321impl MacroDefinition {
322    /// Create a new macro definition
323    pub fn new(name: String, parameters: Vec<String>, body: String) -> Self {
324        Self {
325            name,
326            parameters,
327            body,
328            expansion_rules: Vec::new(),
329        }
330    }
331
332    /// Add an expansion rule
333    pub fn add_rule(&mut self, rule: ExpansionRule) {
334        self.expansion_rules.push(rule);
335    }
336
337    /// Expand the macro with given arguments
338    pub fn expand(
339        &self,
340        args: &[MacroArgument],
341        generator: &DynamicCodeGenerator,
342    ) -> JitResult<GeneratedCode> {
343        if args.len() != self.parameters.len() {
344            return Err(JitError::CompilationError(format!(
345                "Macro '{}' expects {} arguments, got {}",
346                self.name,
347                self.parameters.len(),
348                args.len()
349            )));
350        }
351
352        let mut expanded = self.body.clone();
353
354        // Replace parameters with arguments
355        for (param, arg) in self.parameters.iter().zip(args.iter()) {
356            let replacement = match arg {
357                MacroArgument::Code(code) => code.clone(),
358                MacroArgument::Literal(lit) => lit.clone(),
359                MacroArgument::Expression(expr) => generator.format_expression(expr),
360            };
361            expanded = expanded.replace(&format!("${}", param), &replacement);
362        }
363
364        // Apply expansion rules
365        for rule in &self.expansion_rules {
366            expanded = rule.apply(&expanded);
367        }
368
369        Ok(GeneratedCode {
370            source: expanded,
371            metadata: CodeMetadata {
372                template_name: self.name.clone(),
373                parameters: TemplateParameters::new(), // Macros don't use template parameters
374                generated_at: std::time::SystemTime::now(),
375            },
376        })
377    }
378}
379
380/// Macro arguments
381#[derive(Debug, Clone)]
382pub enum MacroArgument {
383    Code(String),
384    Literal(String),
385    Expression(String),
386}
387
388/// Rules for macro expansion
389#[derive(Debug, Clone)]
390pub struct ExpansionRule {
391    pub pattern: String,
392    pub replacement: String,
393}
394
395impl ExpansionRule {
396    /// Apply the expansion rule to code
397    pub fn apply(&self, code: &str) -> String {
398        code.replace(&self.pattern, &self.replacement)
399    }
400}
401
402/// Runtime reflection capabilities
403pub struct RuntimeReflector {
404    type_registry: HashMap<String, TypeInfo>,
405    operation_registry: HashMap<String, OperationInfo>,
406}
407
408impl RuntimeReflector {
409    /// Create a new runtime reflector
410    pub fn new() -> Self {
411        Self {
412            type_registry: HashMap::new(),
413            operation_registry: HashMap::new(),
414        }
415    }
416
417    /// Register type information
418    pub fn register_type(&mut self, name: String, info: TypeInfo) {
419        self.type_registry.insert(name, info);
420    }
421
422    /// Register operation information
423    pub fn register_operation(&mut self, name: String, info: OperationInfo) {
424        self.operation_registry.insert(name, info);
425    }
426
427    /// Reflect on a computation graph
428    pub fn reflect_graph(&self, graph: &ComputationGraph) -> GraphReflection {
429        let mut node_info = HashMap::new();
430        let mut edge_info = Vec::new();
431        let mut type_analysis = HashMap::new();
432
433        // Analyze nodes
434        for (node_id, node) in graph.nodes() {
435            // Derive input types from incoming edges
436            let input_types: Vec<DType> = graph
437                .get_node_inputs(node_id)
438                .iter()
439                .filter_map(|&input_id| graph.node(input_id).map(|n| n.dtype))
440                .collect();
441
442            // Infer type before moving input_types
443            let inferred_type = if !input_types.is_empty() {
444                // Use the most precise type among inputs
445                input_types[0]
446            } else {
447                node.dtype
448            };
449
450            let reflection = NodeReflection {
451                id: node_id,
452                operation: node.operation_type().to_string(),
453                input_types,
454                output_type: node.dtype,
455                output_shape: node.output_shape.clone(),
456                metadata: self.get_operation_metadata(&node.operation_type()),
457            };
458            node_info.insert(node_id, reflection);
459
460            type_analysis.insert(
461                node_id,
462                TypeAnalysis {
463                    declared_type: node.dtype,
464                    inferred_type,
465                    type_constraints: Vec::new(),
466                },
467            );
468        }
469
470        // Analyze edges
471        for (from, to, _edge_data) in graph.edges() {
472            // Get actual edge type and shape from source node
473            let (data_type, tensor_shape) = graph
474                .node(from)
475                .map(|source_node| (source_node.dtype, source_node.output_shape.clone()))
476                .unwrap_or((DType::F32, Shape::new(vec![])));
477
478            edge_info.push(EdgeReflection {
479                from,
480                to,
481                data_type,
482                tensor_shape,
483            });
484        }
485
486        GraphReflection {
487            node_info,
488            edge_info,
489            type_analysis,
490            graph_properties: self.analyze_graph_properties(graph),
491        }
492    }
493
494    /// Get operation metadata
495    fn get_operation_metadata(&self, op_name: &str) -> Option<OperationInfo> {
496        self.operation_registry.get(op_name).cloned()
497    }
498
499    /// Analyze graph properties
500    fn analyze_graph_properties(&self, graph: &ComputationGraph) -> GraphProperties {
501        // Detect control flow by checking for control flow operations
502        let has_control_flow = graph.nodes().any(|(_, node)| {
503            matches!(
504                node.operation_type(),
505                "If" | "While" | "For" | "Loop" | "Branch" | "Cond" | "Switch"
506            )
507        });
508
509        GraphProperties {
510            node_count: graph.node_count(),
511            edge_count: graph.edge_count(),
512            is_acyclic: graph.is_acyclic(),
513            has_control_flow,
514            complexity_estimate: self.estimate_complexity(graph),
515        }
516    }
517
518    /// Estimate computational complexity
519    fn estimate_complexity(&self, graph: &ComputationGraph) -> ComplexityEstimate {
520        let mut total_ops = 0;
521        let mut memory_usage = 0;
522
523        for (_, node) in graph.nodes() {
524            // Estimate operations based on node type
525            let ops = match node.op.as_str() {
526                "add" | "sub" | "mul" | "div" => 1,
527                "matmul" => node.output_shape.size(0).unwrap_or(1).pow(3), // O(n^3) for matrix multiplication
528                "conv2d" => node.output_shape.size(0).unwrap_or(1) * 9,    // Rough estimate
529                _ => 1,
530            };
531            total_ops += ops;
532
533            // Estimate memory usage
534            memory_usage += node.output_shape.size(0).unwrap_or(1) * node.dtype.size_bytes();
535        }
536
537        ComplexityEstimate {
538            operation_count: total_ops,
539            memory_usage_bytes: memory_usage,
540            estimated_flops: total_ops as f64,
541        }
542    }
543}
544
545/// Information about types
546#[derive(Debug, Clone)]
547pub struct TypeInfo {
548    pub name: String,
549    pub size_bytes: usize,
550    pub alignment: usize,
551    pub is_numeric: bool,
552    pub is_floating_point: bool,
553}
554
555/// Information about operations
556#[derive(Debug, Clone)]
557pub struct OperationInfo {
558    pub name: String,
559    pub input_count: usize,
560    pub output_count: usize,
561    pub is_commutative: bool,
562    pub is_associative: bool,
563    pub complexity_class: ComplexityClass,
564}
565
566/// Complexity classifications
567#[derive(Debug, Clone)]
568pub enum ComplexityClass {
569    Constant,
570    Linear,
571    Quadratic,
572    Cubic,
573    Exponential,
574}
575
576/// Dynamic code generator
577pub struct DynamicCodeGenerator {
578    backend: CodegenBackend,
579}
580
581impl DynamicCodeGenerator {
582    /// Create a new dynamic code generator
583    pub fn new() -> Self {
584        Self {
585            backend: CodegenBackend::Rust,
586        }
587    }
588
589    /// Set the code generation backend
590    pub fn set_backend(&mut self, backend: CodegenBackend) {
591        self.backend = backend;
592    }
593
594    /// Format a type for the current backend
595    pub fn format_type(&self, dtype: &DType) -> String {
596        match self.backend {
597            CodegenBackend::Rust => match dtype {
598                DType::F32 => "f32".to_string(),
599                DType::F64 => "f64".to_string(),
600                DType::I32 => "i32".to_string(),
601                DType::I64 => "i64".to_string(),
602                DType::Bool => "bool".to_string(),
603                _ => "f32".to_string(), // Default fallback
604            },
605            CodegenBackend::C => match dtype {
606                DType::F32 => "float".to_string(),
607                DType::F64 => "double".to_string(),
608                DType::I32 => "int".to_string(),
609                DType::I64 => "long".to_string(),
610                DType::Bool => "bool".to_string(),
611                _ => "float".to_string(),
612            },
613        }
614    }
615
616    /// Format a shape for the current backend
617    pub fn format_shape(&self, shape: &Shape) -> String {
618        let dims: Vec<String> = shape.dims().iter().map(|d| d.to_string()).collect();
619        format!("[{}]", dims.join(", "))
620    }
621
622    /// Format a constant for the current backend
623    pub fn format_constant(&self, value: &ConstantValue) -> String {
624        match value {
625            ConstantValue::Integer(i) => i.to_string(),
626            ConstantValue::Float(f) => f.to_string(),
627            ConstantValue::Boolean(b) => b.to_string(),
628            ConstantValue::String(s) => format!("\"{}\"", s),
629        }
630    }
631
632    /// Format an expression for the current backend
633    pub fn format_expression(&self, expr: &str) -> String {
634        // For now, just return the expression as-is
635        // In a real implementation, this would parse and transform the expression
636        expr.to_string()
637    }
638}
639
640/// Code generation backends
641#[derive(Debug, Clone)]
642pub enum CodegenBackend {
643    Rust,
644    C,
645}
646
647/// Generated code with metadata
648#[derive(Debug, Clone)]
649pub struct GeneratedCode {
650    pub source: String,
651    pub metadata: CodeMetadata,
652}
653
654/// Metadata for generated code
655#[derive(Debug, Clone)]
656pub struct CodeMetadata {
657    pub template_name: String,
658    pub parameters: TemplateParameters,
659    pub generated_at: std::time::SystemTime,
660}
661
662/// Specialization information for code generation
663#[derive(Debug, Clone)]
664pub struct SpecializationInfo {
665    pub type_info: HashMap<String, DType>,
666    pub shape_info: HashMap<String, Shape>,
667    pub constants: HashMap<String, ConstantValue>,
668}
669
670/// Graph reflection information
671#[derive(Debug)]
672pub struct GraphReflection {
673    pub node_info: HashMap<NodeId, NodeReflection>,
674    pub edge_info: Vec<EdgeReflection>,
675    pub type_analysis: HashMap<NodeId, TypeAnalysis>,
676    pub graph_properties: GraphProperties,
677}
678
679/// Reflection information for a single node
680#[derive(Debug)]
681pub struct NodeReflection {
682    pub id: NodeId,
683    pub operation: String,
684    pub input_types: Vec<DType>,
685    pub output_type: DType,
686    pub output_shape: Shape,
687    pub metadata: Option<OperationInfo>,
688}
689
690/// Reflection information for an edge
691#[derive(Debug)]
692pub struct EdgeReflection {
693    pub from: NodeId,
694    pub to: NodeId,
695    pub data_type: DType,
696    pub tensor_shape: Shape,
697}
698
699/// Type analysis information
700#[derive(Debug)]
701pub struct TypeAnalysis {
702    pub declared_type: DType,
703    pub inferred_type: DType,
704    pub type_constraints: Vec<String>,
705}
706
707/// Graph-level properties
708#[derive(Debug)]
709pub struct GraphProperties {
710    pub node_count: usize,
711    pub edge_count: usize,
712    pub is_acyclic: bool,
713    pub has_control_flow: bool,
714    pub complexity_estimate: ComplexityEstimate,
715}
716
717/// Computational complexity estimate
718#[derive(Debug)]
719pub struct ComplexityEstimate {
720    pub operation_count: usize,
721    pub memory_usage_bytes: usize,
722    pub estimated_flops: f64,
723}
724
725/// Create a standard element-wise operation template
726pub fn create_elementwise_template(op_name: &str) -> CodeTemplate {
727    let template_string = format!(
728        r#"
729fn {}(a: &Tensor<${{T}}>, b: &Tensor<${{T}}>) -> Tensor<${{T}}> {{
730    let shape = a.shape();
731    let mut result = Tensor::zeros(shape.clone());
732    
733    for i in 0..shape.size(0).unwrap_or(1) {{
734        let a_val = a.data()[i];
735        let b_val = b.data()[i];
736        result.data_mut()[i] = a_val {} b_val;
737    }}
738    
739    result
740}}
741"#,
742        op_name,
743        match op_name {
744            "add" => "+",
745            "sub" => "-",
746            "mul" => "*",
747            "div" => "/",
748            _ => "+",
749        }
750    );
751
752    let mut template = CodeTemplate::new(format!("{}_template", op_name), template_string);
753
754    template.add_parameter(TemplateParameter {
755        name: "T".to_string(),
756        param_type: ParameterType::Type,
757        default_value: Some("f32".to_string()),
758        description: "Element type".to_string(),
759    });
760
761    template.add_constraint(TemplateConstraint::TypeConstraint {
762        param_name: "T".to_string(),
763        allowed_types: vec![DType::F32, DType::F64, DType::I32, DType::I64],
764    });
765
766    template
767}
768
769/// Create a convolution operation template
770pub fn create_conv2d_template() -> CodeTemplate {
771    let template_string = r#"
772fn conv2d(
773    input: &Tensor<${T}>, 
774    weight: &Tensor<${T}>,
775    stride: $const{stride},
776    padding: $const{padding}
777) -> Tensor<${T}> {
778    let batch_size = input.shape()[0];
779    let in_channels = input.shape()[1];
780    let in_height = input.shape()[2];
781    let in_width = input.shape()[3];
782    
783    let out_channels = weight.shape()[0];
784    let kernel_height = weight.shape()[2];
785    let kernel_width = weight.shape()[3];
786    
787    let out_height = (in_height + 2 * padding - kernel_height) / stride + 1;
788    let out_width = (in_width + 2 * padding - kernel_width) / stride + 1;
789    
790    let output_shape = vec![batch_size, out_channels, out_height, out_width];
791    let mut output = Tensor::zeros(output_shape);
792    
793    $code{convolution_kernel}
794    
795    output
796}
797"#
798    .to_string();
799
800    let mut template = CodeTemplate::new("conv2d_template".to_string(), template_string);
801
802    template.add_parameter(TemplateParameter {
803        name: "T".to_string(),
804        param_type: ParameterType::Type,
805        default_value: Some("f32".to_string()),
806        description: "Element type".to_string(),
807    });
808
809    template.add_parameter(TemplateParameter {
810        name: "stride".to_string(),
811        param_type: ParameterType::Constant,
812        default_value: Some("1".to_string()),
813        description: "Convolution stride".to_string(),
814    });
815
816    template.add_parameter(TemplateParameter {
817        name: "padding".to_string(),
818        param_type: ParameterType::Constant,
819        default_value: Some("0".to_string()),
820        description: "Convolution padding".to_string(),
821    });
822
823    template.add_parameter(TemplateParameter {
824        name: "convolution_kernel".to_string(),
825        param_type: ParameterType::CodeBlock,
826        default_value: None,
827        description: "Convolution kernel implementation".to_string(),
828    });
829
830    template
831}
832
833#[cfg(test)]
834mod tests {
835    use super::*;
836
837    #[test]
838    fn test_template_creation() {
839        let template = create_elementwise_template("add");
840        assert_eq!(template.name, "add_template");
841        assert_eq!(template.parameters.len(), 1);
842        assert_eq!(template.constraints.len(), 1);
843    }
844
845    #[test]
846    fn test_template_instantiation() {
847        let template = create_elementwise_template("add");
848        let generator = DynamicCodeGenerator::new();
849
850        let mut params = TemplateParameters::new();
851        params.add_type("T".to_string(), DType::F32);
852
853        let result = template.instantiate(&params, &generator);
854        assert!(result.is_ok());
855
856        let code = result.unwrap();
857        assert!(code.source.contains("f32"));
858        assert!(code.source.contains("fn add"));
859    }
860
861    #[test]
862    fn test_metaprogramming_engine() {
863        let mut engine = MetaprogrammingEngine::new();
864        let template = create_elementwise_template("mul");
865        engine.register_template("mul_op".to_string(), template);
866
867        let mut params = TemplateParameters::new();
868        params.add_type("T".to_string(), DType::F64);
869
870        let result = engine.generate_from_template("mul_op", &params);
871        assert!(result.is_ok());
872
873        let code = result.unwrap();
874        assert!(code.source.contains("f64"));
875        assert!(code.source.contains("fn mul"));
876    }
877
878    #[test]
879    fn test_macro_definition() {
880        let macro_def = MacroDefinition::new(
881            "BINARY_OP".to_string(),
882            vec!["op".to_string(), "T".to_string()],
883            "fn $op(a: $T, b: $T) -> $T { a $op b }".to_string(),
884        );
885
886        let generator = DynamicCodeGenerator::new();
887        let args = vec![
888            MacroArgument::Literal("+".to_string()),
889            MacroArgument::Literal("f32".to_string()),
890        ];
891
892        let result = macro_def.expand(&args, &generator);
893        assert!(result.is_ok());
894
895        let code = result.unwrap();
896        assert!(code.source.contains("fn +(a: f32, b: f32) -> f32"));
897    }
898
899    #[test]
900    fn test_constant_value_formatting() {
901        let generator = DynamicCodeGenerator::new();
902
903        assert_eq!(generator.format_constant(&ConstantValue::Integer(42)), "42");
904        assert_eq!(
905            generator.format_constant(&ConstantValue::Float(3.14)),
906            "3.14"
907        );
908        assert_eq!(
909            generator.format_constant(&ConstantValue::Boolean(true)),
910            "true"
911        );
912        assert_eq!(
913            generator.format_constant(&ConstantValue::String("hello".to_string())),
914            "\"hello\""
915        );
916    }
917
918    #[test]
919    fn test_template_constraints() {
920        let mut template = CodeTemplate::new("test_template".to_string(), "test ${T}".to_string());
921
922        template.add_constraint(TemplateConstraint::TypeConstraint {
923            param_name: "T".to_string(),
924            allowed_types: vec![DType::F32],
925        });
926
927        let generator = DynamicCodeGenerator::new();
928
929        // Valid parameters
930        let mut valid_params = TemplateParameters::new();
931        valid_params.add_type("T".to_string(), DType::F32);
932
933        let result = template.instantiate(&valid_params, &generator);
934        assert!(result.is_ok());
935
936        // Invalid parameters
937        let mut invalid_params = TemplateParameters::new();
938        invalid_params.add_type("T".to_string(), DType::I32);
939
940        let result = template.instantiate(&invalid_params, &generator);
941        assert!(result.is_err());
942    }
943}