tirea_contract/runtime/tool_call/
lifecycle.rs1use crate::runtime::phase::SuspendTicket;
2use crate::thread::ToolCall;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use tirea_state::State;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11pub enum ResumeDecisionAction {
12 Resume,
13 Cancel,
14}
15
16#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
21#[serde(rename_all = "snake_case")]
22pub enum ToolCallResumeMode {
23 #[default]
25 ReplayToolCall,
26 UseDecisionAsToolResult,
28 PassDecisionToTool,
30}
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
34pub struct PendingToolCall {
35 pub id: String,
36 pub name: String,
37 pub arguments: Value,
38}
39
40impl PendingToolCall {
41 pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: Value) -> Self {
42 Self {
43 id: id.into(),
44 name: name.into(),
45 arguments,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
51pub struct SuspendedCall {
52 #[serde(default)]
54 pub call_id: String,
55 #[serde(default)]
57 pub tool_name: String,
58 #[serde(default)]
60 pub arguments: Value,
61 #[serde(flatten)]
63 pub ticket: SuspendTicket,
64}
65
66impl SuspendedCall {
67 pub fn new(call: &ToolCall, ticket: SuspendTicket) -> Self {
69 Self {
70 call_id: call.id.clone(),
71 tool_name: call.name.clone(),
72 arguments: call.arguments.clone(),
73 ticket,
74 }
75 }
76
77 pub fn into_state_action(self) -> crate::runtime::state::AnyStateAction {
82 let call_id = self.call_id.clone();
83 crate::runtime::state::AnyStateAction::new_for_call::<SuspendedCallState>(
84 SuspendedCallAction::Set(self),
85 call_id,
86 )
87 }
88}
89
90#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
96#[tirea(
97 path = "suspended_call",
98 action = "SuspendedCallAction",
99 scope = "tool_call"
100)]
101pub struct SuspendedCallState {
102 #[serde(flatten)]
104 pub call: SuspendedCall,
105}
106
107#[derive(Serialize, Deserialize)]
109pub enum SuspendedCallAction {
110 Set(SuspendedCall),
112}
113
114impl SuspendedCallState {
115 fn reduce(&mut self, action: SuspendedCallAction) {
116 match action {
117 SuspendedCallAction::Set(call) => {
118 self.call = call;
119 }
120 }
121 }
122}
123
124#[derive(Serialize, Deserialize)]
126pub enum ToolCallStateAction {
127 Set(ToolCallState),
129}
130
131#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
133#[serde(rename_all = "snake_case")]
134pub enum ToolCallStatus {
135 #[default]
137 New,
138 Running,
140 Suspended,
142 Resuming,
144 Succeeded,
146 Failed,
148 Cancelled,
150}
151
152impl ToolCallStatus {
153 pub const ASCII_STATE_MACHINE: &str = r#"new ------------> running
155 | |
156 | v
157 +------------> suspended -----> resuming
158 | |
159 +---------------+
160
161running/resuming ---> succeeded
162running/resuming ---> failed
163running/suspended/resuming ---> cancelled"#;
164
165 pub fn is_terminal(self) -> bool {
167 matches!(
168 self,
169 ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled
170 )
171 }
172
173 pub fn can_transition_to(self, next: Self) -> bool {
175 if self == next {
176 return true;
177 }
178
179 match self {
180 ToolCallStatus::New => true,
181 ToolCallStatus::Running => matches!(
182 next,
183 ToolCallStatus::Suspended
184 | ToolCallStatus::Succeeded
185 | ToolCallStatus::Failed
186 | ToolCallStatus::Cancelled
187 ),
188 ToolCallStatus::Suspended => {
189 matches!(next, ToolCallStatus::Resuming | ToolCallStatus::Cancelled)
190 }
191 ToolCallStatus::Resuming => matches!(
192 next,
193 ToolCallStatus::Running
194 | ToolCallStatus::Suspended
195 | ToolCallStatus::Succeeded
196 | ToolCallStatus::Failed
197 | ToolCallStatus::Cancelled
198 ),
199 ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled => false,
200 }
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
206pub struct ToolCallResume {
207 #[serde(default)]
209 pub decision_id: String,
210 pub action: ResumeDecisionAction,
212 #[serde(default, skip_serializing_if = "Value::is_null")]
214 pub result: Value,
215 #[serde(default, skip_serializing_if = "Option::is_none")]
217 pub reason: Option<String>,
218 #[serde(default)]
220 pub updated_at: u64,
221}
222
223#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, State)]
227#[tirea(
228 path = "tool_call_state",
229 action = "ToolCallStateAction",
230 scope = "tool_call"
231)]
232pub struct ToolCallState {
233 #[serde(default, skip_serializing_if = "String::is_empty")]
235 pub call_id: String,
236 #[serde(default, skip_serializing_if = "String::is_empty")]
238 pub tool_name: String,
239 #[serde(default, skip_serializing_if = "Value::is_null")]
241 pub arguments: Value,
242 #[serde(default)]
244 pub status: ToolCallStatus,
245 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub resume_token: Option<String>,
248 #[serde(default, skip_serializing_if = "Option::is_none")]
250 pub resume: Option<ToolCallResume>,
251 #[serde(default, skip_serializing_if = "Value::is_null")]
253 pub scratch: Value,
254 #[serde(default)]
256 pub updated_at: u64,
257}
258
259impl ToolCallState {
260 pub fn into_state_action(self) -> crate::runtime::state::AnyStateAction {
265 let call_id = self.call_id.clone();
266 crate::runtime::state::AnyStateAction::new_for_call::<ToolCallState>(
267 ToolCallStateAction::Set(self),
268 call_id,
269 )
270 }
271}
272
273impl ToolCallState {
274 fn reduce(&mut self, action: ToolCallStateAction) {
275 match action {
276 ToolCallStateAction::Set(s) => *self = s,
277 }
278 }
279}
280
281pub fn suspended_calls_from_state(state: &Value) -> HashMap<String, SuspendedCall> {
283 let Some(Value::Object(scopes)) = state.get("__tool_call_scope") else {
284 return HashMap::new();
285 };
286 scopes
287 .iter()
288 .filter_map(|(call_id, scope_val)| {
289 scope_val
290 .get("suspended_call")
291 .and_then(|v| SuspendedCallState::from_value(v).ok())
292 .map(|s| (call_id.clone(), s.call))
293 })
294 .collect()
295}
296
297pub fn tool_call_states_from_state(state: &Value) -> HashMap<String, ToolCallState> {
301 let Some(Value::Object(scopes)) = state.get("__tool_call_scope") else {
302 return HashMap::new();
303 };
304 scopes
305 .iter()
306 .filter_map(|(call_id, scope_val)| {
307 scope_val
308 .get("tool_call_state")
309 .and_then(|v| ToolCallState::from_value(v).ok())
310 .map(|s| (call_id.clone(), s))
311 })
312 .collect()
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn suspended_call_state_default() {
321 let suspended = SuspendedCallState::default();
322 assert_eq!(suspended.call.call_id, "");
323 assert_eq!(suspended.call.tool_name, "");
324 }
325
326 #[test]
327 fn tool_call_status_transitions_match_lifecycle() {
328 assert!(ToolCallStatus::New.can_transition_to(ToolCallStatus::Running));
329 assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Suspended));
330 assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Resuming));
331 assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Running));
332 assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Failed));
333 assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Succeeded));
334 assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Failed));
335 assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Cancelled));
336 }
337
338 #[test]
339 fn tool_call_status_rejects_terminal_reopen_transitions() {
340 assert!(!ToolCallStatus::Succeeded.can_transition_to(ToolCallStatus::Running));
341 assert!(!ToolCallStatus::Failed.can_transition_to(ToolCallStatus::Resuming));
342 assert!(!ToolCallStatus::Cancelled.can_transition_to(ToolCallStatus::Suspended));
343 }
344
345 #[test]
346 fn suspended_call_serde_flatten_roundtrip() {
347 use crate::runtime::tool_call::Suspension;
348
349 let call = SuspendedCall {
350 call_id: "call_1".into(),
351 tool_name: "my_tool".into(),
352 arguments: serde_json::json!({"key": "val"}),
353 ticket: SuspendTicket::new(
354 Suspension::new("susp_1", "confirm"),
355 PendingToolCall::new("pending_1", "my_tool", serde_json::json!({"key": "val"})),
356 ToolCallResumeMode::UseDecisionAsToolResult,
357 ),
358 };
359
360 let json = serde_json::to_value(&call).unwrap();
361
362 assert!(json.get("ticket").is_none(), "ticket should be flattened");
364 assert!(
365 json.get("suspension").is_some(),
366 "suspension should be at top level"
367 );
368 assert!(
369 json.get("pending").is_some(),
370 "pending should be at top level"
371 );
372 assert!(
373 json.get("resume_mode").is_some(),
374 "resume_mode should be at top level"
375 );
376 assert_eq!(json["call_id"], "call_1");
377 assert_eq!(json["suspension"]["id"], "susp_1");
378 assert_eq!(json["pending"]["id"], "pending_1");
379
380 let deserialized: SuspendedCall = serde_json::from_value(json).unwrap();
382 assert_eq!(deserialized, call);
383 }
384
385 #[test]
386 fn tool_call_ascii_state_machine_contains_all_states() {
387 let diagram = ToolCallStatus::ASCII_STATE_MACHINE;
388 assert!(diagram.contains("new"));
389 assert!(diagram.contains("running"));
390 assert!(diagram.contains("suspended"));
391 assert!(diagram.contains("resuming"));
392 assert!(diagram.contains("succeeded"));
393 assert!(diagram.contains("failed"));
394 assert!(diagram.contains("cancelled"));
395 }
396}