skill_runtime/generation/
validator.rs

1//! Example validator for AI-generated tool examples
2//!
3//! Validates generated examples against tool parameter schemas and
4//! checks for diversity across examples.
5
6use std::collections::HashMap;
7use crate::skill_md::{ToolDocumentation, ParameterDoc, ParameterType};
8use super::streaming::GeneratedExample;
9
10/// Result of validating an example
11#[derive(Debug, Clone)]
12pub struct ValidationResult {
13    /// Whether the example is valid
14    pub valid: bool,
15    /// Validation errors (if any)
16    pub errors: Vec<String>,
17    /// Validation warnings (non-fatal)
18    pub warnings: Vec<String>,
19    /// Adjusted confidence score
20    pub confidence: f32,
21}
22
23impl ValidationResult {
24    /// Create a valid result
25    pub fn valid(confidence: f32) -> Self {
26        Self {
27            valid: true,
28            errors: Vec::new(),
29            warnings: Vec::new(),
30            confidence,
31        }
32    }
33
34    /// Create an invalid result
35    pub fn invalid(errors: Vec<String>) -> Self {
36        Self {
37            valid: false,
38            errors,
39            warnings: Vec::new(),
40            confidence: 0.0,
41        }
42    }
43
44    /// Add a warning
45    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
46        self.warnings.push(warning.into());
47        self
48    }
49}
50
51/// Parsed command representation
52#[derive(Debug, Clone)]
53pub struct ParsedCommand {
54    /// Skill name (e.g., "kubernetes")
55    pub skill: Option<String>,
56    /// Tool name (e.g., "apply")
57    pub tool: Option<String>,
58    /// Positional arguments
59    pub positional: Vec<String>,
60    /// Named parameters (--param=value or --param value)
61    pub parameters: HashMap<String, String>,
62    /// Flags (--flag without value)
63    pub flags: Vec<String>,
64}
65
66impl ParsedCommand {
67    /// Check if a parameter is present (by name)
68    pub fn has_param(&self, name: &str) -> bool {
69        self.parameters.contains_key(name) || self.flags.contains(&name.to_string())
70    }
71
72    /// Get parameter value
73    pub fn get_param(&self, name: &str) -> Option<&String> {
74        self.parameters.get(name)
75    }
76}
77
78/// Validator for generated examples
79pub struct ExampleValidator {
80    /// Minimum diversity score threshold
81    pub diversity_threshold: f32,
82    /// Strict mode - fail on warnings
83    pub strict: bool,
84}
85
86impl Default for ExampleValidator {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl ExampleValidator {
93    /// Create a new validator with default settings
94    pub fn new() -> Self {
95        Self {
96            diversity_threshold: 0.7,
97            strict: false,
98        }
99    }
100
101    /// Create a strict validator
102    pub fn strict() -> Self {
103        Self {
104            diversity_threshold: 0.8,
105            strict: true,
106        }
107    }
108
109    /// Set diversity threshold
110    pub fn with_diversity_threshold(mut self, threshold: f32) -> Self {
111        self.diversity_threshold = threshold.clamp(0.0, 1.0);
112        self
113    }
114
115    /// Validate a single example against a tool's documentation
116    pub fn validate_example(
117        &self,
118        example: &GeneratedExample,
119        tool: &ToolDocumentation,
120    ) -> ValidationResult {
121        let mut errors = Vec::new();
122        let mut warnings = Vec::new();
123
124        // Parse the command
125        let parsed = match self.parse_command(&example.command) {
126            Ok(p) => p,
127            Err(e) => {
128                return ValidationResult::invalid(vec![format!("Failed to parse command: {}", e)]);
129            }
130        };
131
132        // Validate tool name matches (if extractable)
133        if let Some(ref tool_name) = parsed.tool {
134            let expected_name = &tool.name;
135            if !tool_name.eq_ignore_ascii_case(expected_name) &&
136               !tool_name.contains(expected_name) &&
137               !expected_name.contains(tool_name) {
138                warnings.push(format!(
139                    "Tool name mismatch: expected '{}', got '{}'",
140                    expected_name, tool_name
141                ));
142            }
143        }
144
145        // Validate required parameters are present
146        for param in &tool.parameters {
147            if param.required && !parsed.has_param(&param.name) {
148                // Check for common aliases
149                let has_alias = param.name.chars().next()
150                    .map(|c| parsed.flags.contains(&c.to_string()))
151                    .unwrap_or(false);
152
153                if !has_alias {
154                    errors.push(format!("Missing required parameter: {}", param.name));
155                }
156            }
157        }
158
159        // Validate parameter types if possible
160        for (name, value) in &parsed.parameters {
161            if let Some(param) = tool.parameters.iter().find(|p| p.name == *name) {
162                if let Err(e) = self.validate_param_type(value, &param.param_type) {
163                    warnings.push(format!("Parameter '{}': {}", name, e));
164                }
165            }
166        }
167
168        // Check for unknown parameters
169        for name in parsed.parameters.keys() {
170            if !tool.parameters.iter().any(|p| p.name == *name) {
171                // Not necessarily an error, could be a valid flag
172                warnings.push(format!("Unknown parameter: {}", name));
173            }
174        }
175
176        // Validate explanation is not empty
177        if example.explanation.trim().is_empty() {
178            errors.push("Example explanation is empty".to_string());
179        }
180
181        // Calculate final validity
182        let valid = errors.is_empty() && (!self.strict || warnings.is_empty());
183
184        // Adjust confidence based on warnings
185        let confidence = if valid {
186            let warning_penalty = 0.1 * warnings.len() as f32;
187            (example.confidence - warning_penalty).max(0.1)
188        } else {
189            0.0
190        };
191
192        ValidationResult {
193            valid,
194            errors,
195            warnings,
196            confidence,
197        }
198    }
199
200    /// Validate multiple examples and return batch results
201    pub fn validate_batch(
202        &self,
203        examples: &[GeneratedExample],
204        tool: &ToolDocumentation,
205    ) -> Vec<ValidationResult> {
206        examples
207            .iter()
208            .map(|e| self.validate_example(e, tool))
209            .collect()
210    }
211
212    /// Calculate diversity score for a set of examples
213    /// Returns a score from 0.0 (all identical) to 1.0 (completely diverse)
214    pub fn calculate_diversity(&self, examples: &[GeneratedExample]) -> f32 {
215        if examples.len() < 2 {
216            return 1.0; // Single example is "diverse" by default
217        }
218
219        // Use simple command similarity as a proxy for diversity
220        let mut total_similarity = 0.0;
221        let mut pairs = 0;
222
223        for i in 0..examples.len() {
224            for j in (i + 1)..examples.len() {
225                let similarity = self.command_similarity(&examples[i].command, &examples[j].command);
226                total_similarity += similarity;
227                pairs += 1;
228            }
229        }
230
231        if pairs == 0 {
232            return 1.0;
233        }
234
235        // Average similarity, converted to diversity (1 - similarity)
236        1.0 - (total_similarity / pairs as f32)
237    }
238
239    /// Check if diversity meets threshold
240    pub fn check_diversity(&self, examples: &[GeneratedExample]) -> bool {
241        self.calculate_diversity(examples) >= self.diversity_threshold
242    }
243
244    /// Calculate simple command similarity using Jaccard index
245    fn command_similarity(&self, cmd1: &str, cmd2: &str) -> f32 {
246        let tokens1: std::collections::HashSet<_> = cmd1.split_whitespace().collect();
247        let tokens2: std::collections::HashSet<_> = cmd2.split_whitespace().collect();
248
249        let intersection = tokens1.intersection(&tokens2).count();
250        let union = tokens1.union(&tokens2).count();
251
252        if union == 0 {
253            return 1.0;
254        }
255
256        intersection as f32 / union as f32
257    }
258
259    /// Parse a command string into components
260    ///
261    /// Supports formats:
262    /// - `skill run tool:name --param=value`
263    /// - `skill run skill:tool param=value`
264    /// - `tool --flag --param value`
265    pub fn parse_command(&self, command: &str) -> Result<ParsedCommand, String> {
266        let mut parsed = ParsedCommand {
267            skill: None,
268            tool: None,
269            positional: Vec::new(),
270            parameters: HashMap::new(),
271            flags: Vec::new(),
272        };
273
274        let tokens: Vec<&str> = command.split_whitespace().collect();
275
276        if tokens.is_empty() {
277            return Err("Empty command".to_string());
278        }
279
280        let mut i = 0;
281
282        // Skip "skill run" prefix if present
283        if tokens.get(0) == Some(&"skill") {
284            i += 1;
285            if tokens.get(i) == Some(&"run") {
286                i += 1;
287            }
288        }
289
290        // Parse tool identifier (skill:tool or just tool)
291        if let Some(tool_part) = tokens.get(i) {
292            if tool_part.contains(':') {
293                let parts: Vec<&str> = tool_part.splitn(2, ':').collect();
294                parsed.skill = Some(parts[0].to_string());
295                parsed.tool = Some(parts.get(1).unwrap_or(&"").to_string());
296            } else if !tool_part.starts_with('-') {
297                parsed.tool = Some(tool_part.to_string());
298            }
299            i += 1;
300        }
301
302        // Parse remaining arguments
303        while i < tokens.len() {
304            let token = tokens[i];
305
306            if token.starts_with("--") {
307                // Long parameter
308                let param = &token[2..];
309                if let Some((name, value)) = param.split_once('=') {
310                    parsed.parameters.insert(name.to_string(), value.to_string());
311                } else if i + 1 < tokens.len() && !tokens[i + 1].starts_with('-') {
312                    // Next token is the value
313                    parsed.parameters.insert(param.to_string(), tokens[i + 1].to_string());
314                    i += 1;
315                } else {
316                    // Flag without value
317                    parsed.flags.push(param.to_string());
318                }
319            } else if token.starts_with('-') && token.len() == 2 {
320                // Short flag
321                let flag = &token[1..];
322                if i + 1 < tokens.len() && !tokens[i + 1].starts_with('-') {
323                    parsed.parameters.insert(flag.to_string(), tokens[i + 1].to_string());
324                    i += 1;
325                } else {
326                    parsed.flags.push(flag.to_string());
327                }
328            } else if token.contains('=') {
329                // key=value format (without --)
330                if let Some((name, value)) = token.split_once('=') {
331                    parsed.parameters.insert(name.to_string(), value.to_string());
332                }
333            } else {
334                // Positional argument
335                parsed.positional.push(token.to_string());
336            }
337
338            i += 1;
339        }
340
341        Ok(parsed)
342    }
343
344    /// Validate a parameter value against a ParameterType
345    fn validate_param_type(&self, value: &str, param_type: &ParameterType) -> Result<(), String> {
346        match param_type {
347            ParameterType::String => Ok(()),
348            ParameterType::Integer => {
349                value.parse::<i64>()
350                    .map(|_| ())
351                    .map_err(|_| format!("expected integer, got '{}'", value))
352            }
353            ParameterType::Number => {
354                value.parse::<f64>()
355                    .map(|_| ())
356                    .map_err(|_| format!("expected number, got '{}'", value))
357            }
358            ParameterType::Boolean => {
359                match value.to_lowercase().as_str() {
360                    "true" | "false" | "yes" | "no" | "1" | "0" => Ok(()),
361                    _ => Err(format!("expected boolean, got '{}'", value)),
362                }
363            }
364            ParameterType::Array => Ok(()), // Can't easily validate array syntax
365            ParameterType::Object => Ok(()), // Can't easily validate object syntax
366        }
367    }
368
369    /// Validate a parameter value against a type hint string (for tests)
370    #[allow(dead_code)]
371    fn validate_type(&self, value: &str, type_hint: &str) -> Result<(), String> {
372        let type_lower = type_hint.to_lowercase();
373
374        match type_lower.as_str() {
375            "int" | "integer" | "number" => {
376                value.parse::<i64>()
377                    .map(|_| ())
378                    .map_err(|_| format!("expected integer, got '{}'", value))
379            }
380            "float" | "decimal" => {
381                value.parse::<f64>()
382                    .map(|_| ())
383                    .map_err(|_| format!("expected number, got '{}'", value))
384            }
385            "bool" | "boolean" => {
386                match value.to_lowercase().as_str() {
387                    "true" | "false" | "yes" | "no" | "1" | "0" => Ok(()),
388                    _ => Err(format!("expected boolean, got '{}'", value)),
389                }
390            }
391            "path" | "file" => {
392                // Basic path validation
393                if value.is_empty() {
394                    Err("empty path".to_string())
395                } else {
396                    Ok(())
397                }
398            }
399            "url" => {
400                if value.starts_with("http://") || value.starts_with("https://") {
401                    Ok(())
402                } else {
403                    Err(format!("expected URL, got '{}'", value))
404                }
405            }
406            _ => Ok(()), // Unknown types pass
407        }
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    fn create_test_tool() -> ToolDocumentation {
416        ToolDocumentation {
417            name: "apply".to_string(),
418            description: "Apply a Kubernetes manifest".to_string(),
419            usage: None,
420            parameters: vec![
421                ParameterDoc {
422                    name: "file".to_string(),
423                    param_type: ParameterType::String,
424                    description: "Path to manifest file".to_string(),
425                    required: true,
426                    default: None,
427                    allowed_values: vec![],
428                },
429                ParameterDoc {
430                    name: "namespace".to_string(),
431                    param_type: ParameterType::String,
432                    description: "Target namespace".to_string(),
433                    required: false,
434                    default: Some("default".to_string()),
435                    allowed_values: vec![],
436                },
437                ParameterDoc {
438                    name: "dry-run".to_string(),
439                    param_type: ParameterType::Boolean,
440                    description: "Perform dry run".to_string(),
441                    required: false,
442                    default: None,
443                    allowed_values: vec![],
444                },
445            ],
446            examples: vec![],
447        }
448    }
449
450    #[test]
451    fn test_parse_command_basic() {
452        let validator = ExampleValidator::new();
453        let parsed = validator.parse_command("skill run k8s:apply --file=deploy.yaml").unwrap();
454
455        assert_eq!(parsed.skill, Some("k8s".to_string()));
456        assert_eq!(parsed.tool, Some("apply".to_string()));
457        assert_eq!(parsed.get_param("file"), Some(&"deploy.yaml".to_string()));
458    }
459
460    #[test]
461    fn test_parse_command_separate_value() {
462        let validator = ExampleValidator::new();
463        let parsed = validator.parse_command("skill run apply --file deploy.yaml --namespace prod").unwrap();
464
465        assert_eq!(parsed.tool, Some("apply".to_string()));
466        assert_eq!(parsed.get_param("file"), Some(&"deploy.yaml".to_string()));
467        assert_eq!(parsed.get_param("namespace"), Some(&"prod".to_string()));
468    }
469
470    #[test]
471    fn test_parse_command_flags() {
472        let validator = ExampleValidator::new();
473        let parsed = validator.parse_command("apply --dry-run --file=test.yaml").unwrap();
474
475        assert!(parsed.flags.contains(&"dry-run".to_string()));
476        assert!(parsed.has_param("dry-run"));
477    }
478
479    #[test]
480    fn test_parse_command_key_value() {
481        let validator = ExampleValidator::new();
482        let parsed = validator.parse_command("skill run tool namespace=default file=app.yaml").unwrap();
483
484        assert_eq!(parsed.get_param("namespace"), Some(&"default".to_string()));
485        assert_eq!(parsed.get_param("file"), Some(&"app.yaml".to_string()));
486    }
487
488    #[test]
489    fn test_validate_example_valid() {
490        let validator = ExampleValidator::new();
491        let tool = create_test_tool();
492
493        let example = GeneratedExample {
494            command: "skill run k8s:apply --file=deploy.yaml".to_string(),
495            explanation: "Apply deployment manifest".to_string(),
496            confidence: 0.9,
497            validated: false,
498            category: None,
499            parameters: None,
500        };
501
502        let result = validator.validate_example(&example, &tool);
503        assert!(result.valid);
504        assert!(result.errors.is_empty());
505    }
506
507    #[test]
508    fn test_validate_example_missing_required() {
509        let validator = ExampleValidator::new();
510        let tool = create_test_tool();
511
512        let example = GeneratedExample {
513            command: "skill run k8s:apply --namespace=prod".to_string(),
514            explanation: "Apply to prod namespace".to_string(),
515            confidence: 0.8,
516            validated: false,
517            category: None,
518            parameters: None,
519        };
520
521        let result = validator.validate_example(&example, &tool);
522        assert!(!result.valid);
523        assert!(result.errors.iter().any(|e| e.contains("file")));
524    }
525
526    #[test]
527    fn test_validate_example_empty_explanation() {
528        let validator = ExampleValidator::new();
529        let tool = create_test_tool();
530
531        let example = GeneratedExample {
532            command: "skill run k8s:apply --file=test.yaml".to_string(),
533            explanation: "  ".to_string(),
534            confidence: 0.9,
535            validated: false,
536            category: None,
537            parameters: None,
538        };
539
540        let result = validator.validate_example(&example, &tool);
541        assert!(!result.valid);
542        assert!(result.errors.iter().any(|e| e.contains("explanation")));
543    }
544
545    #[test]
546    fn test_diversity_identical() {
547        let validator = ExampleValidator::new();
548        let examples = vec![
549            GeneratedExample::new("skill run apply --file=a.yaml", "Apply a"),
550            GeneratedExample::new("skill run apply --file=a.yaml", "Apply a"),
551        ];
552
553        let diversity = validator.calculate_diversity(&examples);
554        assert!(diversity < 0.5); // Low diversity for identical commands
555    }
556
557    #[test]
558    fn test_diversity_different() {
559        let validator = ExampleValidator::new();
560        let examples = vec![
561            GeneratedExample::new("skill run apply --file=deploy.yaml", "Deploy app"),
562            GeneratedExample::new("skill run delete --namespace=prod --all", "Delete all in prod"),
563            GeneratedExample::new("skill run get pods --output=json", "List pods as JSON"),
564        ];
565
566        let diversity = validator.calculate_diversity(&examples);
567        assert!(diversity > 0.5); // High diversity for different commands
568    }
569
570    #[test]
571    fn test_validate_type_integer() {
572        let validator = ExampleValidator::new();
573
574        assert!(validator.validate_type("123", "integer").is_ok());
575        assert!(validator.validate_type("-42", "int").is_ok());
576        assert!(validator.validate_type("abc", "integer").is_err());
577    }
578
579    #[test]
580    fn test_validate_type_boolean() {
581        let validator = ExampleValidator::new();
582
583        assert!(validator.validate_type("true", "boolean").is_ok());
584        assert!(validator.validate_type("false", "bool").is_ok());
585        assert!(validator.validate_type("yes", "boolean").is_ok());
586        assert!(validator.validate_type("maybe", "boolean").is_err());
587    }
588
589    #[test]
590    fn test_validate_type_url() {
591        let validator = ExampleValidator::new();
592
593        assert!(validator.validate_type("https://example.com", "url").is_ok());
594        assert!(validator.validate_type("http://localhost:8080", "url").is_ok());
595        assert!(validator.validate_type("not-a-url", "url").is_err());
596    }
597
598    #[test]
599    fn test_batch_validation() {
600        let validator = ExampleValidator::new();
601        let tool = create_test_tool();
602
603        let examples = vec![
604            GeneratedExample::new("skill run apply --file=a.yaml", "Apply a"),
605            GeneratedExample::new("skill run apply --namespace=prod", "Missing file"),
606        ];
607
608        let results = validator.validate_batch(&examples, &tool);
609        assert_eq!(results.len(), 2);
610        assert!(results[0].valid);
611        assert!(!results[1].valid);
612    }
613}