skill_runtime/generation/
example_generator.rs

1//! Example Generator - Core component for AI-powered example synthesis
2//!
3//! Generates realistic usage examples from tool schemas using LLMs,
4//! with streaming output and validation.
5
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use anyhow::{Context, Result};
10use futures_util::Stream;
11use tokio_stream::StreamExt;
12
13use crate::skill_md::ToolDocumentation;
14use crate::search_config::AiIngestionConfig;
15use super::llm_provider::{LlmProvider, CompletionRequest};
16use super::validator::ExampleValidator;
17use super::streaming::{GenerationEvent, GeneratedExample, GenerationStreamBuilder};
18
19/// Configuration for the example generator
20#[derive(Debug, Clone)]
21pub struct GeneratorConfig {
22    /// Number of examples to generate per tool
23    pub examples_per_tool: usize,
24    /// Whether to validate generated examples
25    pub validate_examples: bool,
26    /// Maximum retries for failed generation
27    pub max_retries: usize,
28    /// Timeout for generation
29    pub timeout: Duration,
30    /// Temperature for LLM generation
31    pub temperature: f32,
32    /// Maximum tokens for response
33    pub max_tokens: u32,
34}
35
36impl Default for GeneratorConfig {
37    fn default() -> Self {
38        Self {
39            examples_per_tool: 5,
40            validate_examples: true,
41            max_retries: 2,
42            timeout: Duration::from_secs(30),
43            temperature: 0.7,
44            max_tokens: 2048,
45        }
46    }
47}
48
49impl From<&AiIngestionConfig> for GeneratorConfig {
50    fn from(config: &AiIngestionConfig) -> Self {
51        Self {
52            examples_per_tool: config.examples_per_tool,
53            validate_examples: config.validate_examples,
54            max_retries: 2,
55            timeout: Duration::from_secs(config.timeout_secs),
56            temperature: 0.7,
57            max_tokens: 2048,
58        }
59    }
60}
61
62/// AI-powered example generator
63pub struct ExampleGenerator {
64    /// LLM provider for generation
65    llm: Arc<dyn LlmProvider>,
66    /// Example validator
67    validator: ExampleValidator,
68    /// Generator configuration
69    config: GeneratorConfig,
70}
71
72impl ExampleGenerator {
73    /// Create a new example generator
74    pub fn new(llm: Arc<dyn LlmProvider>, config: GeneratorConfig) -> Self {
75        Self {
76            llm,
77            validator: ExampleValidator::new(),
78            config,
79        }
80    }
81
82    /// Create from AI ingestion config
83    pub fn from_config(llm: Arc<dyn LlmProvider>, config: &AiIngestionConfig) -> Self {
84        Self::new(llm, GeneratorConfig::from(config))
85    }
86
87    /// Generate examples for a tool (non-streaming)
88    pub async fn generate(&self, tool: &ToolDocumentation) -> Result<Vec<GeneratedExample>> {
89        let mut results = Vec::new();
90        let mut stream = Box::pin(self.generate_stream(tool, 1, 1));
91
92        while let Some(event) = stream.next().await {
93            if let GenerationEvent::Example { example } = event {
94                results.push(example);
95            }
96        }
97
98        Ok(results)
99    }
100
101    /// Generate examples for multiple tools (non-streaming)
102    pub async fn generate_batch(
103        &self,
104        tools: &[ToolDocumentation],
105    ) -> Result<Vec<(String, Vec<GeneratedExample>)>> {
106        let mut results = Vec::new();
107
108        for tool in tools {
109            let examples = self.generate(tool).await?;
110            results.push((tool.name.clone(), examples));
111        }
112
113        Ok(results)
114    }
115
116    /// Generate examples with streaming events
117    pub fn generate_stream<'a>(
118        &'a self,
119        tool: &'a ToolDocumentation,
120        current_index: usize,
121        total_tools: usize,
122    ) -> impl Stream<Item = GenerationEvent> + 'a {
123        async_stream::stream! {
124            let start_time = Instant::now();
125            let builder = GenerationStreamBuilder::new(&tool.name, total_tools, current_index);
126
127            // Emit started event
128            yield builder.started();
129
130            // Build the prompt
131            yield builder.thinking(format!("Building prompt for {} parameters...", tool.parameters.len()));
132
133            let prompt = self.build_prompt(tool);
134
135            // Create completion request
136            let request = CompletionRequest::with_system(
137                SYSTEM_PROMPT,
138                &prompt,
139            )
140            .temperature(self.config.temperature)
141            .max_tokens(self.config.max_tokens);
142
143            yield builder.thinking("Generating examples with LLM...");
144
145            // Generate with retries
146            let mut attempts = 0;
147            let mut examples = Vec::new();
148
149            loop {
150                attempts += 1;
151
152                match self.llm.complete(&request).await {
153                    Ok(response) => {
154                        yield builder.thinking("Parsing LLM response...");
155
156                        // Parse examples from response
157                        match self.parse_examples(&response.content) {
158                            Ok(parsed) => {
159                                examples = parsed;
160                                break;
161                            }
162                            Err(e) => {
163                                if attempts >= self.config.max_retries {
164                                    yield builder.error(
165                                        format!("Failed to parse response after {} attempts: {}", attempts, e),
166                                        false,
167                                    );
168                                    return;
169                                }
170                                yield builder.thinking(format!("Retrying ({}/{}): {}", attempts, self.config.max_retries, e));
171                            }
172                        }
173                    }
174                    Err(e) => {
175                        if attempts >= self.config.max_retries {
176                            yield builder.error(
177                                format!("LLM generation failed after {} attempts: {}", attempts, e),
178                                false,
179                            );
180                            return;
181                        }
182                        yield builder.thinking(format!("Retrying ({}/{}): {}", attempts, self.config.max_retries, e));
183                    }
184                }
185            }
186
187            // Process and validate each example
188            let total_examples = examples.len();
189            let mut valid_count = 0;
190
191            for (idx, mut example) in examples.into_iter().enumerate() {
192                // Validate if enabled
193                if self.config.validate_examples {
194                    let validation = self.validator.validate_example(&example, tool);
195
196                    yield builder.validation(
197                        validation.valid,
198                        validation.errors.clone(),
199                        idx,
200                    );
201
202                    if validation.valid {
203                        example.validated = true;
204                        example.confidence = validation.confidence;
205                        valid_count += 1;
206                        yield builder.example(example);
207                    }
208                } else {
209                    yield builder.example(example);
210                    valid_count += 1;
211                }
212
213                // Progress update
214                yield GenerationEvent::progress(
215                    idx + 1,
216                    total_examples,
217                    Some(format!("Processed {}/{} examples", idx + 1, total_examples)),
218                );
219            }
220
221            // Emit completion
222            let duration = start_time.elapsed();
223            yield builder.tool_completed(total_examples, valid_count, duration);
224        }
225    }
226
227    /// Build the prompt for example generation
228    fn build_prompt(&self, tool: &ToolDocumentation) -> String {
229        let params_desc = self.format_parameters(tool);
230        let existing_examples = self.format_existing_examples(tool);
231
232        format!(
233            r#"Generate {count} realistic CLI usage examples for the following tool:
234
235## Tool Information
236- **Name**: {name}
237- **Description**: {description}
238
239## Parameters
240{parameters}
241
242{existing}
243
244## Requirements
2451. Each example must use valid parameter values
2462. Cover diverse use cases (common operations, edge cases, real-world scenarios)
2473. Include a brief explanation for each example
2484. Use the format: `skill run {name} [options]`
249
250## Output Format
251Return a JSON array with exactly {count} examples:
252```json
253[
254  {{"command": "skill run {name} --param=value", "explanation": "Brief description of what this does"}},
255  ...
256]
257```
258
259Generate {count} diverse, realistic examples now:"#,
260            count = self.config.examples_per_tool,
261            name = tool.name,
262            description = tool.description,
263            parameters = params_desc,
264            existing = existing_examples,
265        )
266    }
267
268    /// Format parameters for the prompt
269    fn format_parameters(&self, tool: &ToolDocumentation) -> String {
270        if tool.parameters.is_empty() {
271            return "No parameters defined.".to_string();
272        }
273
274        tool.parameters
275            .iter()
276            .map(|p| {
277                let required = if p.required { " (required)" } else { "" };
278                let default = p.default.as_ref()
279                    .map(|d| format!(" [default: {}]", d))
280                    .unwrap_or_default();
281                let allowed = if !p.allowed_values.is_empty() {
282                    format!(" [values: {}]", p.allowed_values.join(", "))
283                } else {
284                    String::new()
285                };
286
287                format!(
288                    "- `--{name}` ({type}){required}{default}{allowed}: {desc}",
289                    name = p.name,
290                    type = format!("{:?}", p.param_type).to_lowercase(),
291                    required = required,
292                    default = default,
293                    allowed = allowed,
294                    desc = p.description,
295                )
296            })
297            .collect::<Vec<_>>()
298            .join("\n")
299    }
300
301    /// Format existing examples (if any) to avoid duplicates
302    fn format_existing_examples(&self, tool: &ToolDocumentation) -> String {
303        if tool.examples.is_empty() {
304            return String::new();
305        }
306
307        let examples = tool.examples
308            .iter()
309            .take(3)
310            .map(|e| format!("- `{}`", e.code.lines().next().unwrap_or(&e.code)))
311            .collect::<Vec<_>>()
312            .join("\n");
313
314        format!(
315            "\n## Existing Examples (do not duplicate)\n{}\n",
316            examples
317        )
318    }
319
320    /// Parse examples from LLM response
321    fn parse_examples(&self, response: &str) -> Result<Vec<GeneratedExample>> {
322        // Try to find JSON array in response
323        let json_str = self.extract_json_array(response)?;
324
325        // Parse the JSON
326        let parsed: serde_json::Value = serde_json::from_str(&json_str)
327            .with_context(|| format!("Failed to parse JSON: {}", &json_str[..json_str.len().min(100)]))?;
328
329        let array = parsed.as_array()
330            .context("Expected JSON array")?;
331
332        let examples: Vec<GeneratedExample> = array
333            .iter()
334            .filter_map(|item| {
335                let command = item.get("command")?.as_str()?;
336                let explanation = item.get("explanation")?.as_str()?;
337
338                Some(GeneratedExample::new(command, explanation))
339            })
340            .collect();
341
342        if examples.is_empty() {
343            anyhow::bail!("No valid examples found in response");
344        }
345
346        Ok(examples)
347    }
348
349    /// Extract JSON array from response text
350    fn extract_json_array(&self, response: &str) -> Result<String> {
351        // Try to find JSON array directly
352        if let Some(start) = response.find('[') {
353            if let Some(end) = response.rfind(']') {
354                if end > start {
355                    return Ok(response[start..=end].to_string());
356                }
357            }
358        }
359
360        // Try to find JSON in code block
361        if let Some(start) = response.find("```json") {
362            let after_marker = &response[start + 7..];
363            if let Some(end) = after_marker.find("```") {
364                let json_content = &after_marker[..end];
365                if let Some(arr_start) = json_content.find('[') {
366                    if let Some(arr_end) = json_content.rfind(']') {
367                        return Ok(json_content[arr_start..=arr_end].to_string());
368                    }
369                }
370            }
371        }
372
373        // Try to find any code block
374        if let Some(start) = response.find("```") {
375            let after_marker = &response[start + 3..];
376            // Skip optional language identifier
377            let content_start = after_marker.find('\n').unwrap_or(0) + 1;
378            let after_newline = &after_marker[content_start..];
379            if let Some(end) = after_newline.find("```") {
380                let json_content = &after_newline[..end];
381                if let Some(arr_start) = json_content.find('[') {
382                    if let Some(arr_end) = json_content.rfind(']') {
383                        return Ok(json_content[arr_start..=arr_end].to_string());
384                    }
385                }
386            }
387        }
388
389        anyhow::bail!("Could not find JSON array in response")
390    }
391
392    /// Get the LLM provider name
393    pub fn provider_name(&self) -> &str {
394        self.llm.name()
395    }
396
397    /// Get the model name
398    pub fn model_name(&self) -> &str {
399        self.llm.model()
400    }
401}
402
403/// System prompt for example generation
404const SYSTEM_PROMPT: &str = r#"You are a CLI tool documentation expert who generates realistic usage examples.
405
406Your task is to create diverse, practical examples that demonstrate various use cases for command-line tools.
407
408Guidelines:
409- Generate valid commands with proper parameter syntax
410- Cover common use cases, edge cases, and real-world scenarios
411- Include meaningful explanations that help users understand each example
412- Use realistic parameter values (not placeholders like "value1", "example")
413- Ensure examples are syntactically correct and would execute successfully
414
415Output your examples as a JSON array with "command" and "explanation" fields."#;
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::skill_md::{ParameterDoc, ParameterType};
421
422    fn create_test_tool() -> ToolDocumentation {
423        ToolDocumentation {
424            name: "apply".to_string(),
425            description: "Apply a Kubernetes manifest".to_string(),
426            usage: None,
427            parameters: vec![
428                ParameterDoc {
429                    name: "file".to_string(),
430                    param_type: ParameterType::String,
431                    description: "Path to manifest file".to_string(),
432                    required: true,
433                    default: None,
434                    allowed_values: vec![],
435                },
436                ParameterDoc {
437                    name: "namespace".to_string(),
438                    param_type: ParameterType::String,
439                    description: "Target namespace".to_string(),
440                    required: false,
441                    default: Some("default".to_string()),
442                    allowed_values: vec![],
443                },
444            ],
445            examples: vec![],
446        }
447    }
448
449    #[test]
450    fn test_build_prompt() {
451        // Create a mock provider for testing
452        struct MockProvider;
453
454        #[async_trait::async_trait]
455        impl LlmProvider for MockProvider {
456            fn name(&self) -> &str { "mock" }
457            fn model(&self) -> &str { "test" }
458            async fn complete(&self, _: &CompletionRequest) -> Result<super::super::llm_provider::LlmResponse> {
459                unimplemented!()
460            }
461            async fn complete_stream(&self, _: &CompletionRequest) -> Result<Pin<Box<dyn Stream<Item = Result<super::super::llm_provider::LlmChunk>> + Send>>> {
462                unimplemented!()
463            }
464        }
465
466        let generator = ExampleGenerator::new(
467            Arc::new(MockProvider),
468            GeneratorConfig::default(),
469        );
470        let tool = create_test_tool();
471
472        let prompt = generator.build_prompt(&tool);
473
474        assert!(prompt.contains("apply"));
475        assert!(prompt.contains("Kubernetes manifest"));
476        assert!(prompt.contains("--file"));
477        assert!(prompt.contains("--namespace"));
478        assert!(prompt.contains("(required)"));
479        assert!(prompt.contains("[default: default]"));
480    }
481
482    #[test]
483    fn test_parse_examples_json() {
484        struct MockProvider;
485        #[async_trait::async_trait]
486        impl LlmProvider for MockProvider {
487            fn name(&self) -> &str { "mock" }
488            fn model(&self) -> &str { "test" }
489            async fn complete(&self, _: &CompletionRequest) -> Result<super::super::llm_provider::LlmResponse> {
490                unimplemented!()
491            }
492            async fn complete_stream(&self, _: &CompletionRequest) -> Result<Pin<Box<dyn Stream<Item = Result<super::super::llm_provider::LlmChunk>> + Send>>> {
493                unimplemented!()
494            }
495        }
496
497        let generator = ExampleGenerator::new(
498            Arc::new(MockProvider),
499            GeneratorConfig::default(),
500        );
501
502        let response = r#"
503Here are the examples:
504[
505  {"command": "skill run apply --file=deploy.yaml", "explanation": "Apply deployment"},
506  {"command": "skill run apply --file=service.yaml --namespace=prod", "explanation": "Apply to prod"}
507]
508        "#;
509
510        let examples = generator.parse_examples(response).unwrap();
511        assert_eq!(examples.len(), 2);
512        assert!(examples[0].command.contains("deploy.yaml"));
513        assert!(examples[1].command.contains("namespace=prod"));
514    }
515
516    #[test]
517    fn test_parse_examples_code_block() {
518        struct MockProvider;
519        #[async_trait::async_trait]
520        impl LlmProvider for MockProvider {
521            fn name(&self) -> &str { "mock" }
522            fn model(&self) -> &str { "test" }
523            async fn complete(&self, _: &CompletionRequest) -> Result<super::super::llm_provider::LlmResponse> {
524                unimplemented!()
525            }
526            async fn complete_stream(&self, _: &CompletionRequest) -> Result<Pin<Box<dyn Stream<Item = Result<super::super::llm_provider::LlmChunk>> + Send>>> {
527                unimplemented!()
528            }
529        }
530
531        let generator = ExampleGenerator::new(
532            Arc::new(MockProvider),
533            GeneratorConfig::default(),
534        );
535
536        let response = r#"
537Here are some examples:
538
539```json
540[
541  {"command": "skill run test --param=value", "explanation": "Test command"}
542]
543```
544        "#;
545
546        let examples = generator.parse_examples(response).unwrap();
547        assert_eq!(examples.len(), 1);
548    }
549
550    #[test]
551    fn test_config_from_ai_ingestion() {
552        let ai_config = AiIngestionConfig {
553            enabled: true,
554            examples_per_tool: 3,
555            timeout_secs: 60,
556            ..Default::default()
557        };
558
559        let config = GeneratorConfig::from(&ai_config);
560        assert_eq!(config.examples_per_tool, 3);
561        assert_eq!(config.timeout, Duration::from_secs(60));
562    }
563
564    #[test]
565    fn test_extract_json_array_direct() {
566        struct MockProvider;
567        #[async_trait::async_trait]
568        impl LlmProvider for MockProvider {
569            fn name(&self) -> &str { "mock" }
570            fn model(&self) -> &str { "test" }
571            async fn complete(&self, _: &CompletionRequest) -> Result<super::super::llm_provider::LlmResponse> {
572                unimplemented!()
573            }
574            async fn complete_stream(&self, _: &CompletionRequest) -> Result<Pin<Box<dyn Stream<Item = Result<super::super::llm_provider::LlmChunk>> + Send>>> {
575                unimplemented!()
576            }
577        }
578
579        let generator = ExampleGenerator::new(
580            Arc::new(MockProvider),
581            GeneratorConfig::default(),
582        );
583
584        let input = r#"[{"a": 1}]"#;
585        let result = generator.extract_json_array(input).unwrap();
586        assert_eq!(result, r#"[{"a": 1}]"#);
587    }
588}