Skip to main content

tirea_contract/runtime/run/
context.rs

1use crate::runtime::activity::ActivityManager;
2use crate::runtime::run::delta::RunDelta;
3use crate::runtime::run::RunIdentity;
4use crate::runtime::state::SerializedStateAction;
5use crate::runtime::suspended_calls_from_state;
6use crate::runtime::tool_call::ToolCallContext;
7use crate::runtime::tool_call::{CallerContext, SuspendedCall};
8use crate::thread::Message;
9use crate::RunPolicy;
10use serde_json::Value;
11use std::sync::{Arc, Mutex};
12use tirea_state::{
13    apply_patches_with_registry, get_at_path, parse_path, DeltaTracked, DocCell, LatticeRegistry,
14    Op, State, TireaResult, TrackedPatch,
15};
16
17/// Run-scoped workspace that holds mutable state for a single agent run.
18///
19/// `RunContext` is constructed from a `Thread`'s persisted data at the start
20/// of a run and accumulates messages, patches, and overlay ops as the run
21/// progresses. It owns the `DocCell` (live document) and provides delta
22/// extraction via `take_delta()`.
23///
24/// It does **not** hold the `Thread` itself — only the data needed for
25/// execution. Run identity lives in `RunIdentity`; this workspace only owns
26/// live execution data.
27pub struct RunContext {
28    thread_base: Value,
29    messages: DeltaTracked<Arc<Message>>,
30    thread_patches: DeltaTracked<TrackedPatch>,
31    serialized_state_actions: DeltaTracked<SerializedStateAction>,
32    run_policy: RunPolicy,
33    run_identity: RunIdentity,
34    doc: DocCell,
35    version: Option<u64>,
36    version_timestamp: Option<u64>,
37    lattice_registry: Arc<LatticeRegistry>,
38}
39
40impl RunContext {
41    /// Build a run workspace from thread data.
42    ///
43    /// - `thread_id`: thread identifier (owned)
44    /// - `state`: already-rebuilt state (base + patches)
45    /// - `messages`: initial messages (cursor set to end — no delta)
46    /// - `run_policy`: typed per-run run policy
47    pub fn new(
48        thread_id: impl Into<String>,
49        state: Value,
50        messages: Vec<Arc<Message>>,
51        run_policy: RunPolicy,
52    ) -> Self {
53        let thread_id = thread_id.into();
54        Self::with_registry_and_identity(
55            state,
56            messages,
57            run_policy,
58            RunIdentity::for_thread(thread_id),
59            Arc::new(LatticeRegistry::new()),
60        )
61    }
62
63    /// Build a run workspace with a pre-populated lattice registry.
64    pub fn with_registry(
65        thread_id: impl Into<String>,
66        state: Value,
67        messages: Vec<Arc<Message>>,
68        run_policy: RunPolicy,
69        lattice_registry: Arc<LatticeRegistry>,
70    ) -> Self {
71        let thread_id = thread_id.into();
72        Self::with_registry_and_identity(
73            state,
74            messages,
75            run_policy,
76            RunIdentity::for_thread(thread_id),
77            lattice_registry,
78        )
79    }
80
81    pub fn with_registry_and_identity(
82        state: Value,
83        messages: Vec<Arc<Message>>,
84        run_policy: RunPolicy,
85        run_identity: RunIdentity,
86        lattice_registry: Arc<LatticeRegistry>,
87    ) -> Self {
88        let doc = DocCell::new(state.clone());
89        Self {
90            thread_base: state,
91            messages: DeltaTracked::new(messages),
92            thread_patches: DeltaTracked::empty(),
93            serialized_state_actions: DeltaTracked::empty(),
94            run_policy,
95            run_identity,
96            doc,
97            version: None,
98            version_timestamp: None,
99            lattice_registry,
100        }
101    }
102
103    // =========================================================================
104    // Identity
105    // =========================================================================
106
107    /// Thread identifier.
108    pub fn thread_id(&self) -> &str {
109        &self.run_identity.thread_id
110    }
111
112    pub fn run_policy(&self) -> &RunPolicy {
113        &self.run_policy
114    }
115
116    pub fn run_identity(&self) -> &RunIdentity {
117        &self.run_identity
118    }
119
120    pub fn set_run_identity(&mut self, run_identity: RunIdentity) {
121        self.run_identity = run_identity;
122    }
123
124    // =========================================================================
125    // Version
126    // =========================================================================
127
128    /// Current committed version (0 if never committed).
129    pub fn version(&self) -> u64 {
130        self.version.unwrap_or(0)
131    }
132
133    /// Update version after a successful state commit.
134    pub fn set_version(&mut self, version: u64, timestamp: Option<u64>) {
135        self.version = Some(version);
136        if let Some(ts) = timestamp {
137            self.version_timestamp = Some(ts);
138        }
139    }
140
141    /// Timestamp of the last committed version.
142    pub fn version_timestamp(&self) -> Option<u64> {
143        self.version_timestamp
144    }
145
146    // =========================================================================
147    // Suspended calls
148    // =========================================================================
149
150    /// Read all suspended calls from durable control state.
151    pub fn suspended_calls(&self) -> std::collections::HashMap<String, SuspendedCall> {
152        self.snapshot()
153            .map(|s| suspended_calls_from_state(&s))
154            .unwrap_or_default()
155    }
156
157    // =========================================================================
158    // Messages
159    // =========================================================================
160
161    /// All messages (initial + accumulated during run).
162    pub fn messages(&self) -> &[Arc<Message>] {
163        self.messages.as_slice()
164    }
165
166    /// Number of messages that existed before this run started.
167    pub fn initial_message_count(&self) -> usize {
168        self.messages.initial_count()
169    }
170
171    /// Add a single message to the run.
172    pub fn add_message(&mut self, msg: Arc<Message>) {
173        self.messages.push(msg);
174    }
175
176    /// Add multiple messages to the run.
177    pub fn add_messages(&mut self, msgs: Vec<Arc<Message>>) {
178        self.messages.extend(msgs);
179    }
180
181    // =========================================================================
182    // State / Patches
183    // =========================================================================
184
185    /// The initial rebuilt state (base + thread patches).
186    pub fn thread_base(&self) -> &Value {
187        &self.thread_base
188    }
189
190    /// Add a tracked patch from this run.
191    pub fn add_thread_patch(&mut self, patch: TrackedPatch) {
192        self.thread_patches.push(patch);
193    }
194
195    /// Add multiple tracked patches from this run.
196    pub fn add_thread_patches(&mut self, patches: Vec<TrackedPatch>) {
197        self.thread_patches.extend(patches);
198    }
199
200    /// All patches accumulated during this run.
201    pub fn thread_patches(&self) -> &[TrackedPatch] {
202        self.thread_patches.as_slice()
203    }
204
205    // =========================================================================
206    // Serialized State Actions (intent log)
207    // =========================================================================
208
209    /// Add serialized state actions captured during tool/phase execution.
210    pub fn add_serialized_state_actions(&mut self, state_actions: Vec<SerializedStateAction>) {
211        self.serialized_state_actions.extend(state_actions);
212    }
213
214    // =========================================================================
215    // Doc (live document)
216    // =========================================================================
217
218    /// Rebuild the current run-visible state (thread_base + thread_patches).
219    ///
220    /// This is a pure computation that returns a new `Value` without
221    /// touching the `DocCell`.
222    pub fn snapshot(&self) -> TireaResult<Value> {
223        let patches = self.thread_patches.as_slice();
224        if patches.is_empty() {
225            Ok(self.thread_base.clone())
226        } else {
227            apply_patches_with_registry(
228                &self.thread_base,
229                patches.iter().map(|p| p.patch()),
230                &self.lattice_registry,
231            )
232        }
233    }
234
235    /// Typed snapshot at the type's canonical path.
236    ///
237    /// Rebuilds state and deserializes the value at `T::PATH`.
238    pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
239        let val = self.snapshot()?;
240        let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
241        T::from_value(at)
242    }
243
244    /// Typed snapshot at an explicit path.
245    ///
246    /// Rebuilds state and deserializes the value at the given path.
247    pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
248        let val = self.snapshot()?;
249        let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
250        T::from_value(at)
251    }
252
253    // =========================================================================
254    // Delta output
255    // =========================================================================
256
257    /// Extract the incremental delta (new messages + patches + serialized state actions) since
258    /// the last `take_delta()` call.
259    pub fn take_delta(&mut self) -> RunDelta {
260        RunDelta {
261            messages: self.messages.take_delta(),
262            patches: self.thread_patches.take_delta(),
263            state_actions: self.serialized_state_actions.take_delta(),
264        }
265    }
266
267    /// Whether there are un-consumed messages, patches, or serialized state actions.
268    pub fn has_delta(&self) -> bool {
269        self.messages.has_delta()
270            || self.thread_patches.has_delta()
271            || self.serialized_state_actions.has_delta()
272    }
273
274    // =========================================================================
275    // ToolCallContext derivation
276    // =========================================================================
277
278    /// Create a `ToolCallContext` scoped to a specific tool call.
279    pub fn tool_call_context<'ctx>(
280        &'ctx self,
281        ops: &'ctx Mutex<Vec<Op>>,
282        call_id: impl Into<String>,
283        source: impl Into<String>,
284        pending_messages: &'ctx Mutex<Vec<Arc<Message>>>,
285        activity_manager: Arc<dyn ActivityManager>,
286    ) -> ToolCallContext<'ctx> {
287        let caller_context = CallerContext::new(
288            Some(self.thread_id().to_string()),
289            self.run_identity.run_id_opt().map(ToOwned::to_owned),
290            self.run_identity.agent_id_opt().map(ToOwned::to_owned),
291            self.messages().to_vec(),
292        );
293        ToolCallContext::new(
294            &self.doc,
295            ops,
296            call_id,
297            source,
298            &self.run_policy,
299            pending_messages,
300            activity_manager,
301        )
302        .with_run_identity(self.run_identity.clone())
303        .with_caller_context(caller_context)
304    }
305}
306
307impl RunContext {
308    /// Convenience constructor from a `Thread`.
309    ///
310    /// Rebuilds state from the thread's base state + patches, then wraps
311    /// the thread's messages and the given `run_policy` into a `RunContext`.
312    /// Version metadata is carried over from thread metadata.
313    pub fn from_thread(
314        thread: &crate::thread::Thread,
315        run_policy: RunPolicy,
316    ) -> Result<Self, tirea_state::TireaError> {
317        Self::from_thread_with_registry_and_identity(
318            thread,
319            run_policy,
320            RunIdentity::for_thread(thread.id.clone()),
321            Arc::new(LatticeRegistry::new()),
322        )
323    }
324
325    /// Convenience constructor from a `Thread` with a lattice registry.
326    pub fn from_thread_with_registry(
327        thread: &crate::thread::Thread,
328        run_policy: RunPolicy,
329        lattice_registry: Arc<LatticeRegistry>,
330    ) -> Result<Self, tirea_state::TireaError> {
331        Self::from_thread_with_registry_and_identity(
332            thread,
333            run_policy,
334            RunIdentity::for_thread(thread.id.clone()),
335            lattice_registry,
336        )
337    }
338
339    pub fn from_thread_with_registry_and_identity(
340        thread: &crate::thread::Thread,
341        run_policy: RunPolicy,
342        mut run_identity: RunIdentity,
343        lattice_registry: Arc<LatticeRegistry>,
344    ) -> Result<Self, tirea_state::TireaError> {
345        if run_identity.thread_id_opt().is_none() {
346            run_identity.thread_id = thread.id.clone();
347        }
348        if run_identity.parent_thread_id_opt().is_none() {
349            run_identity.parent_thread_id = thread.parent_thread_id.clone();
350        }
351        let state = thread.rebuild_state()?;
352        let messages: Vec<Arc<Message>> = thread.messages.clone();
353        let mut ctx = Self::with_registry_and_identity(
354            state,
355            messages,
356            run_policy,
357            run_identity,
358            lattice_registry,
359        );
360        if let Some(v) = thread.metadata.version {
361            ctx.set_version(v, thread.metadata.version_timestamp);
362        }
363        Ok(ctx)
364    }
365
366    /// The lattice registry used by this context for CRDT-aware operations.
367    pub fn lattice_registry(&self) -> &Arc<LatticeRegistry> {
368        &self.lattice_registry
369    }
370}
371
372impl std::fmt::Debug for RunContext {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        f.debug_struct("RunContext")
375            .field("thread_id", &self.thread_id())
376            .field("messages", &self.messages.len())
377            .field("thread_patches", &self.thread_patches.len())
378            .field("has_delta", &self.has_delta())
379            .finish()
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use serde_json::json;
387    use tirea_state::{path, Patch};
388
389    #[test]
390    fn new_context_has_no_delta() {
391        let msgs = vec![Arc::new(Message::user("hi"))];
392        let mut ctx = RunContext::new("t-1", json!({}), msgs, RunPolicy::default());
393        assert!(!ctx.has_delta());
394        let delta = ctx.take_delta();
395        assert!(delta.is_empty());
396        assert_eq!(ctx.messages().len(), 1);
397    }
398
399    #[test]
400    fn add_message_creates_delta() {
401        let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
402        ctx.add_message(Arc::new(Message::user("hello")));
403        ctx.add_message(Arc::new(Message::assistant("hi")));
404        assert!(ctx.has_delta());
405        let delta = ctx.take_delta();
406        assert_eq!(delta.messages.len(), 2);
407        assert!(delta.patches.is_empty());
408        assert!(!ctx.has_delta());
409        assert_eq!(ctx.messages().len(), 2);
410    }
411
412    #[test]
413    fn add_patch_creates_delta() {
414        let mut ctx = RunContext::new("t-1", json!({"a": 1}), vec![], RunPolicy::default());
415        let patch = TrackedPatch::new(Patch::new().with_op(Op::set(path!("a"), json!(2))));
416        ctx.add_thread_patch(patch);
417        assert!(ctx.has_delta());
418        let delta = ctx.take_delta();
419        assert_eq!(delta.patches.len(), 1);
420        assert!(!ctx.has_delta());
421    }
422
423    #[test]
424    fn multiple_deltas() {
425        let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
426        ctx.add_message(Arc::new(Message::user("a")));
427        let d1 = ctx.take_delta();
428        assert_eq!(d1.messages.len(), 1);
429
430        ctx.add_message(Arc::new(Message::user("b")));
431        ctx.add_message(Arc::new(Message::user("c")));
432        let d2 = ctx.take_delta();
433        assert_eq!(d2.messages.len(), 2);
434
435        let d3 = ctx.take_delta();
436        assert!(d3.is_empty());
437    }
438
439    // =========================================================================
440    // Category 1: Delta extraction incremental semantics
441    // =========================================================================
442
443    /// Initial messages passed to `new()` are NOT part of the delta.
444    /// Only run-added messages appear in `take_delta()`.
445    #[test]
446    fn initial_messages_excluded_from_delta() {
447        let initial = vec![
448            Arc::new(Message::user("pre-existing-1")),
449            Arc::new(Message::assistant("pre-existing-2")),
450        ];
451        let mut ctx = RunContext::new("t-1", json!({}), initial, RunPolicy::default());
452
453        // No delta despite having 2 messages
454        assert!(!ctx.has_delta());
455        let delta = ctx.take_delta();
456        assert!(delta.messages.is_empty());
457        assert_eq!(ctx.messages().len(), 2);
458
459        // Now add a run message — only that one appears
460        ctx.add_message(Arc::new(Message::user("run-added")));
461        let delta = ctx.take_delta();
462        assert_eq!(delta.messages.len(), 1);
463        assert_eq!(delta.messages[0].content, "run-added");
464        // Total messages still include initial
465        assert_eq!(ctx.messages().len(), 3);
466    }
467
468    /// All patches are delta (cursor starts at 0) — every patch added during
469    /// a run is considered new.
470    #[test]
471    fn all_patches_are_delta() {
472        let mut ctx = RunContext::new("t-1", json!({"a": 0}), vec![], RunPolicy::default());
473        ctx.add_thread_patch(TrackedPatch::new(
474            Patch::new().with_op(Op::set(path!("a"), json!(1))),
475        ));
476        ctx.add_thread_patch(TrackedPatch::new(
477            Patch::new().with_op(Op::set(path!("a"), json!(2))),
478        ));
479        let delta = ctx.take_delta();
480        assert_eq!(delta.patches.len(), 2, "all run patches should be in delta");
481    }
482
483    /// Multiple take_delta calls produce non-overlapping results.
484    #[test]
485    fn consecutive_take_delta_non_overlapping() {
486        let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
487
488        // Round 1: 1 message + 1 patch
489        ctx.add_message(Arc::new(Message::user("m1")));
490        ctx.add_thread_patch(TrackedPatch::new(
491            Patch::new().with_op(Op::set(path!("x"), json!(1))),
492        ));
493        let d1 = ctx.take_delta();
494        assert_eq!(d1.messages.len(), 1);
495        assert_eq!(d1.patches.len(), 1);
496
497        // Round 2: 2 messages + 1 patch (no overlap with d1)
498        ctx.add_message(Arc::new(Message::user("m2")));
499        ctx.add_message(Arc::new(Message::user("m3")));
500        ctx.add_thread_patch(TrackedPatch::new(
501            Patch::new().with_op(Op::set(path!("y"), json!(2))),
502        ));
503        let d2 = ctx.take_delta();
504        assert_eq!(d2.messages.len(), 2);
505        assert_eq!(d2.patches.len(), 1);
506
507        // Round 3: nothing added
508        let d3 = ctx.take_delta();
509        assert!(d3.is_empty());
510
511        // Total accumulated
512        assert_eq!(ctx.messages().len(), 3);
513        assert_eq!(ctx.thread_patches().len(), 2);
514    }
515
516    // =========================================================================
517    // Category 6: Typed snapshot (snapshot_of / snapshot_at)
518    // =========================================================================
519
520    #[test]
521    fn snapshot_of_deserializes_at_canonical_path() {
522        use crate::testing::TestFixtureState;
523
524        let ctx = RunContext::new(
525            "t-1",
526            json!({"__test_fixture": {"label": null}}),
527            vec![],
528            RunPolicy::default(),
529        );
530        let ctrl: TestFixtureState = ctx.snapshot_of().unwrap();
531        assert!(ctrl.label.is_none());
532    }
533
534    #[test]
535    fn snapshot_at_deserializes_at_explicit_path() {
536        use crate::testing::TestFixtureState;
537
538        let ctx = RunContext::new(
539            "t-1",
540            json!({"custom": {"label": null}}),
541            vec![],
542            RunPolicy::default(),
543        );
544        let ctrl: TestFixtureState = ctx.snapshot_at("custom").unwrap();
545        assert!(ctrl.label.is_none());
546    }
547
548    #[test]
549    fn snapshot_of_returns_error_for_missing_path() {
550        use crate::testing::TestFixtureState;
551
552        let ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
553        assert!(ctx.snapshot_of::<TestFixtureState>().is_err());
554    }
555
556    // =========================================================================
557    // Category 5: from_thread boundary conditions
558    // =========================================================================
559
560    #[test]
561    fn from_thread_rebuilds_existing_patches() {
562        use crate::thread::Thread;
563
564        let mut thread = Thread::with_initial_state("t-1", json!({"counter": 0}));
565        thread.patches.push(TrackedPatch::new(
566            Patch::new().with_op(Op::set(path!("counter"), json!(5))),
567        ));
568
569        let ctx = RunContext::from_thread(&thread, RunPolicy::default()).unwrap();
570        // thread_base is pre-rebuilt (includes thread patches)
571        assert_eq!(ctx.thread_base()["counter"], 5);
572        // No run patches yet
573        assert!(ctx.thread_patches().is_empty());
574        // snapshot() is consistent with thread_base()
575        assert_eq!(ctx.snapshot().unwrap()["counter"], 5);
576    }
577
578    #[test]
579    fn from_thread_carries_version_metadata() {
580        use crate::thread::Thread;
581
582        let mut thread = Thread::new("t-1");
583        thread.metadata.version = Some(42);
584        thread.metadata.version_timestamp = Some(1700000000);
585
586        let ctx = RunContext::from_thread(&thread, RunPolicy::default()).unwrap();
587        assert_eq!(ctx.version(), 42);
588        assert_eq!(ctx.version_timestamp(), Some(1700000000));
589    }
590
591    #[test]
592    fn from_thread_broken_patch_returns_error() {
593        use crate::thread::Thread;
594
595        let mut thread = Thread::with_initial_state("t-1", json!({"x": 1}));
596        // Append to a non-array path — this will fail during rebuild_state
597        thread.patches.push(TrackedPatch::new(Patch::with_ops(vec![
598            tirea_state::Op::Append {
599                path: path!("x"),
600                value: json!(999),
601            },
602        ])));
603
604        let result = RunContext::from_thread(&thread, RunPolicy::default());
605        assert!(
606            result.is_err(),
607            "broken patch should cause from_thread to fail"
608        );
609    }
610
611    // =========================================================================
612    // Version tracking
613    // =========================================================================
614
615    #[test]
616    fn version_defaults_to_zero() {
617        let ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
618        assert_eq!(ctx.version(), 0);
619        assert_eq!(ctx.version_timestamp(), None);
620    }
621
622    #[test]
623    fn set_version_updates_correctly() {
624        let mut ctx = RunContext::new("t-1", json!({}), vec![], RunPolicy::default());
625        ctx.set_version(5, Some(1700000000));
626        assert_eq!(ctx.version(), 5);
627        assert_eq!(ctx.version_timestamp(), Some(1700000000));
628
629        // Update again
630        ctx.set_version(6, None);
631        assert_eq!(ctx.version(), 6);
632        // Timestamp unchanged when None passed
633        assert_eq!(ctx.version_timestamp(), Some(1700000000));
634    }
635}