1use crate::runtime::inference::response::{InferenceError, LLMResponse, StreamResult};
2use crate::runtime::phase::step::StepContext;
3use crate::runtime::phase::Phase;
4use crate::runtime::phase::{
5 ActionSet, AfterInferenceAction, AfterToolExecuteAction, BeforeInferenceAction,
6 BeforeToolExecuteAction, LifecycleAction,
7};
8use crate::runtime::run::RunIdentity;
9use crate::runtime::state::StateScopeRegistry;
10use crate::runtime::state::{ScopeContext, StateActionDeserializerRegistry, StateScope, StateSpec};
11use crate::runtime::tool_call::{ToolCallResume, ToolResult};
12use crate::thread::Message;
13use crate::RunPolicy;
14use async_trait::async_trait;
15use serde_json::Value;
16use std::sync::Arc;
17use tirea_state::{get_at_path, parse_path, DocCell, LatticeRegistry, State, TireaResult};
18
19pub struct ReadOnlyContext<'a> {
25 phase: Phase,
26 thread_id: &'a str,
27 messages: &'a [Arc<Message>],
28 run_policy: &'a RunPolicy,
29 run_identity: RunIdentity,
30 doc: &'a DocCell,
31 llm_response: Option<&'a LLMResponse>,
32 tool_name: Option<&'a str>,
33 tool_call_id: Option<&'a str>,
34 tool_args: Option<&'a Value>,
35 tool_result: Option<&'a ToolResult>,
36 resume_input: Option<ToolCallResume>,
37 scope_ctx: ScopeContext,
38 initial_message_count: usize,
39}
40
41impl<'a> ReadOnlyContext<'a> {
42 pub fn new(
43 phase: Phase,
44 thread_id: &'a str,
45 messages: &'a [Arc<Message>],
46 run_policy: &'a RunPolicy,
47 doc: &'a DocCell,
48 ) -> Self {
49 Self {
50 phase,
51 thread_id,
52 messages,
53 run_policy,
54 run_identity: RunIdentity::default(),
55 doc,
56 llm_response: None,
57 tool_name: None,
58 tool_call_id: None,
59 tool_args: None,
60 tool_result: None,
61 resume_input: None,
62 scope_ctx: ScopeContext::run(),
63 initial_message_count: 0,
64 }
65 }
66
67 #[must_use]
68 pub fn with_llm_response(mut self, response: &'a LLMResponse) -> Self {
69 self.llm_response = Some(response);
70 self
71 }
72
73 #[must_use]
74 pub fn with_tool_info(
75 mut self,
76 name: &'a str,
77 call_id: &'a str,
78 args: Option<&'a Value>,
79 ) -> Self {
80 self.tool_name = Some(name);
81 self.tool_call_id = Some(call_id);
82 self.tool_args = args;
83 self
84 }
85
86 #[must_use]
87 pub fn with_tool_result(mut self, result: &'a ToolResult) -> Self {
88 self.tool_result = Some(result);
89 self
90 }
91
92 #[must_use]
93 pub fn with_resume_input(mut self, resume: ToolCallResume) -> Self {
94 self.resume_input = Some(resume);
95 self
96 }
97
98 #[must_use]
99 pub fn with_scope_ctx(mut self, scope_ctx: ScopeContext) -> Self {
100 self.scope_ctx = scope_ctx;
101 self
102 }
103
104 pub fn phase(&self) -> Phase {
105 self.phase
106 }
107
108 pub fn thread_id(&self) -> &str {
109 self.thread_id
110 }
111
112 pub fn messages(&self) -> &[Arc<Message>] {
113 self.messages
114 }
115
116 pub fn initial_message_count(&self) -> usize {
118 self.initial_message_count
119 }
120
121 pub fn run_policy(&self) -> &RunPolicy {
122 self.run_policy
123 }
124
125 pub fn run_identity(&self) -> &RunIdentity {
126 &self.run_identity
127 }
128
129 pub fn doc(&self) -> &DocCell {
130 self.doc
131 }
132
133 pub fn response(&self) -> Option<&StreamResult> {
134 self.llm_response.and_then(|r| r.outcome.as_ref().ok())
135 }
136
137 pub fn inference_error(&self) -> Option<&InferenceError> {
138 self.llm_response.and_then(|r| r.outcome.as_ref().err())
139 }
140
141 pub fn tool_name(&self) -> Option<&str> {
142 self.tool_name
143 }
144
145 pub fn tool_call_id(&self) -> Option<&str> {
146 self.tool_call_id
147 }
148
149 pub fn tool_args(&self) -> Option<&Value> {
150 self.tool_args
151 }
152
153 pub fn tool_result(&self) -> Option<&ToolResult> {
154 self.tool_result
155 }
156
157 pub fn resume_input(&self) -> Option<&ToolCallResume> {
158 self.resume_input.as_ref()
159 }
160
161 pub fn snapshot(&self) -> Value {
162 self.doc.snapshot()
163 }
164
165 pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
166 let val = self.doc.snapshot();
167 let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
168 T::from_value(at)
169 }
170
171 pub fn scoped_state_of<T: StateSpec>(&self, scope: StateScope) -> TireaResult<T> {
173 let path = self.scope_ctx.resolve_path(scope, T::PATH);
174 let val = self.doc.snapshot();
175 let at = get_at_path(&val, &parse_path(&path)).unwrap_or(&Value::Null);
176 T::from_value(at).or_else(|e| {
177 if at.is_null() {
178 T::from_value(&Value::Object(Default::default())).map_err(|_| e)
179 } else {
180 Err(e)
181 }
182 })
183 }
184
185 pub fn scope_ctx(&self) -> &ScopeContext {
186 &self.scope_ctx
187 }
188
189 #[must_use]
190 pub fn with_run_identity(mut self, run_identity: &RunIdentity) -> Self {
191 self.run_identity = run_identity.clone();
192 self
193 }
194}
195
196#[async_trait]
205pub trait AgentBehavior: Send + Sync {
206 fn id(&self) -> &str;
207
208 fn behavior_ids(&self) -> Vec<&str> {
209 vec![self.id()]
210 }
211
212 fn register_lattice_paths(&self, _registry: &mut LatticeRegistry) {}
214
215 fn register_state_scopes(&self, _registry: &mut StateScopeRegistry) {}
217
218 fn register_state_action_deserializers(&self, _registry: &mut StateActionDeserializerRegistry) {
220 }
221
222 async fn run_start(&self, _ctx: &ReadOnlyContext<'_>) -> ActionSet<LifecycleAction> {
223 ActionSet::empty()
224 }
225
226 async fn step_start(&self, _ctx: &ReadOnlyContext<'_>) -> ActionSet<LifecycleAction> {
227 ActionSet::empty()
228 }
229
230 async fn before_inference(
231 &self,
232 _ctx: &ReadOnlyContext<'_>,
233 ) -> ActionSet<BeforeInferenceAction> {
234 ActionSet::empty()
235 }
236
237 async fn after_inference(&self, _ctx: &ReadOnlyContext<'_>) -> ActionSet<AfterInferenceAction> {
238 ActionSet::empty()
239 }
240
241 async fn before_tool_execute(
242 &self,
243 _ctx: &ReadOnlyContext<'_>,
244 ) -> ActionSet<BeforeToolExecuteAction> {
245 ActionSet::empty()
246 }
247
248 async fn after_tool_execute(
249 &self,
250 _ctx: &ReadOnlyContext<'_>,
251 ) -> ActionSet<AfterToolExecuteAction> {
252 ActionSet::empty()
253 }
254
255 async fn step_end(&self, _ctx: &ReadOnlyContext<'_>) -> ActionSet<LifecycleAction> {
256 ActionSet::empty()
257 }
258
259 async fn run_end(&self, _ctx: &ReadOnlyContext<'_>) -> ActionSet<LifecycleAction> {
260 ActionSet::empty()
261 }
262}
263
264pub struct NoOpBehavior;
266
267#[async_trait]
268impl AgentBehavior for NoOpBehavior {
269 fn id(&self) -> &str {
270 "noop"
271 }
272}
273
274pub fn build_read_only_context_from_step<'a>(
276 phase: Phase,
277 step: &'a StepContext<'a>,
278 doc: &'a DocCell,
279) -> ReadOnlyContext<'a> {
280 let mut ctx = ReadOnlyContext::new(
281 phase,
282 step.thread_id(),
283 step.messages(),
284 step.run_policy(),
285 doc,
286 )
287 .with_run_identity(step.ctx().run_identity());
288 ctx.initial_message_count = step.initial_message_count();
289 if let Some(llm) = step.llm_response.as_ref() {
290 ctx = ctx.with_llm_response(llm);
291 }
292 if let Some(gate) = step.gate.as_ref() {
293 ctx = ctx.with_tool_info(&gate.name, &gate.id, Some(&gate.args));
294 if let Some(result) = gate.result.as_ref() {
295 ctx = ctx.with_tool_result(result);
296 }
297 if matches!(phase, Phase::BeforeToolExecute | Phase::AfterToolExecute) {
298 ctx = ctx.with_scope_ctx(ScopeContext::for_call(&gate.id));
299 }
300 }
301 if phase == Phase::BeforeToolExecute {
302 if let Some(call_id) = step.tool_call_id() {
303 if let Ok(Some(resume)) = step.ctx().resume_input_for(call_id) {
304 ctx = ctx.with_resume_input(resume);
305 }
306 }
307 }
308 ctx
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use serde_json::json;
315
316 #[tokio::test]
317 async fn default_agent_all_phases_noop() {
318 let agent = NoOpBehavior;
319 let config = RunPolicy::new();
320 let doc = DocCell::new(json!({}));
321 let ctx = ReadOnlyContext::new(Phase::RunStart, "t1", &[], &config, &doc);
322
323 let actions = agent.run_start(&ctx).await;
324 assert!(actions.is_empty());
325
326 let ctx = ReadOnlyContext::new(Phase::BeforeInference, "t1", &[], &config, &doc);
327 let actions = agent.before_inference(&ctx).await;
328 assert!(actions.is_empty());
329 }
330
331 #[tokio::test]
332 async fn agent_returns_actions() {
333 struct ContextBehavior;
334
335 #[async_trait]
336 impl AgentBehavior for ContextBehavior {
337 fn id(&self) -> &str {
338 "ctx"
339 }
340 async fn before_inference(
341 &self,
342 _ctx: &ReadOnlyContext<'_>,
343 ) -> ActionSet<BeforeInferenceAction> {
344 ActionSet::single(BeforeInferenceAction::AddSystemContext("from agent".into()))
345 }
346 }
347
348 let agent = ContextBehavior;
349 let config = RunPolicy::new();
350 let doc = DocCell::new(json!({}));
351 let ctx = ReadOnlyContext::new(Phase::BeforeInference, "t1", &[], &config, &doc);
352
353 let actions = agent.before_inference(&ctx).await;
354 assert_eq!(actions.len(), 1);
355 }
356
357 #[tokio::test]
358 async fn read_only_context_accessors() {
359 let config = RunPolicy::new();
360 let doc = DocCell::new(json!({"key": "val"}));
361 let ctx = ReadOnlyContext::new(Phase::AfterToolExecute, "thread_42", &[], &config, &doc);
362
363 assert_eq!(ctx.phase(), Phase::AfterToolExecute);
364 assert_eq!(ctx.thread_id(), "thread_42");
365 assert!(ctx.messages().is_empty());
366 assert!(ctx.tool_name().is_none());
367 assert!(ctx.tool_result().is_none());
368 assert!(ctx.response().is_none());
369 assert!(ctx.resume_input().is_none());
370
371 let snapshot = ctx.snapshot();
372 assert_eq!(snapshot["key"], "val");
373 }
374
375 #[tokio::test]
376 async fn read_only_context_with_tool_info() {
377 let config = RunPolicy::new();
378 let doc = DocCell::new(json!({}));
379 let args = json!({"x": 1});
380 let ctx = ReadOnlyContext::new(Phase::BeforeToolExecute, "t1", &[], &config, &doc)
381 .with_tool_info("my_tool", "call_1", Some(&args));
382
383 assert_eq!(ctx.tool_name(), Some("my_tool"));
384 assert_eq!(ctx.tool_call_id(), Some("call_1"));
385 assert_eq!(ctx.tool_args().unwrap()["x"], 1);
386 }
387}