1use crate::{
2 context::{ContextFile, EnvironmentContext, ProjectContext, SessionContextBuilder},
3 message_bridge,
4 sandbox::{SandboxConfig, SandboxMode, SandboxedMcpServer},
5 state::AppState,
6 types::{RunConfig, SessionHandle},
7};
8use async_trait::async_trait;
9use rmcp::model::{
10 CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
11 CancelledNotificationParam, ServerResult,
12};
13use serde_json::json;
14use stakai::{ContentPart, Message, MessageContent, Role};
15use stakpak_agent_core::{
16 AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, BudgetAwareContextReducer,
17 CheckpointEnvelopeV1, CompactionConfig, PassthroughCompactionEngine, ProposedToolCall,
18 RetryConfig, ToolExecutionResult, ToolExecutor, run_agent,
19};
20use stakpak_api::CreateCheckpointRequest;
21use stakpak_mcp_client::McpClient;
22use stakpak_shared::utils::sanitize_text_output;
23use std::{path::Path, sync::Arc};
24use tokio::sync::{Mutex, mpsc};
25use tokio_util::sync::CancellationToken;
26use uuid::Uuid;
27
28const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
29pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
30
31pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
32 AgentRunContext { run_id, session_id }
33}
34
35pub fn build_checkpoint_envelope(
36 run_id: Uuid,
37 messages: Vec<stakai::Message>,
38 metadata: serde_json::Value,
39) -> CheckpointEnvelopeV1 {
40 CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
41}
42
43pub fn spawn_session_actor(
44 state: AppState,
45 session_id: Uuid,
46 run_id: Uuid,
47 run_config: RunConfig,
48 user_message: Message,
49 caller_context: Vec<ContextFile>,
50 sandbox_config: Option<SandboxConfig>,
51) -> Result<SessionHandle, String> {
52 let (command_tx, command_rx) = mpsc::channel(128);
53 let cancel = CancellationToken::new();
54
55 let handle = SessionHandle::new(command_tx, cancel.clone());
56
57 let state_for_task = state.clone();
58 tokio::spawn(async move {
59 let actor_result = run_session_actor(
60 state_for_task.clone(),
61 session_id,
62 run_id,
63 run_config,
64 user_message,
65 caller_context,
66 command_rx,
67 cancel,
68 sandbox_config,
69 )
70 .await;
71
72 let finish_result = actor_result.map(|_| ());
73 let _ = state_for_task
74 .run_manager
75 .mark_run_finished(session_id, run_id, finish_result)
76 .await;
77 });
78
79 Ok(handle)
80}
81
82#[allow(clippy::too_many_arguments)]
83async fn run_session_actor(
84 state: AppState,
85 session_id: Uuid,
86 run_id: Uuid,
87 run_config: RunConfig,
88 mut user_message: Message,
89 caller_context: Vec<ContextFile>,
90 command_rx: mpsc::Receiver<AgentCommand>,
91 cancel: CancellationToken,
92 sandbox_config: Option<SandboxConfig>,
93) -> Result<(), String> {
94 let active_checkpoint = state
95 .session_store
96 .get_active_checkpoint(session_id)
97 .await
98 .ok();
99 let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
100
101 let (initial_messages, mut initial_metadata) =
102 match state.checkpoint_store.load_latest(session_id).await {
103 Ok(Some(envelope)) => (envelope.messages, envelope.metadata),
104 Ok(None) => {
105 let messages = active_checkpoint
106 .as_ref()
107 .map(|checkpoint| {
108 message_bridge::chat_to_stakai(checkpoint.state.messages.clone())
109 })
110 .unwrap_or_default();
111 let metadata = active_checkpoint
112 .as_ref()
113 .and_then(|checkpoint| checkpoint.state.metadata.clone())
114 .unwrap_or_else(|| json!({}));
115 (messages, metadata)
116 }
117 Err(error) => {
118 return Err(format!("Failed to load checkpoint envelope: {error}"));
119 }
120 };
121
122 let mut ephemeral_sandbox: Option<SandboxedMcpServer> = None;
129
130 let (run_tools, tool_executor): (Vec<stakai::Tool>, Box<dyn ToolExecutor + Send + Sync>) =
131 if let Some(ref sandbox_cfg) = sandbox_config {
132 if let Some(ref persistent) = state.persistent_sandbox {
133 tracing::info!(session_id = %session_id, "Using persistent sandbox for session");
135 (
136 persistent.tools().await,
137 Box::new(SandboxedToolExecutor {
138 mcp_client: persistent.client().await,
139 }),
140 )
141 } else if sandbox_cfg.mode == SandboxMode::Persistent {
142 return Err(format!(
147 "Sandbox mode is 'persistent' but no persistent sandbox is available for session {session_id}. \
148 This indicates the server started without a healthy sandbox. Restart the autopilot to fix."
149 ));
150 } else {
151 tracing::info!(session_id = %session_id, image = %sandbox_cfg.image, "Spawning ephemeral sandbox container for session");
153 let sandbox = SandboxedMcpServer::spawn(sandbox_cfg).await.map_err(|e| {
154 format!("Failed to start sandbox for session {session_id}: {e}")
155 })?;
156 let tools = sandbox.tools.clone();
157 let client = sandbox.client.clone();
158 ephemeral_sandbox = Some(sandbox);
159 (
160 tools,
161 Box::new(SandboxedToolExecutor { mcp_client: client }),
162 )
163 }
164 } else {
165 (
166 state.current_mcp_tools().await,
167 Box::new(ServerToolExecutor {
168 state: state.clone(),
169 }),
170 )
171 };
172
173 let is_new_session = is_new_session_history(&initial_messages);
174 let session_cwd = resolve_session_cwd(&state, session_id).await;
175 let environment = EnvironmentContext::snapshot(&session_cwd).await;
176
177 let has_runtime_caller_context = !caller_context.is_empty();
181 let mut all_caller_context = caller_context;
182 all_caller_context.extend(state.current_skills().await);
183
184 let project =
185 ProjectContext::discover(Path::new(&session_cwd)).with_caller_context(all_caller_context);
186
187 let session_context = SessionContextBuilder::new()
188 .base_system_prompt(
189 run_config
190 .system_prompt
191 .clone()
192 .or_else(|| state.base_system_prompt.clone())
193 .unwrap_or_default(),
194 )
195 .environment(environment)
196 .project(project)
197 .tools(&run_tools)
198 .budget(state.context_budget.clone())
199 .build();
200
201 if (is_new_session || has_runtime_caller_context)
202 && let Some(context_block) = session_context.user_context_block.as_deref()
203 {
204 user_message = prepend_context_to_user_message(user_message, context_block);
205 }
206
207 let mut baseline_messages = initial_messages.clone();
208 baseline_messages.push(user_message.clone());
209
210 let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
211 state.clone(),
212 session_id,
213 run_id,
214 run_config.model.clone(),
215 parent_checkpoint_id,
216 baseline_messages,
217 initial_metadata.clone(),
218 ));
219
220 checkpoint_runtime
221 .persist_snapshot()
222 .await
223 .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
224
225 let periodic_checkpoint_cancel = CancellationToken::new();
226 let periodic_checkpoint_runtime = checkpoint_runtime.clone();
227 let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
228 let periodic_task = tokio::spawn(async move {
229 let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
230 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
231
232 loop {
233 tokio::select! {
234 _ = periodic_checkpoint_cancel_task.cancelled() => break,
235 _ = interval.tick() => {
236 let _ = periodic_checkpoint_runtime.persist_snapshot().await;
237 }
238 }
239 }
240 });
241
242 let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
243
244 let event_state = state.clone();
245 let event_forwarder = tokio::spawn(async move {
246 while let Some(event) = core_event_rx.recv().await {
247 handle_core_event(&event_state, session_id, run_id, event).await;
248 }
249 });
250
251 let max_output_tokens = run_config.model.limit.output as u32;
255 let agent_config = AgentConfig {
256 model: run_config.model.clone(),
257 system_prompt: session_context.system_prompt,
258 max_turns: run_config.max_turns,
259 max_output_tokens,
260 provider_options: None,
261 tool_approval: run_config.tool_approval_policy.clone(),
262 retry: RetryConfig::default(),
263 compaction: CompactionConfig::default(),
264 tools: run_tools,
265 };
266
267 let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
268 checkpoint_runtime: checkpoint_runtime.clone(),
269 })];
270
271 let compactor = PassthroughCompactionEngine;
272 let context_reducer = BudgetAwareContextReducer::new(5, 0.8);
273 let run_context = build_run_context(session_id, run_id);
274
275 let run_result = run_agent(
276 run_context,
277 run_config.inference.as_ref(),
278 &agent_config,
279 initial_messages,
280 &mut initial_metadata,
281 user_message,
282 tool_executor.as_ref(),
283 &hooks,
284 core_event_tx,
285 command_rx,
286 cancel,
287 &compactor,
288 &context_reducer,
289 )
290 .await;
291
292 periodic_checkpoint_cancel.cancel();
293 let _ = periodic_task.await;
294
295 if let Some(sandbox) = ephemeral_sandbox {
298 sandbox.shutdown().await;
299 }
300
301 state.clear_pending_tools(session_id, run_id).await;
302
303 match &run_result {
304 Ok(result) => {
305 checkpoint_runtime.update_messages(&result.messages).await;
306 checkpoint_runtime.update_metadata(&result.metadata).await;
307 checkpoint_runtime
308 .persist_snapshot()
309 .await
310 .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
311 }
312 Err(_) => {
313 checkpoint_runtime.update_metadata(&initial_metadata).await;
314 let _ = checkpoint_runtime.persist_snapshot().await;
315 }
316 }
317
318 let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
319
320 run_result
321 .map(|_| ())
322 .map_err(|error| format!("Agent run failed: {error}"))
323}
324
325fn is_new_session_history(messages: &[Message]) -> bool {
326 !messages
327 .iter()
328 .any(|message| matches!(message.role, Role::User | Role::Assistant | Role::Tool))
329}
330
331async fn resolve_session_cwd(state: &AppState, session_id: Uuid) -> String {
332 if let Ok(session) = state.session_store.get_session(session_id).await
334 && let Some(cwd) = session.cwd
335 && !cwd.trim().is_empty()
336 {
337 return cwd;
338 }
339
340 if let Some(project_dir) = &state.project_dir {
342 return project_dir.clone();
343 }
344
345 std::env::current_dir()
347 .ok()
348 .map(|path| path.to_string_lossy().to_string())
349 .unwrap_or_else(|| ".".to_string())
350}
351
352fn prepend_context_to_user_message(mut message: Message, context_block: &str) -> Message {
353 if context_block.trim().is_empty() {
354 return message;
355 }
356
357 match &mut message.content {
358 MessageContent::Text(text) => {
359 let existing = std::mem::take(text);
360 *text = if existing.trim().is_empty() {
361 context_block.to_string()
362 } else {
363 format!("{context_block}\n\n{existing}")
364 };
365 }
366 MessageContent::Parts(parts) => {
367 let mut prefixed = Vec::with_capacity(parts.len() + 1);
368 prefixed.push(ContentPart::text(context_block));
369 prefixed.append(parts);
370 *parts = prefixed;
371 }
372 }
373
374 message
375}
376
377async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
378 match &event {
379 AgentEvent::ToolCallsProposed { tool_calls, .. } => {
380 state
381 .set_pending_tools(session_id, run_id, tool_calls.clone())
382 .await;
383 }
384 AgentEvent::TurnCompleted { .. }
385 | AgentEvent::RunCompleted { .. }
386 | AgentEvent::RunError { .. } => {
387 state.clear_pending_tools(session_id, run_id).await;
388 }
389 _ => {}
390 }
391
392 state.events.publish(session_id, Some(run_id), event).await;
393}
394
395#[derive(Clone)]
396struct ServerToolExecutor {
397 state: AppState,
398}
399
400#[async_trait]
401impl ToolExecutor for ServerToolExecutor {
402 async fn execute_tool_call(
403 &self,
404 run: &AgentRunContext,
405 tool_call: &ProposedToolCall,
406 cancel: &CancellationToken,
407 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
408 Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
409 }
410}
411
412#[derive(Clone)]
414struct SandboxedToolExecutor {
415 mcp_client: Arc<McpClient>,
416}
417
418#[async_trait]
419impl ToolExecutor for SandboxedToolExecutor {
420 async fn execute_tool_call(
421 &self,
422 run: &AgentRunContext,
423 tool_call: &ProposedToolCall,
424 cancel: &CancellationToken,
425 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
426 Ok(execute_mcp_tool_call_with_client(
427 &self.mcp_client,
428 run.session_id,
429 run.run_id,
430 tool_call,
431 cancel,
432 )
433 .await)
434 }
435}
436
437struct CheckpointRuntime {
438 state: AppState,
439 session_id: Uuid,
440 run_id: Uuid,
441 active_model: stakai::Model,
442 inner: Mutex<CheckpointRuntimeInner>,
443}
444
445struct CheckpointRuntimeInner {
446 parent_checkpoint_id: Option<Uuid>,
447 latest_messages: Vec<Message>,
448 latest_metadata: serde_json::Value,
449 last_persisted_signature: Option<String>,
450 dirty: bool,
451}
452
453impl CheckpointRuntime {
454 fn new(
455 state: AppState,
456 session_id: Uuid,
457 run_id: Uuid,
458 active_model: stakai::Model,
459 parent_checkpoint_id: Option<Uuid>,
460 latest_messages: Vec<Message>,
461 latest_metadata: serde_json::Value,
462 ) -> Self {
463 Self {
464 state,
465 session_id,
466 run_id,
467 active_model,
468 inner: Mutex::new(CheckpointRuntimeInner {
469 parent_checkpoint_id,
470 latest_messages,
471 latest_metadata,
472 last_persisted_signature: None,
473 dirty: true,
474 }),
475 }
476 }
477
478 async fn update_messages(&self, messages: &[Message]) {
479 let mut guard = self.inner.lock().await;
480 guard.latest_messages = messages.to_vec();
481 guard.dirty = true;
482 }
483
484 async fn update_metadata(&self, metadata: &serde_json::Value) {
485 let mut guard = self.inner.lock().await;
486 guard.latest_metadata = metadata.clone();
487 guard.dirty = true;
488 }
489
490 async fn persist_snapshot(&self) -> Result<Uuid, String> {
491 let mut guard = self.inner.lock().await;
492 self.persist_if_needed(&mut guard).await
493 }
494
495 async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
496 if !guard.dirty
497 && let Some(checkpoint_id) = guard.parent_checkpoint_id
498 {
499 return Ok(checkpoint_id);
500 }
501
502 let signature = checkpoint_signature(&guard.latest_messages, &guard.latest_metadata)?;
503 let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
504 let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
505
506 if !should_persist {
507 guard.dirty = false;
508 if let Some(checkpoint_id) = guard.parent_checkpoint_id {
509 return Ok(checkpoint_id);
510 }
511 }
512
513 let checkpoint_id = persist_checkpoint(
514 &self.state,
515 self.session_id,
516 self.run_id,
517 &self.active_model,
518 guard.parent_checkpoint_id,
519 &guard.latest_messages,
520 &guard.latest_metadata,
521 )
522 .await?;
523
524 guard.parent_checkpoint_id = Some(checkpoint_id);
525 guard.last_persisted_signature = Some(signature);
526 guard.dirty = false;
527
528 Ok(checkpoint_id)
529 }
530}
531
532struct ServerCheckpointHook {
533 checkpoint_runtime: Arc<CheckpointRuntime>,
534}
535
536#[async_trait]
537impl AgentHook for ServerCheckpointHook {
538 async fn before_inference(
539 &self,
540 _run: &AgentRunContext,
541 messages: &[Message],
542 _model: &stakai::Model,
543 ) -> Result<(), stakpak_agent_core::AgentError> {
544 self.checkpoint_runtime.update_messages(messages).await;
545 Ok(())
546 }
547
548 async fn after_inference(
549 &self,
550 _run: &AgentRunContext,
551 messages: &[Message],
552 _model: &stakai::Model,
553 ) -> Result<(), stakpak_agent_core::AgentError> {
554 self.checkpoint_runtime.update_messages(messages).await;
555 Ok(())
556 }
557
558 async fn after_tool_execution(
559 &self,
560 _run: &AgentRunContext,
561 _tool_call: &ProposedToolCall,
562 messages: &[Message],
563 ) -> Result<(), stakpak_agent_core::AgentError> {
564 self.checkpoint_runtime.update_messages(messages).await;
565 Ok(())
566 }
567
568 async fn on_error(
569 &self,
570 _run: &AgentRunContext,
571 _error: &stakpak_agent_core::AgentError,
572 messages: &[Message],
573 ) -> Result<(), stakpak_agent_core::AgentError> {
574 self.checkpoint_runtime.update_messages(messages).await;
575 let _ = self.checkpoint_runtime.persist_snapshot().await;
576 Ok(())
577 }
578}
579
580async fn execute_mcp_tool_call(
581 state: &AppState,
582 session_id: Uuid,
583 run_id: Uuid,
584 tool_call: &ProposedToolCall,
585 cancel: &CancellationToken,
586) -> ToolExecutionResult {
587 let Some(mcp_client) = state.mcp_client.as_ref() else {
588 return ToolExecutionResult::Completed {
589 result: "MCP client is not initialized".to_string(),
590 is_error: true,
591 };
592 };
593
594 execute_mcp_tool_call_with_client(mcp_client, session_id, run_id, tool_call, cancel).await
595}
596
597async fn execute_mcp_tool_call_with_client(
598 mcp_client: &McpClient,
599 session_id: Uuid,
600 run_id: Uuid,
601 tool_call: &ProposedToolCall,
602 cancel: &CancellationToken,
603) -> ToolExecutionResult {
604 let metadata = Some(serde_json::Map::from_iter([
605 (
606 "session_id".to_string(),
607 serde_json::Value::String(session_id.to_string()),
608 ),
609 (
610 "run_id".to_string(),
611 serde_json::Value::String(run_id.to_string()),
612 ),
613 (
614 "tool_call_id".to_string(),
615 serde_json::Value::String(tool_call.id.clone()),
616 ),
617 ]));
618
619 let arguments = match &tool_call.arguments {
620 serde_json::Value::Object(map) => Some(map.clone()),
621 serde_json::Value::Null => None,
622 other => Some(serde_json::Map::from_iter([(
623 "input".to_string(),
624 other.clone(),
625 )])),
626 };
627
628 let request_handle = match stakpak_mcp_client::call_tool(
629 mcp_client,
630 CallToolRequestParam {
631 name: tool_call.name.clone().into(),
632 arguments,
633 },
634 metadata,
635 )
636 .await
637 {
638 Ok(handle) => handle,
639 Err(error) => {
640 return ToolExecutionResult::Completed {
641 result: format!("MCP tool call failed: {error}"),
642 is_error: true,
643 };
644 }
645 };
646
647 let peer_for_cancel = request_handle.peer.clone();
648 let request_id = request_handle.id.clone();
649
650 tokio::select! {
651 _ = cancel.cancelled() => {
652 let notification = CancelledNotification {
653 method: CancelledNotificationMethod,
654 params: CancelledNotificationParam {
655 request_id,
656 reason: Some("user cancel".to_string()),
657 },
658 extensions: Default::default(),
659 };
660
661 let _ = peer_for_cancel.send_notification(notification.into()).await;
662 ToolExecutionResult::Cancelled
663 }
664 server_result = request_handle.await_response() => {
665 match server_result {
666 Ok(ServerResult::CallToolResult(result)) => {
667 ToolExecutionResult::Completed {
668 result: render_call_tool_result(&result),
669 is_error: result.is_error.unwrap_or(false),
670 }
671 }
672 Ok(_) => ToolExecutionResult::Completed {
673 result: "Unexpected MCP response type".to_string(),
674 is_error: true,
675 },
676 Err(error) => ToolExecutionResult::Completed {
677 result: format!("MCP tool execution error: {error}"),
678 is_error: true,
679 },
680 }
681 }
682 }
683}
684
685fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
686 let rendered = result
687 .content
688 .iter()
689 .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
690 .collect::<Vec<_>>()
691 .join("\n");
692
693 if !rendered.is_empty() {
694 return sanitize_text_output(&rendered);
695 }
696
697 if result.content.is_empty() {
698 return "<empty tool result>".to_string();
699 }
700
701 "<non-text tool result omitted for safety>".to_string()
702}
703
704fn checkpoint_signature(
705 messages: &[Message],
706 metadata: &serde_json::Value,
707) -> Result<String, String> {
708 serde_json::to_string(&(messages, metadata))
709 .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
710}
711
712async fn persist_checkpoint(
713 state: &AppState,
714 session_id: Uuid,
715 run_id: Uuid,
716 active_model: &stakai::Model,
717 parent_id: Option<Uuid>,
718 messages: &[Message],
719 metadata: &serde_json::Value,
720) -> Result<Uuid, String> {
721 let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages))
724 .with_metadata(metadata.clone());
725
726 if let Some(parent_id) = parent_id {
727 request = request.with_parent(parent_id);
728 }
729
730 let checkpoint = state
731 .session_store
732 .create_checkpoint(session_id, &request)
733 .await
734 .map_err(|error| error.to_string())?;
735
736 let mut envelope_metadata = if metadata.is_object() {
737 metadata.clone()
738 } else {
739 json!({})
740 };
741
742 if let Some(obj) = envelope_metadata.as_object_mut() {
743 obj.insert(
744 "session_id".to_string(),
745 serde_json::Value::String(session_id.to_string()),
746 );
747 obj.insert(
748 "checkpoint_id".to_string(),
749 serde_json::Value::String(checkpoint.id.to_string()),
750 );
751 obj.insert(
752 ACTIVE_MODEL_METADATA_KEY.to_string(),
753 serde_json::Value::String(format!("{}/{}", active_model.provider, active_model.id)),
754 );
755 }
756
757 let envelope = build_checkpoint_envelope(run_id, messages.to_vec(), envelope_metadata);
758
759 state
760 .checkpoint_store
761 .save_latest(session_id, &envelope)
762 .await
763 .map_err(|error| {
764 format!(
765 "Failed to persist checkpoint envelope for session {}: {}",
766 session_id, error
767 )
768 })?;
769
770 Ok(checkpoint.id)
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776 use rmcp::model::{CallToolResult, Content};
777 use serde_json::json;
778 use stakai::{ContentPart, Message, MessageContent, Role};
779
780 #[test]
781 fn run_id_is_not_regenerated_when_building_run_context() {
782 let session_id = Uuid::new_v4();
783 let run_id = Uuid::new_v4();
784
785 let run_context = build_run_context(session_id, run_id);
786
787 assert_eq!(run_context.session_id, session_id);
788 assert_eq!(run_context.run_id, run_id);
789 }
790
791 #[test]
792 fn checkpoint_envelope_carries_same_run_id() {
793 let run_id = Uuid::new_v4();
794 let envelope = build_checkpoint_envelope(
795 run_id,
796 vec![Message::new(Role::User, "hello")],
797 json!({"turn": 1}),
798 );
799
800 assert_eq!(envelope.run_id, Some(run_id));
801 }
802
803 #[test]
804 fn render_call_tool_result_sanitizes_text_blocks() {
805 let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
806
807 assert_eq!(render_call_tool_result(&result), "okdone");
808 }
809
810 #[test]
811 fn render_call_tool_result_omits_non_text_blocks() {
812 let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
813
814 assert_eq!(
815 render_call_tool_result(&result),
816 "<non-text tool result omitted for safety>"
817 );
818 }
819
820 #[test]
821 fn checkpoint_signature_changes_when_messages_change() {
822 let messages_a = vec![Message::new(Role::User, "hello")];
823 let messages_b = vec![
824 Message::new(Role::User, "hello"),
825 Message::new(Role::Assistant, "hi"),
826 ];
827
828 let sig_a = checkpoint_signature(&messages_a, &json!({}))
829 .unwrap_or_else(|error| panic!("signature failed: {error}"));
830 let sig_b = checkpoint_signature(&messages_b, &json!({}))
831 .unwrap_or_else(|error| panic!("signature failed: {error}"));
832
833 assert_ne!(sig_a, sig_b);
834 }
835
836 #[test]
837 fn checkpoint_signature_changes_when_metadata_changes() {
838 let messages = vec![Message::new(Role::User, "hello")];
839
840 let sig_a = checkpoint_signature(&messages, &json!({}))
841 .unwrap_or_else(|error| panic!("signature failed: {error}"));
842 let sig_b = checkpoint_signature(&messages, &json!({"trimmed_up_to_message_index": 5}))
843 .unwrap_or_else(|error| panic!("signature failed: {error}"));
844
845 assert_ne!(sig_a, sig_b);
846 }
847
848 #[test]
849 fn is_new_session_empty_history() {
850 assert!(is_new_session_history(&[]));
851 }
852
853 #[test]
854 fn is_new_session_system_only() {
855 let messages = vec![Message::new(Role::System, "you are an agent")];
856 assert!(is_new_session_history(&messages));
857 }
858
859 #[test]
860 fn is_not_new_session_with_user_message() {
861 let messages = vec![Message::new(Role::User, "hello")];
862 assert!(!is_new_session_history(&messages));
863 }
864
865 #[test]
866 fn is_not_new_session_with_system_and_user() {
867 let messages = vec![
868 Message::new(Role::System, "system"),
869 Message::new(Role::User, "hello"),
870 ];
871 assert!(!is_new_session_history(&messages));
872 }
873
874 #[test]
875 fn is_not_new_session_with_assistant() {
876 let messages = vec![Message::new(Role::Assistant, "hi there")];
877 assert!(!is_new_session_history(&messages));
878 }
879
880 #[test]
881 fn prepend_context_to_text_message() {
882 let msg = Message::new(Role::User, "how do I deploy?");
883 let result = prepend_context_to_user_message(msg, "<context>env info</context>");
884
885 let text = result.text().unwrap_or_default();
886 assert!(
887 text.starts_with("<context>env info</context>"),
888 "context should be prepended"
889 );
890 assert!(
891 text.contains("how do I deploy?"),
892 "original text should be preserved"
893 );
894 }
895
896 #[test]
897 fn prepend_context_to_empty_text_message() {
898 let msg = Message::new(Role::User, " ");
899 let result = prepend_context_to_user_message(msg, "<context>env info</context>");
900
901 let text = result.text().unwrap_or_default();
902 assert_eq!(text, "<context>env info</context>");
903 }
904
905 #[test]
906 fn prepend_context_to_parts_message() {
907 let msg = Message {
908 role: Role::User,
909 content: MessageContent::Parts(vec![ContentPart::text("original text")]),
910 name: None,
911 provider_options: None,
912 };
913 let result = prepend_context_to_user_message(msg, "<context>env info</context>");
914
915 if let MessageContent::Parts(parts) = &result.content {
916 assert_eq!(parts.len(), 2, "should have context part + original part");
917 if let ContentPart::Text { text, .. } = &parts[0] {
918 assert_eq!(text, "<context>env info</context>");
919 } else {
920 panic!("first part should be text");
921 }
922 } else {
923 panic!("expected Parts content");
924 }
925 }
926
927 #[test]
928 fn prepend_empty_context_is_noop() {
929 let msg = Message::new(Role::User, "hello");
930 let result = prepend_context_to_user_message(msg, " ");
931
932 assert_eq!(result.text().unwrap_or_default(), "hello");
933 }
934}