skill_runtime/
executor.rs

1use anyhow::{Context, Result};
2use std::path::Path;
3use std::sync::Arc;
4use std::time::Instant;
5use wasmtime::{
6    component::{Component, Linker},
7    Store,
8};
9
10use crate::engine::SkillEngine;
11use crate::instance::InstanceConfig;
12use crate::sandbox::SandboxBuilder;
13use crate::types::{ExecutionResult, SkillMetadata, ToolDefinition, Parameter, ParameterType};
14
15// Generate WIT bindings for the skill interface
16// TODO: Add host function imports for configuration access
17wasmtime::component::bindgen!({
18    inline: "
19        package skill-engine:skill@1.0.0;
20
21        world skill {
22            export get-metadata: func() -> string;
23            export get-tools: func() -> string;
24            export execute-tool: func(tool-name: string, args: string) -> string;
25            export validate-config: func(config: string) -> string;
26        }
27    ",
28    async: true,
29});
30
31/// High-level executor for running skills
32pub struct SkillExecutor {
33    engine: Arc<SkillEngine>,
34    skill_name: String,
35    instance_name: String,
36    config: InstanceConfig,
37    component: Component,
38}
39
40impl SkillExecutor {
41    /// Load a skill and prepare for execution
42    pub async fn load(
43        engine: Arc<SkillEngine>,
44        skill_path: impl AsRef<Path>,
45        skill_name: String,
46        instance_name: String,
47        config: InstanceConfig,
48    ) -> Result<Self> {
49        tracing::info!(
50            skill = %skill_name,
51            instance = %instance_name,
52            path = %skill_path.as_ref().display(),
53            "Loading skill"
54        );
55
56        let start = Instant::now();
57
58        // Load the component
59        let component = engine.load_component(skill_path.as_ref()).await?;
60
61        // Validate the component
62        engine.validate_component(&component).await?;
63
64        let duration = start.elapsed();
65        tracing::info!(
66            skill = %skill_name,
67            instance = %instance_name,
68            duration_ms = duration.as_millis(),
69            "Skill loaded successfully"
70        );
71
72        Ok(Self {
73            engine,
74            skill_name,
75            instance_name,
76            config,
77            component,
78        })
79    }
80
81    /// Create an executor from an already-loaded component
82    pub fn from_component(
83        engine: Arc<SkillEngine>,
84        component: Component,
85        skill_name: String,
86        instance_name: String,
87        config: InstanceConfig,
88    ) -> Result<Self> {
89        Ok(Self {
90            engine,
91            skill_name,
92            instance_name,
93            config,
94            component,
95        })
96    }
97
98    /// Get skill metadata
99    pub async fn get_metadata(&self) -> Result<SkillMetadata> {
100        // Create a store for this execution
101        let instance_dir = InstanceConfig::instance_dir(&self.skill_name, &self.instance_name)?;
102
103        let sandbox = SandboxBuilder::new(&self.instance_name, instance_dir)
104            .env_from_config(&self.config)
105            .build()?;
106
107        let mut store = Store::new(self.engine.wasmtime_engine(), sandbox);
108
109        // Create linker and instantiate component
110        let mut linker = Linker::new(self.engine.wasmtime_engine());
111        wasmtime_wasi::add_to_linker_async(&mut linker)?;
112
113        let skill = Skill::instantiate_async(&mut store, &self.component, &linker).await?;
114
115        // Call get-metadata export
116        let metadata_json = skill.call_get_metadata(&mut store).await?;
117
118        // Parse JSON metadata
119        let metadata: serde_json::Value = serde_json::from_str(&metadata_json)
120            .context("Failed to parse skill metadata JSON")?;
121
122        Ok(SkillMetadata {
123            name: metadata["name"].as_str().unwrap_or(&self.skill_name).to_string(),
124            version: metadata["version"].as_str().unwrap_or("0.0.0").to_string(),
125            description: metadata["description"].as_str().unwrap_or("").to_string(),
126            author: metadata["author"].as_str().unwrap_or("").to_string(),
127            repository: metadata["repository"].as_str().map(|s| s.to_string()),
128            license: metadata["license"].as_str().map(|s| s.to_string()),
129        })
130    }
131
132    /// Get list of tools provided by this skill
133    pub async fn get_tools(&self) -> Result<Vec<ToolDefinition>> {
134        // Create a store for this execution
135        let instance_dir = InstanceConfig::instance_dir(&self.skill_name, &self.instance_name)?;
136
137        let sandbox = SandboxBuilder::new(&self.instance_name, instance_dir)
138            .env_from_config(&self.config)
139            .build()?;
140
141        let mut store = Store::new(self.engine.wasmtime_engine(), sandbox);
142
143        // Create linker and instantiate component
144        let mut linker = Linker::new(self.engine.wasmtime_engine());
145        wasmtime_wasi::add_to_linker_async(&mut linker)?;
146
147        let skill = Skill::instantiate_async(&mut store, &self.component, &linker).await?;
148
149        // Call get-tools export
150        let tools_json = skill.call_get_tools(&mut store).await?;
151
152        // Parse JSON tools list
153        let tools: Vec<serde_json::Value> = serde_json::from_str(&tools_json)
154            .context("Failed to parse tools JSON")?;
155
156        // Convert to ToolDefinition structs
157        let mut tool_defs = Vec::new();
158        let empty_params = Vec::new();
159        for tool in tools {
160            let params_json = tool["parameters"].as_array().unwrap_or(&empty_params);
161            let mut parameters = Vec::new();
162
163            for param in params_json {
164                let param_type_str = param["paramType"].as_str().unwrap_or("string");
165                let param_type = match param_type_str {
166                    "number" => ParameterType::Number,
167                    "boolean" => ParameterType::Boolean,
168                    "file" => ParameterType::File,
169                    "json" => ParameterType::Json,
170                    "array" => ParameterType::Array,
171                    _ => ParameterType::String,
172                };
173
174                parameters.push(Parameter {
175                    name: param["name"].as_str().unwrap_or("").to_string(),
176                    param_type,
177                    description: param["description"].as_str().unwrap_or("").to_string(),
178                    required: param["required"].as_bool().unwrap_or(false),
179                    default_value: param["defaultValue"].as_str().map(|s| s.to_string()),
180                });
181            }
182
183            tool_defs.push(ToolDefinition {
184                name: tool["name"].as_str().unwrap_or("").to_string(),
185                description: tool["description"].as_str().unwrap_or("").to_string(),
186                parameters,
187                streaming: false, // TODO: Support streaming tools
188            });
189        }
190
191        Ok(tool_defs)
192    }
193
194    /// Execute a tool
195    pub async fn execute_tool(
196        &self,
197        tool_name: &str,
198        args: Vec<(String, String)>,
199    ) -> Result<ExecutionResult> {
200        let start = Instant::now();
201
202        tracing::info!(
203            skill = %self.skill_name,
204            instance = %self.instance_name,
205            tool = %tool_name,
206            args_count = args.len(),
207            "Executing tool"
208        );
209
210        // Create sandbox environment
211        let instance_dir = InstanceConfig::instance_dir(&self.skill_name, &self.instance_name)?;
212
213        let sandbox = SandboxBuilder::new(&self.instance_name, instance_dir)
214            .env_from_config(&self.config)
215            .args(vec![tool_name.to_string()])
216            .build()?;
217
218        let mut store = Store::new(self.engine.wasmtime_engine(), sandbox);
219
220        // Create linker and instantiate component
221        let mut linker = Linker::new(self.engine.wasmtime_engine());
222        wasmtime_wasi::add_to_linker_async(&mut linker)?;
223
224        let skill = Skill::instantiate_async(&mut store, &self.component, &linker).await?;
225
226        // Convert args to JSON string
227        let args_json = serde_json::to_string(&serde_json::Map::from_iter(
228            args.into_iter().map(|(k, v)| (k, serde_json::Value::String(v)))
229        ))?;
230
231        // Call execute-tool export
232        let result_json = skill.call_execute_tool(&mut store, tool_name, args_json.as_str()).await?;
233
234        // Parse JSON result
235        let result_value: serde_json::Value = serde_json::from_str(&result_json)
236            .context("Failed to parse execution result JSON")?;
237
238        let result = if let Some(ok) = result_value.get("ok") {
239            // Success case
240            ExecutionResult {
241                success: ok["success"].as_bool().unwrap_or(true),
242                output: ok["output"].as_str().unwrap_or("").to_string(),
243                error_message: ok["errorMessage"].as_str().map(|s| s.to_string()),
244                metadata: None,
245            }
246        } else if let Some(err) = result_value.get("err") {
247            // Error case
248            ExecutionResult {
249                success: false,
250                output: String::new(),
251                error_message: Some(err.as_str().unwrap_or("Unknown error").to_string()),
252                metadata: None,
253            }
254        } else {
255            // Fallback
256            ExecutionResult {
257                success: false,
258                output: String::new(),
259                error_message: Some("Invalid result format".to_string()),
260                metadata: None,
261            }
262        };
263
264        let duration = start.elapsed();
265        tracing::info!(
266            skill = %self.skill_name,
267            instance = %self.instance_name,
268            tool = %tool_name,
269            success = result.success,
270            duration_ms = duration.as_millis(),
271            "Tool execution completed"
272        );
273
274        Ok(result)
275    }
276
277    /// Validate configuration
278    pub async fn validate_config(&self) -> Result<()> {
279        // Create a store for this execution
280        let instance_dir = InstanceConfig::instance_dir(&self.skill_name, &self.instance_name)?;
281
282        let sandbox = SandboxBuilder::new(&self.instance_name, instance_dir)
283            .env_from_config(&self.config)
284            .build()?;
285
286        let mut store = Store::new(self.engine.wasmtime_engine(), sandbox);
287
288        // Create linker and instantiate component
289        let mut linker = Linker::new(self.engine.wasmtime_engine());
290        wasmtime_wasi::add_to_linker_async(&mut linker)?;
291
292        let skill = Skill::instantiate_async(&mut store, &self.component, &linker).await?;
293
294        // Convert config to JSON string
295        let config_json = serde_json::to_string(&self.config.config)?;
296
297        // Call validate-config export
298        let result_json = skill.call_validate_config(&mut store, config_json.as_str()).await?;
299
300        // Parse result
301        let result: serde_json::Value = serde_json::from_str(&result_json)
302            .context("Failed to parse validate-config result")?;
303
304        if let Some(err) = result.get("err") {
305            anyhow::bail!("Configuration validation failed: {}", err.as_str().unwrap_or("Unknown error"));
306        }
307
308        Ok(())
309    }
310
311    /// Get the underlying component
312    pub fn component(&self) -> &Component {
313        &self.component
314    }
315
316    /// Get skill name
317    pub fn skill_name(&self) -> &str {
318        &self.skill_name
319    }
320
321    /// Get instance name
322    pub fn instance_name(&self) -> &str {
323        &self.instance_name
324    }
325
326    /// Get configuration
327    pub fn config(&self) -> &InstanceConfig {
328        &self.config
329    }
330}
331
332/// Cache for compiled components
333pub struct ComponentCache {
334    cache_dir: std::path::PathBuf,
335}
336
337impl ComponentCache {
338    /// Create a new component cache
339    pub fn new(cache_dir: impl AsRef<Path>) -> Result<Self> {
340        let cache_dir = cache_dir.as_ref().to_path_buf();
341        std::fs::create_dir_all(&cache_dir)?;
342
343        Ok(Self { cache_dir })
344    }
345
346    /// Get cache key for a component
347    fn cache_key(&self, skill_name: &str, version: &str) -> String {
348        // Include wasmtime version in cache key (hardcoded for now)
349        format!("{}_{}_wasmtime_26", skill_name, version)
350    }
351
352    /// Get cached component path
353    pub fn cache_path(&self, skill_name: &str, version: &str) -> std::path::PathBuf {
354        self.cache_dir.join(format!("{}.cwasm", self.cache_key(skill_name, version)))
355    }
356
357    /// Check if component is cached
358    pub fn is_cached(&self, skill_name: &str, version: &str) -> bool {
359        self.cache_path(skill_name, version).exists()
360    }
361
362    /// Load component from cache
363    pub fn load(&self, skill_name: &str, version: &str) -> Result<Vec<u8>> {
364        let path = self.cache_path(skill_name, version);
365        std::fs::read(&path)
366            .with_context(|| format!("Failed to read cached component: {}", path.display()))
367    }
368
369    /// Save component to cache
370    pub fn save(&self, skill_name: &str, version: &str, data: &[u8]) -> Result<()> {
371        let path = self.cache_path(skill_name, version);
372        std::fs::write(&path, data)
373            .with_context(|| format!("Failed to write cached component: {}", path.display()))?;
374
375        tracing::debug!(
376            skill = %skill_name,
377            version = %version,
378            size = data.len(),
379            "Saved component to cache"
380        );
381
382        Ok(())
383    }
384
385    /// Clear cache for a specific skill
386    pub fn clear_skill(&self, skill_name: &str) -> Result<()> {
387        for entry in std::fs::read_dir(&self.cache_dir)? {
388            let entry = entry?;
389            let filename = entry.file_name();
390            if let Some(name) = filename.to_str() {
391                if name.starts_with(&format!("{}_", skill_name)) {
392                    std::fs::remove_file(entry.path())?;
393                    tracing::debug!(file = %name, "Removed cached component");
394                }
395            }
396        }
397        Ok(())
398    }
399
400    /// Clear entire cache
401    pub fn clear_all(&self) -> Result<()> {
402        for entry in std::fs::read_dir(&self.cache_dir)? {
403            let entry = entry?;
404            if entry.file_type()?.is_file() {
405                std::fs::remove_file(entry.path())?;
406            }
407        }
408        tracing::info!("Cleared component cache");
409        Ok(())
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use tempfile::TempDir;
417
418    #[test]
419    fn test_component_cache() {
420        let temp_dir = TempDir::new().unwrap();
421        let cache = ComponentCache::new(temp_dir.path()).unwrap();
422
423        let skill_name = "test-skill";
424        let version = "1.0.0";
425        let data = vec![1, 2, 3, 4, 5];
426
427        // Should not be cached initially
428        assert!(!cache.is_cached(skill_name, version));
429
430        // Save to cache
431        cache.save(skill_name, version, &data).unwrap();
432
433        // Should be cached now
434        assert!(cache.is_cached(skill_name, version));
435
436        // Load from cache
437        let loaded = cache.load(skill_name, version).unwrap();
438        assert_eq!(loaded, data);
439
440        // Clear cache
441        cache.clear_all().unwrap();
442        assert!(!cache.is_cached(skill_name, version));
443    }
444
445    #[test]
446    fn test_cache_key_includes_wasmtime_version() {
447        let temp_dir = TempDir::new().unwrap();
448        let cache = ComponentCache::new(temp_dir.path()).unwrap();
449
450        let key = cache.cache_key("test", "1.0.0");
451        assert!(key.contains("wasmtime"));
452        assert!(key.contains("wasmtime_26"));
453    }
454}