1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::time::{Duration, Instant};
11use tracing::{debug, info};
12
13#[derive(Debug)]
15pub struct AICodeAnalyzer {
16 config: AIAnalysisConfig,
17 analysis_cache: HashMap<String, CachedAnalysis>,
18 #[allow(dead_code)]
19 pattern_database: ModelPatternDatabase,
20 performance_monitor: AnalysisPerformanceMonitor,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AIAnalysisConfig {
26 pub enable_deep_analysis: bool,
28 pub enable_pattern_recognition: bool,
30 pub enable_optimization_suggestions: bool,
32 pub enable_vulnerability_detection: bool,
34 pub enable_performance_prediction: bool,
36 pub max_analysis_time_secs: u64,
38 pub confidence_threshold: f64,
40 pub enable_caching: bool,
42 pub cache_expiration_hours: u64,
44}
45
46impl Default for AIAnalysisConfig {
47 fn default() -> Self {
48 Self {
49 enable_deep_analysis: true,
50 enable_pattern_recognition: true,
51 enable_optimization_suggestions: true,
52 enable_vulnerability_detection: true,
53 enable_performance_prediction: true,
54 max_analysis_time_secs: 30,
55 confidence_threshold: 0.75,
56 enable_caching: true,
57 cache_expiration_hours: 24,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64struct CachedAnalysis {
65 result: CodeAnalysisResult,
66 timestamp: std::time::SystemTime,
67 code_hash: String,
68}
69
70#[derive(Debug)]
72struct AnalysisPerformanceMonitor {
73 analysis_count: u64,
74 total_analysis_time: Duration,
75 cache_hits: u64,
76 cache_misses: u64,
77}
78
79impl AnalysisPerformanceMonitor {
80 fn new() -> Self {
81 Self {
82 analysis_count: 0,
83 total_analysis_time: Duration::from_secs(0),
84 cache_hits: 0,
85 cache_misses: 0,
86 }
87 }
88
89 fn record_analysis(&mut self, duration: Duration, cache_hit: bool) {
90 self.analysis_count += 1;
91 self.total_analysis_time += duration;
92 if cache_hit {
93 self.cache_hits += 1;
94 } else {
95 self.cache_misses += 1;
96 }
97 }
98
99 fn average_analysis_time(&self) -> Duration {
100 if self.analysis_count > 0 {
101 self.total_analysis_time / self.analysis_count as u32
102 } else {
103 Duration::from_secs(0)
104 }
105 }
106
107 fn cache_hit_rate(&self) -> f64 {
108 let total = self.cache_hits + self.cache_misses;
109 if total > 0 {
110 self.cache_hits as f64 / total as f64
111 } else {
112 0.0
113 }
114 }
115}
116
117impl AICodeAnalyzer {
118 pub fn new(config: AIAnalysisConfig) -> Self {
120 Self {
121 config,
122 analysis_cache: HashMap::new(),
123 pattern_database: ModelPatternDatabase::new(),
124 performance_monitor: AnalysisPerformanceMonitor::new(),
125 }
126 }
127
128 pub async fn analyze_model_code(
130 &mut self,
131 code: &str,
132 context: ModelContext,
133 ) -> Result<CodeAnalysisResult> {
134 let start_time = Instant::now();
135 let code_hash = self.compute_code_hash(code);
136
137 if self.config.enable_caching {
139 if let Some(cached) = self.get_cached_analysis(&code_hash) {
140 let result = cached.result.clone();
141 self.performance_monitor.record_analysis(start_time.elapsed(), true);
142 return Ok(result);
143 }
144 }
145
146 info!(
147 "Starting AI code analysis for {} lines of code",
148 code.lines().count()
149 );
150
151 let mut result = CodeAnalysisResult::new();
152
153 if self.config.enable_pattern_recognition {
155 let patterns = self.detect_code_patterns(code, &context).await?;
156 result.detected_patterns = patterns;
157 }
158
159 if self.config.enable_deep_analysis {
161 let issues = self.perform_deep_analysis(code, &context).await?;
162 result.identified_issues = issues;
163 }
164
165 if self.config.enable_optimization_suggestions {
167 let optimizations = self.generate_optimization_suggestions(code, &context).await?;
168 result.optimization_suggestions = optimizations;
169 }
170
171 if self.config.enable_vulnerability_detection {
173 let vulnerabilities = self.detect_vulnerabilities(code, &context).await?;
174 result.security_issues = vulnerabilities;
175 }
176
177 if self.config.enable_performance_prediction {
179 let predictions = self.predict_performance_characteristics(code, &context).await?;
180 result.performance_predictions = predictions;
181 }
182
183 result.quality_score = self.calculate_quality_score(&result);
185 result.analysis_metadata = AnalysisMetadata {
186 analysis_duration: start_time.elapsed(),
187 confidence_score: self.calculate_confidence_score(&result),
188 analyzer_version: "1.0.0".to_string(),
189 timestamp: std::time::SystemTime::now(),
190 };
191
192 if self.config.enable_caching {
194 self.cache_analysis(code_hash, &result);
195 }
196
197 self.performance_monitor.record_analysis(start_time.elapsed(), false);
198
199 info!(
200 "AI code analysis completed in {:?} with quality score: {:.2}",
201 start_time.elapsed(),
202 result.quality_score
203 );
204
205 Ok(result)
206 }
207
208 pub async fn analyze_tensor_operations(
210 &self,
211 operations: &[TensorOperation],
212 ) -> Result<TensorOptimizationReport> {
213 debug!("Analyzing {} tensor operations", operations.len());
214
215 let mut report = TensorOptimizationReport::new();
216
217 report.fusion_opportunities = self.detect_fusion_opportunities(operations).await?;
219 report.memory_optimizations = self.detect_memory_optimizations(operations).await?;
220 report.parallelization_opportunities =
221 self.detect_parallelization_opportunities(operations).await?;
222 report.redundant_operations = self.detect_redundant_operations(operations).await?;
223
224 report.estimated_speedup = self.estimate_optimization_speedup(&report);
226 report.estimated_memory_savings = self.estimate_memory_savings(&report);
227
228 Ok(report)
229 }
230
231 pub async fn automated_debugging_assistance(
233 &self,
234 error_context: &ErrorContext,
235 ) -> Result<DebuggingAssistance> {
236 info!(
237 "Providing automated debugging assistance for error: {}",
238 error_context.error_type
239 );
240
241 let mut assistance = DebuggingAssistance::new();
242
243 assistance.probable_causes = self.analyze_error_patterns(error_context).await?;
245 assistance.suggested_fixes = self.generate_suggested_fixes(error_context).await?;
246 assistance.debugging_steps = self.generate_debugging_steps(error_context).await?;
247 assistance.related_documentation = self.find_related_documentation(error_context).await?;
248
249 assistance.confidence_score = self.calculate_debugging_confidence(&assistance);
251
252 Ok(assistance)
253 }
254
255 pub fn get_performance_metrics(&self) -> AnalysisPerformanceMetrics {
257 AnalysisPerformanceMetrics {
258 total_analyses: self.performance_monitor.analysis_count,
259 average_analysis_time: self.performance_monitor.average_analysis_time(),
260 cache_hit_rate: self.performance_monitor.cache_hit_rate(),
261 cached_results: self.analysis_cache.len(),
262 }
263 }
264
265 async fn detect_code_patterns(
268 &self,
269 code: &str,
270 context: &ModelContext,
271 ) -> Result<Vec<DetectedPattern>> {
272 debug!("Detecting code patterns");
273
274 let mut patterns = Vec::new();
275
276 if code.contains("torch.cuda.empty_cache()") && context.model_type == ModelType::Production
278 {
279 patterns.push(DetectedPattern {
280 pattern_type: PatternType::AntiPattern,
281 name: "Frequent CUDA Cache Clearing".to_string(),
282 description: "Frequent CUDA cache clearing can hurt performance".to_string(),
283 severity: Severity::Medium,
284 confidence: 0.85,
285 recommendations: vec![
286 "Consider using gradient accumulation instead".to_string(),
287 "Review memory management strategy".to_string(),
288 ],
289 });
290 }
291
292 if code.contains("grad_norm") && code.contains("clip") {
294 patterns.push(DetectedPattern {
295 pattern_type: PatternType::GoodPattern,
296 name: "Gradient Clipping".to_string(),
297 description: "Proper gradient clipping implementation detected".to_string(),
298 severity: Severity::Info,
299 confidence: 0.9,
300 recommendations: vec!["Consider adaptive gradient clipping".to_string()],
301 });
302 }
303
304 if code.contains("detach()") && code.contains("requires_grad") {
306 patterns.push(DetectedPattern {
307 pattern_type: PatternType::OptimizationOpportunity,
308 name: "Gradient Computation Inefficiency".to_string(),
309 description: "Potential inefficient gradient computation detected".to_string(),
310 severity: Severity::Medium,
311 confidence: 0.75,
312 recommendations: vec![
313 "Consider using torch.no_grad() context".to_string(),
314 "Review gradient requirements".to_string(),
315 ],
316 });
317 }
318
319 Ok(patterns)
320 }
321
322 async fn perform_deep_analysis(
323 &self,
324 code: &str,
325 _context: &ModelContext,
326 ) -> Result<Vec<IdentifiedIssue>> {
327 debug!("Performing deep AI analysis");
328
329 let mut issues = Vec::new();
330
331 tokio::time::sleep(Duration::from_millis(100)).await;
333
334 if code.contains("log") && !code.contains("log1p") && code.contains("softmax") {
336 issues.push(IdentifiedIssue {
337 issue_type: IssueType::NumericalStability,
338 title: "Potential Numerical Instability in Log-Softmax".to_string(),
339 description: "Using log(softmax(x)) can cause numerical instability. Consider using log_softmax directly.".to_string(),
340 severity: Severity::High,
341 confidence: 0.88,
342 suggested_fix: "Replace log(softmax(x)) with log_softmax(x)".to_string(),
343 code_location: None, });
345 }
346
347 if code.contains("attention") && code.contains("matmul") && !code.contains("flash") {
349 issues.push(IdentifiedIssue {
350 issue_type: IssueType::Performance,
351 title: "Inefficient Attention Implementation".to_string(),
352 description:
353 "Standard attention implementation may be inefficient for large sequences."
354 .to_string(),
355 severity: Severity::Medium,
356 confidence: 0.75,
357 suggested_fix:
358 "Consider using Flash Attention or other optimized attention mechanisms"
359 .to_string(),
360 code_location: None,
361 });
362 }
363
364 if code.contains("accumulate") && !code.contains("zero_grad") {
366 issues.push(IdentifiedIssue {
367 issue_type: IssueType::MemoryLeak,
368 title: "Potential Gradient Accumulation Memory Leak".to_string(),
369 description: "Gradient accumulation without zero_grad() can cause memory leaks."
370 .to_string(),
371 severity: Severity::High,
372 confidence: 0.82,
373 suggested_fix: "Ensure optimizer.zero_grad() is called appropriately".to_string(),
374 code_location: None,
375 });
376 }
377
378 Ok(issues)
379 }
380
381 async fn generate_optimization_suggestions(
382 &self,
383 code: &str,
384 context: &ModelContext,
385 ) -> Result<Vec<OptimizationSuggestion>> {
386 debug!("Generating optimization suggestions");
387
388 let mut suggestions = Vec::new();
389
390 if context.model_type == ModelType::Training && !code.contains("autocast") {
392 suggestions.push(OptimizationSuggestion {
393 optimization_type: OptimizationType::MixedPrecision,
394 title: "Enable Mixed Precision Training".to_string(),
395 description: "Mixed precision training can significantly speed up training and reduce memory usage.".to_string(),
396 potential_speedup: 1.5,
397 memory_savings: 0.4,
398 implementation_effort: ImplementationEffort::Low,
399 confidence: 0.9,
400 code_example: Some("with torch.autocast(device_type='cuda', dtype=torch.float16):".to_string()),
401 });
402 }
403
404 if context.model_type == ModelType::Production && !code.contains("compile") {
406 suggestions.push(OptimizationSuggestion {
407 optimization_type: OptimizationType::ModelCompilation,
408 title: "Enable Model Compilation".to_string(),
409 description: "Model compilation can provide significant inference speedups."
410 .to_string(),
411 potential_speedup: 2.0,
412 memory_savings: 0.0,
413 implementation_effort: ImplementationEffort::Low,
414 confidence: 0.85,
415 code_example: Some("model = torch.compile(model)".to_string()),
416 });
417 }
418
419 if context.model_size > 1_000_000_000 && !code.contains("checkpoint") {
421 suggestions.push(OptimizationSuggestion {
422 optimization_type: OptimizationType::MemoryOptimization,
423 title: "Enable Gradient Checkpointing".to_string(),
424 description:
425 "Gradient checkpointing can significantly reduce memory usage for large models."
426 .to_string(),
427 potential_speedup: 0.8, memory_savings: 0.6,
429 implementation_effort: ImplementationEffort::Medium,
430 confidence: 0.88,
431 code_example: Some("torch.utils.checkpoint.checkpoint(layer, x)".to_string()),
432 });
433 }
434
435 Ok(suggestions)
436 }
437
438 async fn detect_vulnerabilities(
439 &self,
440 code: &str,
441 context: &ModelContext,
442 ) -> Result<Vec<SecurityIssue>> {
443 debug!("Detecting security vulnerabilities");
444
445 let mut vulnerabilities = Vec::new();
446
447 if code.contains("pickle.load") && !code.contains("safe_load") {
449 vulnerabilities.push(SecurityIssue {
450 vulnerability_type: VulnerabilityType::CodeExecution,
451 title: "Unsafe Pickle Loading".to_string(),
452 description:
453 "Loading pickle files can execute arbitrary code. Use safe alternatives."
454 .to_string(),
455 severity: Severity::Critical,
456 confidence: 0.95,
457 mitigation: "Use torch.load with weights_only=True or safetensors".to_string(),
458 cve_references: vec!["CWE-502".to_string()],
459 });
460 }
461
462 if code.contains("state_dict")
464 && code.contains("save")
465 && context.model_type == ModelType::Production
466 {
467 vulnerabilities.push(SecurityIssue {
468 vulnerability_type: VulnerabilityType::DataExposure,
469 title: "Potential Model Parameter Exposure".to_string(),
470 description: "Saving full model state may expose sensitive parameters.".to_string(),
471 severity: Severity::Medium,
472 confidence: 0.7,
473 mitigation: "Consider differential privacy or parameter encryption".to_string(),
474 cve_references: vec![],
475 });
476 }
477
478 if code.contains("input") && !code.contains("validate") && !code.contains("sanitize") {
480 vulnerabilities.push(SecurityIssue {
481 vulnerability_type: VulnerabilityType::InputValidation,
482 title: "Missing Input Validation".to_string(),
483 description: "Input validation is important for preventing adversarial attacks."
484 .to_string(),
485 severity: Severity::Medium,
486 confidence: 0.65,
487 mitigation: "Implement input validation and sanitization".to_string(),
488 cve_references: vec![],
489 });
490 }
491
492 Ok(vulnerabilities)
493 }
494
495 async fn predict_performance_characteristics(
496 &self,
497 code: &str,
498 context: &ModelContext,
499 ) -> Result<PerformancePredictions> {
500 debug!("Predicting performance characteristics");
501
502 tokio::time::sleep(Duration::from_millis(50)).await;
504
505 let mut predictions = PerformancePredictions::new();
506
507 predictions.estimated_memory_usage = self.estimate_memory_usage(code, context);
509 predictions.estimated_training_time = self.estimate_training_time(code, context);
510 predictions.estimated_inference_latency = self.estimate_inference_latency(code, context);
511 predictions.scaling_characteristics = self.predict_scaling_behavior(code, context);
512
513 predictions.predicted_bottlenecks = vec![
515 "Attention computation may become bottleneck for long sequences".to_string(),
516 "Memory bandwidth may limit performance for large batch sizes".to_string(),
517 ];
518
519 predictions.confidence_score = 0.75;
520
521 Ok(predictions)
522 }
523
524 async fn detect_fusion_opportunities(
525 &self,
526 operations: &[TensorOperation],
527 ) -> Result<Vec<FusionOpportunity>> {
528 let mut opportunities = Vec::new();
529
530 for window in operations.windows(2) {
532 if let [op1, op2] = window {
533 if matches!(op1.op_type, OperationType::MatMul)
534 && matches!(op2.op_type, OperationType::Add)
535 {
536 opportunities.push(FusionOpportunity {
537 operations: vec![op1.clone(), op2.clone()],
538 fusion_type: FusionType::GEMM,
539 estimated_speedup: 1.3,
540 description: "MatMul + Add can be fused into GEMM operation".to_string(),
541 });
542 }
543 }
544 }
545
546 for window in operations.windows(2) {
548 if let [op1, op2] = window {
549 if matches!(op1.op_type, OperationType::Linear)
550 && matches!(op2.op_type, OperationType::Activation)
551 {
552 opportunities.push(FusionOpportunity {
553 operations: vec![op1.clone(), op2.clone()],
554 fusion_type: FusionType::LinearActivation,
555 estimated_speedup: 1.2,
556 description: "Linear + Activation can be fused".to_string(),
557 });
558 }
559 }
560 }
561
562 Ok(opportunities)
563 }
564
565 async fn detect_memory_optimizations(
566 &self,
567 operations: &[TensorOperation],
568 ) -> Result<Vec<MemoryOptimization>> {
569 let mut optimizations = Vec::new();
570
571 for op in operations {
573 if op.can_be_inplace() && !op.is_inplace {
574 optimizations.push(MemoryOptimization {
575 operation: op.clone(),
576 optimization_type: MemoryOptimizationType::InPlace,
577 memory_savings: op.output_size_bytes,
578 description: format!("Operation {} can be performed in-place", op.name),
579 });
580 }
581 }
582
583 let mut tensor_usage = HashMap::new();
585 for op in operations {
586 for input in &op.inputs {
587 *tensor_usage.entry(input.clone()).or_insert(0) += 1;
588 }
589 }
590
591 for (tensor, usage_count) in tensor_usage {
592 if usage_count == 1 {
593 optimizations.push(MemoryOptimization {
594 operation: TensorOperation::default(),
595 optimization_type: MemoryOptimizationType::TensorReuse,
596 memory_savings: 0, description: format!("Tensor {} can be reused", tensor),
598 });
599 }
600 }
601
602 Ok(optimizations)
603 }
604
605 async fn detect_parallelization_opportunities(
606 &self,
607 operations: &[TensorOperation],
608 ) -> Result<Vec<ParallelizationOpportunity>> {
609 let mut opportunities = Vec::new();
610
611 for (i, op1) in operations.iter().enumerate() {
613 for op2 in operations.iter().skip(i + 1) {
614 if self.operations_are_independent(op1, op2) {
615 opportunities.push(ParallelizationOpportunity {
616 operations: vec![op1.clone(), op2.clone()],
617 parallelization_type: ParallelizationType::DataParallel,
618 estimated_speedup: 1.8,
619 description: "Operations can run in parallel".to_string(),
620 });
621 }
622 }
623 }
624
625 Ok(opportunities)
626 }
627
628 async fn detect_redundant_operations(
629 &self,
630 operations: &[TensorOperation],
631 ) -> Result<Vec<RedundantOperation>> {
632 let mut redundant = Vec::new();
633
634 for (i, op1) in operations.iter().enumerate() {
636 for (_j, op2) in operations.iter().enumerate().skip(i + 1) {
637 if self.operations_are_equivalent(op1, op2) {
638 redundant.push(RedundantOperation {
639 original_operation: op1.clone(),
640 redundant_operation: op2.clone(),
641 redundancy_type: RedundancyType::Duplicate,
642 description: "Operations produce identical results".to_string(),
643 });
644 }
645 }
646 }
647
648 Ok(redundant)
649 }
650
651 fn operations_are_independent(&self, op1: &TensorOperation, op2: &TensorOperation) -> bool {
654 for input1 in &op1.inputs {
656 for output2 in &op2.outputs {
657 if input1 == output2 {
658 return false;
659 }
660 }
661 }
662 for input2 in &op2.inputs {
663 for output1 in &op1.outputs {
664 if input2 == output1 {
665 return false;
666 }
667 }
668 }
669 true
670 }
671
672 fn operations_are_equivalent(&self, op1: &TensorOperation, op2: &TensorOperation) -> bool {
673 op1.op_type == op2.op_type && op1.inputs == op2.inputs && op1.parameters == op2.parameters
674 }
675
676 fn compute_code_hash(&self, code: &str) -> String {
677 use std::collections::hash_map::DefaultHasher;
678 use std::hash::{Hash, Hasher};
679
680 let mut hasher = DefaultHasher::new();
681 code.hash(&mut hasher);
682 format!("{:x}", hasher.finish())
683 }
684
685 fn get_cached_analysis(&self, code_hash: &str) -> Option<&CachedAnalysis> {
686 self.analysis_cache.get(code_hash).and_then(|cached| {
687 let age = std::time::SystemTime::now()
688 .duration_since(cached.timestamp)
689 .unwrap_or_default();
690
691 if age.as_secs() < self.config.cache_expiration_hours * 3600 {
692 Some(cached)
693 } else {
694 None
695 }
696 })
697 }
698
699 fn cache_analysis(&mut self, code_hash: String, result: &CodeAnalysisResult) {
700 self.analysis_cache.insert(
701 code_hash.clone(),
702 CachedAnalysis {
703 result: result.clone(),
704 timestamp: std::time::SystemTime::now(),
705 code_hash,
706 },
707 );
708 }
709
710 fn calculate_quality_score(&self, result: &CodeAnalysisResult) -> f64 {
711 let mut score: f64 = 100.0;
712
713 for issue in &result.identified_issues {
715 match issue.severity {
716 Severity::Critical => score -= 20.0,
717 Severity::High => score -= 10.0,
718 Severity::Medium => score -= 5.0,
719 Severity::Low => score -= 2.0,
720 Severity::Info => score -= 0.0,
721 }
722 }
723
724 for vulnerability in &result.security_issues {
726 match vulnerability.severity {
727 Severity::Critical => score -= 25.0,
728 Severity::High => score -= 15.0,
729 Severity::Medium => score -= 8.0,
730 Severity::Low => score -= 3.0,
731 Severity::Info => score -= 0.0,
732 }
733 }
734
735 for pattern in &result.detected_patterns {
737 if pattern.pattern_type == PatternType::GoodPattern {
738 score += 2.0;
739 }
740 }
741
742 score.max(0.0).min(100.0)
743 }
744
745 fn calculate_confidence_score(&self, result: &CodeAnalysisResult) -> f64 {
746 let mut total_confidence = 0.0;
747 let mut count = 0;
748
749 for issue in &result.identified_issues {
750 total_confidence += issue.confidence;
751 count += 1;
752 }
753
754 for pattern in &result.detected_patterns {
755 total_confidence += pattern.confidence;
756 count += 1;
757 }
758
759 if count > 0 {
760 total_confidence / count as f64
761 } else {
762 1.0
763 }
764 }
765
766 fn estimate_memory_usage(&self, code: &str, context: &ModelContext) -> f64 {
767 let base_memory = context.model_size as f64 * 4.0; let mut multiplier = 1.0;
771 if code.contains("gradient_accumulation") {
772 multiplier += 0.5;
773 }
774 if code.contains("mixed_precision") {
775 multiplier *= 0.7;
776 }
777
778 base_memory * multiplier / 1_000_000.0 }
780
781 fn estimate_training_time(&self, code: &str, context: &ModelContext) -> f64 {
782 let base_time = (context.model_size as f64).log10() * 10.0;
784
785 let mut multiplier = 1.0;
786 if code.contains("mixed_precision") {
787 multiplier *= 0.6;
788 }
789 if code.contains("gradient_checkpointing") {
790 multiplier *= 1.3;
791 }
792
793 base_time * multiplier
794 }
795
796 fn estimate_inference_latency(&self, code: &str, context: &ModelContext) -> f64 {
797 let base_latency = (context.model_size as f64).log10() * 5.0;
799
800 let mut multiplier = 1.0;
801 if code.contains("compile") {
802 multiplier *= 0.5;
803 }
804 if code.contains("quantization") {
805 multiplier *= 0.7;
806 }
807
808 base_latency * multiplier
809 }
810
811 fn predict_scaling_behavior(
812 &self,
813 _code: &str,
814 context: &ModelContext,
815 ) -> ScalingCharacteristics {
816 ScalingCharacteristics {
817 batch_size_scaling: if context.model_size > 1_000_000_000 {
818 ScalingBehavior::Sublinear
819 } else {
820 ScalingBehavior::Linear
821 },
822 sequence_length_scaling: ScalingBehavior::Quadratic, model_size_scaling: ScalingBehavior::Linear,
824 memory_scaling: ScalingBehavior::Linear,
825 }
826 }
827
828 fn estimate_optimization_speedup(&self, report: &TensorOptimizationReport) -> f64 {
829 let mut speedup = 1.0;
830
831 for fusion in &report.fusion_opportunities {
832 speedup *= fusion.estimated_speedup;
833 }
834
835 for parallel in &report.parallelization_opportunities {
836 speedup *= parallel.estimated_speedup;
837 }
838
839 speedup.min(10.0) }
841
842 fn estimate_memory_savings(&self, report: &TensorOptimizationReport) -> f64 {
843 let total_savings: u64 =
844 report.memory_optimizations.iter().map(|opt| opt.memory_savings).sum();
845
846 total_savings as f64 / 1_000_000.0 }
848
849 async fn analyze_error_patterns(
850 &self,
851 error_context: &ErrorContext,
852 ) -> Result<Vec<ProbableCause>> {
853 let mut causes = Vec::new();
854
855 match error_context.error_type.as_str() {
856 "OutOfMemoryError" => {
857 causes.push(ProbableCause {
858 cause: "Batch size too large".to_string(),
859 probability: 0.8,
860 evidence: vec!["GPU memory limit exceeded".to_string()],
861 });
862 causes.push(ProbableCause {
863 cause: "Model too large for available memory".to_string(),
864 probability: 0.6,
865 evidence: vec!["Model parameter count".to_string()],
866 });
867 },
868 "GradientExplosion" => {
869 causes.push(ProbableCause {
870 cause: "Learning rate too high".to_string(),
871 probability: 0.7,
872 evidence: vec!["Gradient norm increasing rapidly".to_string()],
873 });
874 },
875 _ => {
876 causes.push(ProbableCause {
877 cause: "Unknown error pattern".to_string(),
878 probability: 0.3,
879 evidence: vec![],
880 });
881 },
882 }
883
884 Ok(causes)
885 }
886
887 async fn generate_suggested_fixes(
888 &self,
889 error_context: &ErrorContext,
890 ) -> Result<Vec<SuggestedFix>> {
891 let mut fixes = Vec::new();
892
893 match error_context.error_type.as_str() {
894 "OutOfMemoryError" => {
895 fixes.push(SuggestedFix {
896 description: "Reduce batch size".to_string(),
897 implementation: "batch_size = batch_size // 2".to_string(),
898 confidence: 0.9,
899 estimated_impact: "Should free ~50% of memory".to_string(),
900 });
901 fixes.push(SuggestedFix {
902 description: "Enable gradient checkpointing".to_string(),
903 implementation: "model.gradient_checkpointing_enable()".to_string(),
904 confidence: 0.8,
905 estimated_impact: "Reduces memory by ~40% with 10-20% speed penalty"
906 .to_string(),
907 });
908 },
909 "GradientExplosion" => {
910 fixes.push(SuggestedFix {
911 description: "Add gradient clipping".to_string(),
912 implementation:
913 "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)"
914 .to_string(),
915 confidence: 0.95,
916 estimated_impact: "Prevents gradient explosion".to_string(),
917 });
918 },
919 _ => {},
920 }
921
922 Ok(fixes)
923 }
924
925 async fn generate_debugging_steps(
926 &self,
927 error_context: &ErrorContext,
928 ) -> Result<Vec<DebuggingStep>> {
929 let mut steps = Vec::new();
930
931 steps.push(DebuggingStep {
932 step_number: 1,
933 description: "Check system resources".to_string(),
934 command: Some("nvidia-smi".to_string()),
935 expected_output: "GPU memory usage and availability".to_string(),
936 });
937
938 steps.push(DebuggingStep {
939 step_number: 2,
940 description: "Verify model configuration".to_string(),
941 command: Some("print(model)".to_string()),
942 expected_output: "Model architecture and parameter count".to_string(),
943 });
944
945 if error_context.error_type.as_str() == "OutOfMemoryError" {
946 steps.push(DebuggingStep {
947 step_number: 3,
948 description: "Check tensor shapes and batch size".to_string(),
949 command: Some(
950 "print(f'Batch size: {batch_size}, Input shape: {input.shape}')".to_string(),
951 ),
952 expected_output: "Current batch size and input dimensions".to_string(),
953 });
954 }
955
956 Ok(steps)
957 }
958
959 async fn find_related_documentation(
960 &self,
961 error_context: &ErrorContext,
962 ) -> Result<Vec<DocumentationReference>> {
963 let mut references = Vec::new();
964
965 match error_context.error_type.as_str() {
966 "OutOfMemoryError" => {
967 references.push(DocumentationReference {
968 title: "Memory Management Best Practices".to_string(),
969 url: "https://docs.trustformers.ai/memory-management".to_string(),
970 relevance_score: 0.95,
971 });
972 references.push(DocumentationReference {
973 title: "Gradient Checkpointing Guide".to_string(),
974 url: "https://docs.trustformers.ai/gradient-checkpointing".to_string(),
975 relevance_score: 0.8,
976 });
977 },
978 "GradientExplosion" => {
979 references.push(DocumentationReference {
980 title: "Training Stability Guide".to_string(),
981 url: "https://docs.trustformers.ai/training-stability".to_string(),
982 relevance_score: 0.9,
983 });
984 },
985 _ => {},
986 }
987
988 Ok(references)
989 }
990
991 fn calculate_debugging_confidence(&self, assistance: &DebuggingAssistance) -> f64 {
992 let avg_cause_probability =
993 assistance.probable_causes.iter().map(|cause| cause.probability).sum::<f64>()
994 / assistance.probable_causes.len().max(1) as f64;
995
996 let avg_fix_confidence =
997 assistance.suggested_fixes.iter().map(|fix| fix.confidence).sum::<f64>()
998 / assistance.suggested_fixes.len().max(1) as f64;
999
1000 (avg_cause_probability + avg_fix_confidence) / 2.0
1001 }
1002}
1003
1004#[derive(Debug)]
1008struct ModelPatternDatabase {
1009 #[allow(dead_code)]
1010 patterns: HashMap<String, PatternDefinition>,
1011}
1012
1013impl ModelPatternDatabase {
1014 fn new() -> Self {
1015 let mut patterns = HashMap::new();
1016
1017 patterns.insert(
1019 "gradient_clipping".to_string(),
1020 PatternDefinition {
1021 name: "Gradient Clipping".to_string(),
1022 pattern_type: PatternType::GoodPattern,
1023 keywords: vec![
1024 "clip_grad_norm".to_string(),
1025 "gradient".to_string(),
1026 "clip".to_string(),
1027 ],
1028 severity: Severity::Info,
1029 description: "Proper gradient clipping prevents gradient explosion".to_string(),
1030 },
1031 );
1032
1033 Self { patterns }
1034 }
1035}
1036
1037#[derive(Debug, Clone)]
1038#[allow(dead_code)]
1039struct PatternDefinition {
1040 #[allow(dead_code)]
1041 name: String,
1042 pattern_type: PatternType,
1043 keywords: Vec<String>,
1044 severity: Severity,
1045 description: String,
1046}
1047
1048#[derive(Debug, Clone)]
1050pub struct ModelContext {
1051 pub model_type: ModelType,
1052 pub model_size: u64, pub framework: String,
1054 pub target_hardware: String,
1055 pub training_stage: TrainingStage,
1056}
1057
1058#[derive(Debug, Clone, PartialEq)]
1059pub enum ModelType {
1060 Training,
1061 Inference,
1062 Production,
1063 Development,
1064}
1065
1066#[derive(Debug, Clone)]
1067pub enum TrainingStage {
1068 Training,
1069 Development,
1070 Pretraining,
1071 Finetuning,
1072 Evaluation,
1073 Inference,
1074}
1075
1076#[derive(Debug, Clone, Serialize, Deserialize)]
1078pub struct CodeAnalysisResult {
1079 pub quality_score: f64,
1080 pub detected_patterns: Vec<DetectedPattern>,
1081 pub identified_issues: Vec<IdentifiedIssue>,
1082 pub optimization_suggestions: Vec<OptimizationSuggestion>,
1083 pub security_issues: Vec<SecurityIssue>,
1084 pub performance_predictions: PerformancePredictions,
1085 pub analysis_metadata: AnalysisMetadata,
1086}
1087
1088impl CodeAnalysisResult {
1089 fn new() -> Self {
1090 Self {
1091 quality_score: 0.0,
1092 detected_patterns: Vec::new(),
1093 identified_issues: Vec::new(),
1094 optimization_suggestions: Vec::new(),
1095 security_issues: Vec::new(),
1096 performance_predictions: PerformancePredictions::new(),
1097 analysis_metadata: AnalysisMetadata::default(),
1098 }
1099 }
1100}
1101
1102#[derive(Debug, Clone, Serialize, Deserialize)]
1103pub struct DetectedPattern {
1104 pub pattern_type: PatternType,
1105 pub name: String,
1106 pub description: String,
1107 pub severity: Severity,
1108 pub confidence: f64,
1109 pub recommendations: Vec<String>,
1110}
1111
1112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1113pub enum PatternType {
1114 GoodPattern,
1115 AntiPattern,
1116 OptimizationOpportunity,
1117 SecurityConcern,
1118}
1119
1120#[derive(Debug, Clone, Serialize, Deserialize)]
1121pub struct IdentifiedIssue {
1122 pub issue_type: IssueType,
1123 pub title: String,
1124 pub description: String,
1125 pub severity: Severity,
1126 pub confidence: f64,
1127 pub suggested_fix: String,
1128 pub code_location: Option<CodeLocation>,
1129}
1130
1131#[derive(Debug, Clone, Serialize, Deserialize)]
1132pub enum IssueType {
1133 NumericalStability,
1134 Performance,
1135 MemoryLeak,
1136 LogicError,
1137 TypeMismatch,
1138 ResourceLeak,
1139}
1140
1141#[derive(Debug, Clone, Serialize, Deserialize)]
1142pub struct CodeLocation {
1143 pub file: String,
1144 pub line: u32,
1145 pub column: u32,
1146}
1147
1148#[derive(Debug, Clone, Serialize, Deserialize)]
1149pub struct OptimizationSuggestion {
1150 pub optimization_type: OptimizationType,
1151 pub title: String,
1152 pub description: String,
1153 pub potential_speedup: f64,
1154 pub memory_savings: f64,
1155 pub implementation_effort: ImplementationEffort,
1156 pub confidence: f64,
1157 pub code_example: Option<String>,
1158}
1159
1160#[derive(Debug, Clone, Serialize, Deserialize)]
1161pub enum OptimizationType {
1162 MixedPrecision,
1163 ModelCompilation,
1164 MemoryOptimization,
1165 ComputationOptimization,
1166 IOOptimization,
1167 ParallelizationOptimization,
1168}
1169
1170#[derive(Debug, Clone, Serialize, Deserialize)]
1171pub enum ImplementationEffort {
1172 Low,
1173 Medium,
1174 High,
1175}
1176
1177#[derive(Debug, Clone, Serialize, Deserialize)]
1178pub struct SecurityIssue {
1179 pub vulnerability_type: VulnerabilityType,
1180 pub title: String,
1181 pub description: String,
1182 pub severity: Severity,
1183 pub confidence: f64,
1184 pub mitigation: String,
1185 pub cve_references: Vec<String>,
1186}
1187
1188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1189pub enum VulnerabilityType {
1190 CodeExecution,
1191 DataExposure,
1192 InputValidation,
1193 AuthenticationBypass,
1194 PrivilegeEscalation,
1195}
1196
1197#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1198pub enum Severity {
1199 Critical,
1200 High,
1201 Medium,
1202 Low,
1203 Info,
1204}
1205
1206#[derive(Debug, Clone, Serialize, Deserialize)]
1207pub struct PerformancePredictions {
1208 pub estimated_memory_usage: f64, pub estimated_training_time: f64, pub estimated_inference_latency: f64, pub scaling_characteristics: ScalingCharacteristics,
1212 pub predicted_bottlenecks: Vec<String>,
1213 pub confidence_score: f64,
1214}
1215
1216impl PerformancePredictions {
1217 fn new() -> Self {
1218 Self {
1219 estimated_memory_usage: 0.0,
1220 estimated_training_time: 0.0,
1221 estimated_inference_latency: 0.0,
1222 scaling_characteristics: ScalingCharacteristics::default(),
1223 predicted_bottlenecks: Vec::new(),
1224 confidence_score: 0.0,
1225 }
1226 }
1227}
1228
1229#[derive(Debug, Clone, Serialize, Deserialize)]
1230pub struct ScalingCharacteristics {
1231 pub batch_size_scaling: ScalingBehavior,
1232 pub sequence_length_scaling: ScalingBehavior,
1233 pub model_size_scaling: ScalingBehavior,
1234 pub memory_scaling: ScalingBehavior,
1235}
1236
1237impl Default for ScalingCharacteristics {
1238 fn default() -> Self {
1239 Self {
1240 batch_size_scaling: ScalingBehavior::Linear,
1241 sequence_length_scaling: ScalingBehavior::Linear,
1242 model_size_scaling: ScalingBehavior::Linear,
1243 memory_scaling: ScalingBehavior::Linear,
1244 }
1245 }
1246}
1247
1248#[derive(Debug, Clone, Serialize, Deserialize)]
1249pub enum ScalingBehavior {
1250 Constant,
1251 Linear,
1252 Quadratic,
1253 Exponential,
1254 Sublinear,
1255}
1256
1257#[derive(Debug, Clone, Serialize, Deserialize)]
1258pub struct AnalysisMetadata {
1259 pub analysis_duration: Duration,
1260 pub confidence_score: f64,
1261 pub analyzer_version: String,
1262 pub timestamp: std::time::SystemTime,
1263}
1264
1265impl Default for AnalysisMetadata {
1266 fn default() -> Self {
1267 Self {
1268 analysis_duration: Duration::from_secs(0),
1269 confidence_score: 0.0,
1270 analyzer_version: "1.0.0".to_string(),
1271 timestamp: std::time::SystemTime::now(),
1272 }
1273 }
1274}
1275
1276#[derive(Debug, Clone)]
1279pub struct TensorOperation {
1280 pub name: String,
1281 pub op_type: OperationType,
1282 pub inputs: Vec<String>,
1283 pub outputs: Vec<String>,
1284 pub parameters: HashMap<String, String>,
1285 pub output_size_bytes: u64,
1286 pub is_inplace: bool,
1287}
1288
1289impl Default for TensorOperation {
1290 fn default() -> Self {
1291 Self {
1292 name: String::new(),
1293 op_type: OperationType::Unknown,
1294 inputs: Vec::new(),
1295 outputs: Vec::new(),
1296 parameters: HashMap::new(),
1297 output_size_bytes: 0,
1298 is_inplace: false,
1299 }
1300 }
1301}
1302
1303impl TensorOperation {
1304 fn can_be_inplace(&self) -> bool {
1305 matches!(
1306 self.op_type,
1307 OperationType::Add | OperationType::Mul | OperationType::Activation
1308 )
1309 }
1310}
1311
1312#[derive(Debug, Clone, PartialEq)]
1313pub enum OperationType {
1314 MatMul,
1315 Add,
1316 Mul,
1317 Conv2D,
1318 Linear,
1319 Activation,
1320 Pooling,
1321 BatchNorm,
1322 LayerNorm,
1323 Attention,
1324 Unknown,
1325}
1326
1327#[derive(Debug, Clone)]
1328pub struct TensorOptimizationReport {
1329 pub fusion_opportunities: Vec<FusionOpportunity>,
1330 pub memory_optimizations: Vec<MemoryOptimization>,
1331 pub parallelization_opportunities: Vec<ParallelizationOpportunity>,
1332 pub redundant_operations: Vec<RedundantOperation>,
1333 pub estimated_speedup: f64,
1334 pub estimated_memory_savings: f64,
1335}
1336
1337impl TensorOptimizationReport {
1338 fn new() -> Self {
1339 Self {
1340 fusion_opportunities: Vec::new(),
1341 memory_optimizations: Vec::new(),
1342 parallelization_opportunities: Vec::new(),
1343 redundant_operations: Vec::new(),
1344 estimated_speedup: 1.0,
1345 estimated_memory_savings: 0.0,
1346 }
1347 }
1348}
1349
1350#[derive(Debug, Clone)]
1351pub struct FusionOpportunity {
1352 pub operations: Vec<TensorOperation>,
1353 pub fusion_type: FusionType,
1354 pub estimated_speedup: f64,
1355 pub description: String,
1356}
1357
1358#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1359pub enum FusionType {
1360 GEMM,
1361 LinearActivation,
1362 ConvBatchNorm,
1363 AttentionQKV,
1364}
1365
1366#[derive(Debug, Clone)]
1367pub struct MemoryOptimization {
1368 pub operation: TensorOperation,
1369 pub optimization_type: MemoryOptimizationType,
1370 pub memory_savings: u64,
1371 pub description: String,
1372}
1373
1374#[derive(Debug, Clone)]
1375pub enum MemoryOptimizationType {
1376 InPlace,
1377 TensorReuse,
1378 MemoryPool,
1379 GradientCheckpointing,
1380}
1381
1382#[derive(Debug, Clone)]
1383pub struct ParallelizationOpportunity {
1384 pub operations: Vec<TensorOperation>,
1385 pub parallelization_type: ParallelizationType,
1386 pub estimated_speedup: f64,
1387 pub description: String,
1388}
1389
1390#[derive(Debug, Clone)]
1391pub enum ParallelizationType {
1392 DataParallel,
1393 ModelParallel,
1394 PipelineParallel,
1395 TensorParallel,
1396}
1397
1398#[derive(Debug, Clone)]
1399pub struct RedundantOperation {
1400 pub original_operation: TensorOperation,
1401 pub redundant_operation: TensorOperation,
1402 pub redundancy_type: RedundancyType,
1403 pub description: String,
1404}
1405
1406#[derive(Debug, Clone)]
1407pub enum RedundancyType {
1408 Duplicate,
1409 Subsumed,
1410 Unnecessary,
1411}
1412
1413#[derive(Debug, Clone)]
1416pub struct ErrorContext {
1417 pub error_type: String,
1418 pub error_message: String,
1419 pub stack_trace: Option<String>,
1420 pub system_info: SystemInfo,
1421 pub model_info: Option<ModelContext>,
1422}
1423
1424#[derive(Debug, Clone)]
1425pub struct SystemInfo {
1426 pub gpu_memory_total: u64,
1427 pub gpu_memory_used: u64,
1428 pub cpu_count: u32,
1429 pub ram_total: u64,
1430 pub ram_used: u64,
1431}
1432
1433#[derive(Debug, Clone)]
1434pub struct DebuggingAssistance {
1435 pub probable_causes: Vec<ProbableCause>,
1436 pub suggested_fixes: Vec<SuggestedFix>,
1437 pub debugging_steps: Vec<DebuggingStep>,
1438 pub related_documentation: Vec<DocumentationReference>,
1439 pub confidence_score: f64,
1440}
1441
1442impl DebuggingAssistance {
1443 fn new() -> Self {
1444 Self {
1445 probable_causes: Vec::new(),
1446 suggested_fixes: Vec::new(),
1447 debugging_steps: Vec::new(),
1448 related_documentation: Vec::new(),
1449 confidence_score: 0.0,
1450 }
1451 }
1452}
1453
1454#[derive(Debug, Clone)]
1455pub struct ProbableCause {
1456 pub cause: String,
1457 pub probability: f64,
1458 pub evidence: Vec<String>,
1459}
1460
1461#[derive(Debug, Clone)]
1462pub struct SuggestedFix {
1463 pub description: String,
1464 pub implementation: String,
1465 pub confidence: f64,
1466 pub estimated_impact: String,
1467}
1468
1469#[derive(Debug, Clone)]
1470pub struct DebuggingStep {
1471 pub step_number: u32,
1472 pub description: String,
1473 pub command: Option<String>,
1474 pub expected_output: String,
1475}
1476
1477#[derive(Debug, Clone)]
1478pub struct DocumentationReference {
1479 pub title: String,
1480 pub url: String,
1481 pub relevance_score: f64,
1482}
1483
1484#[derive(Debug, Serialize, Deserialize)]
1487pub struct AnalysisPerformanceMetrics {
1488 pub total_analyses: u64,
1489 pub average_analysis_time: Duration,
1490 pub cache_hit_rate: f64,
1491 pub cached_results: usize,
1492}
1493
1494#[macro_export]
1496macro_rules! ai_analyze {
1497 ($code:expr, $context:expr) => {{
1498 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1499 analyzer.analyze_model_code($code, $context).await
1500 }};
1501}
1502
1503#[cfg(test)]
1504mod tests {
1505 use super::*;
1506
1507 #[tokio::test]
1508 async fn test_ai_code_analyzer_creation() {
1509 let analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1510 assert!(analyzer.config.enable_deep_analysis);
1511 }
1512
1513 #[tokio::test]
1514 async fn test_pattern_detection() {
1515 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1516
1517 let code = r#"
1518 import torch
1519
1520 def train_step(model, data):
1521 torch.cuda.empty_cache() # Should trigger anti-pattern
1522 grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Good pattern
1523 return grad_norm
1524 "#;
1525
1526 let context = ModelContext {
1527 model_type: ModelType::Production,
1528 model_size: 1_000_000,
1529 framework: "PyTorch".to_string(),
1530 target_hardware: "CUDA".to_string(),
1531 training_stage: TrainingStage::Training,
1532 };
1533
1534 let result = analyzer
1535 .analyze_model_code(code, context)
1536 .await
1537 .expect("async operation failed");
1538 assert!(!result.detected_patterns.is_empty());
1539 }
1540
1541 #[tokio::test]
1542 async fn test_security_vulnerability_detection() {
1543 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1544
1545 let code = r#"
1546 import pickle
1547
1548 def load_model(path):
1549 with open(path, 'rb') as f:
1550 model = pickle.load(f) # Should trigger security warning
1551 return model
1552 "#;
1553
1554 let context = ModelContext {
1555 model_type: ModelType::Production,
1556 model_size: 1_000_000,
1557 framework: "PyTorch".to_string(),
1558 target_hardware: "CUDA".to_string(),
1559 training_stage: TrainingStage::Inference,
1560 };
1561
1562 let result = analyzer
1563 .analyze_model_code(code, context)
1564 .await
1565 .expect("async operation failed");
1566 assert!(!result.security_issues.is_empty());
1567 assert_eq!(
1568 result.security_issues[0].vulnerability_type,
1569 VulnerabilityType::CodeExecution
1570 );
1571 }
1572
1573 #[tokio::test]
1574 async fn test_tensor_operation_analysis() {
1575 let analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1576
1577 let operations = vec![
1578 TensorOperation {
1579 name: "matmul1".to_string(),
1580 op_type: OperationType::MatMul,
1581 inputs: vec!["a".to_string(), "b".to_string()],
1582 outputs: vec!["c".to_string()],
1583 parameters: HashMap::new(),
1584 output_size_bytes: 1024,
1585 is_inplace: false,
1586 },
1587 TensorOperation {
1588 name: "add1".to_string(),
1589 op_type: OperationType::Add,
1590 inputs: vec!["c".to_string(), "bias".to_string()],
1591 outputs: vec!["d".to_string()],
1592 parameters: HashMap::new(),
1593 output_size_bytes: 1024,
1594 is_inplace: false,
1595 },
1596 ];
1597
1598 let report = analyzer
1599 .analyze_tensor_operations(&operations)
1600 .await
1601 .expect("tensor operation failed");
1602 assert!(!report.fusion_opportunities.is_empty());
1603 assert_eq!(report.fusion_opportunities[0].fusion_type, FusionType::GEMM);
1604 }
1605
1606 #[tokio::test]
1607 async fn test_performance_metrics() {
1608 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1609
1610 let code = "print('hello')";
1612 let context = ModelContext {
1613 model_type: ModelType::Development,
1614 model_size: 1000,
1615 framework: "PyTorch".to_string(),
1616 target_hardware: "CPU".to_string(),
1617 training_stage: TrainingStage::Development,
1618 };
1619
1620 analyzer
1621 .analyze_model_code(code, context.clone())
1622 .await
1623 .expect("async operation failed");
1624 analyzer
1625 .analyze_model_code(code, context)
1626 .await
1627 .expect("async operation failed"); let metrics = analyzer.get_performance_metrics();
1630 assert_eq!(metrics.total_analyses, 2);
1631 assert!(metrics.cache_hit_rate > 0.0);
1632 }
1633
1634 #[tokio::test]
1635 async fn test_debugging_assistance() {
1636 let analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1637
1638 let error_context = ErrorContext {
1639 error_type: "OutOfMemoryError".to_string(),
1640 error_message: "CUDA out of memory".to_string(),
1641 stack_trace: None,
1642 system_info: SystemInfo {
1643 gpu_memory_total: 8_000_000_000,
1644 gpu_memory_used: 7_500_000_000,
1645 cpu_count: 8,
1646 ram_total: 32_000_000_000,
1647 ram_used: 16_000_000_000,
1648 },
1649 model_info: None,
1650 };
1651
1652 let assistance = analyzer
1653 .automated_debugging_assistance(&error_context)
1654 .await
1655 .expect("async operation failed");
1656 assert!(!assistance.probable_causes.is_empty());
1657 assert!(!assistance.suggested_fixes.is_empty());
1658 assert!(assistance.confidence_score > 0.0);
1659 }
1660}