tirea_contract/runtime/phase/
contexts.rs1use crate::runtime::inference::{InferenceError, StreamResult};
2use crate::runtime::run::RunIdentity;
3use crate::runtime::run::{RunAction, TerminationReason};
4use crate::runtime::tool_call::gate::{SuspendTicket, ToolCallAction};
5use crate::runtime::tool_call::{ToolCallResume, ToolResult};
6use crate::thread::Message;
7use crate::RunPolicy;
8use serde_json::Value;
9use std::sync::Arc;
10use tirea_state::State;
11
12use super::step::StepContext;
13use super::types::Phase;
14
15pub trait PhaseContext {
17 fn phase(&self) -> Phase;
18 fn thread_id(&self) -> &str;
19 fn messages(&self) -> &[Arc<Message>];
20 fn run_policy(&self) -> &RunPolicy;
21 fn run_identity(&self) -> &RunIdentity;
22 fn state_of<T: State>(&self) -> T::Ref<'_>;
23 fn snapshot(&self) -> serde_json::Value;
24}
25
26macro_rules! impl_phase_context {
27 ($name:ident, $phase:expr) => {
28 impl<'s, 'a> $name<'s, 'a> {
29 pub fn new(step: &'s mut StepContext<'a>) -> Self {
30 Self { step }
31 }
32
33 #[cfg(feature = "test-support")]
34 pub fn step_mut_for_tests(&mut self) -> &mut StepContext<'a> {
35 self.step
36 }
37 }
38
39 impl<'s, 'a> PhaseContext for $name<'s, 'a> {
40 fn phase(&self) -> Phase {
41 $phase
42 }
43
44 fn thread_id(&self) -> &str {
45 self.step.thread_id()
46 }
47
48 fn messages(&self) -> &[Arc<Message>] {
49 self.step.messages()
50 }
51
52 fn run_policy(&self) -> &RunPolicy {
53 self.step.run_policy()
54 }
55
56 fn run_identity(&self) -> &RunIdentity {
57 self.step.run_identity()
58 }
59
60 fn state_of<T: State>(&self) -> T::Ref<'_> {
61 self.step.state_of::<T>()
62 }
63
64 fn snapshot(&self) -> serde_json::Value {
65 self.step.snapshot()
66 }
67 }
68 };
69}
70
71pub struct RunStartContext<'s, 'a> {
72 step: &'s mut StepContext<'a>,
73}
74impl_phase_context!(RunStartContext, Phase::RunStart);
75
76pub struct StepStartContext<'s, 'a> {
77 step: &'s mut StepContext<'a>,
78}
79impl_phase_context!(StepStartContext, Phase::StepStart);
80
81pub struct BeforeInferenceContext<'s, 'a> {
82 step: &'s mut StepContext<'a>,
83}
84impl_phase_context!(BeforeInferenceContext, Phase::BeforeInference);
85
86impl<'s, 'a> BeforeInferenceContext<'s, 'a> {
87 pub fn add_system_context(&mut self, text: impl Into<String>) {
89 self.step.inference.system_context.push(text.into());
90 }
91
92 pub fn add_session_message(&mut self, text: impl Into<String>) {
94 self.step.inference.session_context.push(text.into());
95 }
96
97 pub fn exclude_tool(&mut self, tool_id: &str) {
99 self.step.inference.tools.retain(|t| t.id != tool_id);
100 }
101
102 pub fn include_only(&mut self, tool_ids: &[&str]) {
104 self.step
105 .inference
106 .tools
107 .retain(|t| tool_ids.contains(&t.id.as_str()));
108 }
109
110 pub fn terminate_behavior_requested(&mut self) {
112 self.step.flow.run_action =
113 Some(RunAction::Terminate(TerminationReason::BehaviorRequested));
114 }
115
116 pub fn request_termination(&mut self, reason: TerminationReason) {
118 self.step.flow.run_action = Some(RunAction::Terminate(reason));
119 }
120}
121
122pub struct AfterInferenceContext<'s, 'a> {
123 step: &'s mut StepContext<'a>,
124}
125impl_phase_context!(AfterInferenceContext, Phase::AfterInference);
126
127impl<'s, 'a> AfterInferenceContext<'s, 'a> {
128 pub fn response_opt(&self) -> Option<&StreamResult> {
129 self.step
130 .llm_response
131 .as_ref()
132 .and_then(|r| r.outcome.as_ref().ok())
133 }
134
135 pub fn response(&self) -> &StreamResult {
136 self.step
137 .llm_response
138 .as_ref()
139 .expect("AfterInferenceContext.response() requires LLMResponse to be set")
140 .outcome
141 .as_ref()
142 .expect("AfterInferenceContext.response() requires a successful outcome")
143 }
144
145 pub fn inference_error(&self) -> Option<&InferenceError> {
146 self.step
147 .llm_response
148 .as_ref()
149 .and_then(|r| r.outcome.as_ref().err())
150 }
151
152 pub fn request_termination(&mut self, reason: TerminationReason) {
154 self.step.flow.run_action = Some(RunAction::Terminate(reason));
155 }
156}
157
158pub struct BeforeToolExecuteContext<'s, 'a> {
159 step: &'s mut StepContext<'a>,
160}
161impl_phase_context!(BeforeToolExecuteContext, Phase::BeforeToolExecute);
162
163impl<'s, 'a> BeforeToolExecuteContext<'s, 'a> {
164 pub fn tool_name(&self) -> Option<&str> {
165 self.step.tool_name()
166 }
167
168 pub fn tool_call_id(&self) -> Option<&str> {
169 self.step.tool_call_id()
170 }
171
172 pub fn tool_args(&self) -> Option<&Value> {
173 self.step.tool_args()
174 }
175
176 pub fn resume_input(&self) -> Option<ToolCallResume> {
178 let gate = self.step.gate.as_ref()?;
179 self.step.ctx().resume_input_for(&gate.id).ok().flatten()
180 }
181
182 pub fn decision(&self) -> ToolCallAction {
183 self.step.tool_action()
184 }
185
186 pub fn set_decision(&mut self, decision: ToolCallAction) {
187 if let Some(gate) = self.step.gate.as_mut() {
188 match decision {
189 ToolCallAction::Proceed => {
190 gate.blocked = false;
191 gate.block_reason = None;
192 gate.pending = false;
193 gate.suspend_ticket = None;
194 }
195 ToolCallAction::Suspend(ticket) => {
196 gate.blocked = false;
197 gate.block_reason = None;
198 gate.pending = true;
199 gate.suspend_ticket = Some(*ticket);
200 }
201 ToolCallAction::Block { reason } => {
202 gate.blocked = true;
203 gate.block_reason = Some(reason);
204 gate.pending = false;
205 gate.suspend_ticket = None;
206 }
207 }
208 }
209 }
210
211 pub fn block(&mut self, reason: impl Into<String>) {
212 if let Some(gate) = self.step.gate.as_mut() {
213 gate.blocked = true;
214 gate.block_reason = Some(reason.into());
215 gate.pending = false;
216 gate.suspend_ticket = None;
217 }
218 }
219
220 pub fn allow(&mut self) {
224 if let Some(gate) = self.step.gate.as_mut() {
225 gate.blocked = false;
226 gate.block_reason = None;
227 gate.pending = false;
228 gate.suspend_ticket = None;
229 }
230 }
231
232 pub fn set_tool_result(&mut self, result: ToolResult) {
237 if let Some(gate) = self.step.gate.as_mut() {
238 gate.result = Some(result);
239 }
240 }
241
242 pub fn suspend(&mut self, ticket: SuspendTicket) {
243 if let Some(gate) = self.step.gate.as_mut() {
244 gate.blocked = false;
245 gate.block_reason = None;
246 gate.pending = true;
247 gate.suspend_ticket = Some(ticket);
248 }
249 }
250}
251
252pub struct AfterToolExecuteContext<'s, 'a> {
253 step: &'s mut StepContext<'a>,
254}
255impl_phase_context!(AfterToolExecuteContext, Phase::AfterToolExecute);
256
257impl<'s, 'a> AfterToolExecuteContext<'s, 'a> {
258 pub fn tool_name(&self) -> Option<&str> {
259 self.step.tool_name()
260 }
261
262 pub fn tool_call_id(&self) -> Option<&str> {
263 self.step.tool_call_id()
264 }
265
266 pub fn tool_result(&self) -> &ToolResult {
267 self.step
268 .gate
269 .as_ref()
270 .and_then(|g| g.result.as_ref())
271 .expect("AfterToolExecuteContext.tool_result() requires tool result")
272 }
273
274 pub fn add_system_reminder(&mut self, text: impl Into<String>) {
275 self.step.messaging.reminders.push(text.into());
276 }
277}
278
279pub struct StepEndContext<'s, 'a> {
280 step: &'s mut StepContext<'a>,
281}
282impl_phase_context!(StepEndContext, Phase::StepEnd);
283
284pub struct RunEndContext<'s, 'a> {
285 step: &'s mut StepContext<'a>,
286}
287impl_phase_context!(RunEndContext, Phase::RunEnd);