skill_runtime/generation/
evaluation.rs

1//! Accuracy evaluation metrics for AI-generated examples
2//!
3//! Provides metrics for measuring the quality, accuracy, and diversity
4//! of generated examples against tool schemas.
5
6use std::collections::HashMap;
7use crate::skill_md::ToolDocumentation;
8use super::streaming::GeneratedExample;
9use super::validator::ExampleValidator;
10
11// =============================================================================
12// Accuracy Metrics
13// =============================================================================
14
15/// Comprehensive accuracy metrics for a batch of generated examples
16#[derive(Debug, Clone, Default)]
17pub struct AccuracyMetrics {
18    /// Total number of examples generated
19    pub total_generated: usize,
20
21    /// Number that passed schema validation
22    pub schema_valid: usize,
23
24    /// Number with all required parameters present
25    pub required_params_present: usize,
26
27    /// Number with correct parameter types
28    pub type_correct: usize,
29
30    /// Number with non-empty explanations
31    pub has_explanation: usize,
32
33    /// Diversity score (Jaccard-based, 0.0-1.0)
34    pub diversity_score: f32,
35
36    /// Per-tool breakdown
37    pub per_tool: HashMap<String, ToolMetrics>,
38
39    /// Validation errors by type
40    pub error_breakdown: HashMap<String, usize>,
41}
42
43impl AccuracyMetrics {
44    /// Create new empty metrics
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Calculate schema validation rate (0.0-1.0)
50    pub fn validation_rate(&self) -> f32 {
51        if self.total_generated == 0 {
52            return 0.0;
53        }
54        self.schema_valid as f32 / self.total_generated as f32
55    }
56
57    /// Calculate required parameter compliance rate (0.0-1.0)
58    pub fn param_compliance_rate(&self) -> f32 {
59        if self.total_generated == 0 {
60            return 0.0;
61        }
62        self.required_params_present as f32 / self.total_generated as f32
63    }
64
65    /// Calculate type correctness rate (0.0-1.0)
66    pub fn type_correctness_rate(&self) -> f32 {
67        if self.total_generated == 0 {
68            return 0.0;
69        }
70        self.type_correct as f32 / self.total_generated as f32
71    }
72
73    /// Calculate explanation coverage rate (0.0-1.0)
74    pub fn explanation_rate(&self) -> f32 {
75        if self.total_generated == 0 {
76            return 0.0;
77        }
78        self.has_explanation as f32 / self.total_generated as f32
79    }
80
81    /// Calculate overall quality score (weighted average, 0.0-1.0)
82    pub fn overall_quality(&self) -> f32 {
83        let weights = [
84            (self.validation_rate(), 0.4),       // Schema validation is most important
85            (self.param_compliance_rate(), 0.25), // Required params
86            (self.type_correctness_rate(), 0.15), // Type correctness
87            (self.explanation_rate(), 0.1),       // Has explanation
88            (self.diversity_score, 0.1),          // Diversity
89        ];
90
91        weights.iter().map(|(rate, weight)| rate * weight).sum()
92    }
93
94    /// Check if metrics meet minimum quality threshold
95    pub fn meets_threshold(&self, threshold: f32) -> bool {
96        self.validation_rate() >= threshold
97    }
98
99    /// Add metrics for a tool
100    pub fn add_tool_metrics(&mut self, tool_name: &str, metrics: ToolMetrics) {
101        self.total_generated += metrics.total_generated;
102        self.schema_valid += metrics.schema_valid;
103        self.required_params_present += metrics.required_params_present;
104        self.type_correct += metrics.type_correct;
105        self.has_explanation += metrics.has_explanation;
106
107        // Aggregate error breakdown
108        for (error_type, count) in &metrics.error_breakdown {
109            *self.error_breakdown.entry(error_type.clone()).or_insert(0) += count;
110        }
111
112        self.per_tool.insert(tool_name.to_string(), metrics);
113    }
114
115    /// Format as a summary string
116    pub fn summary(&self) -> String {
117        format!(
118            "Accuracy Metrics:\n\
119             - Total Generated: {}\n\
120             - Schema Valid: {} ({:.1}%)\n\
121             - Param Compliance: {:.1}%\n\
122             - Type Correct: {:.1}%\n\
123             - Has Explanation: {:.1}%\n\
124             - Diversity: {:.2}\n\
125             - Overall Quality: {:.2}",
126            self.total_generated,
127            self.schema_valid,
128            self.validation_rate() * 100.0,
129            self.param_compliance_rate() * 100.0,
130            self.type_correctness_rate() * 100.0,
131            self.explanation_rate() * 100.0,
132            self.diversity_score,
133            self.overall_quality()
134        )
135    }
136}
137
138/// Metrics for a single tool's generated examples
139#[derive(Debug, Clone, Default)]
140pub struct ToolMetrics {
141    /// Tool name
142    pub tool_name: String,
143
144    /// Total examples generated for this tool
145    pub total_generated: usize,
146
147    /// Examples that passed validation
148    pub schema_valid: usize,
149
150    /// Examples with all required params
151    pub required_params_present: usize,
152
153    /// Examples with correct types
154    pub type_correct: usize,
155
156    /// Examples with non-empty explanations
157    pub has_explanation: usize,
158
159    /// Error types for this tool
160    pub error_breakdown: HashMap<String, usize>,
161
162    /// Average confidence score
163    pub avg_confidence: f32,
164}
165
166impl ToolMetrics {
167    /// Create new metrics for a tool
168    pub fn new(tool_name: &str) -> Self {
169        Self {
170            tool_name: tool_name.to_string(),
171            ..Default::default()
172        }
173    }
174
175    /// Calculate validation rate
176    pub fn validation_rate(&self) -> f32 {
177        if self.total_generated == 0 {
178            return 0.0;
179        }
180        self.schema_valid as f32 / self.total_generated as f32
181    }
182
183    /// Calculate type correctness rate
184    pub fn type_correctness_rate(&self) -> f32 {
185        if self.total_generated == 0 {
186            return 0.0;
187        }
188        self.type_correct as f32 / self.total_generated as f32
189    }
190
191    /// Calculate required param compliance rate
192    pub fn param_compliance_rate(&self) -> f32 {
193        if self.total_generated == 0 {
194            return 0.0;
195        }
196        self.required_params_present as f32 / self.total_generated as f32
197    }
198}
199
200// =============================================================================
201// Accuracy Evaluator
202// =============================================================================
203
204/// Evaluator for measuring accuracy of generated examples
205pub struct AccuracyEvaluator {
206    validator: ExampleValidator,
207}
208
209impl AccuracyEvaluator {
210    /// Create a new evaluator
211    pub fn new() -> Self {
212        Self {
213            validator: ExampleValidator::new(),
214        }
215    }
216
217    /// Create with strict validation
218    pub fn strict() -> Self {
219        Self {
220            validator: ExampleValidator::strict(),
221        }
222    }
223
224    /// Evaluate a batch of examples for a single tool
225    pub fn evaluate_tool(
226        &self,
227        tool: &ToolDocumentation,
228        examples: &[GeneratedExample],
229    ) -> ToolMetrics {
230        let mut metrics = ToolMetrics::new(&tool.name);
231        metrics.total_generated = examples.len();
232
233        let mut total_confidence = 0.0;
234
235        for example in examples {
236            // Check for explanation
237            if !example.explanation.trim().is_empty() {
238                metrics.has_explanation += 1;
239            }
240
241            // Validate example
242            let validation = self.validator.validate_example(example, tool);
243
244            if validation.valid {
245                metrics.schema_valid += 1;
246            }
247
248            // Check required params (more specific than full validation)
249            let parsed = self.validator.parse_command(&example.command);
250            if let Ok(parsed) = parsed {
251                let has_all_required = tool.parameters.iter()
252                    .filter(|p| p.required)
253                    .all(|p| parsed.has_param(&p.name));
254
255                if has_all_required {
256                    metrics.required_params_present += 1;
257                }
258
259                // Type checking would require running validation on each param
260                // For now, count valid examples as type-correct
261                if validation.valid {
262                    metrics.type_correct += 1;
263                }
264            }
265
266            // Track errors
267            for error in &validation.errors {
268                let error_type = categorize_error(error);
269                *metrics.error_breakdown.entry(error_type).or_insert(0) += 1;
270            }
271
272            total_confidence += example.confidence;
273        }
274
275        if !examples.is_empty() {
276            metrics.avg_confidence = total_confidence / examples.len() as f32;
277        }
278
279        metrics
280    }
281
282    /// Evaluate examples for multiple tools
283    pub fn evaluate_batch(
284        &self,
285        tools: &[ToolDocumentation],
286        examples_by_tool: &HashMap<String, Vec<GeneratedExample>>,
287    ) -> AccuracyMetrics {
288        let mut metrics = AccuracyMetrics::new();
289
290        for tool in tools {
291            if let Some(examples) = examples_by_tool.get(&tool.name) {
292                let tool_metrics = self.evaluate_tool(tool, examples);
293                metrics.add_tool_metrics(&tool.name, tool_metrics);
294            }
295        }
296
297        // Calculate diversity across all examples
298        let all_examples: Vec<_> = examples_by_tool.values()
299            .flat_map(|v| v.iter())
300            .cloned()
301            .collect();
302        metrics.diversity_score = self.validator.calculate_diversity(&all_examples);
303
304        metrics
305    }
306
307    /// Evaluate a single tool and return pass/fail with detailed results
308    pub fn evaluate_with_threshold(
309        &self,
310        tool: &ToolDocumentation,
311        examples: &[GeneratedExample],
312        threshold: f32,
313    ) -> (bool, ToolMetrics) {
314        let metrics = self.evaluate_tool(tool, examples);
315        let passes = metrics.validation_rate() >= threshold;
316        (passes, metrics)
317    }
318}
319
320impl Default for AccuracyEvaluator {
321    fn default() -> Self {
322        Self::new()
323    }
324}
325
326/// Categorize an error message into a type
327fn categorize_error(error: &str) -> String {
328    let lower = error.to_lowercase();
329    if lower.contains("required") || lower.contains("missing") {
330        "missing_required".to_string()
331    } else if lower.contains("type") || lower.contains("expected") {
332        "type_mismatch".to_string()
333    } else if lower.contains("parse") {
334        "parse_error".to_string()
335    } else if lower.contains("explanation") {
336        "empty_explanation".to_string()
337    } else {
338        "other".to_string()
339    }
340}
341
342// =============================================================================
343// Performance Metrics
344// =============================================================================
345
346/// Performance metrics for generation
347#[derive(Debug, Clone, Default)]
348pub struct PerformanceMetrics {
349    /// Total time for all generation (ms)
350    pub total_time_ms: u64,
351
352    /// Time per tool (ms)
353    pub per_tool_time_ms: HashMap<String, u64>,
354
355    /// Time to first event (ms)
356    pub time_to_first_event_ms: Option<u64>,
357
358    /// Events per second
359    pub events_per_second: f32,
360
361    /// Total events emitted
362    pub total_events: usize,
363}
364
365impl PerformanceMetrics {
366    /// Create new empty metrics
367    pub fn new() -> Self {
368        Self::default()
369    }
370
371    /// Calculate average time per tool
372    pub fn avg_time_per_tool(&self) -> u64 {
373        if self.per_tool_time_ms.is_empty() {
374            return 0;
375        }
376        let total: u64 = self.per_tool_time_ms.values().sum();
377        total / self.per_tool_time_ms.len() as u64
378    }
379
380    /// Check if meets latency threshold
381    pub fn meets_latency_threshold(&self, max_ms_per_tool: u64) -> bool {
382        self.per_tool_time_ms.values().all(|&t| t <= max_ms_per_tool)
383    }
384
385    /// Format as summary string
386    pub fn summary(&self) -> String {
387        format!(
388            "Performance Metrics:\n\
389             - Total Time: {}ms\n\
390             - Avg per Tool: {}ms\n\
391             - Time to First Event: {:?}ms\n\
392             - Events/sec: {:.1}\n\
393             - Total Events: {}",
394            self.total_time_ms,
395            self.avg_time_per_tool(),
396            self.time_to_first_event_ms,
397            self.events_per_second,
398            self.total_events
399        )
400    }
401}
402
403// =============================================================================
404// Tests
405// =============================================================================
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use super::super::fixtures::*;
411
412    #[test]
413    fn test_accuracy_metrics_calculation() {
414        let mut metrics = AccuracyMetrics::new();
415        metrics.total_generated = 10;
416        metrics.schema_valid = 9;
417        metrics.required_params_present = 10;
418        metrics.type_correct = 8;
419        metrics.has_explanation = 10;
420        metrics.diversity_score = 0.75;
421
422        assert!((metrics.validation_rate() - 0.9).abs() < 0.01);
423        assert!((metrics.param_compliance_rate() - 1.0).abs() < 0.01);
424        assert!((metrics.type_correctness_rate() - 0.8).abs() < 0.01);
425        assert!(metrics.overall_quality() > 0.8);
426    }
427
428    #[test]
429    fn test_empty_metrics() {
430        let metrics = AccuracyMetrics::new();
431        assert_eq!(metrics.validation_rate(), 0.0);
432        assert_eq!(metrics.param_compliance_rate(), 0.0);
433        assert_eq!(metrics.overall_quality(), 0.0);
434    }
435
436    #[test]
437    fn test_meets_threshold() {
438        let mut metrics = AccuracyMetrics::new();
439        metrics.total_generated = 100;
440        metrics.schema_valid = 95;
441
442        assert!(metrics.meets_threshold(0.95));
443        assert!(!metrics.meets_threshold(0.96));
444    }
445
446    #[test]
447    fn test_tool_metrics() {
448        let mut metrics = ToolMetrics::new("apply");
449        metrics.total_generated = 5;
450        metrics.schema_valid = 4;
451
452        assert_eq!(metrics.tool_name, "apply");
453        assert!((metrics.validation_rate() - 0.8).abs() < 0.01);
454    }
455
456    #[test]
457    fn test_evaluator_with_valid_examples() {
458        let evaluator = AccuracyEvaluator::new();
459        let tool = kubernetes_apply_tool();
460
461        let examples = vec![
462            GeneratedExample::new(
463                "skill run kubernetes:apply --file=deploy.yaml",
464                "Apply deployment manifest"
465            ).with_confidence(0.9),
466            GeneratedExample::new(
467                "skill run kubernetes:apply --file=service.yaml --namespace=prod",
468                "Apply to production"
469            ).with_confidence(0.85),
470        ];
471
472        let metrics = evaluator.evaluate_tool(&tool, &examples);
473
474        assert_eq!(metrics.total_generated, 2);
475        assert!(metrics.validation_rate() > 0.0);
476        assert!(metrics.has_explanation > 0);
477    }
478
479    #[test]
480    fn test_evaluator_with_invalid_examples() {
481        let evaluator = AccuracyEvaluator::new();
482        let tool = kubernetes_apply_tool();
483
484        let examples = vec![
485            // Missing required 'file' parameter
486            GeneratedExample::new(
487                "skill run kubernetes:apply --namespace=prod",
488                "Missing file param"
489            ),
490            // Empty explanation
491            GeneratedExample::new(
492                "skill run kubernetes:apply --file=test.yaml",
493                ""
494            ),
495        ];
496
497        let metrics = evaluator.evaluate_tool(&tool, &examples);
498
499        assert_eq!(metrics.total_generated, 2);
500        // Both should fail - one missing required param, one empty explanation
501        assert!(metrics.schema_valid < 2);
502        assert_eq!(metrics.has_explanation, 1); // Only first has explanation
503    }
504
505    #[test]
506    fn test_error_categorization() {
507        assert_eq!(categorize_error("Missing required parameter: file"), "missing_required");
508        assert_eq!(categorize_error("expected integer, got 'abc'"), "type_mismatch");
509        assert_eq!(categorize_error("Failed to parse command"), "parse_error");
510        assert_eq!(categorize_error("explanation is empty"), "empty_explanation");
511        assert_eq!(categorize_error("unknown error"), "other");
512    }
513
514    #[test]
515    fn test_performance_metrics() {
516        let mut metrics = PerformanceMetrics::new();
517        metrics.total_time_ms = 5000;
518        metrics.per_tool_time_ms.insert("apply".to_string(), 1000);
519        metrics.per_tool_time_ms.insert("get".to_string(), 2000);
520        metrics.total_events = 50;
521        metrics.events_per_second = 10.0;
522
523        assert_eq!(metrics.avg_time_per_tool(), 1500);
524        assert!(metrics.meets_latency_threshold(2000));
525        assert!(!metrics.meets_latency_threshold(1500));
526    }
527
528    #[test]
529    fn test_batch_evaluation() {
530        let evaluator = AccuracyEvaluator::new();
531
532        let tools = vec![
533            kubernetes_apply_tool(),
534            simple_tool(),
535        ];
536
537        let mut examples_by_tool = HashMap::new();
538        examples_by_tool.insert(
539            "apply".to_string(),
540            vec![GeneratedExample::new("skill run kubernetes:apply --file=test.yaml", "Test")],
541        );
542        examples_by_tool.insert(
543            "list".to_string(),
544            vec![GeneratedExample::new("skill run tool:list --type=pods", "List pods")],
545        );
546
547        let metrics = evaluator.evaluate_batch(&tools, &examples_by_tool);
548
549        assert_eq!(metrics.total_generated, 2);
550        assert_eq!(metrics.per_tool.len(), 2);
551        assert!(metrics.diversity_score > 0.0);
552    }
553}