skill_runtime/generation/
evaluation.rs1use std::collections::HashMap;
7use crate::skill_md::ToolDocumentation;
8use super::streaming::GeneratedExample;
9use super::validator::ExampleValidator;
10
11#[derive(Debug, Clone, Default)]
17pub struct AccuracyMetrics {
18 pub total_generated: usize,
20
21 pub schema_valid: usize,
23
24 pub required_params_present: usize,
26
27 pub type_correct: usize,
29
30 pub has_explanation: usize,
32
33 pub diversity_score: f32,
35
36 pub per_tool: HashMap<String, ToolMetrics>,
38
39 pub error_breakdown: HashMap<String, usize>,
41}
42
43impl AccuracyMetrics {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn validation_rate(&self) -> f32 {
51 if self.total_generated == 0 {
52 return 0.0;
53 }
54 self.schema_valid as f32 / self.total_generated as f32
55 }
56
57 pub fn param_compliance_rate(&self) -> f32 {
59 if self.total_generated == 0 {
60 return 0.0;
61 }
62 self.required_params_present as f32 / self.total_generated as f32
63 }
64
65 pub fn type_correctness_rate(&self) -> f32 {
67 if self.total_generated == 0 {
68 return 0.0;
69 }
70 self.type_correct as f32 / self.total_generated as f32
71 }
72
73 pub fn explanation_rate(&self) -> f32 {
75 if self.total_generated == 0 {
76 return 0.0;
77 }
78 self.has_explanation as f32 / self.total_generated as f32
79 }
80
81 pub fn overall_quality(&self) -> f32 {
83 let weights = [
84 (self.validation_rate(), 0.4), (self.param_compliance_rate(), 0.25), (self.type_correctness_rate(), 0.15), (self.explanation_rate(), 0.1), (self.diversity_score, 0.1), ];
90
91 weights.iter().map(|(rate, weight)| rate * weight).sum()
92 }
93
94 pub fn meets_threshold(&self, threshold: f32) -> bool {
96 self.validation_rate() >= threshold
97 }
98
99 pub fn add_tool_metrics(&mut self, tool_name: &str, metrics: ToolMetrics) {
101 self.total_generated += metrics.total_generated;
102 self.schema_valid += metrics.schema_valid;
103 self.required_params_present += metrics.required_params_present;
104 self.type_correct += metrics.type_correct;
105 self.has_explanation += metrics.has_explanation;
106
107 for (error_type, count) in &metrics.error_breakdown {
109 *self.error_breakdown.entry(error_type.clone()).or_insert(0) += count;
110 }
111
112 self.per_tool.insert(tool_name.to_string(), metrics);
113 }
114
115 pub fn summary(&self) -> String {
117 format!(
118 "Accuracy Metrics:\n\
119 - Total Generated: {}\n\
120 - Schema Valid: {} ({:.1}%)\n\
121 - Param Compliance: {:.1}%\n\
122 - Type Correct: {:.1}%\n\
123 - Has Explanation: {:.1}%\n\
124 - Diversity: {:.2}\n\
125 - Overall Quality: {:.2}",
126 self.total_generated,
127 self.schema_valid,
128 self.validation_rate() * 100.0,
129 self.param_compliance_rate() * 100.0,
130 self.type_correctness_rate() * 100.0,
131 self.explanation_rate() * 100.0,
132 self.diversity_score,
133 self.overall_quality()
134 )
135 }
136}
137
138#[derive(Debug, Clone, Default)]
140pub struct ToolMetrics {
141 pub tool_name: String,
143
144 pub total_generated: usize,
146
147 pub schema_valid: usize,
149
150 pub required_params_present: usize,
152
153 pub type_correct: usize,
155
156 pub has_explanation: usize,
158
159 pub error_breakdown: HashMap<String, usize>,
161
162 pub avg_confidence: f32,
164}
165
166impl ToolMetrics {
167 pub fn new(tool_name: &str) -> Self {
169 Self {
170 tool_name: tool_name.to_string(),
171 ..Default::default()
172 }
173 }
174
175 pub fn validation_rate(&self) -> f32 {
177 if self.total_generated == 0 {
178 return 0.0;
179 }
180 self.schema_valid as f32 / self.total_generated as f32
181 }
182
183 pub fn type_correctness_rate(&self) -> f32 {
185 if self.total_generated == 0 {
186 return 0.0;
187 }
188 self.type_correct as f32 / self.total_generated as f32
189 }
190
191 pub fn param_compliance_rate(&self) -> f32 {
193 if self.total_generated == 0 {
194 return 0.0;
195 }
196 self.required_params_present as f32 / self.total_generated as f32
197 }
198}
199
200pub struct AccuracyEvaluator {
206 validator: ExampleValidator,
207}
208
209impl AccuracyEvaluator {
210 pub fn new() -> Self {
212 Self {
213 validator: ExampleValidator::new(),
214 }
215 }
216
217 pub fn strict() -> Self {
219 Self {
220 validator: ExampleValidator::strict(),
221 }
222 }
223
224 pub fn evaluate_tool(
226 &self,
227 tool: &ToolDocumentation,
228 examples: &[GeneratedExample],
229 ) -> ToolMetrics {
230 let mut metrics = ToolMetrics::new(&tool.name);
231 metrics.total_generated = examples.len();
232
233 let mut total_confidence = 0.0;
234
235 for example in examples {
236 if !example.explanation.trim().is_empty() {
238 metrics.has_explanation += 1;
239 }
240
241 let validation = self.validator.validate_example(example, tool);
243
244 if validation.valid {
245 metrics.schema_valid += 1;
246 }
247
248 let parsed = self.validator.parse_command(&example.command);
250 if let Ok(parsed) = parsed {
251 let has_all_required = tool.parameters.iter()
252 .filter(|p| p.required)
253 .all(|p| parsed.has_param(&p.name));
254
255 if has_all_required {
256 metrics.required_params_present += 1;
257 }
258
259 if validation.valid {
262 metrics.type_correct += 1;
263 }
264 }
265
266 for error in &validation.errors {
268 let error_type = categorize_error(error);
269 *metrics.error_breakdown.entry(error_type).or_insert(0) += 1;
270 }
271
272 total_confidence += example.confidence;
273 }
274
275 if !examples.is_empty() {
276 metrics.avg_confidence = total_confidence / examples.len() as f32;
277 }
278
279 metrics
280 }
281
282 pub fn evaluate_batch(
284 &self,
285 tools: &[ToolDocumentation],
286 examples_by_tool: &HashMap<String, Vec<GeneratedExample>>,
287 ) -> AccuracyMetrics {
288 let mut metrics = AccuracyMetrics::new();
289
290 for tool in tools {
291 if let Some(examples) = examples_by_tool.get(&tool.name) {
292 let tool_metrics = self.evaluate_tool(tool, examples);
293 metrics.add_tool_metrics(&tool.name, tool_metrics);
294 }
295 }
296
297 let all_examples: Vec<_> = examples_by_tool.values()
299 .flat_map(|v| v.iter())
300 .cloned()
301 .collect();
302 metrics.diversity_score = self.validator.calculate_diversity(&all_examples);
303
304 metrics
305 }
306
307 pub fn evaluate_with_threshold(
309 &self,
310 tool: &ToolDocumentation,
311 examples: &[GeneratedExample],
312 threshold: f32,
313 ) -> (bool, ToolMetrics) {
314 let metrics = self.evaluate_tool(tool, examples);
315 let passes = metrics.validation_rate() >= threshold;
316 (passes, metrics)
317 }
318}
319
320impl Default for AccuracyEvaluator {
321 fn default() -> Self {
322 Self::new()
323 }
324}
325
326fn categorize_error(error: &str) -> String {
328 let lower = error.to_lowercase();
329 if lower.contains("required") || lower.contains("missing") {
330 "missing_required".to_string()
331 } else if lower.contains("type") || lower.contains("expected") {
332 "type_mismatch".to_string()
333 } else if lower.contains("parse") {
334 "parse_error".to_string()
335 } else if lower.contains("explanation") {
336 "empty_explanation".to_string()
337 } else {
338 "other".to_string()
339 }
340}
341
342#[derive(Debug, Clone, Default)]
348pub struct PerformanceMetrics {
349 pub total_time_ms: u64,
351
352 pub per_tool_time_ms: HashMap<String, u64>,
354
355 pub time_to_first_event_ms: Option<u64>,
357
358 pub events_per_second: f32,
360
361 pub total_events: usize,
363}
364
365impl PerformanceMetrics {
366 pub fn new() -> Self {
368 Self::default()
369 }
370
371 pub fn avg_time_per_tool(&self) -> u64 {
373 if self.per_tool_time_ms.is_empty() {
374 return 0;
375 }
376 let total: u64 = self.per_tool_time_ms.values().sum();
377 total / self.per_tool_time_ms.len() as u64
378 }
379
380 pub fn meets_latency_threshold(&self, max_ms_per_tool: u64) -> bool {
382 self.per_tool_time_ms.values().all(|&t| t <= max_ms_per_tool)
383 }
384
385 pub fn summary(&self) -> String {
387 format!(
388 "Performance Metrics:\n\
389 - Total Time: {}ms\n\
390 - Avg per Tool: {}ms\n\
391 - Time to First Event: {:?}ms\n\
392 - Events/sec: {:.1}\n\
393 - Total Events: {}",
394 self.total_time_ms,
395 self.avg_time_per_tool(),
396 self.time_to_first_event_ms,
397 self.events_per_second,
398 self.total_events
399 )
400 }
401}
402
403#[cfg(test)]
408mod tests {
409 use super::*;
410 use super::super::fixtures::*;
411
412 #[test]
413 fn test_accuracy_metrics_calculation() {
414 let mut metrics = AccuracyMetrics::new();
415 metrics.total_generated = 10;
416 metrics.schema_valid = 9;
417 metrics.required_params_present = 10;
418 metrics.type_correct = 8;
419 metrics.has_explanation = 10;
420 metrics.diversity_score = 0.75;
421
422 assert!((metrics.validation_rate() - 0.9).abs() < 0.01);
423 assert!((metrics.param_compliance_rate() - 1.0).abs() < 0.01);
424 assert!((metrics.type_correctness_rate() - 0.8).abs() < 0.01);
425 assert!(metrics.overall_quality() > 0.8);
426 }
427
428 #[test]
429 fn test_empty_metrics() {
430 let metrics = AccuracyMetrics::new();
431 assert_eq!(metrics.validation_rate(), 0.0);
432 assert_eq!(metrics.param_compliance_rate(), 0.0);
433 assert_eq!(metrics.overall_quality(), 0.0);
434 }
435
436 #[test]
437 fn test_meets_threshold() {
438 let mut metrics = AccuracyMetrics::new();
439 metrics.total_generated = 100;
440 metrics.schema_valid = 95;
441
442 assert!(metrics.meets_threshold(0.95));
443 assert!(!metrics.meets_threshold(0.96));
444 }
445
446 #[test]
447 fn test_tool_metrics() {
448 let mut metrics = ToolMetrics::new("apply");
449 metrics.total_generated = 5;
450 metrics.schema_valid = 4;
451
452 assert_eq!(metrics.tool_name, "apply");
453 assert!((metrics.validation_rate() - 0.8).abs() < 0.01);
454 }
455
456 #[test]
457 fn test_evaluator_with_valid_examples() {
458 let evaluator = AccuracyEvaluator::new();
459 let tool = kubernetes_apply_tool();
460
461 let examples = vec![
462 GeneratedExample::new(
463 "skill run kubernetes:apply --file=deploy.yaml",
464 "Apply deployment manifest"
465 ).with_confidence(0.9),
466 GeneratedExample::new(
467 "skill run kubernetes:apply --file=service.yaml --namespace=prod",
468 "Apply to production"
469 ).with_confidence(0.85),
470 ];
471
472 let metrics = evaluator.evaluate_tool(&tool, &examples);
473
474 assert_eq!(metrics.total_generated, 2);
475 assert!(metrics.validation_rate() > 0.0);
476 assert!(metrics.has_explanation > 0);
477 }
478
479 #[test]
480 fn test_evaluator_with_invalid_examples() {
481 let evaluator = AccuracyEvaluator::new();
482 let tool = kubernetes_apply_tool();
483
484 let examples = vec![
485 GeneratedExample::new(
487 "skill run kubernetes:apply --namespace=prod",
488 "Missing file param"
489 ),
490 GeneratedExample::new(
492 "skill run kubernetes:apply --file=test.yaml",
493 ""
494 ),
495 ];
496
497 let metrics = evaluator.evaluate_tool(&tool, &examples);
498
499 assert_eq!(metrics.total_generated, 2);
500 assert!(metrics.schema_valid < 2);
502 assert_eq!(metrics.has_explanation, 1); }
504
505 #[test]
506 fn test_error_categorization() {
507 assert_eq!(categorize_error("Missing required parameter: file"), "missing_required");
508 assert_eq!(categorize_error("expected integer, got 'abc'"), "type_mismatch");
509 assert_eq!(categorize_error("Failed to parse command"), "parse_error");
510 assert_eq!(categorize_error("explanation is empty"), "empty_explanation");
511 assert_eq!(categorize_error("unknown error"), "other");
512 }
513
514 #[test]
515 fn test_performance_metrics() {
516 let mut metrics = PerformanceMetrics::new();
517 metrics.total_time_ms = 5000;
518 metrics.per_tool_time_ms.insert("apply".to_string(), 1000);
519 metrics.per_tool_time_ms.insert("get".to_string(), 2000);
520 metrics.total_events = 50;
521 metrics.events_per_second = 10.0;
522
523 assert_eq!(metrics.avg_time_per_tool(), 1500);
524 assert!(metrics.meets_latency_threshold(2000));
525 assert!(!metrics.meets_latency_threshold(1500));
526 }
527
528 #[test]
529 fn test_batch_evaluation() {
530 let evaluator = AccuracyEvaluator::new();
531
532 let tools = vec![
533 kubernetes_apply_tool(),
534 simple_tool(),
535 ];
536
537 let mut examples_by_tool = HashMap::new();
538 examples_by_tool.insert(
539 "apply".to_string(),
540 vec![GeneratedExample::new("skill run kubernetes:apply --file=test.yaml", "Test")],
541 );
542 examples_by_tool.insert(
543 "list".to_string(),
544 vec![GeneratedExample::new("skill run tool:list --type=pods", "List pods")],
545 );
546
547 let metrics = evaluator.evaluate_batch(&tools, &examples_by_tool);
548
549 assert_eq!(metrics.total_generated, 2);
550 assert_eq!(metrics.per_tool.len(), 2);
551 assert!(metrics.diversity_score > 0.0);
552 }
553}