Skip to main content

sh_layer3/builtin_tools/
workflow_tools.rs

1//! # Workflow Tools
2//!
3//! 工作流控制工具集,提供检查点保存和恢复功能。
4//!
5//! 使用 Layer2 CheckpointSystem 实现持久化。
6
7use crate::builtin_tools::BuiltinTool;
8use crate::types::{Layer3Result, ToolCategory};
9use async_trait::async_trait;
10use chrono::Utc;
11use std::path::PathBuf;
12use std::sync::Arc;
13
14// Layer2 CheckpointSystem integration
15use sh_layer2::{CheckpointData, CheckpointId, CheckpointSystemTrait, CheckpointWriter, SessionId};
16
17/// 默认检查点存储路径
18fn default_checkpoint_path() -> PathBuf {
19    std::env::temp_dir().join("continuum_checkpoints")
20}
21
22/// Create Checkpoint Tool
23///
24/// 保存当前会话状态到检查点文件。
25/// 使用 Layer2 CheckpointWriter 实现原子写入和完整性验证。
26pub struct CreateCheckpointTool {
27    writer: Arc<CheckpointWriter>,
28}
29
30impl CreateCheckpointTool {
31    /// 创建新工具,使用默认存储路径
32    pub fn new() -> Self {
33        Self {
34            writer: Arc::new(CheckpointWriter::new(default_checkpoint_path())),
35        }
36    }
37
38    /// 使用自定义存储路径
39    pub fn with_path(path: PathBuf) -> Self {
40        Self {
41            writer: Arc::new(CheckpointWriter::new(path)),
42        }
43    }
44}
45
46impl Default for CreateCheckpointTool {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52#[async_trait]
53impl BuiltinTool for CreateCheckpointTool {
54    fn name(&self) -> &str {
55        "create_checkpoint"
56    }
57
58    fn description(&self) -> &str {
59        "Create a checkpoint to save current agent state to a file."
60    }
61
62    fn parameters_schema(&self) -> serde_json::Value {
63        serde_json::json!({
64            "type": "object",
65            "properties": {
66                "session_id": {
67                    "type": "string",
68                    "description": "The session ID to checkpoint"
69                },
70                "trigger": {
71                    "type": "string",
72                    "description": "Optional: trigger reason for the checkpoint (default: 'manual')"
73                },
74                "messages": {
75                    "type": "array",
76                    "description": "Optional: message history to save",
77                    "items": {
78                        "type": "object",
79                        "properties": {
80                            "role": { "type": "string" },
81                            "content": { "type": "string" }
82                        }
83                    }
84                },
85                "iteration": {
86                    "type": "integer",
87                    "description": "Optional: current iteration number (default: 0)"
88                },
89                "tokens_used": {
90                    "type": "integer",
91                    "description": "Optional: tokens used so far (default: 0)"
92                }
93            },
94            "required": ["session_id"]
95        })
96    }
97
98    fn category(&self) -> ToolCategory {
99        ToolCategory::Workflow
100    }
101
102    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
103        let session_id_str = args["session_id"]
104            .as_str()
105            .ok_or_else(|| anyhow::anyhow!("Missing session_id parameter"))?;
106
107        let session_id = SessionId::from(session_id_str);
108        let trigger = args["trigger"].as_str().unwrap_or("manual");
109        let iteration = args["iteration"].as_i64().unwrap_or(0) as i32;
110        let tokens_used = args["tokens_used"].as_i64().unwrap_or(0);
111
112        let messages = args["messages"].as_array().cloned().unwrap_or_default();
113
114        let tool_calls_pending = args["tool_calls_pending"]
115            .as_array()
116            .cloned()
117            .unwrap_or_default();
118
119        let tool_results = args
120            .get("tool_results")
121            .cloned()
122            .unwrap_or(serde_json::Value::Null);
123
124        // 构建 CheckpointData (Layer2 格式)
125        let checkpoint_data = CheckpointData {
126            checkpoint_id: CheckpointId::new(),
127            session_id: session_id.clone(),
128            created_at: Utc::now(),
129            trigger: trigger.to_string(),
130            iteration,
131            messages,
132            tool_calls_pending,
133            tool_results,
134            tokens_used,
135            cost_estimate: 0.0,
136            resume_hint: None,
137        };
138
139        // 使用 Layer2 CheckpointWriter 保存
140        let checkpoint_id = self.writer.save(&checkpoint_data).await?;
141
142        Ok(format!(
143            "Checkpoint created: {}\nSession: {}\nTrigger: {}\nIteration: {}",
144            checkpoint_id, session_id, trigger, iteration
145        ))
146    }
147}
148
149/// Restore Checkpoint Tool
150///
151/// 从检查点文件恢复会话状态。
152/// 使用 Layer2 CheckpointWriter 实现读取和验证。
153pub struct RestoreCheckpointTool {
154    writer: Arc<CheckpointWriter>,
155}
156
157impl RestoreCheckpointTool {
158    /// 创建新工具,使用默认存储路径
159    pub fn new() -> Self {
160        Self {
161            writer: Arc::new(CheckpointWriter::new(default_checkpoint_path())),
162        }
163    }
164
165    /// 使用自定义存储路径
166    pub fn with_path(path: PathBuf) -> Self {
167        Self {
168            writer: Arc::new(CheckpointWriter::new(path)),
169        }
170    }
171}
172
173impl Default for RestoreCheckpointTool {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179#[async_trait]
180impl BuiltinTool for RestoreCheckpointTool {
181    fn name(&self) -> &str {
182        "restore_checkpoint"
183    }
184
185    fn description(&self) -> &str {
186        "Restore agent state from a checkpoint file."
187    }
188
189    fn parameters_schema(&self) -> serde_json::Value {
190        serde_json::json!({
191            "type": "object",
192            "properties": {
193                "session_id": {
194                    "type": "string",
195                    "description": "The session ID to restore"
196                },
197                "checkpoint_id": {
198                    "type": "string",
199                    "description": "Optional: specific checkpoint ID (default: latest)"
200                }
201            },
202            "required": ["session_id"]
203        })
204    }
205
206    fn category(&self) -> ToolCategory {
207        ToolCategory::Workflow
208    }
209
210    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
211        let session_id_str = args["session_id"]
212            .as_str()
213            .ok_or_else(|| anyhow::anyhow!("Missing session_id parameter"))?;
214
215        let session_id = SessionId::from(session_id_str);
216        let checkpoint_id_opt = args["checkpoint_id"]
217            .as_str()
218            .map(|s| CheckpointId(s.to_string()));
219
220        // 使用 Layer2 CheckpointWriter 加载
221        let result = self
222            .writer
223            .load(&session_id, checkpoint_id_opt.as_ref())
224            .await?;
225
226        match result {
227            Some(checkpoint) => {
228                Ok(format!(
229                    "Checkpoint restored: {}\nSession: {}\nTrigger: {}\nIteration: {}\nMessages: {}\nTokens used: {}",
230                    checkpoint.checkpoint_id,
231                    checkpoint.session_id,
232                    checkpoint.trigger,
233                    checkpoint.iteration,
234                    checkpoint.messages.len(),
235                    checkpoint.tokens_used
236                ))
237            }
238            None => Err(anyhow::anyhow!(
239                "No checkpoints found for session: {}",
240                session_id_str
241            )),
242        }
243    }
244}
245
246/// List Checkpoints Tool
247///
248/// 列出会话的所有检查点。
249/// 使用 Layer2 CheckpointWriter 实现。
250pub struct ListCheckpointsTool {
251    writer: Arc<CheckpointWriter>,
252}
253
254impl ListCheckpointsTool {
255    pub fn new() -> Self {
256        Self {
257            writer: Arc::new(CheckpointWriter::new(default_checkpoint_path())),
258        }
259    }
260
261    pub fn with_path(path: PathBuf) -> Self {
262        Self {
263            writer: Arc::new(CheckpointWriter::new(path)),
264        }
265    }
266}
267
268impl Default for ListCheckpointsTool {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274#[async_trait]
275impl BuiltinTool for ListCheckpointsTool {
276    fn name(&self) -> &str {
277        "list_checkpoints"
278    }
279
280    fn description(&self) -> &str {
281        "List all checkpoints for a session."
282    }
283
284    fn parameters_schema(&self) -> serde_json::Value {
285        serde_json::json!({
286            "type": "object",
287            "properties": {
288                "session_id": {
289                    "type": "string",
290                    "description": "The session ID to list checkpoints for"
291                }
292            },
293            "required": ["session_id"]
294        })
295    }
296
297    fn category(&self) -> ToolCategory {
298        ToolCategory::Workflow
299    }
300
301    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
302        let session_id_str = args["session_id"]
303            .as_str()
304            .ok_or_else(|| anyhow::anyhow!("Missing session_id parameter"))?;
305
306        let session_id = SessionId::from(session_id_str);
307
308        // 使用 Layer2 CheckpointWriter 列出检查点
309        let checkpoints = self.writer.list(&session_id).await?;
310
311        if checkpoints.is_empty() {
312            return Ok(format!(
313                "No checkpoints found for session: {}",
314                session_id_str
315            ));
316        }
317
318        let mut result = format!("Checkpoints for session {}:\n", session_id_str);
319        for (i, meta) in checkpoints.iter().enumerate() {
320            result.push_str(&format!(
321                "  {}. {} (created: {})\n",
322                i + 1,
323                meta.checkpoint_id,
324                meta.created_at.format("%Y-%m-%d %H:%M:%S")
325            ));
326        }
327
328        Ok(result)
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use serde_json::json;
336    use tempfile::TempDir;
337
338    #[test]
339    fn test_checkpoint_tool_category() {
340        let tool = CreateCheckpointTool::new();
341        assert_eq!(tool.category(), ToolCategory::Workflow);
342    }
343
344    #[test]
345    fn test_restore_checkpoint_tool_category() {
346        let tool = RestoreCheckpointTool::new();
347        assert_eq!(tool.category(), ToolCategory::Workflow);
348    }
349
350    #[tokio::test]
351    async fn test_create_checkpoint() {
352        let temp_dir = TempDir::new().unwrap();
353        let tool = CreateCheckpointTool::with_path(temp_dir.path().to_path_buf());
354
355        let result = tool
356            .execute(json!({
357                "session_id": "test_session",
358                "trigger": "manual",
359                "messages": [{"role": "user", "content": "hello"}],
360                "iteration": 1
361            }))
362            .await;
363
364        assert!(result.is_ok());
365        let output = result.unwrap();
366        assert!(output.contains("Checkpoint created"));
367        assert!(output.contains("test_session"));
368    }
369
370    #[tokio::test]
371    async fn test_restore_checkpoint() {
372        let temp_dir = TempDir::new().unwrap();
373        let create_tool = CreateCheckpointTool::with_path(temp_dir.path().to_path_buf());
374
375        // 先创建检查点
376        create_tool
377            .execute(json!({
378                "session_id": "test_session",
379                "messages": [{"role": "user", "content": "test"}]
380            }))
381            .await
382            .unwrap();
383
384        // 然后恢复
385        let restore_tool = RestoreCheckpointTool::with_path(temp_dir.path().to_path_buf());
386        let result = restore_tool
387            .execute(json!({"session_id": "test_session"}))
388            .await;
389
390        assert!(result.is_ok());
391        let output = result.unwrap();
392        assert!(output.contains("Checkpoint restored"));
393    }
394
395    #[tokio::test]
396    async fn test_restore_nonexistent_checkpoint() {
397        let temp_dir = TempDir::new().unwrap();
398        let tool = RestoreCheckpointTool::with_path(temp_dir.path().to_path_buf());
399
400        let result = tool
401            .execute(json!({"session_id": "nonexistent_session"}))
402            .await;
403
404        assert!(result.is_err());
405        assert!(result
406            .unwrap_err()
407            .to_string()
408            .contains("No checkpoints found"));
409    }
410
411    #[tokio::test]
412    async fn test_list_checkpoints() {
413        let temp_dir = TempDir::new().unwrap();
414        let create_tool = CreateCheckpointTool::with_path(temp_dir.path().to_path_buf());
415
416        // 创建多个检查点
417        create_tool
418            .execute(json!({"session_id": "test_session"}))
419            .await
420            .unwrap();
421
422        let list_tool = ListCheckpointsTool::with_path(temp_dir.path().to_path_buf());
423        let result = list_tool
424            .execute(json!({"session_id": "test_session"}))
425            .await;
426
427        assert!(result.is_ok());
428        let output = result.unwrap();
429        assert!(output.contains("Checkpoints for session"));
430    }
431}