1use crate::api_data_structures::{
8 ApiReference, ApiVisualization, ApiVisualizationData, AssociatedType, CodeExample, CrateInfo,
9 FieldInfo, MethodInfo, ParameterInfo, TraitInfo, TypeInfo, TypeKind, Visibility,
10 VisualizationConfig, VisualizationNode, VisualizationType,
11};
12use crate::api_generator_config::GeneratorConfig;
13use crate::error::{Result, SklearsError};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone)]
23pub struct TraitAnalyzer {
24 config: GeneratorConfig,
25 trait_cache: HashMap<String, TraitInfo>,
26 hierarchy_depth: usize,
27}
28
29impl TraitAnalyzer {
30 pub fn new(config: GeneratorConfig) -> Self {
32 Self {
33 config,
34 trait_cache: HashMap::new(),
35 hierarchy_depth: 0,
36 }
37 }
38
39 pub fn analyze_traits(&mut self, _crate_info: &CrateInfo) -> Result<Vec<TraitInfo>> {
41 self.trait_cache.clear();
43 self.hierarchy_depth = 0;
44
45 let traits = vec![
48 self.create_estimator_trait()?,
49 self.create_fit_trait()?,
50 self.create_predict_trait()?,
51 self.create_transform_trait()?,
52 self.create_score_trait()?,
53 ];
54
55 for trait_info in &traits {
57 self.trait_cache
58 .insert(trait_info.name.clone(), trait_info.clone());
59 }
60
61 if self.config.include_cross_refs {
63 self.build_trait_hierarchies(&traits)
64 } else {
65 Ok(traits)
66 }
67 }
68
69 pub fn analyze_trait(&mut self, trait_name: &str) -> Result<Option<TraitInfo>> {
71 if let Some(cached) = self.trait_cache.get(trait_name) {
73 return Ok(Some(cached.clone()));
74 }
75
76 match trait_name {
78 "Estimator" => Ok(Some(self.create_estimator_trait()?)),
79 "Fit" => Ok(Some(self.create_fit_trait()?)),
80 "Predict" => Ok(Some(self.create_predict_trait()?)),
81 "Transform" => Ok(Some(self.create_transform_trait()?)),
82 "Score" => Ok(Some(self.create_score_trait()?)),
83 _ => Ok(None),
84 }
85 }
86
87 pub fn get_cached_traits(&self) -> &HashMap<String, TraitInfo> {
89 &self.trait_cache
90 }
91
92 pub fn clear_cache(&mut self) {
94 self.trait_cache.clear();
95 }
96
97 pub fn validate_hierarchy_depth(&self, depth: usize) -> Result<()> {
99 if depth > self.config.max_hierarchy_depth {
100 return Err(SklearsError::InvalidInput(format!(
101 "Trait hierarchy depth {} exceeds maximum allowed depth of {}",
102 depth, self.config.max_hierarchy_depth
103 )));
104 }
105 Ok(())
106 }
107
108 fn build_trait_hierarchies(&mut self, traits: &[TraitInfo]) -> Result<Vec<TraitInfo>> {
110 let mut enhanced_traits = traits.to_vec();
111
112 for trait_info in &mut enhanced_traits {
114 trait_info.supertraits = self.find_supertraits(&trait_info.name)?;
115 trait_info.implementations = self.find_implementations(&trait_info.name)?;
116 }
117
118 Ok(enhanced_traits)
119 }
120
121 fn find_supertraits(&self, trait_name: &str) -> Result<Vec<String>> {
123 match trait_name {
125 "Fit" => Ok(vec!["Estimator".to_string()]),
126 "Predict" => Ok(vec!["Estimator".to_string()]),
127 "Transform" => Ok(vec!["Estimator".to_string()]),
128 "Score" => Ok(vec!["Predict".to_string()]),
129 _ => Ok(Vec::new()),
130 }
131 }
132
133 fn find_implementations(&self, trait_name: &str) -> Result<Vec<String>> {
135 match trait_name {
137 "Estimator" => Ok(vec![
138 "LinearRegression".to_string(),
139 "LogisticRegression".to_string(),
140 "RandomForest".to_string(),
141 "SVM".to_string(),
142 ]),
143 "Fit" => Ok(vec![
144 "LinearRegression".to_string(),
145 "LogisticRegression".to_string(),
146 "RandomForest".to_string(),
147 ]),
148 "Predict" => Ok(vec![
149 "LinearRegression".to_string(),
150 "LogisticRegression".to_string(),
151 "RandomForest".to_string(),
152 ]),
153 "Transform" => Ok(vec![
154 "StandardScaler".to_string(),
155 "PCA".to_string(),
156 "MinMaxScaler".to_string(),
157 ]),
158 "Score" => Ok(vec![
159 "LinearRegression".to_string(),
160 "LogisticRegression".to_string(),
161 ]),
162 _ => Ok(Vec::new()),
163 }
164 }
165
166 fn create_estimator_trait(&self) -> Result<TraitInfo> {
168 Ok(TraitInfo {
169 name: "Estimator".to_string(),
170 description: "Base trait for all machine learning estimators in sklears".to_string(),
171 path: "sklears_core::traits::Estimator".to_string(),
172 generics: Vec::new(),
173 associated_types: vec![AssociatedType {
174 name: "Config".to_string(),
175 description: "Configuration type for the estimator".to_string(),
176 bounds: vec!["Clone".to_string(), "Debug".to_string()],
177 }],
178 methods: vec![
179 MethodInfo {
180 name: "name".to_string(),
181 signature: "fn name(&self) -> &'static str".to_string(),
182 description: "Get the name of the estimator".to_string(),
183 parameters: Vec::new(),
184 return_type: "&'static str".to_string(),
185 required: true,
186 },
187 MethodInfo {
188 name: "default_config".to_string(),
189 signature: "fn default_config() -> Self::Config".to_string(),
190 description: "Get the default configuration for this estimator".to_string(),
191 parameters: Vec::new(),
192 return_type: "Self::Config".to_string(),
193 required: false,
194 },
195 ],
196 supertraits: Vec::new(),
197 implementations: Vec::new(),
198 })
199 }
200
201 fn create_fit_trait(&self) -> Result<TraitInfo> {
203 Ok(TraitInfo {
204 name: "Fit".to_string(),
205 description: "Trait for estimators that can be fitted to training data".to_string(),
206 path: "sklears_core::traits::Fit".to_string(),
207 generics: vec!["X".to_string(), "Y".to_string()],
208 associated_types: vec![AssociatedType {
209 name: "Fitted".to_string(),
210 description: "The type returned after fitting".to_string(),
211 bounds: vec!["Send".to_string(), "Sync".to_string()],
212 }],
213 methods: vec![MethodInfo {
214 name: "fit".to_string(),
215 signature: "fn fit(self, x: &X, y: &Y) -> Result<Self::Fitted>".to_string(),
216 description: "Fit the estimator to training data".to_string(),
217 parameters: vec![
218 ParameterInfo {
219 name: "x".to_string(),
220 param_type: "&X".to_string(),
221 description: "Training features matrix".to_string(),
222 optional: false,
223 },
224 ParameterInfo {
225 name: "y".to_string(),
226 param_type: "&Y".to_string(),
227 description: "Training targets vector".to_string(),
228 optional: false,
229 },
230 ],
231 return_type: "Result<Self::Fitted>".to_string(),
232 required: true,
233 }],
234 supertraits: Vec::new(),
235 implementations: Vec::new(),
236 })
237 }
238
239 fn create_predict_trait(&self) -> Result<TraitInfo> {
241 Ok(TraitInfo {
242 name: "Predict".to_string(),
243 description: "Trait for estimators that can make predictions".to_string(),
244 path: "sklears_core::traits::Predict".to_string(),
245 generics: vec!["X".to_string()],
246 associated_types: vec![AssociatedType {
247 name: "Output".to_string(),
248 description: "The type of prediction output".to_string(),
249 bounds: vec!["Clone".to_string()],
250 }],
251 methods: vec![MethodInfo {
252 name: "predict".to_string(),
253 signature: "fn predict(&self, x: &X) -> Result<Self::Output>".to_string(),
254 description: "Make predictions on new data".to_string(),
255 parameters: vec![ParameterInfo {
256 name: "x".to_string(),
257 param_type: "&X".to_string(),
258 description: "Input features to predict on".to_string(),
259 optional: false,
260 }],
261 return_type: "Result<Self::Output>".to_string(),
262 required: true,
263 }],
264 supertraits: Vec::new(),
265 implementations: Vec::new(),
266 })
267 }
268
269 fn create_transform_trait(&self) -> Result<TraitInfo> {
271 Ok(TraitInfo {
272 name: "Transform".to_string(),
273 description: "Trait for estimators that can transform data".to_string(),
274 path: "sklears_core::traits::Transform".to_string(),
275 generics: vec!["X".to_string()],
276 associated_types: vec![AssociatedType {
277 name: "Output".to_string(),
278 description: "The type of transformed output".to_string(),
279 bounds: vec!["Clone".to_string()],
280 }],
281 methods: vec![
282 MethodInfo {
283 name: "transform".to_string(),
284 signature: "fn transform(&self, x: &X) -> Result<Self::Output>".to_string(),
285 description: "Transform input data".to_string(),
286 parameters: vec![ParameterInfo {
287 name: "x".to_string(),
288 param_type: "&X".to_string(),
289 description: "Input data to transform".to_string(),
290 optional: false,
291 }],
292 return_type: "Result<Self::Output>".to_string(),
293 required: true,
294 },
295 MethodInfo {
296 name: "fit_transform".to_string(),
297 signature:
298 "fn fit_transform(self, x: &X) -> Result<(Self::Fitted, Self::Output)>"
299 .to_string(),
300 description: "Fit the transformer and transform data in one step".to_string(),
301 parameters: vec![ParameterInfo {
302 name: "x".to_string(),
303 param_type: "&X".to_string(),
304 description: "Input data to fit and transform".to_string(),
305 optional: false,
306 }],
307 return_type: "Result<(Self::Fitted, Self::Output)>".to_string(),
308 required: false,
309 },
310 ],
311 supertraits: Vec::new(),
312 implementations: Vec::new(),
313 })
314 }
315
316 fn create_score_trait(&self) -> Result<TraitInfo> {
318 Ok(TraitInfo {
319 name: "Score".to_string(),
320 description: "Trait for estimators that can compute accuracy scores".to_string(),
321 path: "sklears_core::traits::Score".to_string(),
322 generics: vec!["X".to_string(), "Y".to_string()],
323 associated_types: vec![AssociatedType {
324 name: "Score".to_string(),
325 description: "The type of score output".to_string(),
326 bounds: vec!["PartialOrd".to_string(), "Copy".to_string()],
327 }],
328 methods: vec![MethodInfo {
329 name: "score".to_string(),
330 signature: "fn score(&self, x: &X, y: &Y) -> Result<Self::Score>".to_string(),
331 description: "Compute accuracy score on test data".to_string(),
332 parameters: vec![
333 ParameterInfo {
334 name: "x".to_string(),
335 param_type: "&X".to_string(),
336 description: "Test features".to_string(),
337 optional: false,
338 },
339 ParameterInfo {
340 name: "y".to_string(),
341 param_type: "&Y".to_string(),
342 description: "True test targets".to_string(),
343 optional: false,
344 },
345 ],
346 return_type: "Result<Self::Score>".to_string(),
347 required: true,
348 }],
349 supertraits: Vec::new(),
350 implementations: Vec::new(),
351 })
352 }
353}
354
355impl Default for TraitAnalyzer {
356 fn default() -> Self {
357 Self::new(GeneratorConfig::default())
358 }
359}
360
361#[derive(Debug, Clone)]
367pub struct TypeExtractor {
368 #[allow(dead_code)]
369 config: GeneratorConfig,
370 type_cache: HashMap<String, TypeInfo>,
371 generic_constraints: HashMap<String, Vec<String>>,
372}
373
374impl TypeExtractor {
375 pub fn new(config: GeneratorConfig) -> Self {
377 Self {
378 config,
379 type_cache: HashMap::new(),
380 generic_constraints: HashMap::new(),
381 }
382 }
383
384 pub fn extract_types(&mut self, _crate_info: &CrateInfo) -> Result<Vec<TypeInfo>> {
386 self.type_cache.clear();
388
389 let types = vec![
391 self.create_sklears_error_type()?,
392 self.create_estimator_config_type()?,
393 self.create_matrix_type()?,
394 self.create_array_type()?,
395 self.create_model_type()?,
396 ];
397
398 for type_info in &types {
400 self.type_cache
401 .insert(type_info.name.clone(), type_info.clone());
402 }
403
404 Ok(types)
405 }
406
407 pub fn extract_type(&mut self, type_name: &str) -> Result<Option<TypeInfo>> {
409 if let Some(cached) = self.type_cache.get(type_name) {
411 return Ok(Some(cached.clone()));
412 }
413
414 match type_name {
416 "SklearsError" => Ok(Some(self.create_sklears_error_type()?)),
417 "EstimatorConfig" => Ok(Some(self.create_estimator_config_type()?)),
418 "Matrix" => Ok(Some(self.create_matrix_type()?)),
419 "Array" => Ok(Some(self.create_array_type()?)),
420 "Model" => Ok(Some(self.create_model_type()?)),
421 _ => Ok(None),
422 }
423 }
424
425 pub fn get_cached_types(&self) -> &HashMap<String, TypeInfo> {
427 &self.type_cache
428 }
429
430 pub fn analyze_generic_constraints(&mut self, type_name: &str) -> Result<Vec<String>> {
432 if let Some(constraints) = self.generic_constraints.get(type_name) {
433 return Ok(constraints.clone());
434 }
435
436 let constraints = match type_name {
438 "Matrix" => vec!["Clone".to_string(), "Debug".to_string(), "Send".to_string()],
439 "Array" => vec!["Clone".to_string(), "Debug".to_string()],
440 "Model" => vec![
441 "Clone".to_string(),
442 "Debug".to_string(),
443 "Serialize".to_string(),
444 "Deserialize".to_string(),
445 ],
446 _ => Vec::new(),
447 };
448
449 self.generic_constraints
450 .insert(type_name.to_string(), constraints.clone());
451 Ok(constraints)
452 }
453
454 fn create_sklears_error_type(&self) -> Result<TypeInfo> {
456 Ok(TypeInfo {
457 name: "SklearsError".to_string(),
458 description: "Error types for sklears operations".to_string(),
459 path: "sklears_core::error::SklearsError".to_string(),
460 kind: TypeKind::Enum,
461 generics: Vec::new(),
462 fields: vec![
463 FieldInfo {
464 name: "InvalidInput".to_string(),
465 field_type: "String".to_string(),
466 description: "Invalid input parameter provided".to_string(),
467 visibility: Visibility::Public,
468 },
469 FieldInfo {
470 name: "ComputationError".to_string(),
471 field_type: "String".to_string(),
472 description: "Computation or numerical error occurred".to_string(),
473 visibility: Visibility::Public,
474 },
475 FieldInfo {
476 name: "ConfigurationError".to_string(),
477 field_type: "String".to_string(),
478 description: "Invalid configuration provided".to_string(),
479 visibility: Visibility::Public,
480 },
481 FieldInfo {
482 name: "DataError".to_string(),
483 field_type: "String".to_string(),
484 description: "Data format or quality issue".to_string(),
485 visibility: Visibility::Public,
486 },
487 ],
488 trait_impls: vec![
489 "Debug".to_string(),
490 "Display".to_string(),
491 "Error".to_string(),
492 "Clone".to_string(),
493 ],
494 })
495 }
496
497 fn create_estimator_config_type(&self) -> Result<TypeInfo> {
499 Ok(TypeInfo {
500 name: "EstimatorConfig".to_string(),
501 description: "Base configuration for all estimators".to_string(),
502 path: "sklears_core::config::EstimatorConfig".to_string(),
503 kind: TypeKind::Struct,
504 generics: Vec::new(),
505 fields: vec![
506 FieldInfo {
507 name: "random_state".to_string(),
508 field_type: "Option<u64>".to_string(),
509 description: "Random seed for reproducible results".to_string(),
510 visibility: Visibility::Public,
511 },
512 FieldInfo {
513 name: "verbose".to_string(),
514 field_type: "bool".to_string(),
515 description: "Enable verbose output during training".to_string(),
516 visibility: Visibility::Public,
517 },
518 FieldInfo {
519 name: "max_iterations".to_string(),
520 field_type: "usize".to_string(),
521 description: "Maximum number of training iterations".to_string(),
522 visibility: Visibility::Public,
523 },
524 ],
525 trait_impls: vec![
526 "Debug".to_string(),
527 "Clone".to_string(),
528 "Default".to_string(),
529 "Serialize".to_string(),
530 "Deserialize".to_string(),
531 ],
532 })
533 }
534
535 fn create_matrix_type(&self) -> Result<TypeInfo> {
537 Ok(TypeInfo {
538 name: "Matrix".to_string(),
539 description: "Generic matrix type for numerical computations".to_string(),
540 path: "sklears_core::linalg::Matrix".to_string(),
541 kind: TypeKind::Struct,
542 generics: vec!["T".to_string()],
543 fields: vec![
544 FieldInfo {
545 name: "data".to_string(),
546 field_type: "Vec<T>".to_string(),
547 description: "Flattened matrix data in row-major order".to_string(),
548 visibility: Visibility::Private,
549 },
550 FieldInfo {
551 name: "rows".to_string(),
552 field_type: "usize".to_string(),
553 description: "Number of rows in the matrix".to_string(),
554 visibility: Visibility::Public,
555 },
556 FieldInfo {
557 name: "cols".to_string(),
558 field_type: "usize".to_string(),
559 description: "Number of columns in the matrix".to_string(),
560 visibility: Visibility::Public,
561 },
562 ],
563 trait_impls: vec![
564 "Debug".to_string(),
565 "Clone".to_string(),
566 "Index".to_string(),
567 "IndexMut".to_string(),
568 ],
569 })
570 }
571
572 fn create_array_type(&self) -> Result<TypeInfo> {
574 Ok(TypeInfo {
575 name: "Array".to_string(),
576 description: "Multi-dimensional array type".to_string(),
577 path: "sklears_core::array::Array".to_string(),
578 kind: TypeKind::Struct,
579 generics: vec!["T".to_string(), "const N: usize".to_string()],
580 fields: vec![
581 FieldInfo {
582 name: "data".to_string(),
583 field_type: "Vec<T>".to_string(),
584 description: "Array data storage".to_string(),
585 visibility: Visibility::Private,
586 },
587 FieldInfo {
588 name: "shape".to_string(),
589 field_type: "[usize; N]".to_string(),
590 description: "Shape of the array in each dimension".to_string(),
591 visibility: Visibility::Public,
592 },
593 ],
594 trait_impls: vec![
595 "Debug".to_string(),
596 "Clone".to_string(),
597 "Index".to_string(),
598 "IntoIterator".to_string(),
599 ],
600 })
601 }
602
603 fn create_model_type(&self) -> Result<TypeInfo> {
605 Ok(TypeInfo {
606 name: "Model".to_string(),
607 description: "Trained model container".to_string(),
608 path: "sklears_core::model::Model".to_string(),
609 kind: TypeKind::Struct,
610 generics: vec!["E".to_string()],
611 fields: vec![
612 FieldInfo {
613 name: "estimator".to_string(),
614 field_type: "E".to_string(),
615 description: "The trained estimator".to_string(),
616 visibility: Visibility::Private,
617 },
618 FieldInfo {
619 name: "metadata".to_string(),
620 field_type: "ModelMetadata".to_string(),
621 description: "Model training metadata".to_string(),
622 visibility: Visibility::Public,
623 },
624 FieldInfo {
625 name: "metrics".to_string(),
626 field_type: "HashMap<String, f64>".to_string(),
627 description: "Training and validation metrics".to_string(),
628 visibility: Visibility::Public,
629 },
630 ],
631 trait_impls: vec![
632 "Debug".to_string(),
633 "Clone".to_string(),
634 "Serialize".to_string(),
635 "Deserialize".to_string(),
636 ],
637 })
638 }
639}
640
641impl Default for TypeExtractor {
642 fn default() -> Self {
643 Self::new(GeneratorConfig::default())
644 }
645}
646
647#[derive(Debug, Clone)]
653pub struct ExampleValidator {
654 validation_rules: Vec<ValidationRule>,
655 #[allow(dead_code)]
656 compile_timeout_secs: u64,
657 enable_compilation: bool,
658 enable_execution: bool,
659}
660
661impl ExampleValidator {
662 pub fn new() -> Self {
664 Self {
665 validation_rules: vec![
666 ValidationRule::SyntaxCheck,
667 ValidationRule::ImportCheck,
668 ValidationRule::TypeCheck,
669 ValidationRule::SafetyCheck,
670 ],
671 compile_timeout_secs: 30,
672 enable_compilation: false, enable_execution: false, }
675 }
676
677 pub fn with_config(
679 enable_compilation: bool,
680 enable_execution: bool,
681 timeout_secs: u64,
682 ) -> Self {
683 Self {
684 validation_rules: vec![
685 ValidationRule::SyntaxCheck,
686 ValidationRule::ImportCheck,
687 ValidationRule::TypeCheck,
688 ValidationRule::SafetyCheck,
689 ],
690 compile_timeout_secs: timeout_secs,
691 enable_compilation,
692 enable_execution,
693 }
694 }
695
696 pub fn validate_examples(&self, examples: &[CodeExample]) -> Result<Vec<CodeExample>> {
698 let mut validated_examples = Vec::new();
699
700 for example in examples {
701 match self.validate_example(example) {
702 Ok(validated) => validated_examples.push(validated),
703 Err(e) => {
704 eprintln!(
706 "Warning: Failed to validate example '{}': {}",
707 example.title, e
708 );
709 validated_examples.push(example.clone());
710 }
711 }
712 }
713
714 Ok(validated_examples)
715 }
716
717 pub fn validate_example(&self, example: &CodeExample) -> Result<CodeExample> {
719 let mut validated = example.clone();
720
721 for rule in &self.validation_rules {
723 self.apply_validation_rule(rule, &mut validated)?;
724 }
725
726 if self.enable_compilation && example.runnable {
728 self.compile_check(&validated)?;
729 }
730
731 if self.enable_execution && example.runnable {
733 self.execution_check(&validated)?;
734 }
735
736 Ok(validated)
737 }
738
739 fn apply_validation_rule(
741 &self,
742 rule: &ValidationRule,
743 example: &mut CodeExample,
744 ) -> Result<()> {
745 match rule {
746 ValidationRule::SyntaxCheck => self.check_syntax(example),
747 ValidationRule::ImportCheck => self.check_imports(example),
748 ValidationRule::TypeCheck => self.check_types(example),
749 ValidationRule::SafetyCheck => self.check_safety(example),
750 }
751 }
752
753 fn check_syntax(&self, example: &CodeExample) -> Result<()> {
755 if example.code.trim().is_empty() {
757 return Err(SklearsError::InvalidInput(
758 "Example code cannot be empty".to_string(),
759 ));
760 }
761
762 let open_braces = example.code.matches('{').count();
764 let close_braces = example.code.matches('}').count();
765 if open_braces != close_braces {
766 return Err(SklearsError::InvalidInput(
767 "Unbalanced braces in example code".to_string(),
768 ));
769 }
770
771 let open_parens = example.code.matches('(').count();
773 let close_parens = example.code.matches(')').count();
774 if open_parens != close_parens {
775 return Err(SklearsError::InvalidInput(
776 "Unbalanced parentheses in example code".to_string(),
777 ));
778 }
779
780 Ok(())
781 }
782
783 fn check_imports(&self, example: &CodeExample) -> Result<()> {
785 let lines: Vec<&str> = example.code.lines().collect();
786
787 for line in lines {
788 let trimmed = line.trim();
789 if trimmed.starts_with("use ") {
790 if trimmed.contains("sklears_") && !self.is_valid_sklears_import(trimmed) {
792 return Err(SklearsError::InvalidInput(format!(
793 "Invalid sklears import: {}",
794 trimmed
795 )));
796 }
797 }
798 }
799
800 Ok(())
801 }
802
803 fn check_types(&self, _example: &CodeExample) -> Result<()> {
805 Ok(())
808 }
809
810 fn check_safety(&self, example: &CodeExample) -> Result<()> {
812 if example.code.contains("unsafe") {
814 return Err(SklearsError::InvalidInput(
815 "Unsafe code blocks are not allowed in examples".to_string(),
816 ));
817 }
818
819 let dangerous_patterns = [
821 "std::process::Command",
822 "std::fs::remove",
823 "std::ptr::",
824 "libc::",
825 "transmute",
826 ];
827
828 for pattern in &dangerous_patterns {
829 if example.code.contains(pattern) {
830 return Err(SklearsError::InvalidInput(format!(
831 "Potentially dangerous pattern '{}' found in example",
832 pattern
833 )));
834 }
835 }
836
837 Ok(())
838 }
839
840 fn is_valid_sklears_import(&self, import: &str) -> bool {
842 let valid_modules = [
843 "sklears_core",
844 "sklears_linear",
845 "sklears_tree",
846 "sklears_ensemble",
847 "sklears_preprocessing",
848 "sklears_metrics",
849 "sklears_neighbors",
850 "sklears_clustering",
851 "sklears_datasets",
852 ];
853
854 valid_modules.iter().any(|module| import.contains(module))
855 }
856
857 fn compile_check(&self, _example: &CodeExample) -> Result<()> {
859 Ok(())
866 }
867
868 fn execution_check(&self, _example: &CodeExample) -> Result<()> {
870 Ok(())
876 }
877
878 pub fn set_validation_rules(&mut self, rules: Vec<ValidationRule>) {
880 self.validation_rules = rules;
881 }
882
883 pub fn set_compilation_enabled(&mut self, enabled: bool) {
885 self.enable_compilation = enabled;
886 }
887
888 pub fn set_execution_enabled(&mut self, enabled: bool) {
890 self.enable_execution = enabled;
891 }
892}
893
894impl Default for ExampleValidator {
895 fn default() -> Self {
896 Self::new()
897 }
898}
899
900#[derive(Debug, Clone, PartialEq, Eq)]
902pub enum ValidationRule {
903 SyntaxCheck,
905 ImportCheck,
907 TypeCheck,
909 SafetyCheck,
911}
912
913#[derive(Debug, Clone)]
919pub struct CrossReferenceBuilder {
920 reference_cache: HashMap<String, Vec<String>>,
921 bidirectional_refs: bool,
922 max_depth: usize,
923}
924
925impl CrossReferenceBuilder {
926 pub fn new() -> Self {
928 Self {
929 reference_cache: HashMap::new(),
930 bidirectional_refs: true,
931 max_depth: 3,
932 }
933 }
934
935 pub fn with_config(bidirectional_refs: bool, max_depth: usize) -> Self {
937 Self {
938 reference_cache: HashMap::new(),
939 bidirectional_refs,
940 max_depth,
941 }
942 }
943
944 pub fn build_cross_references(
946 &mut self,
947 traits: &[TraitInfo],
948 types: &[TypeInfo],
949 ) -> Result<HashMap<String, Vec<String>>> {
950 let mut refs = HashMap::new();
951
952 for trait_info in traits {
954 let mut trait_refs = trait_info.implementations.clone();
955
956 trait_refs.extend(trait_info.supertraits.clone());
958 trait_refs.extend(self.find_related_traits(trait_info, traits)?);
959
960 refs.insert(trait_info.name.clone(), trait_refs);
961 }
962
963 for type_info in types {
965 let mut type_refs = type_info.trait_impls.clone();
966
967 type_refs.extend(self.find_related_types(type_info, types)?);
969
970 refs.insert(type_info.name.clone(), type_refs);
971 }
972
973 if self.bidirectional_refs {
975 refs = self.add_bidirectional_references(refs)?;
976 }
977
978 self.reference_cache = refs.clone();
980
981 Ok(refs)
982 }
983
984 fn find_related_traits(
986 &self,
987 trait_info: &TraitInfo,
988 all_traits: &[TraitInfo],
989 ) -> Result<Vec<String>> {
990 let mut related = Vec::new();
991
992 for other_trait in all_traits {
993 if other_trait.name == trait_info.name {
994 continue;
995 }
996
997 let common_methods = self.count_common_methods(trait_info, other_trait);
999 if common_methods > 0 {
1000 related.push(other_trait.name.clone());
1001 }
1002
1003 if other_trait.supertraits.contains(&trait_info.name) {
1005 related.push(other_trait.name.clone());
1006 }
1007 }
1008
1009 Ok(related)
1010 }
1011
1012 fn find_related_types(
1014 &self,
1015 type_info: &TypeInfo,
1016 all_types: &[TypeInfo],
1017 ) -> Result<Vec<String>> {
1018 let mut related = Vec::new();
1019
1020 for other_type in all_types {
1021 if other_type.name == type_info.name {
1022 continue;
1023 }
1024
1025 let common_traits = self.count_common_trait_impls(type_info, other_type);
1027 if common_traits > 0 {
1028 related.push(other_type.name.clone());
1029 }
1030
1031 if self.have_similar_structure(type_info, other_type) {
1033 related.push(other_type.name.clone());
1034 }
1035 }
1036
1037 Ok(related)
1038 }
1039
1040 fn count_common_methods(&self, trait1: &TraitInfo, trait2: &TraitInfo) -> usize {
1042 let trait1_methods: std::collections::HashSet<_> =
1043 trait1.methods.iter().map(|m| &m.name).collect();
1044 let trait2_methods: std::collections::HashSet<_> =
1045 trait2.methods.iter().map(|m| &m.name).collect();
1046
1047 trait1_methods.intersection(&trait2_methods).count()
1048 }
1049
1050 fn count_common_trait_impls(&self, type1: &TypeInfo, type2: &TypeInfo) -> usize {
1052 let type1_traits: std::collections::HashSet<_> = type1.trait_impls.iter().collect();
1053 let type2_traits: std::collections::HashSet<_> = type2.trait_impls.iter().collect();
1054
1055 type1_traits.intersection(&type2_traits).count()
1056 }
1057
1058 fn have_similar_structure(&self, type1: &TypeInfo, type2: &TypeInfo) -> bool {
1060 if std::mem::discriminant(&type1.kind) != std::mem::discriminant(&type2.kind) {
1062 return false;
1063 }
1064
1065 let field_count_diff = (type1.fields.len() as i32 - type2.fields.len() as i32).abs();
1067 field_count_diff <= 2 }
1069
1070 fn add_bidirectional_references(
1072 &self,
1073 mut refs: HashMap<String, Vec<String>>,
1074 ) -> Result<HashMap<String, Vec<String>>> {
1075 let keys: Vec<String> = refs.keys().cloned().collect();
1076
1077 for key in &keys {
1078 if let Some(values) = refs.get(key).cloned() {
1079 for value in values {
1080 refs.entry(value).or_default().push(key.clone());
1082 }
1083 }
1084 }
1085
1086 for (_, values) in refs.iter_mut() {
1088 values.sort();
1089 values.dedup();
1090 }
1091
1092 Ok(refs)
1093 }
1094
1095 pub fn get_cached_references(&self) -> &HashMap<String, Vec<String>> {
1097 &self.reference_cache
1098 }
1099
1100 pub fn clear_cache(&mut self) {
1102 self.reference_cache.clear();
1103 }
1104
1105 pub fn set_max_depth(&mut self, depth: usize) {
1107 self.max_depth = depth;
1108 }
1109
1110 pub fn set_bidirectional(&mut self, enabled: bool) {
1112 self.bidirectional_refs = enabled;
1113 }
1114}
1115
1116impl Default for CrossReferenceBuilder {
1117 fn default() -> Self {
1118 Self::new()
1119 }
1120}
1121
1122#[derive(Debug, Clone)]
1128pub struct ApiVisualizationEngine {
1129 visualization_templates: HashMap<String, VisualizationTemplate>,
1130}
1131
1132impl ApiVisualizationEngine {
1133 pub fn new() -> Self {
1135 let mut engine = Self {
1136 visualization_templates: HashMap::new(),
1137 };
1138 engine.initialize_templates();
1139 engine
1140 }
1141
1142 pub fn generate_visualizations(&self, api_ref: &ApiReference) -> Result<Vec<ApiVisualization>> {
1144 let mut visualizations = Vec::new();
1145
1146 if !api_ref.traits.is_empty() {
1148 visualizations.push(self.generate_trait_hierarchy_viz(api_ref)?);
1149 }
1150
1151 if !api_ref.types.is_empty() {
1153 visualizations.push(self.generate_type_relationship_viz(api_ref)?);
1154 }
1155
1156 if !api_ref.examples.is_empty() {
1158 visualizations.push(self.generate_example_flow_viz(api_ref)?);
1159 }
1160
1161 Ok(visualizations)
1162 }
1163
1164 fn initialize_templates(&mut self) {
1166 self.visualization_templates.insert(
1168 "trait-hierarchy".to_string(),
1169 VisualizationTemplate {
1170 name: "Trait Hierarchy".to_string(),
1171 template_type: VisualizationType::Tree,
1172 default_config: VisualizationConfig {
1173 width: 800,
1174 height: 600,
1175 theme: "dark".to_string(),
1176 animation_enabled: true,
1177 },
1178 },
1179 );
1180
1181 self.visualization_templates.insert(
1183 "type-relationships".to_string(),
1184 VisualizationTemplate {
1185 name: "Type Relationships".to_string(),
1186 template_type: VisualizationType::Network,
1187 default_config: VisualizationConfig {
1188 width: 700,
1189 height: 500,
1190 theme: "light".to_string(),
1191 animation_enabled: true,
1192 },
1193 },
1194 );
1195
1196 self.visualization_templates.insert(
1198 "example-flow".to_string(),
1199 VisualizationTemplate {
1200 name: "Example Code Flow".to_string(),
1201 template_type: VisualizationType::FlowChart,
1202 default_config: VisualizationConfig {
1203 width: 600,
1204 height: 400,
1205 theme: "auto".to_string(),
1206 animation_enabled: false,
1207 },
1208 },
1209 );
1210 }
1211
1212 fn generate_trait_hierarchy_viz(&self, api_ref: &ApiReference) -> Result<ApiVisualization> {
1214 let template = self.visualization_templates.get("trait-hierarchy").unwrap();
1215
1216 Ok(ApiVisualization {
1217 title: template.name.clone(),
1218 visualization_type: template.template_type.clone(),
1219 data: ApiVisualizationData {
1220 nodes: api_ref
1221 .traits
1222 .iter()
1223 .map(|t| VisualizationNode {
1224 id: t.name.clone(),
1225 label: t.name.clone(),
1226 node_type: "trait".to_string(),
1227 properties: HashMap::new(),
1228 })
1229 .collect(),
1230 edges: Vec::new(),
1231 metadata: HashMap::new(),
1232 },
1233 config: template.default_config.clone(),
1234 })
1235 }
1236
1237 fn generate_type_relationship_viz(&self, api_ref: &ApiReference) -> Result<ApiVisualization> {
1239 let template = self
1240 .visualization_templates
1241 .get("type-relationships")
1242 .unwrap();
1243
1244 Ok(ApiVisualization {
1245 title: template.name.clone(),
1246 visualization_type: template.template_type.clone(),
1247 data: ApiVisualizationData {
1248 nodes: api_ref
1249 .types
1250 .iter()
1251 .map(|t| VisualizationNode {
1252 id: t.name.clone(),
1253 label: t.name.clone(),
1254 node_type: "type".to_string(),
1255 properties: HashMap::new(),
1256 })
1257 .collect(),
1258 edges: Vec::new(),
1259 metadata: HashMap::new(),
1260 },
1261 config: template.default_config.clone(),
1262 })
1263 }
1264
1265 fn generate_example_flow_viz(&self, api_ref: &ApiReference) -> Result<ApiVisualization> {
1267 let template = self.visualization_templates.get("example-flow").unwrap();
1268
1269 Ok(ApiVisualization {
1270 title: template.name.clone(),
1271 visualization_type: template.template_type.clone(),
1272 data: ApiVisualizationData {
1273 nodes: api_ref
1274 .examples
1275 .iter()
1276 .map(|e| VisualizationNode {
1277 id: e.title.clone(),
1278 label: e.title.clone(),
1279 node_type: "example".to_string(),
1280 properties: HashMap::new(),
1281 })
1282 .collect(),
1283 edges: Vec::new(),
1284 metadata: HashMap::new(),
1285 },
1286 config: template.default_config.clone(),
1287 })
1288 }
1289}
1290
1291impl Default for ApiVisualizationEngine {
1292 fn default() -> Self {
1293 Self::new()
1294 }
1295}
1296
1297#[derive(Debug, Clone, Serialize, Deserialize)]
1299pub struct VisualizationTemplate {
1300 pub name: String,
1302 pub template_type: VisualizationType,
1304 pub default_config: VisualizationConfig,
1306}
1307
1308#[allow(non_snake_case)]
1313#[cfg(test)]
1314mod tests {
1315 use super::*;
1316
1317 #[test]
1318 fn test_trait_analyzer() {
1319 let config = GeneratorConfig::new();
1320 let mut analyzer = TraitAnalyzer::new(config);
1321 let crate_info = CrateInfo {
1322 name: "test-crate".to_string(),
1323 version: "1.0.0".to_string(),
1324 description: "Test crate".to_string(),
1325 modules: Vec::new(),
1326 dependencies: Vec::new(),
1327 };
1328
1329 let traits = analyzer.analyze_traits(&crate_info).unwrap();
1330 assert!(!traits.is_empty());
1331 assert!(traits.iter().any(|t| t.name == "Estimator"));
1332 assert!(traits.iter().any(|t| t.name == "Fit"));
1333 }
1334
1335 #[test]
1336 fn test_type_extractor() {
1337 let config = GeneratorConfig::new();
1338 let mut extractor = TypeExtractor::new(config);
1339 let crate_info = CrateInfo {
1340 name: "test-crate".to_string(),
1341 version: "1.0.0".to_string(),
1342 description: "Test crate".to_string(),
1343 modules: Vec::new(),
1344 dependencies: Vec::new(),
1345 };
1346
1347 let types = extractor.extract_types(&crate_info).unwrap();
1348 assert!(!types.is_empty());
1349 assert!(types.iter().any(|t| t.name == "SklearsError"));
1350 }
1351
1352 #[test]
1353 fn test_example_validator() {
1354 let validator = ExampleValidator::new();
1355 let example = CodeExample {
1356 title: "Test Example".to_string(),
1357 description: "A test example".to_string(),
1358 code: "fn main() { println!(\"Hello, world!\"); }".to_string(),
1359 language: "rust".to_string(),
1360 runnable: true,
1361 expected_output: Some("Hello, world!".to_string()),
1362 };
1363
1364 let validated = validator.validate_example(&example).unwrap();
1365 assert_eq!(validated.title, example.title);
1366 }
1367
1368 #[test]
1369 fn test_example_validator_syntax_error() {
1370 let validator = ExampleValidator::new();
1371 let example = CodeExample {
1372 title: "Invalid Example".to_string(),
1373 description: "An invalid example".to_string(),
1374 code: "fn main() { println!(\"Hello, world!\"; }".to_string(), language: "rust".to_string(),
1376 runnable: true,
1377 expected_output: None,
1378 };
1379
1380 let result = validator.validate_example(&example);
1381 assert!(result.is_err());
1382 }
1383
1384 #[test]
1385 fn test_cross_reference_builder() {
1386 let mut builder = CrossReferenceBuilder::new();
1387 let traits = vec![TraitInfo {
1388 name: "TestTrait".to_string(),
1389 description: "A test trait".to_string(),
1390 path: "test::TestTrait".to_string(),
1391 generics: Vec::new(),
1392 associated_types: Vec::new(),
1393 methods: Vec::new(),
1394 supertraits: Vec::new(),
1395 implementations: vec!["TestImpl".to_string()],
1396 }];
1397 let types = vec![TypeInfo {
1398 name: "TestType".to_string(),
1399 description: "A test type".to_string(),
1400 path: "test::TestType".to_string(),
1401 kind: TypeKind::Struct,
1402 generics: Vec::new(),
1403 fields: Vec::new(),
1404 trait_impls: vec!["TestTrait".to_string()],
1405 }];
1406
1407 let refs = builder.build_cross_references(&traits, &types).unwrap();
1408 assert!(!refs.is_empty());
1409 assert!(refs.contains_key("TestTrait"));
1410 assert!(refs.contains_key("TestType"));
1411 }
1412
1413 #[test]
1414 fn test_validation_rules() {
1415 let validator = ExampleValidator::new();
1416 let unsafe_example = CodeExample {
1417 title: "Unsafe Example".to_string(),
1418 description: "An unsafe example".to_string(),
1419 code: "unsafe { println!(\"Dangerous!\"); }".to_string(),
1420 language: "rust".to_string(),
1421 runnable: true,
1422 expected_output: None,
1423 };
1424
1425 let result = validator.validate_example(&unsafe_example);
1426 assert!(result.is_err());
1427 }
1428
1429 #[test]
1430 fn test_api_visualization_engine() {
1431 let engine = ApiVisualizationEngine::new();
1432 let api_ref = ApiReference {
1433 crate_name: "test-crate".to_string(),
1434 version: "1.0.0".to_string(),
1435 traits: vec![TraitInfo::default()],
1436 types: vec![],
1437 examples: vec![],
1438 cross_references: HashMap::new(),
1439 metadata: crate::api_data_structures::ApiMetadata::default(),
1440 };
1441
1442 let visualizations = engine.generate_visualizations(&api_ref).unwrap();
1443 assert!(!visualizations.is_empty());
1444 }
1445}