Skip to main content

tirea_contract/runtime/tool_call/
context.rs

1//! Execution context types for tools and plugins.
2//!
3//! `ToolCallContext` provides state access, run policy, and identity for tool execution.
4//! It replaces direct `&Thread` usage in tool signatures, keeping the persistent
5//! entity (`Thread`) invisible to tools and plugins.
6
7use crate::runtime::activity::ActivityManager;
8use crate::runtime::run::RunIdentity;
9use crate::runtime::{ToolCallResume, ToolCallState};
10use crate::thread::Message;
11use crate::RunPolicy;
12use futures::future::pending;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::sync::{Arc, Mutex};
16use std::time::{SystemTime, UNIX_EPOCH};
17use tirea_state::{
18    get_at_path, parse_path, DocCell, Op, Patch, PatchSink, Path, State, TireaError, TireaResult,
19    TrackedPatch,
20};
21use tokio_util::sync::CancellationToken;
22
23type PatchHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
24const TOOL_PROGRESS_STREAM_PREFIX: &str = "tool_call:";
25/// Activity type used for tool-call progress updates.
26pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
27/// Legacy public alias kept for backward compatibility.
28pub const TOOL_PROGRESS_ACTIVITY_TYPE: &str = TOOL_CALL_PROGRESS_ACTIVITY_TYPE;
29/// Legacy activity type accepted by consumers.
30pub const TOOL_PROGRESS_ACTIVITY_TYPE_LEGACY: &str = "progress";
31/// Canonical payload `type` value for tool-call progress events.
32pub const TOOL_CALL_PROGRESS_TYPE: &str = "tool-call-progress";
33/// Canonical payload schema version for tool-call progress events.
34pub const TOOL_CALL_PROGRESS_SCHEMA: &str = "tool-call-progress.v1";
35
36/// Status marker for a tool-call progress node.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum ToolCallProgressStatus {
40    Pending,
41    #[default]
42    Running,
43    Done,
44    Failed,
45    Cancelled,
46}
47
48/// Canonical tree-node payload for tool-call progress updates.
49#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
50pub struct ToolCallProgressState {
51    /// Payload type identifier.
52    #[serde(rename = "type")]
53    pub event_type: String,
54    /// Payload schema version.
55    pub schema: String,
56    /// Stable node id.
57    pub node_id: String,
58    /// Optional parent node id in the progress tree.
59    #[serde(default)]
60    pub parent_node_id: Option<String>,
61    /// Optional parent tool call id when this node belongs to a nested run.
62    #[serde(default)]
63    pub parent_call_id: Option<String>,
64    /// Tool call id that owns this node.
65    pub call_id: String,
66    /// Optional tool name.
67    #[serde(default)]
68    pub tool_name: Option<String>,
69    /// Current status.
70    pub status: ToolCallProgressStatus,
71    /// Normalized progress ratio when available.
72    #[serde(default)]
73    pub progress: Option<f64>,
74    /// Optional absolute loaded counter.
75    #[serde(default)]
76    pub loaded: Option<f64>,
77    /// Optional absolute total counter.
78    #[serde(default)]
79    pub total: Option<f64>,
80    /// Optional human-readable message.
81    #[serde(default)]
82    pub message: Option<String>,
83    /// Current run id.
84    #[serde(default)]
85    pub run_id: Option<String>,
86    /// Parent run id.
87    #[serde(default)]
88    pub parent_run_id: Option<String>,
89    /// Current thread id when available.
90    #[serde(default)]
91    pub thread_id: Option<String>,
92    /// Last update timestamp in unix milliseconds.
93    pub updated_at_ms: u64,
94}
95
96/// Input shape for publishing tool-call progress updates.
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct ToolCallProgressUpdate {
99    #[serde(default)]
100    pub status: ToolCallProgressStatus,
101    #[serde(default, skip_serializing_if = "Option::is_none")]
102    pub progress: Option<f64>,
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub loaded: Option<f64>,
105    #[serde(default, skip_serializing_if = "Option::is_none")]
106    pub total: Option<f64>,
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub message: Option<String>,
109}
110
111/// Canonical activity state shape for tool progress updates.
112#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
113pub struct ToolProgressState {
114    /// Normalized progress value.
115    pub progress: f64,
116    /// Optional absolute total if the source has one.
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    pub total: Option<f64>,
119    /// Optional human-readable progress message.
120    #[serde(default, skip_serializing_if = "Option::is_none")]
121    pub message: Option<String>,
122}
123
124/// Sink interface for tool-call progress events.
125///
126/// Tools report progress through [`ToolCallContext::report_tool_call_progress`], and
127/// the context forwards canonical payloads into this sink. The sink implementation
128/// decides how payloads are emitted/transported.
129pub trait ToolCallProgressSink: Send + Sync {
130    /// Consume a canonical tool-call progress payload.
131    fn report(
132        &self,
133        stream_id: &str,
134        activity_type: &str,
135        payload: &ToolCallProgressState,
136    ) -> TireaResult<()>;
137}
138
139#[derive(Clone)]
140struct ActivityManagerProgressSink {
141    manager: Arc<dyn ActivityManager>,
142}
143
144impl ActivityManagerProgressSink {
145    fn new(manager: Arc<dyn ActivityManager>) -> Self {
146        Self { manager }
147    }
148}
149
150/// Typed caller metadata exposed to tool executions.
151#[derive(Clone, Debug, Default)]
152pub struct CallerContext {
153    thread_id: Option<String>,
154    run_id: Option<String>,
155    agent_id: Option<String>,
156    messages: Arc<[Arc<Message>]>,
157}
158
159impl CallerContext {
160    pub fn new(
161        thread_id: Option<String>,
162        run_id: Option<String>,
163        agent_id: Option<String>,
164        messages: Vec<Arc<Message>>,
165    ) -> Self {
166        Self {
167            thread_id: thread_id
168                .map(|value| value.trim().to_string())
169                .filter(|value| !value.is_empty()),
170            run_id: run_id
171                .map(|value| value.trim().to_string())
172                .filter(|value| !value.is_empty()),
173            agent_id: agent_id
174                .map(|value| value.trim().to_string())
175                .filter(|value| !value.is_empty()),
176            messages: Arc::<[Arc<Message>]>::from(messages),
177        }
178    }
179
180    pub fn thread_id(&self) -> Option<&str> {
181        self.thread_id.as_deref()
182    }
183
184    pub fn run_id(&self) -> Option<&str> {
185        self.run_id.as_deref()
186    }
187
188    pub fn agent_id(&self) -> Option<&str> {
189        self.agent_id.as_deref()
190    }
191
192    pub fn messages(&self) -> &[Arc<Message>] {
193        self.messages.as_ref()
194    }
195}
196
197impl ToolCallProgressSink for ActivityManagerProgressSink {
198    fn report(
199        &self,
200        stream_id: &str,
201        activity_type: &str,
202        payload: &ToolCallProgressState,
203    ) -> TireaResult<()> {
204        let Value::Object(fields) = serde_json::to_value(payload)? else {
205            return Err(TireaError::invalid_operation(
206                "tool-call-progress payload must serialize as object",
207            ));
208        };
209        for (key, value) in fields {
210            let op = Op::set(Path::root().key(key), value);
211            self.manager.on_activity_op(stream_id, activity_type, &op);
212        }
213        Ok(())
214    }
215}
216
217/// Execution context for tool invocations.
218///
219/// Provides typed state access (read/write), run policy access, identity,
220/// message queuing, and activity tracking. Tools receive `&ToolCallContext`
221/// instead of `&Thread`.
222pub struct ToolCallContext<'a> {
223    doc: &'a DocCell,
224    ops: &'a Mutex<Vec<Op>>,
225    call_id: String,
226    source: String,
227    run_policy: &'a RunPolicy,
228    run_identity: RunIdentity,
229    caller_context: CallerContext,
230    pending_messages: &'a Mutex<Vec<Arc<Message>>>,
231    activity_manager: Arc<dyn ActivityManager>,
232    tool_call_progress_sink: Arc<dyn ToolCallProgressSink>,
233    cancellation_token: Option<&'a CancellationToken>,
234}
235
236impl<'a> ToolCallContext<'a> {
237    fn tool_call_state_path(call_id: &str) -> Path {
238        Path::root()
239            .key("__tool_call_scope")
240            .key(call_id)
241            .key("tool_call_state")
242    }
243
244    fn apply_op(&self, op: Op) -> TireaResult<()> {
245        self.doc.apply(&op)?;
246        self.ops.lock().unwrap().push(op);
247        Ok(())
248    }
249
250    /// Create a new tool call context.
251    pub fn new(
252        doc: &'a DocCell,
253        ops: &'a Mutex<Vec<Op>>,
254        call_id: impl Into<String>,
255        source: impl Into<String>,
256        run_policy: &'a RunPolicy,
257        pending_messages: &'a Mutex<Vec<Arc<Message>>>,
258        activity_manager: Arc<dyn ActivityManager>,
259    ) -> Self {
260        let tool_call_progress_sink: Arc<dyn ToolCallProgressSink> =
261            Arc::new(ActivityManagerProgressSink::new(activity_manager.clone()));
262        Self {
263            doc,
264            ops,
265            call_id: call_id.into(),
266            source: source.into(),
267            run_policy,
268            run_identity: RunIdentity::default(),
269            caller_context: CallerContext::default(),
270            pending_messages,
271            activity_manager,
272            tool_call_progress_sink,
273            cancellation_token: None,
274        }
275    }
276
277    /// Attach cancellation token.
278    #[must_use]
279    pub fn with_cancellation_token(mut self, token: &'a CancellationToken) -> Self {
280        self.cancellation_token = Some(token);
281        self
282    }
283
284    #[must_use]
285    pub fn with_run_identity(mut self, run_identity: RunIdentity) -> Self {
286        self.run_identity = run_identity;
287        self
288    }
289
290    #[must_use]
291    pub fn with_caller_context(mut self, caller_context: CallerContext) -> Self {
292        self.caller_context = caller_context;
293        self
294    }
295
296    /// Override the sink used for tool-call progress payload forwarding.
297    ///
298    /// This allows runtime integrations to decouple progress collection from
299    /// activity transport details.
300    #[must_use]
301    pub fn with_tool_call_progress_sink(mut self, sink: Arc<dyn ToolCallProgressSink>) -> Self {
302        self.tool_call_progress_sink = sink;
303        self
304    }
305
306    // =========================================================================
307    // Identity
308    // =========================================================================
309
310    /// Borrow the underlying document cell.
311    pub fn doc(&self) -> &DocCell {
312        self.doc
313    }
314
315    /// Current call id (typically the `tool_call_id`).
316    pub fn call_id(&self) -> &str {
317        &self.call_id
318    }
319
320    /// Stable idempotency key for the current tool invocation.
321    ///
322    /// Tools should use this value when implementing idempotent side effects.
323    pub fn idempotency_key(&self) -> &str {
324        self.call_id()
325    }
326
327    /// Source identifier used for tracked patches.
328    pub fn source(&self) -> &str {
329        &self.source
330    }
331
332    /// Whether the run cancellation token has already been cancelled.
333    pub fn is_cancelled(&self) -> bool {
334        self.cancellation_token
335            .is_some_and(CancellationToken::is_cancelled)
336    }
337
338    /// Await cancellation for this context.
339    ///
340    /// If no cancellation token is available, this future never resolves.
341    pub async fn cancelled(&self) {
342        if let Some(token) = self.cancellation_token {
343            token.cancelled().await;
344        } else {
345            pending::<()>().await;
346        }
347    }
348
349    /// Borrow the cancellation token when present.
350    pub fn cancellation_token(&self) -> Option<&CancellationToken> {
351        self.cancellation_token
352    }
353
354    // =========================================================================
355    // Run policy / identity
356    // =========================================================================
357
358    /// Borrow the run policy.
359    pub fn run_policy(&self) -> &RunPolicy {
360        self.run_policy
361    }
362
363    pub fn run_identity(&self) -> &RunIdentity {
364        &self.run_identity
365    }
366
367    pub fn caller_context(&self) -> &CallerContext {
368        &self.caller_context
369    }
370
371    // =========================================================================
372    // State access
373    // =========================================================================
374
375    /// Typed state reference at path.
376    pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
377        let base = parse_path(path);
378        let doc = self.doc;
379        let hook: PatchHook<'_> = Arc::new(|op: &Op| {
380            doc.apply(op)?;
381            Ok(())
382        });
383        T::state_ref(doc, base, PatchSink::new_with_hook(self.ops, hook))
384    }
385
386    /// Typed state reference at the type's canonical path.
387    ///
388    /// Panics if `T::PATH` is empty (no bound path via `#[tirea(path = "...")]`).
389    pub fn state_of<T: State>(&self) -> T::Ref<'_> {
390        assert!(
391            !T::PATH.is_empty(),
392            "State type has no bound path; use state::<T>(path) instead"
393        );
394        self.state::<T>(T::PATH)
395    }
396
397    /// Typed state reference for current call (`tool_calls.<call_id>`).
398    pub fn call_state<T: State>(&self) -> T::Ref<'_> {
399        let path = format!("tool_calls.{}", self.call_id);
400        self.state::<T>(&path)
401    }
402
403    /// Read persisted runtime state for a specific tool call.
404    pub fn tool_call_state_for(&self, call_id: &str) -> TireaResult<Option<ToolCallState>> {
405        if call_id.trim().is_empty() {
406            return Ok(None);
407        }
408        let val = self.doc.snapshot();
409        let path = Self::tool_call_state_path(call_id);
410        let at = get_at_path(&val, &path);
411        match at {
412            Some(v) if !v.is_null() => {
413                let state = ToolCallState::from_value(v)?;
414                Ok(Some(state))
415            }
416            _ => Ok(None),
417        }
418    }
419
420    /// Read persisted runtime state for current `call_id`.
421    pub fn tool_call_state(&self) -> TireaResult<Option<ToolCallState>> {
422        self.tool_call_state_for(self.call_id())
423    }
424
425    /// Upsert persisted runtime state for a specific tool call.
426    pub fn set_tool_call_state_for(&self, call_id: &str, state: ToolCallState) -> TireaResult<()> {
427        if call_id.trim().is_empty() {
428            return Err(TireaError::invalid_operation(
429                "tool_call_state requires non-empty call_id",
430            ));
431        }
432        let value = serde_json::to_value(state)?;
433        self.apply_op(Op::set(Self::tool_call_state_path(call_id), value))
434    }
435
436    /// Upsert persisted runtime state for current `call_id`.
437    pub fn set_tool_call_state(&self, state: ToolCallState) -> TireaResult<()> {
438        self.set_tool_call_state_for(self.call_id(), state)
439    }
440
441    /// Remove persisted runtime state for a specific tool call.
442    pub fn clear_tool_call_state_for(&self, call_id: &str) -> TireaResult<()> {
443        if call_id.trim().is_empty() {
444            return Ok(());
445        }
446        if self.tool_call_state_for(call_id)?.is_some() {
447            self.apply_op(Op::delete(Self::tool_call_state_path(call_id)))?;
448        }
449        Ok(())
450    }
451
452    /// Remove persisted runtime state for current `call_id`.
453    pub fn clear_tool_call_state(&self) -> TireaResult<()> {
454        self.clear_tool_call_state_for(self.call_id())
455    }
456
457    /// Read resume payload for a specific tool call.
458    pub fn resume_input_for(&self, call_id: &str) -> TireaResult<Option<ToolCallResume>> {
459        Ok(self
460            .tool_call_state_for(call_id)?
461            .and_then(|state| state.resume))
462    }
463
464    /// Read resume payload for current `call_id`.
465    pub fn resume_input(&self) -> TireaResult<Option<ToolCallResume>> {
466        self.resume_input_for(self.call_id())
467    }
468
469    // =========================================================================
470    // Messages
471    // =========================================================================
472
473    /// Queue a message addition in this operation.
474    pub fn add_message(&self, message: Message) {
475        self.pending_messages
476            .lock()
477            .unwrap()
478            .push(Arc::new(message));
479    }
480
481    /// Queue multiple messages in this operation.
482    pub fn add_messages(&self, messages: impl IntoIterator<Item = Message>) {
483        self.pending_messages
484            .lock()
485            .unwrap()
486            .extend(messages.into_iter().map(Arc::new));
487    }
488
489    // =========================================================================
490    // Activity
491    // =========================================================================
492
493    /// Create an activity context for a stream/type pair.
494    pub fn activity(
495        &self,
496        stream_id: impl Into<String>,
497        activity_type: impl Into<String>,
498    ) -> ActivityContext {
499        let stream_id = stream_id.into();
500        let activity_type = activity_type.into();
501        let snapshot = self.activity_manager.snapshot(&stream_id);
502
503        ActivityContext::new(
504            snapshot,
505            stream_id,
506            activity_type,
507            self.activity_manager.clone(),
508        )
509    }
510
511    /// Stable stream id used by default for this tool call's progress activity.
512    pub fn progress_stream_id(&self) -> String {
513        format!("{TOOL_PROGRESS_STREAM_PREFIX}{}", self.call_id)
514    }
515
516    fn source_tool_name(&self) -> Option<String> {
517        self.source
518            .strip_prefix("tool:")
519            .filter(|name| !name.trim().is_empty())
520            .map(ToOwned::to_owned)
521    }
522
523    fn validate_progress_value(name: &str, value: Option<f64>) -> TireaResult<()> {
524        let Some(value) = value else {
525            return Ok(());
526        };
527        if !value.is_finite() {
528            return Err(TireaError::invalid_operation(format!(
529                "{name} must be a finite number"
530            )));
531        }
532        if value < 0.0 {
533            return Err(TireaError::invalid_operation(format!(
534                "{name} must be non-negative"
535            )));
536        }
537        Ok(())
538    }
539
540    /// Publish a typed tool-call progress node update.
541    ///
542    /// The update is written to `activity(progress_stream_id(), "tool-call-progress")`
543    /// with payload schema `tool-call-progress.v1`.
544    pub fn report_tool_call_progress(&self, update: ToolCallProgressUpdate) -> TireaResult<()> {
545        Self::validate_progress_value("progress value", update.progress)?;
546        Self::validate_progress_value("progress loaded", update.loaded)?;
547        Self::validate_progress_value("progress total", update.total)?;
548
549        let run_id = self.run_identity.run_id_opt().map(ToOwned::to_owned);
550        let parent_run_id = self.run_identity.parent_run_id_opt().map(ToOwned::to_owned);
551        let thread_id = self.caller_context.thread_id().map(ToOwned::to_owned);
552        let parent_call_id = self.run_identity.parent_tool_call_id_opt().and_then(|id| {
553            if id == self.call_id {
554                None
555            } else {
556                Some(id.to_string())
557            }
558        });
559        let parent_node_id = parent_call_id
560            .as_ref()
561            .map(|id| format!("{TOOL_PROGRESS_STREAM_PREFIX}{id}"))
562            .or_else(|| run_id.as_ref().map(|id| format!("run:{id}")));
563        let stream_id = self.progress_stream_id();
564        let payload = ToolCallProgressState {
565            event_type: TOOL_CALL_PROGRESS_TYPE.to_string(),
566            schema: TOOL_CALL_PROGRESS_SCHEMA.to_string(),
567            node_id: stream_id.clone(),
568            parent_node_id,
569            parent_call_id,
570            call_id: self.call_id.clone(),
571            tool_name: self.source_tool_name(),
572            status: update.status,
573            progress: update.progress,
574            loaded: update.loaded,
575            total: update.total,
576            message: update.message,
577            run_id,
578            parent_run_id,
579            thread_id,
580            updated_at_ms: current_unix_millis(),
581        };
582
583        self.tool_call_progress_sink
584            .report(&stream_id, TOOL_CALL_PROGRESS_ACTIVITY_TYPE, &payload)
585    }
586
587    // =========================================================================
588    // State snapshot
589    // =========================================================================
590
591    /// Snapshot the current document state.
592    ///
593    /// Returns the current state including all write-through updates.
594    /// Equivalent to `Thread::rebuild_state()` in transient contexts.
595    pub fn snapshot(&self) -> Value {
596        self.doc.snapshot()
597    }
598
599    /// Typed snapshot at the type's canonical path.
600    ///
601    /// Reads current doc state and deserializes the value at `T::PATH`.
602    pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
603        let val = self.doc.snapshot();
604        let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
605        T::from_value(at)
606    }
607
608    /// Typed snapshot at an explicit path.
609    ///
610    /// Reads current doc state and deserializes the value at the given path.
611    pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
612        let val = self.doc.snapshot();
613        let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
614        T::from_value(at)
615    }
616
617    // =========================================================================
618    // Patch extraction
619    // =========================================================================
620
621    /// Extract accumulated patch with context source metadata.
622    pub fn take_patch(&self) -> TrackedPatch {
623        let ops = std::mem::take(&mut *self.ops.lock().unwrap());
624        TrackedPatch::new(Patch::with_ops(ops)).with_source(self.source.clone())
625    }
626
627    /// Whether state has pending transient changes.
628    pub fn has_changes(&self) -> bool {
629        !self.ops.lock().unwrap().is_empty()
630    }
631
632    /// Number of queued transient operations.
633    pub fn ops_count(&self) -> usize {
634        self.ops.lock().unwrap().len()
635    }
636}
637
638fn current_unix_millis() -> u64 {
639    SystemTime::now()
640        .duration_since(UNIX_EPOCH)
641        .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
642}
643
644/// Activity-scoped state context.
645pub struct ActivityContext {
646    doc: DocCell,
647    stream_id: String,
648    activity_type: String,
649    ops: Mutex<Vec<Op>>,
650    manager: Arc<dyn ActivityManager>,
651}
652
653impl ActivityContext {
654    pub(crate) fn new(
655        doc: Value,
656        stream_id: String,
657        activity_type: String,
658        manager: Arc<dyn ActivityManager>,
659    ) -> Self {
660        Self {
661            doc: DocCell::new(doc),
662            stream_id,
663            activity_type,
664            ops: Mutex::new(Vec::new()),
665            manager,
666        }
667    }
668
669    /// Typed activity state reference at the type's canonical path.
670    ///
671    /// Panics if `T::PATH` is empty.
672    pub fn state_of<T: State>(&self) -> T::Ref<'_> {
673        assert!(
674            !T::PATH.is_empty(),
675            "State type has no bound path; use state::<T>(path) instead"
676        );
677        self.state::<T>(T::PATH)
678    }
679
680    /// Get a typed activity state reference at the specified path.
681    ///
682    /// All modifications are automatically collected and immediately reported
683    /// to the activity manager. Writes are applied to the shared doc for
684    /// immediate read-back.
685    pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
686        let base = parse_path(path);
687        let manager = self.manager.clone();
688        let stream_id = self.stream_id.clone();
689        let activity_type = self.activity_type.clone();
690        let doc = &self.doc;
691        let hook: PatchHook<'_> = Arc::new(move |op: &Op| {
692            doc.apply(op)?;
693            manager.on_activity_op(&stream_id, &activity_type, op);
694            Ok(())
695        });
696        T::state_ref(&self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use crate::io::ResumeDecisionAction;
704    use crate::runtime::activity::{ActivityManager, NoOpActivityManager};
705    use crate::testing::TestFixtureState;
706    use serde_json::json;
707    use std::sync::Arc;
708    use tirea_state::apply_patch;
709    use tokio::time::{timeout, Duration};
710    use tokio_util::sync::CancellationToken;
711
712    fn make_ctx<'a>(
713        doc: &'a DocCell,
714        ops: &'a Mutex<Vec<Op>>,
715        run_policy: &'a RunPolicy,
716        pending: &'a Mutex<Vec<Arc<Message>>>,
717    ) -> ToolCallContext<'a> {
718        ToolCallContext::new(
719            doc,
720            ops,
721            "call-1",
722            "test",
723            run_policy,
724            pending,
725            NoOpActivityManager::arc(),
726        )
727    }
728
729    fn run_identity(run_id: &str) -> RunIdentity {
730        RunIdentity::new(
731            "thread-child".to_string(),
732            None,
733            run_id.to_string(),
734            None,
735            "agent".to_string(),
736            crate::storage::RunOrigin::Internal,
737        )
738    }
739
740    fn caller_context(thread_id: &str) -> CallerContext {
741        CallerContext::new(
742            Some(thread_id.to_string()),
743            Some("run-parent".to_string()),
744            Some("caller".to_string()),
745            vec![Arc::new(Message::user("seed"))],
746        )
747    }
748
749    #[test]
750    fn test_identity() {
751        let doc = DocCell::new(json!({}));
752        let ops = Mutex::new(Vec::new());
753        let scope = RunPolicy::default();
754        let pending = Mutex::new(Vec::new());
755
756        let ctx = make_ctx(&doc, &ops, &scope, &pending);
757        assert_eq!(ctx.call_id(), "call-1");
758        assert_eq!(ctx.idempotency_key(), "call-1");
759        assert_eq!(ctx.source(), "test");
760    }
761
762    #[test]
763    fn test_typed_context_access() {
764        let doc = DocCell::new(json!({}));
765        let ops = Mutex::new(Vec::new());
766        let scope = RunPolicy::new();
767        let pending = Mutex::new(Vec::new());
768
769        let ctx = make_ctx(&doc, &ops, &scope, &pending)
770            .with_run_identity(run_identity("run-1").with_parent_tool_call_id("call-parent"))
771            .with_caller_context(caller_context("thread-1"));
772
773        assert_eq!(
774            ctx.run_identity().parent_tool_call_id_opt(),
775            Some("call-parent")
776        );
777        assert_eq!(ctx.run_identity().run_id_opt(), Some("run-1"));
778        assert_eq!(ctx.caller_context().thread_id(), Some("thread-1"));
779        assert_eq!(ctx.caller_context().agent_id(), Some("caller"));
780        assert_eq!(ctx.caller_context().messages().len(), 1);
781    }
782
783    #[test]
784    fn test_state_of_read_write() {
785        let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
786        let ops = Mutex::new(Vec::new());
787        let scope = RunPolicy::default();
788        let pending = Mutex::new(Vec::new());
789
790        let ctx = make_ctx(&doc, &ops, &scope, &pending);
791
792        // Write
793        let ctrl = ctx.state_of::<TestFixtureState>();
794        ctrl.set_label(Some("rate_limit".into()))
795            .expect("failed to set label");
796
797        // Read back from same ref
798        let val = ctrl.label().unwrap();
799        assert!(val.is_some());
800        assert_eq!(val.unwrap(), "rate_limit");
801
802        // Ops captured in thread ops
803        assert!(!ops.lock().unwrap().is_empty());
804    }
805
806    #[test]
807    fn test_write_through_read_cross_ref() {
808        let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
809        let ops = Mutex::new(Vec::new());
810        let scope = RunPolicy::default();
811        let pending = Mutex::new(Vec::new());
812
813        let ctx = make_ctx(&doc, &ops, &scope, &pending);
814
815        // Write via first ref
816        ctx.state_of::<TestFixtureState>()
817            .set_label(Some("timeout".into()))
818            .expect("failed to set label");
819
820        // Read via second ref
821        let val = ctx.state_of::<TestFixtureState>().label().unwrap();
822        assert_eq!(val.unwrap(), "timeout");
823    }
824
825    #[test]
826    fn test_take_patch() {
827        let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
828        let ops = Mutex::new(Vec::new());
829        let scope = RunPolicy::default();
830        let pending = Mutex::new(Vec::new());
831
832        let ctx = make_ctx(&doc, &ops, &scope, &pending);
833
834        ctx.state_of::<TestFixtureState>()
835            .set_label(Some("test".into()))
836            .expect("failed to set label");
837
838        assert!(ctx.has_changes());
839        assert!(ctx.ops_count() > 0);
840
841        let patch = ctx.take_patch();
842        assert!(!patch.patch().is_empty());
843        assert_eq!(patch.source.as_deref(), Some("test"));
844        assert!(!ctx.has_changes());
845        assert_eq!(ctx.ops_count(), 0);
846    }
847
848    #[test]
849    fn test_add_messages() {
850        let doc = DocCell::new(json!({}));
851        let ops = Mutex::new(Vec::new());
852        let scope = RunPolicy::default();
853        let pending = Mutex::new(Vec::new());
854
855        let ctx = make_ctx(&doc, &ops, &scope, &pending);
856
857        ctx.add_message(Message::user("hello"));
858        ctx.add_messages(vec![Message::assistant("hi"), Message::user("bye")]);
859
860        assert_eq!(pending.lock().unwrap().len(), 3);
861    }
862
863    #[test]
864    fn test_call_state() {
865        let doc = DocCell::new(json!({"tool_calls": {}}));
866        let ops = Mutex::new(Vec::new());
867        let scope = RunPolicy::default();
868        let pending = Mutex::new(Vec::new());
869
870        let ctx = make_ctx(&doc, &ops, &scope, &pending);
871
872        let ctrl = ctx.call_state::<TestFixtureState>();
873        ctrl.set_label(Some("call_scoped".into()))
874            .expect("failed to set label");
875
876        assert!(ctx.has_changes());
877    }
878
879    #[test]
880    fn test_tool_call_state_roundtrip_and_resume_input() {
881        let doc = DocCell::new(json!({}));
882        let ops = Mutex::new(Vec::new());
883        let scope = RunPolicy::default();
884        let pending = Mutex::new(Vec::new());
885        let ctx = make_ctx(&doc, &ops, &scope, &pending);
886
887        let state = ToolCallState {
888            call_id: "call.1".to_string(),
889            tool_name: "confirm".to_string(),
890            arguments: json!({"value": 1}),
891            status: crate::runtime::ToolCallStatus::Resuming,
892            resume_token: Some("resume.1".to_string()),
893            resume: Some(crate::runtime::ToolCallResume {
894                decision_id: "decision_1".to_string(),
895                action: ResumeDecisionAction::Resume,
896                result: json!({"approved": true}),
897                reason: None,
898                updated_at: 123,
899            }),
900            scratch: json!({"k": "v"}),
901            updated_at: 124,
902        };
903
904        ctx.set_tool_call_state_for("call.1", state.clone())
905            .expect("state should be persisted");
906
907        let loaded = ctx
908            .tool_call_state_for("call.1")
909            .expect("state read should succeed");
910        assert_eq!(loaded, Some(state.clone()));
911
912        let resume = ctx
913            .resume_input_for("call.1")
914            .expect("resume read should succeed");
915        assert_eq!(resume, state.resume);
916    }
917
918    #[test]
919    fn test_clear_tool_call_state_for_removes_entry() {
920        let doc = DocCell::new(json!({}));
921        let ops = Mutex::new(Vec::new());
922        let scope = RunPolicy::default();
923        let pending = Mutex::new(Vec::new());
924        let ctx = make_ctx(&doc, &ops, &scope, &pending);
925
926        ctx.set_tool_call_state_for(
927            "call-1",
928            ToolCallState {
929                call_id: "call-1".to_string(),
930                tool_name: "echo".to_string(),
931                arguments: json!({"x": 1}),
932                status: crate::runtime::ToolCallStatus::Running,
933                resume_token: None,
934                resume: None,
935                scratch: Value::Null,
936                updated_at: 1,
937            },
938        )
939        .expect("state should be set");
940
941        ctx.clear_tool_call_state_for("call-1")
942            .expect("clear should succeed");
943        assert_eq!(
944            ctx.tool_call_state_for("call-1")
945                .expect("state read should succeed"),
946            None
947        );
948    }
949
950    #[test]
951    fn test_cancellation_token_absent_by_default() {
952        let doc = DocCell::new(json!({}));
953        let ops = Mutex::new(Vec::new());
954        let scope = RunPolicy::default();
955        let pending = Mutex::new(Vec::new());
956        let ctx = make_ctx(&doc, &ops, &scope, &pending);
957
958        assert!(!ctx.is_cancelled());
959        assert!(ctx.cancellation_token().is_none());
960    }
961
962    #[tokio::test]
963    async fn test_cancelled_waits_for_attached_token() {
964        let doc = DocCell::new(json!({}));
965        let ops = Mutex::new(Vec::new());
966        let scope = RunPolicy::default();
967        let pending = Mutex::new(Vec::new());
968        let token = CancellationToken::new();
969
970        let ctx = ToolCallContext::new(
971            &doc,
972            &ops,
973            "call-1",
974            "test",
975            &scope,
976            &pending,
977            NoOpActivityManager::arc(),
978        )
979        .with_cancellation_token(&token);
980
981        let token_for_task = token.clone();
982        tokio::spawn(async move {
983            tokio::time::sleep(Duration::from_millis(20)).await;
984            token_for_task.cancel();
985        });
986
987        timeout(Duration::from_millis(300), ctx.cancelled())
988            .await
989            .expect("cancelled() should resolve after token cancellation");
990    }
991
992    #[tokio::test]
993    async fn test_cancelled_without_token_never_resolves() {
994        let doc = DocCell::new(json!({}));
995        let ops = Mutex::new(Vec::new());
996        let scope = RunPolicy::default();
997        let pending = Mutex::new(Vec::new());
998        let ctx = make_ctx(&doc, &ops, &scope, &pending);
999
1000        let timed_out = timeout(Duration::from_millis(30), ctx.cancelled())
1001            .await
1002            .is_err();
1003        assert!(timed_out, "cancelled() without token should remain pending");
1004    }
1005
1006    #[derive(Default)]
1007    struct RecordingActivityManager {
1008        events: Mutex<Vec<(String, String, Op)>>,
1009    }
1010
1011    impl ActivityManager for RecordingActivityManager {
1012        fn snapshot(&self, _stream_id: &str) -> Value {
1013            json!({})
1014        }
1015
1016        fn on_activity_op(&self, stream_id: &str, activity_type: &str, op: &Op) {
1017            self.events.lock().unwrap().push((
1018                stream_id.to_string(),
1019                activity_type.to_string(),
1020                op.clone(),
1021            ));
1022        }
1023    }
1024
1025    fn rebuild_activity_state(events: &[(String, String, Op)]) -> Value {
1026        let mut value = json!({});
1027        for (_, _, op) in events {
1028            value = apply_patch(&value, &Patch::with_ops(vec![op.clone()]))
1029                .expect("activity op should apply");
1030        }
1031        value
1032    }
1033
1034    #[derive(Default)]
1035    struct RecordingProgressSink {
1036        events: Mutex<Vec<(String, String, ToolCallProgressState)>>,
1037    }
1038
1039    impl ToolCallProgressSink for RecordingProgressSink {
1040        fn report(
1041            &self,
1042            stream_id: &str,
1043            activity_type: &str,
1044            payload: &ToolCallProgressState,
1045        ) -> TireaResult<()> {
1046            self.events.lock().unwrap().push((
1047                stream_id.to_string(),
1048                activity_type.to_string(),
1049                payload.clone(),
1050            ));
1051            Ok(())
1052        }
1053    }
1054
1055    struct FailingProgressSink;
1056
1057    impl ToolCallProgressSink for FailingProgressSink {
1058        fn report(
1059            &self,
1060            _stream_id: &str,
1061            _activity_type: &str,
1062            _payload: &ToolCallProgressState,
1063        ) -> TireaResult<()> {
1064            Err(TireaError::invalid_operation("sink failed"))
1065        }
1066    }
1067
1068    #[test]
1069    fn test_report_tool_call_progress_emits_tool_call_progress_activity() {
1070        let doc = DocCell::new(json!({}));
1071        let ops = Mutex::new(Vec::new());
1072        let scope = RunPolicy::default();
1073        let pending = Mutex::new(Vec::new());
1074        let activity_manager = Arc::new(RecordingActivityManager::default());
1075
1076        let ctx = ToolCallContext::new(
1077            &doc,
1078            &ops,
1079            "call-1",
1080            "test",
1081            &scope,
1082            &pending,
1083            activity_manager.clone(),
1084        );
1085
1086        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1087            status: ToolCallProgressStatus::Running,
1088            progress: Some(0.5),
1089            loaded: None,
1090            total: Some(10.0),
1091            message: Some("half way".to_string()),
1092        })
1093        .expect("progress should be emitted");
1094
1095        let events = activity_manager.events.lock().unwrap();
1096        assert!(!events.is_empty());
1097        assert!(events.iter().all(|(stream_id, activity_type, _)| {
1098            stream_id == "tool_call:call-1" && activity_type == TOOL_CALL_PROGRESS_ACTIVITY_TYPE
1099        }));
1100        let state = rebuild_activity_state(&events);
1101        assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1102        assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1103        assert_eq!(state["node_id"], "tool_call:call-1");
1104        assert_eq!(state["call_id"], "call-1");
1105        assert_eq!(state["status"], "running");
1106        assert_eq!(state["progress"], json!(0.5));
1107        assert_eq!(state["total"], json!(10.0));
1108        assert_eq!(state["message"], json!("half way"));
1109    }
1110
1111    #[test]
1112    fn test_report_tool_call_progress_rejects_non_finite_values() {
1113        let doc = DocCell::new(json!({}));
1114        let ops = Mutex::new(Vec::new());
1115        let scope = RunPolicy::default();
1116        let pending = Mutex::new(Vec::new());
1117        let ctx = make_ctx(&doc, &ops, &scope, &pending);
1118
1119        assert!(ctx
1120            .report_tool_call_progress(ToolCallProgressUpdate {
1121                status: ToolCallProgressStatus::Running,
1122                progress: Some(f64::NAN),
1123                loaded: None,
1124                total: None,
1125                message: None,
1126            })
1127            .is_err());
1128        assert!(ctx
1129            .report_tool_call_progress(ToolCallProgressUpdate {
1130                status: ToolCallProgressStatus::Running,
1131                progress: Some(0.5),
1132                loaded: None,
1133                total: Some(f64::INFINITY),
1134                message: None,
1135            })
1136            .is_err());
1137        assert!(ctx
1138            .report_tool_call_progress(ToolCallProgressUpdate {
1139                status: ToolCallProgressStatus::Running,
1140                progress: Some(0.5),
1141                loaded: Some(-1.0),
1142                total: None,
1143                message: None,
1144            })
1145            .is_err());
1146    }
1147
1148    #[test]
1149    fn test_report_tool_call_progress_writes_lineage_and_metadata() {
1150        let doc = DocCell::new(json!({}));
1151        let ops = Mutex::new(Vec::new());
1152        let scope = RunPolicy::new();
1153        let pending = Mutex::new(Vec::new());
1154        let activity_manager = Arc::new(RecordingActivityManager::default());
1155        let run_identity = RunIdentity::new(
1156            "thread-abc".to_string(),
1157            None,
1158            "run-123".to_string(),
1159            Some("run-parent".to_string()),
1160            "agent".to_string(),
1161            crate::storage::RunOrigin::Internal,
1162        )
1163        .with_parent_tool_call_id("call-parent");
1164        let caller_context = CallerContext::new(
1165            Some("thread-abc".to_string()),
1166            Some("run-parent".to_string()),
1167            Some("caller".to_string()),
1168            vec![],
1169        );
1170
1171        let ctx = ToolCallContext::new(
1172            &doc,
1173            &ops,
1174            "call-1",
1175            "tool:echo",
1176            &scope,
1177            &pending,
1178            activity_manager.clone(),
1179        )
1180        .with_run_identity(run_identity)
1181        .with_caller_context(caller_context);
1182
1183        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1184            status: ToolCallProgressStatus::Done,
1185            progress: Some(1.0),
1186            loaded: Some(5.0),
1187            total: Some(5.0),
1188            message: Some("done".to_string()),
1189        })
1190        .expect("tool call progress should be emitted");
1191
1192        let events = activity_manager.events.lock().unwrap();
1193        let state = rebuild_activity_state(&events);
1194        assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1195        assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1196        assert_eq!(state["node_id"], "tool_call:call-1");
1197        assert_eq!(state["parent_node_id"], "tool_call:call-parent");
1198        assert_eq!(state["parent_call_id"], "call-parent");
1199        assert_eq!(state["tool_name"], "echo");
1200        assert_eq!(state["status"], "done");
1201        assert_eq!(state["run_id"], "run-123");
1202        assert_eq!(state["parent_run_id"], "run-parent");
1203        assert_eq!(state["thread_id"], "thread-abc");
1204        assert!(state["updated_at_ms"].as_u64().unwrap_or_default() > 0);
1205    }
1206
1207    #[test]
1208    fn test_report_tool_call_progress_without_parent_tool_call_anchors_to_run_node() {
1209        let doc = DocCell::new(json!({}));
1210        let ops = Mutex::new(Vec::new());
1211        let scope = RunPolicy::new();
1212        let pending = Mutex::new(Vec::new());
1213        let activity_manager = Arc::new(RecordingActivityManager::default());
1214        let run_identity = run_identity("run-123");
1215        let ctx = ToolCallContext::new(
1216            &doc,
1217            &ops,
1218            "call-1",
1219            "tool:echo",
1220            &scope,
1221            &pending,
1222            activity_manager.clone(),
1223        )
1224        .with_run_identity(run_identity);
1225
1226        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1227            status: ToolCallProgressStatus::Running,
1228            progress: Some(0.3),
1229            loaded: None,
1230            total: None,
1231            message: Some("working".to_string()),
1232        })
1233        .expect("tool call progress should be emitted");
1234
1235        let events = activity_manager.events.lock().unwrap();
1236        let state = rebuild_activity_state(&events);
1237        assert_eq!(state["parent_node_id"], "run:run-123");
1238        assert!(state["parent_call_id"].is_null());
1239    }
1240
1241    #[test]
1242    fn test_report_tool_call_progress_uses_injected_sink_instead_of_activity_manager() {
1243        let doc = DocCell::new(json!({}));
1244        let ops = Mutex::new(Vec::new());
1245        let scope = RunPolicy::default();
1246        let pending = Mutex::new(Vec::new());
1247        let activity_manager = Arc::new(RecordingActivityManager::default());
1248        let sink = Arc::new(RecordingProgressSink::default());
1249        let ctx = ToolCallContext::new(
1250            &doc,
1251            &ops,
1252            "call-1",
1253            "tool:echo",
1254            &scope,
1255            &pending,
1256            activity_manager.clone(),
1257        )
1258        .with_tool_call_progress_sink(sink.clone());
1259
1260        ctx.report_tool_call_progress(ToolCallProgressUpdate {
1261            status: ToolCallProgressStatus::Running,
1262            progress: Some(0.2),
1263            loaded: None,
1264            total: Some(10.0),
1265            message: Some("working".to_string()),
1266        })
1267        .expect("tool call progress should be reported");
1268
1269        let sink_events = sink.events.lock().unwrap();
1270        assert_eq!(sink_events.len(), 1);
1271        let (stream_id, activity_type, payload) = &sink_events[0];
1272        assert_eq!(stream_id, "tool_call:call-1");
1273        assert_eq!(activity_type, TOOL_CALL_PROGRESS_ACTIVITY_TYPE);
1274        assert_eq!(payload.call_id, "call-1");
1275        assert_eq!(payload.progress, Some(0.2));
1276
1277        let activity_events = activity_manager.events.lock().unwrap();
1278        assert!(
1279            activity_events.is_empty(),
1280            "injected sink should bypass default activity manager sink"
1281        );
1282    }
1283
1284    #[test]
1285    fn test_report_tool_call_progress_propagates_sink_error() {
1286        let doc = DocCell::new(json!({}));
1287        let ops = Mutex::new(Vec::new());
1288        let scope = RunPolicy::default();
1289        let pending = Mutex::new(Vec::new());
1290        let ctx = ToolCallContext::new(
1291            &doc,
1292            &ops,
1293            "call-1",
1294            "tool:echo",
1295            &scope,
1296            &pending,
1297            NoOpActivityManager::arc(),
1298        )
1299        .with_tool_call_progress_sink(Arc::new(FailingProgressSink));
1300
1301        let result = ctx.report_tool_call_progress(ToolCallProgressUpdate {
1302            status: ToolCallProgressStatus::Running,
1303            progress: Some(0.1),
1304            loaded: None,
1305            total: None,
1306            message: None,
1307        });
1308        assert!(result.is_err());
1309    }
1310}