tirea_contract/runtime/phase/
step.rs1use crate::runtime::inference::{InferenceContext, LLMResponse, MessagingContext};
2use crate::runtime::run::RunIdentity;
3use crate::runtime::run::{FlowControl, RunAction};
4use crate::runtime::state::{AnyStateAction, SerializedStateAction};
5use crate::runtime::tool_call::gate::ToolGate;
6use crate::runtime::tool_call::{ToolCallContext, ToolDescriptor, ToolResult};
7use crate::thread::Message;
8use crate::RunPolicy;
9use serde_json::Value;
10use std::sync::Arc;
11use tirea_state::{State, TireaResult, TrackedPatch};
12
13use super::types::{StepOutcome, ToolCallAction};
14
15pub struct StepContext<'a> {
27 ctx: ToolCallContext<'a>,
29
30 thread_id: &'a str,
32
33 messages: &'a [Arc<Message>],
35
36 initial_message_count: usize,
38
39 pub inference: InferenceContext,
43
44 pub llm_response: Option<LLMResponse>,
47
48 pub gate: Option<ToolGate>,
51
52 pub messaging: MessagingContext,
55
56 pub flow: FlowControl,
59
60 pub pending_state_actions: Vec<AnyStateAction>,
63
64 pub pending_patches: Vec<TrackedPatch>,
67
68 pub(crate) pending_serialized_state_actions: Vec<SerializedStateAction>,
71}
72
73impl<'a> StepContext<'a> {
74 pub fn new(
76 ctx: ToolCallContext<'a>,
77 thread_id: &'a str,
78 messages: &'a [Arc<Message>],
79 tools: Vec<ToolDescriptor>,
80 ) -> Self {
81 Self {
82 ctx,
83 thread_id,
84 messages,
85 initial_message_count: 0,
86 inference: InferenceContext {
87 tools,
88 ..Default::default()
89 },
90 llm_response: None,
91 gate: None,
92 messaging: MessagingContext::default(),
93 flow: FlowControl::default(),
94 pending_state_actions: Vec::new(),
95 pending_patches: Vec::new(),
96 pending_serialized_state_actions: Vec::new(),
97 }
98 }
99
100 pub fn ctx(&self) -> &ToolCallContext<'a> {
105 &self.ctx
106 }
107
108 pub fn thread_id(&self) -> &str {
109 self.thread_id
110 }
111
112 pub fn messages(&self) -> &[Arc<Message>] {
113 self.messages
114 }
115
116 pub fn initial_message_count(&self) -> usize {
118 self.initial_message_count
119 }
120
121 pub fn set_initial_message_count(&mut self, count: usize) {
123 self.initial_message_count = count;
124 }
125
126 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
127 self.ctx.state_of::<T>()
128 }
129
130 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
131 self.ctx.state::<T>(path)
132 }
133
134 pub fn run_policy(&self) -> &RunPolicy {
135 self.ctx.run_policy()
136 }
137
138 pub fn run_identity(&self) -> &RunIdentity {
139 self.ctx.run_identity()
140 }
141
142 pub fn snapshot(&self) -> Value {
143 self.ctx.snapshot()
144 }
145
146 pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
147 self.ctx.snapshot_of::<T>()
148 }
149
150 pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
151 self.ctx.snapshot_at::<T>(path)
152 }
153
154 pub fn reset(&mut self) {
158 let tools = std::mem::take(&mut self.inference.tools);
159 self.inference = InferenceContext {
160 tools,
161 ..Default::default()
162 };
163 self.llm_response = None;
164 self.gate = None;
165 self.messaging = MessagingContext::default();
166 self.flow = FlowControl::default();
167 self.pending_state_actions.clear();
168 self.pending_patches.clear();
169 self.pending_serialized_state_actions.clear();
170 }
171
172 pub fn tool_name(&self) -> Option<&str> {
177 self.gate.as_ref().map(|g| g.name.as_str())
178 }
179
180 pub fn tool_call_id(&self) -> Option<&str> {
181 self.gate.as_ref().map(|g| g.id.as_str())
182 }
183
184 pub fn tool_idempotency_key(&self) -> Option<&str> {
185 self.tool_call_id()
186 }
187
188 pub fn tool_args(&self) -> Option<&Value> {
189 self.gate.as_ref().map(|g| &g.args)
190 }
191
192 pub fn tool_result(&self) -> Option<&ToolResult> {
193 self.gate.as_ref().and_then(|g| g.result.as_ref())
194 }
195
196 pub fn tool_blocked(&self) -> bool {
197 self.gate.as_ref().map(|g| g.blocked).unwrap_or(false)
198 }
199
200 pub fn tool_pending(&self) -> bool {
201 self.gate.as_ref().map(|g| g.pending).unwrap_or(false)
202 }
203
204 pub fn emit_patch(&mut self, patch: TrackedPatch) {
210 self.pending_patches.push(patch);
211 }
212
213 pub fn emit_state_action(&mut self, action: AnyStateAction) {
215 self.pending_state_actions.push(action);
216 }
217
218 pub fn emit_serialized_state_action(&mut self, action: SerializedStateAction) {
220 self.pending_serialized_state_actions.push(action);
221 }
222
223 pub fn take_pending_serialized_state_actions(&mut self) -> Vec<SerializedStateAction> {
225 std::mem::take(&mut self.pending_serialized_state_actions)
226 }
227
228 pub fn run_action(&self) -> RunAction {
234 self.flow.run_action.clone().unwrap_or(RunAction::Continue)
235 }
236
237 pub fn tool_action(&self) -> ToolCallAction {
239 if let Some(gate) = &self.gate {
240 if gate.blocked {
241 return ToolCallAction::Block {
242 reason: gate.block_reason.clone().unwrap_or_default(),
243 };
244 }
245 if gate.pending {
246 if let Some(ticket) = gate.suspend_ticket.as_ref() {
247 return ToolCallAction::suspend(ticket.clone());
248 }
249 return ToolCallAction::Block {
250 reason: "invalid pending tool state: missing suspend ticket".to_string(),
251 };
252 }
253 }
254 ToolCallAction::Proceed
255 }
256
257 pub fn result(&self) -> StepOutcome {
263 if let Some(gate) = &self.gate {
264 if gate.pending {
265 if let Some(ticket) = gate.suspend_ticket.as_ref() {
266 return StepOutcome::Pending(Box::new(ticket.clone()));
267 }
268 return StepOutcome::Continue;
269 }
270 }
271
272 if let Some(llm) = &self.llm_response {
273 if let Ok(result) = &llm.outcome {
274 if result.tool_calls.is_empty() && !result.text.is_empty() {
275 return StepOutcome::Complete;
276 }
277 }
278 }
279
280 StepOutcome::Continue
281 }
282}