1use crate::ir::{IrModule, TypeKind};
7use crate::{JitError, JitResult};
8use indexmap::IndexMap;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct GenericFunctionManager {
14 templates: IndexMap<String, GenericFunctionTemplate>,
16
17 instances: IndexMap<InstantiationKey, InstantiatedFunction>,
19
20 stats: GenericFunctionStats,
22
23 config: GenericFunctionConfig,
25}
26
27#[derive(Debug, Clone)]
29pub struct GenericFunctionTemplate {
30 pub name: String,
32
33 pub type_params: Vec<TypeParameter>,
35
36 pub constraints: Vec<TypeConstraint>,
38
39 pub template_ir: IrModule,
41
42 pub default_impls: HashMap<Vec<TypeKind>, IrModule>,
44
45 pub metadata: TemplateMetadata,
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51pub struct TypeParameter {
52 pub name: String,
54
55 pub kind: ParameterKind,
57
58 pub default: Option<TypeKind>,
60
61 pub variance: Variance,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub enum ParameterKind {
68 Type,
70
71 Shape,
73
74 Constant,
76
77 Layout,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Hash)]
83pub enum Variance {
84 Covariant,
86
87 Contravariant,
89
90 Invariant,
92}
93
94#[derive(Debug, Clone)]
96pub enum TypeConstraint {
97 Trait { param: String, trait_name: String },
99
100 Equality { param1: String, param2: String },
102
103 Subtype { subtype: String, supertype: String },
105
106 Shape {
108 param: String,
109 constraint: ShapeConstraint,
110 },
111
112 Custom {
114 name: String,
115 checker: fn(&[TypeKind]) -> bool,
116 },
117}
118
119#[derive(Debug, Clone)]
121pub enum ShapeConstraint {
122 Positive,
124
125 PowerOfTwo,
127
128 Equal(String),
130
131 LessThan(usize),
133
134 GreaterThan(usize),
136
137 InRange(usize, usize),
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, Hash)]
143pub struct InstantiationKey {
144 pub template_name: String,
146
147 pub type_args: Vec<TypeKind>,
149
150 pub shape_args: Vec<Vec<usize>>,
152
153 pub const_args: Vec<ConstantValue>,
155}
156
157#[derive(Debug, Clone, PartialEq, Eq, Hash)]
159pub enum ConstantValue {
160 Int(i64),
161 UInt(u64),
162 Float(u64), Bool(bool),
164 String(String),
165}
166
167#[derive(Debug, Clone)]
169pub struct InstantiatedFunction {
170 pub key: InstantiationKey,
172
173 pub module: IrModule,
175
176 pub perf_info: GenericFunctionPerfInfo,
178
179 pub usage_count: usize,
181
182 pub instantiation_time_ns: u64,
184}
185
186#[derive(Debug, Clone, Default)]
188pub struct GenericFunctionPerfInfo {
189 pub estimated_exec_time_ns: u64,
191
192 pub code_size: usize,
194
195 pub register_usage: u32,
197
198 pub memory_usage: usize,
200
201 pub optimization_level: OptimizationLevel,
203}
204
205#[derive(Debug, Clone, Default)]
207pub enum OptimizationLevel {
208 None,
210
211 #[default]
213 Basic,
214
215 Aggressive,
217
218 Size,
220
221 Speed,
223}
224
225#[derive(Debug, Clone, Default)]
227pub struct TemplateMetadata {
228 pub author: Option<String>,
230
231 pub documentation: Option<String>,
233
234 pub version: Option<String>,
236
237 pub tags: Vec<String>,
239
240 pub complexity: ComplexityEstimate,
242}
243
244#[derive(Debug, Clone, Default)]
246pub struct ComplexityEstimate {
247 pub time_complexity: String,
249
250 pub space_complexity: String,
252
253 pub instantiation_complexity: String,
255}
256
257#[derive(Debug, Clone)]
259pub struct GenericFunctionConfig {
260 pub max_instantiations_per_template: usize,
262
263 pub enable_monomorphization: bool,
265
266 pub enable_type_inference: bool,
268
269 pub enable_caching: bool,
271
272 pub max_recursion_depth: usize,
274
275 pub enable_profiling: bool,
277}
278
279#[derive(Debug, Clone, Default)]
281pub struct GenericFunctionStats {
282 pub total_templates: usize,
284
285 pub total_instantiations: usize,
287
288 pub cache_hits: usize,
290
291 pub cache_misses: usize,
293
294 pub avg_instantiation_time_ns: u64,
296
297 pub most_used_templates: Vec<(String, usize)>,
299}
300
301impl Default for GenericFunctionConfig {
302 fn default() -> Self {
303 Self {
304 max_instantiations_per_template: 32,
305 enable_monomorphization: true,
306 enable_type_inference: true,
307 enable_caching: true,
308 max_recursion_depth: 16,
309 enable_profiling: false,
310 }
311 }
312}
313
314impl GenericFunctionManager {
315 pub fn new(config: GenericFunctionConfig) -> Self {
317 Self {
318 templates: IndexMap::new(),
319 instances: IndexMap::new(),
320 stats: GenericFunctionStats::default(),
321 config,
322 }
323 }
324
325 pub fn with_defaults() -> Self {
327 Self::new(GenericFunctionConfig::default())
328 }
329
330 pub fn register_template(&mut self, template: GenericFunctionTemplate) -> JitResult<()> {
332 self.validate_template(&template)?;
334
335 self.templates.insert(template.name.clone(), template);
337 self.stats.total_templates += 1;
338
339 Ok(())
340 }
341
342 pub fn instantiate(
344 &mut self,
345 template_name: &str,
346 type_args: &[TypeKind],
347 shape_args: &[Vec<usize>],
348 const_args: &[ConstantValue],
349 ) -> JitResult<InstantiatedFunction> {
350 let key = InstantiationKey {
351 template_name: template_name.to_string(),
352 type_args: type_args.to_vec(),
353 shape_args: shape_args.to_vec(),
354 const_args: const_args.to_vec(),
355 };
356
357 if let Some(instance) = self.instances.get_mut(&key) {
359 instance.usage_count += 1;
360 self.stats.cache_hits += 1;
361 return Ok(instance.clone());
362 }
363
364 self.stats.cache_misses += 1;
365
366 let template = self.templates.get(template_name).ok_or_else(|| {
368 JitError::CompilationError(format!("Template '{}' not found", template_name))
369 })?;
370
371 self.check_constraints(template, type_args, shape_args, const_args)?;
373
374 let start_time = std::time::Instant::now();
376 let instantiated_module = self.perform_instantiation(template, &key)?;
377 let instantiation_time = start_time.elapsed().as_nanos() as u64;
378
379 let perf_info = self.estimate_performance(&instantiated_module)?;
381 let instance = InstantiatedFunction {
382 key: key.clone(),
383 module: instantiated_module,
384 perf_info,
385 usage_count: 1,
386 instantiation_time_ns: instantiation_time,
387 };
388
389 self.instances.insert(key, instance.clone());
391 self.stats.total_instantiations += 1;
392 self.stats.avg_instantiation_time_ns = (self.stats.avg_instantiation_time_ns
393 * (self.stats.total_instantiations - 1) as u64
394 + instantiation_time)
395 / self.stats.total_instantiations as u64;
396
397 Ok(instance)
398 }
399
400 fn validate_template(&self, template: &GenericFunctionTemplate) -> JitResult<()> {
402 let mut param_names = std::collections::HashSet::new();
404 for param in &template.type_params {
405 if !param_names.insert(¶m.name) {
406 return Err(JitError::CompilationError(format!(
407 "Duplicate type parameter name: {}",
408 param.name
409 )));
410 }
411 }
412
413 for constraint in &template.constraints {
415 match constraint {
416 TypeConstraint::Trait { param, .. } => {
417 if !param_names.contains(param) {
418 return Err(JitError::CompilationError(format!(
419 "Constraint references unknown parameter: {}",
420 param
421 )));
422 }
423 }
424 TypeConstraint::Equality { param1, param2 } => {
425 if !param_names.contains(param1) || !param_names.contains(param2) {
426 return Err(JitError::CompilationError(
427 "Equality constraint references unknown parameter".to_string(),
428 ));
429 }
430 }
431 TypeConstraint::Subtype { subtype, supertype } => {
432 if !param_names.contains(subtype) || !param_names.contains(supertype) {
433 return Err(JitError::CompilationError(
434 "Subtype constraint references unknown parameter".to_string(),
435 ));
436 }
437 }
438 TypeConstraint::Shape { param, .. } => {
439 if !param_names.contains(param) {
440 return Err(JitError::CompilationError(format!(
441 "Shape constraint references unknown parameter: {}",
442 param
443 )));
444 }
445 }
446 TypeConstraint::Custom { .. } => {
447 }
449 }
450 }
451
452 Ok(())
453 }
454
455 fn check_constraints(
457 &self,
458 template: &GenericFunctionTemplate,
459 type_args: &[TypeKind],
460 shape_args: &[Vec<usize>],
461 _const_args: &[ConstantValue],
462 ) -> JitResult<()> {
463 if type_args.len() != template.type_params.len() {
464 return Err(JitError::CompilationError(
465 "Wrong number of type arguments".to_string(),
466 ));
467 }
468
469 for constraint in &template.constraints {
470 match constraint {
471 TypeConstraint::Trait { param, trait_name } => {
472 let param_index = template
473 .type_params
474 .iter()
475 .position(|p| p.name == *param)
476 .expect("param should exist in type_params");
477
478 if !self.check_trait_constraint(&type_args[param_index], trait_name) {
479 return Err(JitError::CompilationError(format!(
480 "Type {:?} does not implement trait {}",
481 type_args[param_index], trait_name
482 )));
483 }
484 }
485 TypeConstraint::Shape {
486 param,
487 constraint: shape_constraint,
488 } => {
489 let param_index = template
490 .type_params
491 .iter()
492 .position(|p| p.name == *param)
493 .expect("param should exist in type_params");
494
495 if param_index < shape_args.len() {
496 if !self.check_shape_constraint(&shape_args[param_index], shape_constraint)
497 {
498 return Err(JitError::CompilationError(format!(
499 "Shape constraint violation for parameter {}",
500 param
501 )));
502 }
503 }
504 }
505 TypeConstraint::Custom { checker, .. } => {
506 if !checker(type_args) {
507 return Err(JitError::CompilationError(
508 "Custom constraint violation".to_string(),
509 ));
510 }
511 }
512 TypeConstraint::Equality { param1, param2 } => {
513 let type1 = self.find_type_for_param(template, param1, type_args)?;
515 let type2 = self.find_type_for_param(template, param2, type_args)?;
516
517 if type1 != type2 {
518 return Err(JitError::CompilationError(format!(
519 "Type equality constraint violated: {} != {}",
520 param1, param2
521 )));
522 }
523 }
524 TypeConstraint::Subtype { subtype, supertype } => {
525 let sub_type = self.find_type_for_param(template, subtype, type_args)?;
527 let super_type = self.find_type_for_param(template, supertype, type_args)?;
528
529 if !self.is_subtype(&sub_type, &super_type) {
530 return Err(JitError::CompilationError(format!(
531 "Subtype constraint violated: {} is not a subtype of {}",
532 subtype, supertype
533 )));
534 }
535 }
536 }
537 }
538
539 Ok(())
540 }
541
542 fn find_type_for_param(
544 &self,
545 template: &GenericFunctionTemplate,
546 param_name: &str,
547 type_args: &[TypeKind],
548 ) -> JitResult<TypeKind> {
549 for (i, param) in template.type_params.iter().enumerate() {
550 if param.name == param_name {
551 return Ok(type_args[i].clone());
552 }
553 }
554 Err(JitError::CompilationError(format!(
555 "Type parameter '{}' not found",
556 param_name
557 )))
558 }
559
560 fn is_subtype(&self, sub: &TypeKind, sup: &TypeKind) -> bool {
562 match (sub, sup) {
564 (a, b) if a == b => true,
566 (TypeKind::I8, TypeKind::I16 | TypeKind::I32 | TypeKind::I64) => true,
568 (TypeKind::I16, TypeKind::I32 | TypeKind::I64) => true,
569 (TypeKind::I32, TypeKind::I64) => true,
570 (TypeKind::U8, TypeKind::U16 | TypeKind::U32 | TypeKind::U64) => true,
572 (TypeKind::U16, TypeKind::U32 | TypeKind::U64) => true,
573 (TypeKind::U32, TypeKind::U64) => true,
574 (TypeKind::F16, TypeKind::F32 | TypeKind::F64) => true,
576 (TypeKind::F32, TypeKind::F64) => true,
577 (TypeKind::C64, TypeKind::C128) => true,
579 _ => false,
580 }
581 }
582
583 fn check_trait_constraint(&self, type_kind: &TypeKind, trait_name: &str) -> bool {
585 match trait_name {
586 "Float" => matches!(type_kind, TypeKind::F16 | TypeKind::F32 | TypeKind::F64),
587 "Integer" => matches!(
588 type_kind,
589 TypeKind::I8
590 | TypeKind::I16
591 | TypeKind::I32
592 | TypeKind::I64
593 | TypeKind::U8
594 | TypeKind::U16
595 | TypeKind::U32
596 | TypeKind::U64
597 ),
598 "Numeric" => matches!(
599 type_kind,
600 TypeKind::I8
601 | TypeKind::I16
602 | TypeKind::I32
603 | TypeKind::I64
604 | TypeKind::U8
605 | TypeKind::U16
606 | TypeKind::U32
607 | TypeKind::U64
608 | TypeKind::F16
609 | TypeKind::F32
610 | TypeKind::F64
611 ),
612 "Complex" => matches!(type_kind, TypeKind::C64 | TypeKind::C128),
613 _ => false, }
615 }
616
617 fn check_shape_constraint(&self, shape: &[usize], constraint: &ShapeConstraint) -> bool {
619 match constraint {
620 ShapeConstraint::Positive => shape.iter().all(|&dim| dim > 0),
621 ShapeConstraint::PowerOfTwo => shape.iter().all(|&dim| dim.is_power_of_two()),
622 ShapeConstraint::LessThan(limit) => shape.iter().all(|&dim| dim < *limit),
623 ShapeConstraint::GreaterThan(limit) => shape.iter().all(|&dim| dim > *limit),
624 ShapeConstraint::InRange(min, max) => {
625 shape.iter().all(|&dim| dim >= *min && dim <= *max)
626 }
627 ShapeConstraint::Equal(_param_name) => {
628 !shape.is_empty()
632 }
633 }
634 }
635
636 fn perform_instantiation(
638 &self,
639 template: &GenericFunctionTemplate,
640 key: &InstantiationKey,
641 ) -> JitResult<IrModule> {
642 let mut instantiated = template.template_ir.clone();
644 instantiated.name = format!("{}_{}", template.name, self.generate_mangled_name(key));
645
646 self.substitute_types(&mut instantiated, template, key)?;
648
649 self.optimize_instantiated_function(&mut instantiated)?;
651
652 Ok(instantiated)
653 }
654
655 fn generate_mangled_name(&self, key: &InstantiationKey) -> String {
657 use std::collections::hash_map::DefaultHasher;
658 use std::hash::{Hash, Hasher};
659
660 let mut hasher = DefaultHasher::new();
661 key.hash(&mut hasher);
662 format!("{:x}", hasher.finish())
663 }
664
665 fn substitute_types(
667 &self,
668 module: &mut IrModule,
669 template: &GenericFunctionTemplate,
670 key: &InstantiationKey,
671 ) -> JitResult<()> {
672 let mut type_map = HashMap::new();
674 for (param, type_arg) in template.type_params.iter().zip(key.type_args.iter()) {
675 type_map.insert(param.name.clone(), type_arg.clone());
676 }
677
678 for (_val_id, val_def) in module.values.iter_mut() {
680 self.substitute_ir_type(&mut val_def.ty, &type_map)?;
682 }
683
684 for (_block_id, block) in module.blocks.iter_mut() {
686 for instruction in &mut block.instructions {
687 if let Some(result) = instruction.result {
689 if let Some(val_def) = module.values.get_mut(&result) {
690 self.substitute_ir_type(&mut val_def.ty, &type_map)?;
691 }
692 }
693 }
694 }
695
696 Ok(())
697 }
698
699 fn substitute_ir_type(
701 &self,
702 _ir_type: &mut crate::ir::IrType,
703 _type_map: &HashMap<String, TypeKind>,
704 ) -> JitResult<()> {
705 Ok(())
713 }
714
715 fn optimize_instantiated_function(&self, module: &mut IrModule) -> JitResult<()> {
717 self.eliminate_dead_code_in_generics(module)?;
721
722 self.fold_constants_with_type_info(module)?;
724
725 self.inline_small_functions(module)?;
727
728 self.apply_type_specific_optimizations(module)?;
730
731 Ok(())
732 }
733
734 fn eliminate_dead_code_in_generics(&self, module: &mut IrModule) -> JitResult<()> {
736 use std::collections::HashSet;
737
738 let mut reachable = HashSet::new();
740 let mut worklist = vec![module.entry_block];
741
742 while let Some(block_id) = worklist.pop() {
743 if reachable.insert(block_id) {
744 if let Some(_block) = module.blocks.get(&block_id) {
746 }
750 }
751 }
752
753 module
755 .blocks
756 .retain(|block_id, _| reachable.contains(block_id));
757
758 Ok(())
759 }
760
761 fn fold_constants_with_type_info(&self, module: &mut IrModule) -> JitResult<()> {
763 use crate::ir::IrOpcode;
764
765 for (_block_id, block) in module.blocks.iter_mut() {
767 let mut folded_instructions = Vec::new();
768
769 for instruction in &block.instructions {
770 match instruction.opcode {
772 IrOpcode::Add | IrOpcode::Sub | IrOpcode::Mul | IrOpcode::Div => {
773 folded_instructions.push(instruction.clone());
775 }
776 _ => {
777 folded_instructions.push(instruction.clone());
778 }
779 }
780 }
781
782 }
784
785 Ok(())
786 }
787
788 fn inline_small_functions(&self, _module: &mut IrModule) -> JitResult<()> {
790 Ok(())
793 }
794
795 fn apply_type_specific_optimizations(&self, module: &mut IrModule) -> JitResult<()> {
797 use crate::ir::{IrOpcode, ValueKind};
798
799 for (_val_id, val_def) in &module.values {
801 match &val_def.kind {
802 ValueKind::Constant { .. } => {
803 }
805 _ => {}
806 }
807 }
808
809 for (_block_id, block) in module.blocks.iter_mut() {
815 for instruction in &mut block.instructions {
816 match instruction.opcode {
817 IrOpcode::Mul => {
818 }
820 IrOpcode::Div => {
821 }
823 _ => {}
824 }
825 }
826 }
827
828 Ok(())
829 }
830
831 fn estimate_performance(&self, module: &IrModule) -> JitResult<GenericFunctionPerfInfo> {
833 let mut perf_info = GenericFunctionPerfInfo::default();
834
835 let mut instruction_count = 0;
837 for (_, block) in &module.blocks {
838 instruction_count += block.instructions.len();
839 }
840
841 perf_info.estimated_exec_time_ns = instruction_count as u64 * 10; perf_info.code_size = instruction_count * 4; perf_info.register_usage = (instruction_count / 4).min(32) as u32; perf_info.memory_usage = instruction_count * 8; Ok(perf_info)
847 }
848
849 pub fn stats(&self) -> &GenericFunctionStats {
851 &self.stats
852 }
853
854 pub fn list_templates(&self) -> Vec<&str> {
856 self.templates.keys().map(|s| s.as_str()).collect()
857 }
858
859 pub fn get_template(&self, name: &str) -> Option<&GenericFunctionTemplate> {
861 self.templates.get(name)
862 }
863
864 pub fn clear_instances(&mut self) {
866 self.instances.clear();
867 self.stats.total_instantiations = 0;
868 self.stats.cache_hits = 0;
869 self.stats.cache_misses = 0;
870 }
871
872 pub fn instantiation_count(&self, template_name: &str) -> usize {
874 self.instances
875 .keys()
876 .filter(|k| k.template_name == template_name)
877 .count()
878 }
879}
880
881pub fn create_type_param(name: &str, kind: ParameterKind) -> TypeParameter {
883 TypeParameter {
884 name: name.to_string(),
885 kind,
886 default: None,
887 variance: Variance::Invariant,
888 }
889}
890
891pub fn trait_constraint(param: &str, trait_name: &str) -> TypeConstraint {
893 TypeConstraint::Trait {
894 param: param.to_string(),
895 trait_name: trait_name.to_string(),
896 }
897}
898
899pub fn shape_constraint(param: &str, constraint: ShapeConstraint) -> TypeConstraint {
901 TypeConstraint::Shape {
902 param: param.to_string(),
903 constraint,
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910
911 #[test]
912 fn test_generic_function_manager_creation() {
913 let manager = GenericFunctionManager::with_defaults();
914 assert_eq!(manager.templates.len(), 0);
915 assert_eq!(manager.instances.len(), 0);
916 }
917
918 #[test]
919 fn test_type_parameter_creation() {
920 let param = create_type_param("T", ParameterKind::Type);
921 assert_eq!(param.name, "T");
922 assert_eq!(param.kind, ParameterKind::Type);
923 assert_eq!(param.variance, Variance::Invariant);
924 }
925
926 #[test]
927 fn test_trait_constraint_creation() {
928 let constraint = trait_constraint("T", "Float");
929 match constraint {
930 TypeConstraint::Trait { param, trait_name } => {
931 assert_eq!(param, "T");
932 assert_eq!(trait_name, "Float");
933 }
934 _ => panic!("Expected trait constraint"),
935 }
936 }
937
938 #[test]
939 fn test_trait_constraint_checking() {
940 let manager = GenericFunctionManager::with_defaults();
941
942 assert!(manager.check_trait_constraint(&TypeKind::F32, "Float"));
944 assert!(manager.check_trait_constraint(&TypeKind::F64, "Float"));
945
946 assert!(!manager.check_trait_constraint(&TypeKind::I32, "Float"));
948
949 assert!(manager.check_trait_constraint(&TypeKind::I32, "Integer"));
951 assert!(manager.check_trait_constraint(&TypeKind::U64, "Integer"));
952 }
953
954 #[test]
955 fn test_shape_constraint_checking() {
956 let manager = GenericFunctionManager::with_defaults();
957
958 let positive_shape = vec![2, 4, 8];
959 let zero_shape = vec![0, 4, 8];
960 let power_of_two_shape = vec![2, 4, 8];
961 let non_power_of_two_shape = vec![3, 5, 7];
962
963 assert!(manager.check_shape_constraint(&positive_shape, &ShapeConstraint::Positive));
964 assert!(!manager.check_shape_constraint(&zero_shape, &ShapeConstraint::Positive));
965
966 assert!(manager.check_shape_constraint(&power_of_two_shape, &ShapeConstraint::PowerOfTwo));
967 assert!(
968 !manager.check_shape_constraint(&non_power_of_two_shape, &ShapeConstraint::PowerOfTwo)
969 );
970 }
971}