1use crate::{
2 context::{ContextFile, EnvironmentContext, ProjectContext, SessionContextBuilder},
3 message_bridge,
4 sandbox::{SandboxConfig, SandboxedMcpServer},
5 state::AppState,
6 types::SessionHandle,
7};
8use async_trait::async_trait;
9use rmcp::model::{
10 CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
11 CancelledNotificationParam, ServerResult,
12};
13use serde_json::json;
14use stakai::{ContentPart, Message, MessageContent, Role};
15use stakpak_agent_core::{
16 AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, BudgetAwareContextReducer,
17 CheckpointEnvelopeV1, CompactionConfig, PassthroughCompactionEngine, ProposedToolCall,
18 RetryConfig, ToolExecutionResult, ToolExecutor, run_agent,
19};
20use stakpak_api::CreateCheckpointRequest;
21use stakpak_mcp_client::McpClient;
22use stakpak_shared::utils::sanitize_text_output;
23use std::{path::Path, sync::Arc};
24use tokio::sync::{Mutex, mpsc};
25use tokio_util::sync::CancellationToken;
26use uuid::Uuid;
27
28const MAX_TURNS: usize = 64;
29const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
30pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
31
32pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
33 AgentRunContext { run_id, session_id }
34}
35
36pub fn build_checkpoint_envelope(
37 run_id: Uuid,
38 messages: Vec<stakai::Message>,
39 metadata: serde_json::Value,
40) -> CheckpointEnvelopeV1 {
41 CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
42}
43
44pub fn spawn_session_actor(
45 state: AppState,
46 session_id: Uuid,
47 run_id: Uuid,
48 model: stakai::Model,
49 user_message: Message,
50 caller_context: Vec<ContextFile>,
51 sandbox_config: Option<SandboxConfig>,
52) -> Result<SessionHandle, String> {
53 let (command_tx, command_rx) = mpsc::channel(128);
54 let cancel = CancellationToken::new();
55
56 let handle = SessionHandle::new(command_tx, cancel.clone());
57
58 let state_for_task = state.clone();
59 tokio::spawn(async move {
60 let actor_result = run_session_actor(
61 state_for_task.clone(),
62 session_id,
63 run_id,
64 model,
65 user_message,
66 caller_context,
67 command_rx,
68 cancel,
69 sandbox_config,
70 )
71 .await;
72
73 let finish_result = actor_result.map(|_| ());
74 let _ = state_for_task
75 .run_manager
76 .mark_run_finished(session_id, run_id, finish_result)
77 .await;
78 });
79
80 Ok(handle)
81}
82
83#[allow(clippy::too_many_arguments)]
84async fn run_session_actor(
85 state: AppState,
86 session_id: Uuid,
87 run_id: Uuid,
88 model: stakai::Model,
89 mut user_message: Message,
90 caller_context: Vec<ContextFile>,
91 command_rx: mpsc::Receiver<AgentCommand>,
92 cancel: CancellationToken,
93 sandbox_config: Option<SandboxConfig>,
94) -> Result<(), String> {
95 let active_checkpoint = state
96 .session_store
97 .get_active_checkpoint(session_id)
98 .await
99 .ok();
100 let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
101
102 let (initial_messages, mut initial_metadata) =
103 match state.checkpoint_store.load_latest(session_id).await {
104 Ok(Some(envelope)) => (envelope.messages, envelope.metadata),
105 Ok(None) => {
106 let messages = active_checkpoint
107 .as_ref()
108 .map(|checkpoint| {
109 message_bridge::chat_to_stakai(checkpoint.state.messages.clone())
110 })
111 .unwrap_or_default();
112 let metadata = active_checkpoint
113 .as_ref()
114 .and_then(|checkpoint| checkpoint.state.metadata.clone())
115 .unwrap_or_else(|| json!({}));
116 (messages, metadata)
117 }
118 Err(error) => {
119 return Err(format!("Failed to load checkpoint envelope: {error}"));
120 }
121 };
122
123 let sandbox = if let Some(sandbox_config) = sandbox_config {
126 tracing::info!(session_id = %session_id, image = %sandbox_config.image, "Spawning sandbox container for session");
127 Some(
128 SandboxedMcpServer::spawn(&sandbox_config)
129 .await
130 .map_err(|e| format!("Failed to start sandbox for session {session_id}: {e}"))?,
131 )
132 } else {
133 None
134 };
135
136 let (run_tools, tool_executor): (Vec<stakai::Tool>, Box<dyn ToolExecutor + Send + Sync>) =
137 if let Some(ref sandbox) = sandbox {
138 (
139 sandbox.tools.clone(),
140 Box::new(SandboxedToolExecutor {
141 mcp_client: sandbox.client.clone(),
142 }),
143 )
144 } else {
145 (
146 state.current_mcp_tools().await,
147 Box::new(ServerToolExecutor {
148 state: state.clone(),
149 }),
150 )
151 };
152
153 let is_new_session = is_new_session_history(&initial_messages);
154 let session_cwd = resolve_session_cwd(&state, session_id).await;
155 let environment = EnvironmentContext::snapshot(&session_cwd).await;
156
157 let has_runtime_caller_context = !caller_context.is_empty();
161 let mut all_caller_context = caller_context;
162 all_caller_context.extend(state.current_skills().await);
163
164 let project =
165 ProjectContext::discover(Path::new(&session_cwd)).with_caller_context(all_caller_context);
166
167 let session_context = SessionContextBuilder::new()
168 .base_system_prompt(state.base_system_prompt.clone().unwrap_or_default())
169 .environment(environment)
170 .project(project)
171 .tools(&run_tools)
172 .budget(state.context_budget.clone())
173 .build();
174
175 if (is_new_session || has_runtime_caller_context)
176 && let Some(context_block) = session_context.user_context_block.as_deref()
177 {
178 user_message = prepend_context_to_user_message(user_message, context_block);
179 }
180
181 let mut baseline_messages = initial_messages.clone();
182 baseline_messages.push(user_message.clone());
183
184 let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
185 state.clone(),
186 session_id,
187 run_id,
188 model.clone(),
189 parent_checkpoint_id,
190 baseline_messages,
191 initial_metadata.clone(),
192 ));
193
194 checkpoint_runtime
195 .persist_snapshot()
196 .await
197 .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
198
199 let periodic_checkpoint_cancel = CancellationToken::new();
200 let periodic_checkpoint_runtime = checkpoint_runtime.clone();
201 let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
202 let periodic_task = tokio::spawn(async move {
203 let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
204 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
205
206 loop {
207 tokio::select! {
208 _ = periodic_checkpoint_cancel_task.cancelled() => break,
209 _ = interval.tick() => {
210 let _ = periodic_checkpoint_runtime.persist_snapshot().await;
211 }
212 }
213 }
214 });
215
216 let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
217
218 let event_state = state.clone();
219 let event_forwarder = tokio::spawn(async move {
220 while let Some(event) = core_event_rx.recv().await {
221 handle_core_event(&event_state, session_id, run_id, event).await;
222 }
223 });
224
225 let max_output_tokens = model.limit.output as u32;
229 let agent_config = AgentConfig {
230 model,
231 system_prompt: session_context.system_prompt,
232 max_turns: MAX_TURNS,
233 max_output_tokens,
234 provider_options: None,
235 tool_approval: state.tool_approval_policy.clone(),
236 retry: RetryConfig::default(),
237 compaction: CompactionConfig::default(),
238 tools: run_tools,
239 };
240
241 let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
242 checkpoint_runtime: checkpoint_runtime.clone(),
243 })];
244
245 let compactor = PassthroughCompactionEngine;
246 let context_reducer = BudgetAwareContextReducer::new(5, 0.8);
247 let run_context = build_run_context(session_id, run_id);
248
249 let run_result = run_agent(
250 run_context,
251 state.inference.as_ref(),
252 &agent_config,
253 initial_messages,
254 &mut initial_metadata,
255 user_message,
256 tool_executor.as_ref(),
257 &hooks,
258 core_event_tx,
259 command_rx,
260 cancel,
261 &compactor,
262 &context_reducer,
263 )
264 .await;
265
266 periodic_checkpoint_cancel.cancel();
267 let _ = periodic_task.await;
268
269 if let Some(sandbox) = sandbox {
271 sandbox.shutdown().await;
272 }
273
274 state.clear_pending_tools(session_id, run_id).await;
275
276 match &run_result {
277 Ok(result) => {
278 checkpoint_runtime.update_messages(&result.messages).await;
279 checkpoint_runtime.update_metadata(&result.metadata).await;
280 checkpoint_runtime
281 .persist_snapshot()
282 .await
283 .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
284 }
285 Err(_) => {
286 checkpoint_runtime.update_metadata(&initial_metadata).await;
287 let _ = checkpoint_runtime.persist_snapshot().await;
288 }
289 }
290
291 let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
292
293 run_result
294 .map(|_| ())
295 .map_err(|error| format!("Agent run failed: {error}"))
296}
297
298fn is_new_session_history(messages: &[Message]) -> bool {
299 !messages
300 .iter()
301 .any(|message| matches!(message.role, Role::User | Role::Assistant | Role::Tool))
302}
303
304async fn resolve_session_cwd(state: &AppState, session_id: Uuid) -> String {
305 if let Ok(session) = state.session_store.get_session(session_id).await
307 && let Some(cwd) = session.cwd
308 && !cwd.trim().is_empty()
309 {
310 return cwd;
311 }
312
313 if let Some(project_dir) = &state.project_dir {
315 return project_dir.clone();
316 }
317
318 std::env::current_dir()
320 .ok()
321 .map(|path| path.to_string_lossy().to_string())
322 .unwrap_or_else(|| ".".to_string())
323}
324
325fn prepend_context_to_user_message(mut message: Message, context_block: &str) -> Message {
326 if context_block.trim().is_empty() {
327 return message;
328 }
329
330 match &mut message.content {
331 MessageContent::Text(text) => {
332 let existing = std::mem::take(text);
333 *text = if existing.trim().is_empty() {
334 context_block.to_string()
335 } else {
336 format!("{context_block}\n\n{existing}")
337 };
338 }
339 MessageContent::Parts(parts) => {
340 let mut prefixed = Vec::with_capacity(parts.len() + 1);
341 prefixed.push(ContentPart::text(context_block));
342 prefixed.append(parts);
343 *parts = prefixed;
344 }
345 }
346
347 message
348}
349
350async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
351 match &event {
352 AgentEvent::ToolCallsProposed { tool_calls, .. } => {
353 state
354 .set_pending_tools(session_id, run_id, tool_calls.clone())
355 .await;
356 }
357 AgentEvent::TurnCompleted { .. }
358 | AgentEvent::RunCompleted { .. }
359 | AgentEvent::RunError { .. } => {
360 state.clear_pending_tools(session_id, run_id).await;
361 }
362 _ => {}
363 }
364
365 state.events.publish(session_id, Some(run_id), event).await;
366}
367
368#[derive(Clone)]
369struct ServerToolExecutor {
370 state: AppState,
371}
372
373#[async_trait]
374impl ToolExecutor for ServerToolExecutor {
375 async fn execute_tool_call(
376 &self,
377 run: &AgentRunContext,
378 tool_call: &ProposedToolCall,
379 cancel: &CancellationToken,
380 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
381 Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
382 }
383}
384
385#[derive(Clone)]
387struct SandboxedToolExecutor {
388 mcp_client: Arc<McpClient>,
389}
390
391#[async_trait]
392impl ToolExecutor for SandboxedToolExecutor {
393 async fn execute_tool_call(
394 &self,
395 run: &AgentRunContext,
396 tool_call: &ProposedToolCall,
397 cancel: &CancellationToken,
398 ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
399 Ok(execute_mcp_tool_call_with_client(
400 &self.mcp_client,
401 run.session_id,
402 run.run_id,
403 tool_call,
404 cancel,
405 )
406 .await)
407 }
408}
409
410struct CheckpointRuntime {
411 state: AppState,
412 session_id: Uuid,
413 run_id: Uuid,
414 active_model: stakai::Model,
415 inner: Mutex<CheckpointRuntimeInner>,
416}
417
418struct CheckpointRuntimeInner {
419 parent_checkpoint_id: Option<Uuid>,
420 latest_messages: Vec<Message>,
421 latest_metadata: serde_json::Value,
422 last_persisted_signature: Option<String>,
423 dirty: bool,
424}
425
426impl CheckpointRuntime {
427 fn new(
428 state: AppState,
429 session_id: Uuid,
430 run_id: Uuid,
431 active_model: stakai::Model,
432 parent_checkpoint_id: Option<Uuid>,
433 latest_messages: Vec<Message>,
434 latest_metadata: serde_json::Value,
435 ) -> Self {
436 Self {
437 state,
438 session_id,
439 run_id,
440 active_model,
441 inner: Mutex::new(CheckpointRuntimeInner {
442 parent_checkpoint_id,
443 latest_messages,
444 latest_metadata,
445 last_persisted_signature: None,
446 dirty: true,
447 }),
448 }
449 }
450
451 async fn update_messages(&self, messages: &[Message]) {
452 let mut guard = self.inner.lock().await;
453 guard.latest_messages = messages.to_vec();
454 guard.dirty = true;
455 }
456
457 async fn update_metadata(&self, metadata: &serde_json::Value) {
458 let mut guard = self.inner.lock().await;
459 guard.latest_metadata = metadata.clone();
460 guard.dirty = true;
461 }
462
463 async fn persist_snapshot(&self) -> Result<Uuid, String> {
464 let mut guard = self.inner.lock().await;
465 self.persist_if_needed(&mut guard).await
466 }
467
468 async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
469 if !guard.dirty
470 && let Some(checkpoint_id) = guard.parent_checkpoint_id
471 {
472 return Ok(checkpoint_id);
473 }
474
475 let signature = checkpoint_signature(&guard.latest_messages, &guard.latest_metadata)?;
476 let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
477 let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
478
479 if !should_persist {
480 guard.dirty = false;
481 if let Some(checkpoint_id) = guard.parent_checkpoint_id {
482 return Ok(checkpoint_id);
483 }
484 }
485
486 let checkpoint_id = persist_checkpoint(
487 &self.state,
488 self.session_id,
489 self.run_id,
490 &self.active_model,
491 guard.parent_checkpoint_id,
492 &guard.latest_messages,
493 &guard.latest_metadata,
494 )
495 .await?;
496
497 guard.parent_checkpoint_id = Some(checkpoint_id);
498 guard.last_persisted_signature = Some(signature);
499 guard.dirty = false;
500
501 Ok(checkpoint_id)
502 }
503}
504
505struct ServerCheckpointHook {
506 checkpoint_runtime: Arc<CheckpointRuntime>,
507}
508
509#[async_trait]
510impl AgentHook for ServerCheckpointHook {
511 async fn before_inference(
512 &self,
513 _run: &AgentRunContext,
514 messages: &[Message],
515 _model: &stakai::Model,
516 ) -> Result<(), stakpak_agent_core::AgentError> {
517 self.checkpoint_runtime.update_messages(messages).await;
518 Ok(())
519 }
520
521 async fn after_inference(
522 &self,
523 _run: &AgentRunContext,
524 messages: &[Message],
525 _model: &stakai::Model,
526 ) -> Result<(), stakpak_agent_core::AgentError> {
527 self.checkpoint_runtime.update_messages(messages).await;
528 Ok(())
529 }
530
531 async fn after_tool_execution(
532 &self,
533 _run: &AgentRunContext,
534 _tool_call: &ProposedToolCall,
535 messages: &[Message],
536 ) -> Result<(), stakpak_agent_core::AgentError> {
537 self.checkpoint_runtime.update_messages(messages).await;
538 Ok(())
539 }
540
541 async fn on_error(
542 &self,
543 _run: &AgentRunContext,
544 _error: &stakpak_agent_core::AgentError,
545 messages: &[Message],
546 ) -> Result<(), stakpak_agent_core::AgentError> {
547 self.checkpoint_runtime.update_messages(messages).await;
548 let _ = self.checkpoint_runtime.persist_snapshot().await;
549 Ok(())
550 }
551}
552
553async fn execute_mcp_tool_call(
554 state: &AppState,
555 session_id: Uuid,
556 run_id: Uuid,
557 tool_call: &ProposedToolCall,
558 cancel: &CancellationToken,
559) -> ToolExecutionResult {
560 let Some(mcp_client) = state.mcp_client.as_ref() else {
561 return ToolExecutionResult::Completed {
562 result: "MCP client is not initialized".to_string(),
563 is_error: true,
564 };
565 };
566
567 execute_mcp_tool_call_with_client(mcp_client, session_id, run_id, tool_call, cancel).await
568}
569
570async fn execute_mcp_tool_call_with_client(
571 mcp_client: &McpClient,
572 session_id: Uuid,
573 run_id: Uuid,
574 tool_call: &ProposedToolCall,
575 cancel: &CancellationToken,
576) -> ToolExecutionResult {
577 let metadata = Some(serde_json::Map::from_iter([
578 (
579 "session_id".to_string(),
580 serde_json::Value::String(session_id.to_string()),
581 ),
582 (
583 "run_id".to_string(),
584 serde_json::Value::String(run_id.to_string()),
585 ),
586 (
587 "tool_call_id".to_string(),
588 serde_json::Value::String(tool_call.id.clone()),
589 ),
590 ]));
591
592 let arguments = match &tool_call.arguments {
593 serde_json::Value::Object(map) => Some(map.clone()),
594 serde_json::Value::Null => None,
595 other => Some(serde_json::Map::from_iter([(
596 "input".to_string(),
597 other.clone(),
598 )])),
599 };
600
601 let request_handle = match stakpak_mcp_client::call_tool(
602 mcp_client,
603 CallToolRequestParam {
604 name: tool_call.name.clone().into(),
605 arguments,
606 },
607 metadata,
608 )
609 .await
610 {
611 Ok(handle) => handle,
612 Err(error) => {
613 return ToolExecutionResult::Completed {
614 result: format!("MCP tool call failed: {error}"),
615 is_error: true,
616 };
617 }
618 };
619
620 let peer_for_cancel = request_handle.peer.clone();
621 let request_id = request_handle.id.clone();
622
623 tokio::select! {
624 _ = cancel.cancelled() => {
625 let notification = CancelledNotification {
626 method: CancelledNotificationMethod,
627 params: CancelledNotificationParam {
628 request_id,
629 reason: Some("user cancel".to_string()),
630 },
631 extensions: Default::default(),
632 };
633
634 let _ = peer_for_cancel.send_notification(notification.into()).await;
635 ToolExecutionResult::Cancelled
636 }
637 server_result = request_handle.await_response() => {
638 match server_result {
639 Ok(ServerResult::CallToolResult(result)) => {
640 ToolExecutionResult::Completed {
641 result: render_call_tool_result(&result),
642 is_error: result.is_error.unwrap_or(false),
643 }
644 }
645 Ok(_) => ToolExecutionResult::Completed {
646 result: "Unexpected MCP response type".to_string(),
647 is_error: true,
648 },
649 Err(error) => ToolExecutionResult::Completed {
650 result: format!("MCP tool execution error: {error}"),
651 is_error: true,
652 },
653 }
654 }
655 }
656}
657
658fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
659 let rendered = result
660 .content
661 .iter()
662 .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
663 .collect::<Vec<_>>()
664 .join("\n");
665
666 if !rendered.is_empty() {
667 return sanitize_text_output(&rendered);
668 }
669
670 if result.content.is_empty() {
671 return "<empty tool result>".to_string();
672 }
673
674 "<non-text tool result omitted for safety>".to_string()
675}
676
677fn checkpoint_signature(
678 messages: &[Message],
679 metadata: &serde_json::Value,
680) -> Result<String, String> {
681 serde_json::to_string(&(messages, metadata))
682 .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
683}
684
685async fn persist_checkpoint(
686 state: &AppState,
687 session_id: Uuid,
688 run_id: Uuid,
689 active_model: &stakai::Model,
690 parent_id: Option<Uuid>,
691 messages: &[Message],
692 metadata: &serde_json::Value,
693) -> Result<Uuid, String> {
694 let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages))
697 .with_metadata(metadata.clone());
698
699 if let Some(parent_id) = parent_id {
700 request = request.with_parent(parent_id);
701 }
702
703 let checkpoint = state
704 .session_store
705 .create_checkpoint(session_id, &request)
706 .await
707 .map_err(|error| error.to_string())?;
708
709 let mut envelope_metadata = if metadata.is_object() {
710 metadata.clone()
711 } else {
712 json!({})
713 };
714
715 if let Some(obj) = envelope_metadata.as_object_mut() {
716 obj.insert(
717 "session_id".to_string(),
718 serde_json::Value::String(session_id.to_string()),
719 );
720 obj.insert(
721 "checkpoint_id".to_string(),
722 serde_json::Value::String(checkpoint.id.to_string()),
723 );
724 obj.insert(
725 ACTIVE_MODEL_METADATA_KEY.to_string(),
726 serde_json::Value::String(format!("{}/{}", active_model.provider, active_model.id)),
727 );
728 }
729
730 let envelope = build_checkpoint_envelope(run_id, messages.to_vec(), envelope_metadata);
731
732 state
733 .checkpoint_store
734 .save_latest(session_id, &envelope)
735 .await
736 .map_err(|error| {
737 format!(
738 "Failed to persist checkpoint envelope for session {}: {}",
739 session_id, error
740 )
741 })?;
742
743 Ok(checkpoint.id)
744}
745
746#[cfg(test)]
747mod tests {
748 use super::*;
749 use rmcp::model::{CallToolResult, Content};
750 use serde_json::json;
751 use stakai::{ContentPart, Message, MessageContent, Role};
752
753 #[test]
754 fn run_id_is_not_regenerated_when_building_run_context() {
755 let session_id = Uuid::new_v4();
756 let run_id = Uuid::new_v4();
757
758 let run_context = build_run_context(session_id, run_id);
759
760 assert_eq!(run_context.session_id, session_id);
761 assert_eq!(run_context.run_id, run_id);
762 }
763
764 #[test]
765 fn checkpoint_envelope_carries_same_run_id() {
766 let run_id = Uuid::new_v4();
767 let envelope = build_checkpoint_envelope(
768 run_id,
769 vec![Message::new(Role::User, "hello")],
770 json!({"turn": 1}),
771 );
772
773 assert_eq!(envelope.run_id, Some(run_id));
774 }
775
776 #[test]
777 fn render_call_tool_result_sanitizes_text_blocks() {
778 let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
779
780 assert_eq!(render_call_tool_result(&result), "okdone");
781 }
782
783 #[test]
784 fn render_call_tool_result_omits_non_text_blocks() {
785 let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
786
787 assert_eq!(
788 render_call_tool_result(&result),
789 "<non-text tool result omitted for safety>"
790 );
791 }
792
793 #[test]
794 fn checkpoint_signature_changes_when_messages_change() {
795 let messages_a = vec![Message::new(Role::User, "hello")];
796 let messages_b = vec![
797 Message::new(Role::User, "hello"),
798 Message::new(Role::Assistant, "hi"),
799 ];
800
801 let sig_a = checkpoint_signature(&messages_a, &json!({}))
802 .unwrap_or_else(|error| panic!("signature failed: {error}"));
803 let sig_b = checkpoint_signature(&messages_b, &json!({}))
804 .unwrap_or_else(|error| panic!("signature failed: {error}"));
805
806 assert_ne!(sig_a, sig_b);
807 }
808
809 #[test]
810 fn checkpoint_signature_changes_when_metadata_changes() {
811 let messages = vec![Message::new(Role::User, "hello")];
812
813 let sig_a = checkpoint_signature(&messages, &json!({}))
814 .unwrap_or_else(|error| panic!("signature failed: {error}"));
815 let sig_b = checkpoint_signature(&messages, &json!({"trimmed_up_to_message_index": 5}))
816 .unwrap_or_else(|error| panic!("signature failed: {error}"));
817
818 assert_ne!(sig_a, sig_b);
819 }
820
821 #[test]
822 fn is_new_session_empty_history() {
823 assert!(is_new_session_history(&[]));
824 }
825
826 #[test]
827 fn is_new_session_system_only() {
828 let messages = vec![Message::new(Role::System, "you are an agent")];
829 assert!(is_new_session_history(&messages));
830 }
831
832 #[test]
833 fn is_not_new_session_with_user_message() {
834 let messages = vec![Message::new(Role::User, "hello")];
835 assert!(!is_new_session_history(&messages));
836 }
837
838 #[test]
839 fn is_not_new_session_with_system_and_user() {
840 let messages = vec![
841 Message::new(Role::System, "system"),
842 Message::new(Role::User, "hello"),
843 ];
844 assert!(!is_new_session_history(&messages));
845 }
846
847 #[test]
848 fn is_not_new_session_with_assistant() {
849 let messages = vec![Message::new(Role::Assistant, "hi there")];
850 assert!(!is_new_session_history(&messages));
851 }
852
853 #[test]
854 fn prepend_context_to_text_message() {
855 let msg = Message::new(Role::User, "how do I deploy?");
856 let result = prepend_context_to_user_message(msg, "<context>env info</context>");
857
858 let text = result.text().unwrap_or_default();
859 assert!(
860 text.starts_with("<context>env info</context>"),
861 "context should be prepended"
862 );
863 assert!(
864 text.contains("how do I deploy?"),
865 "original text should be preserved"
866 );
867 }
868
869 #[test]
870 fn prepend_context_to_empty_text_message() {
871 let msg = Message::new(Role::User, " ");
872 let result = prepend_context_to_user_message(msg, "<context>env info</context>");
873
874 let text = result.text().unwrap_or_default();
875 assert_eq!(text, "<context>env info</context>");
876 }
877
878 #[test]
879 fn prepend_context_to_parts_message() {
880 let msg = Message {
881 role: Role::User,
882 content: MessageContent::Parts(vec![ContentPart::text("original text")]),
883 name: None,
884 provider_options: None,
885 };
886 let result = prepend_context_to_user_message(msg, "<context>env info</context>");
887
888 if let MessageContent::Parts(parts) = &result.content {
889 assert_eq!(parts.len(), 2, "should have context part + original part");
890 if let ContentPart::Text { text, .. } = &parts[0] {
891 assert_eq!(text, "<context>env info</context>");
892 } else {
893 panic!("first part should be text");
894 }
895 } else {
896 panic!("expected Parts content");
897 }
898 }
899
900 #[test]
901 fn prepend_empty_context_is_noop() {
902 let msg = Message::new(Role::User, "hello");
903 let result = prepend_context_to_user_message(msg, " ");
904
905 assert_eq!(result.text().unwrap_or_default(), "hello");
906 }
907}