Skip to main content

stakpak_server/
session_actor.rs

1use crate::{
2    context::{ContextFile, EnvironmentContext, ProjectContext, SessionContextBuilder},
3    message_bridge,
4    sandbox::{SandboxConfig, SandboxedMcpServer},
5    state::AppState,
6    types::SessionHandle,
7};
8use async_trait::async_trait;
9use rmcp::model::{
10    CallToolRequestParam, CancelledNotification, CancelledNotificationMethod,
11    CancelledNotificationParam, ServerResult,
12};
13use serde_json::json;
14use stakai::{ContentPart, Message, MessageContent, Role};
15use stakpak_agent_core::{
16    AgentCommand, AgentConfig, AgentEvent, AgentHook, AgentRunContext, BudgetAwareContextReducer,
17    CheckpointEnvelopeV1, CompactionConfig, PassthroughCompactionEngine, ProposedToolCall,
18    RetryConfig, ToolExecutionResult, ToolExecutor, run_agent,
19};
20use stakpak_api::CreateCheckpointRequest;
21use stakpak_mcp_client::McpClient;
22use stakpak_shared::utils::sanitize_text_output;
23use std::{path::Path, sync::Arc};
24use tokio::sync::{Mutex, mpsc};
25use tokio_util::sync::CancellationToken;
26use uuid::Uuid;
27
28const MAX_TURNS: usize = 64;
29const CHECKPOINT_FLUSH_INTERVAL: std::time::Duration = std::time::Duration::from_secs(5);
30pub(crate) const ACTIVE_MODEL_METADATA_KEY: &str = "active_model";
31
32pub fn build_run_context(session_id: Uuid, run_id: Uuid) -> AgentRunContext {
33    AgentRunContext { run_id, session_id }
34}
35
36pub fn build_checkpoint_envelope(
37    run_id: Uuid,
38    messages: Vec<stakai::Message>,
39    metadata: serde_json::Value,
40) -> CheckpointEnvelopeV1 {
41    CheckpointEnvelopeV1::new(Some(run_id), messages, metadata)
42}
43
44pub fn spawn_session_actor(
45    state: AppState,
46    session_id: Uuid,
47    run_id: Uuid,
48    model: stakai::Model,
49    user_message: Message,
50    caller_context: Vec<ContextFile>,
51    sandbox_config: Option<SandboxConfig>,
52) -> Result<SessionHandle, String> {
53    let (command_tx, command_rx) = mpsc::channel(128);
54    let cancel = CancellationToken::new();
55
56    let handle = SessionHandle::new(command_tx, cancel.clone());
57
58    let state_for_task = state.clone();
59    tokio::spawn(async move {
60        let actor_result = run_session_actor(
61            state_for_task.clone(),
62            session_id,
63            run_id,
64            model,
65            user_message,
66            caller_context,
67            command_rx,
68            cancel,
69            sandbox_config,
70        )
71        .await;
72
73        let finish_result = actor_result.map(|_| ());
74        let _ = state_for_task
75            .run_manager
76            .mark_run_finished(session_id, run_id, finish_result)
77            .await;
78    });
79
80    Ok(handle)
81}
82
83#[allow(clippy::too_many_arguments)]
84async fn run_session_actor(
85    state: AppState,
86    session_id: Uuid,
87    run_id: Uuid,
88    model: stakai::Model,
89    mut user_message: Message,
90    caller_context: Vec<ContextFile>,
91    command_rx: mpsc::Receiver<AgentCommand>,
92    cancel: CancellationToken,
93    sandbox_config: Option<SandboxConfig>,
94) -> Result<(), String> {
95    let active_checkpoint = state
96        .session_store
97        .get_active_checkpoint(session_id)
98        .await
99        .ok();
100    let parent_checkpoint_id = active_checkpoint.as_ref().map(|checkpoint| checkpoint.id);
101
102    let (initial_messages, mut initial_metadata) =
103        match state.checkpoint_store.load_latest(session_id).await {
104            Ok(Some(envelope)) => (envelope.messages, envelope.metadata),
105            Ok(None) => {
106                let messages = active_checkpoint
107                    .as_ref()
108                    .map(|checkpoint| {
109                        message_bridge::chat_to_stakai(checkpoint.state.messages.clone())
110                    })
111                    .unwrap_or_default();
112                let metadata = active_checkpoint
113                    .as_ref()
114                    .and_then(|checkpoint| checkpoint.state.metadata.clone())
115                    .unwrap_or_else(|| json!({}));
116                (messages, metadata)
117            }
118            Err(error) => {
119                return Err(format!("Failed to load checkpoint envelope: {error}"));
120            }
121        };
122
123    // If sandbox is requested, spawn a sandboxed MCP server for this session.
124    // Otherwise, use the shared in-process MCP client.
125    let sandbox = if let Some(sandbox_config) = sandbox_config {
126        tracing::info!(session_id = %session_id, image = %sandbox_config.image, "Spawning sandbox container for session");
127        Some(
128            SandboxedMcpServer::spawn(&sandbox_config)
129                .await
130                .map_err(|e| format!("Failed to start sandbox for session {session_id}: {e}"))?,
131        )
132    } else {
133        None
134    };
135
136    let (run_tools, tool_executor): (Vec<stakai::Tool>, Box<dyn ToolExecutor + Send + Sync>) =
137        if let Some(ref sandbox) = sandbox {
138            (
139                sandbox.tools.clone(),
140                Box::new(SandboxedToolExecutor {
141                    mcp_client: sandbox.client.clone(),
142                }),
143            )
144        } else {
145            (
146                state.current_mcp_tools().await,
147                Box::new(ServerToolExecutor {
148                    state: state.clone(),
149                }),
150            )
151        };
152
153    let is_new_session = is_new_session_history(&initial_messages);
154    let session_cwd = resolve_session_cwd(&state, session_id).await;
155    let environment = EnvironmentContext::snapshot(&session_cwd).await;
156
157    // Combine caller context with pre-loaded remote skills context from AppState.
158    // Explicit caller context should force per-turn injection, even on resumed
159    // sessions, while startup remote skills remain baseline context.
160    let has_runtime_caller_context = !caller_context.is_empty();
161    let mut all_caller_context = caller_context;
162    all_caller_context.extend(state.current_skills().await);
163
164    let project =
165        ProjectContext::discover(Path::new(&session_cwd)).with_caller_context(all_caller_context);
166
167    let session_context = SessionContextBuilder::new()
168        .base_system_prompt(state.base_system_prompt.clone().unwrap_or_default())
169        .environment(environment)
170        .project(project)
171        .tools(&run_tools)
172        .budget(state.context_budget.clone())
173        .build();
174
175    if (is_new_session || has_runtime_caller_context)
176        && let Some(context_block) = session_context.user_context_block.as_deref()
177    {
178        user_message = prepend_context_to_user_message(user_message, context_block);
179    }
180
181    let mut baseline_messages = initial_messages.clone();
182    baseline_messages.push(user_message.clone());
183
184    let checkpoint_runtime = Arc::new(CheckpointRuntime::new(
185        state.clone(),
186        session_id,
187        run_id,
188        model.clone(),
189        parent_checkpoint_id,
190        baseline_messages,
191        initial_metadata.clone(),
192    ));
193
194    checkpoint_runtime
195        .persist_snapshot()
196        .await
197        .map_err(|error| format!("Failed to persist baseline checkpoint: {error}"))?;
198
199    let periodic_checkpoint_cancel = CancellationToken::new();
200    let periodic_checkpoint_runtime = checkpoint_runtime.clone();
201    let periodic_checkpoint_cancel_task = periodic_checkpoint_cancel.clone();
202    let periodic_task = tokio::spawn(async move {
203        let mut interval = tokio::time::interval(CHECKPOINT_FLUSH_INTERVAL);
204        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
205
206        loop {
207            tokio::select! {
208                _ = periodic_checkpoint_cancel_task.cancelled() => break,
209                _ = interval.tick() => {
210                    let _ = periodic_checkpoint_runtime.persist_snapshot().await;
211                }
212            }
213        }
214    });
215
216    let (core_event_tx, mut core_event_rx) = mpsc::channel::<AgentEvent>(256);
217
218    let event_state = state.clone();
219    let event_forwarder = tokio::spawn(async move {
220        while let Some(event) = core_event_rx.recv().await {
221            handle_core_event(&event_state, session_id, run_id, event).await;
222        }
223    });
224
225    // Use the model's maximum output capacity as the output budget for context
226    // window calculations. This is conservative — the actual response may be shorter,
227    // but reserving the full limit avoids mid-response context truncation.
228    let max_output_tokens = model.limit.output as u32;
229    let agent_config = AgentConfig {
230        model,
231        system_prompt: session_context.system_prompt,
232        max_turns: MAX_TURNS,
233        max_output_tokens,
234        provider_options: None,
235        tool_approval: state.tool_approval_policy.clone(),
236        retry: RetryConfig::default(),
237        compaction: CompactionConfig::default(),
238        tools: run_tools,
239    };
240
241    let hooks: Vec<Box<dyn AgentHook>> = vec![Box::new(ServerCheckpointHook {
242        checkpoint_runtime: checkpoint_runtime.clone(),
243    })];
244
245    let compactor = PassthroughCompactionEngine;
246    let context_reducer = BudgetAwareContextReducer::new(5, 0.8);
247    let run_context = build_run_context(session_id, run_id);
248
249    let run_result = run_agent(
250        run_context,
251        state.inference.as_ref(),
252        &agent_config,
253        initial_messages,
254        &mut initial_metadata,
255        user_message,
256        tool_executor.as_ref(),
257        &hooks,
258        core_event_tx,
259        command_rx,
260        cancel,
261        &compactor,
262        &context_reducer,
263    )
264    .await;
265
266    periodic_checkpoint_cancel.cancel();
267    let _ = periodic_task.await;
268
269    // Shut down sandbox container if one was started
270    if let Some(sandbox) = sandbox {
271        sandbox.shutdown().await;
272    }
273
274    state.clear_pending_tools(session_id, run_id).await;
275
276    match &run_result {
277        Ok(result) => {
278            checkpoint_runtime.update_messages(&result.messages).await;
279            checkpoint_runtime.update_metadata(&result.metadata).await;
280            checkpoint_runtime
281                .persist_snapshot()
282                .await
283                .map_err(|error| format!("Failed to persist terminal checkpoint: {error}"))?;
284        }
285        Err(_) => {
286            checkpoint_runtime.update_metadata(&initial_metadata).await;
287            let _ = checkpoint_runtime.persist_snapshot().await;
288        }
289    }
290
291    let _ = tokio::time::timeout(std::time::Duration::from_secs(2), event_forwarder).await;
292
293    run_result
294        .map(|_| ())
295        .map_err(|error| format!("Agent run failed: {error}"))
296}
297
298fn is_new_session_history(messages: &[Message]) -> bool {
299    !messages
300        .iter()
301        .any(|message| matches!(message.role, Role::User | Role::Assistant | Role::Tool))
302}
303
304async fn resolve_session_cwd(state: &AppState, session_id: Uuid) -> String {
305    // 1. Session-specific cwd (set by API caller)
306    if let Ok(session) = state.session_store.get_session(session_id).await
307        && let Some(cwd) = session.cwd
308        && !cwd.trim().is_empty()
309    {
310        return cwd;
311    }
312
313    // 2. Configured project directory (set at server startup, e.g. from `stakpak up`)
314    if let Some(project_dir) = &state.project_dir {
315        return project_dir.clone();
316    }
317
318    // 3. Process working directory
319    std::env::current_dir()
320        .ok()
321        .map(|path| path.to_string_lossy().to_string())
322        .unwrap_or_else(|| ".".to_string())
323}
324
325fn prepend_context_to_user_message(mut message: Message, context_block: &str) -> Message {
326    if context_block.trim().is_empty() {
327        return message;
328    }
329
330    match &mut message.content {
331        MessageContent::Text(text) => {
332            let existing = std::mem::take(text);
333            *text = if existing.trim().is_empty() {
334                context_block.to_string()
335            } else {
336                format!("{context_block}\n\n{existing}")
337            };
338        }
339        MessageContent::Parts(parts) => {
340            let mut prefixed = Vec::with_capacity(parts.len() + 1);
341            prefixed.push(ContentPart::text(context_block));
342            prefixed.append(parts);
343            *parts = prefixed;
344        }
345    }
346
347    message
348}
349
350async fn handle_core_event(state: &AppState, session_id: Uuid, run_id: Uuid, event: AgentEvent) {
351    match &event {
352        AgentEvent::ToolCallsProposed { tool_calls, .. } => {
353            state
354                .set_pending_tools(session_id, run_id, tool_calls.clone())
355                .await;
356        }
357        AgentEvent::TurnCompleted { .. }
358        | AgentEvent::RunCompleted { .. }
359        | AgentEvent::RunError { .. } => {
360            state.clear_pending_tools(session_id, run_id).await;
361        }
362        _ => {}
363    }
364
365    state.events.publish(session_id, Some(run_id), event).await;
366}
367
368#[derive(Clone)]
369struct ServerToolExecutor {
370    state: AppState,
371}
372
373#[async_trait]
374impl ToolExecutor for ServerToolExecutor {
375    async fn execute_tool_call(
376        &self,
377        run: &AgentRunContext,
378        tool_call: &ProposedToolCall,
379        cancel: &CancellationToken,
380    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
381        Ok(execute_mcp_tool_call(&self.state, run.session_id, run.run_id, tool_call, cancel).await)
382    }
383}
384
385/// Tool executor that routes calls through a per-session sandboxed MCP client.
386#[derive(Clone)]
387struct SandboxedToolExecutor {
388    mcp_client: Arc<McpClient>,
389}
390
391#[async_trait]
392impl ToolExecutor for SandboxedToolExecutor {
393    async fn execute_tool_call(
394        &self,
395        run: &AgentRunContext,
396        tool_call: &ProposedToolCall,
397        cancel: &CancellationToken,
398    ) -> Result<ToolExecutionResult, stakpak_agent_core::AgentError> {
399        Ok(execute_mcp_tool_call_with_client(
400            &self.mcp_client,
401            run.session_id,
402            run.run_id,
403            tool_call,
404            cancel,
405        )
406        .await)
407    }
408}
409
410struct CheckpointRuntime {
411    state: AppState,
412    session_id: Uuid,
413    run_id: Uuid,
414    active_model: stakai::Model,
415    inner: Mutex<CheckpointRuntimeInner>,
416}
417
418struct CheckpointRuntimeInner {
419    parent_checkpoint_id: Option<Uuid>,
420    latest_messages: Vec<Message>,
421    latest_metadata: serde_json::Value,
422    last_persisted_signature: Option<String>,
423    dirty: bool,
424}
425
426impl CheckpointRuntime {
427    fn new(
428        state: AppState,
429        session_id: Uuid,
430        run_id: Uuid,
431        active_model: stakai::Model,
432        parent_checkpoint_id: Option<Uuid>,
433        latest_messages: Vec<Message>,
434        latest_metadata: serde_json::Value,
435    ) -> Self {
436        Self {
437            state,
438            session_id,
439            run_id,
440            active_model,
441            inner: Mutex::new(CheckpointRuntimeInner {
442                parent_checkpoint_id,
443                latest_messages,
444                latest_metadata,
445                last_persisted_signature: None,
446                dirty: true,
447            }),
448        }
449    }
450
451    async fn update_messages(&self, messages: &[Message]) {
452        let mut guard = self.inner.lock().await;
453        guard.latest_messages = messages.to_vec();
454        guard.dirty = true;
455    }
456
457    async fn update_metadata(&self, metadata: &serde_json::Value) {
458        let mut guard = self.inner.lock().await;
459        guard.latest_metadata = metadata.clone();
460        guard.dirty = true;
461    }
462
463    async fn persist_snapshot(&self) -> Result<Uuid, String> {
464        let mut guard = self.inner.lock().await;
465        self.persist_if_needed(&mut guard).await
466    }
467
468    async fn persist_if_needed(&self, guard: &mut CheckpointRuntimeInner) -> Result<Uuid, String> {
469        if !guard.dirty
470            && let Some(checkpoint_id) = guard.parent_checkpoint_id
471        {
472            return Ok(checkpoint_id);
473        }
474
475        let signature = checkpoint_signature(&guard.latest_messages, &guard.latest_metadata)?;
476        let changed = guard.last_persisted_signature.as_deref() != Some(signature.as_str());
477        let should_persist = guard.parent_checkpoint_id.is_none() || (guard.dirty && changed);
478
479        if !should_persist {
480            guard.dirty = false;
481            if let Some(checkpoint_id) = guard.parent_checkpoint_id {
482                return Ok(checkpoint_id);
483            }
484        }
485
486        let checkpoint_id = persist_checkpoint(
487            &self.state,
488            self.session_id,
489            self.run_id,
490            &self.active_model,
491            guard.parent_checkpoint_id,
492            &guard.latest_messages,
493            &guard.latest_metadata,
494        )
495        .await?;
496
497        guard.parent_checkpoint_id = Some(checkpoint_id);
498        guard.last_persisted_signature = Some(signature);
499        guard.dirty = false;
500
501        Ok(checkpoint_id)
502    }
503}
504
505struct ServerCheckpointHook {
506    checkpoint_runtime: Arc<CheckpointRuntime>,
507}
508
509#[async_trait]
510impl AgentHook for ServerCheckpointHook {
511    async fn before_inference(
512        &self,
513        _run: &AgentRunContext,
514        messages: &[Message],
515        _model: &stakai::Model,
516    ) -> Result<(), stakpak_agent_core::AgentError> {
517        self.checkpoint_runtime.update_messages(messages).await;
518        Ok(())
519    }
520
521    async fn after_inference(
522        &self,
523        _run: &AgentRunContext,
524        messages: &[Message],
525        _model: &stakai::Model,
526    ) -> Result<(), stakpak_agent_core::AgentError> {
527        self.checkpoint_runtime.update_messages(messages).await;
528        Ok(())
529    }
530
531    async fn after_tool_execution(
532        &self,
533        _run: &AgentRunContext,
534        _tool_call: &ProposedToolCall,
535        messages: &[Message],
536    ) -> Result<(), stakpak_agent_core::AgentError> {
537        self.checkpoint_runtime.update_messages(messages).await;
538        Ok(())
539    }
540
541    async fn on_error(
542        &self,
543        _run: &AgentRunContext,
544        _error: &stakpak_agent_core::AgentError,
545        messages: &[Message],
546    ) -> Result<(), stakpak_agent_core::AgentError> {
547        self.checkpoint_runtime.update_messages(messages).await;
548        let _ = self.checkpoint_runtime.persist_snapshot().await;
549        Ok(())
550    }
551}
552
553async fn execute_mcp_tool_call(
554    state: &AppState,
555    session_id: Uuid,
556    run_id: Uuid,
557    tool_call: &ProposedToolCall,
558    cancel: &CancellationToken,
559) -> ToolExecutionResult {
560    let Some(mcp_client) = state.mcp_client.as_ref() else {
561        return ToolExecutionResult::Completed {
562            result: "MCP client is not initialized".to_string(),
563            is_error: true,
564        };
565    };
566
567    execute_mcp_tool_call_with_client(mcp_client, session_id, run_id, tool_call, cancel).await
568}
569
570async fn execute_mcp_tool_call_with_client(
571    mcp_client: &McpClient,
572    session_id: Uuid,
573    run_id: Uuid,
574    tool_call: &ProposedToolCall,
575    cancel: &CancellationToken,
576) -> ToolExecutionResult {
577    let metadata = Some(serde_json::Map::from_iter([
578        (
579            "session_id".to_string(),
580            serde_json::Value::String(session_id.to_string()),
581        ),
582        (
583            "run_id".to_string(),
584            serde_json::Value::String(run_id.to_string()),
585        ),
586        (
587            "tool_call_id".to_string(),
588            serde_json::Value::String(tool_call.id.clone()),
589        ),
590    ]));
591
592    let arguments = match &tool_call.arguments {
593        serde_json::Value::Object(map) => Some(map.clone()),
594        serde_json::Value::Null => None,
595        other => Some(serde_json::Map::from_iter([(
596            "input".to_string(),
597            other.clone(),
598        )])),
599    };
600
601    let request_handle = match stakpak_mcp_client::call_tool(
602        mcp_client,
603        CallToolRequestParam {
604            name: tool_call.name.clone().into(),
605            arguments,
606        },
607        metadata,
608    )
609    .await
610    {
611        Ok(handle) => handle,
612        Err(error) => {
613            return ToolExecutionResult::Completed {
614                result: format!("MCP tool call failed: {error}"),
615                is_error: true,
616            };
617        }
618    };
619
620    let peer_for_cancel = request_handle.peer.clone();
621    let request_id = request_handle.id.clone();
622
623    tokio::select! {
624        _ = cancel.cancelled() => {
625            let notification = CancelledNotification {
626                method: CancelledNotificationMethod,
627                params: CancelledNotificationParam {
628                    request_id,
629                    reason: Some("user cancel".to_string()),
630                },
631                extensions: Default::default(),
632            };
633
634            let _ = peer_for_cancel.send_notification(notification.into()).await;
635            ToolExecutionResult::Cancelled
636        }
637        server_result = request_handle.await_response() => {
638            match server_result {
639                Ok(ServerResult::CallToolResult(result)) => {
640                    ToolExecutionResult::Completed {
641                        result: render_call_tool_result(&result),
642                        is_error: result.is_error.unwrap_or(false),
643                    }
644                }
645                Ok(_) => ToolExecutionResult::Completed {
646                    result: "Unexpected MCP response type".to_string(),
647                    is_error: true,
648                },
649                Err(error) => ToolExecutionResult::Completed {
650                    result: format!("MCP tool execution error: {error}"),
651                    is_error: true,
652                },
653            }
654        }
655    }
656}
657
658fn render_call_tool_result(result: &rmcp::model::CallToolResult) -> String {
659    let rendered = result
660        .content
661        .iter()
662        .filter_map(|content| content.raw.as_text().map(|text| text.text.clone()))
663        .collect::<Vec<_>>()
664        .join("\n");
665
666    if !rendered.is_empty() {
667        return sanitize_text_output(&rendered);
668    }
669
670    if result.content.is_empty() {
671        return "<empty tool result>".to_string();
672    }
673
674    "<non-text tool result omitted for safety>".to_string()
675}
676
677fn checkpoint_signature(
678    messages: &[Message],
679    metadata: &serde_json::Value,
680) -> Result<String, String> {
681    serde_json::to_string(&(messages, metadata))
682        .map_err(|error| format!("Failed to serialize checkpoint messages: {error}"))
683}
684
685async fn persist_checkpoint(
686    state: &AppState,
687    session_id: Uuid,
688    run_id: Uuid,
689    active_model: &stakai::Model,
690    parent_id: Option<Uuid>,
691    messages: &[Message],
692    metadata: &serde_json::Value,
693) -> Result<Uuid, String> {
694    // TODO(ahmed): Migrate server/session checkpoint storage to `Vec<stakai::Message>` directly
695    // and remove the ChatMessage adapter conversion (`message_bridge::stakai_to_chat`).
696    let mut request = CreateCheckpointRequest::new(message_bridge::stakai_to_chat(messages))
697        .with_metadata(metadata.clone());
698
699    if let Some(parent_id) = parent_id {
700        request = request.with_parent(parent_id);
701    }
702
703    let checkpoint = state
704        .session_store
705        .create_checkpoint(session_id, &request)
706        .await
707        .map_err(|error| error.to_string())?;
708
709    let mut envelope_metadata = if metadata.is_object() {
710        metadata.clone()
711    } else {
712        json!({})
713    };
714
715    if let Some(obj) = envelope_metadata.as_object_mut() {
716        obj.insert(
717            "session_id".to_string(),
718            serde_json::Value::String(session_id.to_string()),
719        );
720        obj.insert(
721            "checkpoint_id".to_string(),
722            serde_json::Value::String(checkpoint.id.to_string()),
723        );
724        obj.insert(
725            ACTIVE_MODEL_METADATA_KEY.to_string(),
726            serde_json::Value::String(format!("{}/{}", active_model.provider, active_model.id)),
727        );
728    }
729
730    let envelope = build_checkpoint_envelope(run_id, messages.to_vec(), envelope_metadata);
731
732    state
733        .checkpoint_store
734        .save_latest(session_id, &envelope)
735        .await
736        .map_err(|error| {
737            format!(
738                "Failed to persist checkpoint envelope for session {}: {}",
739                session_id, error
740            )
741        })?;
742
743    Ok(checkpoint.id)
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use rmcp::model::{CallToolResult, Content};
750    use serde_json::json;
751    use stakai::{ContentPart, Message, MessageContent, Role};
752
753    #[test]
754    fn run_id_is_not_regenerated_when_building_run_context() {
755        let session_id = Uuid::new_v4();
756        let run_id = Uuid::new_v4();
757
758        let run_context = build_run_context(session_id, run_id);
759
760        assert_eq!(run_context.session_id, session_id);
761        assert_eq!(run_context.run_id, run_id);
762    }
763
764    #[test]
765    fn checkpoint_envelope_carries_same_run_id() {
766        let run_id = Uuid::new_v4();
767        let envelope = build_checkpoint_envelope(
768            run_id,
769            vec![Message::new(Role::User, "hello")],
770            json!({"turn": 1}),
771        );
772
773        assert_eq!(envelope.run_id, Some(run_id));
774    }
775
776    #[test]
777    fn render_call_tool_result_sanitizes_text_blocks() {
778        let result = CallToolResult::success(vec![Content::text("ok\u{0007}done")]);
779
780        assert_eq!(render_call_tool_result(&result), "okdone");
781    }
782
783    #[test]
784    fn render_call_tool_result_omits_non_text_blocks() {
785        let result = CallToolResult::success(vec![Content::image("dGVzdA==", "image/png")]);
786
787        assert_eq!(
788            render_call_tool_result(&result),
789            "<non-text tool result omitted for safety>"
790        );
791    }
792
793    #[test]
794    fn checkpoint_signature_changes_when_messages_change() {
795        let messages_a = vec![Message::new(Role::User, "hello")];
796        let messages_b = vec![
797            Message::new(Role::User, "hello"),
798            Message::new(Role::Assistant, "hi"),
799        ];
800
801        let sig_a = checkpoint_signature(&messages_a, &json!({}))
802            .unwrap_or_else(|error| panic!("signature failed: {error}"));
803        let sig_b = checkpoint_signature(&messages_b, &json!({}))
804            .unwrap_or_else(|error| panic!("signature failed: {error}"));
805
806        assert_ne!(sig_a, sig_b);
807    }
808
809    #[test]
810    fn checkpoint_signature_changes_when_metadata_changes() {
811        let messages = vec![Message::new(Role::User, "hello")];
812
813        let sig_a = checkpoint_signature(&messages, &json!({}))
814            .unwrap_or_else(|error| panic!("signature failed: {error}"));
815        let sig_b = checkpoint_signature(&messages, &json!({"trimmed_up_to_message_index": 5}))
816            .unwrap_or_else(|error| panic!("signature failed: {error}"));
817
818        assert_ne!(sig_a, sig_b);
819    }
820
821    #[test]
822    fn is_new_session_empty_history() {
823        assert!(is_new_session_history(&[]));
824    }
825
826    #[test]
827    fn is_new_session_system_only() {
828        let messages = vec![Message::new(Role::System, "you are an agent")];
829        assert!(is_new_session_history(&messages));
830    }
831
832    #[test]
833    fn is_not_new_session_with_user_message() {
834        let messages = vec![Message::new(Role::User, "hello")];
835        assert!(!is_new_session_history(&messages));
836    }
837
838    #[test]
839    fn is_not_new_session_with_system_and_user() {
840        let messages = vec![
841            Message::new(Role::System, "system"),
842            Message::new(Role::User, "hello"),
843        ];
844        assert!(!is_new_session_history(&messages));
845    }
846
847    #[test]
848    fn is_not_new_session_with_assistant() {
849        let messages = vec![Message::new(Role::Assistant, "hi there")];
850        assert!(!is_new_session_history(&messages));
851    }
852
853    #[test]
854    fn prepend_context_to_text_message() {
855        let msg = Message::new(Role::User, "how do I deploy?");
856        let result = prepend_context_to_user_message(msg, "<context>env info</context>");
857
858        let text = result.text().unwrap_or_default();
859        assert!(
860            text.starts_with("<context>env info</context>"),
861            "context should be prepended"
862        );
863        assert!(
864            text.contains("how do I deploy?"),
865            "original text should be preserved"
866        );
867    }
868
869    #[test]
870    fn prepend_context_to_empty_text_message() {
871        let msg = Message::new(Role::User, "  ");
872        let result = prepend_context_to_user_message(msg, "<context>env info</context>");
873
874        let text = result.text().unwrap_or_default();
875        assert_eq!(text, "<context>env info</context>");
876    }
877
878    #[test]
879    fn prepend_context_to_parts_message() {
880        let msg = Message {
881            role: Role::User,
882            content: MessageContent::Parts(vec![ContentPart::text("original text")]),
883            name: None,
884            provider_options: None,
885        };
886        let result = prepend_context_to_user_message(msg, "<context>env info</context>");
887
888        if let MessageContent::Parts(parts) = &result.content {
889            assert_eq!(parts.len(), 2, "should have context part + original part");
890            if let ContentPart::Text { text, .. } = &parts[0] {
891                assert_eq!(text, "<context>env info</context>");
892            } else {
893                panic!("first part should be text");
894            }
895        } else {
896            panic!("expected Parts content");
897        }
898    }
899
900    #[test]
901    fn prepend_empty_context_is_noop() {
902        let msg = Message::new(Role::User, "hello");
903        let result = prepend_context_to_user_message(msg, "   ");
904
905        assert_eq!(result.text().unwrap_or_default(), "hello");
906    }
907}