1use crate::runtime::activity::ActivityManager;
2use crate::runtime::run::delta::RunDelta;
3use crate::runtime::run::RunIdentity;
4use crate::runtime::state::SerializedStateAction;
5use crate::runtime::suspended_calls_from_state;
6use crate::runtime::tool_call::ToolCallContext;
7use crate::runtime::tool_call::{CallerContext, SuspendedCall};
8use crate::thread::Message;
9use crate::RunPolicy;
10use serde_json::Value;
11use std::sync::{Arc, Mutex};
12use tirea_state::{
13 apply_patches_with_registry, get_at_path, parse_path, DeltaTracked, DocCell, LatticeRegistry,
14 Op, State, TireaResult, TrackedPatch,
15};
16
17pub struct RunContext {
28 thread_base: Value,
29 messages: DeltaTracked<Arc<Message>>,
30 thread_patches: DeltaTracked<TrackedPatch>,
31 serialized_state_actions: DeltaTracked<SerializedStateAction>,
32 run_policy: RunPolicy,
33 run_identity: RunIdentity,
34 doc: DocCell,
35 version: Option<u64>,
36 version_timestamp: Option<u64>,
37 lattice_registry: Arc<LatticeRegistry>,
38}
39
40impl RunContext {
41 pub fn new(
48 thread_id: impl Into<String>,
49 state: Value,
50 messages: Vec<Arc<Message>>,
51 run_policy: RunPolicy,
52 ) -> Self {
53 let thread_id = thread_id.into();
54 Self::with_registry_and_identity(
55 state,
56 messages,
57 run_policy,
58 RunIdentity::for_thread(thread_id),
59 Arc::new(LatticeRegistry::new()),
60 )
61 }
62
63 pub fn with_registry(
65 thread_id: impl Into<String>,
66 state: Value,
67 messages: Vec<Arc<Message>>,
68 run_policy: RunPolicy,
69 lattice_registry: Arc<LatticeRegistry>,
70 ) -> Self {
71 let thread_id = thread_id.into();
72 Self::with_registry_and_identity(
73 state,
74 messages,
75 run_policy,
76 RunIdentity::for_thread(thread_id),
77 lattice_registry,
78 )
79 }
80
81 pub fn with_registry_and_identity(
82 state: Value,
83 messages: Vec<Arc<Message>>,
84 run_policy: RunPolicy,
85 run_identity: RunIdentity,
86 lattice_registry: Arc<LatticeRegistry>,
87 ) -> Self {
88 let doc = DocCell::new(state.clone());
89 Self {
90 thread_base: state,
91 messages: DeltaTracked::new(messages),
92 thread_patches: DeltaTracked::empty(),
93 serialized_state_actions: DeltaTracked::empty(),
94 run_policy,
95 run_identity,
96 doc,
97 version: None,
98 version_timestamp: None,
99 lattice_registry,
100 }
101 }
102
103 pub fn thread_id(&self) -> &str {
109 &self.run_identity.thread_id
110 }
111
112 pub fn run_policy(&self) -> &RunPolicy {
113 &self.run_policy
114 }
115
116 pub fn run_identity(&self) -> &RunIdentity {
117 &self.run_identity
118 }
119
120 pub fn set_run_identity(&mut self, run_identity: RunIdentity) {
121 self.run_identity = run_identity;
122 }
123
124 pub fn version(&self) -> u64 {
130 self.version.unwrap_or(0)
131 }
132
133 pub fn set_version(&mut self, version: u64, timestamp: Option<u64>) {
135 self.version = Some(version);
136 if let Some(ts) = timestamp {
137 self.version_timestamp = Some(ts);
138 }
139 }
140
141 pub fn version_timestamp(&self) -> Option<u64> {
143 self.version_timestamp
144 }
145
146 pub fn suspended_calls(&self) -> std::collections::HashMap<String, SuspendedCall> {
152 self.snapshot()
153 .map(|s| suspended_calls_from_state(&s))
154 .unwrap_or_default()
155 }
156
157 pub fn messages(&self) -> &[Arc<Message>] {
163 self.messages.as_slice()
164 }
165
166 pub fn initial_message_count(&self) -> usize {
168 self.messages.initial_count()
169 }
170
171 pub fn add_message(&mut self, msg: Arc<Message>) {
173 self.messages.push(msg);
174 }
175
176 pub fn add_messages(&mut self, msgs: Vec<Arc<Message>>) {
178 self.messages.extend(msgs);
179 }
180
181 pub fn thread_base(&self) -> &Value {
187 &self.thread_base
188 }
189
190 pub fn add_thread_patch(&mut self, patch: TrackedPatch) {
192 self.thread_patches.push(patch);
193 }
194
195 pub fn add_thread_patches(&mut self, patches: Vec<TrackedPatch>) {
197 self.thread_patches.extend(patches);
198 }
199
200 pub fn thread_patches(&self) -> &[TrackedPatch] {
202 self.thread_patches.as_slice()
203 }
204
205 pub fn add_serialized_state_actions(&mut self, state_actions: Vec<SerializedStateAction>) {
211 self.serialized_state_actions.extend(state_actions);
212 }
213
214 pub fn snapshot(&self) -> TireaResult<Value> {
223 let patches = self.thread_patches.as_slice();
224 if patches.is_empty() {
225 Ok(self.thread_base.clone())
226 } else {
227 apply_patches_with_registry(
228 &self.thread_base,
229 patches.iter().map(|p| p.patch()),
230 &self.lattice_registry,
231 )
232 }
233 }
234
235 pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
239 let val = self.snapshot()?;
240 let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
241 T::from_value(at)
242 }
243
244 pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
248 let val = self.snapshot()?;
249 let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
250 T::from_value(at)
251 }
252
253 pub fn take_delta(&mut self) -> RunDelta {
260 RunDelta {
261 messages: self.messages.take_delta(),
262 patches: self.thread_patches.take_delta(),
263 state_actions: self.serialized_state_actions.take_delta(),
264 }
265 }
266
267 pub fn has_delta(&self) -> bool {
269 self.messages.has_delta()
270 || self.thread_patches.has_delta()
271 || self.serialized_state_actions.has_delta()
272 }
273
274 pub fn tool_call_context<'ctx>(
280 &'ctx self,
281 ops: &'ctx Mutex<Vec<Op>>,
282 call_id: impl Into<String>,
283 source: impl Into<String>,
284 pending_messages: &'ctx Mutex<Vec<Arc<Message>>>,
285 activity_manager: Arc<dyn ActivityManager>,
286 ) -> ToolCallContext<'ctx> {
287 let caller_context = CallerContext::new(
288 Some(self.thread_id().to_string()),
289 self.run_identity.run_id_opt().map(ToOwned::to_owned),
290 self.run_identity.agent_id_opt().map(ToOwned::to_owned),
291 self.messages().to_vec(),
292 );
293 ToolCallContext::new(
294 &self.doc,
295 ops,
296 call_id,
297 source,
298 &self.run_policy,
299 pending_messages,
300 activity_manager,
301 )
302 .with_run_identity(self.run_identity.clone())
303 .with_caller_context(caller_context)
304 }
305}
306
307impl RunContext {
308 pub fn from_thread(
314 thread: &crate::thread::Thread,
315 run_policy: RunPolicy,
316 ) -> Result<Self, tirea_state::TireaError> {
317 Self::from_thread_with_registry_and_identity(
318 thread,
319 run_policy,
320 RunIdentity::for_thread(thread.id.clone()),
321 Arc::new(LatticeRegistry::new()),
322 )
323 }
324
325 pub fn from_thread_with_registry(
327 thread: &crate::thread::Thread,
328 run_policy: RunPolicy,
329 lattice_registry: Arc<LatticeRegistry>,
330 ) -> Result<Self, tirea_state::TireaError> {
331 Self::from_thread_with_registry_and_identity(
332 thread,
333 run_policy,
334 RunIdentity::for_thread(thread.id.clone()),
335 lattice_registry,
336 )
337 }
338
339 pub fn from_thread_with_registry_and_identity(
340 thread: &crate::thread::Thread,
341 run_policy: RunPolicy,
342 mut run_identity: RunIdentity,
343 lattice_registry: Arc<LatticeRegistry>,
344 ) -> Result<Self, tirea_state::TireaError> {
345 if run_identity.thread_id_opt().is_none() {
346 run_identity.thread_id = thread.id.clone();
347 }
348 if run_identity.parent_thread_id_opt().is_none() {
349 run_identity.parent_thread_id = thread.parent_thread_id.clone();
350 }
351 let state = thread.rebuild_state()?;
352 let messages: Vec<Arc<Message>> = thread.messages.clone();
353 let mut ctx = Self::with_registry_and_identity(
354 state,
355 messages,
356 run_policy,
357 run_identity,
358 lattice_registry,
359 );
360 if let Some(v) = thread.metadata.version {
361 ctx.set_version(v, thread.metadata.version_timestamp);
362 }
363 Ok(ctx)
364 }
365
366 pub fn lattice_registry(&self) -> &Arc<LatticeRegistry> {
368 &self.lattice_registry
369 }
370}
371
372impl std::fmt::Debug for RunContext {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 f.debug_struct("RunContext")
375 .field("thread_id", &self.thread_id())
376 .field("messages", &self.messages.len())
377 .field("thread_patches", &self.thread_patches.len())
378 .field("has_delta", &self.has_delta())
379 .finish()
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386 use serde_json::json;
387 use tirea_state::{path, Patch};
388
389 #[test]
390 fn new_context_has_no_delta() {
391 let msgs = vec![Arc::new(Message::user("hi"))];
392 let mut ctx = RunContext::new("t-1", json!({}), msgs, RunPolicy::default());
393 assert!(!ctx.has_delta());
394 let delta = ctx.take_delta();
395 assert!(delta.is_empty());
396 assert_eq!(ctx.messages().len(), 1);
397 }
398
399 #[test]
400 fn add_message_creates_delta() {
401 let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
402 ctx.add_message(Arc::new(Message::user("hello")));
403 ctx.add_message(Arc::new(Message::assistant("hi")));
404 assert!(ctx.has_delta());
405 let delta = ctx.take_delta();
406 assert_eq!(delta.messages.len(), 2);
407 assert!(delta.patches.is_empty());
408 assert!(!ctx.has_delta());
409 assert_eq!(ctx.messages().len(), 2);
410 }
411
412 #[test]
413 fn add_patch_creates_delta() {
414 let mut ctx = RunContext::new("t-1", json!({"a": 1}), vec![], RunPolicy::default());
415 let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(2))));
416 ctx.add_thread_patch(patch);
417 assert!(ctx.has_delta());
418 let delta = ctx.take_delta();
419 assert_eq!(delta.patches.len(), 1);
420 assert!(!ctx.has_delta());
421 }
422
423 #[test]
424 fn multiple_deltas() {
425 let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
426 ctx.add_message(Arc::new(Message::user("a")));
427 let d1 = ctx.take_delta();
428 assert_eq!(d1.messages.len(), 1);
429
430 ctx.add_message(Arc::new(Message::user("b")));
431 ctx.add_message(Arc::new(Message::user("c")));
432 let d2 = ctx.take_delta();
433 assert_eq!(d2.messages.len(), 2);
434
435 let d3 = ctx.take_delta();
436 assert!(d3.is_empty());
437 }
438
439 #[test]
446 fn initial_messages_excluded_from_delta() {
447 let initial = vec![
448 Arc::new(Message::user("pre-existing-1")),
449 Arc::new(Message::assistant("pre-existing-2")),
450 ];
451 let mut ctx = RunContext::new("t-1", json!({}), initial, RunPolicy::default());
452
453 assert!(!ctx.has_delta());
455 let delta = ctx.take_delta();
456 assert!(delta.messages.is_empty());
457 assert_eq!(ctx.messages().len(), 2);
458
459 ctx.add_message(Arc::new(Message::user("run-added")));
461 let delta = ctx.take_delta();
462 assert_eq!(delta.messages.len(), 1);
463 assert_eq!(delta.messages[0].content, "run-added");
464 assert_eq!(ctx.messages().len(), 3);
466 }
467
468 #[test]
471 fn all_patches_are_delta() {
472 let mut ctx = RunContext::new("t-1", json!({"a": 0}), vec![], RunPolicy::default());
473 ctx.add_thread_patch(TrackedPatch::new(
474 Patch::new().with_op(Op::set(path!("a"), json!(1))),
475 ));
476 ctx.add_thread_patch(TrackedPatch::new(
477 Patch::new().with_op(Op::set(path!("a"), json!(2))),
478 ));
479 let delta = ctx.take_delta();
480 assert_eq!(delta.patches.len(), 2, "all run patches should be in delta");
481 }
482
483 #[test]
485 fn consecutive_take_delta_non_overlapping() {
486 let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
487
488 ctx.add_message(Arc::new(Message::user("m1")));
490 ctx.add_thread_patch(TrackedPatch::new(
491 Patch::new().with_op(Op::set(path!("x"), json!(1))),
492 ));
493 let d1 = ctx.take_delta();
494 assert_eq!(d1.messages.len(), 1);
495 assert_eq!(d1.patches.len(), 1);
496
497 ctx.add_message(Arc::new(Message::user("m2")));
499 ctx.add_message(Arc::new(Message::user("m3")));
500 ctx.add_thread_patch(TrackedPatch::new(
501 Patch::new().with_op(Op::set(path!("y"), json!(2))),
502 ));
503 let d2 = ctx.take_delta();
504 assert_eq!(d2.messages.len(), 2);
505 assert_eq!(d2.patches.len(), 1);
506
507 let d3 = ctx.take_delta();
509 assert!(d3.is_empty());
510
511 assert_eq!(ctx.messages().len(), 3);
513 assert_eq!(ctx.thread_patches().len(), 2);
514 }
515
516 #[test]
521 fn snapshot_of_deserializes_at_canonical_path() {
522 use crate::testing::TestFixtureState;
523
524 let ctx = RunContext::new(
525 "t-1",
526 json!({"__test_fixture": {"label": null}}),
527 vec![],
528 RunPolicy::default(),
529 );
530 let ctrl: TestFixtureState = ctx.snapshot_of().unwrap();
531 assert!(ctrl.label.is_none());
532 }
533
534 #[test]
535 fn snapshot_at_deserializes_at_explicit_path() {
536 use crate::testing::TestFixtureState;
537
538 let ctx = RunContext::new(
539 "t-1",
540 json!({"custom": {"label": null}}),
541 vec![],
542 RunPolicy::default(),
543 );
544 let ctrl: TestFixtureState = ctx.snapshot_at("custom").unwrap();
545 assert!(ctrl.label.is_none());
546 }
547
548 #[test]
549 fn snapshot_of_returns_error_for_missing_path() {
550 use crate::testing::TestFixtureState;
551
552 let ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
553 assert!(ctx.snapshot_of::<TestFixtureState>().is_err());
554 }
555
556 #[test]
561 fn from_thread_rebuilds_existing_patches() {
562 use crate::thread::Thread;
563
564 let mut thread = Thread::with_initial_state("t-1", json!({"counter": 0}));
565 thread.patches.push(TrackedPatch::new(
566 Patch::new().with_op(Op::set(path!("counter"), json!(5))),
567 ));
568
569 let ctx = RunContext::from_thread(&thread, RunPolicy::default()).unwrap();
570 assert_eq!(ctx.thread_base()["counter"], 5);
572 assert!(ctx.thread_patches().is_empty());
574 assert_eq!(ctx.snapshot().unwrap()["counter"], 5);
576 }
577
578 #[test]
579 fn from_thread_carries_version_metadata() {
580 use crate::thread::Thread;
581
582 let mut thread = Thread::new("t-1");
583 thread.metadata.version = Some(42);
584 thread.metadata.version_timestamp = Some(1700000000);
585
586 let ctx = RunContext::from_thread(&thread, RunPolicy::default()).unwrap();
587 assert_eq!(ctx.version(), 42);
588 assert_eq!(ctx.version_timestamp(), Some(1700000000));
589 }
590
591 #[test]
592 fn from_thread_broken_patch_returns_error() {
593 use crate::thread::Thread;
594
595 let mut thread = Thread::with_initial_state("t-1", json!({"x": 1}));
596 thread.patches.push(TrackedPatch::new(Patch::with_ops(vec![
598 tirea_state::Op::Append {
599 path: path!("x"),
600 value: json!(999),
601 },
602 ])));
603
604 let result = RunContext::from_thread(&thread, RunPolicy::default());
605 assert!(
606 result.is_err(),
607 "broken patch should cause from_thread to fail"
608 );
609 }
610
611 #[test]
616 fn version_defaults_to_zero() {
617 let ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
618 assert_eq!(ctx.version(), 0);
619 assert_eq!(ctx.version_timestamp(), None);
620 }
621
622 #[test]
623 fn set_version_updates_correctly() {
624 let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
625 ctx.set_version(5, Some(1700000000));
626 assert_eq!(ctx.version(), 5);
627 assert_eq!(ctx.version_timestamp(), Some(1700000000));
628
629 ctx.set_version(6, None);
631 assert_eq!(ctx.version(), 6);
632 assert_eq!(ctx.version_timestamp(), Some(1700000000));
634 }
635}