Skip to main content

tirea_contract/runtime/phase/
contexts.rs

1use crate::runtime::inference::{InferenceError, StreamResult};
2use crate::runtime::run::RunIdentity;
3use crate::runtime::run::{RunAction, TerminationReason};
4use crate::runtime::tool_call::gate::{SuspendTicket, ToolCallAction};
5use crate::runtime::tool_call::{ToolCallResume, ToolResult};
6use crate::thread::Message;
7use crate::RunPolicy;
8use serde_json::Value;
9use std::sync::Arc;
10use tirea_state::State;
11
12use super::step::StepContext;
13use super::types::Phase;
14
15/// Shared read access available to all phase contexts.
16pub trait PhaseContext {
17    fn phase(&self) -> Phase;
18    fn thread_id(&self) -> &str;
19    fn messages(&self) -> &[Arc<Message>];
20    fn run_policy(&self) -> &RunPolicy;
21    fn run_identity(&self) -> &RunIdentity;
22    fn state_of<T: State>(&self) -> T::Ref<'_>;
23    fn snapshot(&self) -> serde_json::Value;
24}
25
26macro_rules! impl_phase_context {
27    ($name:ident, $phase:expr) => {
28        impl<'s, 'a> $name<'s, 'a> {
29            pub fn new(step: &'s mut StepContext<'a>) -> Self {
30                Self { step }
31            }
32
33            #[cfg(feature = "test-support")]
34            pub fn step_mut_for_tests(&mut self) -> &mut StepContext<'a> {
35                self.step
36            }
37        }
38
39        impl<'s, 'a> PhaseContext for $name<'s, 'a> {
40            fn phase(&self) -> Phase {
41                $phase
42            }
43
44            fn thread_id(&self) -> &str {
45                self.step.thread_id()
46            }
47
48            fn messages(&self) -> &[Arc<Message>] {
49                self.step.messages()
50            }
51
52            fn run_policy(&self) -> &RunPolicy {
53                self.step.run_policy()
54            }
55
56            fn run_identity(&self) -> &RunIdentity {
57                self.step.run_identity()
58            }
59
60            fn state_of<T: State>(&self) -> T::Ref<'_> {
61                self.step.state_of::<T>()
62            }
63
64            fn snapshot(&self) -> serde_json::Value {
65                self.step.snapshot()
66            }
67        }
68    };
69}
70
71pub struct RunStartContext<'s, 'a> {
72    step: &'s mut StepContext<'a>,
73}
74impl_phase_context!(RunStartContext, Phase::RunStart);
75
76pub struct StepStartContext<'s, 'a> {
77    step: &'s mut StepContext<'a>,
78}
79impl_phase_context!(StepStartContext, Phase::StepStart);
80
81pub struct BeforeInferenceContext<'s, 'a> {
82    step: &'s mut StepContext<'a>,
83}
84impl_phase_context!(BeforeInferenceContext, Phase::BeforeInference);
85
86impl<'s, 'a> BeforeInferenceContext<'s, 'a> {
87    /// Append a system context line.
88    pub fn add_system_context(&mut self, text: impl Into<String>) {
89        self.step.inference.system_context.push(text.into());
90    }
91
92    /// Append a session message.
93    pub fn add_session_message(&mut self, text: impl Into<String>) {
94        self.step.inference.session_context.push(text.into());
95    }
96
97    /// Exclude tool by id.
98    pub fn exclude_tool(&mut self, tool_id: &str) {
99        self.step.inference.tools.retain(|t| t.id != tool_id);
100    }
101
102    /// Keep only listed tools.
103    pub fn include_only(&mut self, tool_ids: &[&str]) {
104        self.step
105            .inference
106            .tools
107            .retain(|t| tool_ids.contains(&t.id.as_str()));
108    }
109
110    /// Terminate current run as behavior-requested before inference.
111    pub fn terminate_behavior_requested(&mut self) {
112        self.step.flow.run_action =
113            Some(RunAction::Terminate(TerminationReason::BehaviorRequested));
114    }
115
116    /// Request run termination with a specific reason.
117    pub fn request_termination(&mut self, reason: TerminationReason) {
118        self.step.flow.run_action = Some(RunAction::Terminate(reason));
119    }
120}
121
122pub struct AfterInferenceContext<'s, 'a> {
123    step: &'s mut StepContext<'a>,
124}
125impl_phase_context!(AfterInferenceContext, Phase::AfterInference);
126
127impl<'s, 'a> AfterInferenceContext<'s, 'a> {
128    pub fn response_opt(&self) -> Option<&StreamResult> {
129        self.step
130            .llm_response
131            .as_ref()
132            .and_then(|r| r.outcome.as_ref().ok())
133    }
134
135    pub fn response(&self) -> &StreamResult {
136        self.step
137            .llm_response
138            .as_ref()
139            .expect("AfterInferenceContext.response() requires LLMResponse to be set")
140            .outcome
141            .as_ref()
142            .expect("AfterInferenceContext.response() requires a successful outcome")
143    }
144
145    pub fn inference_error(&self) -> Option<&InferenceError> {
146        self.step
147            .llm_response
148            .as_ref()
149            .and_then(|r| r.outcome.as_ref().err())
150    }
151
152    /// Request run termination with a specific reason after inference has completed.
153    pub fn request_termination(&mut self, reason: TerminationReason) {
154        self.step.flow.run_action = Some(RunAction::Terminate(reason));
155    }
156}
157
158pub struct BeforeToolExecuteContext<'s, 'a> {
159    step: &'s mut StepContext<'a>,
160}
161impl_phase_context!(BeforeToolExecuteContext, Phase::BeforeToolExecute);
162
163impl<'s, 'a> BeforeToolExecuteContext<'s, 'a> {
164    pub fn tool_name(&self) -> Option<&str> {
165        self.step.tool_name()
166    }
167
168    pub fn tool_call_id(&self) -> Option<&str> {
169        self.step.tool_call_id()
170    }
171
172    pub fn tool_args(&self) -> Option<&Value> {
173        self.step.tool_args()
174    }
175
176    /// Resume payload attached to current tool call, if present.
177    pub fn resume_input(&self) -> Option<ToolCallResume> {
178        let gate = self.step.gate.as_ref()?;
179        self.step.ctx().resume_input_for(&gate.id).ok().flatten()
180    }
181
182    pub fn decision(&self) -> ToolCallAction {
183        self.step.tool_action()
184    }
185
186    pub fn set_decision(&mut self, decision: ToolCallAction) {
187        if let Some(gate) = self.step.gate.as_mut() {
188            match decision {
189                ToolCallAction::Proceed => {
190                    gate.blocked = false;
191                    gate.block_reason = None;
192                    gate.pending = false;
193                    gate.suspend_ticket = None;
194                }
195                ToolCallAction::Suspend(ticket) => {
196                    gate.blocked = false;
197                    gate.block_reason = None;
198                    gate.pending = true;
199                    gate.suspend_ticket = Some(*ticket);
200                }
201                ToolCallAction::Block { reason } => {
202                    gate.blocked = true;
203                    gate.block_reason = Some(reason);
204                    gate.pending = false;
205                    gate.suspend_ticket = None;
206                }
207            }
208        }
209    }
210
211    pub fn block(&mut self, reason: impl Into<String>) {
212        if let Some(gate) = self.step.gate.as_mut() {
213            gate.blocked = true;
214            gate.block_reason = Some(reason.into());
215            gate.pending = false;
216            gate.suspend_ticket = None;
217        }
218    }
219
220    /// Explicitly allow tool execution.
221    ///
222    /// This clears any previous block/suspend state set by earlier plugins.
223    pub fn allow(&mut self) {
224        if let Some(gate) = self.step.gate.as_mut() {
225            gate.blocked = false;
226            gate.block_reason = None;
227            gate.pending = false;
228            gate.suspend_ticket = None;
229        }
230    }
231
232    /// Override current call result directly from plugin logic.
233    ///
234    /// Useful for resumed frontend interactions where the external payload
235    /// should become the tool result without executing a backend tool.
236    pub fn set_tool_result(&mut self, result: ToolResult) {
237        if let Some(gate) = self.step.gate.as_mut() {
238            gate.result = Some(result);
239        }
240    }
241
242    pub fn suspend(&mut self, ticket: SuspendTicket) {
243        if let Some(gate) = self.step.gate.as_mut() {
244            gate.blocked = false;
245            gate.block_reason = None;
246            gate.pending = true;
247            gate.suspend_ticket = Some(ticket);
248        }
249    }
250}
251
252pub struct AfterToolExecuteContext<'s, 'a> {
253    step: &'s mut StepContext<'a>,
254}
255impl_phase_context!(AfterToolExecuteContext, Phase::AfterToolExecute);
256
257impl<'s, 'a> AfterToolExecuteContext<'s, 'a> {
258    pub fn tool_name(&self) -> Option<&str> {
259        self.step.tool_name()
260    }
261
262    pub fn tool_call_id(&self) -> Option<&str> {
263        self.step.tool_call_id()
264    }
265
266    pub fn tool_result(&self) -> &ToolResult {
267        self.step
268            .gate
269            .as_ref()
270            .and_then(|g| g.result.as_ref())
271            .expect("AfterToolExecuteContext.tool_result() requires tool result")
272    }
273
274    pub fn add_system_reminder(&mut self, text: impl Into<String>) {
275        self.step.messaging.reminders.push(text.into());
276    }
277}
278
279pub struct StepEndContext<'s, 'a> {
280    step: &'s mut StepContext<'a>,
281}
282impl_phase_context!(StepEndContext, Phase::StepEnd);
283
284pub struct RunEndContext<'s, 'a> {
285    step: &'s mut StepContext<'a>,
286}
287impl_phase_context!(RunEndContext, Phase::RunEnd);