1use crate::{ComputationGraph, JitError, JitResult, NodeId};
10use std::collections::HashMap;
11use torsh_core::{DType, Shape};
13
14pub struct MetaprogrammingEngine {
16 templates: HashMap<String, CodeTemplate>,
17 macros: HashMap<String, MacroDefinition>,
18 reflector: RuntimeReflector,
19 code_generator: DynamicCodeGenerator,
20}
21
22impl MetaprogrammingEngine {
23 pub fn new() -> Self {
25 Self {
26 templates: HashMap::new(),
27 macros: HashMap::new(),
28 reflector: RuntimeReflector::new(),
29 code_generator: DynamicCodeGenerator::new(),
30 }
31 }
32
33 pub fn register_template(&mut self, name: String, template: CodeTemplate) {
35 self.templates.insert(name, template);
36 }
37
38 pub fn register_macro(&mut self, name: String, macro_def: MacroDefinition) {
40 self.macros.insert(name, macro_def);
41 }
42
43 pub fn generate_from_template(
45 &self,
46 template_name: &str,
47 parameters: &TemplateParameters,
48 ) -> JitResult<GeneratedCode> {
49 let template = self.templates.get(template_name).ok_or_else(|| {
50 JitError::CompilationError(format!("Template '{}' not found", template_name))
51 })?;
52
53 template.instantiate(parameters, &self.code_generator)
54 }
55
56 pub fn expand_macro(
58 &self,
59 macro_name: &str,
60 args: &[MacroArgument],
61 ) -> JitResult<GeneratedCode> {
62 let macro_def = self.macros.get(macro_name).ok_or_else(|| {
63 JitError::CompilationError(format!("Macro '{}' not found", macro_name))
64 })?;
65
66 macro_def.expand(args, &self.code_generator)
67 }
68
69 pub fn reflect_graph(&self, graph: &ComputationGraph) -> GraphReflection {
71 self.reflector.reflect_graph(graph)
72 }
73
74 pub fn generate_specialized_code(
76 &self,
77 base_template: &str,
78 specialization_info: &SpecializationInfo,
79 ) -> JitResult<GeneratedCode> {
80 let mut params = TemplateParameters::new();
81
82 for (key, value) in &specialization_info.type_info {
84 params.add_type(key.clone(), value.clone());
85 }
86
87 for (key, value) in &specialization_info.shape_info {
88 params.add_shape(key.clone(), value.clone());
89 }
90
91 for (key, value) in &specialization_info.constants {
92 params.add_constant(key.clone(), value.clone());
93 }
94
95 self.generate_from_template(base_template, ¶ms)
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct CodeTemplate {
102 pub name: String,
103 pub template_string: String,
104 pub parameters: Vec<TemplateParameter>,
105 pub constraints: Vec<TemplateConstraint>,
106}
107
108impl CodeTemplate {
109 pub fn new(name: String, template_string: String) -> Self {
111 Self {
112 name,
113 template_string,
114 parameters: Vec::new(),
115 constraints: Vec::new(),
116 }
117 }
118
119 pub fn add_parameter(&mut self, param: TemplateParameter) {
121 self.parameters.push(param);
122 }
123
124 pub fn add_constraint(&mut self, constraint: TemplateConstraint) {
126 self.constraints.push(constraint);
127 }
128
129 pub fn instantiate(
131 &self,
132 parameters: &TemplateParameters,
133 generator: &DynamicCodeGenerator,
134 ) -> JitResult<GeneratedCode> {
135 self.validate_constraints(parameters)?;
137
138 let mut code = self.template_string.clone();
140
141 for (name, dtype) in ¶meters.types {
143 let replacement = generator.format_type(dtype);
144 code = code.replace(&format!("${{{}}}", name), &replacement);
145 }
146
147 for (name, shape) in ¶meters.shapes {
149 let replacement = generator.format_shape(shape);
150 code = code.replace(&format!("$shape{{{}}}", name), &replacement);
151 }
152
153 for (name, value) in ¶meters.constants {
155 let replacement = generator.format_constant(value);
156 code = code.replace(&format!("$const{{{}}}", name), &replacement);
157 }
158
159 for (name, block) in ¶meters.code_blocks {
161 code = code.replace(&format!("$code{{{}}}", name), block);
162 }
163
164 Ok(GeneratedCode {
165 source: code,
166 metadata: CodeMetadata {
167 template_name: self.name.clone(),
168 parameters: parameters.clone(),
169 generated_at: std::time::SystemTime::now(),
170 },
171 })
172 }
173
174 fn validate_constraints(&self, parameters: &TemplateParameters) -> JitResult<()> {
176 for constraint in &self.constraints {
177 match constraint {
178 TemplateConstraint::TypeConstraint {
179 param_name,
180 allowed_types,
181 } => {
182 if let Some(dtype) = parameters.types.get(param_name) {
183 if !allowed_types.contains(dtype) {
184 return Err(JitError::CompilationError(format!(
185 "Type constraint violated for parameter '{}'",
186 param_name
187 )));
188 }
189 }
190 }
191 TemplateConstraint::ShapeConstraint {
192 param_name,
193 dimension_count,
194 } => {
195 if let Some(shape) = parameters.shapes.get(param_name) {
196 if shape.ndim() != *dimension_count {
197 return Err(JitError::CompilationError(format!(
198 "Shape constraint violated for parameter '{}'",
199 param_name
200 )));
201 }
202 }
203 }
204 TemplateConstraint::ValueConstraint {
205 param_name,
206 min_value,
207 max_value,
208 } => {
209 if let Some(value) = parameters.constants.get(param_name) {
210 if let ConstantValue::Integer(val) = value {
211 if *val < *min_value || *val > *max_value {
212 return Err(JitError::CompilationError(format!(
213 "Value constraint violated for parameter '{}'",
214 param_name
215 )));
216 }
217 }
218 }
219 }
220 }
221 }
222 Ok(())
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct TemplateParameter {
229 pub name: String,
230 pub param_type: ParameterType,
231 pub default_value: Option<String>,
232 pub description: String,
233}
234
235#[derive(Debug, Clone)]
237pub enum ParameterType {
238 Type,
239 Shape,
240 Constant,
241 CodeBlock,
242}
243
244#[derive(Debug, Clone)]
246pub enum TemplateConstraint {
247 TypeConstraint {
248 param_name: String,
249 allowed_types: Vec<DType>,
250 },
251 ShapeConstraint {
252 param_name: String,
253 dimension_count: usize,
254 },
255 ValueConstraint {
256 param_name: String,
257 min_value: i64,
258 max_value: i64,
259 },
260}
261
262#[derive(Debug, Clone)]
264pub struct TemplateParameters {
265 pub types: HashMap<String, DType>,
266 pub shapes: HashMap<String, Shape>,
267 pub constants: HashMap<String, ConstantValue>,
268 pub code_blocks: HashMap<String, String>,
269}
270
271impl TemplateParameters {
272 pub fn new() -> Self {
274 Self {
275 types: HashMap::new(),
276 shapes: HashMap::new(),
277 constants: HashMap::new(),
278 code_blocks: HashMap::new(),
279 }
280 }
281
282 pub fn add_type(&mut self, name: String, dtype: DType) {
284 self.types.insert(name, dtype);
285 }
286
287 pub fn add_shape(&mut self, name: String, shape: Shape) {
289 self.shapes.insert(name, shape);
290 }
291
292 pub fn add_constant(&mut self, name: String, value: ConstantValue) {
294 self.constants.insert(name, value);
295 }
296
297 pub fn add_code_block(&mut self, name: String, code: String) {
299 self.code_blocks.insert(name, code);
300 }
301}
302
303#[derive(Debug, Clone)]
305pub enum ConstantValue {
306 Integer(i64),
307 Float(f64),
308 Boolean(bool),
309 String(String),
310}
311
312#[derive(Debug, Clone)]
314pub struct MacroDefinition {
315 pub name: String,
316 pub parameters: Vec<String>,
317 pub body: String,
318 pub expansion_rules: Vec<ExpansionRule>,
319}
320
321impl MacroDefinition {
322 pub fn new(name: String, parameters: Vec<String>, body: String) -> Self {
324 Self {
325 name,
326 parameters,
327 body,
328 expansion_rules: Vec::new(),
329 }
330 }
331
332 pub fn add_rule(&mut self, rule: ExpansionRule) {
334 self.expansion_rules.push(rule);
335 }
336
337 pub fn expand(
339 &self,
340 args: &[MacroArgument],
341 generator: &DynamicCodeGenerator,
342 ) -> JitResult<GeneratedCode> {
343 if args.len() != self.parameters.len() {
344 return Err(JitError::CompilationError(format!(
345 "Macro '{}' expects {} arguments, got {}",
346 self.name,
347 self.parameters.len(),
348 args.len()
349 )));
350 }
351
352 let mut expanded = self.body.clone();
353
354 for (param, arg) in self.parameters.iter().zip(args.iter()) {
356 let replacement = match arg {
357 MacroArgument::Code(code) => code.clone(),
358 MacroArgument::Literal(lit) => lit.clone(),
359 MacroArgument::Expression(expr) => generator.format_expression(expr),
360 };
361 expanded = expanded.replace(&format!("${}", param), &replacement);
362 }
363
364 for rule in &self.expansion_rules {
366 expanded = rule.apply(&expanded);
367 }
368
369 Ok(GeneratedCode {
370 source: expanded,
371 metadata: CodeMetadata {
372 template_name: self.name.clone(),
373 parameters: TemplateParameters::new(), generated_at: std::time::SystemTime::now(),
375 },
376 })
377 }
378}
379
380#[derive(Debug, Clone)]
382pub enum MacroArgument {
383 Code(String),
384 Literal(String),
385 Expression(String),
386}
387
388#[derive(Debug, Clone)]
390pub struct ExpansionRule {
391 pub pattern: String,
392 pub replacement: String,
393}
394
395impl ExpansionRule {
396 pub fn apply(&self, code: &str) -> String {
398 code.replace(&self.pattern, &self.replacement)
399 }
400}
401
402pub struct RuntimeReflector {
404 type_registry: HashMap<String, TypeInfo>,
405 operation_registry: HashMap<String, OperationInfo>,
406}
407
408impl RuntimeReflector {
409 pub fn new() -> Self {
411 Self {
412 type_registry: HashMap::new(),
413 operation_registry: HashMap::new(),
414 }
415 }
416
417 pub fn register_type(&mut self, name: String, info: TypeInfo) {
419 self.type_registry.insert(name, info);
420 }
421
422 pub fn register_operation(&mut self, name: String, info: OperationInfo) {
424 self.operation_registry.insert(name, info);
425 }
426
427 pub fn reflect_graph(&self, graph: &ComputationGraph) -> GraphReflection {
429 let mut node_info = HashMap::new();
430 let mut edge_info = Vec::new();
431 let mut type_analysis = HashMap::new();
432
433 for (node_id, node) in graph.nodes() {
435 let input_types: Vec<DType> = graph
437 .get_node_inputs(node_id)
438 .iter()
439 .filter_map(|&input_id| graph.node(input_id).map(|n| n.dtype))
440 .collect();
441
442 let inferred_type = if !input_types.is_empty() {
444 input_types[0]
446 } else {
447 node.dtype
448 };
449
450 let reflection = NodeReflection {
451 id: node_id,
452 operation: node.operation_type().to_string(),
453 input_types,
454 output_type: node.dtype,
455 output_shape: node.output_shape.clone(),
456 metadata: self.get_operation_metadata(&node.operation_type()),
457 };
458 node_info.insert(node_id, reflection);
459
460 type_analysis.insert(
461 node_id,
462 TypeAnalysis {
463 declared_type: node.dtype,
464 inferred_type,
465 type_constraints: Vec::new(),
466 },
467 );
468 }
469
470 for (from, to, _edge_data) in graph.edges() {
472 let (data_type, tensor_shape) = graph
474 .node(from)
475 .map(|source_node| (source_node.dtype, source_node.output_shape.clone()))
476 .unwrap_or((DType::F32, Shape::new(vec![])));
477
478 edge_info.push(EdgeReflection {
479 from,
480 to,
481 data_type,
482 tensor_shape,
483 });
484 }
485
486 GraphReflection {
487 node_info,
488 edge_info,
489 type_analysis,
490 graph_properties: self.analyze_graph_properties(graph),
491 }
492 }
493
494 fn get_operation_metadata(&self, op_name: &str) -> Option<OperationInfo> {
496 self.operation_registry.get(op_name).cloned()
497 }
498
499 fn analyze_graph_properties(&self, graph: &ComputationGraph) -> GraphProperties {
501 let has_control_flow = graph.nodes().any(|(_, node)| {
503 matches!(
504 node.operation_type(),
505 "If" | "While" | "For" | "Loop" | "Branch" | "Cond" | "Switch"
506 )
507 });
508
509 GraphProperties {
510 node_count: graph.node_count(),
511 edge_count: graph.edge_count(),
512 is_acyclic: graph.is_acyclic(),
513 has_control_flow,
514 complexity_estimate: self.estimate_complexity(graph),
515 }
516 }
517
518 fn estimate_complexity(&self, graph: &ComputationGraph) -> ComplexityEstimate {
520 let mut total_ops = 0;
521 let mut memory_usage = 0;
522
523 for (_, node) in graph.nodes() {
524 let ops = match node.op.as_str() {
526 "add" | "sub" | "mul" | "div" => 1,
527 "matmul" => node.output_shape.size(0).unwrap_or(1).pow(3), "conv2d" => node.output_shape.size(0).unwrap_or(1) * 9, _ => 1,
530 };
531 total_ops += ops;
532
533 memory_usage += node.output_shape.size(0).unwrap_or(1) * node.dtype.size_bytes();
535 }
536
537 ComplexityEstimate {
538 operation_count: total_ops,
539 memory_usage_bytes: memory_usage,
540 estimated_flops: total_ops as f64,
541 }
542 }
543}
544
545#[derive(Debug, Clone)]
547pub struct TypeInfo {
548 pub name: String,
549 pub size_bytes: usize,
550 pub alignment: usize,
551 pub is_numeric: bool,
552 pub is_floating_point: bool,
553}
554
555#[derive(Debug, Clone)]
557pub struct OperationInfo {
558 pub name: String,
559 pub input_count: usize,
560 pub output_count: usize,
561 pub is_commutative: bool,
562 pub is_associative: bool,
563 pub complexity_class: ComplexityClass,
564}
565
566#[derive(Debug, Clone)]
568pub enum ComplexityClass {
569 Constant,
570 Linear,
571 Quadratic,
572 Cubic,
573 Exponential,
574}
575
576pub struct DynamicCodeGenerator {
578 backend: CodegenBackend,
579}
580
581impl DynamicCodeGenerator {
582 pub fn new() -> Self {
584 Self {
585 backend: CodegenBackend::Rust,
586 }
587 }
588
589 pub fn set_backend(&mut self, backend: CodegenBackend) {
591 self.backend = backend;
592 }
593
594 pub fn format_type(&self, dtype: &DType) -> String {
596 match self.backend {
597 CodegenBackend::Rust => match dtype {
598 DType::F32 => "f32".to_string(),
599 DType::F64 => "f64".to_string(),
600 DType::I32 => "i32".to_string(),
601 DType::I64 => "i64".to_string(),
602 DType::Bool => "bool".to_string(),
603 _ => "f32".to_string(), },
605 CodegenBackend::C => match dtype {
606 DType::F32 => "float".to_string(),
607 DType::F64 => "double".to_string(),
608 DType::I32 => "int".to_string(),
609 DType::I64 => "long".to_string(),
610 DType::Bool => "bool".to_string(),
611 _ => "float".to_string(),
612 },
613 }
614 }
615
616 pub fn format_shape(&self, shape: &Shape) -> String {
618 let dims: Vec<String> = shape.dims().iter().map(|d| d.to_string()).collect();
619 format!("[{}]", dims.join(", "))
620 }
621
622 pub fn format_constant(&self, value: &ConstantValue) -> String {
624 match value {
625 ConstantValue::Integer(i) => i.to_string(),
626 ConstantValue::Float(f) => f.to_string(),
627 ConstantValue::Boolean(b) => b.to_string(),
628 ConstantValue::String(s) => format!("\"{}\"", s),
629 }
630 }
631
632 pub fn format_expression(&self, expr: &str) -> String {
634 expr.to_string()
637 }
638}
639
640#[derive(Debug, Clone)]
642pub enum CodegenBackend {
643 Rust,
644 C,
645}
646
647#[derive(Debug, Clone)]
649pub struct GeneratedCode {
650 pub source: String,
651 pub metadata: CodeMetadata,
652}
653
654#[derive(Debug, Clone)]
656pub struct CodeMetadata {
657 pub template_name: String,
658 pub parameters: TemplateParameters,
659 pub generated_at: std::time::SystemTime,
660}
661
662#[derive(Debug, Clone)]
664pub struct SpecializationInfo {
665 pub type_info: HashMap<String, DType>,
666 pub shape_info: HashMap<String, Shape>,
667 pub constants: HashMap<String, ConstantValue>,
668}
669
670#[derive(Debug)]
672pub struct GraphReflection {
673 pub node_info: HashMap<NodeId, NodeReflection>,
674 pub edge_info: Vec<EdgeReflection>,
675 pub type_analysis: HashMap<NodeId, TypeAnalysis>,
676 pub graph_properties: GraphProperties,
677}
678
679#[derive(Debug)]
681pub struct NodeReflection {
682 pub id: NodeId,
683 pub operation: String,
684 pub input_types: Vec<DType>,
685 pub output_type: DType,
686 pub output_shape: Shape,
687 pub metadata: Option<OperationInfo>,
688}
689
690#[derive(Debug)]
692pub struct EdgeReflection {
693 pub from: NodeId,
694 pub to: NodeId,
695 pub data_type: DType,
696 pub tensor_shape: Shape,
697}
698
699#[derive(Debug)]
701pub struct TypeAnalysis {
702 pub declared_type: DType,
703 pub inferred_type: DType,
704 pub type_constraints: Vec<String>,
705}
706
707#[derive(Debug)]
709pub struct GraphProperties {
710 pub node_count: usize,
711 pub edge_count: usize,
712 pub is_acyclic: bool,
713 pub has_control_flow: bool,
714 pub complexity_estimate: ComplexityEstimate,
715}
716
717#[derive(Debug)]
719pub struct ComplexityEstimate {
720 pub operation_count: usize,
721 pub memory_usage_bytes: usize,
722 pub estimated_flops: f64,
723}
724
725pub fn create_elementwise_template(op_name: &str) -> CodeTemplate {
727 let template_string = format!(
728 r#"
729fn {}(a: &Tensor<${{T}}>, b: &Tensor<${{T}}>) -> Tensor<${{T}}> {{
730 let shape = a.shape();
731 let mut result = Tensor::zeros(shape.clone());
732
733 for i in 0..shape.size(0).unwrap_or(1) {{
734 let a_val = a.data()[i];
735 let b_val = b.data()[i];
736 result.data_mut()[i] = a_val {} b_val;
737 }}
738
739 result
740}}
741"#,
742 op_name,
743 match op_name {
744 "add" => "+",
745 "sub" => "-",
746 "mul" => "*",
747 "div" => "/",
748 _ => "+",
749 }
750 );
751
752 let mut template = CodeTemplate::new(format!("{}_template", op_name), template_string);
753
754 template.add_parameter(TemplateParameter {
755 name: "T".to_string(),
756 param_type: ParameterType::Type,
757 default_value: Some("f32".to_string()),
758 description: "Element type".to_string(),
759 });
760
761 template.add_constraint(TemplateConstraint::TypeConstraint {
762 param_name: "T".to_string(),
763 allowed_types: vec![DType::F32, DType::F64, DType::I32, DType::I64],
764 });
765
766 template
767}
768
769pub fn create_conv2d_template() -> CodeTemplate {
771 let template_string = r#"
772fn conv2d(
773 input: &Tensor<${T}>,
774 weight: &Tensor<${T}>,
775 stride: $const{stride},
776 padding: $const{padding}
777) -> Tensor<${T}> {
778 let batch_size = input.shape()[0];
779 let in_channels = input.shape()[1];
780 let in_height = input.shape()[2];
781 let in_width = input.shape()[3];
782
783 let out_channels = weight.shape()[0];
784 let kernel_height = weight.shape()[2];
785 let kernel_width = weight.shape()[3];
786
787 let out_height = (in_height + 2 * padding - kernel_height) / stride + 1;
788 let out_width = (in_width + 2 * padding - kernel_width) / stride + 1;
789
790 let output_shape = vec![batch_size, out_channels, out_height, out_width];
791 let mut output = Tensor::zeros(output_shape);
792
793 $code{convolution_kernel}
794
795 output
796}
797"#
798 .to_string();
799
800 let mut template = CodeTemplate::new("conv2d_template".to_string(), template_string);
801
802 template.add_parameter(TemplateParameter {
803 name: "T".to_string(),
804 param_type: ParameterType::Type,
805 default_value: Some("f32".to_string()),
806 description: "Element type".to_string(),
807 });
808
809 template.add_parameter(TemplateParameter {
810 name: "stride".to_string(),
811 param_type: ParameterType::Constant,
812 default_value: Some("1".to_string()),
813 description: "Convolution stride".to_string(),
814 });
815
816 template.add_parameter(TemplateParameter {
817 name: "padding".to_string(),
818 param_type: ParameterType::Constant,
819 default_value: Some("0".to_string()),
820 description: "Convolution padding".to_string(),
821 });
822
823 template.add_parameter(TemplateParameter {
824 name: "convolution_kernel".to_string(),
825 param_type: ParameterType::CodeBlock,
826 default_value: None,
827 description: "Convolution kernel implementation".to_string(),
828 });
829
830 template
831}
832
833#[cfg(test)]
834mod tests {
835 use super::*;
836
837 #[test]
838 fn test_template_creation() {
839 let template = create_elementwise_template("add");
840 assert_eq!(template.name, "add_template");
841 assert_eq!(template.parameters.len(), 1);
842 assert_eq!(template.constraints.len(), 1);
843 }
844
845 #[test]
846 fn test_template_instantiation() {
847 let template = create_elementwise_template("add");
848 let generator = DynamicCodeGenerator::new();
849
850 let mut params = TemplateParameters::new();
851 params.add_type("T".to_string(), DType::F32);
852
853 let result = template.instantiate(¶ms, &generator);
854 assert!(result.is_ok());
855
856 let code = result.unwrap();
857 assert!(code.source.contains("f32"));
858 assert!(code.source.contains("fn add"));
859 }
860
861 #[test]
862 fn test_metaprogramming_engine() {
863 let mut engine = MetaprogrammingEngine::new();
864 let template = create_elementwise_template("mul");
865 engine.register_template("mul_op".to_string(), template);
866
867 let mut params = TemplateParameters::new();
868 params.add_type("T".to_string(), DType::F64);
869
870 let result = engine.generate_from_template("mul_op", ¶ms);
871 assert!(result.is_ok());
872
873 let code = result.unwrap();
874 assert!(code.source.contains("f64"));
875 assert!(code.source.contains("fn mul"));
876 }
877
878 #[test]
879 fn test_macro_definition() {
880 let macro_def = MacroDefinition::new(
881 "BINARY_OP".to_string(),
882 vec!["op".to_string(), "T".to_string()],
883 "fn $op(a: $T, b: $T) -> $T { a $op b }".to_string(),
884 );
885
886 let generator = DynamicCodeGenerator::new();
887 let args = vec![
888 MacroArgument::Literal("+".to_string()),
889 MacroArgument::Literal("f32".to_string()),
890 ];
891
892 let result = macro_def.expand(&args, &generator);
893 assert!(result.is_ok());
894
895 let code = result.unwrap();
896 assert!(code.source.contains("fn +(a: f32, b: f32) -> f32"));
897 }
898
899 #[test]
900 fn test_constant_value_formatting() {
901 let generator = DynamicCodeGenerator::new();
902
903 assert_eq!(generator.format_constant(&ConstantValue::Integer(42)), "42");
904 assert_eq!(
905 generator.format_constant(&ConstantValue::Float(3.14)),
906 "3.14"
907 );
908 assert_eq!(
909 generator.format_constant(&ConstantValue::Boolean(true)),
910 "true"
911 );
912 assert_eq!(
913 generator.format_constant(&ConstantValue::String("hello".to_string())),
914 "\"hello\""
915 );
916 }
917
918 #[test]
919 fn test_template_constraints() {
920 let mut template = CodeTemplate::new("test_template".to_string(), "test ${T}".to_string());
921
922 template.add_constraint(TemplateConstraint::TypeConstraint {
923 param_name: "T".to_string(),
924 allowed_types: vec![DType::F32],
925 });
926
927 let generator = DynamicCodeGenerator::new();
928
929 let mut valid_params = TemplateParameters::new();
931 valid_params.add_type("T".to_string(), DType::F32);
932
933 let result = template.instantiate(&valid_params, &generator);
934 assert!(result.is_ok());
935
936 let mut invalid_params = TemplateParameters::new();
938 invalid_params.add_type("T".to_string(), DType::I32);
939
940 let result = template.instantiate(&invalid_params, &generator);
941 assert!(result.is_err());
942 }
943}