Skip to main content

sh_layer2/
types.rs

1//! # Layer 2 Core Types
2//!
3//! 定义 Layer 2 使用的核心类型,供所有模块共享。
4
5use serde::{Deserialize, Serialize};
6use sh_layer1::generate_short_id;
7use std::fmt;
8
9/// 会话 ID
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct SessionId(pub String);
12
13impl SessionId {
14    pub fn new() -> Self {
15        Self(generate_short_id())
16    }
17}
18
19impl Default for SessionId {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl fmt::Display for SessionId {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        write!(f, "{}", self.0)
28    }
29}
30
31impl From<&str> for SessionId {
32    fn from(s: &str) -> Self {
33        Self(s.to_string())
34    }
35}
36
37/// Agent ID
38#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
39pub struct AgentId(pub String);
40
41impl AgentId {
42    pub fn new() -> Self {
43        Self("default".to_string())
44    }
45}
46
47impl Default for AgentId {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl fmt::Display for AgentId {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "{}", self.0)
56    }
57}
58
59/// 任务 ID
60#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
61pub struct TaskId(pub String);
62
63impl TaskId {
64    pub fn new() -> Self {
65        Self(generate_short_id())
66    }
67}
68
69impl Default for TaskId {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl fmt::Display for TaskId {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        write!(f, "{}", self.0)
78    }
79}
80
81/// 检查点 ID
82#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
83pub struct CheckpointId(pub String);
84
85impl CheckpointId {
86    pub fn new() -> Self {
87        Self(generate_short_id())
88    }
89}
90
91impl Default for CheckpointId {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl fmt::Display for CheckpointId {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        write!(f, "{}", self.0)
100    }
101}
102
103/// Agent 执行状态机
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
105#[serde(rename_all = "snake_case")]
106pub enum AgentState {
107    #[default]
108    Idle,
109    Running,
110    ToolCalling,
111    WaitingTool,
112    Stopped,
113    Error,
114    Completed,
115}
116
117impl fmt::Display for AgentState {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        match self {
120            Self::Idle => write!(f, "idle"),
121            Self::Running => write!(f, "running"),
122            Self::ToolCalling => write!(f, "tool_calling"),
123            Self::WaitingTool => write!(f, "waiting_tool"),
124            Self::Stopped => write!(f, "stopped"),
125            Self::Error => write!(f, "error"),
126            Self::Completed => write!(f, "completed"),
127        }
128    }
129}
130
131impl std::str::FromStr for AgentState {
132    type Err = String;
133
134    fn from_str(s: &str) -> Result<Self, Self::Err> {
135        match s {
136            "idle" => Ok(Self::Idle),
137            "running" => Ok(Self::Running),
138            "tool_calling" => Ok(Self::ToolCalling),
139            "waiting_tool" => Ok(Self::WaitingTool),
140            "stopped" => Ok(Self::Stopped),
141            "error" => Ok(Self::Error),
142            "completed" => Ok(Self::Completed),
143            _ => Err(format!("Unknown agent state: {}", s)),
144        }
145    }
146}
147
148/// 消息角色
149#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
150#[serde(rename_all = "lowercase")]
151pub enum MessageRole {
152    System,
153    User,
154    Assistant,
155    Tool,
156}
157
158/// OpenAI 格式消息
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct Message {
161    pub role: MessageRole,
162    pub content: String,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub name: Option<String>,
165    #[serde(skip_serializing_if = "Option::is_none")]
166    pub tool_call_id: Option<String>,
167}
168
169impl Message {
170    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
171        Self {
172            role,
173            content: content.into(),
174            name: None,
175            tool_call_id: None,
176        }
177    }
178
179    pub fn user(content: impl Into<String>) -> Self {
180        Self::new(MessageRole::User, content)
181    }
182
183    pub fn assistant(content: impl Into<String>) -> Self {
184        Self::new(MessageRole::Assistant, content)
185    }
186
187    pub fn system(content: impl Into<String>) -> Self {
188        Self::new(MessageRole::System, content)
189    }
190}
191
192/// 工具调用
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct ToolCall {
195    pub id: String,
196    pub name: String,
197    pub arguments: String,
198}
199
200/// 工具结果
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct ToolResult {
203    pub tool_call_id: String,
204    pub name: String,
205    pub content: String,
206    pub is_error: bool,
207}
208
209/// Hook 事件类型
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
211pub enum HookEvent {
212    BeforeAgentStart,
213    AfterAgentStop,
214    BeforeToolCall,
215    AfterToolCall,
216    BeforeCheckpoint,
217    AfterCheckpoint,
218    OnError,
219    OnStateChange,
220}
221
222/// 工作流节点
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct WorkflowNode {
225    pub id: String,
226    pub name: String,
227    #[serde(default)]
228    pub dependencies: Vec<String>,
229}
230
231/// 检查点元数据
232#[derive(Debug, Clone, Serialize, Deserialize)]
233pub struct CheckpointMeta {
234    pub checkpoint_id: CheckpointId,
235    pub session_id: SessionId,
236    pub created_at: chrono::DateTime<chrono::Utc>,
237    pub trigger: String,
238    pub iteration: i32,
239    #[serde(rename = "_checksum")]
240    pub checksum: String,
241}
242
243/// 会话元数据(用于列表显示)
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct SessionMeta {
246    pub session_id: SessionId,
247    pub agent_id: AgentId,
248    pub state: AgentState,
249    pub created_at: chrono::DateTime<chrono::Utc>,
250    pub last_updated: chrono::DateTime<chrono::Utc>,
251    pub message_count: usize,
252    pub checkpoint_count: i32,
253}
254
255/// 统一的 Layer 2 Result 类型
256pub type Layer2Result<T> = anyhow::Result<T>;
257
258/// 统一的 Layer 2 Error 类型
259#[derive(Debug, thiserror::Error)]
260pub enum Layer2Error {
261    #[error("Session not found: {0}")]
262    SessionNotFound(SessionId),
263
264    #[error("Checkpoint not found: {0}")]
265    CheckpointNotFound(CheckpointId),
266
267    #[error("Tool not found: {0}")]
268    ToolNotFound(String),
269
270    #[error("Task not found: {0}")]
271    TaskNotFound(TaskId),
272
273    #[error("Lock acquisition timeout")]
274    LockTimeout,
275
276    #[error("Invalid state transition: from {from} to {to}")]
277    InvalidStateTransition { from: AgentState, to: AgentState },
278
279    #[error("Checkpoint corrupted: {0}")]
280    CheckpointCorrupted(String),
281
282    #[error("IO error: {0}")]
283    Io(#[from] std::io::Error),
284
285    #[error("Serialization error: {0}")]
286    Serialization(#[from] serde_json::Error),
287
288    #[error("Max iterations reached: {0}")]
289    MaxIterations(i32),
290
291    #[error("Agent error: {0}")]
292    AgentError(String),
293
294    #[error("LLM client not configured")]
295    LlmNotConfigured,
296
297    #[error("Max sessions reached: {0}")]
298    MaxSessionsReached(usize),
299
300    #[error("Permission denied: {0}")]
301    PermissionDenied(String),
302}
303
304// 注意:不需要手动实现 From<Layer2Error> for anyhow::Error,
305// 因为 Layer2Error 实现了 std::error::Error + Send + Sync + 'static,
306// anyhow 会自动实现这个转换。
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_session_id_creation() {
314        let id1 = SessionId::new();
315        let id2 = SessionId::new();
316        assert_ne!(id1, id2);
317        assert_eq!(id1.0.len(), 8);
318    }
319
320    #[test]
321    fn test_agent_state_display() {
322        assert_eq!(format!("{}", AgentState::Running), "running");
323        assert_eq!(format!("{}", AgentState::ToolCalling), "tool_calling");
324    }
325
326    #[test]
327    fn test_agent_state_from_str() {
328        let state: AgentState = "running".parse().unwrap();
329        assert_eq!(state, AgentState::Running);
330    }
331
332    #[test]
333    fn test_message_creation() {
334        let msg = Message::user("Hello");
335        assert_eq!(msg.role, MessageRole::User);
336        assert_eq!(msg.content, "Hello");
337    }
338}