1use scirs2_core::ndarray::Array1;
8use serde::{Deserialize, Serialize};
9use sklears_core::{
10 error::Result as SklResult,
11 traits::{Estimator, Transform},
12};
13use std::collections::HashMap;
14use std::fmt::Debug;
15
16pub trait StandardConfig: Debug + Clone + Default {
18 fn validate(&self) -> SklResult<()>;
20
21 fn summary(&self) -> ConfigSummary;
23
24 fn to_params(&self) -> HashMap<String, ConfigValue>;
26
27 fn from_params(params: HashMap<String, ConfigValue>) -> SklResult<Self>;
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum ConfigValue {
34 Boolean(bool),
36 Integer(i64),
38 Float(f64),
40 String(String),
42 FloatArray(Vec<f64>),
44 IntegerArray(Vec<i64>),
46 StringArray(Vec<String>),
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ConfigSummary {
53 pub component_type: String,
55 pub description: String,
57 pub parameters: HashMap<String, String>,
59 pub is_valid: bool,
61 pub validation_messages: Vec<String>,
63}
64
65pub trait StandardBuilder<T>: Default {
67 fn build(self) -> SklResult<T>;
69
70 fn build_with_validation<F>(self, validator: F) -> SklResult<T>
72 where
73 F: FnOnce(&T) -> SklResult<()>;
74
75 fn reset(self) -> Self {
77 Self::default()
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct ExecutionMetadata {
84 pub component_name: String,
86 pub start_time: u64,
88 pub end_time: Option<u64>,
90 pub duration_ms: Option<f64>,
92 pub input_shape: Option<(usize, usize)>,
94 pub output_shape: Option<(usize, usize)>,
96 pub memory_before_mb: Option<f64>,
98 pub memory_after_mb: Option<f64>,
100 pub cpu_utilization: Option<f64>,
102 pub warnings: Vec<String>,
104 pub extra_metadata: HashMap<String, String>,
106}
107
108#[derive(Debug, Clone)]
110pub struct StandardResult<T> {
111 pub result: T,
113 pub metadata: ExecutionMetadata,
115}
116
117pub trait MetadataProvider {
119 fn last_execution_metadata(&self) -> Option<&ExecutionMetadata>;
121
122 fn execution_history(&self) -> &[ExecutionMetadata];
124
125 fn clear_history(&mut self);
127}
128
129pub trait StandardEstimator<X, Y>: Estimator + MetadataProvider + Send + Sync {
131 type Config: StandardConfig;
133
134 type Fitted: StandardFittedEstimator<X, Y>;
136
137 fn config(&self) -> &<Self as StandardEstimator<X, Y>>::Config;
139
140 fn with_config(self, config: <Self as StandardEstimator<X, Y>>::Config) -> SklResult<Self>
142 where
143 Self: Sized;
144
145 fn fit_with_metadata(self, x: X, y: Y) -> SklResult<StandardResult<Self::Fitted>>;
147
148 fn validate_input(&self, x: &X, y: &Y) -> SklResult<()>;
150
151 fn model_summary(&self) -> ModelSummary;
153}
154
155pub trait StandardFittedEstimator<X, Y>: Send + Sync {
157 type Config: StandardConfig;
159
160 type Output;
162
163 fn config(&self) -> &Self::Config;
165
166 fn predict_with_metadata(&self, x: X) -> SklResult<StandardResult<Self::Output>>;
168
169 fn fitted_summary(&self) -> FittedModelSummary;
171
172 fn feature_importance(&self) -> Option<Array1<f64>>;
174}
175
176pub trait StandardTransformer<X>: Transform<X, X> + MetadataProvider + Send + Sync {
178 type Config: StandardConfig;
180
181 type Fitted: StandardFittedTransformer<X>;
183
184 fn config(&self) -> &Self::Config;
186
187 fn fit_with_metadata(self, x: X) -> SklResult<StandardResult<Self::Fitted>>;
189
190 fn transform_with_metadata(&self, x: X) -> SklResult<StandardResult<X>>;
192
193 fn fit_transform_with_metadata(self, x: X) -> SklResult<StandardResult<X>>;
195}
196
197pub trait StandardFittedTransformer<X>: Send + Sync {
199 type Config: StandardConfig;
201
202 fn transform_with_metadata(&self, x: X) -> SklResult<StandardResult<X>>;
204
205 fn fitted_summary(&self) -> FittedTransformerSummary;
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct ModelSummary {
212 pub model_type: String,
214 pub description: String,
216 pub parameter_count: Option<usize>,
218 pub complexity: Option<f64>,
220 pub supports_incremental: bool,
222 pub provides_feature_importance: bool,
224 pub provides_prediction_intervals: bool,
226 pub extra_info: HashMap<String, String>,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct FittedModelSummary {
233 pub base_summary: ModelSummary,
235 pub training_shape: Option<(usize, usize)>,
237 pub training_duration_ms: Option<f64>,
239 pub training_score: Option<f64>,
241 pub cv_score: Option<f64>,
243 pub iterations: Option<usize>,
245 pub converged: Option<bool>,
247 pub feature_names: Option<Vec<String>>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct FittedTransformerSummary {
254 pub transformer_type: String,
256 pub input_shape: Option<(usize, usize)>,
258 pub output_shape: Option<(usize, usize)>,
260 pub is_invertible: bool,
262 pub learned_parameters: HashMap<String, String>,
264}
265
266pub struct ApiConsistencyChecker {
268 config: ConsistencyCheckConfig,
269 type_registry: HashMap<String, ComponentTypeInfo>,
270 cached_reports: HashMap<String, ConsistencyReport>,
271}
272
273pub struct ConsistencyCheckConfig {
275 pub enable_type_analysis: bool,
277 pub enable_performance_analysis: bool,
279 pub enable_thread_safety_check: bool,
281 pub enable_memory_analysis: bool,
283 pub strictness_level: CheckStrictnessLevel,
285 pub custom_rules: Vec<Box<dyn Fn(&str) -> Vec<ConsistencyIssue>>>,
287}
288
289impl std::fmt::Debug for ConsistencyCheckConfig {
290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
291 f.debug_struct("ConsistencyCheckConfig")
292 .field("enable_type_analysis", &self.enable_type_analysis)
293 .field(
294 "enable_performance_analysis",
295 &self.enable_performance_analysis,
296 )
297 .field(
298 "enable_thread_safety_check",
299 &self.enable_thread_safety_check,
300 )
301 .field("enable_memory_analysis", &self.enable_memory_analysis)
302 .field("strictness_level", &self.strictness_level)
303 .field(
304 "custom_rules",
305 &format!("<{} custom rules>", self.custom_rules.len()),
306 )
307 .finish()
308 }
309}
310
311impl Clone for ConsistencyCheckConfig {
312 fn clone(&self) -> Self {
313 Self {
315 enable_type_analysis: self.enable_type_analysis,
316 enable_performance_analysis: self.enable_performance_analysis,
317 enable_thread_safety_check: self.enable_thread_safety_check,
318 enable_memory_analysis: self.enable_memory_analysis,
319 strictness_level: self.strictness_level.clone(),
320 custom_rules: Vec::new(), }
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Eq)]
327pub enum CheckStrictnessLevel {
328 Lenient,
330 Standard,
332 Strict,
334 Pedantic,
336}
337
338#[derive(Debug, Clone)]
340pub struct ComponentTypeInfo {
341 pub name: String,
342 pub category: ComponentCategory,
343 pub implemented_traits: Vec<String>,
344 pub method_signatures: Vec<MethodSignature>,
345 pub performance_characteristics: PerformanceCharacteristics,
346}
347
348#[derive(Debug, Clone, PartialEq, Eq)]
350pub enum ComponentCategory {
351 Estimator,
353 Transformer,
355 Pipeline,
357 Validator,
359 Debugger,
361 Unknown,
363}
364
365#[derive(Debug, Clone)]
367pub struct MethodSignature {
368 pub name: String,
369 pub input_types: Vec<String>,
370 pub output_type: String,
371 pub is_async: bool,
372 pub error_handling: ErrorHandlingPattern,
373}
374
375#[derive(Debug, Clone, PartialEq, Eq)]
377pub enum ErrorHandlingPattern {
378 Result,
380 Option,
382 Panic,
384 Custom(String),
386 None,
387}
388
389#[derive(Debug, Clone)]
391pub struct PerformanceCharacteristics {
392 pub computational_complexity: String,
393 pub memory_complexity: String,
394 pub thread_safety: ThreadSafetyLevel,
395 pub cache_efficiency: f64,
396}
397
398#[derive(Debug, Clone, PartialEq, Eq)]
400pub enum ThreadSafetyLevel {
401 Safe,
403 Conditional,
405 Unsafe,
407 Unknown,
409}
410
411impl Default for ConsistencyCheckConfig {
412 fn default() -> Self {
413 Self {
414 enable_type_analysis: true,
415 enable_performance_analysis: true,
416 enable_thread_safety_check: true,
417 enable_memory_analysis: true,
418 strictness_level: CheckStrictnessLevel::Standard,
419 custom_rules: Vec::new(),
420 }
421 }
422}
423
424impl Default for ApiConsistencyChecker {
425 fn default() -> Self {
426 Self::new()
427 }
428}
429
430impl ApiConsistencyChecker {
431 #[must_use]
433 pub fn new() -> Self {
434 Self {
435 config: ConsistencyCheckConfig::default(),
436 type_registry: HashMap::new(),
437 cached_reports: HashMap::new(),
438 }
439 }
440
441 #[must_use]
443 pub fn with_config(config: ConsistencyCheckConfig) -> Self {
444 Self {
445 config,
446 type_registry: HashMap::new(),
447 cached_reports: HashMap::new(),
448 }
449 }
450
451 pub fn check_component<T>(&mut self, component: &T) -> ConsistencyReport
453 where
454 T: Debug,
455 {
456 let component_name = std::any::type_name::<T>().to_string();
457
458 if let Some(cached_report) = self.cached_reports.get(&component_name) {
460 return cached_report.clone();
461 }
462
463 let mut issues = Vec::new();
464 let mut recommendations = Vec::new();
465
466 let component_info = self.analyze_component_type(&component_name);
468
469 if self.config.enable_type_analysis {
471 issues.extend(self.analyze_type_consistency(&component_info));
472 }
473
474 if self.config.enable_performance_analysis {
475 issues.extend(self.analyze_performance_patterns(&component_info));
476 }
477
478 if self.config.enable_thread_safety_check {
479 issues.extend(self.analyze_thread_safety(&component_info));
480 }
481
482 if self.config.enable_memory_analysis {
483 issues.extend(self.analyze_memory_patterns(&component_info));
484 }
485
486 recommendations = self.generate_recommendations(&issues, &component_info);
488
489 let score = self.calculate_consistency_score(&issues);
491
492 let report = ConsistencyReport {
493 component_name: component_name.clone(),
494 is_consistent: issues
495 .iter()
496 .all(|i| matches!(i.severity, IssueSeverity::Suggestion)),
497 issues,
498 recommendations,
499 score,
500 };
501
502 self.cached_reports.insert(component_name, report.clone());
504 report
505 }
506
507 pub fn check_pipeline_consistency<T>(&mut self, components: &[T]) -> PipelineConsistencyReport
509 where
510 T: Debug,
511 {
512 let mut component_reports = Vec::new();
513 let mut cross_component_issues = Vec::new();
514
515 for component in components {
517 let report = self.check_component(component);
518 component_reports.push(report);
519 }
520
521 cross_component_issues.extend(self.analyze_cross_component_consistency(&component_reports));
523
524 let total_components = component_reports.len();
526 let consistent_components = component_reports.iter().filter(|r| r.is_consistent).count();
527
528 let overall_score = if total_components > 0 {
529 let individual_scores: f64 = component_reports.iter().map(|r| r.score).sum();
530 let cross_penalty = cross_component_issues.len() as f64 * 0.05;
531 ((individual_scores / total_components as f64) - cross_penalty).max(0.0)
532 } else {
533 1.0
534 };
535
536 let improvement_suggestions = self
538 .generate_pipeline_improvement_suggestions(&component_reports, &cross_component_issues);
539
540 PipelineConsistencyReport {
542 total_components,
543 consistent_components,
544 component_reports,
545 overall_score,
546 critical_issues: cross_component_issues,
547 improvement_suggestions,
548 }
549 }
550
551 pub fn register_component_type(&mut self, info: ComponentTypeInfo) {
553 self.type_registry.insert(info.name.clone(), info);
554 }
555
556 pub fn clear_cache(&mut self) {
558 self.cached_reports.clear();
559 }
560
561 #[must_use]
563 pub fn get_analysis_statistics(&self) -> AnalysisStatistics {
564 AnalysisStatistics {
566 total_components_analyzed: self.cached_reports.len(),
567 average_consistency_score: self.cached_reports.values().map(|r| r.score).sum::<f64>()
568 / self.cached_reports.len().max(1) as f64,
569 most_common_issues: self.get_most_common_issues(),
570 registered_types: self.type_registry.len(),
571 }
572 }
573
574 fn analyze_component_type(&self, component_name: &str) -> ComponentTypeInfo {
577 if let Some(info) = self.type_registry.get(component_name) {
579 return info.clone();
580 }
581
582 let category =
584 if component_name.contains("Predictor") || component_name.contains("Estimator") {
585 ComponentCategory::Estimator
586 } else if component_name.contains("Transformer") || component_name.contains("Scaler") {
587 ComponentCategory::Transformer
588 } else if component_name.contains("Pipeline") {
589 ComponentCategory::Pipeline
590 } else if component_name.contains("Validator") {
591 ComponentCategory::Validator
592 } else if component_name.contains("Debugger") {
593 ComponentCategory::Debugger
594 } else {
595 ComponentCategory::Unknown
596 };
597
598 ComponentTypeInfo {
600 name: component_name.to_string(),
601 category,
602 implemented_traits: self.infer_implemented_traits(component_name),
603 method_signatures: self.infer_method_signatures(component_name),
604 performance_characteristics: self.infer_performance_characteristics(component_name),
605 }
606 }
607
608 fn analyze_type_consistency(&self, info: &ComponentTypeInfo) -> Vec<ConsistencyIssue> {
609 let mut issues = Vec::new();
610
611 match info.category {
613 ComponentCategory::Estimator => {
614 if !info.implemented_traits.contains(&"Fit".to_string()) {
615 issues.push(ConsistencyIssue {
616 category: IssueCategory::ConfigurationPattern,
617 severity: IssueSeverity::Major,
618 description: "Estimator should implement Fit trait".to_string(),
619 location: Some("Type definition".to_string()),
620 suggested_fix: Some("impl Fit<X, Y> for YourEstimator".to_string()),
621 });
622 }
623 if !info.implemented_traits.contains(&"Predict".to_string()) {
624 issues.push(ConsistencyIssue {
625 category: IssueCategory::ConfigurationPattern,
626 severity: IssueSeverity::Major,
627 description: "Estimator should implement Predict trait".to_string(),
628 location: Some("Type definition".to_string()),
629 suggested_fix: Some("impl Predict<X> for YourEstimator".to_string()),
630 });
631 }
632 }
633 ComponentCategory::Transformer => {
634 if !info.implemented_traits.contains(&"Transform".to_string()) {
635 issues.push(ConsistencyIssue {
636 category: IssueCategory::ConfigurationPattern,
637 severity: IssueSeverity::Major,
638 description: "Transformer should implement Transform trait".to_string(),
639 location: Some("Type definition".to_string()),
640 suggested_fix: Some("impl Transform<X> for YourTransformer".to_string()),
641 });
642 }
643 }
644 _ => {} }
646
647 issues
648 }
649
650 fn analyze_performance_patterns(&self, info: &ComponentTypeInfo) -> Vec<ConsistencyIssue> {
651 let mut issues = Vec::new();
652
653 if info.performance_characteristics.cache_efficiency < 0.5 {
654 issues.push(ConsistencyIssue {
655 category: IssueCategory::ReturnTypes,
656 severity: IssueSeverity::Minor,
657 description: "Low cache efficiency detected".to_string(),
658 location: Some("Performance analysis".to_string()),
659 suggested_fix: Some("Consider memory access pattern optimizations".to_string()),
660 });
661 }
662
663 issues
664 }
665
666 fn analyze_thread_safety(&self, info: &ComponentTypeInfo) -> Vec<ConsistencyIssue> {
667 let mut issues = Vec::new();
668
669 if matches!(
670 info.performance_characteristics.thread_safety,
671 ThreadSafetyLevel::Unsafe
672 ) {
673 issues.push(ConsistencyIssue {
674 category: IssueCategory::ConfigurationPattern,
675 severity: IssueSeverity::Major,
676 description: "Component is not thread-safe".to_string(),
677 location: Some("Thread safety analysis".to_string()),
678 suggested_fix: Some(
679 "Add synchronization or document thread safety requirements".to_string(),
680 ),
681 });
682 }
683
684 issues
685 }
686
687 fn analyze_memory_patterns(&self, info: &ComponentTypeInfo) -> Vec<ConsistencyIssue> {
688 let mut issues = Vec::new();
689
690 if info
691 .performance_characteristics
692 .memory_complexity
693 .contains("exponential")
694 {
695 issues.push(ConsistencyIssue {
696 category: IssueCategory::ReturnTypes,
697 severity: IssueSeverity::Critical,
698 description: "Exponential memory complexity detected".to_string(),
699 location: Some("Memory analysis".to_string()),
700 suggested_fix: Some("Optimize data structures and algorithms".to_string()),
701 });
702 }
703
704 issues
705 }
706
707 fn generate_recommendations(
708 &self,
709 issues: &[ConsistencyIssue],
710 info: &ComponentTypeInfo,
711 ) -> Vec<ApiRecommendation> {
712 let mut recommendations = Vec::new();
713
714 match info.category {
716 ComponentCategory::Estimator => {
717 recommendations.push(ApiRecommendation {
718 category: RecommendationCategory::InterfaceDesign,
719 priority: RecommendationPriority::Medium,
720 title: "Implement StandardEstimator trait".to_string(),
721 description: "Consider implementing StandardEstimator for enhanced consistency"
722 .to_string(),
723 example_code: Some(
724 "impl StandardEstimator<X, Y> for YourEstimator { ... }".to_string(),
725 ),
726 });
727 }
728 ComponentCategory::Transformer => {
729 recommendations.push(ApiRecommendation {
730 category: RecommendationCategory::InterfaceDesign,
731 priority: RecommendationPriority::Medium,
732 title: "Implement StandardTransformer trait".to_string(),
733 description:
734 "Consider implementing StandardTransformer for enhanced consistency"
735 .to_string(),
736 example_code: Some(
737 "impl StandardTransformer<X> for YourTransformer { ... }".to_string(),
738 ),
739 });
740 }
741 _ => {}
742 }
743
744 for issue in issues {
746 if matches!(
747 issue.severity,
748 IssueSeverity::Critical | IssueSeverity::Major
749 ) {
750 recommendations.push(ApiRecommendation {
751 category: RecommendationCategory::ErrorHandling,
752 priority: RecommendationPriority::High,
753 title: format!("Address {}", issue.description),
754 description: issue
755 .suggested_fix
756 .clone()
757 .unwrap_or_else(|| "No specific fix available".to_string()),
758 example_code: None,
759 });
760 }
761 }
762
763 if recommendations.is_empty() {
764 recommendations.push(ApiRecommendation {
765 category: RecommendationCategory::Documentation,
766 priority: RecommendationPriority::Low,
767 title: "Document component API contract".to_string(),
768 description:
769 "Add high-level documentation describing expected inputs, outputs, and lifecycle to improve discoverability."
770 .to_string(),
771 example_code: None,
772 });
773 }
774
775 recommendations
776 }
777
778 fn calculate_consistency_score(&self, issues: &[ConsistencyIssue]) -> f64 {
779 if issues.is_empty() {
780 return 1.0;
781 }
782
783 let total_penalty: f64 = issues
784 .iter()
785 .map(|issue| match issue.severity {
786 IssueSeverity::Critical => 0.5,
787 IssueSeverity::Major => 0.3,
788 IssueSeverity::Minor => 0.1,
789 IssueSeverity::Suggestion => 0.02,
790 })
791 .sum();
792
793 (1.0 - total_penalty).max(0.0)
794 }
795
796 fn analyze_cross_component_consistency(&self, reports: &[ConsistencyReport]) -> Vec<String> {
797 let mut issues = Vec::new();
798
799 let error_patterns: Vec<_> = reports
801 .iter()
802 .flat_map(|r| &r.issues)
803 .filter(|i| matches!(i.category, IssueCategory::ErrorHandling))
804 .collect();
805
806 if error_patterns.len() > 1 {
807 issues.push("Inconsistent error handling patterns across components".to_string());
808 }
809
810 let naming_issues: Vec<_> = reports
812 .iter()
813 .flat_map(|r| &r.issues)
814 .filter(|i| matches!(i.category, IssueCategory::NamingConvention))
815 .collect();
816
817 if naming_issues.len() > reports.len() / 2 {
818 issues.push("Multiple components have naming convention issues".to_string());
819 }
820
821 issues
822 }
823
824 fn generate_pipeline_improvement_suggestions(
825 &self,
826 reports: &[ConsistencyReport],
827 cross_issues: &[String],
828 ) -> Vec<String> {
829 let mut suggestions = vec![
830 "Standardize error handling across all components".to_string(),
831 "Implement consistent metadata collection".to_string(),
832 "Add configuration validation to all components".to_string(),
833 ];
834
835 if reports.iter().any(|r| r.score < 0.7) {
837 suggestions.push("Focus on improving low-scoring components first".to_string());
838 }
839
840 if !cross_issues.is_empty() {
841 suggestions.push("Address cross-component consistency issues".to_string());
842 }
843
844 suggestions
845 }
846
847 fn infer_implemented_traits(&self, component_name: &str) -> Vec<String> {
848 let mut traits = Vec::new();
849
850 if component_name.contains("Predictor") || component_name.contains("Mock") {
851 traits.extend_from_slice(&["Fit".to_string(), "Predict".to_string()]);
852 }
853 if component_name.contains("Transformer") {
854 traits.push("Transform".to_string());
855 }
856 if component_name.contains("Estimator") {
857 traits.push("Estimator".to_string());
858 }
859
860 traits
861 }
862
863 fn infer_method_signatures(&self, _component_name: &str) -> Vec<MethodSignature> {
864 vec![
866 MethodSignature {
868 name: "fit".to_string(),
869 input_types: vec!["X".to_string(), "Y".to_string()],
870 output_type: "Result<Self::Fitted>".to_string(),
871 is_async: false,
872 error_handling: ErrorHandlingPattern::Result,
873 },
874 MethodSignature {
876 name: "predict".to_string(),
877 input_types: vec!["X".to_string()],
878 output_type: "Result<Output>".to_string(),
879 is_async: false,
880 error_handling: ErrorHandlingPattern::Result,
881 },
882 ]
883 }
884
885 fn infer_performance_characteristics(
886 &self,
887 _component_name: &str,
888 ) -> PerformanceCharacteristics {
889 PerformanceCharacteristics {
891 computational_complexity: "O(n)".to_string(),
892 memory_complexity: "O(n)".to_string(),
893 thread_safety: ThreadSafetyLevel::Safe,
894 cache_efficiency: 0.8,
895 }
896 }
897
898 fn get_most_common_issues(&self) -> Vec<String> {
899 let mut issue_counts: HashMap<String, usize> = HashMap::new();
900
901 for report in self.cached_reports.values() {
902 for issue in &report.issues {
903 *issue_counts.entry(issue.description.clone()).or_insert(0) += 1;
904 }
905 }
906
907 let mut issues: Vec<_> = issue_counts.into_iter().collect();
908 issues.sort_by(|a, b| b.1.cmp(&a.1));
909 issues.into_iter().take(5).map(|(desc, _)| desc).collect()
910 }
911}
912
913#[derive(Debug, Clone)]
915pub struct AnalysisStatistics {
916 pub total_components_analyzed: usize,
917 pub average_consistency_score: f64,
918 pub most_common_issues: Vec<String>,
919 pub registered_types: usize,
920}
921
922#[derive(Debug, Clone, Serialize, Deserialize)]
924pub struct ConsistencyReport {
925 pub component_name: String,
927 pub is_consistent: bool,
929 pub issues: Vec<ConsistencyIssue>,
931 pub recommendations: Vec<ApiRecommendation>,
933 pub score: f64,
935}
936
937#[derive(Debug, Clone, Serialize, Deserialize)]
939pub struct PipelineConsistencyReport {
940 pub total_components: usize,
942 pub consistent_components: usize,
944 pub component_reports: Vec<ConsistencyReport>,
946 pub overall_score: f64,
948 pub critical_issues: Vec<String>,
950 pub improvement_suggestions: Vec<String>,
952}
953
954#[derive(Debug, Clone, Serialize, Deserialize)]
956pub struct ConsistencyIssue {
957 pub category: IssueCategory,
959 pub severity: IssueSeverity,
961 pub description: String,
963 pub location: Option<String>,
965 pub suggested_fix: Option<String>,
967}
968
969#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
971pub enum IssueCategory {
972 NamingConvention,
974 ParameterHandling,
976 ErrorHandling,
978 Documentation,
980 ReturnTypes,
982 ConfigurationPattern,
984}
985
986#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
988pub enum IssueSeverity {
989 Critical,
991 Major,
993 Minor,
995 Suggestion,
997}
998
999#[derive(Debug, Clone, Serialize, Deserialize)]
1001pub struct ApiRecommendation {
1002 pub category: RecommendationCategory,
1004 pub priority: RecommendationPriority,
1006 pub title: String,
1008 pub description: String,
1010 pub example_code: Option<String>,
1012}
1013
1014#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
1016pub enum RecommendationCategory {
1017 InterfaceDesign,
1019 ErrorHandling,
1021 Documentation,
1023 Performance,
1025 Usability,
1027}
1028
1029#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
1031pub enum RecommendationPriority {
1032 High,
1034 Medium,
1036 Low,
1038}
1039
1040#[macro_export]
1042macro_rules! impl_standard_config {
1043 ($config_type:ty, $component_type:expr, $description:expr) => {
1044 impl StandardConfig for $config_type {
1045 fn validate(&self) -> SklResult<()> {
1046 Ok(())
1048 }
1049
1050 fn summary(&self) -> ConfigSummary {
1051 ConfigSummary {
1053 component_type: $component_type.to_string(),
1054 description: $description.to_string(),
1055 parameters: HashMap::new(), is_valid: true,
1057 validation_messages: vec![],
1058 }
1059 }
1060
1061 fn to_params(&self) -> HashMap<String, ConfigValue> {
1062 HashMap::new() }
1064
1065 fn from_params(_params: HashMap<String, ConfigValue>) -> SklResult<Self> {
1066 Ok(Self::default()) }
1068 }
1069 };
1070}
1071
1072#[allow(non_snake_case)]
1073#[cfg(test)]
1074mod tests {
1075 use super::*;
1076
1077 #[derive(Debug, Clone, Default)]
1078 struct TestConfig {
1079 pub param1: f64,
1080 pub param2: bool,
1081 }
1082
1083 impl_standard_config!(TestConfig, "TestComponent", "Test configuration");
1084
1085 #[test]
1086 fn test_standard_config_implementation() {
1087 let config = TestConfig::default();
1088 assert!(config.validate().is_ok());
1089
1090 let summary = config.summary();
1091 assert_eq!(summary.component_type, "TestComponent");
1092 assert_eq!(summary.description, "Test configuration");
1093 assert!(summary.is_valid);
1094 }
1095
1096 #[test]
1097 fn test_api_consistency_checker() {
1098 let mut checker = ApiConsistencyChecker::new();
1099 let config = TestConfig::default();
1100 let report = checker.check_component(&config);
1101
1102 assert_eq!(report.component_name, std::any::type_name::<TestConfig>());
1103 assert!(!report.recommendations.is_empty());
1104 assert!(report.score > 0.0);
1105 }
1106
1107 #[test]
1108 fn test_enhanced_consistency_checking() {
1109 let mut config = ConsistencyCheckConfig::default();
1110 config.strictness_level = CheckStrictnessLevel::Strict;
1111
1112 let mut checker = ApiConsistencyChecker::with_config(config);
1113 let test_config = TestConfig::default();
1114
1115 let report = checker.check_component(&test_config);
1116 assert!(report.score >= 0.0 && report.score <= 1.0);
1117
1118 let cached_report = checker.check_component(&test_config);
1120 assert_eq!(report.component_name, cached_report.component_name);
1121 }
1122
1123 #[test]
1124 fn test_pipeline_consistency_checking() {
1125 let mut checker = ApiConsistencyChecker::new();
1126 let config1 = TestConfig::default();
1127 let config2 = TestConfig {
1128 param1: 1.0,
1129 param2: false,
1130 };
1131
1132 let components = vec![&config1, &config2];
1133 let pipeline_report = checker.check_pipeline_consistency(&components);
1134
1135 assert_eq!(pipeline_report.total_components, 2);
1136 assert!(pipeline_report.overall_score >= 0.0 && pipeline_report.overall_score <= 1.0);
1137 assert!(!pipeline_report.improvement_suggestions.is_empty());
1138 }
1139
1140 #[test]
1141 fn test_analysis_statistics() {
1142 let mut checker = ApiConsistencyChecker::new();
1143 let config = TestConfig::default();
1144
1145 let _report = checker.check_component(&config);
1147
1148 let stats = checker.get_analysis_statistics();
1149 assert_eq!(stats.total_components_analyzed, 1);
1150 assert!(stats.average_consistency_score >= 0.0);
1151 }
1152
1153 #[test]
1154 fn test_execution_metadata() {
1155 let metadata = ExecutionMetadata {
1156 component_name: "test_component".to_string(),
1157 start_time: 1630000000,
1158 end_time: Some(1630000001),
1159 duration_ms: Some(1000.0),
1160 input_shape: Some((100, 10)),
1161 output_shape: Some((100, 5)),
1162 memory_before_mb: Some(50.0),
1163 memory_after_mb: Some(55.0),
1164 cpu_utilization: Some(0.75),
1165 warnings: vec![],
1166 extra_metadata: HashMap::new(),
1167 };
1168
1169 assert_eq!(metadata.component_name, "test_component");
1170 assert_eq!(metadata.duration_ms, Some(1000.0));
1171 assert_eq!(metadata.input_shape, Some((100, 10)));
1172 }
1173
1174 #[test]
1175 fn test_model_summary() {
1176 let summary = ModelSummary {
1177 model_type: "LinearRegression".to_string(),
1178 description: "Linear regression model".to_string(),
1179 parameter_count: Some(10),
1180 complexity: Some(0.3),
1181 supports_incremental: true,
1182 provides_feature_importance: true,
1183 provides_prediction_intervals: false,
1184 extra_info: HashMap::new(),
1185 };
1186
1187 assert_eq!(summary.model_type, "LinearRegression");
1188 assert_eq!(summary.parameter_count, Some(10));
1189 assert!(summary.supports_incremental);
1190 }
1191}
1192
1193pub mod pattern_detection {
1195 use super::{Debug, HashMap};
1196
1197 pub struct ApiPatternDetector {
1199 known_patterns: HashMap<String, ApiPattern>,
1200 pattern_cache: HashMap<String, Vec<DetectedPattern>>,
1201 }
1202
1203 #[derive(Debug, Clone)]
1205 pub struct ApiPattern {
1206 pub name: String,
1207 pub description: String,
1208 pub pattern_type: PatternType,
1209 pub detection_rules: Vec<DetectionRule>,
1210 pub compliance_level: ComplianceLevel,
1211 }
1212
1213 #[derive(Debug, Clone, PartialEq, Eq)]
1215 pub enum PatternType {
1216 Builder,
1218 Repository,
1220 Factory,
1222 Strategy,
1224 Observer,
1226 Decorator,
1228 Pipeline,
1230 }
1231
1232 #[derive(Debug, Clone)]
1234 pub struct DetectionRule {
1235 pub rule_type: RuleType,
1236 pub pattern: String,
1237 pub weight: f64,
1238 }
1239
1240 #[derive(Debug, Clone, PartialEq, Eq)]
1242 pub enum RuleType {
1243 MethodName,
1245 TypeName,
1247 TraitImplementation,
1249 MethodSignature,
1251 FieldName,
1253 }
1254
1255 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
1257 pub enum ComplianceLevel {
1258 Mandatory,
1260 Recommended,
1262 Optional,
1264 }
1265
1266 #[derive(Debug, Clone)]
1268 pub struct DetectedPattern {
1269 pub pattern: ApiPattern,
1270 pub confidence: f64,
1271 pub evidence: Vec<String>,
1272 pub compliance_score: f64,
1273 pub violations: Vec<PatternViolation>,
1274 }
1275
1276 #[derive(Debug, Clone)]
1278 pub struct PatternViolation {
1279 pub violation_type: ViolationType,
1280 pub description: String,
1281 pub location: String,
1282 pub severity: ViolationSeverity,
1283 pub suggestion: String,
1284 }
1285
1286 #[derive(Debug, Clone, PartialEq, Eq)]
1288 pub enum ViolationType {
1289 MissingMethod,
1291 IncorrectSignature,
1293 InconsistentNaming,
1295 MissingTrait,
1297 ErrorHandling,
1299 }
1300
1301 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
1303 pub enum ViolationSeverity {
1304 Critical,
1306 Major,
1308 Minor,
1310 Style,
1312 }
1313
1314 impl Default for ApiPatternDetector {
1315 fn default() -> Self {
1316 Self::new()
1317 }
1318 }
1319
1320 impl ApiPatternDetector {
1321 #[must_use]
1323 pub fn new() -> Self {
1324 let mut detector = Self {
1325 known_patterns: HashMap::new(),
1326 pattern_cache: HashMap::new(),
1327 };
1328
1329 detector.register_standard_patterns();
1330 detector
1331 }
1332
1333 fn register_standard_patterns(&mut self) {
1335 self.register_pattern(ApiPattern {
1337 name: "Builder".to_string(),
1338 description: "Builder pattern for fluent configuration".to_string(),
1339 pattern_type: PatternType::Builder,
1340 detection_rules: vec![
1341 DetectionRule {
1343 rule_type: RuleType::TypeName,
1344 pattern: ".*Builder$".to_string(),
1345 weight: 0.8,
1346 },
1347 DetectionRule {
1349 rule_type: RuleType::MethodName,
1350 pattern: "build".to_string(),
1351 weight: 0.9,
1352 },
1353 DetectionRule {
1355 rule_type: RuleType::MethodSignature,
1356 pattern: "-> Self".to_string(),
1357 weight: 0.7,
1358 },
1359 ],
1360 compliance_level: ComplianceLevel::Recommended,
1361 });
1362
1363 self.register_pattern(ApiPattern {
1365 name: "Repository".to_string(),
1366 description: "Repository pattern for data access".to_string(),
1367 pattern_type: PatternType::Repository,
1368 detection_rules: vec![
1369 DetectionRule {
1371 rule_type: RuleType::MethodName,
1372 pattern: "find.*|get.*|save.*|delete.*".to_string(),
1373 weight: 0.8,
1374 },
1375 DetectionRule {
1377 rule_type: RuleType::TypeName,
1378 pattern: ".*Repository$".to_string(),
1379 weight: 0.9,
1380 },
1381 ],
1382 compliance_level: ComplianceLevel::Optional,
1383 });
1384
1385 self.register_pattern(ApiPattern {
1387 name: "Factory".to_string(),
1388 description: "Factory pattern for object creation".to_string(),
1389 pattern_type: PatternType::Factory,
1390 detection_rules: vec![
1391 DetectionRule {
1393 rule_type: RuleType::MethodName,
1394 pattern: "create.*|new.*|make.*".to_string(),
1395 weight: 0.7,
1396 },
1397 DetectionRule {
1399 rule_type: RuleType::TypeName,
1400 pattern: ".*Factory$".to_string(),
1401 weight: 0.9,
1402 },
1403 ],
1404 compliance_level: ComplianceLevel::Optional,
1405 });
1406
1407 self.register_pattern(ApiPattern {
1409 name: "Pipeline".to_string(),
1410 description: "Pipeline pattern for data processing".to_string(),
1411 pattern_type: PatternType::Pipeline,
1412 detection_rules: vec![
1413 DetectionRule {
1415 rule_type: RuleType::MethodName,
1416 pattern: "fit.*|transform.*|predict.*".to_string(),
1417 weight: 0.8,
1418 },
1419 DetectionRule {
1421 rule_type: RuleType::TypeName,
1422 pattern: ".*Pipeline$".to_string(),
1423 weight: 0.9,
1424 },
1425 DetectionRule {
1427 rule_type: RuleType::TraitImplementation,
1428 pattern: "Estimator|Transform".to_string(),
1429 weight: 0.85,
1430 },
1431 ],
1432 compliance_level: ComplianceLevel::Mandatory,
1433 });
1434 }
1435
1436 pub fn register_pattern(&mut self, pattern: ApiPattern) {
1438 self.known_patterns.insert(pattern.name.clone(), pattern);
1439 }
1440
1441 pub fn detect_patterns<T>(&mut self, component: &T) -> Vec<DetectedPattern>
1443 where
1444 T: Debug,
1445 {
1446 let component_name = std::any::type_name::<T>().to_string();
1447
1448 if let Some(cached_patterns) = self.pattern_cache.get(&component_name) {
1450 return cached_patterns.clone();
1451 }
1452
1453 let mut detected_patterns = Vec::new();
1454
1455 for pattern in self.known_patterns.values() {
1456 if let Some(detected) = self.analyze_pattern_compliance(pattern, &component_name) {
1457 detected_patterns.push(detected);
1458 }
1459 }
1460
1461 self.pattern_cache
1463 .insert(component_name, detected_patterns.clone());
1464 detected_patterns
1465 }
1466
1467 fn analyze_pattern_compliance(
1469 &self,
1470 pattern: &ApiPattern,
1471 component_name: &str,
1472 ) -> Option<DetectedPattern> {
1473 let mut total_score = 0.0;
1474 let mut max_score = 0.0;
1475 let mut evidence = Vec::new();
1476 let mut violations = Vec::new();
1477
1478 for rule in &pattern.detection_rules {
1479 max_score += rule.weight;
1480
1481 match self.evaluate_rule(rule, component_name) {
1482 Some(score) => {
1483 total_score += score * rule.weight;
1484 evidence.push(format!(
1485 "Rule '{}' matched with score {:.2}",
1486 rule.pattern, score
1487 ));
1488 }
1489 None => {
1490 violations.push(PatternViolation {
1491 violation_type: ViolationType::MissingMethod,
1492 description: format!("Rule '{}' not satisfied", rule.pattern),
1493 location: component_name.to_string(),
1494 severity: match pattern.compliance_level {
1495 ComplianceLevel::Mandatory => ViolationSeverity::Critical,
1496 ComplianceLevel::Recommended => ViolationSeverity::Major,
1497 ComplianceLevel::Optional => ViolationSeverity::Minor,
1498 },
1499 suggestion: format!(
1500 "Consider implementing pattern: {}",
1501 pattern.description
1502 ),
1503 });
1504 }
1505 }
1506 }
1507
1508 let confidence = if max_score > 0.0 {
1509 total_score / max_score
1510 } else {
1511 0.0
1512 };
1513
1514 if confidence > 0.3 {
1516 Some(DetectedPattern {
1517 pattern: pattern.clone(),
1518 confidence,
1519 evidence,
1520 compliance_score: confidence,
1521 violations,
1522 })
1523 } else {
1524 None
1525 }
1526 }
1527
1528 fn evaluate_rule(&self, rule: &DetectionRule, component_name: &str) -> Option<f64> {
1530 match rule.rule_type {
1531 RuleType::TypeName => {
1532 if self.matches_pattern(&rule.pattern, component_name) {
1533 Some(1.0)
1534 } else {
1535 None
1536 }
1537 }
1538 RuleType::MethodName
1539 | RuleType::MethodSignature
1540 | RuleType::TraitImplementation => {
1541 if component_name.contains("Builder") && rule.pattern.contains("build") {
1543 Some(0.8)
1544 } else if component_name.contains("Pipeline") && rule.pattern.contains("fit") {
1545 Some(0.9)
1546 } else {
1547 None
1548 }
1549 }
1550 RuleType::FieldName => {
1551 Some(0.5)
1553 }
1554 }
1555 }
1556
1557 fn matches_pattern(&self, pattern: &str, text: &str) -> bool {
1559 if pattern.ends_with('$') {
1561 let prefix = pattern.trim_end_matches('$');
1562 text.ends_with(prefix)
1563 } else if pattern.starts_with(".*") {
1564 let suffix = pattern.trim_start_matches(".*");
1565 text.contains(suffix)
1566 } else {
1567 text.contains(pattern)
1568 }
1569 }
1570
1571 #[must_use]
1573 pub fn get_compliance_report(&self, component_name: &str) -> PatternComplianceReport {
1574 let detected_patterns = self
1575 .pattern_cache
1576 .get(component_name)
1577 .cloned()
1578 .unwrap_or_default();
1579
1580 let total_patterns = self.known_patterns.len();
1581 let compliant_patterns = detected_patterns
1582 .iter()
1583 .filter(|p| p.compliance_score > 0.8)
1584 .count();
1585
1586 let average_compliance = if detected_patterns.is_empty() {
1587 0.0
1588 } else {
1589 detected_patterns
1590 .iter()
1591 .map(|p| p.compliance_score)
1592 .sum::<f64>()
1593 / detected_patterns.len() as f64
1594 };
1595
1596 let critical_violations = detected_patterns
1597 .iter()
1598 .flat_map(|p| &p.violations)
1599 .filter(|v| v.severity == ViolationSeverity::Critical)
1600 .count();
1601
1602 PatternComplianceReport {
1604 component_name: component_name.to_string(),
1605 total_patterns,
1606 compliant_patterns,
1607 detected_patterns: detected_patterns.clone(),
1608 average_compliance,
1609 critical_violations,
1610 recommendations: self.generate_pattern_recommendations(&detected_patterns),
1611 }
1612 }
1613
1614 fn generate_pattern_recommendations(&self, patterns: &[DetectedPattern]) -> Vec<String> {
1616 let mut recommendations = Vec::new();
1617
1618 for pattern in patterns {
1619 if pattern.compliance_score < 0.7 {
1620 recommendations.push(format!(
1621 "Improve compliance with {} pattern (current score: {:.2})",
1622 pattern.pattern.name, pattern.compliance_score
1623 ));
1624 }
1625
1626 for violation in &pattern.violations {
1627 if violation.severity >= ViolationSeverity::Major {
1628 recommendations.push(violation.suggestion.clone());
1629 }
1630 }
1631 }
1632
1633 if recommendations.is_empty() {
1634 recommendations.push("Component shows good pattern compliance".to_string());
1635 }
1636
1637 recommendations
1638 }
1639 }
1640
1641 #[derive(Debug, Clone)]
1643 pub struct PatternComplianceReport {
1644 pub component_name: String,
1645 pub total_patterns: usize,
1646 pub compliant_patterns: usize,
1647 pub detected_patterns: Vec<DetectedPattern>,
1648 pub average_compliance: f64,
1649 pub critical_violations: usize,
1650 pub recommendations: Vec<String>,
1651 }
1652}