tirea_contract/runtime/tool_call/
lifecycle.rs1use crate::runtime::plugin::phase::SuspendTicket;
2use crate::runtime::state_paths::{SUSPENDED_TOOL_CALLS_STATE_PATH, TOOL_CALL_STATES_STATE_PATH};
3use crate::thread::ToolCall;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use tirea_state::State;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11#[serde(rename_all = "snake_case")]
12pub enum ResumeDecisionAction {
13 Resume,
14 Cancel,
15}
16
17#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
22#[serde(rename_all = "snake_case")]
23pub enum ToolCallResumeMode {
24 ReplayToolCall,
26 UseDecisionAsToolResult,
28 PassDecisionToTool,
30}
31
32impl Default for ToolCallResumeMode {
33 fn default() -> Self {
34 Self::ReplayToolCall
35 }
36}
37
38#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
40pub struct PendingToolCall {
41 pub id: String,
42 pub name: String,
43 pub arguments: Value,
44}
45
46impl PendingToolCall {
47 pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: Value) -> Self {
48 Self {
49 id: id.into(),
50 name: name.into(),
51 arguments,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct SuspendedCall {
58 pub call_id: String,
60 pub tool_name: String,
62 pub arguments: Value,
64 #[serde(flatten)]
66 pub ticket: SuspendTicket,
67}
68
69impl SuspendedCall {
70 pub fn new(call: &ToolCall, ticket: SuspendTicket) -> Self {
72 Self {
73 call_id: call.id.clone(),
74 tool_name: call.name.clone(),
75 arguments: call.arguments.clone(),
76 ticket,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
83#[tirea(path = "__suspended_tool_calls")]
84pub struct SuspendedToolCallsState {
85 #[serde(default)]
87 #[tirea(default = "HashMap::new()")]
88 pub calls: HashMap<String, SuspendedCall>,
89}
90
91#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
93#[serde(rename_all = "snake_case")]
94pub enum ToolCallStatus {
95 #[default]
97 New,
98 Running,
100 Suspended,
102 Resuming,
104 Succeeded,
106 Failed,
108 Cancelled,
110}
111
112impl ToolCallStatus {
113 pub const ASCII_STATE_MACHINE: &str = r#"new ------------> running
115 | |
116 | v
117 +------------> suspended -----> resuming
118 | |
119 +---------------+
120
121running/resuming ---> succeeded
122running/resuming ---> failed
123running/suspended/resuming ---> cancelled"#;
124
125 pub fn is_terminal(self) -> bool {
127 matches!(
128 self,
129 ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled
130 )
131 }
132
133 pub fn can_transition_to(self, next: Self) -> bool {
135 if self == next {
136 return true;
137 }
138
139 match self {
140 ToolCallStatus::New => true,
141 ToolCallStatus::Running => matches!(
142 next,
143 ToolCallStatus::Suspended
144 | ToolCallStatus::Succeeded
145 | ToolCallStatus::Failed
146 | ToolCallStatus::Cancelled
147 ),
148 ToolCallStatus::Suspended => {
149 matches!(next, ToolCallStatus::Resuming | ToolCallStatus::Cancelled)
150 }
151 ToolCallStatus::Resuming => matches!(
152 next,
153 ToolCallStatus::Running
154 | ToolCallStatus::Suspended
155 | ToolCallStatus::Succeeded
156 | ToolCallStatus::Failed
157 | ToolCallStatus::Cancelled
158 ),
159 ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled => false,
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
166pub struct ToolCallResume {
167 #[serde(default)]
169 pub decision_id: String,
170 pub action: ResumeDecisionAction,
172 #[serde(default, skip_serializing_if = "Value::is_null")]
174 pub result: Value,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub reason: Option<String>,
178 #[serde(default)]
180 pub updated_at: u64,
181}
182
183#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, State)]
185pub struct ToolCallState {
186 #[serde(default, skip_serializing_if = "String::is_empty")]
188 pub call_id: String,
189 #[serde(default, skip_serializing_if = "String::is_empty")]
191 pub tool_name: String,
192 #[serde(default, skip_serializing_if = "Value::is_null")]
194 pub arguments: Value,
195 #[serde(default)]
197 pub status: ToolCallStatus,
198 #[serde(default, skip_serializing_if = "Option::is_none")]
200 pub resume_token: Option<String>,
201 #[serde(default, skip_serializing_if = "Option::is_none")]
203 pub resume: Option<ToolCallResume>,
204 #[serde(default, skip_serializing_if = "Value::is_null")]
206 pub scratch: Value,
207 #[serde(default)]
209 pub updated_at: u64,
210}
211
212#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
214#[tirea(path = "__tool_call_states")]
215pub struct ToolCallStatesMap {
216 #[serde(default)]
218 #[tirea(default = "HashMap::new()")]
219 pub calls: HashMap<String, ToolCallState>,
220}
221
222pub fn suspended_calls_from_state(state: &Value) -> HashMap<String, SuspendedCall> {
224 state
225 .get(SUSPENDED_TOOL_CALLS_STATE_PATH)
226 .and_then(|value| value.get("calls"))
227 .cloned()
228 .and_then(|value| serde_json::from_value(value).ok())
229 .unwrap_or_default()
230}
231
232pub fn tool_call_states_from_state(state: &Value) -> HashMap<String, ToolCallState> {
234 state
235 .get(TOOL_CALL_STATES_STATE_PATH)
236 .and_then(|value| value.get("calls"))
237 .cloned()
238 .and_then(|value| serde_json::from_value(value).ok())
239 .unwrap_or_default()
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn suspended_tool_calls_state_defaults_to_empty() {
248 let suspended = SuspendedToolCallsState::default();
249 assert!(suspended.calls.is_empty());
250 }
251
252 #[test]
253 fn tool_call_status_transitions_match_lifecycle() {
254 assert!(ToolCallStatus::New.can_transition_to(ToolCallStatus::Running));
255 assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Suspended));
256 assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Resuming));
257 assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Running));
258 assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Failed));
259 assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Succeeded));
260 assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Failed));
261 assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Cancelled));
262 }
263
264 #[test]
265 fn tool_call_status_rejects_terminal_reopen_transitions() {
266 assert!(!ToolCallStatus::Succeeded.can_transition_to(ToolCallStatus::Running));
267 assert!(!ToolCallStatus::Failed.can_transition_to(ToolCallStatus::Resuming));
268 assert!(!ToolCallStatus::Cancelled.can_transition_to(ToolCallStatus::Suspended));
269 }
270
271 #[test]
272 fn suspended_call_serde_flatten_roundtrip() {
273 use crate::runtime::tool_call::Suspension;
274
275 let call = SuspendedCall {
276 call_id: "call_1".into(),
277 tool_name: "my_tool".into(),
278 arguments: serde_json::json!({"key": "val"}),
279 ticket: SuspendTicket::new(
280 Suspension::new("susp_1", "confirm"),
281 PendingToolCall::new("pending_1", "my_tool", serde_json::json!({"key": "val"})),
282 ToolCallResumeMode::UseDecisionAsToolResult,
283 ),
284 };
285
286 let json = serde_json::to_value(&call).unwrap();
287
288 assert!(json.get("ticket").is_none(), "ticket should be flattened");
290 assert!(json.get("suspension").is_some(), "suspension should be at top level");
291 assert!(json.get("pending").is_some(), "pending should be at top level");
292 assert!(json.get("resume_mode").is_some(), "resume_mode should be at top level");
293 assert_eq!(json["call_id"], "call_1");
294 assert_eq!(json["suspension"]["id"], "susp_1");
295 assert_eq!(json["pending"]["id"], "pending_1");
296
297 let deserialized: SuspendedCall = serde_json::from_value(json).unwrap();
299 assert_eq!(deserialized, call);
300 }
301
302 #[test]
303 fn tool_call_ascii_state_machine_contains_all_states() {
304 let diagram = ToolCallStatus::ASCII_STATE_MACHINE;
305 assert!(diagram.contains("new"));
306 assert!(diagram.contains("running"));
307 assert!(diagram.contains("suspended"));
308 assert!(diagram.contains("resuming"));
309 assert!(diagram.contains("succeeded"));
310 assert!(diagram.contains("failed"));
311 assert!(diagram.contains("cancelled"));
312 }
313}