Skip to main content

xz_skill/runtime/
default.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3
4use crate::error::SkillError;
5use crate::security::permissions::PermissionValidator;
6use crate::traits::{SkillRegistry, SkillRuntime};
7use crate::types::context::ExecutionContext;
8use crate::types::filter::{PreflightWarning, WarningSeverity};
9use crate::types::output::{SkillOutput, TokenUsage, ToolCallRecord};
10use crate::types::skill::{Skill, ToolType};
11
12/// Default skill runtime — LLM prompt injection + tool-calling loop.
13#[derive(Debug)]
14pub struct DefaultSkillRuntime {
15    registry: Option<Arc<dyn SkillRegistry>>,
16    permission_validator: PermissionValidator,
17}
18
19impl DefaultSkillRuntime {
20    pub fn new(registry: Option<Arc<dyn SkillRegistry>>) -> Self {
21        Self {
22            registry,
23            permission_validator: PermissionValidator::new(false, vec![]),
24        }
25    }
26
27    pub fn with_permissions(mut self, allowed_network: bool, allowed_paths: Vec<std::path::PathBuf>) -> Self {
28        self.permission_validator = PermissionValidator::new(allowed_network, allowed_paths);
29        self
30    }
31
32    async fn resolve_skill(&self, skill_id: &str) -> Result<Skill, SkillError> {
33        if let Some(ref reg) = self.registry {
34            reg.get(skill_id)
35                .await?
36                .ok_or_else(|| SkillError::NotFound(skill_id.to_string()))
37        } else {
38            Err(SkillError::ConfigValidation(
39                "No skill registry configured".into(),
40            ))
41        }
42    }
43}
44
45#[async_trait]
46impl SkillRuntime for DefaultSkillRuntime {
47    async fn execute(
48        &self,
49        skill_id: &str,
50        input: &str,
51        context: &ExecutionContext,
52    ) -> Result<SkillOutput, SkillError> {
53        let skill = self.resolve_skill(skill_id).await?;
54
55        if !skill.enabled {
56            return Err(SkillError::Disabled(skill_id.to_string()));
57        }
58
59        // Validate permissions
60        self.validate_permissions(&skill, context).await?;
61
62        // Validate version constraint
63        if let Some(ref min_version) = skill.min_agent_version {
64            // Simple version check: compare "0.1.0" format
65            if let Err(e) = check_version(min_version, "0.1.0") {
66                return Err(SkillError::VersionMismatch { required: e });
67            }
68        }
69
70        let start = std::time::Instant::now();
71        let mut tool_calls = Vec::new();
72
73        // Simple tool execution loop — in production this would be LLM-driven
74        for tool_def in &skill.tools {
75            let tool_start = std::time::Instant::now();
76            match &tool_def.tool_type {
77                ToolType::Builtin { handler } => {
78                    let result = execute_builtin_tool(handler, &serde_json::Value::String(input.to_string()))
79                        .await;
80                    tool_calls.push(ToolCallRecord {
81                        tool_name: tool_def.name.clone(),
82                        args: serde_json::json!({"input": input}),
83                        result: result.as_ref().ok().cloned(),
84                        error: result.as_ref().err().map(|e| e.to_string()),
85                        duration_ms: tool_start.elapsed().as_millis() as u64,
86                    });
87                }
88                #[cfg(feature = "http-tool")]
89                ToolType::Http {
90                    url,
91                    method,
92                    headers,
93                    timeout_ms,
94                } => {
95                    let executor = crate::runtime::http::HttpToolExecutor::new();
96                    let result = executor
97                        .execute(
98                            url,
99                            method,
100                            headers,
101                            *timeout_ms,
102                            &serde_json::json!({"input": input}),
103                        )
104                        .await;
105                    tool_calls.push(ToolCallRecord {
106                        tool_name: tool_def.name.clone(),
107                        args: serde_json::json!({"input": input}),
108                        result: result.as_ref().ok().cloned(),
109                        error: result.as_ref().err().map(|e| e.to_string()),
110                        duration_ms: tool_start.elapsed().as_millis() as u64,
111                    });
112                }
113                #[cfg(feature = "wasm-runtime")]
114                ToolType::Wasm {
115                    module,
116                    timeout_ms: _timeout_ms,
117                    ..
118                } => {
119                    let runtime = crate::runtime::wasm::WasmRuntime::new(
120                        crate::runtime::wasm::WasmConfig::default(),
121                    )?;
122                    let result = runtime
123                        .execute(module, &tool_def.name, serde_json::json!({"input": input}))
124                        .await;
125                    tool_calls.push(ToolCallRecord {
126                        tool_name: tool_def.name.clone(),
127                        args: serde_json::json!({"input": input}),
128                        result: result.as_ref().ok().cloned(),
129                        error: result.as_ref().err().map(|e| e.to_string()),
130                        duration_ms: tool_start.elapsed().as_millis() as u64,
131                    });
132                }
133                _ => {
134                    tool_calls.push(ToolCallRecord {
135                        tool_name: tool_def.name.clone(),
136                        args: serde_json::Value::Null,
137                        result: None,
138                        error: Some("Tool type not supported (missing feature flag)".into()),
139                        duration_ms: 0,
140                    });
141                }
142            }
143        }
144
145        let total_ms = start.elapsed().as_millis() as u64;
146
147        Ok(SkillOutput {
148            content: format!("Skill '{}' executed with {} tools", skill.name, tool_calls.len()),
149            tool_calls,
150            token_usage: TokenUsage::default(),
151            duration_ms: total_ms,
152        })
153    }
154
155    async fn execute_tool(
156        &self,
157        tool_name: &str,
158        args: serde_json::Value,
159    ) -> Result<serde_json::Value, SkillError> {
160        execute_builtin_tool(tool_name, &args).await
161    }
162
163    async fn validate_permissions(
164        &self,
165        skill: &Skill,
166        _context: &ExecutionContext,
167    ) -> Result<(), SkillError> {
168        for perm in &skill.permissions {
169            self.permission_validator.check(perm)?;
170        }
171        Ok(())
172    }
173
174    async fn preflight_check(&self, skill: &Skill) -> Result<Vec<PreflightWarning>, SkillError> {
175        let mut warnings = Vec::new();
176
177        // Check for missing prompt
178        if skill.prompt.trim().is_empty() {
179            warnings.push(PreflightWarning {
180                severity: WarningSeverity::Error,
181                message: "Skill has no prompt defined".into(),
182            });
183        }
184
185        // Check WASM modules
186        #[cfg(feature = "wasm-runtime")]
187        for tool in &skill.tools {
188            if let ToolType::Wasm {
189                module, timeout_ms, ..
190            } = &tool.tool_type
191            {
192                if module.is_empty() {
193                    warnings.push(PreflightWarning {
194                        severity: WarningSeverity::Error,
195                        message: format!("WASM module is empty for tool '{}'", tool.name),
196                    });
197                }
198                if *timeout_ms > 30_000 {
199                    warnings.push(PreflightWarning {
200                        severity: WarningSeverity::Warning,
201                        message: format!(
202                            "WASM tool '{}' timeout is > 30s ({}ms)",
203                            tool.name, timeout_ms
204                        ),
205                    });
206                }
207            }
208        }
209
210        Ok(warnings)
211    }
212}
213
214/// Simple built-in tool execution — placeholder for real implementations.
215async fn execute_builtin_tool(
216    handler: &str,
217    args: &serde_json::Value,
218) -> Result<serde_json::Value, SkillError> {
219    match handler {
220        "echo" => Ok(args.clone()),
221        "now" => Ok(serde_json::json!({
222            "timestamp": std::time::SystemTime::now()
223                .duration_since(std::time::UNIX_EPOCH)
224                .unwrap()
225                .as_secs()
226        })),
227        "uuid" => {
228            let id = uuid::Uuid::new_v4().to_string();
229            Ok(serde_json::json!({"uuid": id}))
230        }
231        "json_path" => {
232            let path = args.get("path")
233                .and_then(|v| v.as_str())
234                .unwrap_or("");
235            let data = args.get("data").cloned().unwrap_or(serde_json::Value::Null);
236            let result = extract_json_path(&data, path);
237            Ok(result)
238        }
239        "base64_encode" => {
240            let text = args.get("text")
241                .and_then(|v| v.as_str())
242                .unwrap_or("");
243            Ok(serde_json::json!({
244                "encoded": base64_encode(text)
245            }))
246        }
247        "base64_decode" => {
248            let text = args.get("encoded")
249                .and_then(|v| v.as_str())
250                .unwrap_or("");
251            let decoded = base64_decode(text)
252                .map_err(|e| SkillError::ToolExecution(e))?;
253            Ok(serde_json::json!({
254                "decoded": decoded
255            }))
256        }
257        _ => Err(SkillError::ToolExecution(format!(
258            "Unknown builtin handler: {}",
259            handler
260        ))),
261    }
262}
263
264fn extract_json_path(data: &serde_json::Value, path: &str) -> serde_json::Value {
265    if path.is_empty() || path == "$" {
266        return data.clone();
267    }
268    let segments: Vec<&str> = path
269        .trim_start_matches("$.")
270        .split('.')
271        .collect();
272    let mut current = data;
273    for seg in segments {
274        match current {
275            serde_json::Value::Object(map) => {
276                current = map.get(seg).unwrap_or(&serde_json::Value::Null);
277            }
278            serde_json::Value::Array(arr) => {
279                if let Ok(idx) = seg.parse::<usize>() {
280                    current = arr.get(idx).unwrap_or(&serde_json::Value::Null);
281                } else {
282                    return serde_json::Value::Null;
283                }
284            }
285            _ => return serde_json::Value::Null,
286        }
287    }
288    current.clone()
289}
290
291fn base64_encode(text: &str) -> String {
292    use base64::Engine;
293    base64::engine::general_purpose::STANDARD.encode(text)
294}
295
296fn base64_decode(encoded: &str) -> Result<String, String> {
297    use base64::Engine;
298    let bytes = base64::engine::general_purpose::STANDARD
299        .decode(encoded.trim())
300        .map_err(|e| format!("Base64 decode failed: {}", e))?;
301    String::from_utf8(bytes).map_err(|e| format!("Invalid UTF-8: {}", e))
302}
303
304/// Simple semantic version check.
305fn check_version(required: &str, actual: &str) -> Result<(), String> {
306    let req = required.trim_start_matches(">=").trim();
307    let req_parts: Vec<u32> = req.split('.').filter_map(|s| s.parse().ok()).collect();
308    let act_parts: Vec<u32> = actual.split('.').filter_map(|s| s.parse().ok()).collect();
309
310    for i in 0..req_parts.len().max(act_parts.len()) {
311        let r = req_parts.get(i).copied().unwrap_or(0);
312        let a = act_parts.get(i).copied().unwrap_or(0);
313        if a < r {
314            return Err(required.to_string());
315        }
316        if a > r {
317            break;
318        }
319    }
320    Ok(())
321}