sh_layer3/builtin_tools/
workflow_tools.rs1use crate::builtin_tools::BuiltinTool;
8use crate::types::{Layer3Result, ToolCategory};
9use async_trait::async_trait;
10use chrono::Utc;
11use std::path::PathBuf;
12use std::sync::Arc;
13
14use sh_layer2::{CheckpointData, CheckpointId, CheckpointSystemTrait, CheckpointWriter, SessionId};
16
17fn default_checkpoint_path() -> PathBuf {
19 std::env::temp_dir().join("continuum_checkpoints")
20}
21
22pub struct CreateCheckpointTool {
27 writer: Arc<CheckpointWriter>,
28}
29
30impl CreateCheckpointTool {
31 pub fn new() -> Self {
33 Self {
34 writer: Arc::new(CheckpointWriter::new(default_checkpoint_path())),
35 }
36 }
37
38 pub fn with_path(path: PathBuf) -> Self {
40 Self {
41 writer: Arc::new(CheckpointWriter::new(path)),
42 }
43 }
44}
45
46impl Default for CreateCheckpointTool {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52#[async_trait]
53impl BuiltinTool for CreateCheckpointTool {
54 fn name(&self) -> &str {
55 "create_checkpoint"
56 }
57
58 fn description(&self) -> &str {
59 "Create a checkpoint to save current agent state to a file."
60 }
61
62 fn parameters_schema(&self) -> serde_json::Value {
63 serde_json::json!({
64 "type": "object",
65 "properties": {
66 "session_id": {
67 "type": "string",
68 "description": "The session ID to checkpoint"
69 },
70 "trigger": {
71 "type": "string",
72 "description": "Optional: trigger reason for the checkpoint (default: 'manual')"
73 },
74 "messages": {
75 "type": "array",
76 "description": "Optional: message history to save",
77 "items": {
78 "type": "object",
79 "properties": {
80 "role": { "type": "string" },
81 "content": { "type": "string" }
82 }
83 }
84 },
85 "iteration": {
86 "type": "integer",
87 "description": "Optional: current iteration number (default: 0)"
88 },
89 "tokens_used": {
90 "type": "integer",
91 "description": "Optional: tokens used so far (default: 0)"
92 }
93 },
94 "required": ["session_id"]
95 })
96 }
97
98 fn category(&self) -> ToolCategory {
99 ToolCategory::Workflow
100 }
101
102 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
103 let session_id_str = args["session_id"]
104 .as_str()
105 .ok_or_else(|| anyhow::anyhow!("Missing session_id parameter"))?;
106
107 let session_id = SessionId::from(session_id_str);
108 let trigger = args["trigger"].as_str().unwrap_or("manual");
109 let iteration = args["iteration"].as_i64().unwrap_or(0) as i32;
110 let tokens_used = args["tokens_used"].as_i64().unwrap_or(0);
111
112 let messages = args["messages"].as_array().cloned().unwrap_or_default();
113
114 let tool_calls_pending = args["tool_calls_pending"]
115 .as_array()
116 .cloned()
117 .unwrap_or_default();
118
119 let tool_results = args
120 .get("tool_results")
121 .cloned()
122 .unwrap_or(serde_json::Value::Null);
123
124 let checkpoint_data = CheckpointData {
126 checkpoint_id: CheckpointId::new(),
127 session_id: session_id.clone(),
128 created_at: Utc::now(),
129 trigger: trigger.to_string(),
130 iteration,
131 messages,
132 tool_calls_pending,
133 tool_results,
134 tokens_used,
135 cost_estimate: 0.0,
136 resume_hint: None,
137 };
138
139 let checkpoint_id = self.writer.save(&checkpoint_data).await?;
141
142 Ok(format!(
143 "Checkpoint created: {}\nSession: {}\nTrigger: {}\nIteration: {}",
144 checkpoint_id, session_id, trigger, iteration
145 ))
146 }
147}
148
149pub struct RestoreCheckpointTool {
154 writer: Arc<CheckpointWriter>,
155}
156
157impl RestoreCheckpointTool {
158 pub fn new() -> Self {
160 Self {
161 writer: Arc::new(CheckpointWriter::new(default_checkpoint_path())),
162 }
163 }
164
165 pub fn with_path(path: PathBuf) -> Self {
167 Self {
168 writer: Arc::new(CheckpointWriter::new(path)),
169 }
170 }
171}
172
173impl Default for RestoreCheckpointTool {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179#[async_trait]
180impl BuiltinTool for RestoreCheckpointTool {
181 fn name(&self) -> &str {
182 "restore_checkpoint"
183 }
184
185 fn description(&self) -> &str {
186 "Restore agent state from a checkpoint file."
187 }
188
189 fn parameters_schema(&self) -> serde_json::Value {
190 serde_json::json!({
191 "type": "object",
192 "properties": {
193 "session_id": {
194 "type": "string",
195 "description": "The session ID to restore"
196 },
197 "checkpoint_id": {
198 "type": "string",
199 "description": "Optional: specific checkpoint ID (default: latest)"
200 }
201 },
202 "required": ["session_id"]
203 })
204 }
205
206 fn category(&self) -> ToolCategory {
207 ToolCategory::Workflow
208 }
209
210 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
211 let session_id_str = args["session_id"]
212 .as_str()
213 .ok_or_else(|| anyhow::anyhow!("Missing session_id parameter"))?;
214
215 let session_id = SessionId::from(session_id_str);
216 let checkpoint_id_opt = args["checkpoint_id"]
217 .as_str()
218 .map(|s| CheckpointId(s.to_string()));
219
220 let result = self
222 .writer
223 .load(&session_id, checkpoint_id_opt.as_ref())
224 .await?;
225
226 match result {
227 Some(checkpoint) => {
228 Ok(format!(
229 "Checkpoint restored: {}\nSession: {}\nTrigger: {}\nIteration: {}\nMessages: {}\nTokens used: {}",
230 checkpoint.checkpoint_id,
231 checkpoint.session_id,
232 checkpoint.trigger,
233 checkpoint.iteration,
234 checkpoint.messages.len(),
235 checkpoint.tokens_used
236 ))
237 }
238 None => Err(anyhow::anyhow!(
239 "No checkpoints found for session: {}",
240 session_id_str
241 )),
242 }
243 }
244}
245
246pub struct ListCheckpointsTool {
251 writer: Arc<CheckpointWriter>,
252}
253
254impl ListCheckpointsTool {
255 pub fn new() -> Self {
256 Self {
257 writer: Arc::new(CheckpointWriter::new(default_checkpoint_path())),
258 }
259 }
260
261 pub fn with_path(path: PathBuf) -> Self {
262 Self {
263 writer: Arc::new(CheckpointWriter::new(path)),
264 }
265 }
266}
267
268impl Default for ListCheckpointsTool {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[async_trait]
275impl BuiltinTool for ListCheckpointsTool {
276 fn name(&self) -> &str {
277 "list_checkpoints"
278 }
279
280 fn description(&self) -> &str {
281 "List all checkpoints for a session."
282 }
283
284 fn parameters_schema(&self) -> serde_json::Value {
285 serde_json::json!({
286 "type": "object",
287 "properties": {
288 "session_id": {
289 "type": "string",
290 "description": "The session ID to list checkpoints for"
291 }
292 },
293 "required": ["session_id"]
294 })
295 }
296
297 fn category(&self) -> ToolCategory {
298 ToolCategory::Workflow
299 }
300
301 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
302 let session_id_str = args["session_id"]
303 .as_str()
304 .ok_or_else(|| anyhow::anyhow!("Missing session_id parameter"))?;
305
306 let session_id = SessionId::from(session_id_str);
307
308 let checkpoints = self.writer.list(&session_id).await?;
310
311 if checkpoints.is_empty() {
312 return Ok(format!(
313 "No checkpoints found for session: {}",
314 session_id_str
315 ));
316 }
317
318 let mut result = format!("Checkpoints for session {}:\n", session_id_str);
319 for (i, meta) in checkpoints.iter().enumerate() {
320 result.push_str(&format!(
321 " {}. {} (created: {})\n",
322 i + 1,
323 meta.checkpoint_id,
324 meta.created_at.format("%Y-%m-%d %H:%M:%S")
325 ));
326 }
327
328 Ok(result)
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use serde_json::json;
336 use tempfile::TempDir;
337
338 #[test]
339 fn test_checkpoint_tool_category() {
340 let tool = CreateCheckpointTool::new();
341 assert_eq!(tool.category(), ToolCategory::Workflow);
342 }
343
344 #[test]
345 fn test_restore_checkpoint_tool_category() {
346 let tool = RestoreCheckpointTool::new();
347 assert_eq!(tool.category(), ToolCategory::Workflow);
348 }
349
350 #[tokio::test]
351 async fn test_create_checkpoint() {
352 let temp_dir = TempDir::new().unwrap();
353 let tool = CreateCheckpointTool::with_path(temp_dir.path().to_path_buf());
354
355 let result = tool
356 .execute(json!({
357 "session_id": "test_session",
358 "trigger": "manual",
359 "messages": [{"role": "user", "content": "hello"}],
360 "iteration": 1
361 }))
362 .await;
363
364 assert!(result.is_ok());
365 let output = result.unwrap();
366 assert!(output.contains("Checkpoint created"));
367 assert!(output.contains("test_session"));
368 }
369
370 #[tokio::test]
371 async fn test_restore_checkpoint() {
372 let temp_dir = TempDir::new().unwrap();
373 let create_tool = CreateCheckpointTool::with_path(temp_dir.path().to_path_buf());
374
375 create_tool
377 .execute(json!({
378 "session_id": "test_session",
379 "messages": [{"role": "user", "content": "test"}]
380 }))
381 .await
382 .unwrap();
383
384 let restore_tool = RestoreCheckpointTool::with_path(temp_dir.path().to_path_buf());
386 let result = restore_tool
387 .execute(json!({"session_id": "test_session"}))
388 .await;
389
390 assert!(result.is_ok());
391 let output = result.unwrap();
392 assert!(output.contains("Checkpoint restored"));
393 }
394
395 #[tokio::test]
396 async fn test_restore_nonexistent_checkpoint() {
397 let temp_dir = TempDir::new().unwrap();
398 let tool = RestoreCheckpointTool::with_path(temp_dir.path().to_path_buf());
399
400 let result = tool
401 .execute(json!({"session_id": "nonexistent_session"}))
402 .await;
403
404 assert!(result.is_err());
405 assert!(result
406 .unwrap_err()
407 .to_string()
408 .contains("No checkpoints found"));
409 }
410
411 #[tokio::test]
412 async fn test_list_checkpoints() {
413 let temp_dir = TempDir::new().unwrap();
414 let create_tool = CreateCheckpointTool::with_path(temp_dir.path().to_path_buf());
415
416 create_tool
418 .execute(json!({"session_id": "test_session"}))
419 .await
420 .unwrap();
421
422 let list_tool = ListCheckpointsTool::with_path(temp_dir.path().to_path_buf());
423 let result = list_tool
424 .execute(json!({"session_id": "test_session"}))
425 .await;
426
427 assert!(result.is_ok());
428 let output = result.unwrap();
429 assert!(output.contains("Checkpoints for session"));
430 }
431}