1use crate::runtime::activity::ActivityManager;
8use crate::runtime::run::RunIdentity;
9use crate::runtime::{ToolCallResume, ToolCallState};
10use crate::thread::Message;
11use crate::RunPolicy;
12use futures::future::pending;
13use serde::{Deserialize, Serialize};
14use serde_json::Value;
15use std::sync::{Arc, Mutex};
16use std::time::{SystemTime, UNIX_EPOCH};
17use tirea_state::{
18 get_at_path, parse_path, DocCell, Op, Patch, PatchSink, Path, State, TireaError, TireaResult,
19 TrackedPatch,
20};
21use tokio_util::sync::CancellationToken;
22
23type PatchHook<'a> = Arc<dyn Fn(&Op) -> TireaResult<()> + Send + Sync + 'a>;
24const TOOL_PROGRESS_STREAM_PREFIX: &str = "tool_call:";
25pub const TOOL_CALL_PROGRESS_ACTIVITY_TYPE: &str = "tool-call-progress";
27pub const TOOL_PROGRESS_ACTIVITY_TYPE: &str = TOOL_CALL_PROGRESS_ACTIVITY_TYPE;
29pub const TOOL_PROGRESS_ACTIVITY_TYPE_LEGACY: &str = "progress";
31pub const TOOL_CALL_PROGRESS_TYPE: &str = "tool-call-progress";
33pub const TOOL_CALL_PROGRESS_SCHEMA: &str = "tool-call-progress.v1";
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum ToolCallProgressStatus {
40 Pending,
41 #[default]
42 Running,
43 Done,
44 Failed,
45 Cancelled,
46}
47
48#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
50pub struct ToolCallProgressState {
51 #[serde(rename = "type")]
53 pub event_type: String,
54 pub schema: String,
56 pub node_id: String,
58 #[serde(default)]
60 pub parent_node_id: Option<String>,
61 #[serde(default)]
63 pub parent_call_id: Option<String>,
64 pub call_id: String,
66 #[serde(default)]
68 pub tool_name: Option<String>,
69 pub status: ToolCallProgressStatus,
71 #[serde(default)]
73 pub progress: Option<f64>,
74 #[serde(default)]
76 pub loaded: Option<f64>,
77 #[serde(default)]
79 pub total: Option<f64>,
80 #[serde(default)]
82 pub message: Option<String>,
83 #[serde(default)]
85 pub run_id: Option<String>,
86 #[serde(default)]
88 pub parent_run_id: Option<String>,
89 #[serde(default)]
91 pub thread_id: Option<String>,
92 pub updated_at_ms: u64,
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct ToolCallProgressUpdate {
99 #[serde(default)]
100 pub status: ToolCallProgressStatus,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
102 pub progress: Option<f64>,
103 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub loaded: Option<f64>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub total: Option<f64>,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub message: Option<String>,
109}
110
111#[derive(Debug, Clone, Default, Serialize, Deserialize, State)]
113pub struct ToolProgressState {
114 pub progress: f64,
116 #[serde(default, skip_serializing_if = "Option::is_none")]
118 pub total: Option<f64>,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub message: Option<String>,
122}
123
124pub trait ToolCallProgressSink: Send + Sync {
130 fn report(
132 &self,
133 stream_id: &str,
134 activity_type: &str,
135 payload: &ToolCallProgressState,
136 ) -> TireaResult<()>;
137}
138
139#[derive(Clone)]
140struct ActivityManagerProgressSink {
141 manager: Arc<dyn ActivityManager>,
142}
143
144impl ActivityManagerProgressSink {
145 fn new(manager: Arc<dyn ActivityManager>) -> Self {
146 Self { manager }
147 }
148}
149
150#[derive(Clone, Debug, Default)]
152pub struct CallerContext {
153 thread_id: Option<String>,
154 run_id: Option<String>,
155 agent_id: Option<String>,
156 messages: Arc<[Arc<Message>]>,
157}
158
159impl CallerContext {
160 pub fn new(
161 thread_id: Option<String>,
162 run_id: Option<String>,
163 agent_id: Option<String>,
164 messages: Vec<Arc<Message>>,
165 ) -> Self {
166 Self {
167 thread_id: thread_id
168 .map(|value| value.trim().to_string())
169 .filter(|value| !value.is_empty()),
170 run_id: run_id
171 .map(|value| value.trim().to_string())
172 .filter(|value| !value.is_empty()),
173 agent_id: agent_id
174 .map(|value| value.trim().to_string())
175 .filter(|value| !value.is_empty()),
176 messages: Arc::<[Arc<Message>]>::from(messages),
177 }
178 }
179
180 pub fn thread_id(&self) -> Option<&str> {
181 self.thread_id.as_deref()
182 }
183
184 pub fn run_id(&self) -> Option<&str> {
185 self.run_id.as_deref()
186 }
187
188 pub fn agent_id(&self) -> Option<&str> {
189 self.agent_id.as_deref()
190 }
191
192 pub fn messages(&self) -> &[Arc<Message>] {
193 self.messages.as_ref()
194 }
195}
196
197impl ToolCallProgressSink for ActivityManagerProgressSink {
198 fn report(
199 &self,
200 stream_id: &str,
201 activity_type: &str,
202 payload: &ToolCallProgressState,
203 ) -> TireaResult<()> {
204 let Value::Object(fields) = serde_json::to_value(payload)? else {
205 return Err(TireaError::invalid_operation(
206 "tool-call-progress payload must serialize as object",
207 ));
208 };
209 for (key, value) in fields {
210 let op = Op::set(Path::root().key(key), value);
211 self.manager.on_activity_op(stream_id, activity_type, &op);
212 }
213 Ok(())
214 }
215}
216
217pub struct ToolCallContext<'a> {
223 doc: &'a DocCell,
224 ops: &'a Mutex<Vec<Op>>,
225 call_id: String,
226 source: String,
227 run_policy: &'a RunPolicy,
228 run_identity: RunIdentity,
229 caller_context: CallerContext,
230 pending_messages: &'a Mutex<Vec<Arc<Message>>>,
231 activity_manager: Arc<dyn ActivityManager>,
232 tool_call_progress_sink: Arc<dyn ToolCallProgressSink>,
233 cancellation_token: Option<&'a CancellationToken>,
234}
235
236impl<'a> ToolCallContext<'a> {
237 fn tool_call_state_path(call_id: &str) -> Path {
238 Path::root()
239 .key("__tool_call_scope")
240 .key(call_id)
241 .key("tool_call_state")
242 }
243
244 fn apply_op(&self, op: Op) -> TireaResult<()> {
245 self.doc.apply(&op)?;
246 self.ops.lock().unwrap().push(op);
247 Ok(())
248 }
249
250 pub fn new(
252 doc: &'a DocCell,
253 ops: &'a Mutex<Vec<Op>>,
254 call_id: impl Into<String>,
255 source: impl Into<String>,
256 run_policy: &'a RunPolicy,
257 pending_messages: &'a Mutex<Vec<Arc<Message>>>,
258 activity_manager: Arc<dyn ActivityManager>,
259 ) -> Self {
260 let tool_call_progress_sink: Arc<dyn ToolCallProgressSink> =
261 Arc::new(ActivityManagerProgressSink::new(activity_manager.clone()));
262 Self {
263 doc,
264 ops,
265 call_id: call_id.into(),
266 source: source.into(),
267 run_policy,
268 run_identity: RunIdentity::default(),
269 caller_context: CallerContext::default(),
270 pending_messages,
271 activity_manager,
272 tool_call_progress_sink,
273 cancellation_token: None,
274 }
275 }
276
277 #[must_use]
279 pub fn with_cancellation_token(mut self, token: &'a CancellationToken) -> Self {
280 self.cancellation_token = Some(token);
281 self
282 }
283
284 #[must_use]
285 pub fn with_run_identity(mut self, run_identity: RunIdentity) -> Self {
286 self.run_identity = run_identity;
287 self
288 }
289
290 #[must_use]
291 pub fn with_caller_context(mut self, caller_context: CallerContext) -> Self {
292 self.caller_context = caller_context;
293 self
294 }
295
296 #[must_use]
301 pub fn with_tool_call_progress_sink(mut self, sink: Arc<dyn ToolCallProgressSink>) -> Self {
302 self.tool_call_progress_sink = sink;
303 self
304 }
305
306 pub fn doc(&self) -> &DocCell {
312 self.doc
313 }
314
315 pub fn call_id(&self) -> &str {
317 &self.call_id
318 }
319
320 pub fn idempotency_key(&self) -> &str {
324 self.call_id()
325 }
326
327 pub fn source(&self) -> &str {
329 &self.source
330 }
331
332 pub fn is_cancelled(&self) -> bool {
334 self.cancellation_token
335 .is_some_and(CancellationToken::is_cancelled)
336 }
337
338 pub async fn cancelled(&self) {
342 if let Some(token) = self.cancellation_token {
343 token.cancelled().await;
344 } else {
345 pending::<()>().await;
346 }
347 }
348
349 pub fn cancellation_token(&self) -> Option<&CancellationToken> {
351 self.cancellation_token
352 }
353
354 pub fn run_policy(&self) -> &RunPolicy {
360 self.run_policy
361 }
362
363 pub fn run_identity(&self) -> &RunIdentity {
364 &self.run_identity
365 }
366
367 pub fn caller_context(&self) -> &CallerContext {
368 &self.caller_context
369 }
370
371 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
377 let base = parse_path(path);
378 let doc = self.doc;
379 let hook: PatchHook<'_> = Arc::new(|op: &Op| {
380 doc.apply(op)?;
381 Ok(())
382 });
383 T::state_ref(doc, base, PatchSink::new_with_hook(self.ops, hook))
384 }
385
386 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
390 assert!(
391 !T::PATH.is_empty(),
392 "State type has no bound path; use state::<T>(path) instead"
393 );
394 self.state::<T>(T::PATH)
395 }
396
397 pub fn call_state<T: State>(&self) -> T::Ref<'_> {
399 let path = format!("tool_calls.{}", self.call_id);
400 self.state::<T>(&path)
401 }
402
403 pub fn tool_call_state_for(&self, call_id: &str) -> TireaResult<Option<ToolCallState>> {
405 if call_id.trim().is_empty() {
406 return Ok(None);
407 }
408 let val = self.doc.snapshot();
409 let path = Self::tool_call_state_path(call_id);
410 let at = get_at_path(&val, &path);
411 match at {
412 Some(v) if !v.is_null() => {
413 let state = ToolCallState::from_value(v)?;
414 Ok(Some(state))
415 }
416 _ => Ok(None),
417 }
418 }
419
420 pub fn tool_call_state(&self) -> TireaResult<Option<ToolCallState>> {
422 self.tool_call_state_for(self.call_id())
423 }
424
425 pub fn set_tool_call_state_for(&self, call_id: &str, state: ToolCallState) -> TireaResult<()> {
427 if call_id.trim().is_empty() {
428 return Err(TireaError::invalid_operation(
429 "tool_call_state requires non-empty call_id",
430 ));
431 }
432 let value = serde_json::to_value(state)?;
433 self.apply_op(Op::set(Self::tool_call_state_path(call_id), value))
434 }
435
436 pub fn set_tool_call_state(&self, state: ToolCallState) -> TireaResult<()> {
438 self.set_tool_call_state_for(self.call_id(), state)
439 }
440
441 pub fn clear_tool_call_state_for(&self, call_id: &str) -> TireaResult<()> {
443 if call_id.trim().is_empty() {
444 return Ok(());
445 }
446 if self.tool_call_state_for(call_id)?.is_some() {
447 self.apply_op(Op::delete(Self::tool_call_state_path(call_id)))?;
448 }
449 Ok(())
450 }
451
452 pub fn clear_tool_call_state(&self) -> TireaResult<()> {
454 self.clear_tool_call_state_for(self.call_id())
455 }
456
457 pub fn resume_input_for(&self, call_id: &str) -> TireaResult<Option<ToolCallResume>> {
459 Ok(self
460 .tool_call_state_for(call_id)?
461 .and_then(|state| state.resume))
462 }
463
464 pub fn resume_input(&self) -> TireaResult<Option<ToolCallResume>> {
466 self.resume_input_for(self.call_id())
467 }
468
469 pub fn add_message(&self, message: Message) {
475 self.pending_messages
476 .lock()
477 .unwrap()
478 .push(Arc::new(message));
479 }
480
481 pub fn add_messages(&self, messages: impl IntoIterator<Item = Message>) {
483 self.pending_messages
484 .lock()
485 .unwrap()
486 .extend(messages.into_iter().map(Arc::new));
487 }
488
489 pub fn activity(
495 &self,
496 stream_id: impl Into<String>,
497 activity_type: impl Into<String>,
498 ) -> ActivityContext {
499 let stream_id = stream_id.into();
500 let activity_type = activity_type.into();
501 let snapshot = self.activity_manager.snapshot(&stream_id);
502
503 ActivityContext::new(
504 snapshot,
505 stream_id,
506 activity_type,
507 self.activity_manager.clone(),
508 )
509 }
510
511 pub fn progress_stream_id(&self) -> String {
513 format!("{TOOL_PROGRESS_STREAM_PREFIX}{}", self.call_id)
514 }
515
516 fn source_tool_name(&self) -> Option<String> {
517 self.source
518 .strip_prefix("tool:")
519 .filter(|name| !name.trim().is_empty())
520 .map(ToOwned::to_owned)
521 }
522
523 fn validate_progress_value(name: &str, value: Option<f64>) -> TireaResult<()> {
524 let Some(value) = value else {
525 return Ok(());
526 };
527 if !value.is_finite() {
528 return Err(TireaError::invalid_operation(format!(
529 "{name} must be a finite number"
530 )));
531 }
532 if value < 0.0 {
533 return Err(TireaError::invalid_operation(format!(
534 "{name} must be non-negative"
535 )));
536 }
537 Ok(())
538 }
539
540 pub fn report_tool_call_progress(&self, update: ToolCallProgressUpdate) -> TireaResult<()> {
545 Self::validate_progress_value("progress value", update.progress)?;
546 Self::validate_progress_value("progress loaded", update.loaded)?;
547 Self::validate_progress_value("progress total", update.total)?;
548
549 let run_id = self.run_identity.run_id_opt().map(ToOwned::to_owned);
550 let parent_run_id = self.run_identity.parent_run_id_opt().map(ToOwned::to_owned);
551 let thread_id = self.caller_context.thread_id().map(ToOwned::to_owned);
552 let parent_call_id = self.run_identity.parent_tool_call_id_opt().and_then(|id| {
553 if id == self.call_id {
554 None
555 } else {
556 Some(id.to_string())
557 }
558 });
559 let parent_node_id = parent_call_id
560 .as_ref()
561 .map(|id| format!("{TOOL_PROGRESS_STREAM_PREFIX}{id}"))
562 .or_else(|| run_id.as_ref().map(|id| format!("run:{id}")));
563 let stream_id = self.progress_stream_id();
564 let payload = ToolCallProgressState {
565 event_type: TOOL_CALL_PROGRESS_TYPE.to_string(),
566 schema: TOOL_CALL_PROGRESS_SCHEMA.to_string(),
567 node_id: stream_id.clone(),
568 parent_node_id,
569 parent_call_id,
570 call_id: self.call_id.clone(),
571 tool_name: self.source_tool_name(),
572 status: update.status,
573 progress: update.progress,
574 loaded: update.loaded,
575 total: update.total,
576 message: update.message,
577 run_id,
578 parent_run_id,
579 thread_id,
580 updated_at_ms: current_unix_millis(),
581 };
582
583 self.tool_call_progress_sink
584 .report(&stream_id, TOOL_CALL_PROGRESS_ACTIVITY_TYPE, &payload)
585 }
586
587 pub fn snapshot(&self) -> Value {
596 self.doc.snapshot()
597 }
598
599 pub fn snapshot_of<T: State>(&self) -> TireaResult<T> {
603 let val = self.doc.snapshot();
604 let at = get_at_path(&val, &parse_path(T::PATH)).unwrap_or(&Value::Null);
605 T::from_value(at)
606 }
607
608 pub fn snapshot_at<T: State>(&self, path: &str) -> TireaResult<T> {
612 let val = self.doc.snapshot();
613 let at = get_at_path(&val, &parse_path(path)).unwrap_or(&Value::Null);
614 T::from_value(at)
615 }
616
617 pub fn take_patch(&self) -> TrackedPatch {
623 let ops = std::mem::take(&mut *self.ops.lock().unwrap());
624 TrackedPatch::new(Patch::with_ops(ops)).with_source(self.source.clone())
625 }
626
627 pub fn has_changes(&self) -> bool {
629 !self.ops.lock().unwrap().is_empty()
630 }
631
632 pub fn ops_count(&self) -> usize {
634 self.ops.lock().unwrap().len()
635 }
636}
637
638fn current_unix_millis() -> u64 {
639 SystemTime::now()
640 .duration_since(UNIX_EPOCH)
641 .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
642}
643
644pub struct ActivityContext {
646 doc: DocCell,
647 stream_id: String,
648 activity_type: String,
649 ops: Mutex<Vec<Op>>,
650 manager: Arc<dyn ActivityManager>,
651}
652
653impl ActivityContext {
654 pub(crate) fn new(
655 doc: Value,
656 stream_id: String,
657 activity_type: String,
658 manager: Arc<dyn ActivityManager>,
659 ) -> Self {
660 Self {
661 doc: DocCell::new(doc),
662 stream_id,
663 activity_type,
664 ops: Mutex::new(Vec::new()),
665 manager,
666 }
667 }
668
669 pub fn state_of<T: State>(&self) -> T::Ref<'_> {
673 assert!(
674 !T::PATH.is_empty(),
675 "State type has no bound path; use state::<T>(path) instead"
676 );
677 self.state::<T>(T::PATH)
678 }
679
680 pub fn state<T: State>(&self, path: &str) -> T::Ref<'_> {
686 let base = parse_path(path);
687 let manager = self.manager.clone();
688 let stream_id = self.stream_id.clone();
689 let activity_type = self.activity_type.clone();
690 let doc = &self.doc;
691 let hook: PatchHook<'_> = Arc::new(move |op: &Op| {
692 doc.apply(op)?;
693 manager.on_activity_op(&stream_id, &activity_type, op);
694 Ok(())
695 });
696 T::state_ref(&self.doc, base, PatchSink::new_with_hook(&self.ops, hook))
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use crate::io::ResumeDecisionAction;
704 use crate::runtime::activity::{ActivityManager, NoOpActivityManager};
705 use crate::testing::TestFixtureState;
706 use serde_json::json;
707 use std::sync::Arc;
708 use tirea_state::apply_patch;
709 use tokio::time::{timeout, Duration};
710 use tokio_util::sync::CancellationToken;
711
712 fn make_ctx<'a>(
713 doc: &'a DocCell,
714 ops: &'a Mutex<Vec<Op>>,
715 run_policy: &'a RunPolicy,
716 pending: &'a Mutex<Vec<Arc<Message>>>,
717 ) -> ToolCallContext<'a> {
718 ToolCallContext::new(
719 doc,
720 ops,
721 "call-1",
722 "test",
723 run_policy,
724 pending,
725 NoOpActivityManager::arc(),
726 )
727 }
728
729 fn run_identity(run_id: &str) -> RunIdentity {
730 RunIdentity::new(
731 "thread-child".to_string(),
732 None,
733 run_id.to_string(),
734 None,
735 "agent".to_string(),
736 crate::storage::RunOrigin::Internal,
737 )
738 }
739
740 fn caller_context(thread_id: &str) -> CallerContext {
741 CallerContext::new(
742 Some(thread_id.to_string()),
743 Some("run-parent".to_string()),
744 Some("caller".to_string()),
745 vec![Arc::new(Message::user("seed"))],
746 )
747 }
748
749 #[test]
750 fn test_identity() {
751 let doc = DocCell::new(json!({}));
752 let ops = Mutex::new(Vec::new());
753 let scope = RunPolicy::default();
754 let pending = Mutex::new(Vec::new());
755
756 let ctx = make_ctx(&doc, &ops, &scope, &pending);
757 assert_eq!(ctx.call_id(), "call-1");
758 assert_eq!(ctx.idempotency_key(), "call-1");
759 assert_eq!(ctx.source(), "test");
760 }
761
762 #[test]
763 fn test_typed_context_access() {
764 let doc = DocCell::new(json!({}));
765 let ops = Mutex::new(Vec::new());
766 let scope = RunPolicy::new();
767 let pending = Mutex::new(Vec::new());
768
769 let ctx = make_ctx(&doc, &ops, &scope, &pending)
770 .with_run_identity(run_identity("run-1").with_parent_tool_call_id("call-parent"))
771 .with_caller_context(caller_context("thread-1"));
772
773 assert_eq!(
774 ctx.run_identity().parent_tool_call_id_opt(),
775 Some("call-parent")
776 );
777 assert_eq!(ctx.run_identity().run_id_opt(), Some("run-1"));
778 assert_eq!(ctx.caller_context().thread_id(), Some("thread-1"));
779 assert_eq!(ctx.caller_context().agent_id(), Some("caller"));
780 assert_eq!(ctx.caller_context().messages().len(), 1);
781 }
782
783 #[test]
784 fn test_state_of_read_write() {
785 let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
786 let ops = Mutex::new(Vec::new());
787 let scope = RunPolicy::default();
788 let pending = Mutex::new(Vec::new());
789
790 let ctx = make_ctx(&doc, &ops, &scope, &pending);
791
792 let ctrl = ctx.state_of::<TestFixtureState>();
794 ctrl.set_label(Some("rate_limit".into()))
795 .expect("failed to set label");
796
797 let val = ctrl.label().unwrap();
799 assert!(val.is_some());
800 assert_eq!(val.unwrap(), "rate_limit");
801
802 assert!(!ops.lock().unwrap().is_empty());
804 }
805
806 #[test]
807 fn test_write_through_read_cross_ref() {
808 let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
809 let ops = Mutex::new(Vec::new());
810 let scope = RunPolicy::default();
811 let pending = Mutex::new(Vec::new());
812
813 let ctx = make_ctx(&doc, &ops, &scope, &pending);
814
815 ctx.state_of::<TestFixtureState>()
817 .set_label(Some("timeout".into()))
818 .expect("failed to set label");
819
820 let val = ctx.state_of::<TestFixtureState>().label().unwrap();
822 assert_eq!(val.unwrap(), "timeout");
823 }
824
825 #[test]
826 fn test_take_patch() {
827 let doc = DocCell::new(json!({"__test_fixture": {"label": null}}));
828 let ops = Mutex::new(Vec::new());
829 let scope = RunPolicy::default();
830 let pending = Mutex::new(Vec::new());
831
832 let ctx = make_ctx(&doc, &ops, &scope, &pending);
833
834 ctx.state_of::<TestFixtureState>()
835 .set_label(Some("test".into()))
836 .expect("failed to set label");
837
838 assert!(ctx.has_changes());
839 assert!(ctx.ops_count() > 0);
840
841 let patch = ctx.take_patch();
842 assert!(!patch.patch().is_empty());
843 assert_eq!(patch.source.as_deref(), Some("test"));
844 assert!(!ctx.has_changes());
845 assert_eq!(ctx.ops_count(), 0);
846 }
847
848 #[test]
849 fn test_add_messages() {
850 let doc = DocCell::new(json!({}));
851 let ops = Mutex::new(Vec::new());
852 let scope = RunPolicy::default();
853 let pending = Mutex::new(Vec::new());
854
855 let ctx = make_ctx(&doc, &ops, &scope, &pending);
856
857 ctx.add_message(Message::user("hello"));
858 ctx.add_messages(vec![Message::assistant("hi"), Message::user("bye")]);
859
860 assert_eq!(pending.lock().unwrap().len(), 3);
861 }
862
863 #[test]
864 fn test_call_state() {
865 let doc = DocCell::new(json!({"tool_calls": {}}));
866 let ops = Mutex::new(Vec::new());
867 let scope = RunPolicy::default();
868 let pending = Mutex::new(Vec::new());
869
870 let ctx = make_ctx(&doc, &ops, &scope, &pending);
871
872 let ctrl = ctx.call_state::<TestFixtureState>();
873 ctrl.set_label(Some("call_scoped".into()))
874 .expect("failed to set label");
875
876 assert!(ctx.has_changes());
877 }
878
879 #[test]
880 fn test_tool_call_state_roundtrip_and_resume_input() {
881 let doc = DocCell::new(json!({}));
882 let ops = Mutex::new(Vec::new());
883 let scope = RunPolicy::default();
884 let pending = Mutex::new(Vec::new());
885 let ctx = make_ctx(&doc, &ops, &scope, &pending);
886
887 let state = ToolCallState {
888 call_id: "call.1".to_string(),
889 tool_name: "confirm".to_string(),
890 arguments: json!({"value": 1}),
891 status: crate::runtime::ToolCallStatus::Resuming,
892 resume_token: Some("resume.1".to_string()),
893 resume: Some(crate::runtime::ToolCallResume {
894 decision_id: "decision_1".to_string(),
895 action: ResumeDecisionAction::Resume,
896 result: json!({"approved": true}),
897 reason: None,
898 updated_at: 123,
899 }),
900 scratch: json!({"k": "v"}),
901 updated_at: 124,
902 };
903
904 ctx.set_tool_call_state_for("call.1", state.clone())
905 .expect("state should be persisted");
906
907 let loaded = ctx
908 .tool_call_state_for("call.1")
909 .expect("state read should succeed");
910 assert_eq!(loaded, Some(state.clone()));
911
912 let resume = ctx
913 .resume_input_for("call.1")
914 .expect("resume read should succeed");
915 assert_eq!(resume, state.resume);
916 }
917
918 #[test]
919 fn test_clear_tool_call_state_for_removes_entry() {
920 let doc = DocCell::new(json!({}));
921 let ops = Mutex::new(Vec::new());
922 let scope = RunPolicy::default();
923 let pending = Mutex::new(Vec::new());
924 let ctx = make_ctx(&doc, &ops, &scope, &pending);
925
926 ctx.set_tool_call_state_for(
927 "call-1",
928 ToolCallState {
929 call_id: "call-1".to_string(),
930 tool_name: "echo".to_string(),
931 arguments: json!({"x": 1}),
932 status: crate::runtime::ToolCallStatus::Running,
933 resume_token: None,
934 resume: None,
935 scratch: Value::Null,
936 updated_at: 1,
937 },
938 )
939 .expect("state should be set");
940
941 ctx.clear_tool_call_state_for("call-1")
942 .expect("clear should succeed");
943 assert_eq!(
944 ctx.tool_call_state_for("call-1")
945 .expect("state read should succeed"),
946 None
947 );
948 }
949
950 #[test]
951 fn test_cancellation_token_absent_by_default() {
952 let doc = DocCell::new(json!({}));
953 let ops = Mutex::new(Vec::new());
954 let scope = RunPolicy::default();
955 let pending = Mutex::new(Vec::new());
956 let ctx = make_ctx(&doc, &ops, &scope, &pending);
957
958 assert!(!ctx.is_cancelled());
959 assert!(ctx.cancellation_token().is_none());
960 }
961
962 #[tokio::test]
963 async fn test_cancelled_waits_for_attached_token() {
964 let doc = DocCell::new(json!({}));
965 let ops = Mutex::new(Vec::new());
966 let scope = RunPolicy::default();
967 let pending = Mutex::new(Vec::new());
968 let token = CancellationToken::new();
969
970 let ctx = ToolCallContext::new(
971 &doc,
972 &ops,
973 "call-1",
974 "test",
975 &scope,
976 &pending,
977 NoOpActivityManager::arc(),
978 )
979 .with_cancellation_token(&token);
980
981 let token_for_task = token.clone();
982 tokio::spawn(async move {
983 tokio::time::sleep(Duration::from_millis(20)).await;
984 token_for_task.cancel();
985 });
986
987 timeout(Duration::from_millis(300), ctx.cancelled())
988 .await
989 .expect("cancelled() should resolve after token cancellation");
990 }
991
992 #[tokio::test]
993 async fn test_cancelled_without_token_never_resolves() {
994 let doc = DocCell::new(json!({}));
995 let ops = Mutex::new(Vec::new());
996 let scope = RunPolicy::default();
997 let pending = Mutex::new(Vec::new());
998 let ctx = make_ctx(&doc, &ops, &scope, &pending);
999
1000 let timed_out = timeout(Duration::from_millis(30), ctx.cancelled())
1001 .await
1002 .is_err();
1003 assert!(timed_out, "cancelled() without token should remain pending");
1004 }
1005
1006 #[derive(Default)]
1007 struct RecordingActivityManager {
1008 events: Mutex<Vec<(String, String, Op)>>,
1009 }
1010
1011 impl ActivityManager for RecordingActivityManager {
1012 fn snapshot(&self, _stream_id: &str) -> Value {
1013 json!({})
1014 }
1015
1016 fn on_activity_op(&self, stream_id: &str, activity_type: &str, op: &Op) {
1017 self.events.lock().unwrap().push((
1018 stream_id.to_string(),
1019 activity_type.to_string(),
1020 op.clone(),
1021 ));
1022 }
1023 }
1024
1025 fn rebuild_activity_state(events: &[(String, String, Op)]) -> Value {
1026 let mut value = json!({});
1027 for (_, _, op) in events {
1028 value = apply_patch(&value, &Patch::with_ops(vec![op.clone()]))
1029 .expect("activity op should apply");
1030 }
1031 value
1032 }
1033
1034 #[derive(Default)]
1035 struct RecordingProgressSink {
1036 events: Mutex<Vec<(String, String, ToolCallProgressState)>>,
1037 }
1038
1039 impl ToolCallProgressSink for RecordingProgressSink {
1040 fn report(
1041 &self,
1042 stream_id: &str,
1043 activity_type: &str,
1044 payload: &ToolCallProgressState,
1045 ) -> TireaResult<()> {
1046 self.events.lock().unwrap().push((
1047 stream_id.to_string(),
1048 activity_type.to_string(),
1049 payload.clone(),
1050 ));
1051 Ok(())
1052 }
1053 }
1054
1055 struct FailingProgressSink;
1056
1057 impl ToolCallProgressSink for FailingProgressSink {
1058 fn report(
1059 &self,
1060 _stream_id: &str,
1061 _activity_type: &str,
1062 _payload: &ToolCallProgressState,
1063 ) -> TireaResult<()> {
1064 Err(TireaError::invalid_operation("sink failed"))
1065 }
1066 }
1067
1068 #[test]
1069 fn test_report_tool_call_progress_emits_tool_call_progress_activity() {
1070 let doc = DocCell::new(json!({}));
1071 let ops = Mutex::new(Vec::new());
1072 let scope = RunPolicy::default();
1073 let pending = Mutex::new(Vec::new());
1074 let activity_manager = Arc::new(RecordingActivityManager::default());
1075
1076 let ctx = ToolCallContext::new(
1077 &doc,
1078 &ops,
1079 "call-1",
1080 "test",
1081 &scope,
1082 &pending,
1083 activity_manager.clone(),
1084 );
1085
1086 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1087 status: ToolCallProgressStatus::Running,
1088 progress: Some(0.5),
1089 loaded: None,
1090 total: Some(10.0),
1091 message: Some("half way".to_string()),
1092 })
1093 .expect("progress should be emitted");
1094
1095 let events = activity_manager.events.lock().unwrap();
1096 assert!(!events.is_empty());
1097 assert!(events.iter().all(|(stream_id, activity_type, _)| {
1098 stream_id == "tool_call:call-1" && activity_type == TOOL_CALL_PROGRESS_ACTIVITY_TYPE
1099 }));
1100 let state = rebuild_activity_state(&events);
1101 assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1102 assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1103 assert_eq!(state["node_id"], "tool_call:call-1");
1104 assert_eq!(state["call_id"], "call-1");
1105 assert_eq!(state["status"], "running");
1106 assert_eq!(state["progress"], json!(0.5));
1107 assert_eq!(state["total"], json!(10.0));
1108 assert_eq!(state["message"], json!("half way"));
1109 }
1110
1111 #[test]
1112 fn test_report_tool_call_progress_rejects_non_finite_values() {
1113 let doc = DocCell::new(json!({}));
1114 let ops = Mutex::new(Vec::new());
1115 let scope = RunPolicy::default();
1116 let pending = Mutex::new(Vec::new());
1117 let ctx = make_ctx(&doc, &ops, &scope, &pending);
1118
1119 assert!(ctx
1120 .report_tool_call_progress(ToolCallProgressUpdate {
1121 status: ToolCallProgressStatus::Running,
1122 progress: Some(f64::NAN),
1123 loaded: None,
1124 total: None,
1125 message: None,
1126 })
1127 .is_err());
1128 assert!(ctx
1129 .report_tool_call_progress(ToolCallProgressUpdate {
1130 status: ToolCallProgressStatus::Running,
1131 progress: Some(0.5),
1132 loaded: None,
1133 total: Some(f64::INFINITY),
1134 message: None,
1135 })
1136 .is_err());
1137 assert!(ctx
1138 .report_tool_call_progress(ToolCallProgressUpdate {
1139 status: ToolCallProgressStatus::Running,
1140 progress: Some(0.5),
1141 loaded: Some(-1.0),
1142 total: None,
1143 message: None,
1144 })
1145 .is_err());
1146 }
1147
1148 #[test]
1149 fn test_report_tool_call_progress_writes_lineage_and_metadata() {
1150 let doc = DocCell::new(json!({}));
1151 let ops = Mutex::new(Vec::new());
1152 let scope = RunPolicy::new();
1153 let pending = Mutex::new(Vec::new());
1154 let activity_manager = Arc::new(RecordingActivityManager::default());
1155 let run_identity = RunIdentity::new(
1156 "thread-abc".to_string(),
1157 None,
1158 "run-123".to_string(),
1159 Some("run-parent".to_string()),
1160 "agent".to_string(),
1161 crate::storage::RunOrigin::Internal,
1162 )
1163 .with_parent_tool_call_id("call-parent");
1164 let caller_context = CallerContext::new(
1165 Some("thread-abc".to_string()),
1166 Some("run-parent".to_string()),
1167 Some("caller".to_string()),
1168 vec![],
1169 );
1170
1171 let ctx = ToolCallContext::new(
1172 &doc,
1173 &ops,
1174 "call-1",
1175 "tool:echo",
1176 &scope,
1177 &pending,
1178 activity_manager.clone(),
1179 )
1180 .with_run_identity(run_identity)
1181 .with_caller_context(caller_context);
1182
1183 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1184 status: ToolCallProgressStatus::Done,
1185 progress: Some(1.0),
1186 loaded: Some(5.0),
1187 total: Some(5.0),
1188 message: Some("done".to_string()),
1189 })
1190 .expect("tool call progress should be emitted");
1191
1192 let events = activity_manager.events.lock().unwrap();
1193 let state = rebuild_activity_state(&events);
1194 assert_eq!(state["type"], TOOL_CALL_PROGRESS_TYPE);
1195 assert_eq!(state["schema"], TOOL_CALL_PROGRESS_SCHEMA);
1196 assert_eq!(state["node_id"], "tool_call:call-1");
1197 assert_eq!(state["parent_node_id"], "tool_call:call-parent");
1198 assert_eq!(state["parent_call_id"], "call-parent");
1199 assert_eq!(state["tool_name"], "echo");
1200 assert_eq!(state["status"], "done");
1201 assert_eq!(state["run_id"], "run-123");
1202 assert_eq!(state["parent_run_id"], "run-parent");
1203 assert_eq!(state["thread_id"], "thread-abc");
1204 assert!(state["updated_at_ms"].as_u64().unwrap_or_default() > 0);
1205 }
1206
1207 #[test]
1208 fn test_report_tool_call_progress_without_parent_tool_call_anchors_to_run_node() {
1209 let doc = DocCell::new(json!({}));
1210 let ops = Mutex::new(Vec::new());
1211 let scope = RunPolicy::new();
1212 let pending = Mutex::new(Vec::new());
1213 let activity_manager = Arc::new(RecordingActivityManager::default());
1214 let run_identity = run_identity("run-123");
1215 let ctx = ToolCallContext::new(
1216 &doc,
1217 &ops,
1218 "call-1",
1219 "tool:echo",
1220 &scope,
1221 &pending,
1222 activity_manager.clone(),
1223 )
1224 .with_run_identity(run_identity);
1225
1226 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1227 status: ToolCallProgressStatus::Running,
1228 progress: Some(0.3),
1229 loaded: None,
1230 total: None,
1231 message: Some("working".to_string()),
1232 })
1233 .expect("tool call progress should be emitted");
1234
1235 let events = activity_manager.events.lock().unwrap();
1236 let state = rebuild_activity_state(&events);
1237 assert_eq!(state["parent_node_id"], "run:run-123");
1238 assert!(state["parent_call_id"].is_null());
1239 }
1240
1241 #[test]
1242 fn test_report_tool_call_progress_uses_injected_sink_instead_of_activity_manager() {
1243 let doc = DocCell::new(json!({}));
1244 let ops = Mutex::new(Vec::new());
1245 let scope = RunPolicy::default();
1246 let pending = Mutex::new(Vec::new());
1247 let activity_manager = Arc::new(RecordingActivityManager::default());
1248 let sink = Arc::new(RecordingProgressSink::default());
1249 let ctx = ToolCallContext::new(
1250 &doc,
1251 &ops,
1252 "call-1",
1253 "tool:echo",
1254 &scope,
1255 &pending,
1256 activity_manager.clone(),
1257 )
1258 .with_tool_call_progress_sink(sink.clone());
1259
1260 ctx.report_tool_call_progress(ToolCallProgressUpdate {
1261 status: ToolCallProgressStatus::Running,
1262 progress: Some(0.2),
1263 loaded: None,
1264 total: Some(10.0),
1265 message: Some("working".to_string()),
1266 })
1267 .expect("tool call progress should be reported");
1268
1269 let sink_events = sink.events.lock().unwrap();
1270 assert_eq!(sink_events.len(), 1);
1271 let (stream_id, activity_type, payload) = &sink_events[0];
1272 assert_eq!(stream_id, "tool_call:call-1");
1273 assert_eq!(activity_type, TOOL_CALL_PROGRESS_ACTIVITY_TYPE);
1274 assert_eq!(payload.call_id, "call-1");
1275 assert_eq!(payload.progress, Some(0.2));
1276
1277 let activity_events = activity_manager.events.lock().unwrap();
1278 assert!(
1279 activity_events.is_empty(),
1280 "injected sink should bypass default activity manager sink"
1281 );
1282 }
1283
1284 #[test]
1285 fn test_report_tool_call_progress_propagates_sink_error() {
1286 let doc = DocCell::new(json!({}));
1287 let ops = Mutex::new(Vec::new());
1288 let scope = RunPolicy::default();
1289 let pending = Mutex::new(Vec::new());
1290 let ctx = ToolCallContext::new(
1291 &doc,
1292 &ops,
1293 "call-1",
1294 "tool:echo",
1295 &scope,
1296 &pending,
1297 NoOpActivityManager::arc(),
1298 )
1299 .with_tool_call_progress_sink(Arc::new(FailingProgressSink));
1300
1301 let result = ctx.report_tool_call_progress(ToolCallProgressUpdate {
1302 status: ToolCallProgressStatus::Running,
1303 progress: Some(0.1),
1304 loaded: None,
1305 total: None,
1306 message: None,
1307 });
1308 assert!(result.is_err());
1309 }
1310}