Skip to main content

tirea_contract/runtime/tool_call/
lifecycle.rs

1use crate::runtime::phase::SuspendTicket;
2use crate::thread::ToolCall;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use tirea_state::State;
7
8/// Action to apply for a suspended tool call.
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10#[serde(rename_all = "snake_case")]
11pub enum ResumeDecisionAction {
12    Resume,
13    Cancel,
14}
15
16/// A tool call that has been suspended, awaiting external resolution.
17///
18/// The core loop stores stable call identity, pending interaction payload,
19/// and explicit resume behavior for deterministic replay.
20#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
21#[serde(rename_all = "snake_case")]
22pub enum ToolCallResumeMode {
23    /// Resume by replaying the original backend tool call.
24    #[default]
25    ReplayToolCall,
26    /// Resume by turning external decision payload into tool result directly.
27    UseDecisionAsToolResult,
28    /// Resume by passing external payload back into tool-call arguments.
29    PassDecisionToTool,
30}
31
32/// External pending tool-call projection emitted to event streams.
33#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
34pub struct PendingToolCall {
35    pub id: String,
36    pub name: String,
37    pub arguments: Value,
38}
39
40impl PendingToolCall {
41    pub fn new(id: impl Into<String>, name: impl Into<String>, arguments: Value) -> Self {
42        Self {
43            id: id.into(),
44            name: name.into(),
45            arguments,
46        }
47    }
48}
49
50#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
51pub struct SuspendedCall {
52    /// Original backend call identity.
53    #[serde(default)]
54    pub call_id: String,
55    /// Original backend tool name.
56    #[serde(default)]
57    pub tool_name: String,
58    /// Original backend tool arguments.
59    #[serde(default)]
60    pub arguments: Value,
61    /// Suspension ticket carrying interaction payload, pending projection, and resume strategy.
62    #[serde(flatten)]
63    pub ticket: SuspendTicket,
64}
65
66impl SuspendedCall {
67    /// Create a suspended call from a tool call and a suspend ticket.
68    pub fn new(call: &ToolCall, ticket: SuspendTicket) -> Self {
69        Self {
70            call_id: call.id.clone(),
71            tool_name: call.name.clone(),
72            arguments: call.arguments.clone(),
73            ticket,
74        }
75    }
76
77    /// Convert into a type-erased state action targeting this call's scope.
78    ///
79    /// Equivalent to `AnyStateAction::new_for_call::<SuspendedCallState>(Set(self), call_id)`
80    /// but hides the internal `SuspendedCallState` / `SuspendedCallAction` types.
81    pub fn into_state_action(self) -> crate::runtime::state::AnyStateAction {
82        let call_id = self.call_id.clone();
83        crate::runtime::state::AnyStateAction::new_for_call::<SuspendedCallState>(
84            SuspendedCallAction::Set(self),
85            call_id,
86        )
87    }
88}
89
90/// Per-tool-call suspended state stored at `__tool_call_scope.<call_id>.suspended_call`.
91///
92/// When a tool call is suspended, this state holds the suspension ticket, pending
93/// interaction payload, and resume strategy. It is automatically deleted when the
94/// tool call reaches a terminal outcome (Succeeded/Failed/Cancelled).
95#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
96#[tirea(
97    path = "suspended_call",
98    action = "SuspendedCallAction",
99    scope = "tool_call"
100)]
101pub struct SuspendedCallState {
102    /// The suspended call data (flattened for serialization).
103    #[serde(flatten)]
104    pub call: SuspendedCall,
105}
106
107/// Action type for `SuspendedCallState` reducer.
108#[derive(Serialize, Deserialize)]
109pub enum SuspendedCallAction {
110    /// Set the suspended call state.
111    Set(SuspendedCall),
112}
113
114impl SuspendedCallState {
115    fn reduce(&mut self, action: SuspendedCallAction) {
116        match action {
117            SuspendedCallAction::Set(call) => {
118                self.call = call;
119            }
120        }
121    }
122}
123
124/// Action type for `ToolCallState` reducer.
125#[derive(Serialize, Deserialize)]
126pub enum ToolCallStateAction {
127    /// Set the full tool call state (used by recovery and normal updates).
128    Set(ToolCallState),
129}
130
131/// Tool call lifecycle status for suspend/resume capable execution.
132#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
133#[serde(rename_all = "snake_case")]
134pub enum ToolCallStatus {
135    /// Newly observed call that has not started execution yet.
136    #[default]
137    New,
138    /// Call is currently executing.
139    Running,
140    /// Call is suspended waiting for a resume decision.
141    Suspended,
142    /// Call is resuming with external decision input.
143    Resuming,
144    /// Call finished successfully.
145    Succeeded,
146    /// Call finished with failure.
147    Failed,
148    /// Call was cancelled.
149    Cancelled,
150}
151
152impl ToolCallStatus {
153    /// Canonical tool-call lifecycle state machine used by runtime tests.
154    pub const ASCII_STATE_MACHINE: &str = r#"new ------------> running
155 |                  |
156 |                  v
157 +------------> suspended -----> resuming
158                    |               |
159                    +---------------+
160
161running/resuming ---> succeeded
162running/resuming ---> failed
163running/suspended/resuming ---> cancelled"#;
164
165    /// Whether this status is terminal (no further lifecycle transition expected).
166    pub fn is_terminal(self) -> bool {
167        matches!(
168            self,
169            ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled
170        )
171    }
172
173    /// Validate lifecycle transition from `self` to `next`.
174    pub fn can_transition_to(self, next: Self) -> bool {
175        if self == next {
176            return true;
177        }
178
179        match self {
180            ToolCallStatus::New => true,
181            ToolCallStatus::Running => matches!(
182                next,
183                ToolCallStatus::Suspended
184                    | ToolCallStatus::Succeeded
185                    | ToolCallStatus::Failed
186                    | ToolCallStatus::Cancelled
187            ),
188            ToolCallStatus::Suspended => {
189                matches!(next, ToolCallStatus::Resuming | ToolCallStatus::Cancelled)
190            }
191            ToolCallStatus::Resuming => matches!(
192                next,
193                ToolCallStatus::Running
194                    | ToolCallStatus::Suspended
195                    | ToolCallStatus::Succeeded
196                    | ToolCallStatus::Failed
197                    | ToolCallStatus::Cancelled
198            ),
199            ToolCallStatus::Succeeded | ToolCallStatus::Failed | ToolCallStatus::Cancelled => false,
200        }
201    }
202}
203
204/// Resume input payload attached to a suspended tool call.
205#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
206pub struct ToolCallResume {
207    /// Idempotency key for the decision submission.
208    #[serde(default)]
209    pub decision_id: String,
210    /// Resume or cancel action.
211    pub action: ResumeDecisionAction,
212    /// Raw response payload from suspension/frontend.
213    #[serde(default, skip_serializing_if = "Value::is_null")]
214    pub result: Value,
215    /// Optional human-readable reason.
216    #[serde(default, skip_serializing_if = "Option::is_none")]
217    pub reason: Option<String>,
218    /// Decision update timestamp (unix millis).
219    #[serde(default)]
220    pub updated_at: u64,
221}
222
223/// Durable per-tool-call runtime state.
224///
225/// Stored under `__tool_call_scope.<call_id>.tool_call_state` (ToolCall-scoped).
226#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, State)]
227#[tirea(
228    path = "tool_call_state",
229    action = "ToolCallStateAction",
230    scope = "tool_call"
231)]
232pub struct ToolCallState {
233    /// Stable tool call id.
234    #[serde(default, skip_serializing_if = "String::is_empty")]
235    pub call_id: String,
236    /// Tool name.
237    #[serde(default, skip_serializing_if = "String::is_empty")]
238    pub tool_name: String,
239    /// Tool arguments snapshot.
240    #[serde(default, skip_serializing_if = "Value::is_null")]
241    pub arguments: Value,
242    /// Lifecycle status.
243    #[serde(default)]
244    pub status: ToolCallStatus,
245    /// Token used by external actor to resume this call.
246    #[serde(default, skip_serializing_if = "Option::is_none")]
247    pub resume_token: Option<String>,
248    /// Resume payload written by external decision handling.
249    #[serde(default, skip_serializing_if = "Option::is_none")]
250    pub resume: Option<ToolCallResume>,
251    /// Plugin/tool scratch data for this call.
252    #[serde(default, skip_serializing_if = "Value::is_null")]
253    pub scratch: Value,
254    /// Last update timestamp (unix millis).
255    #[serde(default)]
256    pub updated_at: u64,
257}
258
259impl ToolCallState {
260    /// Convert into a type-erased state action targeting this call's scope.
261    ///
262    /// Equivalent to `AnyStateAction::new_for_call::<ToolCallState>(Set(self), call_id)`
263    /// but hides the internal `ToolCallStateAction` type.
264    pub fn into_state_action(self) -> crate::runtime::state::AnyStateAction {
265        let call_id = self.call_id.clone();
266        crate::runtime::state::AnyStateAction::new_for_call::<ToolCallState>(
267            ToolCallStateAction::Set(self),
268            call_id,
269        )
270    }
271}
272
273impl ToolCallState {
274    fn reduce(&mut self, action: ToolCallStateAction) {
275        match action {
276            ToolCallStateAction::Set(s) => *self = s,
277        }
278    }
279}
280
281/// Parse suspended tool calls from a rebuilt state snapshot.
282pub fn suspended_calls_from_state(state: &Value) -> HashMap<String, SuspendedCall> {
283    let Some(Value::Object(scopes)) = state.get("__tool_call_scope") else {
284        return HashMap::new();
285    };
286    scopes
287        .iter()
288        .filter_map(|(call_id, scope_val)| {
289            scope_val
290                .get("suspended_call")
291                .and_then(|v| SuspendedCallState::from_value(v).ok())
292                .map(|s| (call_id.clone(), s.call))
293        })
294        .collect()
295}
296
297/// Parse persisted tool call runtime states from a rebuilt state snapshot.
298///
299/// Iterates `__tool_call_scope.*["tool_call_state"]` to enumerate all call states.
300pub fn tool_call_states_from_state(state: &Value) -> HashMap<String, ToolCallState> {
301    let Some(Value::Object(scopes)) = state.get("__tool_call_scope") else {
302        return HashMap::new();
303    };
304    scopes
305        .iter()
306        .filter_map(|(call_id, scope_val)| {
307            scope_val
308                .get("tool_call_state")
309                .and_then(|v| ToolCallState::from_value(v).ok())
310                .map(|s| (call_id.clone(), s))
311        })
312        .collect()
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn suspended_call_state_default() {
321        let suspended = SuspendedCallState::default();
322        assert_eq!(suspended.call.call_id, "");
323        assert_eq!(suspended.call.tool_name, "");
324    }
325
326    #[test]
327    fn tool_call_status_transitions_match_lifecycle() {
328        assert!(ToolCallStatus::New.can_transition_to(ToolCallStatus::Running));
329        assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Suspended));
330        assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Resuming));
331        assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Running));
332        assert!(ToolCallStatus::Resuming.can_transition_to(ToolCallStatus::Failed));
333        assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Succeeded));
334        assert!(ToolCallStatus::Running.can_transition_to(ToolCallStatus::Failed));
335        assert!(ToolCallStatus::Suspended.can_transition_to(ToolCallStatus::Cancelled));
336    }
337
338    #[test]
339    fn tool_call_status_rejects_terminal_reopen_transitions() {
340        assert!(!ToolCallStatus::Succeeded.can_transition_to(ToolCallStatus::Running));
341        assert!(!ToolCallStatus::Failed.can_transition_to(ToolCallStatus::Resuming));
342        assert!(!ToolCallStatus::Cancelled.can_transition_to(ToolCallStatus::Suspended));
343    }
344
345    #[test]
346    fn suspended_call_serde_flatten_roundtrip() {
347        use crate::runtime::tool_call::Suspension;
348
349        let call = SuspendedCall {
350            call_id: "call_1".into(),
351            tool_name: "my_tool".into(),
352            arguments: serde_json::json!({"key": "val"}),
353            ticket: SuspendTicket::new(
354                Suspension::new("susp_1", "confirm"),
355                PendingToolCall::new("pending_1", "my_tool", serde_json::json!({"key": "val"})),
356                ToolCallResumeMode::UseDecisionAsToolResult,
357            ),
358        };
359
360        let json = serde_json::to_value(&call).unwrap();
361
362        // Flattened fields should appear at top level, not nested under "ticket"
363        assert!(json.get("ticket").is_none(), "ticket should be flattened");
364        assert!(
365            json.get("suspension").is_some(),
366            "suspension should be at top level"
367        );
368        assert!(
369            json.get("pending").is_some(),
370            "pending should be at top level"
371        );
372        assert!(
373            json.get("resume_mode").is_some(),
374            "resume_mode should be at top level"
375        );
376        assert_eq!(json["call_id"], "call_1");
377        assert_eq!(json["suspension"]["id"], "susp_1");
378        assert_eq!(json["pending"]["id"], "pending_1");
379
380        // Roundtrip: deserialize back
381        let deserialized: SuspendedCall = serde_json::from_value(json).unwrap();
382        assert_eq!(deserialized, call);
383    }
384
385    #[test]
386    fn tool_call_ascii_state_machine_contains_all_states() {
387        let diagram = ToolCallStatus::ASCII_STATE_MACHINE;
388        assert!(diagram.contains("new"));
389        assert!(diagram.contains("running"));
390        assert!(diagram.contains("suspended"));
391        assert!(diagram.contains("resuming"));
392        assert!(diagram.contains("succeeded"));
393        assert!(diagram.contains("failed"));
394        assert!(diagram.contains("cancelled"));
395    }
396}