1use crate::api_data_structures::{
8 ApiReference, ApiVisualization, ApiVisualizationData, AssociatedType, CodeExample, CrateInfo,
9 FieldInfo, MethodInfo, ParameterInfo, TraitInfo, TypeInfo, TypeKind, Visibility,
10 VisualizationConfig, VisualizationNode, VisualizationType,
11};
12use crate::api_generator_config::GeneratorConfig;
13use crate::error::{Result, SklearsError};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone)]
23pub struct TraitAnalyzer {
24 config: GeneratorConfig,
25 trait_cache: HashMap<String, TraitInfo>,
26 hierarchy_depth: usize,
27}
28
29impl TraitAnalyzer {
30 pub fn new(config: GeneratorConfig) -> Self {
32 Self {
33 config,
34 trait_cache: HashMap::new(),
35 hierarchy_depth: 0,
36 }
37 }
38
39 pub fn analyze_traits(&mut self, _crate_info: &CrateInfo) -> Result<Vec<TraitInfo>> {
41 self.trait_cache.clear();
43 self.hierarchy_depth = 0;
44
45 let traits = vec![
48 self.create_estimator_trait()?,
49 self.create_fit_trait()?,
50 self.create_predict_trait()?,
51 self.create_transform_trait()?,
52 self.create_score_trait()?,
53 ];
54
55 for trait_info in &traits {
57 self.trait_cache
58 .insert(trait_info.name.clone(), trait_info.clone());
59 }
60
61 if self.config.include_cross_refs {
63 self.build_trait_hierarchies(&traits)
64 } else {
65 Ok(traits)
66 }
67 }
68
69 pub fn analyze_trait(&mut self, trait_name: &str) -> Result<Option<TraitInfo>> {
71 if let Some(cached) = self.trait_cache.get(trait_name) {
73 return Ok(Some(cached.clone()));
74 }
75
76 match trait_name {
78 "Estimator" => Ok(Some(self.create_estimator_trait()?)),
79 "Fit" => Ok(Some(self.create_fit_trait()?)),
80 "Predict" => Ok(Some(self.create_predict_trait()?)),
81 "Transform" => Ok(Some(self.create_transform_trait()?)),
82 "Score" => Ok(Some(self.create_score_trait()?)),
83 _ => Ok(None),
84 }
85 }
86
87 pub fn get_cached_traits(&self) -> &HashMap<String, TraitInfo> {
89 &self.trait_cache
90 }
91
92 pub fn clear_cache(&mut self) {
94 self.trait_cache.clear();
95 }
96
97 pub fn validate_hierarchy_depth(&self, depth: usize) -> Result<()> {
99 if depth > self.config.max_hierarchy_depth {
100 return Err(SklearsError::InvalidInput(format!(
101 "Trait hierarchy depth {} exceeds maximum allowed depth of {}",
102 depth, self.config.max_hierarchy_depth
103 )));
104 }
105 Ok(())
106 }
107
108 fn build_trait_hierarchies(&mut self, traits: &[TraitInfo]) -> Result<Vec<TraitInfo>> {
110 let mut enhanced_traits = traits.to_vec();
111
112 for trait_info in &mut enhanced_traits {
114 trait_info.supertraits = self.find_supertraits(&trait_info.name)?;
115 trait_info.implementations = self.find_implementations(&trait_info.name)?;
116 }
117
118 Ok(enhanced_traits)
119 }
120
121 fn find_supertraits(&self, trait_name: &str) -> Result<Vec<String>> {
123 match trait_name {
125 "Fit" => Ok(vec!["Estimator".to_string()]),
126 "Predict" => Ok(vec!["Estimator".to_string()]),
127 "Transform" => Ok(vec!["Estimator".to_string()]),
128 "Score" => Ok(vec!["Predict".to_string()]),
129 _ => Ok(Vec::new()),
130 }
131 }
132
133 fn find_implementations(&self, trait_name: &str) -> Result<Vec<String>> {
135 match trait_name {
137 "Estimator" => Ok(vec![
138 "LinearRegression".to_string(),
139 "LogisticRegression".to_string(),
140 "RandomForest".to_string(),
141 "SVM".to_string(),
142 ]),
143 "Fit" => Ok(vec![
144 "LinearRegression".to_string(),
145 "LogisticRegression".to_string(),
146 "RandomForest".to_string(),
147 ]),
148 "Predict" => Ok(vec![
149 "LinearRegression".to_string(),
150 "LogisticRegression".to_string(),
151 "RandomForest".to_string(),
152 ]),
153 "Transform" => Ok(vec![
154 "StandardScaler".to_string(),
155 "PCA".to_string(),
156 "MinMaxScaler".to_string(),
157 ]),
158 "Score" => Ok(vec![
159 "LinearRegression".to_string(),
160 "LogisticRegression".to_string(),
161 ]),
162 _ => Ok(Vec::new()),
163 }
164 }
165
166 fn create_estimator_trait(&self) -> Result<TraitInfo> {
168 Ok(TraitInfo {
169 name: "Estimator".to_string(),
170 description: "Base trait for all machine learning estimators in sklears".to_string(),
171 path: "sklears_core::traits::Estimator".to_string(),
172 generics: Vec::new(),
173 associated_types: vec![AssociatedType {
174 name: "Config".to_string(),
175 description: "Configuration type for the estimator".to_string(),
176 bounds: vec!["Clone".to_string(), "Debug".to_string()],
177 }],
178 methods: vec![
179 MethodInfo {
180 name: "name".to_string(),
181 signature: "fn name(&self) -> &'static str".to_string(),
182 description: "Get the name of the estimator".to_string(),
183 parameters: Vec::new(),
184 return_type: "&'static str".to_string(),
185 required: true,
186 },
187 MethodInfo {
188 name: "default_config".to_string(),
189 signature: "fn default_config() -> Self::Config".to_string(),
190 description: "Get the default configuration for this estimator".to_string(),
191 parameters: Vec::new(),
192 return_type: "Self::Config".to_string(),
193 required: false,
194 },
195 ],
196 supertraits: Vec::new(),
197 implementations: Vec::new(),
198 })
199 }
200
201 fn create_fit_trait(&self) -> Result<TraitInfo> {
203 Ok(TraitInfo {
204 name: "Fit".to_string(),
205 description: "Trait for estimators that can be fitted to training data".to_string(),
206 path: "sklears_core::traits::Fit".to_string(),
207 generics: vec!["X".to_string(), "Y".to_string()],
208 associated_types: vec![AssociatedType {
209 name: "Fitted".to_string(),
210 description: "The type returned after fitting".to_string(),
211 bounds: vec!["Send".to_string(), "Sync".to_string()],
212 }],
213 methods: vec![MethodInfo {
214 name: "fit".to_string(),
215 signature: "fn fit(self, x: &X, y: &Y) -> Result<Self::Fitted>".to_string(),
216 description: "Fit the estimator to training data".to_string(),
217 parameters: vec![
218 ParameterInfo {
219 name: "x".to_string(),
220 param_type: "&X".to_string(),
221 description: "Training features matrix".to_string(),
222 optional: false,
223 },
224 ParameterInfo {
225 name: "y".to_string(),
226 param_type: "&Y".to_string(),
227 description: "Training targets vector".to_string(),
228 optional: false,
229 },
230 ],
231 return_type: "Result<Self::Fitted>".to_string(),
232 required: true,
233 }],
234 supertraits: Vec::new(),
235 implementations: Vec::new(),
236 })
237 }
238
239 fn create_predict_trait(&self) -> Result<TraitInfo> {
241 Ok(TraitInfo {
242 name: "Predict".to_string(),
243 description: "Trait for estimators that can make predictions".to_string(),
244 path: "sklears_core::traits::Predict".to_string(),
245 generics: vec!["X".to_string()],
246 associated_types: vec![AssociatedType {
247 name: "Output".to_string(),
248 description: "The type of prediction output".to_string(),
249 bounds: vec!["Clone".to_string()],
250 }],
251 methods: vec![MethodInfo {
252 name: "predict".to_string(),
253 signature: "fn predict(&self, x: &X) -> Result<Self::Output>".to_string(),
254 description: "Make predictions on new data".to_string(),
255 parameters: vec![ParameterInfo {
256 name: "x".to_string(),
257 param_type: "&X".to_string(),
258 description: "Input features to predict on".to_string(),
259 optional: false,
260 }],
261 return_type: "Result<Self::Output>".to_string(),
262 required: true,
263 }],
264 supertraits: Vec::new(),
265 implementations: Vec::new(),
266 })
267 }
268
269 fn create_transform_trait(&self) -> Result<TraitInfo> {
271 Ok(TraitInfo {
272 name: "Transform".to_string(),
273 description: "Trait for estimators that can transform data".to_string(),
274 path: "sklears_core::traits::Transform".to_string(),
275 generics: vec!["X".to_string()],
276 associated_types: vec![AssociatedType {
277 name: "Output".to_string(),
278 description: "The type of transformed output".to_string(),
279 bounds: vec!["Clone".to_string()],
280 }],
281 methods: vec![
282 MethodInfo {
283 name: "transform".to_string(),
284 signature: "fn transform(&self, x: &X) -> Result<Self::Output>".to_string(),
285 description: "Transform input data".to_string(),
286 parameters: vec![ParameterInfo {
287 name: "x".to_string(),
288 param_type: "&X".to_string(),
289 description: "Input data to transform".to_string(),
290 optional: false,
291 }],
292 return_type: "Result<Self::Output>".to_string(),
293 required: true,
294 },
295 MethodInfo {
296 name: "fit_transform".to_string(),
297 signature:
298 "fn fit_transform(self, x: &X) -> Result<(Self::Fitted, Self::Output)>"
299 .to_string(),
300 description: "Fit the transformer and transform data in one step".to_string(),
301 parameters: vec![ParameterInfo {
302 name: "x".to_string(),
303 param_type: "&X".to_string(),
304 description: "Input data to fit and transform".to_string(),
305 optional: false,
306 }],
307 return_type: "Result<(Self::Fitted, Self::Output)>".to_string(),
308 required: false,
309 },
310 ],
311 supertraits: Vec::new(),
312 implementations: Vec::new(),
313 })
314 }
315
316 fn create_score_trait(&self) -> Result<TraitInfo> {
318 Ok(TraitInfo {
319 name: "Score".to_string(),
320 description: "Trait for estimators that can compute accuracy scores".to_string(),
321 path: "sklears_core::traits::Score".to_string(),
322 generics: vec!["X".to_string(), "Y".to_string()],
323 associated_types: vec![AssociatedType {
324 name: "Score".to_string(),
325 description: "The type of score output".to_string(),
326 bounds: vec!["PartialOrd".to_string(), "Copy".to_string()],
327 }],
328 methods: vec![MethodInfo {
329 name: "score".to_string(),
330 signature: "fn score(&self, x: &X, y: &Y) -> Result<Self::Score>".to_string(),
331 description: "Compute accuracy score on test data".to_string(),
332 parameters: vec![
333 ParameterInfo {
334 name: "x".to_string(),
335 param_type: "&X".to_string(),
336 description: "Test features".to_string(),
337 optional: false,
338 },
339 ParameterInfo {
340 name: "y".to_string(),
341 param_type: "&Y".to_string(),
342 description: "True test targets".to_string(),
343 optional: false,
344 },
345 ],
346 return_type: "Result<Self::Score>".to_string(),
347 required: true,
348 }],
349 supertraits: Vec::new(),
350 implementations: Vec::new(),
351 })
352 }
353}
354
355impl Default for TraitAnalyzer {
356 fn default() -> Self {
357 Self::new(GeneratorConfig::default())
358 }
359}
360
361#[derive(Debug, Clone)]
367pub struct TypeExtractor {
368 #[allow(dead_code)]
369 config: GeneratorConfig,
370 type_cache: HashMap<String, TypeInfo>,
371 generic_constraints: HashMap<String, Vec<String>>,
372}
373
374impl TypeExtractor {
375 pub fn new(config: GeneratorConfig) -> Self {
377 Self {
378 config,
379 type_cache: HashMap::new(),
380 generic_constraints: HashMap::new(),
381 }
382 }
383
384 pub fn extract_types(&mut self, _crate_info: &CrateInfo) -> Result<Vec<TypeInfo>> {
386 self.type_cache.clear();
388
389 let types = vec![
391 self.create_sklears_error_type()?,
392 self.create_estimator_config_type()?,
393 self.create_matrix_type()?,
394 self.create_array_type()?,
395 self.create_model_type()?,
396 ];
397
398 for type_info in &types {
400 self.type_cache
401 .insert(type_info.name.clone(), type_info.clone());
402 }
403
404 Ok(types)
405 }
406
407 pub fn extract_type(&mut self, type_name: &str) -> Result<Option<TypeInfo>> {
409 if let Some(cached) = self.type_cache.get(type_name) {
411 return Ok(Some(cached.clone()));
412 }
413
414 match type_name {
416 "SklearsError" => Ok(Some(self.create_sklears_error_type()?)),
417 "EstimatorConfig" => Ok(Some(self.create_estimator_config_type()?)),
418 "Matrix" => Ok(Some(self.create_matrix_type()?)),
419 "Array" => Ok(Some(self.create_array_type()?)),
420 "Model" => Ok(Some(self.create_model_type()?)),
421 _ => Ok(None),
422 }
423 }
424
425 pub fn get_cached_types(&self) -> &HashMap<String, TypeInfo> {
427 &self.type_cache
428 }
429
430 pub fn analyze_generic_constraints(&mut self, type_name: &str) -> Result<Vec<String>> {
432 if let Some(constraints) = self.generic_constraints.get(type_name) {
433 return Ok(constraints.clone());
434 }
435
436 let constraints = match type_name {
438 "Matrix" => vec!["Clone".to_string(), "Debug".to_string(), "Send".to_string()],
439 "Array" => vec!["Clone".to_string(), "Debug".to_string()],
440 "Model" => vec![
441 "Clone".to_string(),
442 "Debug".to_string(),
443 "Serialize".to_string(),
444 "Deserialize".to_string(),
445 ],
446 _ => Vec::new(),
447 };
448
449 self.generic_constraints
450 .insert(type_name.to_string(), constraints.clone());
451 Ok(constraints)
452 }
453
454 fn create_sklears_error_type(&self) -> Result<TypeInfo> {
456 Ok(TypeInfo {
457 name: "SklearsError".to_string(),
458 description: "Error types for sklears operations".to_string(),
459 path: "sklears_core::error::SklearsError".to_string(),
460 kind: TypeKind::Enum,
461 generics: Vec::new(),
462 fields: vec![
463 FieldInfo {
464 name: "InvalidInput".to_string(),
465 field_type: "String".to_string(),
466 description: "Invalid input parameter provided".to_string(),
467 visibility: Visibility::Public,
468 },
469 FieldInfo {
470 name: "ComputationError".to_string(),
471 field_type: "String".to_string(),
472 description: "Computation or numerical error occurred".to_string(),
473 visibility: Visibility::Public,
474 },
475 FieldInfo {
476 name: "ConfigurationError".to_string(),
477 field_type: "String".to_string(),
478 description: "Invalid configuration provided".to_string(),
479 visibility: Visibility::Public,
480 },
481 FieldInfo {
482 name: "DataError".to_string(),
483 field_type: "String".to_string(),
484 description: "Data format or quality issue".to_string(),
485 visibility: Visibility::Public,
486 },
487 ],
488 trait_impls: vec![
489 "Debug".to_string(),
490 "Display".to_string(),
491 "Error".to_string(),
492 "Clone".to_string(),
493 ],
494 })
495 }
496
497 fn create_estimator_config_type(&self) -> Result<TypeInfo> {
499 Ok(TypeInfo {
500 name: "EstimatorConfig".to_string(),
501 description: "Base configuration for all estimators".to_string(),
502 path: "sklears_core::config::EstimatorConfig".to_string(),
503 kind: TypeKind::Struct,
504 generics: Vec::new(),
505 fields: vec![
506 FieldInfo {
507 name: "random_state".to_string(),
508 field_type: "Option<u64>".to_string(),
509 description: "Random seed for reproducible results".to_string(),
510 visibility: Visibility::Public,
511 },
512 FieldInfo {
513 name: "verbose".to_string(),
514 field_type: "bool".to_string(),
515 description: "Enable verbose output during training".to_string(),
516 visibility: Visibility::Public,
517 },
518 FieldInfo {
519 name: "max_iterations".to_string(),
520 field_type: "usize".to_string(),
521 description: "Maximum number of training iterations".to_string(),
522 visibility: Visibility::Public,
523 },
524 ],
525 trait_impls: vec![
526 "Debug".to_string(),
527 "Clone".to_string(),
528 "Default".to_string(),
529 "Serialize".to_string(),
530 "Deserialize".to_string(),
531 ],
532 })
533 }
534
535 fn create_matrix_type(&self) -> Result<TypeInfo> {
537 Ok(TypeInfo {
538 name: "Matrix".to_string(),
539 description: "Generic matrix type for numerical computations".to_string(),
540 path: "sklears_core::linalg::Matrix".to_string(),
541 kind: TypeKind::Struct,
542 generics: vec!["T".to_string()],
543 fields: vec![
544 FieldInfo {
545 name: "data".to_string(),
546 field_type: "Vec<T>".to_string(),
547 description: "Flattened matrix data in row-major order".to_string(),
548 visibility: Visibility::Private,
549 },
550 FieldInfo {
551 name: "rows".to_string(),
552 field_type: "usize".to_string(),
553 description: "Number of rows in the matrix".to_string(),
554 visibility: Visibility::Public,
555 },
556 FieldInfo {
557 name: "cols".to_string(),
558 field_type: "usize".to_string(),
559 description: "Number of columns in the matrix".to_string(),
560 visibility: Visibility::Public,
561 },
562 ],
563 trait_impls: vec![
564 "Debug".to_string(),
565 "Clone".to_string(),
566 "Index".to_string(),
567 "IndexMut".to_string(),
568 ],
569 })
570 }
571
572 fn create_array_type(&self) -> Result<TypeInfo> {
574 Ok(TypeInfo {
575 name: "Array".to_string(),
576 description: "Multi-dimensional array type".to_string(),
577 path: "sklears_core::array::Array".to_string(),
578 kind: TypeKind::Struct,
579 generics: vec!["T".to_string(), "const N: usize".to_string()],
580 fields: vec![
581 FieldInfo {
582 name: "data".to_string(),
583 field_type: "Vec<T>".to_string(),
584 description: "Array data storage".to_string(),
585 visibility: Visibility::Private,
586 },
587 FieldInfo {
588 name: "shape".to_string(),
589 field_type: "[usize; N]".to_string(),
590 description: "Shape of the array in each dimension".to_string(),
591 visibility: Visibility::Public,
592 },
593 ],
594 trait_impls: vec![
595 "Debug".to_string(),
596 "Clone".to_string(),
597 "Index".to_string(),
598 "IntoIterator".to_string(),
599 ],
600 })
601 }
602
603 fn create_model_type(&self) -> Result<TypeInfo> {
605 Ok(TypeInfo {
606 name: "Model".to_string(),
607 description: "Trained model container".to_string(),
608 path: "sklears_core::model::Model".to_string(),
609 kind: TypeKind::Struct,
610 generics: vec!["E".to_string()],
611 fields: vec![
612 FieldInfo {
613 name: "estimator".to_string(),
614 field_type: "E".to_string(),
615 description: "The trained estimator".to_string(),
616 visibility: Visibility::Private,
617 },
618 FieldInfo {
619 name: "metadata".to_string(),
620 field_type: "ModelMetadata".to_string(),
621 description: "Model training metadata".to_string(),
622 visibility: Visibility::Public,
623 },
624 FieldInfo {
625 name: "metrics".to_string(),
626 field_type: "HashMap<String, f64>".to_string(),
627 description: "Training and validation metrics".to_string(),
628 visibility: Visibility::Public,
629 },
630 ],
631 trait_impls: vec![
632 "Debug".to_string(),
633 "Clone".to_string(),
634 "Serialize".to_string(),
635 "Deserialize".to_string(),
636 ],
637 })
638 }
639}
640
641impl Default for TypeExtractor {
642 fn default() -> Self {
643 Self::new(GeneratorConfig::default())
644 }
645}
646
647#[derive(Debug, Clone)]
653pub struct ExampleValidator {
654 validation_rules: Vec<ValidationRule>,
655 #[allow(dead_code)]
656 compile_timeout_secs: u64,
657 enable_compilation: bool,
658 enable_execution: bool,
659}
660
661impl ExampleValidator {
662 pub fn new() -> Self {
664 Self {
665 validation_rules: vec![
666 ValidationRule::SyntaxCheck,
667 ValidationRule::ImportCheck,
668 ValidationRule::TypeCheck,
669 ValidationRule::SafetyCheck,
670 ],
671 compile_timeout_secs: 30,
672 enable_compilation: false, enable_execution: false, }
675 }
676
677 pub fn with_config(
679 enable_compilation: bool,
680 enable_execution: bool,
681 timeout_secs: u64,
682 ) -> Self {
683 Self {
684 validation_rules: vec![
685 ValidationRule::SyntaxCheck,
686 ValidationRule::ImportCheck,
687 ValidationRule::TypeCheck,
688 ValidationRule::SafetyCheck,
689 ],
690 compile_timeout_secs: timeout_secs,
691 enable_compilation,
692 enable_execution,
693 }
694 }
695
696 pub fn validate_examples(&self, examples: &[CodeExample]) -> Result<Vec<CodeExample>> {
698 let mut validated_examples = Vec::new();
699
700 for example in examples {
701 match self.validate_example(example) {
702 Ok(validated) => validated_examples.push(validated),
703 Err(e) => {
704 eprintln!(
706 "Warning: Failed to validate example '{}': {}",
707 example.title, e
708 );
709 validated_examples.push(example.clone());
710 }
711 }
712 }
713
714 Ok(validated_examples)
715 }
716
717 pub fn validate_example(&self, example: &CodeExample) -> Result<CodeExample> {
719 let mut validated = example.clone();
720
721 for rule in &self.validation_rules {
723 self.apply_validation_rule(rule, &mut validated)?;
724 }
725
726 if self.enable_compilation && example.runnable {
728 self.compile_check(&validated)?;
729 }
730
731 if self.enable_execution && example.runnable {
733 self.execution_check(&validated)?;
734 }
735
736 Ok(validated)
737 }
738
739 fn apply_validation_rule(
741 &self,
742 rule: &ValidationRule,
743 example: &mut CodeExample,
744 ) -> Result<()> {
745 match rule {
746 ValidationRule::SyntaxCheck => self.check_syntax(example),
747 ValidationRule::ImportCheck => self.check_imports(example),
748 ValidationRule::TypeCheck => self.check_types(example),
749 ValidationRule::SafetyCheck => self.check_safety(example),
750 }
751 }
752
753 fn check_syntax(&self, example: &CodeExample) -> Result<()> {
755 if example.code.trim().is_empty() {
757 return Err(SklearsError::InvalidInput(
758 "Example code cannot be empty".to_string(),
759 ));
760 }
761
762 let open_braces = example.code.matches('{').count();
764 let close_braces = example.code.matches('}').count();
765 if open_braces != close_braces {
766 return Err(SklearsError::InvalidInput(
767 "Unbalanced braces in example code".to_string(),
768 ));
769 }
770
771 let open_parens = example.code.matches('(').count();
773 let close_parens = example.code.matches(')').count();
774 if open_parens != close_parens {
775 return Err(SklearsError::InvalidInput(
776 "Unbalanced parentheses in example code".to_string(),
777 ));
778 }
779
780 Ok(())
781 }
782
783 fn check_imports(&self, example: &CodeExample) -> Result<()> {
785 let lines: Vec<&str> = example.code.lines().collect();
786
787 for line in lines {
788 let trimmed = line.trim();
789 if trimmed.starts_with("use ") {
790 if trimmed.contains("sklears_") && !self.is_valid_sklears_import(trimmed) {
792 return Err(SklearsError::InvalidInput(format!(
793 "Invalid sklears import: {}",
794 trimmed
795 )));
796 }
797 }
798 }
799
800 Ok(())
801 }
802
803 fn check_types(&self, _example: &CodeExample) -> Result<()> {
805 Ok(())
808 }
809
810 fn check_safety(&self, example: &CodeExample) -> Result<()> {
812 if example.code.contains("unsafe") {
814 return Err(SklearsError::InvalidInput(
815 "Unsafe code blocks are not allowed in examples".to_string(),
816 ));
817 }
818
819 let dangerous_patterns = [
821 "std::process::Command",
822 "std::fs::remove",
823 "std::ptr::",
824 "libc::",
825 "transmute",
826 ];
827
828 for pattern in &dangerous_patterns {
829 if example.code.contains(pattern) {
830 return Err(SklearsError::InvalidInput(format!(
831 "Potentially dangerous pattern '{}' found in example",
832 pattern
833 )));
834 }
835 }
836
837 Ok(())
838 }
839
840 fn is_valid_sklears_import(&self, import: &str) -> bool {
842 let valid_modules = [
843 "sklears_core",
844 "sklears_linear",
845 "sklears_tree",
846 "sklears_ensemble",
847 "sklears_preprocessing",
848 "sklears_metrics",
849 "sklears_neighbors",
850 "sklears_clustering",
851 "sklears_datasets",
852 ];
853
854 valid_modules.iter().any(|module| import.contains(module))
855 }
856
857 fn compile_check(&self, _example: &CodeExample) -> Result<()> {
859 Ok(())
866 }
867
868 fn execution_check(&self, _example: &CodeExample) -> Result<()> {
870 Ok(())
876 }
877
878 pub fn set_validation_rules(&mut self, rules: Vec<ValidationRule>) {
880 self.validation_rules = rules;
881 }
882
883 pub fn set_compilation_enabled(&mut self, enabled: bool) {
885 self.enable_compilation = enabled;
886 }
887
888 pub fn set_execution_enabled(&mut self, enabled: bool) {
890 self.enable_execution = enabled;
891 }
892}
893
894impl Default for ExampleValidator {
895 fn default() -> Self {
896 Self::new()
897 }
898}
899
900#[derive(Debug, Clone, PartialEq, Eq)]
902pub enum ValidationRule {
903 SyntaxCheck,
905 ImportCheck,
907 TypeCheck,
909 SafetyCheck,
911}
912
913#[derive(Debug, Clone)]
919pub struct CrossReferenceBuilder {
920 reference_cache: HashMap<String, Vec<String>>,
921 bidirectional_refs: bool,
922 max_depth: usize,
923}
924
925impl CrossReferenceBuilder {
926 pub fn new() -> Self {
928 Self {
929 reference_cache: HashMap::new(),
930 bidirectional_refs: true,
931 max_depth: 3,
932 }
933 }
934
935 pub fn with_config(bidirectional_refs: bool, max_depth: usize) -> Self {
937 Self {
938 reference_cache: HashMap::new(),
939 bidirectional_refs,
940 max_depth,
941 }
942 }
943
944 pub fn build_cross_references(
946 &mut self,
947 traits: &[TraitInfo],
948 types: &[TypeInfo],
949 ) -> Result<HashMap<String, Vec<String>>> {
950 let mut refs = HashMap::new();
951
952 for trait_info in traits {
954 let mut trait_refs = trait_info.implementations.clone();
955
956 trait_refs.extend(trait_info.supertraits.clone());
958 trait_refs.extend(self.find_related_traits(trait_info, traits)?);
959
960 refs.insert(trait_info.name.clone(), trait_refs);
961 }
962
963 for type_info in types {
965 let mut type_refs = type_info.trait_impls.clone();
966
967 type_refs.extend(self.find_related_types(type_info, types)?);
969
970 refs.insert(type_info.name.clone(), type_refs);
971 }
972
973 if self.bidirectional_refs {
975 refs = self.add_bidirectional_references(refs)?;
976 }
977
978 self.reference_cache = refs.clone();
980
981 Ok(refs)
982 }
983
984 fn find_related_traits(
986 &self,
987 trait_info: &TraitInfo,
988 all_traits: &[TraitInfo],
989 ) -> Result<Vec<String>> {
990 let mut related = Vec::new();
991
992 for other_trait in all_traits {
993 if other_trait.name == trait_info.name {
994 continue;
995 }
996
997 let common_methods = self.count_common_methods(trait_info, other_trait);
999 if common_methods > 0 {
1000 related.push(other_trait.name.clone());
1001 }
1002
1003 if other_trait.supertraits.contains(&trait_info.name) {
1005 related.push(other_trait.name.clone());
1006 }
1007 }
1008
1009 Ok(related)
1010 }
1011
1012 fn find_related_types(
1014 &self,
1015 type_info: &TypeInfo,
1016 all_types: &[TypeInfo],
1017 ) -> Result<Vec<String>> {
1018 let mut related = Vec::new();
1019
1020 for other_type in all_types {
1021 if other_type.name == type_info.name {
1022 continue;
1023 }
1024
1025 let common_traits = self.count_common_trait_impls(type_info, other_type);
1027 if common_traits > 0 {
1028 related.push(other_type.name.clone());
1029 }
1030
1031 if self.have_similar_structure(type_info, other_type) {
1033 related.push(other_type.name.clone());
1034 }
1035 }
1036
1037 Ok(related)
1038 }
1039
1040 fn count_common_methods(&self, trait1: &TraitInfo, trait2: &TraitInfo) -> usize {
1042 let trait1_methods: std::collections::HashSet<_> =
1043 trait1.methods.iter().map(|m| &m.name).collect();
1044 let trait2_methods: std::collections::HashSet<_> =
1045 trait2.methods.iter().map(|m| &m.name).collect();
1046
1047 trait1_methods.intersection(&trait2_methods).count()
1048 }
1049
1050 fn count_common_trait_impls(&self, type1: &TypeInfo, type2: &TypeInfo) -> usize {
1052 let type1_traits: std::collections::HashSet<_> = type1.trait_impls.iter().collect();
1053 let type2_traits: std::collections::HashSet<_> = type2.trait_impls.iter().collect();
1054
1055 type1_traits.intersection(&type2_traits).count()
1056 }
1057
1058 fn have_similar_structure(&self, type1: &TypeInfo, type2: &TypeInfo) -> bool {
1060 if std::mem::discriminant(&type1.kind) != std::mem::discriminant(&type2.kind) {
1062 return false;
1063 }
1064
1065 let field_count_diff = (type1.fields.len() as i32 - type2.fields.len() as i32).abs();
1067 field_count_diff <= 2 }
1069
1070 fn add_bidirectional_references(
1072 &self,
1073 mut refs: HashMap<String, Vec<String>>,
1074 ) -> Result<HashMap<String, Vec<String>>> {
1075 let keys: Vec<String> = refs.keys().cloned().collect();
1076
1077 for key in &keys {
1078 if let Some(values) = refs.get(key).cloned() {
1079 for value in values {
1080 refs.entry(value).or_default().push(key.clone());
1082 }
1083 }
1084 }
1085
1086 for (_, values) in refs.iter_mut() {
1088 values.sort();
1089 values.dedup();
1090 }
1091
1092 Ok(refs)
1093 }
1094
1095 pub fn get_cached_references(&self) -> &HashMap<String, Vec<String>> {
1097 &self.reference_cache
1098 }
1099
1100 pub fn clear_cache(&mut self) {
1102 self.reference_cache.clear();
1103 }
1104
1105 pub fn set_max_depth(&mut self, depth: usize) {
1107 self.max_depth = depth;
1108 }
1109
1110 pub fn set_bidirectional(&mut self, enabled: bool) {
1112 self.bidirectional_refs = enabled;
1113 }
1114}
1115
1116impl Default for CrossReferenceBuilder {
1117 fn default() -> Self {
1118 Self::new()
1119 }
1120}
1121
1122#[derive(Debug, Clone)]
1128pub struct ApiVisualizationEngine {
1129 visualization_templates: HashMap<String, VisualizationTemplate>,
1130}
1131
1132impl ApiVisualizationEngine {
1133 pub fn new() -> Self {
1135 let mut engine = Self {
1136 visualization_templates: HashMap::new(),
1137 };
1138 engine.initialize_templates();
1139 engine
1140 }
1141
1142 pub fn generate_visualizations(&self, api_ref: &ApiReference) -> Result<Vec<ApiVisualization>> {
1144 let mut visualizations = Vec::new();
1145
1146 if !api_ref.traits.is_empty() {
1148 visualizations.push(self.generate_trait_hierarchy_viz(api_ref)?);
1149 }
1150
1151 if !api_ref.types.is_empty() {
1153 visualizations.push(self.generate_type_relationship_viz(api_ref)?);
1154 }
1155
1156 if !api_ref.examples.is_empty() {
1158 visualizations.push(self.generate_example_flow_viz(api_ref)?);
1159 }
1160
1161 Ok(visualizations)
1162 }
1163
1164 fn initialize_templates(&mut self) {
1166 self.visualization_templates.insert(
1168 "trait-hierarchy".to_string(),
1169 VisualizationTemplate {
1170 name: "Trait Hierarchy".to_string(),
1171 template_type: VisualizationType::Tree,
1172 default_config: VisualizationConfig {
1173 width: 800,
1174 height: 600,
1175 theme: "dark".to_string(),
1176 animation_enabled: true,
1177 },
1178 },
1179 );
1180
1181 self.visualization_templates.insert(
1183 "type-relationships".to_string(),
1184 VisualizationTemplate {
1185 name: "Type Relationships".to_string(),
1186 template_type: VisualizationType::Network,
1187 default_config: VisualizationConfig {
1188 width: 700,
1189 height: 500,
1190 theme: "light".to_string(),
1191 animation_enabled: true,
1192 },
1193 },
1194 );
1195
1196 self.visualization_templates.insert(
1198 "example-flow".to_string(),
1199 VisualizationTemplate {
1200 name: "Example Code Flow".to_string(),
1201 template_type: VisualizationType::FlowChart,
1202 default_config: VisualizationConfig {
1203 width: 600,
1204 height: 400,
1205 theme: "auto".to_string(),
1206 animation_enabled: false,
1207 },
1208 },
1209 );
1210 }
1211
1212 fn generate_trait_hierarchy_viz(&self, api_ref: &ApiReference) -> Result<ApiVisualization> {
1214 let template = self
1215 .visualization_templates
1216 .get("trait-hierarchy")
1217 .expect("key should exist");
1218
1219 Ok(ApiVisualization {
1220 title: template.name.clone(),
1221 visualization_type: template.template_type.clone(),
1222 data: ApiVisualizationData {
1223 nodes: api_ref
1224 .traits
1225 .iter()
1226 .map(|t| VisualizationNode {
1227 id: t.name.clone(),
1228 label: t.name.clone(),
1229 node_type: "trait".to_string(),
1230 properties: HashMap::new(),
1231 })
1232 .collect(),
1233 edges: Vec::new(),
1234 metadata: HashMap::new(),
1235 },
1236 config: template.default_config.clone(),
1237 })
1238 }
1239
1240 fn generate_type_relationship_viz(&self, api_ref: &ApiReference) -> Result<ApiVisualization> {
1242 let template = self
1243 .visualization_templates
1244 .get("type-relationships")
1245 .expect("expected valid value");
1246
1247 Ok(ApiVisualization {
1248 title: template.name.clone(),
1249 visualization_type: template.template_type.clone(),
1250 data: ApiVisualizationData {
1251 nodes: api_ref
1252 .types
1253 .iter()
1254 .map(|t| VisualizationNode {
1255 id: t.name.clone(),
1256 label: t.name.clone(),
1257 node_type: "type".to_string(),
1258 properties: HashMap::new(),
1259 })
1260 .collect(),
1261 edges: Vec::new(),
1262 metadata: HashMap::new(),
1263 },
1264 config: template.default_config.clone(),
1265 })
1266 }
1267
1268 fn generate_example_flow_viz(&self, api_ref: &ApiReference) -> Result<ApiVisualization> {
1270 let template = self
1271 .visualization_templates
1272 .get("example-flow")
1273 .expect("key should exist");
1274
1275 Ok(ApiVisualization {
1276 title: template.name.clone(),
1277 visualization_type: template.template_type.clone(),
1278 data: ApiVisualizationData {
1279 nodes: api_ref
1280 .examples
1281 .iter()
1282 .map(|e| VisualizationNode {
1283 id: e.title.clone(),
1284 label: e.title.clone(),
1285 node_type: "example".to_string(),
1286 properties: HashMap::new(),
1287 })
1288 .collect(),
1289 edges: Vec::new(),
1290 metadata: HashMap::new(),
1291 },
1292 config: template.default_config.clone(),
1293 })
1294 }
1295}
1296
1297impl Default for ApiVisualizationEngine {
1298 fn default() -> Self {
1299 Self::new()
1300 }
1301}
1302
1303#[derive(Debug, Clone, Serialize, Deserialize)]
1305pub struct VisualizationTemplate {
1306 pub name: String,
1308 pub template_type: VisualizationType,
1310 pub default_config: VisualizationConfig,
1312}
1313
1314#[allow(non_snake_case)]
1319#[cfg(test)]
1320mod tests {
1321 use super::*;
1322
1323 #[test]
1324 fn test_trait_analyzer() {
1325 let config = GeneratorConfig::new();
1326 let mut analyzer = TraitAnalyzer::new(config);
1327 let crate_info = CrateInfo {
1328 name: "test-crate".to_string(),
1329 version: "1.0.0".to_string(),
1330 description: "Test crate".to_string(),
1331 modules: Vec::new(),
1332 dependencies: Vec::new(),
1333 };
1334
1335 let traits = analyzer
1336 .analyze_traits(&crate_info)
1337 .expect("analyze_traits should succeed");
1338 assert!(!traits.is_empty());
1339 assert!(traits.iter().any(|t| t.name == "Estimator"));
1340 assert!(traits.iter().any(|t| t.name == "Fit"));
1341 }
1342
1343 #[test]
1344 fn test_type_extractor() {
1345 let config = GeneratorConfig::new();
1346 let mut extractor = TypeExtractor::new(config);
1347 let crate_info = CrateInfo {
1348 name: "test-crate".to_string(),
1349 version: "1.0.0".to_string(),
1350 description: "Test crate".to_string(),
1351 modules: Vec::new(),
1352 dependencies: Vec::new(),
1353 };
1354
1355 let types = extractor
1356 .extract_types(&crate_info)
1357 .expect("extract_types should succeed");
1358 assert!(!types.is_empty());
1359 assert!(types.iter().any(|t| t.name == "SklearsError"));
1360 }
1361
1362 #[test]
1363 fn test_example_validator() {
1364 let validator = ExampleValidator::new();
1365 let example = CodeExample {
1366 title: "Test Example".to_string(),
1367 description: "A test example".to_string(),
1368 code: "fn main() { println!(\"Hello, world!\"); }".to_string(),
1369 language: "rust".to_string(),
1370 runnable: true,
1371 expected_output: Some("Hello, world!".to_string()),
1372 };
1373
1374 let validated = validator
1375 .validate_example(&example)
1376 .expect("validate_example should succeed");
1377 assert_eq!(validated.title, example.title);
1378 }
1379
1380 #[test]
1381 fn test_example_validator_syntax_error() {
1382 let validator = ExampleValidator::new();
1383 let example = CodeExample {
1384 title: "Invalid Example".to_string(),
1385 description: "An invalid example".to_string(),
1386 code: "fn main() { println!(\"Hello, world!\"; }".to_string(), language: "rust".to_string(),
1388 runnable: true,
1389 expected_output: None,
1390 };
1391
1392 let result = validator.validate_example(&example);
1393 assert!(result.is_err());
1394 }
1395
1396 #[test]
1397 fn test_cross_reference_builder() {
1398 let mut builder = CrossReferenceBuilder::new();
1399 let traits = vec![TraitInfo {
1400 name: "TestTrait".to_string(),
1401 description: "A test trait".to_string(),
1402 path: "test::TestTrait".to_string(),
1403 generics: Vec::new(),
1404 associated_types: Vec::new(),
1405 methods: Vec::new(),
1406 supertraits: Vec::new(),
1407 implementations: vec!["TestImpl".to_string()],
1408 }];
1409 let types = vec![TypeInfo {
1410 name: "TestType".to_string(),
1411 description: "A test type".to_string(),
1412 path: "test::TestType".to_string(),
1413 kind: TypeKind::Struct,
1414 generics: Vec::new(),
1415 fields: Vec::new(),
1416 trait_impls: vec!["TestTrait".to_string()],
1417 }];
1418
1419 let refs = builder
1420 .build_cross_references(&traits, &types)
1421 .expect("build_cross_references should succeed");
1422 assert!(!refs.is_empty());
1423 assert!(refs.contains_key("TestTrait"));
1424 assert!(refs.contains_key("TestType"));
1425 }
1426
1427 #[test]
1428 fn test_validation_rules() {
1429 let validator = ExampleValidator::new();
1430 let unsafe_example = CodeExample {
1431 title: "Unsafe Example".to_string(),
1432 description: "An unsafe example".to_string(),
1433 code: "unsafe { println!(\"Dangerous!\"); }".to_string(),
1434 language: "rust".to_string(),
1435 runnable: true,
1436 expected_output: None,
1437 };
1438
1439 let result = validator.validate_example(&unsafe_example);
1440 assert!(result.is_err());
1441 }
1442
1443 #[test]
1444 fn test_api_visualization_engine() {
1445 let engine = ApiVisualizationEngine::new();
1446 let api_ref = ApiReference {
1447 crate_name: "test-crate".to_string(),
1448 version: "1.0.0".to_string(),
1449 traits: vec![TraitInfo::default()],
1450 types: vec![],
1451 examples: vec![],
1452 cross_references: HashMap::new(),
1453 metadata: crate::api_data_structures::ApiMetadata::default(),
1454 };
1455
1456 let visualizations = engine
1457 .generate_visualizations(&api_ref)
1458 .expect("generate_visualizations should succeed");
1459 assert!(!visualizations.is_empty());
1460 }
1461}