zeph_core/
skill_loader.rs1use std::sync::{Arc, RwLock};
5
6use schemars::JsonSchema;
7use serde::Deserialize;
8use zeph_skills::registry::SkillRegistry;
9use zeph_tools::executor::{
10 ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params, truncate_tool_output,
11};
12use zeph_tools::registry::{InvocationHint, ToolDef};
13
14#[derive(Debug, Deserialize, JsonSchema)]
15pub struct LoadSkillParams {
16 pub skill_name: String,
18}
19
20#[derive(Clone, Debug)]
22pub struct SkillLoaderExecutor {
23 registry: Arc<RwLock<SkillRegistry>>,
24}
25
26impl SkillLoaderExecutor {
27 #[must_use]
28 pub fn new(registry: Arc<RwLock<SkillRegistry>>) -> Self {
29 Self { registry }
30 }
31}
32
33impl ToolExecutor for SkillLoaderExecutor {
34 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
35 Ok(None)
36 }
37
38 fn tool_definitions(&self) -> Vec<ToolDef> {
39 vec![ToolDef {
40 id: "load_skill".into(),
41 description: "Load the full body of a skill by name when you see a relevant entry in the <other_skills> catalog.\n\nParameters: name (string, required) - exact skill name from the <other_skills> catalog\nReturns: complete skill instructions (SKILL.md body), or error if skill not found\nErrors: InvalidParams if name is empty; Execution if skill not found in registry\nExample: {\"name\": \"code-review\"}".into(),
42 schema: schemars::schema_for!(LoadSkillParams),
43 invocation: InvocationHint::ToolCall,
44 }]
45 }
46
47 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
48 if call.tool_id != "load_skill" {
49 return Ok(None);
50 }
51 let params: LoadSkillParams = deserialize_params(&call.params)?;
52 let skill_name: String = params.skill_name.chars().take(128).collect();
53 let body = {
54 let guard = self.registry.read().map_err(|_| ToolError::InvalidParams {
55 message: "registry lock poisoned".into(),
56 })?;
57 guard.get_body(&skill_name).map(str::to_owned)
58 };
59
60 let summary = match body {
61 Ok(b) => truncate_tool_output(&b),
62 Err(_) => format!("skill not found: {skill_name}"),
63 };
64
65 Ok(Some(ToolOutput {
66 tool_name: "load_skill".to_owned(),
67 summary,
68 blocks_executed: 1,
69 filter_stats: None,
70 diff: None,
71 streamed: false,
72 terminal_id: None,
73 locations: None,
74 raw_response: None,
75 claim_source: None,
76 }))
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use std::path::Path;
83
84 use super::*;
85
86 fn make_registry_with_skill(dir: &Path, name: &str, body: &str) -> SkillRegistry {
87 let skill_dir = dir.join(name);
88 std::fs::create_dir_all(&skill_dir).unwrap();
89 std::fs::write(
90 skill_dir.join("SKILL.md"),
91 format!("---\nname: {name}\ndescription: test skill\n---\n{body}"),
92 )
93 .unwrap();
94 SkillRegistry::load(&[dir.to_path_buf()])
95 }
96
97 #[tokio::test]
98 async fn load_existing_skill_returns_body() {
99 let dir = tempfile::tempdir().unwrap();
100 let registry =
101 make_registry_with_skill(dir.path(), "git-commit", "## Instructions\nDo git stuff");
102 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
103 let call = ToolCall {
104 tool_id: "load_skill".to_owned(),
105 params: serde_json::json!({"skill_name": "git-commit"})
106 .as_object()
107 .unwrap()
108 .clone(),
109 };
110 let result = executor.execute_tool_call(&call).await.unwrap().unwrap();
111 assert!(result.summary.contains("## Instructions"));
112 assert!(result.summary.contains("Do git stuff"));
113 }
114
115 #[tokio::test]
116 async fn load_nonexistent_skill_returns_error_message() {
117 let dir = tempfile::tempdir().unwrap();
118 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
119 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
120 let call = ToolCall {
121 tool_id: "load_skill".to_owned(),
122 params: serde_json::json!({"skill_name": "nonexistent"})
123 .as_object()
124 .unwrap()
125 .clone(),
126 };
127 let result = executor.execute_tool_call(&call).await.unwrap().unwrap();
128 assert!(result.summary.contains("skill not found"));
129 assert!(result.summary.contains("nonexistent"));
130 }
131
132 #[test]
133 fn tool_definitions_returns_load_skill() {
134 let dir = tempfile::tempdir().unwrap();
135 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
136 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
137 let defs = executor.tool_definitions();
138 assert_eq!(defs.len(), 1);
139 assert_eq!(defs[0].id.as_ref(), "load_skill");
140 }
141
142 #[tokio::test]
143 async fn execute_returns_none_for_wrong_tool_id() {
144 let dir = tempfile::tempdir().unwrap();
145 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
146 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
147 let call = ToolCall {
148 tool_id: "bash".to_owned(),
149 params: serde_json::Map::new(),
150 };
151 let result = executor.execute_tool_call(&call).await.unwrap();
152 assert!(result.is_none());
153 }
154
155 #[tokio::test]
156 async fn long_skill_body_is_truncated() {
157 use zeph_tools::executor::MAX_TOOL_OUTPUT_CHARS;
158 let dir = tempfile::tempdir().unwrap();
159 let long_body = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
160 let registry = make_registry_with_skill(dir.path(), "big-skill", &long_body);
161 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
162 let call = ToolCall {
163 tool_id: "load_skill".to_owned(),
164 params: serde_json::json!({"skill_name": "big-skill"})
165 .as_object()
166 .unwrap()
167 .clone(),
168 };
169 let result = executor.execute_tool_call(&call).await.unwrap().unwrap();
170 assert!(result.summary.contains("truncated"));
171 assert!(result.summary.len() < long_body.len() + 200);
172 }
173
174 #[tokio::test]
175 async fn empty_registry_returns_error_message() {
176 let dir = tempfile::tempdir().unwrap();
177 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
178 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
179 let call = ToolCall {
180 tool_id: "load_skill".to_owned(),
181 params: serde_json::json!({"skill_name": "any"})
182 .as_object()
183 .unwrap()
184 .clone(),
185 };
186 let result = executor.execute_tool_call(&call).await.unwrap().unwrap();
187 assert!(result.summary.contains("skill not found"));
188 }
189
190 #[tokio::test]
192 async fn execute_always_returns_none() {
193 let dir = tempfile::tempdir().unwrap();
194 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
195 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
196 let result = executor.execute("any response text").await.unwrap();
197 assert!(result.is_none());
198 }
199
200 #[tokio::test]
202 async fn concurrent_execute_tool_call_succeeds() {
203 let dir = tempfile::tempdir().unwrap();
204 let registry =
205 make_registry_with_skill(dir.path(), "shared-skill", "## Concurrent test body");
206 let executor = Arc::new(SkillLoaderExecutor::new(Arc::new(RwLock::new(registry))));
207
208 let handles: Vec<_> = (0..8)
209 .map(|_| {
210 let ex = Arc::clone(&executor);
211 tokio::spawn(async move {
212 let call = ToolCall {
213 tool_id: "load_skill".to_owned(),
214 params: serde_json::json!({"skill_name": "shared-skill"})
215 .as_object()
216 .unwrap()
217 .clone(),
218 };
219 ex.execute_tool_call(&call).await
220 })
221 })
222 .collect();
223
224 for h in handles {
225 let result = h.await.unwrap().unwrap().unwrap();
226 assert!(result.summary.contains("## Concurrent test body"));
227 }
228 }
229
230 #[tokio::test]
232 async fn empty_skill_name_returns_not_found() {
233 let dir = tempfile::tempdir().unwrap();
234 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
235 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
236 let call = ToolCall {
237 tool_id: "load_skill".to_owned(),
238 params: serde_json::json!({"skill_name": ""})
239 .as_object()
240 .unwrap()
241 .clone(),
242 };
243 let result = executor.execute_tool_call(&call).await.unwrap().unwrap();
244 assert!(result.summary.contains("skill not found"));
245 }
246
247 #[tokio::test]
249 async fn missing_skill_name_field_returns_error() {
250 let dir = tempfile::tempdir().unwrap();
251 let registry = SkillRegistry::load(&[dir.path().to_path_buf()]);
252 let executor = SkillLoaderExecutor::new(Arc::new(RwLock::new(registry)));
253 let call = ToolCall {
254 tool_id: "load_skill".to_owned(),
255 params: serde_json::Map::new(),
256 };
257 let result = executor.execute_tool_call(&call).await;
258 assert!(result.is_err());
259 }
260}