1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::time::{Duration, Instant};
11use tracing::{debug, info};
12
13#[derive(Debug)]
15pub struct AICodeAnalyzer {
16 config: AIAnalysisConfig,
17 analysis_cache: HashMap<String, CachedAnalysis>,
18 #[allow(dead_code)]
19 pattern_database: ModelPatternDatabase,
20 performance_monitor: AnalysisPerformanceMonitor,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct AIAnalysisConfig {
26 pub enable_deep_analysis: bool,
28 pub enable_pattern_recognition: bool,
30 pub enable_optimization_suggestions: bool,
32 pub enable_vulnerability_detection: bool,
34 pub enable_performance_prediction: bool,
36 pub max_analysis_time_secs: u64,
38 pub confidence_threshold: f64,
40 pub enable_caching: bool,
42 pub cache_expiration_hours: u64,
44}
45
46impl Default for AIAnalysisConfig {
47 fn default() -> Self {
48 Self {
49 enable_deep_analysis: true,
50 enable_pattern_recognition: true,
51 enable_optimization_suggestions: true,
52 enable_vulnerability_detection: true,
53 enable_performance_prediction: true,
54 max_analysis_time_secs: 30,
55 confidence_threshold: 0.75,
56 enable_caching: true,
57 cache_expiration_hours: 24,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64struct CachedAnalysis {
65 result: CodeAnalysisResult,
66 timestamp: std::time::SystemTime,
67 code_hash: String,
68}
69
70#[derive(Debug)]
72struct AnalysisPerformanceMonitor {
73 analysis_count: u64,
74 total_analysis_time: Duration,
75 cache_hits: u64,
76 cache_misses: u64,
77}
78
79impl AnalysisPerformanceMonitor {
80 fn new() -> Self {
81 Self {
82 analysis_count: 0,
83 total_analysis_time: Duration::from_secs(0),
84 cache_hits: 0,
85 cache_misses: 0,
86 }
87 }
88
89 fn record_analysis(&mut self, duration: Duration, cache_hit: bool) {
90 self.analysis_count += 1;
91 self.total_analysis_time += duration;
92 if cache_hit {
93 self.cache_hits += 1;
94 } else {
95 self.cache_misses += 1;
96 }
97 }
98
99 fn average_analysis_time(&self) -> Duration {
100 if self.analysis_count > 0 {
101 self.total_analysis_time / self.analysis_count as u32
102 } else {
103 Duration::from_secs(0)
104 }
105 }
106
107 fn cache_hit_rate(&self) -> f64 {
108 let total = self.cache_hits + self.cache_misses;
109 if total > 0 {
110 self.cache_hits as f64 / total as f64
111 } else {
112 0.0
113 }
114 }
115}
116
117impl AICodeAnalyzer {
118 pub fn new(config: AIAnalysisConfig) -> Self {
120 Self {
121 config,
122 analysis_cache: HashMap::new(),
123 pattern_database: ModelPatternDatabase::new(),
124 performance_monitor: AnalysisPerformanceMonitor::new(),
125 }
126 }
127
128 pub async fn analyze_model_code(
130 &mut self,
131 code: &str,
132 context: ModelContext,
133 ) -> Result<CodeAnalysisResult> {
134 let start_time = Instant::now();
135 let code_hash = self.compute_code_hash(code);
136
137 if self.config.enable_caching {
139 if let Some(cached) = self.get_cached_analysis(&code_hash) {
140 let result = cached.result.clone();
141 self.performance_monitor.record_analysis(start_time.elapsed(), true);
142 return Ok(result);
143 }
144 }
145
146 info!(
147 "Starting AI code analysis for {} lines of code",
148 code.lines().count()
149 );
150
151 let mut result = CodeAnalysisResult::new();
152
153 if self.config.enable_pattern_recognition {
155 let patterns = self.detect_code_patterns(code, &context).await?;
156 result.detected_patterns = patterns;
157 }
158
159 if self.config.enable_deep_analysis {
161 let issues = self.perform_deep_analysis(code, &context).await?;
162 result.identified_issues = issues;
163 }
164
165 if self.config.enable_optimization_suggestions {
167 let optimizations = self.generate_optimization_suggestions(code, &context).await?;
168 result.optimization_suggestions = optimizations;
169 }
170
171 if self.config.enable_vulnerability_detection {
173 let vulnerabilities = self.detect_vulnerabilities(code, &context).await?;
174 result.security_issues = vulnerabilities;
175 }
176
177 if self.config.enable_performance_prediction {
179 let predictions = self.predict_performance_characteristics(code, &context).await?;
180 result.performance_predictions = predictions;
181 }
182
183 result.quality_score = self.calculate_quality_score(&result);
185 result.analysis_metadata = AnalysisMetadata {
186 analysis_duration: start_time.elapsed(),
187 confidence_score: self.calculate_confidence_score(&result),
188 analyzer_version: "1.0.0".to_string(),
189 timestamp: std::time::SystemTime::now(),
190 };
191
192 if self.config.enable_caching {
194 self.cache_analysis(code_hash, &result);
195 }
196
197 self.performance_monitor.record_analysis(start_time.elapsed(), false);
198
199 info!(
200 "AI code analysis completed in {:?} with quality score: {:.2}",
201 start_time.elapsed(),
202 result.quality_score
203 );
204
205 Ok(result)
206 }
207
208 pub async fn analyze_tensor_operations(
210 &self,
211 operations: &[TensorOperation],
212 ) -> Result<TensorOptimizationReport> {
213 debug!("Analyzing {} tensor operations", operations.len());
214
215 let mut report = TensorOptimizationReport::new();
216
217 report.fusion_opportunities = self.detect_fusion_opportunities(operations).await?;
219 report.memory_optimizations = self.detect_memory_optimizations(operations).await?;
220 report.parallelization_opportunities =
221 self.detect_parallelization_opportunities(operations).await?;
222 report.redundant_operations = self.detect_redundant_operations(operations).await?;
223
224 report.estimated_speedup = self.estimate_optimization_speedup(&report);
226 report.estimated_memory_savings = self.estimate_memory_savings(&report);
227
228 Ok(report)
229 }
230
231 pub async fn automated_debugging_assistance(
233 &self,
234 error_context: &ErrorContext,
235 ) -> Result<DebuggingAssistance> {
236 info!(
237 "Providing automated debugging assistance for error: {}",
238 error_context.error_type
239 );
240
241 let mut assistance = DebuggingAssistance::new();
242
243 assistance.probable_causes = self.analyze_error_patterns(error_context).await?;
245 assistance.suggested_fixes = self.generate_suggested_fixes(error_context).await?;
246 assistance.debugging_steps = self.generate_debugging_steps(error_context).await?;
247 assistance.related_documentation = self.find_related_documentation(error_context).await?;
248
249 assistance.confidence_score = self.calculate_debugging_confidence(&assistance);
251
252 Ok(assistance)
253 }
254
255 pub fn get_performance_metrics(&self) -> AnalysisPerformanceMetrics {
257 AnalysisPerformanceMetrics {
258 total_analyses: self.performance_monitor.analysis_count,
259 average_analysis_time: self.performance_monitor.average_analysis_time(),
260 cache_hit_rate: self.performance_monitor.cache_hit_rate(),
261 cached_results: self.analysis_cache.len(),
262 }
263 }
264
265 async fn detect_code_patterns(
268 &self,
269 code: &str,
270 context: &ModelContext,
271 ) -> Result<Vec<DetectedPattern>> {
272 debug!("Detecting code patterns");
273
274 let mut patterns = Vec::new();
275
276 if code.contains("torch.cuda.empty_cache()") && context.model_type == ModelType::Production
278 {
279 patterns.push(DetectedPattern {
280 pattern_type: PatternType::AntiPattern,
281 name: "Frequent CUDA Cache Clearing".to_string(),
282 description: "Frequent CUDA cache clearing can hurt performance".to_string(),
283 severity: Severity::Medium,
284 confidence: 0.85,
285 recommendations: vec![
286 "Consider using gradient accumulation instead".to_string(),
287 "Review memory management strategy".to_string(),
288 ],
289 });
290 }
291
292 if code.contains("grad_norm") && code.contains("clip") {
294 patterns.push(DetectedPattern {
295 pattern_type: PatternType::GoodPattern,
296 name: "Gradient Clipping".to_string(),
297 description: "Proper gradient clipping implementation detected".to_string(),
298 severity: Severity::Info,
299 confidence: 0.9,
300 recommendations: vec!["Consider adaptive gradient clipping".to_string()],
301 });
302 }
303
304 if code.contains("detach()") && code.contains("requires_grad") {
306 patterns.push(DetectedPattern {
307 pattern_type: PatternType::OptimizationOpportunity,
308 name: "Gradient Computation Inefficiency".to_string(),
309 description: "Potential inefficient gradient computation detected".to_string(),
310 severity: Severity::Medium,
311 confidence: 0.75,
312 recommendations: vec![
313 "Consider using torch.no_grad() context".to_string(),
314 "Review gradient requirements".to_string(),
315 ],
316 });
317 }
318
319 Ok(patterns)
320 }
321
322 async fn perform_deep_analysis(
323 &self,
324 code: &str,
325 _context: &ModelContext,
326 ) -> Result<Vec<IdentifiedIssue>> {
327 debug!("Performing deep AI analysis");
328
329 let mut issues = Vec::new();
330
331 tokio::time::sleep(Duration::from_millis(100)).await;
333
334 if code.contains("log") && !code.contains("log1p") && code.contains("softmax") {
336 issues.push(IdentifiedIssue {
337 issue_type: IssueType::NumericalStability,
338 title: "Potential Numerical Instability in Log-Softmax".to_string(),
339 description: "Using log(softmax(x)) can cause numerical instability. Consider using log_softmax directly.".to_string(),
340 severity: Severity::High,
341 confidence: 0.88,
342 suggested_fix: "Replace log(softmax(x)) with log_softmax(x)".to_string(),
343 code_location: None, });
345 }
346
347 if code.contains("attention") && code.contains("matmul") && !code.contains("flash") {
349 issues.push(IdentifiedIssue {
350 issue_type: IssueType::Performance,
351 title: "Inefficient Attention Implementation".to_string(),
352 description:
353 "Standard attention implementation may be inefficient for large sequences."
354 .to_string(),
355 severity: Severity::Medium,
356 confidence: 0.75,
357 suggested_fix:
358 "Consider using Flash Attention or other optimized attention mechanisms"
359 .to_string(),
360 code_location: None,
361 });
362 }
363
364 if code.contains("accumulate") && !code.contains("zero_grad") {
366 issues.push(IdentifiedIssue {
367 issue_type: IssueType::MemoryLeak,
368 title: "Potential Gradient Accumulation Memory Leak".to_string(),
369 description: "Gradient accumulation without zero_grad() can cause memory leaks."
370 .to_string(),
371 severity: Severity::High,
372 confidence: 0.82,
373 suggested_fix: "Ensure optimizer.zero_grad() is called appropriately".to_string(),
374 code_location: None,
375 });
376 }
377
378 Ok(issues)
379 }
380
381 async fn generate_optimization_suggestions(
382 &self,
383 code: &str,
384 context: &ModelContext,
385 ) -> Result<Vec<OptimizationSuggestion>> {
386 debug!("Generating optimization suggestions");
387
388 let mut suggestions = Vec::new();
389
390 if context.model_type == ModelType::Training && !code.contains("autocast") {
392 suggestions.push(OptimizationSuggestion {
393 optimization_type: OptimizationType::MixedPrecision,
394 title: "Enable Mixed Precision Training".to_string(),
395 description: "Mixed precision training can significantly speed up training and reduce memory usage.".to_string(),
396 potential_speedup: 1.5,
397 memory_savings: 0.4,
398 implementation_effort: ImplementationEffort::Low,
399 confidence: 0.9,
400 code_example: Some("with torch.autocast(device_type='cuda', dtype=torch.float16):".to_string()),
401 });
402 }
403
404 if context.model_type == ModelType::Production && !code.contains("compile") {
406 suggestions.push(OptimizationSuggestion {
407 optimization_type: OptimizationType::ModelCompilation,
408 title: "Enable Model Compilation".to_string(),
409 description: "Model compilation can provide significant inference speedups."
410 .to_string(),
411 potential_speedup: 2.0,
412 memory_savings: 0.0,
413 implementation_effort: ImplementationEffort::Low,
414 confidence: 0.85,
415 code_example: Some("model = torch.compile(model)".to_string()),
416 });
417 }
418
419 if context.model_size > 1_000_000_000 && !code.contains("checkpoint") {
421 suggestions.push(OptimizationSuggestion {
422 optimization_type: OptimizationType::MemoryOptimization,
423 title: "Enable Gradient Checkpointing".to_string(),
424 description:
425 "Gradient checkpointing can significantly reduce memory usage for large models."
426 .to_string(),
427 potential_speedup: 0.8, memory_savings: 0.6,
429 implementation_effort: ImplementationEffort::Medium,
430 confidence: 0.88,
431 code_example: Some("torch.utils.checkpoint.checkpoint(layer, x)".to_string()),
432 });
433 }
434
435 Ok(suggestions)
436 }
437
438 async fn detect_vulnerabilities(
439 &self,
440 code: &str,
441 context: &ModelContext,
442 ) -> Result<Vec<SecurityIssue>> {
443 debug!("Detecting security vulnerabilities");
444
445 let mut vulnerabilities = Vec::new();
446
447 if code.contains("pickle.load") && !code.contains("safe_load") {
449 vulnerabilities.push(SecurityIssue {
450 vulnerability_type: VulnerabilityType::CodeExecution,
451 title: "Unsafe Pickle Loading".to_string(),
452 description:
453 "Loading pickle files can execute arbitrary code. Use safe alternatives."
454 .to_string(),
455 severity: Severity::Critical,
456 confidence: 0.95,
457 mitigation: "Use torch.load with weights_only=True or safetensors".to_string(),
458 cve_references: vec!["CWE-502".to_string()],
459 });
460 }
461
462 if code.contains("state_dict")
464 && code.contains("save")
465 && context.model_type == ModelType::Production
466 {
467 vulnerabilities.push(SecurityIssue {
468 vulnerability_type: VulnerabilityType::DataExposure,
469 title: "Potential Model Parameter Exposure".to_string(),
470 description: "Saving full model state may expose sensitive parameters.".to_string(),
471 severity: Severity::Medium,
472 confidence: 0.7,
473 mitigation: "Consider differential privacy or parameter encryption".to_string(),
474 cve_references: vec![],
475 });
476 }
477
478 if code.contains("input") && !code.contains("validate") && !code.contains("sanitize") {
480 vulnerabilities.push(SecurityIssue {
481 vulnerability_type: VulnerabilityType::InputValidation,
482 title: "Missing Input Validation".to_string(),
483 description: "Input validation is important for preventing adversarial attacks."
484 .to_string(),
485 severity: Severity::Medium,
486 confidence: 0.65,
487 mitigation: "Implement input validation and sanitization".to_string(),
488 cve_references: vec![],
489 });
490 }
491
492 Ok(vulnerabilities)
493 }
494
495 async fn predict_performance_characteristics(
496 &self,
497 code: &str,
498 context: &ModelContext,
499 ) -> Result<PerformancePredictions> {
500 debug!("Predicting performance characteristics");
501
502 tokio::time::sleep(Duration::from_millis(50)).await;
504
505 let mut predictions = PerformancePredictions::new();
506
507 predictions.estimated_memory_usage = self.estimate_memory_usage(code, context);
509 predictions.estimated_training_time = self.estimate_training_time(code, context);
510 predictions.estimated_inference_latency = self.estimate_inference_latency(code, context);
511 predictions.scaling_characteristics = self.predict_scaling_behavior(code, context);
512
513 predictions.predicted_bottlenecks = vec![
515 "Attention computation may become bottleneck for long sequences".to_string(),
516 "Memory bandwidth may limit performance for large batch sizes".to_string(),
517 ];
518
519 predictions.confidence_score = 0.75;
520
521 Ok(predictions)
522 }
523
524 async fn detect_fusion_opportunities(
525 &self,
526 operations: &[TensorOperation],
527 ) -> Result<Vec<FusionOpportunity>> {
528 let mut opportunities = Vec::new();
529
530 for window in operations.windows(2) {
532 if let [op1, op2] = window {
533 if matches!(op1.op_type, OperationType::MatMul)
534 && matches!(op2.op_type, OperationType::Add)
535 {
536 opportunities.push(FusionOpportunity {
537 operations: vec![op1.clone(), op2.clone()],
538 fusion_type: FusionType::GEMM,
539 estimated_speedup: 1.3,
540 description: "MatMul + Add can be fused into GEMM operation".to_string(),
541 });
542 }
543 }
544 }
545
546 for window in operations.windows(2) {
548 if let [op1, op2] = window {
549 if matches!(op1.op_type, OperationType::Linear)
550 && matches!(op2.op_type, OperationType::Activation)
551 {
552 opportunities.push(FusionOpportunity {
553 operations: vec![op1.clone(), op2.clone()],
554 fusion_type: FusionType::LinearActivation,
555 estimated_speedup: 1.2,
556 description: "Linear + Activation can be fused".to_string(),
557 });
558 }
559 }
560 }
561
562 Ok(opportunities)
563 }
564
565 async fn detect_memory_optimizations(
566 &self,
567 operations: &[TensorOperation],
568 ) -> Result<Vec<MemoryOptimization>> {
569 let mut optimizations = Vec::new();
570
571 for op in operations {
573 if op.can_be_inplace() && !op.is_inplace {
574 optimizations.push(MemoryOptimization {
575 operation: op.clone(),
576 optimization_type: MemoryOptimizationType::InPlace,
577 memory_savings: op.output_size_bytes,
578 description: format!("Operation {} can be performed in-place", op.name),
579 });
580 }
581 }
582
583 let mut tensor_usage = HashMap::new();
585 for op in operations {
586 for input in &op.inputs {
587 *tensor_usage.entry(input.clone()).or_insert(0) += 1;
588 }
589 }
590
591 for (tensor, usage_count) in tensor_usage {
592 if usage_count == 1 {
593 optimizations.push(MemoryOptimization {
594 operation: TensorOperation::default(),
595 optimization_type: MemoryOptimizationType::TensorReuse,
596 memory_savings: 0, description: format!("Tensor {} can be reused", tensor),
598 });
599 }
600 }
601
602 Ok(optimizations)
603 }
604
605 async fn detect_parallelization_opportunities(
606 &self,
607 operations: &[TensorOperation],
608 ) -> Result<Vec<ParallelizationOpportunity>> {
609 let mut opportunities = Vec::new();
610
611 for (i, op1) in operations.iter().enumerate() {
613 for op2 in operations.iter().skip(i + 1) {
614 if self.operations_are_independent(op1, op2) {
615 opportunities.push(ParallelizationOpportunity {
616 operations: vec![op1.clone(), op2.clone()],
617 parallelization_type: ParallelizationType::DataParallel,
618 estimated_speedup: 1.8,
619 description: "Operations can run in parallel".to_string(),
620 });
621 }
622 }
623 }
624
625 Ok(opportunities)
626 }
627
628 async fn detect_redundant_operations(
629 &self,
630 operations: &[TensorOperation],
631 ) -> Result<Vec<RedundantOperation>> {
632 let mut redundant = Vec::new();
633
634 for (i, op1) in operations.iter().enumerate() {
636 for (_j, op2) in operations.iter().enumerate().skip(i + 1) {
637 if self.operations_are_equivalent(op1, op2) {
638 redundant.push(RedundantOperation {
639 original_operation: op1.clone(),
640 redundant_operation: op2.clone(),
641 redundancy_type: RedundancyType::Duplicate,
642 description: "Operations produce identical results".to_string(),
643 });
644 }
645 }
646 }
647
648 Ok(redundant)
649 }
650
651 fn operations_are_independent(&self, op1: &TensorOperation, op2: &TensorOperation) -> bool {
654 for input1 in &op1.inputs {
656 for output2 in &op2.outputs {
657 if input1 == output2 {
658 return false;
659 }
660 }
661 }
662 for input2 in &op2.inputs {
663 for output1 in &op1.outputs {
664 if input2 == output1 {
665 return false;
666 }
667 }
668 }
669 true
670 }
671
672 fn operations_are_equivalent(&self, op1: &TensorOperation, op2: &TensorOperation) -> bool {
673 op1.op_type == op2.op_type && op1.inputs == op2.inputs && op1.parameters == op2.parameters
674 }
675
676 fn compute_code_hash(&self, code: &str) -> String {
677 use std::collections::hash_map::DefaultHasher;
678 use std::hash::{Hash, Hasher};
679
680 let mut hasher = DefaultHasher::new();
681 code.hash(&mut hasher);
682 format!("{:x}", hasher.finish())
683 }
684
685 fn get_cached_analysis(&self, code_hash: &str) -> Option<&CachedAnalysis> {
686 self.analysis_cache.get(code_hash).and_then(|cached| {
687 let age = std::time::SystemTime::now()
688 .duration_since(cached.timestamp)
689 .unwrap_or_default();
690
691 if age.as_secs() < self.config.cache_expiration_hours * 3600 {
692 Some(cached)
693 } else {
694 None
695 }
696 })
697 }
698
699 fn cache_analysis(&mut self, code_hash: String, result: &CodeAnalysisResult) {
700 self.analysis_cache.insert(
701 code_hash.clone(),
702 CachedAnalysis {
703 result: result.clone(),
704 timestamp: std::time::SystemTime::now(),
705 code_hash,
706 },
707 );
708 }
709
710 fn calculate_quality_score(&self, result: &CodeAnalysisResult) -> f64 {
711 let mut score: f64 = 100.0;
712
713 for issue in &result.identified_issues {
715 match issue.severity {
716 Severity::Critical => score -= 20.0,
717 Severity::High => score -= 10.0,
718 Severity::Medium => score -= 5.0,
719 Severity::Low => score -= 2.0,
720 Severity::Info => score -= 0.0,
721 }
722 }
723
724 for vulnerability in &result.security_issues {
726 match vulnerability.severity {
727 Severity::Critical => score -= 25.0,
728 Severity::High => score -= 15.0,
729 Severity::Medium => score -= 8.0,
730 Severity::Low => score -= 3.0,
731 Severity::Info => score -= 0.0,
732 }
733 }
734
735 for pattern in &result.detected_patterns {
737 if pattern.pattern_type == PatternType::GoodPattern {
738 score += 2.0;
739 }
740 }
741
742 score.max(0.0).min(100.0)
743 }
744
745 fn calculate_confidence_score(&self, result: &CodeAnalysisResult) -> f64 {
746 let mut total_confidence = 0.0;
747 let mut count = 0;
748
749 for issue in &result.identified_issues {
750 total_confidence += issue.confidence;
751 count += 1;
752 }
753
754 for pattern in &result.detected_patterns {
755 total_confidence += pattern.confidence;
756 count += 1;
757 }
758
759 if count > 0 {
760 total_confidence / count as f64
761 } else {
762 1.0
763 }
764 }
765
766 fn estimate_memory_usage(&self, code: &str, context: &ModelContext) -> f64 {
767 let base_memory = context.model_size as f64 * 4.0; let mut multiplier = 1.0;
771 if code.contains("gradient_accumulation") {
772 multiplier += 0.5;
773 }
774 if code.contains("mixed_precision") {
775 multiplier *= 0.7;
776 }
777
778 base_memory * multiplier / 1_000_000.0 }
780
781 fn estimate_training_time(&self, code: &str, context: &ModelContext) -> f64 {
782 let base_time = (context.model_size as f64).log10() * 10.0;
784
785 let mut multiplier = 1.0;
786 if code.contains("mixed_precision") {
787 multiplier *= 0.6;
788 }
789 if code.contains("gradient_checkpointing") {
790 multiplier *= 1.3;
791 }
792
793 base_time * multiplier
794 }
795
796 fn estimate_inference_latency(&self, code: &str, context: &ModelContext) -> f64 {
797 let base_latency = (context.model_size as f64).log10() * 5.0;
799
800 let mut multiplier = 1.0;
801 if code.contains("compile") {
802 multiplier *= 0.5;
803 }
804 if code.contains("quantization") {
805 multiplier *= 0.7;
806 }
807
808 base_latency * multiplier
809 }
810
811 fn predict_scaling_behavior(
812 &self,
813 _code: &str,
814 context: &ModelContext,
815 ) -> ScalingCharacteristics {
816 ScalingCharacteristics {
817 batch_size_scaling: if context.model_size > 1_000_000_000 {
818 ScalingBehavior::Sublinear
819 } else {
820 ScalingBehavior::Linear
821 },
822 sequence_length_scaling: ScalingBehavior::Quadratic, model_size_scaling: ScalingBehavior::Linear,
824 memory_scaling: ScalingBehavior::Linear,
825 }
826 }
827
828 fn estimate_optimization_speedup(&self, report: &TensorOptimizationReport) -> f64 {
829 let mut speedup = 1.0;
830
831 for fusion in &report.fusion_opportunities {
832 speedup *= fusion.estimated_speedup;
833 }
834
835 for parallel in &report.parallelization_opportunities {
836 speedup *= parallel.estimated_speedup;
837 }
838
839 speedup.min(10.0) }
841
842 fn estimate_memory_savings(&self, report: &TensorOptimizationReport) -> f64 {
843 let total_savings: u64 =
844 report.memory_optimizations.iter().map(|opt| opt.memory_savings).sum();
845
846 total_savings as f64 / 1_000_000.0 }
848
849 async fn analyze_error_patterns(
850 &self,
851 error_context: &ErrorContext,
852 ) -> Result<Vec<ProbableCause>> {
853 let mut causes = Vec::new();
854
855 match error_context.error_type.as_str() {
856 "OutOfMemoryError" => {
857 causes.push(ProbableCause {
858 cause: "Batch size too large".to_string(),
859 probability: 0.8,
860 evidence: vec!["GPU memory limit exceeded".to_string()],
861 });
862 causes.push(ProbableCause {
863 cause: "Model too large for available memory".to_string(),
864 probability: 0.6,
865 evidence: vec!["Model parameter count".to_string()],
866 });
867 },
868 "GradientExplosion" => {
869 causes.push(ProbableCause {
870 cause: "Learning rate too high".to_string(),
871 probability: 0.7,
872 evidence: vec!["Gradient norm increasing rapidly".to_string()],
873 });
874 },
875 _ => {
876 causes.push(ProbableCause {
877 cause: "Unknown error pattern".to_string(),
878 probability: 0.3,
879 evidence: vec![],
880 });
881 },
882 }
883
884 Ok(causes)
885 }
886
887 async fn generate_suggested_fixes(
888 &self,
889 error_context: &ErrorContext,
890 ) -> Result<Vec<SuggestedFix>> {
891 let mut fixes = Vec::new();
892
893 match error_context.error_type.as_str() {
894 "OutOfMemoryError" => {
895 fixes.push(SuggestedFix {
896 description: "Reduce batch size".to_string(),
897 implementation: "batch_size = batch_size // 2".to_string(),
898 confidence: 0.9,
899 estimated_impact: "Should free ~50% of memory".to_string(),
900 });
901 fixes.push(SuggestedFix {
902 description: "Enable gradient checkpointing".to_string(),
903 implementation: "model.gradient_checkpointing_enable()".to_string(),
904 confidence: 0.8,
905 estimated_impact: "Reduces memory by ~40% with 10-20% speed penalty"
906 .to_string(),
907 });
908 },
909 "GradientExplosion" => {
910 fixes.push(SuggestedFix {
911 description: "Add gradient clipping".to_string(),
912 implementation:
913 "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)"
914 .to_string(),
915 confidence: 0.95,
916 estimated_impact: "Prevents gradient explosion".to_string(),
917 });
918 },
919 _ => {},
920 }
921
922 Ok(fixes)
923 }
924
925 async fn generate_debugging_steps(
926 &self,
927 error_context: &ErrorContext,
928 ) -> Result<Vec<DebuggingStep>> {
929 let mut steps = Vec::new();
930
931 steps.push(DebuggingStep {
932 step_number: 1,
933 description: "Check system resources".to_string(),
934 command: Some("nvidia-smi".to_string()),
935 expected_output: "GPU memory usage and availability".to_string(),
936 });
937
938 steps.push(DebuggingStep {
939 step_number: 2,
940 description: "Verify model configuration".to_string(),
941 command: Some("print(model)".to_string()),
942 expected_output: "Model architecture and parameter count".to_string(),
943 });
944
945 match error_context.error_type.as_str() {
946 "OutOfMemoryError" => {
947 steps.push(DebuggingStep {
948 step_number: 3,
949 description: "Check tensor shapes and batch size".to_string(),
950 command: Some(
951 "print(f'Batch size: {batch_size}, Input shape: {input.shape}')"
952 .to_string(),
953 ),
954 expected_output: "Current batch size and input dimensions".to_string(),
955 });
956 },
957 _ => {},
958 }
959
960 Ok(steps)
961 }
962
963 async fn find_related_documentation(
964 &self,
965 error_context: &ErrorContext,
966 ) -> Result<Vec<DocumentationReference>> {
967 let mut references = Vec::new();
968
969 match error_context.error_type.as_str() {
970 "OutOfMemoryError" => {
971 references.push(DocumentationReference {
972 title: "Memory Management Best Practices".to_string(),
973 url: "https://docs.trustformers.ai/memory-management".to_string(),
974 relevance_score: 0.95,
975 });
976 references.push(DocumentationReference {
977 title: "Gradient Checkpointing Guide".to_string(),
978 url: "https://docs.trustformers.ai/gradient-checkpointing".to_string(),
979 relevance_score: 0.8,
980 });
981 },
982 "GradientExplosion" => {
983 references.push(DocumentationReference {
984 title: "Training Stability Guide".to_string(),
985 url: "https://docs.trustformers.ai/training-stability".to_string(),
986 relevance_score: 0.9,
987 });
988 },
989 _ => {},
990 }
991
992 Ok(references)
993 }
994
995 fn calculate_debugging_confidence(&self, assistance: &DebuggingAssistance) -> f64 {
996 let avg_cause_probability =
997 assistance.probable_causes.iter().map(|cause| cause.probability).sum::<f64>()
998 / assistance.probable_causes.len().max(1) as f64;
999
1000 let avg_fix_confidence =
1001 assistance.suggested_fixes.iter().map(|fix| fix.confidence).sum::<f64>()
1002 / assistance.suggested_fixes.len().max(1) as f64;
1003
1004 (avg_cause_probability + avg_fix_confidence) / 2.0
1005 }
1006}
1007
1008#[derive(Debug)]
1012struct ModelPatternDatabase {
1013 #[allow(dead_code)]
1014 patterns: HashMap<String, PatternDefinition>,
1015}
1016
1017impl ModelPatternDatabase {
1018 fn new() -> Self {
1019 let mut patterns = HashMap::new();
1020
1021 patterns.insert(
1023 "gradient_clipping".to_string(),
1024 PatternDefinition {
1025 name: "Gradient Clipping".to_string(),
1026 pattern_type: PatternType::GoodPattern,
1027 keywords: vec![
1028 "clip_grad_norm".to_string(),
1029 "gradient".to_string(),
1030 "clip".to_string(),
1031 ],
1032 severity: Severity::Info,
1033 description: "Proper gradient clipping prevents gradient explosion".to_string(),
1034 },
1035 );
1036
1037 Self { patterns }
1038 }
1039}
1040
1041#[derive(Debug, Clone)]
1042#[allow(dead_code)]
1043struct PatternDefinition {
1044 #[allow(dead_code)]
1045 name: String,
1046 pattern_type: PatternType,
1047 keywords: Vec<String>,
1048 severity: Severity,
1049 description: String,
1050}
1051
1052#[derive(Debug, Clone)]
1054pub struct ModelContext {
1055 pub model_type: ModelType,
1056 pub model_size: u64, pub framework: String,
1058 pub target_hardware: String,
1059 pub training_stage: TrainingStage,
1060}
1061
1062#[derive(Debug, Clone, PartialEq)]
1063pub enum ModelType {
1064 Training,
1065 Inference,
1066 Production,
1067 Development,
1068}
1069
1070#[derive(Debug, Clone)]
1071pub enum TrainingStage {
1072 Training,
1073 Development,
1074 Pretraining,
1075 Finetuning,
1076 Evaluation,
1077 Inference,
1078}
1079
1080#[derive(Debug, Clone, Serialize, Deserialize)]
1082pub struct CodeAnalysisResult {
1083 pub quality_score: f64,
1084 pub detected_patterns: Vec<DetectedPattern>,
1085 pub identified_issues: Vec<IdentifiedIssue>,
1086 pub optimization_suggestions: Vec<OptimizationSuggestion>,
1087 pub security_issues: Vec<SecurityIssue>,
1088 pub performance_predictions: PerformancePredictions,
1089 pub analysis_metadata: AnalysisMetadata,
1090}
1091
1092impl CodeAnalysisResult {
1093 fn new() -> Self {
1094 Self {
1095 quality_score: 0.0,
1096 detected_patterns: Vec::new(),
1097 identified_issues: Vec::new(),
1098 optimization_suggestions: Vec::new(),
1099 security_issues: Vec::new(),
1100 performance_predictions: PerformancePredictions::new(),
1101 analysis_metadata: AnalysisMetadata::default(),
1102 }
1103 }
1104}
1105
1106#[derive(Debug, Clone, Serialize, Deserialize)]
1107pub struct DetectedPattern {
1108 pub pattern_type: PatternType,
1109 pub name: String,
1110 pub description: String,
1111 pub severity: Severity,
1112 pub confidence: f64,
1113 pub recommendations: Vec<String>,
1114}
1115
1116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1117pub enum PatternType {
1118 GoodPattern,
1119 AntiPattern,
1120 OptimizationOpportunity,
1121 SecurityConcern,
1122}
1123
1124#[derive(Debug, Clone, Serialize, Deserialize)]
1125pub struct IdentifiedIssue {
1126 pub issue_type: IssueType,
1127 pub title: String,
1128 pub description: String,
1129 pub severity: Severity,
1130 pub confidence: f64,
1131 pub suggested_fix: String,
1132 pub code_location: Option<CodeLocation>,
1133}
1134
1135#[derive(Debug, Clone, Serialize, Deserialize)]
1136pub enum IssueType {
1137 NumericalStability,
1138 Performance,
1139 MemoryLeak,
1140 LogicError,
1141 TypeMismatch,
1142 ResourceLeak,
1143}
1144
1145#[derive(Debug, Clone, Serialize, Deserialize)]
1146pub struct CodeLocation {
1147 pub file: String,
1148 pub line: u32,
1149 pub column: u32,
1150}
1151
1152#[derive(Debug, Clone, Serialize, Deserialize)]
1153pub struct OptimizationSuggestion {
1154 pub optimization_type: OptimizationType,
1155 pub title: String,
1156 pub description: String,
1157 pub potential_speedup: f64,
1158 pub memory_savings: f64,
1159 pub implementation_effort: ImplementationEffort,
1160 pub confidence: f64,
1161 pub code_example: Option<String>,
1162}
1163
1164#[derive(Debug, Clone, Serialize, Deserialize)]
1165pub enum OptimizationType {
1166 MixedPrecision,
1167 ModelCompilation,
1168 MemoryOptimization,
1169 ComputationOptimization,
1170 IOOptimization,
1171 ParallelizationOptimization,
1172}
1173
1174#[derive(Debug, Clone, Serialize, Deserialize)]
1175pub enum ImplementationEffort {
1176 Low,
1177 Medium,
1178 High,
1179}
1180
1181#[derive(Debug, Clone, Serialize, Deserialize)]
1182pub struct SecurityIssue {
1183 pub vulnerability_type: VulnerabilityType,
1184 pub title: String,
1185 pub description: String,
1186 pub severity: Severity,
1187 pub confidence: f64,
1188 pub mitigation: String,
1189 pub cve_references: Vec<String>,
1190}
1191
1192#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1193pub enum VulnerabilityType {
1194 CodeExecution,
1195 DataExposure,
1196 InputValidation,
1197 AuthenticationBypass,
1198 PrivilegeEscalation,
1199}
1200
1201#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1202pub enum Severity {
1203 Critical,
1204 High,
1205 Medium,
1206 Low,
1207 Info,
1208}
1209
1210#[derive(Debug, Clone, Serialize, Deserialize)]
1211pub struct PerformancePredictions {
1212 pub estimated_memory_usage: f64, pub estimated_training_time: f64, pub estimated_inference_latency: f64, pub scaling_characteristics: ScalingCharacteristics,
1216 pub predicted_bottlenecks: Vec<String>,
1217 pub confidence_score: f64,
1218}
1219
1220impl PerformancePredictions {
1221 fn new() -> Self {
1222 Self {
1223 estimated_memory_usage: 0.0,
1224 estimated_training_time: 0.0,
1225 estimated_inference_latency: 0.0,
1226 scaling_characteristics: ScalingCharacteristics::default(),
1227 predicted_bottlenecks: Vec::new(),
1228 confidence_score: 0.0,
1229 }
1230 }
1231}
1232
1233#[derive(Debug, Clone, Serialize, Deserialize)]
1234pub struct ScalingCharacteristics {
1235 pub batch_size_scaling: ScalingBehavior,
1236 pub sequence_length_scaling: ScalingBehavior,
1237 pub model_size_scaling: ScalingBehavior,
1238 pub memory_scaling: ScalingBehavior,
1239}
1240
1241impl Default for ScalingCharacteristics {
1242 fn default() -> Self {
1243 Self {
1244 batch_size_scaling: ScalingBehavior::Linear,
1245 sequence_length_scaling: ScalingBehavior::Linear,
1246 model_size_scaling: ScalingBehavior::Linear,
1247 memory_scaling: ScalingBehavior::Linear,
1248 }
1249 }
1250}
1251
1252#[derive(Debug, Clone, Serialize, Deserialize)]
1253pub enum ScalingBehavior {
1254 Constant,
1255 Linear,
1256 Quadratic,
1257 Exponential,
1258 Sublinear,
1259}
1260
1261#[derive(Debug, Clone, Serialize, Deserialize)]
1262pub struct AnalysisMetadata {
1263 pub analysis_duration: Duration,
1264 pub confidence_score: f64,
1265 pub analyzer_version: String,
1266 pub timestamp: std::time::SystemTime,
1267}
1268
1269impl Default for AnalysisMetadata {
1270 fn default() -> Self {
1271 Self {
1272 analysis_duration: Duration::from_secs(0),
1273 confidence_score: 0.0,
1274 analyzer_version: "1.0.0".to_string(),
1275 timestamp: std::time::SystemTime::now(),
1276 }
1277 }
1278}
1279
1280#[derive(Debug, Clone)]
1283pub struct TensorOperation {
1284 pub name: String,
1285 pub op_type: OperationType,
1286 pub inputs: Vec<String>,
1287 pub outputs: Vec<String>,
1288 pub parameters: HashMap<String, String>,
1289 pub output_size_bytes: u64,
1290 pub is_inplace: bool,
1291}
1292
1293impl Default for TensorOperation {
1294 fn default() -> Self {
1295 Self {
1296 name: String::new(),
1297 op_type: OperationType::Unknown,
1298 inputs: Vec::new(),
1299 outputs: Vec::new(),
1300 parameters: HashMap::new(),
1301 output_size_bytes: 0,
1302 is_inplace: false,
1303 }
1304 }
1305}
1306
1307impl TensorOperation {
1308 fn can_be_inplace(&self) -> bool {
1309 matches!(
1310 self.op_type,
1311 OperationType::Add | OperationType::Mul | OperationType::Activation
1312 )
1313 }
1314}
1315
1316#[derive(Debug, Clone, PartialEq)]
1317pub enum OperationType {
1318 MatMul,
1319 Add,
1320 Mul,
1321 Conv2D,
1322 Linear,
1323 Activation,
1324 Pooling,
1325 BatchNorm,
1326 LayerNorm,
1327 Attention,
1328 Unknown,
1329}
1330
1331#[derive(Debug, Clone)]
1332pub struct TensorOptimizationReport {
1333 pub fusion_opportunities: Vec<FusionOpportunity>,
1334 pub memory_optimizations: Vec<MemoryOptimization>,
1335 pub parallelization_opportunities: Vec<ParallelizationOpportunity>,
1336 pub redundant_operations: Vec<RedundantOperation>,
1337 pub estimated_speedup: f64,
1338 pub estimated_memory_savings: f64,
1339}
1340
1341impl TensorOptimizationReport {
1342 fn new() -> Self {
1343 Self {
1344 fusion_opportunities: Vec::new(),
1345 memory_optimizations: Vec::new(),
1346 parallelization_opportunities: Vec::new(),
1347 redundant_operations: Vec::new(),
1348 estimated_speedup: 1.0,
1349 estimated_memory_savings: 0.0,
1350 }
1351 }
1352}
1353
1354#[derive(Debug, Clone)]
1355pub struct FusionOpportunity {
1356 pub operations: Vec<TensorOperation>,
1357 pub fusion_type: FusionType,
1358 pub estimated_speedup: f64,
1359 pub description: String,
1360}
1361
1362#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1363pub enum FusionType {
1364 GEMM,
1365 LinearActivation,
1366 ConvBatchNorm,
1367 AttentionQKV,
1368}
1369
1370#[derive(Debug, Clone)]
1371pub struct MemoryOptimization {
1372 pub operation: TensorOperation,
1373 pub optimization_type: MemoryOptimizationType,
1374 pub memory_savings: u64,
1375 pub description: String,
1376}
1377
1378#[derive(Debug, Clone)]
1379pub enum MemoryOptimizationType {
1380 InPlace,
1381 TensorReuse,
1382 MemoryPool,
1383 GradientCheckpointing,
1384}
1385
1386#[derive(Debug, Clone)]
1387pub struct ParallelizationOpportunity {
1388 pub operations: Vec<TensorOperation>,
1389 pub parallelization_type: ParallelizationType,
1390 pub estimated_speedup: f64,
1391 pub description: String,
1392}
1393
1394#[derive(Debug, Clone)]
1395pub enum ParallelizationType {
1396 DataParallel,
1397 ModelParallel,
1398 PipelineParallel,
1399 TensorParallel,
1400}
1401
1402#[derive(Debug, Clone)]
1403pub struct RedundantOperation {
1404 pub original_operation: TensorOperation,
1405 pub redundant_operation: TensorOperation,
1406 pub redundancy_type: RedundancyType,
1407 pub description: String,
1408}
1409
1410#[derive(Debug, Clone)]
1411pub enum RedundancyType {
1412 Duplicate,
1413 Subsumed,
1414 Unnecessary,
1415}
1416
1417#[derive(Debug, Clone)]
1420pub struct ErrorContext {
1421 pub error_type: String,
1422 pub error_message: String,
1423 pub stack_trace: Option<String>,
1424 pub system_info: SystemInfo,
1425 pub model_info: Option<ModelContext>,
1426}
1427
1428#[derive(Debug, Clone)]
1429pub struct SystemInfo {
1430 pub gpu_memory_total: u64,
1431 pub gpu_memory_used: u64,
1432 pub cpu_count: u32,
1433 pub ram_total: u64,
1434 pub ram_used: u64,
1435}
1436
1437#[derive(Debug, Clone)]
1438pub struct DebuggingAssistance {
1439 pub probable_causes: Vec<ProbableCause>,
1440 pub suggested_fixes: Vec<SuggestedFix>,
1441 pub debugging_steps: Vec<DebuggingStep>,
1442 pub related_documentation: Vec<DocumentationReference>,
1443 pub confidence_score: f64,
1444}
1445
1446impl DebuggingAssistance {
1447 fn new() -> Self {
1448 Self {
1449 probable_causes: Vec::new(),
1450 suggested_fixes: Vec::new(),
1451 debugging_steps: Vec::new(),
1452 related_documentation: Vec::new(),
1453 confidence_score: 0.0,
1454 }
1455 }
1456}
1457
1458#[derive(Debug, Clone)]
1459pub struct ProbableCause {
1460 pub cause: String,
1461 pub probability: f64,
1462 pub evidence: Vec<String>,
1463}
1464
1465#[derive(Debug, Clone)]
1466pub struct SuggestedFix {
1467 pub description: String,
1468 pub implementation: String,
1469 pub confidence: f64,
1470 pub estimated_impact: String,
1471}
1472
1473#[derive(Debug, Clone)]
1474pub struct DebuggingStep {
1475 pub step_number: u32,
1476 pub description: String,
1477 pub command: Option<String>,
1478 pub expected_output: String,
1479}
1480
1481#[derive(Debug, Clone)]
1482pub struct DocumentationReference {
1483 pub title: String,
1484 pub url: String,
1485 pub relevance_score: f64,
1486}
1487
1488#[derive(Debug, Serialize, Deserialize)]
1491pub struct AnalysisPerformanceMetrics {
1492 pub total_analyses: u64,
1493 pub average_analysis_time: Duration,
1494 pub cache_hit_rate: f64,
1495 pub cached_results: usize,
1496}
1497
1498#[macro_export]
1500macro_rules! ai_analyze {
1501 ($code:expr, $context:expr) => {{
1502 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1503 analyzer.analyze_model_code($code, $context).await
1504 }};
1505}
1506
1507#[cfg(test)]
1508mod tests {
1509 use super::*;
1510
1511 #[tokio::test]
1512 async fn test_ai_code_analyzer_creation() {
1513 let analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1514 assert!(analyzer.config.enable_deep_analysis);
1515 }
1516
1517 #[tokio::test]
1518 async fn test_pattern_detection() {
1519 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1520
1521 let code = r#"
1522 import torch
1523
1524 def train_step(model, data):
1525 torch.cuda.empty_cache() # Should trigger anti-pattern
1526 grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Good pattern
1527 return grad_norm
1528 "#;
1529
1530 let context = ModelContext {
1531 model_type: ModelType::Production,
1532 model_size: 1_000_000,
1533 framework: "PyTorch".to_string(),
1534 target_hardware: "CUDA".to_string(),
1535 training_stage: TrainingStage::Training,
1536 };
1537
1538 let result = analyzer.analyze_model_code(code, context).await.unwrap();
1539 assert!(!result.detected_patterns.is_empty());
1540 }
1541
1542 #[tokio::test]
1543 async fn test_security_vulnerability_detection() {
1544 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1545
1546 let code = r#"
1547 import pickle
1548
1549 def load_model(path):
1550 with open(path, 'rb') as f:
1551 model = pickle.load(f) # Should trigger security warning
1552 return model
1553 "#;
1554
1555 let context = ModelContext {
1556 model_type: ModelType::Production,
1557 model_size: 1_000_000,
1558 framework: "PyTorch".to_string(),
1559 target_hardware: "CUDA".to_string(),
1560 training_stage: TrainingStage::Inference,
1561 };
1562
1563 let result = analyzer.analyze_model_code(code, context).await.unwrap();
1564 assert!(!result.security_issues.is_empty());
1565 assert_eq!(
1566 result.security_issues[0].vulnerability_type,
1567 VulnerabilityType::CodeExecution
1568 );
1569 }
1570
1571 #[tokio::test]
1572 async fn test_tensor_operation_analysis() {
1573 let analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1574
1575 let operations = vec![
1576 TensorOperation {
1577 name: "matmul1".to_string(),
1578 op_type: OperationType::MatMul,
1579 inputs: vec!["a".to_string(), "b".to_string()],
1580 outputs: vec!["c".to_string()],
1581 parameters: HashMap::new(),
1582 output_size_bytes: 1024,
1583 is_inplace: false,
1584 },
1585 TensorOperation {
1586 name: "add1".to_string(),
1587 op_type: OperationType::Add,
1588 inputs: vec!["c".to_string(), "bias".to_string()],
1589 outputs: vec!["d".to_string()],
1590 parameters: HashMap::new(),
1591 output_size_bytes: 1024,
1592 is_inplace: false,
1593 },
1594 ];
1595
1596 let report = analyzer.analyze_tensor_operations(&operations).await.unwrap();
1597 assert!(!report.fusion_opportunities.is_empty());
1598 assert_eq!(report.fusion_opportunities[0].fusion_type, FusionType::GEMM);
1599 }
1600
1601 #[tokio::test]
1602 async fn test_performance_metrics() {
1603 let mut analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1604
1605 let code = "print('hello')";
1607 let context = ModelContext {
1608 model_type: ModelType::Development,
1609 model_size: 1000,
1610 framework: "PyTorch".to_string(),
1611 target_hardware: "CPU".to_string(),
1612 training_stage: TrainingStage::Development,
1613 };
1614
1615 analyzer.analyze_model_code(code, context.clone()).await.unwrap();
1616 analyzer.analyze_model_code(code, context).await.unwrap(); let metrics = analyzer.get_performance_metrics();
1619 assert_eq!(metrics.total_analyses, 2);
1620 assert!(metrics.cache_hit_rate > 0.0);
1621 }
1622
1623 #[tokio::test]
1624 async fn test_debugging_assistance() {
1625 let analyzer = AICodeAnalyzer::new(AIAnalysisConfig::default());
1626
1627 let error_context = ErrorContext {
1628 error_type: "OutOfMemoryError".to_string(),
1629 error_message: "CUDA out of memory".to_string(),
1630 stack_trace: None,
1631 system_info: SystemInfo {
1632 gpu_memory_total: 8_000_000_000,
1633 gpu_memory_used: 7_500_000_000,
1634 cpu_count: 8,
1635 ram_total: 32_000_000_000,
1636 ram_used: 16_000_000_000,
1637 },
1638 model_info: None,
1639 };
1640
1641 let assistance = analyzer.automated_debugging_assistance(&error_context).await.unwrap();
1642 assert!(!assistance.probable_causes.is_empty());
1643 assert!(!assistance.suggested_fixes.is_empty());
1644 assert!(assistance.confidence_score > 0.0);
1645 }
1646}